diff --git a/intelligence/deepdive/config.yaml b/intelligence/deepdive/config.yaml index 3f6706ae..b65cdd48 100644 --- a/intelligence/deepdive/config.yaml +++ b/intelligence/deepdive/config.yaml @@ -99,6 +99,16 @@ deepdive: - "summarize" # Paper summary → fleet-grounded analysis - "relevance" # Relevance analysis → scored fleet context - "implication" # Implications → actionable insight + validation: + enabled: true + flagged_pair_action: "drop" # "drop" = remove bad pairs, "flag" = export with warning + min_prompt_chars: 40 # Minimum prompt length + min_chosen_chars: 80 # Minimum chosen response length + min_rejected_chars: 30 # Minimum rejected response length + min_chosen_rejected_ratio: 1.3 # Chosen must be ≥1.3x longer than rejected + max_chosen_rejected_similarity: 0.70 # Max Jaccard overlap between chosen/rejected + max_prompt_prompt_similarity: 0.85 # Max Jaccard overlap between prompts (dedup) + dedup_history_files: 5 # How many recent JSONL files to scan for cross-run dedup # Phase 0: Fleet Context Grounding fleet_context: diff --git a/intelligence/deepdive/dpo_generator.py b/intelligence/deepdive/dpo_generator.py index 988468df..3a9bc33f 100644 --- a/intelligence/deepdive/dpo_generator.py +++ b/intelligence/deepdive/dpo_generator.py @@ -22,6 +22,14 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional +# Quality validation gate +try: + from dpo_quality import DPOQualityValidator + HAS_DPO_QUALITY = True +except ImportError: + HAS_DPO_QUALITY = False + DPOQualityValidator = None + logger = logging.getLogger("deepdive.dpo_generator") @@ -69,6 +77,20 @@ class DPOPairGenerator: self.max_pairs_per_run = cfg.get("max_pairs_per_run", 30) self.pair_types = cfg.get("pair_types", ["summarize", "relevance", "implication"]) + # Quality validator + self.validator = None + validation_cfg = cfg.get("validation", {}) + if HAS_DPO_QUALITY and validation_cfg.get("enabled", True): + self.validator = DPOQualityValidator( + config=validation_cfg, + output_dir=self.output_dir, + ) + logger.info("DPO quality validator enabled") + elif not HAS_DPO_QUALITY: + logger.info("DPO quality validator not available (dpo_quality module not found)") + else: + logger.info("DPO quality validator disabled in config") + logger.info( f"DPOPairGenerator: output_dir={self.output_dir}, " f"pair_types={self.pair_types}, max_pairs={self.max_pairs_per_run}" @@ -339,7 +361,7 @@ class DPOPairGenerator: fleet_context_text: str = "", session_id: Optional[str] = None, ) -> Dict[str, Any]: - """Full Phase 3.5: generate + export DPO pairs. + """Full Phase 3.5: generate → validate → export DPO pairs. Returns summary dict for pipeline result aggregation. """ @@ -349,9 +371,46 @@ class DPOPairGenerator: return { "status": "skipped", "pairs_generated": 0, + "pairs_validated": 0, "output_path": None, } + # Quality gate: validate before export + quality_report = None + if self.validator: + pair_dicts = [p.to_dict() for p in pairs] + filtered_dicts, quality_report = self.validator.validate(pair_dicts) + + logger.info( + f"Quality gate: {quality_report.passed_pairs}/{quality_report.total_pairs} " + f"passed, {quality_report.dropped_pairs} dropped, " + f"{quality_report.flagged_pairs} flagged" + ) + + if not filtered_dicts: + return { + "status": "all_filtered", + "pairs_generated": len(pairs), + "pairs_validated": 0, + "output_path": None, + "quality": quality_report.to_dict(), + } + + # Rebuild DPOPair objects from filtered dicts + pairs = [ + DPOPair( + prompt=d["prompt"], + chosen=d["chosen"], + rejected=d["rejected"], + task_type=d.get("task_type", "unknown"), + evidence_ids=d.get("evidence_ids", []), + source_session=d.get("source_session", {}), + safety_flags=d.get("safety_flags", []), + metadata=d.get("metadata", {}), + ) + for d in filtered_dicts + ] + output_path = self.export(pairs, session_id) # Summary by task type @@ -359,10 +418,14 @@ class DPOPairGenerator: for p in pairs: type_counts[p.task_type] = type_counts.get(p.task_type, 0) + 1 - return { + result = { "status": "success", - "pairs_generated": len(pairs), + "pairs_generated": len(pairs) + (quality_report.dropped_pairs if quality_report else 0), + "pairs_validated": len(pairs), "output_path": str(output_path), "pair_types": type_counts, "output_dir": str(self.output_dir), } + if quality_report: + result["quality"] = quality_report.to_dict() + return result diff --git a/intelligence/deepdive/dpo_quality.py b/intelligence/deepdive/dpo_quality.py new file mode 100644 index 00000000..0f9a6991 --- /dev/null +++ b/intelligence/deepdive/dpo_quality.py @@ -0,0 +1,489 @@ +#!/usr/bin/env python3 +"""DPO Pair Quality Validator — Gate before overnight training. + +Catches bad training pairs before they enter the tightening loop: + +1. Near-duplicate chosen/rejected (low contrast) — model learns nothing +2. Near-duplicate prompts across pairs (low diversity) — wasted compute +3. Too-short or empty fields — malformed pairs +4. Chosen not meaningfully richer than rejected — inverted signal +5. Cross-run deduplication — don't retrain on yesterday's pairs + +Sits between DPOPairGenerator.generate() and .export(). +Pairs that fail validation get flagged, not silently dropped — +the generator decides whether to export flagged pairs or filter them. + +Usage standalone: + python3 dpo_quality.py ~/.timmy/training-data/dpo-pairs/deepdive_20260413.jsonl +""" + +import hashlib +import json +import logging +import re +from collections import Counter +from dataclasses import dataclass, field, asdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger("deepdive.dpo_quality") + + +# --------------------------------------------------------------------------- +# Configuration defaults (overridable via config dict) +# --------------------------------------------------------------------------- + +DEFAULT_CONFIG = { + # Minimum character lengths + "min_prompt_chars": 40, + "min_chosen_chars": 80, + "min_rejected_chars": 30, + + # Chosen must be at least this ratio longer than rejected + "min_chosen_rejected_ratio": 1.3, + + # Jaccard similarity thresholds (word-level) + "max_chosen_rejected_similarity": 0.70, # Flag if chosen ≈ rejected + "max_prompt_prompt_similarity": 0.85, # Flag if two prompts are near-dupes + + # Cross-run dedup: hash window (how many recent JSONL files to scan) + "dedup_history_files": 5, + + # What to do with flagged pairs: "drop" or "flag" + # "drop" = remove from export entirely + # "flag" = add warning to safety_flags but still export + "flagged_pair_action": "drop", +} + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + +@dataclass +class PairReport: + """Validation result for a single DPO pair.""" + index: int + passed: bool + warnings: List[str] = field(default_factory=list) + scores: Dict[str, float] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class BatchReport: + """Validation result for an entire batch of DPO pairs.""" + total_pairs: int + passed_pairs: int + dropped_pairs: int + flagged_pairs: int + duplicate_prompts_found: int + cross_run_duplicates_found: int + pair_reports: List[PairReport] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + + @property + def pass_rate(self) -> float: + return self.passed_pairs / max(self.total_pairs, 1) + + def to_dict(self) -> Dict[str, Any]: + d = asdict(self) + d["pass_rate"] = round(self.pass_rate, 3) + return d + + def summary(self) -> str: + lines = [ + f"DPO Quality: {self.passed_pairs}/{self.total_pairs} passed " + f"({self.pass_rate:.0%})", + f" Dropped: {self.dropped_pairs}, Flagged: {self.flagged_pairs}", + ] + if self.duplicate_prompts_found: + lines.append(f" Duplicate prompts: {self.duplicate_prompts_found}") + if self.cross_run_duplicates_found: + lines.append(f" Cross-run dupes: {self.cross_run_duplicates_found}") + if self.warnings: + for w in self.warnings: + lines.append(f" ⚠ {w}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# Core validator +# --------------------------------------------------------------------------- + +class DPOQualityValidator: + """Validate DPO pairs for quality before overnight training export. + + Call validate() with a list of pair dicts to get a BatchReport + and a filtered list of pairs that passed validation. + """ + + def __init__(self, config: Optional[Dict[str, Any]] = None, + output_dir: Optional[Path] = None): + self.cfg = {**DEFAULT_CONFIG, **(config or {})} + self.output_dir = Path(output_dir) if output_dir else Path.home() / ".timmy" / "training-data" / "dpo-pairs" + + # Cache of prompt hashes from previous runs (loaded lazily) + self._history_hashes: Optional[Set[str]] = None + + logger.info( + f"DPOQualityValidator: action={self.cfg['flagged_pair_action']}, " + f"max_cr_sim={self.cfg['max_chosen_rejected_similarity']}, " + f"max_pp_sim={self.cfg['max_prompt_prompt_similarity']}" + ) + + # ------------------------------------------------------------------- + # Text analysis helpers + # ------------------------------------------------------------------- + + @staticmethod + def _tokenize(text: str) -> List[str]: + """Simple whitespace + punctuation tokenizer.""" + return re.findall(r'\b\w+\b', text.lower()) + + @staticmethod + def _jaccard(tokens_a: List[str], tokens_b: List[str]) -> float: + """Word-level Jaccard similarity.""" + set_a = set(tokens_a) + set_b = set(tokens_b) + if not set_a and not set_b: + return 1.0 + if not set_a or not set_b: + return 0.0 + return len(set_a & set_b) / len(set_a | set_b) + + @staticmethod + def _content_hash(text: str) -> str: + """Stable hash of normalized text for deduplication.""" + normalized = " ".join(text.lower().split()) + return hashlib.sha256(normalized.encode()).hexdigest()[:16] + + @staticmethod + def _unique_word_ratio(text: str) -> float: + """Ratio of unique words to total words (vocabulary diversity).""" + words = re.findall(r'\b\w+\b', text.lower()) + if not words: + return 0.0 + return len(set(words)) / len(words) + + # ------------------------------------------------------------------- + # Single-pair validation + # ------------------------------------------------------------------- + + def _validate_pair(self, pair: Dict[str, Any], index: int) -> PairReport: + """Run all quality checks on a single pair.""" + warnings = [] + scores = {} + + prompt = pair.get("prompt", "") + chosen = pair.get("chosen", "") + rejected = pair.get("rejected", "") + + # --- Check 1: Field lengths --- + if len(prompt) < self.cfg["min_prompt_chars"]: + warnings.append( + f"prompt too short ({len(prompt)} chars, min {self.cfg['min_prompt_chars']})" + ) + if len(chosen) < self.cfg["min_chosen_chars"]: + warnings.append( + f"chosen too short ({len(chosen)} chars, min {self.cfg['min_chosen_chars']})" + ) + if len(rejected) < self.cfg["min_rejected_chars"]: + warnings.append( + f"rejected too short ({len(rejected)} chars, min {self.cfg['min_rejected_chars']})" + ) + + # --- Check 2: Chosen-Rejected length ratio --- + if len(rejected) > 0: + ratio = len(chosen) / len(rejected) + scores["chosen_rejected_ratio"] = round(ratio, 2) + if ratio < self.cfg["min_chosen_rejected_ratio"]: + warnings.append( + f"chosen/rejected ratio too low ({ratio:.2f}, " + f"min {self.cfg['min_chosen_rejected_ratio']})" + ) + else: + scores["chosen_rejected_ratio"] = 0.0 + warnings.append("rejected is empty") + + # --- Check 3: Chosen-Rejected content similarity --- + chosen_tokens = self._tokenize(chosen) + rejected_tokens = self._tokenize(rejected) + cr_sim = self._jaccard(chosen_tokens, rejected_tokens) + scores["chosen_rejected_similarity"] = round(cr_sim, 3) + + if cr_sim > self.cfg["max_chosen_rejected_similarity"]: + warnings.append( + f"chosen≈rejected (Jaccard {cr_sim:.2f}, " + f"max {self.cfg['max_chosen_rejected_similarity']})" + ) + + # --- Check 4: Vocabulary diversity in chosen --- + chosen_diversity = self._unique_word_ratio(chosen) + scores["chosen_vocab_diversity"] = round(chosen_diversity, 3) + if chosen_diversity < 0.3: + warnings.append( + f"low vocabulary diversity in chosen ({chosen_diversity:.2f})" + ) + + # --- Check 5: Chosen should contain substantive content markers --- + chosen_lower = chosen.lower() + substance_markers = [ + "relevance", "implication", "training", "agent", "fleet", + "hermes", "deploy", "architecture", "pipeline", "score", + "technique", "approach", "recommend", "review", "action", + ] + marker_hits = sum(1 for m in substance_markers if m in chosen_lower) + scores["substance_markers"] = marker_hits + if marker_hits < 2: + warnings.append( + f"chosen lacks substance markers ({marker_hits} found, min 2)" + ) + + passed = len(warnings) == 0 + return PairReport(index=index, passed=passed, warnings=warnings, scores=scores) + + # ------------------------------------------------------------------- + # Batch-level validation (cross-pair checks) + # ------------------------------------------------------------------- + + def _check_prompt_duplicates(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]: + """Find near-duplicate prompts within the batch. + + Returns dict mapping pair index → warning string for duplicates. + """ + prompt_tokens = [] + for pair in pairs: + prompt_tokens.append(self._tokenize(pair.get("prompt", ""))) + + dupe_warnings: Dict[int, str] = {} + seen_groups: List[Set[int]] = [] + + for i in range(len(prompt_tokens)): + # Skip if already in a dupe group + if any(i in g for g in seen_groups): + continue + group = {i} + for j in range(i + 1, len(prompt_tokens)): + sim = self._jaccard(prompt_tokens[i], prompt_tokens[j]) + if sim > self.cfg["max_prompt_prompt_similarity"]: + group.add(j) + dupe_warnings[j] = ( + f"near-duplicate prompt (Jaccard {sim:.2f} with pair {i})" + ) + if len(group) > 1: + seen_groups.append(group) + + return dupe_warnings + + def _load_history_hashes(self) -> Set[str]: + """Load prompt hashes from recent JSONL files for cross-run dedup.""" + if self._history_hashes is not None: + return self._history_hashes + + hashes = set() + if not self.output_dir.exists(): + self._history_hashes = hashes + return hashes + + jsonl_files = sorted( + self.output_dir.glob("deepdive_*.jsonl"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + + files_to_scan = jsonl_files[:self.cfg["dedup_history_files"]] + for path in files_to_scan: + try: + with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + pair = json.loads(line) + prompt_hash = self._content_hash(pair.get("prompt", "")) + hashes.add(prompt_hash) + except Exception as e: + logger.warning(f"Failed to read history file {path}: {e}") + + logger.info(f"Loaded {len(hashes)} prompt hashes from {len(files_to_scan)} history files") + self._history_hashes = hashes + return hashes + + def _check_cross_run_dupes(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]: + """Check if any pair prompts already exist in recent history. + + Returns dict mapping pair index → warning string for duplicates. + """ + history = self._load_history_hashes() + if not history: + return {} + + dupe_warnings: Dict[int, str] = {} + for i, pair in enumerate(pairs): + prompt_hash = self._content_hash(pair.get("prompt", "")) + if prompt_hash in history: + dupe_warnings[i] = "cross-run duplicate (prompt seen in recent history)" + + return dupe_warnings + + # ------------------------------------------------------------------- + # Main validation entry point + # ------------------------------------------------------------------- + + def validate(self, pairs: List[Dict[str, Any]]) -> tuple: + """Validate a batch of DPO pairs. + + Args: + pairs: List of pair dicts with {prompt, chosen, rejected, ...} + + Returns: + (filtered_pairs, report): Tuple of filtered pair list and BatchReport. + If flagged_pair_action="drop", filtered_pairs excludes bad pairs. + If flagged_pair_action="flag", all pairs are returned with safety_flags updated. + """ + if not pairs: + report = BatchReport( + total_pairs=0, passed_pairs=0, dropped_pairs=0, + flagged_pairs=0, duplicate_prompts_found=0, + cross_run_duplicates_found=0, + warnings=["Empty pair batch"], + ) + return [], report + + action = self.cfg["flagged_pair_action"] + pair_dicts = [p if isinstance(p, dict) else p.to_dict() for p in pairs] + + # Single-pair checks + pair_reports = [] + for i, pair in enumerate(pair_dicts): + report = self._validate_pair(pair, i) + pair_reports.append(report) + + # Cross-pair checks: prompt diversity + prompt_dupe_warnings = self._check_prompt_duplicates(pair_dicts) + for idx, warning in prompt_dupe_warnings.items(): + pair_reports[idx].warnings.append(warning) + pair_reports[idx].passed = False + + # Cross-run dedup + crossrun_dupe_warnings = self._check_cross_run_dupes(pair_dicts) + for idx, warning in crossrun_dupe_warnings.items(): + pair_reports[idx].warnings.append(warning) + pair_reports[idx].passed = False + + # Build filtered output + filtered = [] + dropped = 0 + flagged = 0 + + for i, (pair, report) in enumerate(zip(pair_dicts, pair_reports)): + if report.passed: + filtered.append(pair) + elif action == "drop": + dropped += 1 + logger.debug(f"Dropping pair {i}: {report.warnings}") + else: # "flag" + # Add warnings to safety_flags + flags = pair.get("safety_flags", []) + flags.append("quality-flagged") + for w in report.warnings: + flags.append(f"qv:{w[:60]}") + pair["safety_flags"] = flags + filtered.append(pair) + flagged += 1 + + passed = sum(1 for r in pair_reports if r.passed) + + batch_warnings = [] + if passed == 0 and len(pairs) > 0: + batch_warnings.append("ALL pairs failed validation — no training data produced") + if len(prompt_dupe_warnings) > len(pairs) * 0.5: + batch_warnings.append( + f"High prompt duplication: {len(prompt_dupe_warnings)}/{len(pairs)} pairs are near-duplicates" + ) + + # Task type diversity check + task_types = Counter(p.get("task_type", "unknown") for p in filtered) + if len(task_types) == 1 and len(filtered) > 3: + batch_warnings.append( + f"Low task-type diversity: all {len(filtered)} pairs are '{list(task_types.keys())[0]}'" + ) + + batch_report = BatchReport( + total_pairs=len(pairs), + passed_pairs=passed, + dropped_pairs=dropped, + flagged_pairs=flagged, + duplicate_prompts_found=len(prompt_dupe_warnings), + cross_run_duplicates_found=len(crossrun_dupe_warnings), + pair_reports=pair_reports, + warnings=batch_warnings, + ) + + logger.info(batch_report.summary()) + return filtered, batch_report + + +# --------------------------------------------------------------------------- +# CLI for standalone validation of existing JSONL files +# --------------------------------------------------------------------------- + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Validate DPO pair quality") + parser.add_argument("jsonl_file", type=Path, help="Path to JSONL file with DPO pairs") + parser.add_argument("--json", action="store_true", help="Output JSON report") + parser.add_argument("--strict", action="store_true", + help="Drop flagged pairs (default: flag only)") + args = parser.parse_args() + + if not args.jsonl_file.exists(): + print(f"Error: file not found: {args.jsonl_file}") + return 1 + + pairs = [] + with open(args.jsonl_file) as f: + for line in f: + line = line.strip() + if line: + pairs.append(json.loads(line)) + + config = {} + if args.strict: + config["flagged_pair_action"] = "drop" + else: + config["flagged_pair_action"] = "flag" + + # Use parent dir of input file as output_dir for history scanning + output_dir = args.jsonl_file.parent + validator = DPOQualityValidator(config=config, output_dir=output_dir) + filtered, report = validator.validate(pairs) + + if args.json: + print(json.dumps(report.to_dict(), indent=2)) + else: + print("=" * 60) + print(" DPO PAIR QUALITY VALIDATION REPORT") + print("=" * 60) + print(report.summary()) + print("-" * 60) + for pr in report.pair_reports: + status = "✓" if pr.passed else "✗" + print(f" [{status}] Pair {pr.index}: ", end="") + if pr.passed: + print("OK") + else: + print(", ".join(pr.warnings)) + print("=" * 60) + print(f"\nFiltered output: {len(filtered)} pairs " + f"({'strict/drop' if args.strict else 'flag'} mode)") + + return 0 if report.passed_pairs > 0 else 2 + + +if __name__ == "__main__": + exit(main())