#!/usr/bin/env python3 """DPO Pair Quality Validator — Gate before overnight training. Catches bad training pairs before they enter the tightening loop: 1. Near-duplicate chosen/rejected (low contrast) — model learns nothing 2. Near-duplicate prompts across pairs (low diversity) — wasted compute 3. Too-short or empty fields — malformed pairs 4. Chosen not meaningfully richer than rejected — inverted signal 5. Cross-run deduplication — don't retrain on yesterday's pairs Sits between DPOPairGenerator.generate() and .export(). Pairs that fail validation get flagged, not silently dropped — the generator decides whether to export flagged pairs or filter them. Usage standalone: python3 dpo_quality.py ~/.timmy/training-data/dpo-pairs/deepdive_20260413.jsonl """ import hashlib import json import logging import re from collections import Counter from dataclasses import dataclass, field, asdict from pathlib import Path from typing import Any, Dict, List, Optional, Set # Persistent dedup index try: from dedup_index import DedupIndex HAS_DEDUP_INDEX = True except ImportError: HAS_DEDUP_INDEX = False DedupIndex = None logger = logging.getLogger("deepdive.dpo_quality") # --------------------------------------------------------------------------- # Configuration defaults (overridable via config dict) # --------------------------------------------------------------------------- DEFAULT_CONFIG = { # Minimum character lengths "min_prompt_chars": 40, "min_chosen_chars": 80, "min_rejected_chars": 30, # Chosen must be at least this ratio longer than rejected "min_chosen_rejected_ratio": 1.3, # Jaccard similarity thresholds (word-level) "max_chosen_rejected_similarity": 0.70, # Flag if chosen ≈ rejected "max_prompt_prompt_similarity": 0.85, # Flag if two prompts are near-dupes # Cross-run dedup: full-history persistent index # (replaces the old sliding-window approach) "dedup_full_history": True, # What to do with flagged pairs: "drop" or "flag" # "drop" = remove from export entirely # "flag" = add warning to safety_flags but still export "flagged_pair_action": "drop", } # --------------------------------------------------------------------------- # Data structures # --------------------------------------------------------------------------- @dataclass class PairReport: """Validation result for a single DPO pair.""" index: int passed: bool warnings: List[str] = field(default_factory=list) scores: Dict[str, float] = field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: return asdict(self) @dataclass class BatchReport: """Validation result for an entire batch of DPO pairs.""" total_pairs: int passed_pairs: int dropped_pairs: int flagged_pairs: int duplicate_prompts_found: int cross_run_duplicates_found: int pair_reports: List[PairReport] = field(default_factory=list) warnings: List[str] = field(default_factory=list) @property def pass_rate(self) -> float: return self.passed_pairs / max(self.total_pairs, 1) def to_dict(self) -> Dict[str, Any]: d = asdict(self) d["pass_rate"] = round(self.pass_rate, 3) return d def summary(self) -> str: lines = [ f"DPO Quality: {self.passed_pairs}/{self.total_pairs} passed " f"({self.pass_rate:.0%})", f" Dropped: {self.dropped_pairs}, Flagged: {self.flagged_pairs}", ] if self.duplicate_prompts_found: lines.append(f" Duplicate prompts: {self.duplicate_prompts_found}") if self.cross_run_duplicates_found: lines.append(f" Cross-run dupes: {self.cross_run_duplicates_found}") if self.warnings: for w in self.warnings: lines.append(f" ⚠ {w}") return "\n".join(lines) # --------------------------------------------------------------------------- # Core validator # --------------------------------------------------------------------------- class DPOQualityValidator: """Validate DPO pairs for quality before overnight training export. Call validate() with a list of pair dicts to get a BatchReport and a filtered list of pairs that passed validation. """ def __init__(self, config: Optional[Dict[str, Any]] = None, output_dir: Optional[Path] = None): self.cfg = {**DEFAULT_CONFIG, **(config or {})} self.output_dir = Path(output_dir) if output_dir else Path.home() / ".timmy" / "training-data" / "dpo-pairs" # Persistent full-history dedup index self._dedup_index = None if HAS_DEDUP_INDEX and self.cfg.get("dedup_full_history", True): try: self._dedup_index = DedupIndex(self.output_dir) logger.info( f"Full-history dedup index: {self._dedup_index.size} prompts, " f"{self._dedup_index.files_indexed} files" ) except Exception as e: logger.warning(f"Failed to load dedup index, falling back to in-memory: {e}") self._dedup_index = None # Fallback: in-memory hash cache (used if index unavailable) self._history_hashes: Optional[Set[str]] = None logger.info( f"DPOQualityValidator: action={self.cfg['flagged_pair_action']}, " f"max_cr_sim={self.cfg['max_chosen_rejected_similarity']}, " f"max_pp_sim={self.cfg['max_prompt_prompt_similarity']}, " f"dedup={'full-history index' if self._dedup_index else 'in-memory fallback'}" ) # ------------------------------------------------------------------- # Text analysis helpers # ------------------------------------------------------------------- @staticmethod def _tokenize(text: str) -> List[str]: """Simple whitespace + punctuation tokenizer.""" return re.findall(r'\b\w+\b', text.lower()) @staticmethod def _jaccard(tokens_a: List[str], tokens_b: List[str]) -> float: """Word-level Jaccard similarity.""" set_a = set(tokens_a) set_b = set(tokens_b) if not set_a and not set_b: return 1.0 if not set_a or not set_b: return 0.0 return len(set_a & set_b) / len(set_a | set_b) @staticmethod def _content_hash(text: str) -> str: """Stable hash of normalized text for deduplication.""" normalized = " ".join(text.lower().split()) return hashlib.sha256(normalized.encode()).hexdigest()[:16] @staticmethod def _unique_word_ratio(text: str) -> float: """Ratio of unique words to total words (vocabulary diversity).""" words = re.findall(r'\b\w+\b', text.lower()) if not words: return 0.0 return len(set(words)) / len(words) # ------------------------------------------------------------------- # Single-pair validation # ------------------------------------------------------------------- def _validate_pair(self, pair: Dict[str, Any], index: int) -> PairReport: """Run all quality checks on a single pair.""" warnings = [] scores = {} prompt = pair.get("prompt", "") chosen = pair.get("chosen", "") rejected = pair.get("rejected", "") # --- Check 1: Field lengths --- if len(prompt) < self.cfg["min_prompt_chars"]: warnings.append( f"prompt too short ({len(prompt)} chars, min {self.cfg['min_prompt_chars']})" ) if len(chosen) < self.cfg["min_chosen_chars"]: warnings.append( f"chosen too short ({len(chosen)} chars, min {self.cfg['min_chosen_chars']})" ) if len(rejected) < self.cfg["min_rejected_chars"]: warnings.append( f"rejected too short ({len(rejected)} chars, min {self.cfg['min_rejected_chars']})" ) # --- Check 2: Chosen-Rejected length ratio --- if len(rejected) > 0: ratio = len(chosen) / len(rejected) scores["chosen_rejected_ratio"] = round(ratio, 2) if ratio < self.cfg["min_chosen_rejected_ratio"]: warnings.append( f"chosen/rejected ratio too low ({ratio:.2f}, " f"min {self.cfg['min_chosen_rejected_ratio']})" ) else: scores["chosen_rejected_ratio"] = 0.0 warnings.append("rejected is empty") # --- Check 3: Chosen-Rejected content similarity --- chosen_tokens = self._tokenize(chosen) rejected_tokens = self._tokenize(rejected) cr_sim = self._jaccard(chosen_tokens, rejected_tokens) scores["chosen_rejected_similarity"] = round(cr_sim, 3) if cr_sim > self.cfg["max_chosen_rejected_similarity"]: warnings.append( f"chosen≈rejected (Jaccard {cr_sim:.2f}, " f"max {self.cfg['max_chosen_rejected_similarity']})" ) # --- Check 4: Vocabulary diversity in chosen --- chosen_diversity = self._unique_word_ratio(chosen) scores["chosen_vocab_diversity"] = round(chosen_diversity, 3) if chosen_diversity < 0.3: warnings.append( f"low vocabulary diversity in chosen ({chosen_diversity:.2f})" ) # --- Check 5: Chosen should contain substantive content markers --- chosen_lower = chosen.lower() substance_markers = [ "relevance", "implication", "training", "agent", "fleet", "hermes", "deploy", "architecture", "pipeline", "score", "technique", "approach", "recommend", "review", "action", ] marker_hits = sum(1 for m in substance_markers if m in chosen_lower) scores["substance_markers"] = marker_hits if marker_hits < 2: warnings.append( f"chosen lacks substance markers ({marker_hits} found, min 2)" ) passed = len(warnings) == 0 return PairReport(index=index, passed=passed, warnings=warnings, scores=scores) # ------------------------------------------------------------------- # Batch-level validation (cross-pair checks) # ------------------------------------------------------------------- def _check_prompt_duplicates(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]: """Find near-duplicate prompts within the batch. Returns dict mapping pair index → warning string for duplicates. """ prompt_tokens = [] for pair in pairs: prompt_tokens.append(self._tokenize(pair.get("prompt", ""))) dupe_warnings: Dict[int, str] = {} seen_groups: List[Set[int]] = [] for i in range(len(prompt_tokens)): # Skip if already in a dupe group if any(i in g for g in seen_groups): continue group = {i} for j in range(i + 1, len(prompt_tokens)): sim = self._jaccard(prompt_tokens[i], prompt_tokens[j]) if sim > self.cfg["max_prompt_prompt_similarity"]: group.add(j) dupe_warnings[j] = ( f"near-duplicate prompt (Jaccard {sim:.2f} with pair {i})" ) if len(group) > 1: seen_groups.append(group) return dupe_warnings def _check_cross_run_dupes(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]: """Check if any pair prompts exist in full training history. Uses persistent DedupIndex when available (covers all historical JSONL files). Falls back to in-memory scan of ALL files if index module is unavailable. Returns dict mapping pair index → warning string for duplicates. """ dupe_warnings: Dict[int, str] = {} if self._dedup_index: # Full-history lookup via persistent index for i, pair in enumerate(pairs): prompt_hash = self._content_hash(pair.get("prompt", "")) if self._dedup_index.contains(prompt_hash): dupe_warnings[i] = ( f"cross-run duplicate (prompt seen in full history — " f"{self._dedup_index.size} indexed prompts)" ) return dupe_warnings # Fallback: scan all JSONL files in output_dir (no sliding window) if self._history_hashes is None: self._history_hashes = set() if self.output_dir.exists(): jsonl_files = sorted(self.output_dir.glob("deepdive_*.jsonl")) jsonl_files.extend(sorted(self.output_dir.glob("pairs_*.jsonl"))) for path in jsonl_files: try: with open(path) as f: for line in f: line = line.strip() if not line: continue pair_data = json.loads(line) h = self._content_hash(pair_data.get("prompt", "")) self._history_hashes.add(h) except Exception as e: logger.warning(f"Failed to read history file {path}: {e}") logger.info( f"Fallback dedup: loaded {len(self._history_hashes)} hashes " f"from {len(jsonl_files)} files" ) for i, pair in enumerate(pairs): prompt_hash = self._content_hash(pair.get("prompt", "")) if prompt_hash in self._history_hashes: dupe_warnings[i] = "cross-run duplicate (prompt seen in full history)" return dupe_warnings def register_exported_hashes(self, pairs: List[Dict[str, Any]], filename: str) -> None: """After successful export, register new prompt hashes in the index. Called by DPOPairGenerator after writing the JSONL file. """ hashes = [self._content_hash(p.get("prompt", "")) for p in pairs] if self._dedup_index: added = self._dedup_index.add_hashes_and_register(hashes, filename) logger.info( f"Registered {added} new hashes in dedup index " f"(total: {self._dedup_index.size})" ) else: # Update in-memory fallback if self._history_hashes is None: self._history_hashes = set() self._history_hashes.update(hashes) # ------------------------------------------------------------------- # Main validation entry point # ------------------------------------------------------------------- def validate(self, pairs: List[Dict[str, Any]]) -> tuple: """Validate a batch of DPO pairs. Args: pairs: List of pair dicts with {prompt, chosen, rejected, ...} Returns: (filtered_pairs, report): Tuple of filtered pair list and BatchReport. If flagged_pair_action="drop", filtered_pairs excludes bad pairs. If flagged_pair_action="flag", all pairs are returned with safety_flags updated. """ if not pairs: report = BatchReport( total_pairs=0, passed_pairs=0, dropped_pairs=0, flagged_pairs=0, duplicate_prompts_found=0, cross_run_duplicates_found=0, warnings=["Empty pair batch"], ) return [], report action = self.cfg["flagged_pair_action"] pair_dicts = [p if isinstance(p, dict) else p.to_dict() for p in pairs] # Single-pair checks pair_reports = [] for i, pair in enumerate(pair_dicts): report = self._validate_pair(pair, i) pair_reports.append(report) # Cross-pair checks: prompt diversity prompt_dupe_warnings = self._check_prompt_duplicates(pair_dicts) for idx, warning in prompt_dupe_warnings.items(): pair_reports[idx].warnings.append(warning) pair_reports[idx].passed = False # Cross-run dedup crossrun_dupe_warnings = self._check_cross_run_dupes(pair_dicts) for idx, warning in crossrun_dupe_warnings.items(): pair_reports[idx].warnings.append(warning) pair_reports[idx].passed = False # Build filtered output filtered = [] dropped = 0 flagged = 0 for i, (pair, report) in enumerate(zip(pair_dicts, pair_reports)): if report.passed: filtered.append(pair) elif action == "drop": dropped += 1 logger.debug(f"Dropping pair {i}: {report.warnings}") else: # "flag" # Add warnings to safety_flags flags = pair.get("safety_flags", []) flags.append("quality-flagged") for w in report.warnings: flags.append(f"qv:{w[:60]}") pair["safety_flags"] = flags filtered.append(pair) flagged += 1 passed = sum(1 for r in pair_reports if r.passed) batch_warnings = [] if passed == 0 and len(pairs) > 0: batch_warnings.append("ALL pairs failed validation — no training data produced") if len(prompt_dupe_warnings) > len(pairs) * 0.5: batch_warnings.append( f"High prompt duplication: {len(prompt_dupe_warnings)}/{len(pairs)} pairs are near-duplicates" ) # Task type diversity check task_types = Counter(p.get("task_type", "unknown") for p in filtered) if len(task_types) == 1 and len(filtered) > 3: batch_warnings.append( f"Low task-type diversity: all {len(filtered)} pairs are '{list(task_types.keys())[0]}'" ) batch_report = BatchReport( total_pairs=len(pairs), passed_pairs=passed, dropped_pairs=dropped, flagged_pairs=flagged, duplicate_prompts_found=len(prompt_dupe_warnings), cross_run_duplicates_found=len(crossrun_dupe_warnings), pair_reports=pair_reports, warnings=batch_warnings, ) logger.info(batch_report.summary()) return filtered, batch_report # --------------------------------------------------------------------------- # CLI for standalone validation of existing JSONL files # --------------------------------------------------------------------------- def main(): import argparse parser = argparse.ArgumentParser(description="Validate DPO pair quality") parser.add_argument("jsonl_file", type=Path, help="Path to JSONL file with DPO pairs") parser.add_argument("--json", action="store_true", help="Output JSON report") parser.add_argument("--strict", action="store_true", help="Drop flagged pairs (default: flag only)") args = parser.parse_args() if not args.jsonl_file.exists(): print(f"Error: file not found: {args.jsonl_file}") return 1 pairs = [] with open(args.jsonl_file) as f: for line in f: line = line.strip() if line: pairs.append(json.loads(line)) config = {} if args.strict: config["flagged_pair_action"] = "drop" else: config["flagged_pair_action"] = "flag" # Use parent dir of input file as output_dir for history scanning output_dir = args.jsonl_file.parent validator = DPOQualityValidator(config=config, output_dir=output_dir) filtered, report = validator.validate(pairs) if args.json: print(json.dumps(report.to_dict(), indent=2)) else: print("=" * 60) print(" DPO PAIR QUALITY VALIDATION REPORT") print("=" * 60) print(report.summary()) print("-" * 60) for pr in report.pair_reports: status = "✓" if pr.passed else "✗" print(f" [{status}] Pair {pr.index}: ", end="") if pr.passed: print("OK") else: print(", ".join(pr.warnings)) print("=" * 60) print(f"\nFiltered output: {len(filtered)} pairs " f"({'strict/drop' if args.strict else 'flag'} mode)") return 0 if report.passed_pairs > 0 else 2 if __name__ == "__main__": exit(main())