Compare commits
1 Commits
burn/687-1
...
fix/624
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49296d538e |
331
pipeline/nightly_scheduler.py
Executable file
331
pipeline/nightly_scheduler.py
Executable file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
nightly_scheduler.py — Nightly Pipeline Scheduler
|
||||
|
||||
Auto-starts batch pipelines when inference is available, respecting
|
||||
priority ordering, token budgets, and peak-hour pausing.
|
||||
|
||||
Usage:
|
||||
python3 nightly_scheduler.py # run scheduler
|
||||
python3 nightly_scheduler.py --check # dry-run: show what would start
|
||||
python3 nightly_scheduler.py --status # show pipeline status
|
||||
python3 nightly_scheduler.py --reset # reset daily budget
|
||||
|
||||
Crontab:
|
||||
# Run every 30 minutes during off-peak hours (10pm-6am)
|
||||
*/30 22-5 * * * cd /path/to/timmy-config && python3 pipeline/nightly_scheduler.py >> ~/.hermes/pipeline-logs/nightly.log 2>&1
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
# --- Config ---
|
||||
STATE_FILE = Path.home() / ".hermes" / "pipeline_state.json"
|
||||
LOG_DIR = Path.home() / ".hermes" / "pipeline-logs"
|
||||
DAILY_TOKEN_BUDGET = 5_000_000 # 5M tokens per day
|
||||
PEAK_HOURS = list(range(8, 22)) # 8am-10pm = peak interactive usage
|
||||
CHECK_INTERVAL = 1800 # 30 minutes
|
||||
|
||||
INFERENCE_ENDPOINTS = [
|
||||
{"name": "local_ollama", "url": "http://localhost:11434/v1/models", "type": "local"},
|
||||
{"name": "runpod", "url": "https://8lfr3j47a5r3gn-11434.proxy.runpod.net/v1/models", "type": "gpu"},
|
||||
{"name": "openrouter", "url": "https://openrouter.ai/api/v1/models", "type": "cloud"},
|
||||
]
|
||||
|
||||
# Pipeline priority order (highest first)
|
||||
PIPELINE_PRIORITY = [
|
||||
{"name": "playground_factory", "script": "pipeline/playground_factory.py", "priority": 1},
|
||||
{"name": "training_factory", "script": "pipeline/training_factory.py", "priority": 2},
|
||||
{"name": "knowledge_mine", "script": "pipeline/knowledge_mine.py", "priority": 3},
|
||||
{"name": "adversary", "script": "pipeline/adversary_runner.py", "priority": 4},
|
||||
{"name": "codebase_genome", "script": "pipeline/codebase_genome.py", "priority": 5},
|
||||
]
|
||||
|
||||
# Dependency rules: some pipelines only start after others are running
|
||||
DEPENDENCY_RULES = {
|
||||
"playground_factory": [], # no deps, start immediately
|
||||
"training_factory": [], # no deps, start in parallel
|
||||
"knowledge_mine": ["training_factory"], # start after training is running
|
||||
"adversary": ["knowledge_mine"], # start after knowledge is halfway
|
||||
"codebase_genome": [], # continuous, one repo per night
|
||||
}
|
||||
|
||||
|
||||
def load_state():
|
||||
"""Load pipeline state from disk."""
|
||||
if STATE_FILE.exists():
|
||||
with open(STATE_FILE) as f:
|
||||
return json.load(f)
|
||||
return {
|
||||
"last_run": None,
|
||||
"daily_tokens_used": 0,
|
||||
"budget_reset_date": None,
|
||||
"pipelines": {},
|
||||
"active_sessions": [],
|
||||
}
|
||||
|
||||
|
||||
def save_state(state):
|
||||
"""Save pipeline state to disk."""
|
||||
STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(STATE_FILE, "w") as f:
|
||||
json.dump(state, f, indent=2)
|
||||
|
||||
|
||||
def check_provider(endpoint):
|
||||
"""Check if an inference provider is available."""
|
||||
try:
|
||||
req = urllib.request.Request(endpoint["url"], headers={"Authorization": "Bearer ollama"})
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
return resp.status == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_available_providers():
|
||||
"""Check all inference endpoints and return available ones."""
|
||||
available = []
|
||||
for ep in INFERENCE_ENDPOINTS:
|
||||
if check_provider(ep):
|
||||
available.append(ep["name"])
|
||||
return available
|
||||
|
||||
|
||||
def is_peak_hours():
|
||||
"""Check if current time is during peak interactive usage."""
|
||||
now = datetime.now()
|
||||
return now.hour in PEAK_HOURS
|
||||
|
||||
|
||||
def check_token_budget(state):
|
||||
"""Check if daily token budget allows starting new work."""
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
if state.get("budget_reset_date") != today:
|
||||
# New day, reset budget
|
||||
state["daily_tokens_used"] = 0
|
||||
state["budget_reset_date"] = today
|
||||
save_state(state)
|
||||
return state["daily_tokens_used"] < DAILY_TOKEN_BUDGET
|
||||
|
||||
|
||||
def get_pipeline_status(state, pipeline_name):
|
||||
"""Get the status of a specific pipeline."""
|
||||
return state.get("pipelines", {}).get(pipeline_name, {
|
||||
"status": "not_started",
|
||||
"last_run": None,
|
||||
"last_success": None,
|
||||
"progress": 0,
|
||||
})
|
||||
|
||||
|
||||
def check_dependencies(state, pipeline_name):
|
||||
"""Check if pipeline dependencies are satisfied."""
|
||||
deps = DEPENDENCY_RULES.get(pipeline_name, [])
|
||||
for dep in deps:
|
||||
dep_status = get_pipeline_status(state, dep)
|
||||
if dep_status["status"] not in ("running", "completed"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def start_pipeline(pipeline, state, dry_run=False):
|
||||
"""Start a pipeline process."""
|
||||
name = pipeline["name"]
|
||||
script = pipeline["script"]
|
||||
|
||||
log(f"Starting pipeline: {name}")
|
||||
|
||||
if dry_run:
|
||||
log(f" DRY RUN — would run: python3 {script}")
|
||||
return True
|
||||
|
||||
# Check if script exists
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
log(f" Script not found: {script_path}")
|
||||
# Update state anyway so we track the attempt
|
||||
state["pipelines"][name] = {
|
||||
"status": "script_missing",
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"progress": 0,
|
||||
}
|
||||
save_state(state)
|
||||
return False
|
||||
|
||||
# Run the pipeline script
|
||||
import subprocess
|
||||
log_dir = LOG_DIR / name
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_file = log_dir / f"{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
||||
|
||||
try:
|
||||
proc = subprocess.Popen(
|
||||
["python3", str(script_path)],
|
||||
stdout=open(log_file, "w"),
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(Path(script).parent.parent),
|
||||
)
|
||||
|
||||
state["pipelines"][name] = {
|
||||
"status": "running",
|
||||
"pid": proc.pid,
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"log_file": str(log_file),
|
||||
"progress": 0,
|
||||
}
|
||||
save_state(state)
|
||||
log(f" Started PID {proc.pid}, log: {log_file}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log(f" Failed to start: {e}")
|
||||
state["pipelines"][name] = {
|
||||
"status": "failed",
|
||||
"last_run": datetime.now(timezone.utc).isoformat(),
|
||||
"error": str(e),
|
||||
}
|
||||
save_state(state)
|
||||
return False
|
||||
|
||||
|
||||
def check_running_pipelines(state):
|
||||
"""Check status of running pipelines and update state."""
|
||||
import subprocess
|
||||
for name, info in state.get("pipelines", {}).items():
|
||||
if info.get("status") == "running":
|
||||
pid = info.get("pid")
|
||||
if pid:
|
||||
try:
|
||||
os.kill(pid, 0) # Check if process exists
|
||||
except ProcessLookupError:
|
||||
# Process finished
|
||||
info["status"] = "completed"
|
||||
info["completed_at"] = datetime.now(timezone.utc).isoformat()
|
||||
log(f"Pipeline {name} completed (PID {pid} exited)")
|
||||
save_state(state)
|
||||
|
||||
|
||||
def run_scheduler(dry_run=False, check_only=False):
|
||||
"""Main scheduler loop."""
|
||||
state = load_state()
|
||||
|
||||
log("=" * 50)
|
||||
log(f"Pipeline Scheduler — {datetime.now().isoformat()}")
|
||||
log(f"Mode: {'CHECK' if check_only else 'DRY RUN' if dry_run else 'LIVE'}")
|
||||
|
||||
# Check peak hours
|
||||
if is_peak_hours():
|
||||
log("Peak hours detected. Pausing pipeline starts.")
|
||||
log("Pipelines will resume at 10pm.")
|
||||
return
|
||||
|
||||
# Check token budget
|
||||
if not check_token_budget(state):
|
||||
log(f"Daily token budget exhausted ({state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET})")
|
||||
return
|
||||
log(f"Token budget: {state['daily_tokens_used']}/{DAILY_TOKEN_BUDGET}")
|
||||
|
||||
# Check providers
|
||||
providers = get_available_providers()
|
||||
if not providers:
|
||||
log("No inference providers available. Skipping.")
|
||||
return
|
||||
log(f"Available providers: {', '.join(providers)}")
|
||||
|
||||
# Check running pipelines
|
||||
check_running_pipelines(state)
|
||||
|
||||
# Find next pipeline to start
|
||||
started = 0
|
||||
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
|
||||
name = pipeline["name"]
|
||||
status = get_pipeline_status(state, name)
|
||||
|
||||
# Skip if already running or completed
|
||||
if status["status"] in ("running", "completed"):
|
||||
log(f" {name}: {status['status']} (skipping)")
|
||||
continue
|
||||
|
||||
# Check dependencies
|
||||
if not check_dependencies(state, name):
|
||||
deps = DEPENDENCY_RULES.get(name, [])
|
||||
log(f" {name}: waiting for dependencies: {deps}")
|
||||
continue
|
||||
|
||||
# Start the pipeline
|
||||
if check_only:
|
||||
log(f" {name}: READY to start (priority {pipeline['priority']})")
|
||||
else:
|
||||
if start_pipeline(pipeline, state, dry_run):
|
||||
started += 1
|
||||
# Only start one pipeline per run to avoid overload
|
||||
if started >= 1:
|
||||
log("Started 1 pipeline. Will check again next cycle.")
|
||||
break
|
||||
|
||||
if started == 0 and not check_only:
|
||||
log("No pipelines to start. All are running, completed, or blocked.")
|
||||
|
||||
log("=" * 50)
|
||||
|
||||
|
||||
def show_status():
|
||||
"""Show current pipeline status."""
|
||||
state = load_state()
|
||||
print(f"\nPipeline Status — {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
||||
print(f"Token budget: {state.get('daily_tokens_used', 0)}/{DAILY_TOKEN_BUDGET}")
|
||||
print(f"Last run: {state.get('last_run', 'never')}")
|
||||
print()
|
||||
|
||||
for pipeline in sorted(PIPELINE_PRIORITY, key=lambda p: p["priority"]):
|
||||
name = pipeline["name"]
|
||||
status = get_pipeline_status(state, name)
|
||||
st = status["status"]
|
||||
icon = {"running": "●", "completed": "✓", "failed": "✗", "not_started": "○", "script_missing": "?"}.get(st, "?")
|
||||
print(f" {icon} {name:25} {st:15} last={(status.get('last_run') or 'never')[:19]}")
|
||||
|
||||
|
||||
def reset_budget():
|
||||
"""Reset daily token budget."""
|
||||
state = load_state()
|
||||
state["daily_tokens_used"] = 0
|
||||
state["budget_reset_date"] = datetime.now().strftime("%Y-%m-%d")
|
||||
save_state(state)
|
||||
print("Budget reset.")
|
||||
|
||||
|
||||
def log(msg):
|
||||
"""Log to stdout and file."""
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
line = f"[{timestamp}] {msg}"
|
||||
print(line)
|
||||
LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
log_file = LOG_DIR / "nightly.log"
|
||||
with open(log_file, "a") as f:
|
||||
f.write(line + "\n")
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Nightly Pipeline Scheduler")
|
||||
parser.add_argument("--check", action="store_true", help="Dry-run: show what would start")
|
||||
parser.add_argument("--status", action="store_true", help="Show pipeline status")
|
||||
parser.add_argument("--reset", action="store_true", help="Reset daily token budget")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Dry-run mode")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.status:
|
||||
show_status()
|
||||
elif args.reset:
|
||||
reset_budget()
|
||||
else:
|
||||
run_scheduler(dry_run=args.dry_run or args.check, check_only=args.check)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,389 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Data Quality Filter (#687)
|
||||
|
||||
Scores and removes low-quality training pairs from JSONL files.
|
||||
Supports: ShareGPT format, preference pairs, generic JSONL.
|
||||
|
||||
Usage:
|
||||
python3 scripts/filter_training_data.py <input.jsonl> [--output filtered.jsonl]
|
||||
python3 scripts/filter_training_data.py training/data/preference_pairs.jsonl
|
||||
python3 scripts/filter_training_data.py training/data/curated_dataset.jsonl --threshold 0.3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# ============================================================
|
||||
# QUALITY SCORING
|
||||
# ============================================================
|
||||
|
||||
# Generic filler phrases that indicate low-quality responses
|
||||
FILLER_PHRASES = [
|
||||
"as an ai", "i'm an ai", "as a language model", "i don't have personal",
|
||||
"i cannot", "i can't", "it's important to note", "please note that",
|
||||
"in conclusion", "to summarize", "in summary", "hope this helps",
|
||||
"let me know if", "feel free to", "i'd be happy to", "certainly!",
|
||||
"of course!", "absolutely!", "great question!", "that's a great",
|
||||
"i understand your", "i appreciate your", "thank you for asking",
|
||||
"it depends", "there are many ways", "various factors",
|
||||
]
|
||||
|
||||
# Vague/generic short responses
|
||||
VAGUE_RESPONSES = [
|
||||
"ok", "okay", "sure", "yes", "no", "maybe", "idk", "i don't know",
|
||||
"thanks", "thank you", "got it", "understood", "right", "correct",
|
||||
"hello", "hi", "hey", "goodbye", "bye",
|
||||
]
|
||||
|
||||
CODE_BLOCK_PATTERN = re.compile(r"```(?:\w+)?\n(.+?)```", re.DOTALL)
|
||||
INLINE_CODE_PATTERN = re.compile(r"`([^`]+)`")
|
||||
|
||||
|
||||
def detect_format(record: dict) -> str:
|
||||
"""Detect the training data format of a record."""
|
||||
if "conversations" in record:
|
||||
return "sharegpt"
|
||||
if "prompt" in record and "chosen" in record:
|
||||
return "preference"
|
||||
if "scene" in record and "lyric_line" in record:
|
||||
return "scene"
|
||||
if "terse" in record and "rich" in record:
|
||||
return "pairs"
|
||||
return "generic"
|
||||
|
||||
|
||||
def extract_text_fields(record: dict, fmt: str) -> Tuple[str, str]:
|
||||
"""Extract (input_text, output_text) from a record based on format."""
|
||||
if fmt == "sharegpt":
|
||||
convs = record.get("conversations", [])
|
||||
human_msgs = [c["value"] for c in convs if c.get("from") == "human"]
|
||||
gpt_msgs = [c["value"] for c in convs if c.get("from") == "gpt"]
|
||||
input_text = human_msgs[-1] if human_msgs else ""
|
||||
output_text = gpt_msgs[-1] if gpt_msgs else ""
|
||||
return input_text, output_text
|
||||
|
||||
elif fmt == "preference":
|
||||
return record.get("prompt", ""), record.get("chosen", "")
|
||||
|
||||
elif fmt == "scene":
|
||||
return record.get("lyric_line", ""), record.get("scene", {}).get("description", "")
|
||||
|
||||
elif fmt == "pairs":
|
||||
return record.get("terse", ""), record.get("rich", "")
|
||||
|
||||
else:
|
||||
# Generic: try common field names
|
||||
input_text = record.get("input", record.get("prompt", record.get("question", "")))
|
||||
output_text = record.get("output", record.get("response", record.get("answer", "")))
|
||||
return str(input_text), str(output_text)
|
||||
|
||||
|
||||
def score_specificity(text: str) -> float:
|
||||
"""Score 0-1 how specific/detailed a response is vs generic filler."""
|
||||
if not text or not text.strip():
|
||||
return 0.0
|
||||
|
||||
text_lower = text.lower().strip()
|
||||
score = 0.5 # baseline
|
||||
|
||||
# Penalize filler phrases
|
||||
filler_count = sum(1 for phrase in FILLER_PHRASES if phrase in text_lower)
|
||||
score -= filler_count * 0.08
|
||||
|
||||
# Penalize very short responses
|
||||
word_count = len(text.split())
|
||||
if word_count < 5:
|
||||
score -= 0.3
|
||||
elif word_count < 10:
|
||||
score -= 0.15
|
||||
elif word_count > 30:
|
||||
score += 0.1 # longer responses tend to be more detailed
|
||||
|
||||
# Penalize vague single-word responses
|
||||
if text_lower.strip() in VAGUE_RESPONSES:
|
||||
score -= 0.4
|
||||
|
||||
# Reward specificity indicators
|
||||
specificity_markers = [
|
||||
r"\d+", # numbers
|
||||
r"```", # code blocks
|
||||
r"https?://", # URLs
|
||||
r"\$\{", r"\w+\.\w+", # code-like patterns
|
||||
r"(?:specifically|exactly|precisely|in particular)",
|
||||
r"(?:step \d|first,|second,|third,|finally,)",
|
||||
]
|
||||
for pattern in specificity_markers:
|
||||
if re.search(pattern, text):
|
||||
score += 0.05
|
||||
|
||||
# Reward code presence
|
||||
if "```" in text:
|
||||
score += 0.15
|
||||
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
|
||||
def score_length_ratio(input_text: str, output_text: str) -> float:
|
||||
"""Score 0-1 based on reasonable length ratio between input and output."""
|
||||
in_len = len(input_text.split())
|
||||
out_len = len(output_text.split())
|
||||
|
||||
if in_len == 0 and out_len == 0:
|
||||
return 0.0
|
||||
if out_len == 0:
|
||||
return 0.0
|
||||
|
||||
# Ideal ratio: output 0.5x to 10x input length
|
||||
# Too short output for long input = bad
|
||||
# Too long output for short input = acceptable (detailed answer)
|
||||
if in_len > 0:
|
||||
ratio = out_len / in_len
|
||||
else:
|
||||
ratio = out_len / 10 # normalize when no input
|
||||
|
||||
if ratio < 0.05:
|
||||
return 0.1 # output way too short
|
||||
elif ratio < 0.2:
|
||||
return 0.3
|
||||
elif ratio < 0.5:
|
||||
return 0.6
|
||||
elif ratio <= 15:
|
||||
return 1.0 # sweet spot
|
||||
elif ratio <= 50:
|
||||
return 0.8
|
||||
else:
|
||||
return 0.5 # extremely long output, maybe noise
|
||||
|
||||
|
||||
def score_code_correctness(text: str) -> float:
|
||||
"""Score 0-1 for code correctness if code blocks are present."""
|
||||
code_blocks = CODE_BLOCK_PATTERN.findall(text)
|
||||
|
||||
if not code_blocks:
|
||||
return 1.0 # no code, not penalized
|
||||
|
||||
total = len(code_blocks)
|
||||
valid = 0
|
||||
|
||||
for code in code_blocks:
|
||||
# Try Python syntax check
|
||||
try:
|
||||
ast.parse(code)
|
||||
valid += 1
|
||||
continue
|
||||
except SyntaxError:
|
||||
pass
|
||||
|
||||
# Try JavaScript basic check (balanced braces/parens)
|
||||
if _check_brackets_balanced(code):
|
||||
valid += 0.8
|
||||
continue
|
||||
|
||||
# JSON check
|
||||
try:
|
||||
json.loads(code)
|
||||
valid += 1
|
||||
continue
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Shell/YAML: just check it's not empty garbage
|
||||
if len(code.strip()) > 10 and "\n" in code:
|
||||
valid += 0.5
|
||||
|
||||
return valid / total if total > 0 else 1.0
|
||||
|
||||
|
||||
def _check_brackets_balanced(code: str) -> bool:
|
||||
"""Check if brackets are balanced in code."""
|
||||
stack = []
|
||||
pairs = {"(": ")", "[": "]", "{": "}"}
|
||||
for ch in code:
|
||||
if ch in pairs:
|
||||
stack.append(pairs[ch])
|
||||
elif ch in pairs.values():
|
||||
if not stack or stack[-1] != ch:
|
||||
return False
|
||||
stack.pop()
|
||||
return len(stack) == 0
|
||||
|
||||
|
||||
def score_record(record: dict, fmt: str) -> Dict[str, float]:
|
||||
"""Score a single training record. Returns dict of component scores."""
|
||||
input_text, output_text = extract_text_fields(record, fmt)
|
||||
|
||||
specificity = score_specificity(output_text)
|
||||
length_ratio = score_length_ratio(input_text, output_text)
|
||||
code_correctness = score_code_correctness(output_text)
|
||||
|
||||
# Weighted composite
|
||||
composite = (
|
||||
specificity * 0.45 +
|
||||
length_ratio * 0.25 +
|
||||
code_correctness * 0.30
|
||||
)
|
||||
|
||||
return {
|
||||
"specificity": round(specificity, 3),
|
||||
"length_ratio": round(length_ratio, 3),
|
||||
"code_correctness": round(code_correctness, 3),
|
||||
"composite": round(composite, 3),
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# FILTERING
|
||||
# ============================================================
|
||||
|
||||
def filter_jsonl(
|
||||
input_path: str,
|
||||
output_path: Optional[str] = None,
|
||||
threshold: float = 0.3,
|
||||
dry_run: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter a JSONL file, removing low-quality records."""
|
||||
|
||||
if output_path is None:
|
||||
stem = Path(input_path).stem
|
||||
output_path = str(Path(input_path).parent / f"{stem}_filtered.jsonl")
|
||||
|
||||
records = []
|
||||
with open(input_path, "r", encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" [WARN] Line {i+1}: invalid JSON, skipping: {e}", file=sys.stderr)
|
||||
|
||||
if not records:
|
||||
return {"error": "No valid records found", "total": 0}
|
||||
|
||||
# Detect format from first record
|
||||
fmt = detect_format(records[0])
|
||||
print(f" Detected format: {fmt}")
|
||||
print(f" Total records: {len(records)}")
|
||||
|
||||
# Score all records
|
||||
scored = []
|
||||
for i, record in enumerate(records):
|
||||
scores = score_record(record, fmt)
|
||||
scored.append((record, scores, i))
|
||||
|
||||
# Sort by composite score
|
||||
scored.sort(key=lambda x: x[1]["composite"])
|
||||
|
||||
# Filter
|
||||
kept = [(r, s, i) for r, s, i in scored if s["composite"] >= threshold]
|
||||
removed = [(r, s, i) for r, s, i in scored if s["composite"] < threshold]
|
||||
|
||||
# Report
|
||||
report = {
|
||||
"input_file": input_path,
|
||||
"output_file": output_path,
|
||||
"format": fmt,
|
||||
"total_records": len(records),
|
||||
"kept": len(kept),
|
||||
"removed": len(removed),
|
||||
"threshold": threshold,
|
||||
"removal_rate": f"{len(removed) / len(records) * 100:.1f}%",
|
||||
"score_distribution": {
|
||||
"min": scored[0][1]["composite"] if scored else 0,
|
||||
"max": scored[-1][1]["composite"] if scored else 0,
|
||||
"median": scored[len(scored)//2][1]["composite"] if scored else 0,
|
||||
"mean": round(sum(s["composite"] for _, s, _ in scored) / len(scored), 3) if scored else 0,
|
||||
},
|
||||
"removed_score_breakdown": {
|
||||
"specificity_below_0.3": sum(1 for _, s, _ in removed if s["specificity"] < 0.3),
|
||||
"length_ratio_below_0.3": sum(1 for _, s, _ in removed if s["length_ratio"] < 0.3),
|
||||
"code_correctness_below_0.5": sum(1 for _, s, _ in removed if s["code_correctness"] < 0.5),
|
||||
},
|
||||
}
|
||||
|
||||
# Show worst offenders if verbose
|
||||
if verbose and removed:
|
||||
print(f"\n Worst 5 records (by composite score):")
|
||||
for r, s, i in removed[:5]:
|
||||
_, output_text = extract_text_fields(r, fmt)
|
||||
preview = output_text[:80].replace("\n", " ") if output_text else "(empty)"
|
||||
print(f" [{s['composite']:.3f}] {preview}...")
|
||||
|
||||
# Write output (unless dry run)
|
||||
if not dry_run:
|
||||
# Preserve original order, only keeping filtered records
|
||||
kept_indices = {i for _, _, i in kept}
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for i, record in enumerate(records):
|
||||
if i in kept_indices:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||
print(f"\n Written: {output_path}")
|
||||
|
||||
return report
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CLI
|
||||
# ============================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Training data quality filter — remove low-quality pairs (#687)"
|
||||
)
|
||||
parser.add_argument("input", help="Input JSONL file path")
|
||||
parser.add_argument("--output", "-o", help="Output file path (default: <input>_filtered.jsonl)")
|
||||
parser.add_argument("--threshold", "-t", type=float, default=0.3,
|
||||
help="Minimum composite score to keep (default: 0.3)")
|
||||
parser.add_argument("--dry-run", "-n", action="store_true",
|
||||
help="Score only, don't write output")
|
||||
parser.add_argument("--verbose", "-v", action="store_true",
|
||||
help="Show worst offenders")
|
||||
parser.add_argument("--report-json", "-j", help="Write report as JSON to file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.exists(args.input):
|
||||
print(f"Error: {args.input} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Filtering: {args.input}")
|
||||
print(f"Threshold: {args.threshold}")
|
||||
print()
|
||||
|
||||
report = filter_jsonl(
|
||||
args.input,
|
||||
output_path=args.output,
|
||||
threshold=args.threshold,
|
||||
dry_run=args.dry_run,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f" RESULTS")
|
||||
print(f"{'=' * 50}")
|
||||
print(f" Format: {report['format']}")
|
||||
print(f" Total: {report['total_records']}")
|
||||
print(f" Kept: {report['kept']}")
|
||||
print(f" Removed: {report['removed']} ({report['removal_rate']})")
|
||||
print(f" Threshold: {report['threshold']}")
|
||||
print(f" Score range: {report['score_distribution']['min']:.3f} - {report['score_distribution']['max']:.3f}")
|
||||
print(f" Mean score: {report['score_distribution']['mean']:.3f}")
|
||||
|
||||
if args.report_json:
|
||||
with open(args.report_json, "w") as f:
|
||||
json.dump(report, f, indent=2)
|
||||
print(f"\n Report saved: {args.report_json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,192 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for training data quality filter (#687).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
# Import from the script
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "scripts"))
|
||||
from filter_training_data import (
|
||||
detect_format,
|
||||
extract_text_fields,
|
||||
score_specificity,
|
||||
score_length_ratio,
|
||||
score_code_correctness,
|
||||
score_record,
|
||||
filter_jsonl,
|
||||
FILLER_PHRASES,
|
||||
VAGUE_RESPONSES,
|
||||
)
|
||||
|
||||
|
||||
class TestFormatDetection(unittest.TestCase):
|
||||
def test_sharegpt_format(self):
|
||||
record = {"conversations": [{"from": "human", "value": "hi"}]}
|
||||
self.assertEqual(detect_format(record), "sharegpt")
|
||||
|
||||
def test_preference_format(self):
|
||||
record = {"prompt": "do X", "chosen": "done", "rejected": "no"}
|
||||
self.assertEqual(detect_format(record), "preference")
|
||||
|
||||
def test_scene_format(self):
|
||||
record = {"lyric_line": "test", "scene": {"description": "desc"}}
|
||||
self.assertEqual(detect_format(record), "scene")
|
||||
|
||||
def test_pairs_format(self):
|
||||
record = {"terse": "short", "rich": "detailed"}
|
||||
self.assertEqual(detect_format(record), "pairs")
|
||||
|
||||
def test_generic_format(self):
|
||||
record = {"input": "q", "output": "a"}
|
||||
self.assertEqual(detect_format(record), "generic")
|
||||
|
||||
|
||||
class TestExtractTextFields(unittest.TestCase):
|
||||
def test_sharegpt_extraction(self):
|
||||
record = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "system prompt"},
|
||||
{"from": "human", "value": "hello"},
|
||||
{"from": "gpt", "value": "hi there"},
|
||||
]
|
||||
}
|
||||
inp, out = extract_text_fields(record, "sharegpt")
|
||||
self.assertEqual(inp, "hello")
|
||||
self.assertEqual(out, "hi there")
|
||||
|
||||
def test_preference_extraction(self):
|
||||
record = {"prompt": "question", "chosen": "good answer"}
|
||||
inp, out = extract_text_fields(record, "preference")
|
||||
self.assertEqual(inp, "question")
|
||||
self.assertEqual(out, "good answer")
|
||||
|
||||
|
||||
class TestSpecificityScoring(unittest.TestCase):
|
||||
def test_empty_text(self):
|
||||
self.assertEqual(score_specificity(""), 0.0)
|
||||
|
||||
def test_filler_heavy(self):
|
||||
text = "As an AI, I cannot provide that. It's important to note that I'm an AI."
|
||||
score = score_specificity(text)
|
||||
self.assertLess(score, 0.3)
|
||||
|
||||
def test_vague_response(self):
|
||||
score = score_specificity("ok")
|
||||
self.assertLess(score, 0.2)
|
||||
|
||||
def test_specific_response(self):
|
||||
text = "Here are the steps:\n1. First, install Python 3.12\n2. Run `pip install numpy`\n3. Execute main.py"
|
||||
score = score_specificity(text)
|
||||
self.assertGreater(score, 0.5)
|
||||
|
||||
def test_code_response(self):
|
||||
text = "Use this:\n```python\ndef hello():\n print('world')\n```"
|
||||
score = score_specificity(text)
|
||||
self.assertGreater(score, 0.6)
|
||||
|
||||
|
||||
class TestLengthRatio(unittest.TestCase):
|
||||
def test_both_empty(self):
|
||||
self.assertEqual(score_length_ratio("", ""), 0.0)
|
||||
|
||||
def test_empty_output(self):
|
||||
self.assertEqual(score_length_ratio("hello world", ""), 0.0)
|
||||
|
||||
def test_good_ratio(self):
|
||||
score = score_length_ratio("short question", "This is a reasonable length answer that addresses the question.")
|
||||
self.assertGreater(score, 0.7)
|
||||
|
||||
def test_too_short_output(self):
|
||||
score = score_length_ratio("This is a very long question with many words that expects a detailed answer", "ok")
|
||||
self.assertLess(score, 0.5)
|
||||
|
||||
|
||||
class TestCodeCorrectness(unittest.TestCase):
|
||||
def test_no_code(self):
|
||||
self.assertEqual(score_code_correctness("plain text"), 1.0)
|
||||
|
||||
def test_valid_python(self):
|
||||
text = "```python\ndef foo():\n return 42\n```"
|
||||
self.assertEqual(score_code_correctness(text), 1.0)
|
||||
|
||||
def test_invalid_python(self):
|
||||
text = "```python\ndef foo(\n return 42\n```"
|
||||
score = score_code_correctness(text)
|
||||
self.assertLess(score, 1.0)
|
||||
|
||||
def test_valid_json(self):
|
||||
text = "```json\n{\"key\": \"value\"}\n```"
|
||||
self.assertEqual(score_code_correctness(text), 1.0)
|
||||
|
||||
|
||||
class TestFilterJsonl(unittest.TestCase):
|
||||
def _write_temp_jsonl(self, records):
|
||||
f = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
|
||||
for r in records:
|
||||
f.write(json.dumps(r) + "\n")
|
||||
f.close()
|
||||
return f.name
|
||||
|
||||
def test_filter_removes_low_quality(self):
|
||||
records = [
|
||||
{"conversations": [
|
||||
{"from": "human", "value": "How do I sort a list in Python?"},
|
||||
{"from": "gpt", "value": "Use `sorted()` or `list.sort()`.\n```python\nnums = [3,1,2]\nnums.sort()\nprint(nums) # [1, 2, 3]\n```"},
|
||||
]},
|
||||
{"conversations": [
|
||||
{"from": "human", "value": "What is Python?"},
|
||||
{"from": "gpt", "value": "ok"},
|
||||
]},
|
||||
{"conversations": [
|
||||
{"from": "human", "value": "Tell me about databases."},
|
||||
{"from": "gpt", "value": "As an AI, I cannot. It's important to note."},
|
||||
]},
|
||||
]
|
||||
path = self._write_temp_jsonl(records)
|
||||
try:
|
||||
report = filter_jsonl(path, threshold=0.3)
|
||||
self.assertEqual(report["total_records"], 3)
|
||||
self.assertGreater(report["kept"], 0)
|
||||
self.assertGreater(report["removed"], 0)
|
||||
self.assertEqual(report["format"], "sharegpt")
|
||||
finally:
|
||||
os.unlink(path)
|
||||
if os.path.exists(report.get("output_file", "")):
|
||||
os.unlink(report["output_file"])
|
||||
|
||||
def test_dry_run_no_output(self):
|
||||
records = [
|
||||
{"prompt": "test", "chosen": "good detailed answer with code: `print(1)`", "rejected": "no"},
|
||||
]
|
||||
path = self._write_temp_jsonl(records)
|
||||
try:
|
||||
out_path = path.replace(".jsonl", "_filtered.jsonl")
|
||||
report = filter_jsonl(path, threshold=0.3, dry_run=True)
|
||||
self.assertFalse(os.path.exists(out_path))
|
||||
self.assertEqual(report["total_records"], 1)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
|
||||
def test_preference_format(self):
|
||||
records = [
|
||||
{"prompt": "Write a function", "chosen": "```python\ndef f(): pass\n```", "rejected": ""},
|
||||
{"prompt": "Hi", "chosen": "ok", "rejected": "no"},
|
||||
]
|
||||
path = self._write_temp_jsonl(records)
|
||||
try:
|
||||
report = filter_jsonl(path, threshold=0.3)
|
||||
self.assertEqual(report["format"], "preference")
|
||||
self.assertEqual(report["total_records"], 2)
|
||||
finally:
|
||||
os.unlink(path)
|
||||
if os.path.exists(report.get("output_file", "")):
|
||||
os.unlink(report["output_file"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user