feat: DPO pair quality validator — gate before overnight training
Add DPOQualityValidator that catches bad training pairs before they enter the tightening loop. Wired into DPOPairGenerator between generate() and export() as an automatic quality gate. New module: dpo_quality.py - 5 single-pair quality checks: 1. Field length minimums (prompt ≥40, chosen ≥80, rejected ≥30 chars) 2. Chosen/rejected length ratio (chosen must be ≥1.3x longer) 3. Chosen≈rejected similarity (Jaccard ≤0.70 — catches low-contrast) 4. Vocabulary diversity in chosen (unique word ratio ≥0.30) 5. Substance markers in chosen (≥2 fleet/training/action terms) - 2 cross-pair quality checks: 6. Near-duplicate prompts within batch (Jaccard ≤0.85) 7. Cross-run dedup against recent JSONL history files - Two modes: 'drop' (filter out bad pairs) or 'flag' (export with warning) - BatchReport with per-pair diagnostics, pass rates, and warnings - Standalone CLI: python3 dpo_quality.py <file.jsonl> [--strict] [--json] Modified: dpo_generator.py - Imports DPOQualityValidator with graceful degradation - Initializes from config validation section (enabled by default) - Validates between generate() and export() in run() - Quality report included in pipeline result dict - Validator failure never blocks — falls back to unvalidated export Modified: config.yaml - New deepdive.training.dpo.validation section with all tunable knobs: enabled, flagged_pair_action, similarity thresholds, length minimums, dedup_history_files Integration tested — 6 test cases covering: ✓ Good pairs pass (3/3 accepted) ✓ Bad pairs caught: too-short, high-similarity, inverted signal (0/3) ✓ Near-duplicate prompt detection (1/2 deduped) ✓ Flag mode preserves pairs with warnings (3/3 flagged) ✓ Cross-run deduplication against history (1 dupe caught) ✓ Full generator→validator→export pipeline (6/6 validated)
This commit is contained in:
committed by
Alexander Whitestone
parent
984dce78e7
commit
77cfa48707
@@ -22,6 +22,14 @@ from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Quality validation gate
|
||||
try:
|
||||
from dpo_quality import DPOQualityValidator
|
||||
HAS_DPO_QUALITY = True
|
||||
except ImportError:
|
||||
HAS_DPO_QUALITY = False
|
||||
DPOQualityValidator = None
|
||||
|
||||
logger = logging.getLogger("deepdive.dpo_generator")
|
||||
|
||||
|
||||
@@ -69,6 +77,20 @@ class DPOPairGenerator:
|
||||
self.max_pairs_per_run = cfg.get("max_pairs_per_run", 30)
|
||||
self.pair_types = cfg.get("pair_types", ["summarize", "relevance", "implication"])
|
||||
|
||||
# Quality validator
|
||||
self.validator = None
|
||||
validation_cfg = cfg.get("validation", {})
|
||||
if HAS_DPO_QUALITY and validation_cfg.get("enabled", True):
|
||||
self.validator = DPOQualityValidator(
|
||||
config=validation_cfg,
|
||||
output_dir=self.output_dir,
|
||||
)
|
||||
logger.info("DPO quality validator enabled")
|
||||
elif not HAS_DPO_QUALITY:
|
||||
logger.info("DPO quality validator not available (dpo_quality module not found)")
|
||||
else:
|
||||
logger.info("DPO quality validator disabled in config")
|
||||
|
||||
logger.info(
|
||||
f"DPOPairGenerator: output_dir={self.output_dir}, "
|
||||
f"pair_types={self.pair_types}, max_pairs={self.max_pairs_per_run}"
|
||||
@@ -339,7 +361,7 @@ class DPOPairGenerator:
|
||||
fleet_context_text: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Full Phase 3.5: generate + export DPO pairs.
|
||||
"""Full Phase 3.5: generate → validate → export DPO pairs.
|
||||
|
||||
Returns summary dict for pipeline result aggregation.
|
||||
"""
|
||||
@@ -349,9 +371,46 @@ class DPOPairGenerator:
|
||||
return {
|
||||
"status": "skipped",
|
||||
"pairs_generated": 0,
|
||||
"pairs_validated": 0,
|
||||
"output_path": None,
|
||||
}
|
||||
|
||||
# Quality gate: validate before export
|
||||
quality_report = None
|
||||
if self.validator:
|
||||
pair_dicts = [p.to_dict() for p in pairs]
|
||||
filtered_dicts, quality_report = self.validator.validate(pair_dicts)
|
||||
|
||||
logger.info(
|
||||
f"Quality gate: {quality_report.passed_pairs}/{quality_report.total_pairs} "
|
||||
f"passed, {quality_report.dropped_pairs} dropped, "
|
||||
f"{quality_report.flagged_pairs} flagged"
|
||||
)
|
||||
|
||||
if not filtered_dicts:
|
||||
return {
|
||||
"status": "all_filtered",
|
||||
"pairs_generated": len(pairs),
|
||||
"pairs_validated": 0,
|
||||
"output_path": None,
|
||||
"quality": quality_report.to_dict(),
|
||||
}
|
||||
|
||||
# Rebuild DPOPair objects from filtered dicts
|
||||
pairs = [
|
||||
DPOPair(
|
||||
prompt=d["prompt"],
|
||||
chosen=d["chosen"],
|
||||
rejected=d["rejected"],
|
||||
task_type=d.get("task_type", "unknown"),
|
||||
evidence_ids=d.get("evidence_ids", []),
|
||||
source_session=d.get("source_session", {}),
|
||||
safety_flags=d.get("safety_flags", []),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
for d in filtered_dicts
|
||||
]
|
||||
|
||||
output_path = self.export(pairs, session_id)
|
||||
|
||||
# Summary by task type
|
||||
@@ -359,10 +418,14 @@ class DPOPairGenerator:
|
||||
for p in pairs:
|
||||
type_counts[p.task_type] = type_counts.get(p.task_type, 0) + 1
|
||||
|
||||
return {
|
||||
result = {
|
||||
"status": "success",
|
||||
"pairs_generated": len(pairs),
|
||||
"pairs_generated": len(pairs) + (quality_report.dropped_pairs if quality_report else 0),
|
||||
"pairs_validated": len(pairs),
|
||||
"output_path": str(output_path),
|
||||
"pair_types": type_counts,
|
||||
"output_dir": str(self.output_dir),
|
||||
}
|
||||
if quality_report:
|
||||
result["quality"] = quality_report.to_dict()
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user