Replace the 5-file sliding window cross-run dedup with a persistent hash index that covers ALL historical training data. Overfitting risk compounds across the full dataset — a 5-file window lets old duplicates leak back into training after enough overnight runs. New module: dedup_index.py (DedupIndex) - Persistent JSON index (.dpo_dedup_index.json) alongside JSONL files - Append-on-export: new prompt hashes registered after each successful export — no full rescan needed for normal operations - Incremental sync: on load, detects JSONL files not yet indexed and ingests them automatically (handles files from other tools) - Full rebuild: rebuild() scans ALL deepdive_*.jsonl + pairs_*.jsonl to reconstruct from scratch (first run, corruption recovery) - Atomic writes (write-to-tmp + rename) to prevent index corruption - Standalone CLI: python3 dedup_index.py <dir> --rebuild --stats Modified: dpo_quality.py - Imports DedupIndex with graceful degradation - Replaces _load_history_hashes() with persistent index lookup - Fallback: if index unavailable, scans ALL files in-memory (not just 5) - New register_exported_hashes() method called after export - Config key: dedup_full_history (replaces dedup_history_files) Modified: dpo_generator.py - Calls validator.register_exported_hashes() after successful export to keep the persistent index current without rescanning Modified: config.yaml - Replaced dedup_history_files: 5 with dedup_full_history: true Tested — 7 integration tests: ✓ Fresh index build from empty directory ✓ Build from 3 existing JSONL files (15 unique hashes) ✓ Incremental sync when new file appears between runs ✓ Append after export + persistence across reloads ✓ Rebuild from scratch (recovers from corruption) ✓ Validator catches day-1 dupe from 20-day history (5-file window miss) ✓ Full pipeline: generate → validate → export → register → re-run detects
534 lines
20 KiB
Python
534 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
"""DPO Pair Quality Validator — Gate before overnight training.
|
|
|
|
Catches bad training pairs before they enter the tightening loop:
|
|
|
|
1. Near-duplicate chosen/rejected (low contrast) — model learns nothing
|
|
2. Near-duplicate prompts across pairs (low diversity) — wasted compute
|
|
3. Too-short or empty fields — malformed pairs
|
|
4. Chosen not meaningfully richer than rejected — inverted signal
|
|
5. Cross-run deduplication — don't retrain on yesterday's pairs
|
|
|
|
Sits between DPOPairGenerator.generate() and .export().
|
|
Pairs that fail validation get flagged, not silently dropped —
|
|
the generator decides whether to export flagged pairs or filter them.
|
|
|
|
Usage standalone:
|
|
python3 dpo_quality.py ~/.timmy/training-data/dpo-pairs/deepdive_20260413.jsonl
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import re
|
|
from collections import Counter
|
|
from dataclasses import dataclass, field, asdict
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set
|
|
|
|
# Persistent dedup index
|
|
try:
|
|
from dedup_index import DedupIndex
|
|
HAS_DEDUP_INDEX = True
|
|
except ImportError:
|
|
HAS_DEDUP_INDEX = False
|
|
DedupIndex = None
|
|
|
|
logger = logging.getLogger("deepdive.dpo_quality")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration defaults (overridable via config dict)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
DEFAULT_CONFIG = {
|
|
# Minimum character lengths
|
|
"min_prompt_chars": 40,
|
|
"min_chosen_chars": 80,
|
|
"min_rejected_chars": 30,
|
|
|
|
# Chosen must be at least this ratio longer than rejected
|
|
"min_chosen_rejected_ratio": 1.3,
|
|
|
|
# Jaccard similarity thresholds (word-level)
|
|
"max_chosen_rejected_similarity": 0.70, # Flag if chosen ≈ rejected
|
|
"max_prompt_prompt_similarity": 0.85, # Flag if two prompts are near-dupes
|
|
|
|
# Cross-run dedup: full-history persistent index
|
|
# (replaces the old sliding-window approach)
|
|
"dedup_full_history": True,
|
|
|
|
# What to do with flagged pairs: "drop" or "flag"
|
|
# "drop" = remove from export entirely
|
|
# "flag" = add warning to safety_flags but still export
|
|
"flagged_pair_action": "drop",
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data structures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class PairReport:
|
|
"""Validation result for a single DPO pair."""
|
|
index: int
|
|
passed: bool
|
|
warnings: List[str] = field(default_factory=list)
|
|
scores: Dict[str, float] = field(default_factory=dict)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
@dataclass
|
|
class BatchReport:
|
|
"""Validation result for an entire batch of DPO pairs."""
|
|
total_pairs: int
|
|
passed_pairs: int
|
|
dropped_pairs: int
|
|
flagged_pairs: int
|
|
duplicate_prompts_found: int
|
|
cross_run_duplicates_found: int
|
|
pair_reports: List[PairReport] = field(default_factory=list)
|
|
warnings: List[str] = field(default_factory=list)
|
|
|
|
@property
|
|
def pass_rate(self) -> float:
|
|
return self.passed_pairs / max(self.total_pairs, 1)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
d = asdict(self)
|
|
d["pass_rate"] = round(self.pass_rate, 3)
|
|
return d
|
|
|
|
def summary(self) -> str:
|
|
lines = [
|
|
f"DPO Quality: {self.passed_pairs}/{self.total_pairs} passed "
|
|
f"({self.pass_rate:.0%})",
|
|
f" Dropped: {self.dropped_pairs}, Flagged: {self.flagged_pairs}",
|
|
]
|
|
if self.duplicate_prompts_found:
|
|
lines.append(f" Duplicate prompts: {self.duplicate_prompts_found}")
|
|
if self.cross_run_duplicates_found:
|
|
lines.append(f" Cross-run dupes: {self.cross_run_duplicates_found}")
|
|
if self.warnings:
|
|
for w in self.warnings:
|
|
lines.append(f" ⚠ {w}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Core validator
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DPOQualityValidator:
|
|
"""Validate DPO pairs for quality before overnight training export.
|
|
|
|
Call validate() with a list of pair dicts to get a BatchReport
|
|
and a filtered list of pairs that passed validation.
|
|
"""
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None,
|
|
output_dir: Optional[Path] = None):
|
|
self.cfg = {**DEFAULT_CONFIG, **(config or {})}
|
|
self.output_dir = Path(output_dir) if output_dir else Path.home() / ".timmy" / "training-data" / "dpo-pairs"
|
|
|
|
# Persistent full-history dedup index
|
|
self._dedup_index = None
|
|
if HAS_DEDUP_INDEX and self.cfg.get("dedup_full_history", True):
|
|
try:
|
|
self._dedup_index = DedupIndex(self.output_dir)
|
|
logger.info(
|
|
f"Full-history dedup index: {self._dedup_index.size} prompts, "
|
|
f"{self._dedup_index.files_indexed} files"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load dedup index, falling back to in-memory: {e}")
|
|
self._dedup_index = None
|
|
|
|
# Fallback: in-memory hash cache (used if index unavailable)
|
|
self._history_hashes: Optional[Set[str]] = None
|
|
|
|
logger.info(
|
|
f"DPOQualityValidator: action={self.cfg['flagged_pair_action']}, "
|
|
f"max_cr_sim={self.cfg['max_chosen_rejected_similarity']}, "
|
|
f"max_pp_sim={self.cfg['max_prompt_prompt_similarity']}, "
|
|
f"dedup={'full-history index' if self._dedup_index else 'in-memory fallback'}"
|
|
)
|
|
|
|
# -------------------------------------------------------------------
|
|
# Text analysis helpers
|
|
# -------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _tokenize(text: str) -> List[str]:
|
|
"""Simple whitespace + punctuation tokenizer."""
|
|
return re.findall(r'\b\w+\b', text.lower())
|
|
|
|
@staticmethod
|
|
def _jaccard(tokens_a: List[str], tokens_b: List[str]) -> float:
|
|
"""Word-level Jaccard similarity."""
|
|
set_a = set(tokens_a)
|
|
set_b = set(tokens_b)
|
|
if not set_a and not set_b:
|
|
return 1.0
|
|
if not set_a or not set_b:
|
|
return 0.0
|
|
return len(set_a & set_b) / len(set_a | set_b)
|
|
|
|
@staticmethod
|
|
def _content_hash(text: str) -> str:
|
|
"""Stable hash of normalized text for deduplication."""
|
|
normalized = " ".join(text.lower().split())
|
|
return hashlib.sha256(normalized.encode()).hexdigest()[:16]
|
|
|
|
@staticmethod
|
|
def _unique_word_ratio(text: str) -> float:
|
|
"""Ratio of unique words to total words (vocabulary diversity)."""
|
|
words = re.findall(r'\b\w+\b', text.lower())
|
|
if not words:
|
|
return 0.0
|
|
return len(set(words)) / len(words)
|
|
|
|
# -------------------------------------------------------------------
|
|
# Single-pair validation
|
|
# -------------------------------------------------------------------
|
|
|
|
def _validate_pair(self, pair: Dict[str, Any], index: int) -> PairReport:
|
|
"""Run all quality checks on a single pair."""
|
|
warnings = []
|
|
scores = {}
|
|
|
|
prompt = pair.get("prompt", "")
|
|
chosen = pair.get("chosen", "")
|
|
rejected = pair.get("rejected", "")
|
|
|
|
# --- Check 1: Field lengths ---
|
|
if len(prompt) < self.cfg["min_prompt_chars"]:
|
|
warnings.append(
|
|
f"prompt too short ({len(prompt)} chars, min {self.cfg['min_prompt_chars']})"
|
|
)
|
|
if len(chosen) < self.cfg["min_chosen_chars"]:
|
|
warnings.append(
|
|
f"chosen too short ({len(chosen)} chars, min {self.cfg['min_chosen_chars']})"
|
|
)
|
|
if len(rejected) < self.cfg["min_rejected_chars"]:
|
|
warnings.append(
|
|
f"rejected too short ({len(rejected)} chars, min {self.cfg['min_rejected_chars']})"
|
|
)
|
|
|
|
# --- Check 2: Chosen-Rejected length ratio ---
|
|
if len(rejected) > 0:
|
|
ratio = len(chosen) / len(rejected)
|
|
scores["chosen_rejected_ratio"] = round(ratio, 2)
|
|
if ratio < self.cfg["min_chosen_rejected_ratio"]:
|
|
warnings.append(
|
|
f"chosen/rejected ratio too low ({ratio:.2f}, "
|
|
f"min {self.cfg['min_chosen_rejected_ratio']})"
|
|
)
|
|
else:
|
|
scores["chosen_rejected_ratio"] = 0.0
|
|
warnings.append("rejected is empty")
|
|
|
|
# --- Check 3: Chosen-Rejected content similarity ---
|
|
chosen_tokens = self._tokenize(chosen)
|
|
rejected_tokens = self._tokenize(rejected)
|
|
cr_sim = self._jaccard(chosen_tokens, rejected_tokens)
|
|
scores["chosen_rejected_similarity"] = round(cr_sim, 3)
|
|
|
|
if cr_sim > self.cfg["max_chosen_rejected_similarity"]:
|
|
warnings.append(
|
|
f"chosen≈rejected (Jaccard {cr_sim:.2f}, "
|
|
f"max {self.cfg['max_chosen_rejected_similarity']})"
|
|
)
|
|
|
|
# --- Check 4: Vocabulary diversity in chosen ---
|
|
chosen_diversity = self._unique_word_ratio(chosen)
|
|
scores["chosen_vocab_diversity"] = round(chosen_diversity, 3)
|
|
if chosen_diversity < 0.3:
|
|
warnings.append(
|
|
f"low vocabulary diversity in chosen ({chosen_diversity:.2f})"
|
|
)
|
|
|
|
# --- Check 5: Chosen should contain substantive content markers ---
|
|
chosen_lower = chosen.lower()
|
|
substance_markers = [
|
|
"relevance", "implication", "training", "agent", "fleet",
|
|
"hermes", "deploy", "architecture", "pipeline", "score",
|
|
"technique", "approach", "recommend", "review", "action",
|
|
]
|
|
marker_hits = sum(1 for m in substance_markers if m in chosen_lower)
|
|
scores["substance_markers"] = marker_hits
|
|
if marker_hits < 2:
|
|
warnings.append(
|
|
f"chosen lacks substance markers ({marker_hits} found, min 2)"
|
|
)
|
|
|
|
passed = len(warnings) == 0
|
|
return PairReport(index=index, passed=passed, warnings=warnings, scores=scores)
|
|
|
|
# -------------------------------------------------------------------
|
|
# Batch-level validation (cross-pair checks)
|
|
# -------------------------------------------------------------------
|
|
|
|
def _check_prompt_duplicates(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
|
|
"""Find near-duplicate prompts within the batch.
|
|
|
|
Returns dict mapping pair index → warning string for duplicates.
|
|
"""
|
|
prompt_tokens = []
|
|
for pair in pairs:
|
|
prompt_tokens.append(self._tokenize(pair.get("prompt", "")))
|
|
|
|
dupe_warnings: Dict[int, str] = {}
|
|
seen_groups: List[Set[int]] = []
|
|
|
|
for i in range(len(prompt_tokens)):
|
|
# Skip if already in a dupe group
|
|
if any(i in g for g in seen_groups):
|
|
continue
|
|
group = {i}
|
|
for j in range(i + 1, len(prompt_tokens)):
|
|
sim = self._jaccard(prompt_tokens[i], prompt_tokens[j])
|
|
if sim > self.cfg["max_prompt_prompt_similarity"]:
|
|
group.add(j)
|
|
dupe_warnings[j] = (
|
|
f"near-duplicate prompt (Jaccard {sim:.2f} with pair {i})"
|
|
)
|
|
if len(group) > 1:
|
|
seen_groups.append(group)
|
|
|
|
return dupe_warnings
|
|
|
|
def _check_cross_run_dupes(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
|
|
"""Check if any pair prompts exist in full training history.
|
|
|
|
Uses persistent DedupIndex when available (covers all historical
|
|
JSONL files). Falls back to in-memory scan of ALL files if index
|
|
module is unavailable.
|
|
|
|
Returns dict mapping pair index → warning string for duplicates.
|
|
"""
|
|
dupe_warnings: Dict[int, str] = {}
|
|
|
|
if self._dedup_index:
|
|
# Full-history lookup via persistent index
|
|
for i, pair in enumerate(pairs):
|
|
prompt_hash = self._content_hash(pair.get("prompt", ""))
|
|
if self._dedup_index.contains(prompt_hash):
|
|
dupe_warnings[i] = (
|
|
f"cross-run duplicate (prompt seen in full history — "
|
|
f"{self._dedup_index.size} indexed prompts)"
|
|
)
|
|
return dupe_warnings
|
|
|
|
# Fallback: scan all JSONL files in output_dir (no sliding window)
|
|
if self._history_hashes is None:
|
|
self._history_hashes = set()
|
|
if self.output_dir.exists():
|
|
jsonl_files = sorted(self.output_dir.glob("deepdive_*.jsonl"))
|
|
jsonl_files.extend(sorted(self.output_dir.glob("pairs_*.jsonl")))
|
|
for path in jsonl_files:
|
|
try:
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
pair_data = json.loads(line)
|
|
h = self._content_hash(pair_data.get("prompt", ""))
|
|
self._history_hashes.add(h)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to read history file {path}: {e}")
|
|
logger.info(
|
|
f"Fallback dedup: loaded {len(self._history_hashes)} hashes "
|
|
f"from {len(jsonl_files)} files"
|
|
)
|
|
|
|
for i, pair in enumerate(pairs):
|
|
prompt_hash = self._content_hash(pair.get("prompt", ""))
|
|
if prompt_hash in self._history_hashes:
|
|
dupe_warnings[i] = "cross-run duplicate (prompt seen in full history)"
|
|
|
|
return dupe_warnings
|
|
|
|
def register_exported_hashes(self, pairs: List[Dict[str, Any]],
|
|
filename: str) -> None:
|
|
"""After successful export, register new prompt hashes in the index.
|
|
|
|
Called by DPOPairGenerator after writing the JSONL file.
|
|
"""
|
|
hashes = [self._content_hash(p.get("prompt", "")) for p in pairs]
|
|
|
|
if self._dedup_index:
|
|
added = self._dedup_index.add_hashes_and_register(hashes, filename)
|
|
logger.info(
|
|
f"Registered {added} new hashes in dedup index "
|
|
f"(total: {self._dedup_index.size})"
|
|
)
|
|
else:
|
|
# Update in-memory fallback
|
|
if self._history_hashes is None:
|
|
self._history_hashes = set()
|
|
self._history_hashes.update(hashes)
|
|
|
|
# -------------------------------------------------------------------
|
|
# Main validation entry point
|
|
# -------------------------------------------------------------------
|
|
|
|
def validate(self, pairs: List[Dict[str, Any]]) -> tuple:
|
|
"""Validate a batch of DPO pairs.
|
|
|
|
Args:
|
|
pairs: List of pair dicts with {prompt, chosen, rejected, ...}
|
|
|
|
Returns:
|
|
(filtered_pairs, report): Tuple of filtered pair list and BatchReport.
|
|
If flagged_pair_action="drop", filtered_pairs excludes bad pairs.
|
|
If flagged_pair_action="flag", all pairs are returned with safety_flags updated.
|
|
"""
|
|
if not pairs:
|
|
report = BatchReport(
|
|
total_pairs=0, passed_pairs=0, dropped_pairs=0,
|
|
flagged_pairs=0, duplicate_prompts_found=0,
|
|
cross_run_duplicates_found=0,
|
|
warnings=["Empty pair batch"],
|
|
)
|
|
return [], report
|
|
|
|
action = self.cfg["flagged_pair_action"]
|
|
pair_dicts = [p if isinstance(p, dict) else p.to_dict() for p in pairs]
|
|
|
|
# Single-pair checks
|
|
pair_reports = []
|
|
for i, pair in enumerate(pair_dicts):
|
|
report = self._validate_pair(pair, i)
|
|
pair_reports.append(report)
|
|
|
|
# Cross-pair checks: prompt diversity
|
|
prompt_dupe_warnings = self._check_prompt_duplicates(pair_dicts)
|
|
for idx, warning in prompt_dupe_warnings.items():
|
|
pair_reports[idx].warnings.append(warning)
|
|
pair_reports[idx].passed = False
|
|
|
|
# Cross-run dedup
|
|
crossrun_dupe_warnings = self._check_cross_run_dupes(pair_dicts)
|
|
for idx, warning in crossrun_dupe_warnings.items():
|
|
pair_reports[idx].warnings.append(warning)
|
|
pair_reports[idx].passed = False
|
|
|
|
# Build filtered output
|
|
filtered = []
|
|
dropped = 0
|
|
flagged = 0
|
|
|
|
for i, (pair, report) in enumerate(zip(pair_dicts, pair_reports)):
|
|
if report.passed:
|
|
filtered.append(pair)
|
|
elif action == "drop":
|
|
dropped += 1
|
|
logger.debug(f"Dropping pair {i}: {report.warnings}")
|
|
else: # "flag"
|
|
# Add warnings to safety_flags
|
|
flags = pair.get("safety_flags", [])
|
|
flags.append("quality-flagged")
|
|
for w in report.warnings:
|
|
flags.append(f"qv:{w[:60]}")
|
|
pair["safety_flags"] = flags
|
|
filtered.append(pair)
|
|
flagged += 1
|
|
|
|
passed = sum(1 for r in pair_reports if r.passed)
|
|
|
|
batch_warnings = []
|
|
if passed == 0 and len(pairs) > 0:
|
|
batch_warnings.append("ALL pairs failed validation — no training data produced")
|
|
if len(prompt_dupe_warnings) > len(pairs) * 0.5:
|
|
batch_warnings.append(
|
|
f"High prompt duplication: {len(prompt_dupe_warnings)}/{len(pairs)} pairs are near-duplicates"
|
|
)
|
|
|
|
# Task type diversity check
|
|
task_types = Counter(p.get("task_type", "unknown") for p in filtered)
|
|
if len(task_types) == 1 and len(filtered) > 3:
|
|
batch_warnings.append(
|
|
f"Low task-type diversity: all {len(filtered)} pairs are '{list(task_types.keys())[0]}'"
|
|
)
|
|
|
|
batch_report = BatchReport(
|
|
total_pairs=len(pairs),
|
|
passed_pairs=passed,
|
|
dropped_pairs=dropped,
|
|
flagged_pairs=flagged,
|
|
duplicate_prompts_found=len(prompt_dupe_warnings),
|
|
cross_run_duplicates_found=len(crossrun_dupe_warnings),
|
|
pair_reports=pair_reports,
|
|
warnings=batch_warnings,
|
|
)
|
|
|
|
logger.info(batch_report.summary())
|
|
return filtered, batch_report
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI for standalone validation of existing JSONL files
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Validate DPO pair quality")
|
|
parser.add_argument("jsonl_file", type=Path, help="Path to JSONL file with DPO pairs")
|
|
parser.add_argument("--json", action="store_true", help="Output JSON report")
|
|
parser.add_argument("--strict", action="store_true",
|
|
help="Drop flagged pairs (default: flag only)")
|
|
args = parser.parse_args()
|
|
|
|
if not args.jsonl_file.exists():
|
|
print(f"Error: file not found: {args.jsonl_file}")
|
|
return 1
|
|
|
|
pairs = []
|
|
with open(args.jsonl_file) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
pairs.append(json.loads(line))
|
|
|
|
config = {}
|
|
if args.strict:
|
|
config["flagged_pair_action"] = "drop"
|
|
else:
|
|
config["flagged_pair_action"] = "flag"
|
|
|
|
# Use parent dir of input file as output_dir for history scanning
|
|
output_dir = args.jsonl_file.parent
|
|
validator = DPOQualityValidator(config=config, output_dir=output_dir)
|
|
filtered, report = validator.validate(pairs)
|
|
|
|
if args.json:
|
|
print(json.dumps(report.to_dict(), indent=2))
|
|
else:
|
|
print("=" * 60)
|
|
print(" DPO PAIR QUALITY VALIDATION REPORT")
|
|
print("=" * 60)
|
|
print(report.summary())
|
|
print("-" * 60)
|
|
for pr in report.pair_reports:
|
|
status = "✓" if pr.passed else "✗"
|
|
print(f" [{status}] Pair {pr.index}: ", end="")
|
|
if pr.passed:
|
|
print("OK")
|
|
else:
|
|
print(", ".join(pr.warnings))
|
|
print("=" * 60)
|
|
print(f"\nFiltered output: {len(filtered)} pairs "
|
|
f"({'strict/drop' if args.strict else 'flag'} mode)")
|
|
|
|
return 0 if report.passed_pairs > 0 else 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|