Files
the-nexus/intelligence/deepdive/dpo_quality.py
perplexity 4b15cf8283
Some checks failed
CI / test (pull_request) Failing after 16s
CI / validate (pull_request) Failing after 14s
Review Approval Gate / verify-review (pull_request) Failing after 3s
feat: full-history persistent dedup index for DPO training pairs
Replace the 5-file sliding window cross-run dedup with a persistent
hash index that covers ALL historical training data. Overfitting risk
compounds across the full dataset — a 5-file window lets old duplicates
leak back into training after enough overnight runs.

New module: dedup_index.py (DedupIndex)
- Persistent JSON index (.dpo_dedup_index.json) alongside JSONL files
- Append-on-export: new prompt hashes registered after each successful
  export — no full rescan needed for normal operations
- Incremental sync: on load, detects JSONL files not yet indexed and
  ingests them automatically (handles files from other tools)
- Full rebuild: rebuild() scans ALL deepdive_*.jsonl + pairs_*.jsonl
  to reconstruct from scratch (first run, corruption recovery)
- Atomic writes (write-to-tmp + rename) to prevent index corruption
- Standalone CLI: python3 dedup_index.py <dir> --rebuild --stats

Modified: dpo_quality.py
- Imports DedupIndex with graceful degradation
- Replaces _load_history_hashes() with persistent index lookup
- Fallback: if index unavailable, scans ALL files in-memory (not just 5)
- New register_exported_hashes() method called after export
- Config key: dedup_full_history (replaces dedup_history_files)

Modified: dpo_generator.py
- Calls validator.register_exported_hashes() after successful export
  to keep the persistent index current without rescanning

Modified: config.yaml
- Replaced dedup_history_files: 5 with dedup_full_history: true

Tested — 7 integration tests:
  ✓ Fresh index build from empty directory
  ✓ Build from 3 existing JSONL files (15 unique hashes)
  ✓ Incremental sync when new file appears between runs
  ✓ Append after export + persistence across reloads
  ✓ Rebuild from scratch (recovers from corruption)
  ✓ Validator catches day-1 dupe from 20-day history (5-file window miss)
  ✓ Full pipeline: generate → validate → export → register → re-run detects
2026-04-13 03:11:10 +00:00

534 lines
20 KiB
Python

#!/usr/bin/env python3
"""DPO Pair Quality Validator — Gate before overnight training.
Catches bad training pairs before they enter the tightening loop:
1. Near-duplicate chosen/rejected (low contrast) — model learns nothing
2. Near-duplicate prompts across pairs (low diversity) — wasted compute
3. Too-short or empty fields — malformed pairs
4. Chosen not meaningfully richer than rejected — inverted signal
5. Cross-run deduplication — don't retrain on yesterday's pairs
Sits between DPOPairGenerator.generate() and .export().
Pairs that fail validation get flagged, not silently dropped —
the generator decides whether to export flagged pairs or filter them.
Usage standalone:
python3 dpo_quality.py ~/.timmy/training-data/dpo-pairs/deepdive_20260413.jsonl
"""
import hashlib
import json
import logging
import re
from collections import Counter
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
# Persistent dedup index
try:
from dedup_index import DedupIndex
HAS_DEDUP_INDEX = True
except ImportError:
HAS_DEDUP_INDEX = False
DedupIndex = None
logger = logging.getLogger("deepdive.dpo_quality")
# ---------------------------------------------------------------------------
# Configuration defaults (overridable via config dict)
# ---------------------------------------------------------------------------
DEFAULT_CONFIG = {
# Minimum character lengths
"min_prompt_chars": 40,
"min_chosen_chars": 80,
"min_rejected_chars": 30,
# Chosen must be at least this ratio longer than rejected
"min_chosen_rejected_ratio": 1.3,
# Jaccard similarity thresholds (word-level)
"max_chosen_rejected_similarity": 0.70, # Flag if chosen ≈ rejected
"max_prompt_prompt_similarity": 0.85, # Flag if two prompts are near-dupes
# Cross-run dedup: full-history persistent index
# (replaces the old sliding-window approach)
"dedup_full_history": True,
# What to do with flagged pairs: "drop" or "flag"
# "drop" = remove from export entirely
# "flag" = add warning to safety_flags but still export
"flagged_pair_action": "drop",
}
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class PairReport:
"""Validation result for a single DPO pair."""
index: int
passed: bool
warnings: List[str] = field(default_factory=list)
scores: Dict[str, float] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@dataclass
class BatchReport:
"""Validation result for an entire batch of DPO pairs."""
total_pairs: int
passed_pairs: int
dropped_pairs: int
flagged_pairs: int
duplicate_prompts_found: int
cross_run_duplicates_found: int
pair_reports: List[PairReport] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
@property
def pass_rate(self) -> float:
return self.passed_pairs / max(self.total_pairs, 1)
def to_dict(self) -> Dict[str, Any]:
d = asdict(self)
d["pass_rate"] = round(self.pass_rate, 3)
return d
def summary(self) -> str:
lines = [
f"DPO Quality: {self.passed_pairs}/{self.total_pairs} passed "
f"({self.pass_rate:.0%})",
f" Dropped: {self.dropped_pairs}, Flagged: {self.flagged_pairs}",
]
if self.duplicate_prompts_found:
lines.append(f" Duplicate prompts: {self.duplicate_prompts_found}")
if self.cross_run_duplicates_found:
lines.append(f" Cross-run dupes: {self.cross_run_duplicates_found}")
if self.warnings:
for w in self.warnings:
lines.append(f"{w}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Core validator
# ---------------------------------------------------------------------------
class DPOQualityValidator:
"""Validate DPO pairs for quality before overnight training export.
Call validate() with a list of pair dicts to get a BatchReport
and a filtered list of pairs that passed validation.
"""
def __init__(self, config: Optional[Dict[str, Any]] = None,
output_dir: Optional[Path] = None):
self.cfg = {**DEFAULT_CONFIG, **(config or {})}
self.output_dir = Path(output_dir) if output_dir else Path.home() / ".timmy" / "training-data" / "dpo-pairs"
# Persistent full-history dedup index
self._dedup_index = None
if HAS_DEDUP_INDEX and self.cfg.get("dedup_full_history", True):
try:
self._dedup_index = DedupIndex(self.output_dir)
logger.info(
f"Full-history dedup index: {self._dedup_index.size} prompts, "
f"{self._dedup_index.files_indexed} files"
)
except Exception as e:
logger.warning(f"Failed to load dedup index, falling back to in-memory: {e}")
self._dedup_index = None
# Fallback: in-memory hash cache (used if index unavailable)
self._history_hashes: Optional[Set[str]] = None
logger.info(
f"DPOQualityValidator: action={self.cfg['flagged_pair_action']}, "
f"max_cr_sim={self.cfg['max_chosen_rejected_similarity']}, "
f"max_pp_sim={self.cfg['max_prompt_prompt_similarity']}, "
f"dedup={'full-history index' if self._dedup_index else 'in-memory fallback'}"
)
# -------------------------------------------------------------------
# Text analysis helpers
# -------------------------------------------------------------------
@staticmethod
def _tokenize(text: str) -> List[str]:
"""Simple whitespace + punctuation tokenizer."""
return re.findall(r'\b\w+\b', text.lower())
@staticmethod
def _jaccard(tokens_a: List[str], tokens_b: List[str]) -> float:
"""Word-level Jaccard similarity."""
set_a = set(tokens_a)
set_b = set(tokens_b)
if not set_a and not set_b:
return 1.0
if not set_a or not set_b:
return 0.0
return len(set_a & set_b) / len(set_a | set_b)
@staticmethod
def _content_hash(text: str) -> str:
"""Stable hash of normalized text for deduplication."""
normalized = " ".join(text.lower().split())
return hashlib.sha256(normalized.encode()).hexdigest()[:16]
@staticmethod
def _unique_word_ratio(text: str) -> float:
"""Ratio of unique words to total words (vocabulary diversity)."""
words = re.findall(r'\b\w+\b', text.lower())
if not words:
return 0.0
return len(set(words)) / len(words)
# -------------------------------------------------------------------
# Single-pair validation
# -------------------------------------------------------------------
def _validate_pair(self, pair: Dict[str, Any], index: int) -> PairReport:
"""Run all quality checks on a single pair."""
warnings = []
scores = {}
prompt = pair.get("prompt", "")
chosen = pair.get("chosen", "")
rejected = pair.get("rejected", "")
# --- Check 1: Field lengths ---
if len(prompt) < self.cfg["min_prompt_chars"]:
warnings.append(
f"prompt too short ({len(prompt)} chars, min {self.cfg['min_prompt_chars']})"
)
if len(chosen) < self.cfg["min_chosen_chars"]:
warnings.append(
f"chosen too short ({len(chosen)} chars, min {self.cfg['min_chosen_chars']})"
)
if len(rejected) < self.cfg["min_rejected_chars"]:
warnings.append(
f"rejected too short ({len(rejected)} chars, min {self.cfg['min_rejected_chars']})"
)
# --- Check 2: Chosen-Rejected length ratio ---
if len(rejected) > 0:
ratio = len(chosen) / len(rejected)
scores["chosen_rejected_ratio"] = round(ratio, 2)
if ratio < self.cfg["min_chosen_rejected_ratio"]:
warnings.append(
f"chosen/rejected ratio too low ({ratio:.2f}, "
f"min {self.cfg['min_chosen_rejected_ratio']})"
)
else:
scores["chosen_rejected_ratio"] = 0.0
warnings.append("rejected is empty")
# --- Check 3: Chosen-Rejected content similarity ---
chosen_tokens = self._tokenize(chosen)
rejected_tokens = self._tokenize(rejected)
cr_sim = self._jaccard(chosen_tokens, rejected_tokens)
scores["chosen_rejected_similarity"] = round(cr_sim, 3)
if cr_sim > self.cfg["max_chosen_rejected_similarity"]:
warnings.append(
f"chosen≈rejected (Jaccard {cr_sim:.2f}, "
f"max {self.cfg['max_chosen_rejected_similarity']})"
)
# --- Check 4: Vocabulary diversity in chosen ---
chosen_diversity = self._unique_word_ratio(chosen)
scores["chosen_vocab_diversity"] = round(chosen_diversity, 3)
if chosen_diversity < 0.3:
warnings.append(
f"low vocabulary diversity in chosen ({chosen_diversity:.2f})"
)
# --- Check 5: Chosen should contain substantive content markers ---
chosen_lower = chosen.lower()
substance_markers = [
"relevance", "implication", "training", "agent", "fleet",
"hermes", "deploy", "architecture", "pipeline", "score",
"technique", "approach", "recommend", "review", "action",
]
marker_hits = sum(1 for m in substance_markers if m in chosen_lower)
scores["substance_markers"] = marker_hits
if marker_hits < 2:
warnings.append(
f"chosen lacks substance markers ({marker_hits} found, min 2)"
)
passed = len(warnings) == 0
return PairReport(index=index, passed=passed, warnings=warnings, scores=scores)
# -------------------------------------------------------------------
# Batch-level validation (cross-pair checks)
# -------------------------------------------------------------------
def _check_prompt_duplicates(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
"""Find near-duplicate prompts within the batch.
Returns dict mapping pair index → warning string for duplicates.
"""
prompt_tokens = []
for pair in pairs:
prompt_tokens.append(self._tokenize(pair.get("prompt", "")))
dupe_warnings: Dict[int, str] = {}
seen_groups: List[Set[int]] = []
for i in range(len(prompt_tokens)):
# Skip if already in a dupe group
if any(i in g for g in seen_groups):
continue
group = {i}
for j in range(i + 1, len(prompt_tokens)):
sim = self._jaccard(prompt_tokens[i], prompt_tokens[j])
if sim > self.cfg["max_prompt_prompt_similarity"]:
group.add(j)
dupe_warnings[j] = (
f"near-duplicate prompt (Jaccard {sim:.2f} with pair {i})"
)
if len(group) > 1:
seen_groups.append(group)
return dupe_warnings
def _check_cross_run_dupes(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
"""Check if any pair prompts exist in full training history.
Uses persistent DedupIndex when available (covers all historical
JSONL files). Falls back to in-memory scan of ALL files if index
module is unavailable.
Returns dict mapping pair index → warning string for duplicates.
"""
dupe_warnings: Dict[int, str] = {}
if self._dedup_index:
# Full-history lookup via persistent index
for i, pair in enumerate(pairs):
prompt_hash = self._content_hash(pair.get("prompt", ""))
if self._dedup_index.contains(prompt_hash):
dupe_warnings[i] = (
f"cross-run duplicate (prompt seen in full history — "
f"{self._dedup_index.size} indexed prompts)"
)
return dupe_warnings
# Fallback: scan all JSONL files in output_dir (no sliding window)
if self._history_hashes is None:
self._history_hashes = set()
if self.output_dir.exists():
jsonl_files = sorted(self.output_dir.glob("deepdive_*.jsonl"))
jsonl_files.extend(sorted(self.output_dir.glob("pairs_*.jsonl")))
for path in jsonl_files:
try:
with open(path) as f:
for line in f:
line = line.strip()
if not line:
continue
pair_data = json.loads(line)
h = self._content_hash(pair_data.get("prompt", ""))
self._history_hashes.add(h)
except Exception as e:
logger.warning(f"Failed to read history file {path}: {e}")
logger.info(
f"Fallback dedup: loaded {len(self._history_hashes)} hashes "
f"from {len(jsonl_files)} files"
)
for i, pair in enumerate(pairs):
prompt_hash = self._content_hash(pair.get("prompt", ""))
if prompt_hash in self._history_hashes:
dupe_warnings[i] = "cross-run duplicate (prompt seen in full history)"
return dupe_warnings
def register_exported_hashes(self, pairs: List[Dict[str, Any]],
filename: str) -> None:
"""After successful export, register new prompt hashes in the index.
Called by DPOPairGenerator after writing the JSONL file.
"""
hashes = [self._content_hash(p.get("prompt", "")) for p in pairs]
if self._dedup_index:
added = self._dedup_index.add_hashes_and_register(hashes, filename)
logger.info(
f"Registered {added} new hashes in dedup index "
f"(total: {self._dedup_index.size})"
)
else:
# Update in-memory fallback
if self._history_hashes is None:
self._history_hashes = set()
self._history_hashes.update(hashes)
# -------------------------------------------------------------------
# Main validation entry point
# -------------------------------------------------------------------
def validate(self, pairs: List[Dict[str, Any]]) -> tuple:
"""Validate a batch of DPO pairs.
Args:
pairs: List of pair dicts with {prompt, chosen, rejected, ...}
Returns:
(filtered_pairs, report): Tuple of filtered pair list and BatchReport.
If flagged_pair_action="drop", filtered_pairs excludes bad pairs.
If flagged_pair_action="flag", all pairs are returned with safety_flags updated.
"""
if not pairs:
report = BatchReport(
total_pairs=0, passed_pairs=0, dropped_pairs=0,
flagged_pairs=0, duplicate_prompts_found=0,
cross_run_duplicates_found=0,
warnings=["Empty pair batch"],
)
return [], report
action = self.cfg["flagged_pair_action"]
pair_dicts = [p if isinstance(p, dict) else p.to_dict() for p in pairs]
# Single-pair checks
pair_reports = []
for i, pair in enumerate(pair_dicts):
report = self._validate_pair(pair, i)
pair_reports.append(report)
# Cross-pair checks: prompt diversity
prompt_dupe_warnings = self._check_prompt_duplicates(pair_dicts)
for idx, warning in prompt_dupe_warnings.items():
pair_reports[idx].warnings.append(warning)
pair_reports[idx].passed = False
# Cross-run dedup
crossrun_dupe_warnings = self._check_cross_run_dupes(pair_dicts)
for idx, warning in crossrun_dupe_warnings.items():
pair_reports[idx].warnings.append(warning)
pair_reports[idx].passed = False
# Build filtered output
filtered = []
dropped = 0
flagged = 0
for i, (pair, report) in enumerate(zip(pair_dicts, pair_reports)):
if report.passed:
filtered.append(pair)
elif action == "drop":
dropped += 1
logger.debug(f"Dropping pair {i}: {report.warnings}")
else: # "flag"
# Add warnings to safety_flags
flags = pair.get("safety_flags", [])
flags.append("quality-flagged")
for w in report.warnings:
flags.append(f"qv:{w[:60]}")
pair["safety_flags"] = flags
filtered.append(pair)
flagged += 1
passed = sum(1 for r in pair_reports if r.passed)
batch_warnings = []
if passed == 0 and len(pairs) > 0:
batch_warnings.append("ALL pairs failed validation — no training data produced")
if len(prompt_dupe_warnings) > len(pairs) * 0.5:
batch_warnings.append(
f"High prompt duplication: {len(prompt_dupe_warnings)}/{len(pairs)} pairs are near-duplicates"
)
# Task type diversity check
task_types = Counter(p.get("task_type", "unknown") for p in filtered)
if len(task_types) == 1 and len(filtered) > 3:
batch_warnings.append(
f"Low task-type diversity: all {len(filtered)} pairs are '{list(task_types.keys())[0]}'"
)
batch_report = BatchReport(
total_pairs=len(pairs),
passed_pairs=passed,
dropped_pairs=dropped,
flagged_pairs=flagged,
duplicate_prompts_found=len(prompt_dupe_warnings),
cross_run_duplicates_found=len(crossrun_dupe_warnings),
pair_reports=pair_reports,
warnings=batch_warnings,
)
logger.info(batch_report.summary())
return filtered, batch_report
# ---------------------------------------------------------------------------
# CLI for standalone validation of existing JSONL files
# ---------------------------------------------------------------------------
def main():
import argparse
parser = argparse.ArgumentParser(description="Validate DPO pair quality")
parser.add_argument("jsonl_file", type=Path, help="Path to JSONL file with DPO pairs")
parser.add_argument("--json", action="store_true", help="Output JSON report")
parser.add_argument("--strict", action="store_true",
help="Drop flagged pairs (default: flag only)")
args = parser.parse_args()
if not args.jsonl_file.exists():
print(f"Error: file not found: {args.jsonl_file}")
return 1
pairs = []
with open(args.jsonl_file) as f:
for line in f:
line = line.strip()
if line:
pairs.append(json.loads(line))
config = {}
if args.strict:
config["flagged_pair_action"] = "drop"
else:
config["flagged_pair_action"] = "flag"
# Use parent dir of input file as output_dir for history scanning
output_dir = args.jsonl_file.parent
validator = DPOQualityValidator(config=config, output_dir=output_dir)
filtered, report = validator.validate(pairs)
if args.json:
print(json.dumps(report.to_dict(), indent=2))
else:
print("=" * 60)
print(" DPO PAIR QUALITY VALIDATION REPORT")
print("=" * 60)
print(report.summary())
print("-" * 60)
for pr in report.pair_reports:
status = "" if pr.passed else ""
print(f" [{status}] Pair {pr.index}: ", end="")
if pr.passed:
print("OK")
else:
print(", ".join(pr.warnings))
print("=" * 60)
print(f"\nFiltered output: {len(filtered)} pairs "
f"({'strict/drop' if args.strict else 'flag'} mode)")
return 0 if report.passed_pairs > 0 else 2
if __name__ == "__main__":
exit(main())