Compare commits
2 Commits
feature/de
...
feature/dp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb4922adeb | ||
| c19000de03 |
@@ -99,6 +99,16 @@ deepdive:
|
||||
- "summarize" # Paper summary → fleet-grounded analysis
|
||||
- "relevance" # Relevance analysis → scored fleet context
|
||||
- "implication" # Implications → actionable insight
|
||||
validation:
|
||||
enabled: true
|
||||
flagged_pair_action: "drop" # "drop" = remove bad pairs, "flag" = export with warning
|
||||
min_prompt_chars: 40 # Minimum prompt length
|
||||
min_chosen_chars: 80 # Minimum chosen response length
|
||||
min_rejected_chars: 30 # Minimum rejected response length
|
||||
min_chosen_rejected_ratio: 1.3 # Chosen must be ≥1.3x longer than rejected
|
||||
max_chosen_rejected_similarity: 0.70 # Max Jaccard overlap between chosen/rejected
|
||||
max_prompt_prompt_similarity: 0.85 # Max Jaccard overlap between prompts (dedup)
|
||||
dedup_history_files: 5 # How many recent JSONL files to scan for cross-run dedup
|
||||
|
||||
# Phase 0: Fleet Context Grounding
|
||||
fleet_context:
|
||||
|
||||
@@ -22,6 +22,14 @@ from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Quality validation gate
|
||||
try:
|
||||
from dpo_quality import DPOQualityValidator
|
||||
HAS_DPO_QUALITY = True
|
||||
except ImportError:
|
||||
HAS_DPO_QUALITY = False
|
||||
DPOQualityValidator = None
|
||||
|
||||
logger = logging.getLogger("deepdive.dpo_generator")
|
||||
|
||||
|
||||
@@ -69,6 +77,20 @@ class DPOPairGenerator:
|
||||
self.max_pairs_per_run = cfg.get("max_pairs_per_run", 30)
|
||||
self.pair_types = cfg.get("pair_types", ["summarize", "relevance", "implication"])
|
||||
|
||||
# Quality validator
|
||||
self.validator = None
|
||||
validation_cfg = cfg.get("validation", {})
|
||||
if HAS_DPO_QUALITY and validation_cfg.get("enabled", True):
|
||||
self.validator = DPOQualityValidator(
|
||||
config=validation_cfg,
|
||||
output_dir=self.output_dir,
|
||||
)
|
||||
logger.info("DPO quality validator enabled")
|
||||
elif not HAS_DPO_QUALITY:
|
||||
logger.info("DPO quality validator not available (dpo_quality module not found)")
|
||||
else:
|
||||
logger.info("DPO quality validator disabled in config")
|
||||
|
||||
logger.info(
|
||||
f"DPOPairGenerator: output_dir={self.output_dir}, "
|
||||
f"pair_types={self.pair_types}, max_pairs={self.max_pairs_per_run}"
|
||||
@@ -339,7 +361,7 @@ class DPOPairGenerator:
|
||||
fleet_context_text: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Full Phase 3.5: generate + export DPO pairs.
|
||||
"""Full Phase 3.5: generate → validate → export DPO pairs.
|
||||
|
||||
Returns summary dict for pipeline result aggregation.
|
||||
"""
|
||||
@@ -349,9 +371,46 @@ class DPOPairGenerator:
|
||||
return {
|
||||
"status": "skipped",
|
||||
"pairs_generated": 0,
|
||||
"pairs_validated": 0,
|
||||
"output_path": None,
|
||||
}
|
||||
|
||||
# Quality gate: validate before export
|
||||
quality_report = None
|
||||
if self.validator:
|
||||
pair_dicts = [p.to_dict() for p in pairs]
|
||||
filtered_dicts, quality_report = self.validator.validate(pair_dicts)
|
||||
|
||||
logger.info(
|
||||
f"Quality gate: {quality_report.passed_pairs}/{quality_report.total_pairs} "
|
||||
f"passed, {quality_report.dropped_pairs} dropped, "
|
||||
f"{quality_report.flagged_pairs} flagged"
|
||||
)
|
||||
|
||||
if not filtered_dicts:
|
||||
return {
|
||||
"status": "all_filtered",
|
||||
"pairs_generated": len(pairs),
|
||||
"pairs_validated": 0,
|
||||
"output_path": None,
|
||||
"quality": quality_report.to_dict(),
|
||||
}
|
||||
|
||||
# Rebuild DPOPair objects from filtered dicts
|
||||
pairs = [
|
||||
DPOPair(
|
||||
prompt=d["prompt"],
|
||||
chosen=d["chosen"],
|
||||
rejected=d["rejected"],
|
||||
task_type=d.get("task_type", "unknown"),
|
||||
evidence_ids=d.get("evidence_ids", []),
|
||||
source_session=d.get("source_session", {}),
|
||||
safety_flags=d.get("safety_flags", []),
|
||||
metadata=d.get("metadata", {}),
|
||||
)
|
||||
for d in filtered_dicts
|
||||
]
|
||||
|
||||
output_path = self.export(pairs, session_id)
|
||||
|
||||
# Summary by task type
|
||||
@@ -359,10 +418,14 @@ class DPOPairGenerator:
|
||||
for p in pairs:
|
||||
type_counts[p.task_type] = type_counts.get(p.task_type, 0) + 1
|
||||
|
||||
return {
|
||||
result = {
|
||||
"status": "success",
|
||||
"pairs_generated": len(pairs),
|
||||
"pairs_generated": len(pairs) + (quality_report.dropped_pairs if quality_report else 0),
|
||||
"pairs_validated": len(pairs),
|
||||
"output_path": str(output_path),
|
||||
"pair_types": type_counts,
|
||||
"output_dir": str(self.output_dir),
|
||||
}
|
||||
if quality_report:
|
||||
result["quality"] = quality_report.to_dict()
|
||||
return result
|
||||
|
||||
489
intelligence/deepdive/dpo_quality.py
Normal file
489
intelligence/deepdive/dpo_quality.py
Normal file
@@ -0,0 +1,489 @@
|
||||
#!/usr/bin/env python3
|
||||
"""DPO Pair Quality Validator — Gate before overnight training.
|
||||
|
||||
Catches bad training pairs before they enter the tightening loop:
|
||||
|
||||
1. Near-duplicate chosen/rejected (low contrast) — model learns nothing
|
||||
2. Near-duplicate prompts across pairs (low diversity) — wasted compute
|
||||
3. Too-short or empty fields — malformed pairs
|
||||
4. Chosen not meaningfully richer than rejected — inverted signal
|
||||
5. Cross-run deduplication — don't retrain on yesterday's pairs
|
||||
|
||||
Sits between DPOPairGenerator.generate() and .export().
|
||||
Pairs that fail validation get flagged, not silently dropped —
|
||||
the generator decides whether to export flagged pairs or filter them.
|
||||
|
||||
Usage standalone:
|
||||
python3 dpo_quality.py ~/.timmy/training-data/dpo-pairs/deepdive_20260413.jsonl
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger("deepdive.dpo_quality")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration defaults (overridable via config dict)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
# Minimum character lengths
|
||||
"min_prompt_chars": 40,
|
||||
"min_chosen_chars": 80,
|
||||
"min_rejected_chars": 30,
|
||||
|
||||
# Chosen must be at least this ratio longer than rejected
|
||||
"min_chosen_rejected_ratio": 1.3,
|
||||
|
||||
# Jaccard similarity thresholds (word-level)
|
||||
"max_chosen_rejected_similarity": 0.70, # Flag if chosen ≈ rejected
|
||||
"max_prompt_prompt_similarity": 0.85, # Flag if two prompts are near-dupes
|
||||
|
||||
# Cross-run dedup: hash window (how many recent JSONL files to scan)
|
||||
"dedup_history_files": 5,
|
||||
|
||||
# What to do with flagged pairs: "drop" or "flag"
|
||||
# "drop" = remove from export entirely
|
||||
# "flag" = add warning to safety_flags but still export
|
||||
"flagged_pair_action": "drop",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class PairReport:
|
||||
"""Validation result for a single DPO pair."""
|
||||
index: int
|
||||
passed: bool
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
scores: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchReport:
|
||||
"""Validation result for an entire batch of DPO pairs."""
|
||||
total_pairs: int
|
||||
passed_pairs: int
|
||||
dropped_pairs: int
|
||||
flagged_pairs: int
|
||||
duplicate_prompts_found: int
|
||||
cross_run_duplicates_found: int
|
||||
pair_reports: List[PairReport] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def pass_rate(self) -> float:
|
||||
return self.passed_pairs / max(self.total_pairs, 1)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
d = asdict(self)
|
||||
d["pass_rate"] = round(self.pass_rate, 3)
|
||||
return d
|
||||
|
||||
def summary(self) -> str:
|
||||
lines = [
|
||||
f"DPO Quality: {self.passed_pairs}/{self.total_pairs} passed "
|
||||
f"({self.pass_rate:.0%})",
|
||||
f" Dropped: {self.dropped_pairs}, Flagged: {self.flagged_pairs}",
|
||||
]
|
||||
if self.duplicate_prompts_found:
|
||||
lines.append(f" Duplicate prompts: {self.duplicate_prompts_found}")
|
||||
if self.cross_run_duplicates_found:
|
||||
lines.append(f" Cross-run dupes: {self.cross_run_duplicates_found}")
|
||||
if self.warnings:
|
||||
for w in self.warnings:
|
||||
lines.append(f" ⚠ {w}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core validator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DPOQualityValidator:
|
||||
"""Validate DPO pairs for quality before overnight training export.
|
||||
|
||||
Call validate() with a list of pair dicts to get a BatchReport
|
||||
and a filtered list of pairs that passed validation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[Dict[str, Any]] = None,
|
||||
output_dir: Optional[Path] = None):
|
||||
self.cfg = {**DEFAULT_CONFIG, **(config or {})}
|
||||
self.output_dir = Path(output_dir) if output_dir else Path.home() / ".timmy" / "training-data" / "dpo-pairs"
|
||||
|
||||
# Cache of prompt hashes from previous runs (loaded lazily)
|
||||
self._history_hashes: Optional[Set[str]] = None
|
||||
|
||||
logger.info(
|
||||
f"DPOQualityValidator: action={self.cfg['flagged_pair_action']}, "
|
||||
f"max_cr_sim={self.cfg['max_chosen_rejected_similarity']}, "
|
||||
f"max_pp_sim={self.cfg['max_prompt_prompt_similarity']}"
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Text analysis helpers
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
"""Simple whitespace + punctuation tokenizer."""
|
||||
return re.findall(r'\b\w+\b', text.lower())
|
||||
|
||||
@staticmethod
|
||||
def _jaccard(tokens_a: List[str], tokens_b: List[str]) -> float:
|
||||
"""Word-level Jaccard similarity."""
|
||||
set_a = set(tokens_a)
|
||||
set_b = set(tokens_b)
|
||||
if not set_a and not set_b:
|
||||
return 1.0
|
||||
if not set_a or not set_b:
|
||||
return 0.0
|
||||
return len(set_a & set_b) / len(set_a | set_b)
|
||||
|
||||
@staticmethod
|
||||
def _content_hash(text: str) -> str:
|
||||
"""Stable hash of normalized text for deduplication."""
|
||||
normalized = " ".join(text.lower().split())
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()[:16]
|
||||
|
||||
@staticmethod
|
||||
def _unique_word_ratio(text: str) -> float:
|
||||
"""Ratio of unique words to total words (vocabulary diversity)."""
|
||||
words = re.findall(r'\b\w+\b', text.lower())
|
||||
if not words:
|
||||
return 0.0
|
||||
return len(set(words)) / len(words)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Single-pair validation
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _validate_pair(self, pair: Dict[str, Any], index: int) -> PairReport:
|
||||
"""Run all quality checks on a single pair."""
|
||||
warnings = []
|
||||
scores = {}
|
||||
|
||||
prompt = pair.get("prompt", "")
|
||||
chosen = pair.get("chosen", "")
|
||||
rejected = pair.get("rejected", "")
|
||||
|
||||
# --- Check 1: Field lengths ---
|
||||
if len(prompt) < self.cfg["min_prompt_chars"]:
|
||||
warnings.append(
|
||||
f"prompt too short ({len(prompt)} chars, min {self.cfg['min_prompt_chars']})"
|
||||
)
|
||||
if len(chosen) < self.cfg["min_chosen_chars"]:
|
||||
warnings.append(
|
||||
f"chosen too short ({len(chosen)} chars, min {self.cfg['min_chosen_chars']})"
|
||||
)
|
||||
if len(rejected) < self.cfg["min_rejected_chars"]:
|
||||
warnings.append(
|
||||
f"rejected too short ({len(rejected)} chars, min {self.cfg['min_rejected_chars']})"
|
||||
)
|
||||
|
||||
# --- Check 2: Chosen-Rejected length ratio ---
|
||||
if len(rejected) > 0:
|
||||
ratio = len(chosen) / len(rejected)
|
||||
scores["chosen_rejected_ratio"] = round(ratio, 2)
|
||||
if ratio < self.cfg["min_chosen_rejected_ratio"]:
|
||||
warnings.append(
|
||||
f"chosen/rejected ratio too low ({ratio:.2f}, "
|
||||
f"min {self.cfg['min_chosen_rejected_ratio']})"
|
||||
)
|
||||
else:
|
||||
scores["chosen_rejected_ratio"] = 0.0
|
||||
warnings.append("rejected is empty")
|
||||
|
||||
# --- Check 3: Chosen-Rejected content similarity ---
|
||||
chosen_tokens = self._tokenize(chosen)
|
||||
rejected_tokens = self._tokenize(rejected)
|
||||
cr_sim = self._jaccard(chosen_tokens, rejected_tokens)
|
||||
scores["chosen_rejected_similarity"] = round(cr_sim, 3)
|
||||
|
||||
if cr_sim > self.cfg["max_chosen_rejected_similarity"]:
|
||||
warnings.append(
|
||||
f"chosen≈rejected (Jaccard {cr_sim:.2f}, "
|
||||
f"max {self.cfg['max_chosen_rejected_similarity']})"
|
||||
)
|
||||
|
||||
# --- Check 4: Vocabulary diversity in chosen ---
|
||||
chosen_diversity = self._unique_word_ratio(chosen)
|
||||
scores["chosen_vocab_diversity"] = round(chosen_diversity, 3)
|
||||
if chosen_diversity < 0.3:
|
||||
warnings.append(
|
||||
f"low vocabulary diversity in chosen ({chosen_diversity:.2f})"
|
||||
)
|
||||
|
||||
# --- Check 5: Chosen should contain substantive content markers ---
|
||||
chosen_lower = chosen.lower()
|
||||
substance_markers = [
|
||||
"relevance", "implication", "training", "agent", "fleet",
|
||||
"hermes", "deploy", "architecture", "pipeline", "score",
|
||||
"technique", "approach", "recommend", "review", "action",
|
||||
]
|
||||
marker_hits = sum(1 for m in substance_markers if m in chosen_lower)
|
||||
scores["substance_markers"] = marker_hits
|
||||
if marker_hits < 2:
|
||||
warnings.append(
|
||||
f"chosen lacks substance markers ({marker_hits} found, min 2)"
|
||||
)
|
||||
|
||||
passed = len(warnings) == 0
|
||||
return PairReport(index=index, passed=passed, warnings=warnings, scores=scores)
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Batch-level validation (cross-pair checks)
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def _check_prompt_duplicates(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
|
||||
"""Find near-duplicate prompts within the batch.
|
||||
|
||||
Returns dict mapping pair index → warning string for duplicates.
|
||||
"""
|
||||
prompt_tokens = []
|
||||
for pair in pairs:
|
||||
prompt_tokens.append(self._tokenize(pair.get("prompt", "")))
|
||||
|
||||
dupe_warnings: Dict[int, str] = {}
|
||||
seen_groups: List[Set[int]] = []
|
||||
|
||||
for i in range(len(prompt_tokens)):
|
||||
# Skip if already in a dupe group
|
||||
if any(i in g for g in seen_groups):
|
||||
continue
|
||||
group = {i}
|
||||
for j in range(i + 1, len(prompt_tokens)):
|
||||
sim = self._jaccard(prompt_tokens[i], prompt_tokens[j])
|
||||
if sim > self.cfg["max_prompt_prompt_similarity"]:
|
||||
group.add(j)
|
||||
dupe_warnings[j] = (
|
||||
f"near-duplicate prompt (Jaccard {sim:.2f} with pair {i})"
|
||||
)
|
||||
if len(group) > 1:
|
||||
seen_groups.append(group)
|
||||
|
||||
return dupe_warnings
|
||||
|
||||
def _load_history_hashes(self) -> Set[str]:
|
||||
"""Load prompt hashes from recent JSONL files for cross-run dedup."""
|
||||
if self._history_hashes is not None:
|
||||
return self._history_hashes
|
||||
|
||||
hashes = set()
|
||||
if not self.output_dir.exists():
|
||||
self._history_hashes = hashes
|
||||
return hashes
|
||||
|
||||
jsonl_files = sorted(
|
||||
self.output_dir.glob("deepdive_*.jsonl"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
files_to_scan = jsonl_files[:self.cfg["dedup_history_files"]]
|
||||
for path in files_to_scan:
|
||||
try:
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
pair = json.loads(line)
|
||||
prompt_hash = self._content_hash(pair.get("prompt", ""))
|
||||
hashes.add(prompt_hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read history file {path}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(hashes)} prompt hashes from {len(files_to_scan)} history files")
|
||||
self._history_hashes = hashes
|
||||
return hashes
|
||||
|
||||
def _check_cross_run_dupes(self, pairs: List[Dict[str, Any]]) -> Dict[int, str]:
|
||||
"""Check if any pair prompts already exist in recent history.
|
||||
|
||||
Returns dict mapping pair index → warning string for duplicates.
|
||||
"""
|
||||
history = self._load_history_hashes()
|
||||
if not history:
|
||||
return {}
|
||||
|
||||
dupe_warnings: Dict[int, str] = {}
|
||||
for i, pair in enumerate(pairs):
|
||||
prompt_hash = self._content_hash(pair.get("prompt", ""))
|
||||
if prompt_hash in history:
|
||||
dupe_warnings[i] = "cross-run duplicate (prompt seen in recent history)"
|
||||
|
||||
return dupe_warnings
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Main validation entry point
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
def validate(self, pairs: List[Dict[str, Any]]) -> tuple:
|
||||
"""Validate a batch of DPO pairs.
|
||||
|
||||
Args:
|
||||
pairs: List of pair dicts with {prompt, chosen, rejected, ...}
|
||||
|
||||
Returns:
|
||||
(filtered_pairs, report): Tuple of filtered pair list and BatchReport.
|
||||
If flagged_pair_action="drop", filtered_pairs excludes bad pairs.
|
||||
If flagged_pair_action="flag", all pairs are returned with safety_flags updated.
|
||||
"""
|
||||
if not pairs:
|
||||
report = BatchReport(
|
||||
total_pairs=0, passed_pairs=0, dropped_pairs=0,
|
||||
flagged_pairs=0, duplicate_prompts_found=0,
|
||||
cross_run_duplicates_found=0,
|
||||
warnings=["Empty pair batch"],
|
||||
)
|
||||
return [], report
|
||||
|
||||
action = self.cfg["flagged_pair_action"]
|
||||
pair_dicts = [p if isinstance(p, dict) else p.to_dict() for p in pairs]
|
||||
|
||||
# Single-pair checks
|
||||
pair_reports = []
|
||||
for i, pair in enumerate(pair_dicts):
|
||||
report = self._validate_pair(pair, i)
|
||||
pair_reports.append(report)
|
||||
|
||||
# Cross-pair checks: prompt diversity
|
||||
prompt_dupe_warnings = self._check_prompt_duplicates(pair_dicts)
|
||||
for idx, warning in prompt_dupe_warnings.items():
|
||||
pair_reports[idx].warnings.append(warning)
|
||||
pair_reports[idx].passed = False
|
||||
|
||||
# Cross-run dedup
|
||||
crossrun_dupe_warnings = self._check_cross_run_dupes(pair_dicts)
|
||||
for idx, warning in crossrun_dupe_warnings.items():
|
||||
pair_reports[idx].warnings.append(warning)
|
||||
pair_reports[idx].passed = False
|
||||
|
||||
# Build filtered output
|
||||
filtered = []
|
||||
dropped = 0
|
||||
flagged = 0
|
||||
|
||||
for i, (pair, report) in enumerate(zip(pair_dicts, pair_reports)):
|
||||
if report.passed:
|
||||
filtered.append(pair)
|
||||
elif action == "drop":
|
||||
dropped += 1
|
||||
logger.debug(f"Dropping pair {i}: {report.warnings}")
|
||||
else: # "flag"
|
||||
# Add warnings to safety_flags
|
||||
flags = pair.get("safety_flags", [])
|
||||
flags.append("quality-flagged")
|
||||
for w in report.warnings:
|
||||
flags.append(f"qv:{w[:60]}")
|
||||
pair["safety_flags"] = flags
|
||||
filtered.append(pair)
|
||||
flagged += 1
|
||||
|
||||
passed = sum(1 for r in pair_reports if r.passed)
|
||||
|
||||
batch_warnings = []
|
||||
if passed == 0 and len(pairs) > 0:
|
||||
batch_warnings.append("ALL pairs failed validation — no training data produced")
|
||||
if len(prompt_dupe_warnings) > len(pairs) * 0.5:
|
||||
batch_warnings.append(
|
||||
f"High prompt duplication: {len(prompt_dupe_warnings)}/{len(pairs)} pairs are near-duplicates"
|
||||
)
|
||||
|
||||
# Task type diversity check
|
||||
task_types = Counter(p.get("task_type", "unknown") for p in filtered)
|
||||
if len(task_types) == 1 and len(filtered) > 3:
|
||||
batch_warnings.append(
|
||||
f"Low task-type diversity: all {len(filtered)} pairs are '{list(task_types.keys())[0]}'"
|
||||
)
|
||||
|
||||
batch_report = BatchReport(
|
||||
total_pairs=len(pairs),
|
||||
passed_pairs=passed,
|
||||
dropped_pairs=dropped,
|
||||
flagged_pairs=flagged,
|
||||
duplicate_prompts_found=len(prompt_dupe_warnings),
|
||||
cross_run_duplicates_found=len(crossrun_dupe_warnings),
|
||||
pair_reports=pair_reports,
|
||||
warnings=batch_warnings,
|
||||
)
|
||||
|
||||
logger.info(batch_report.summary())
|
||||
return filtered, batch_report
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI for standalone validation of existing JSONL files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Validate DPO pair quality")
|
||||
parser.add_argument("jsonl_file", type=Path, help="Path to JSONL file with DPO pairs")
|
||||
parser.add_argument("--json", action="store_true", help="Output JSON report")
|
||||
parser.add_argument("--strict", action="store_true",
|
||||
help="Drop flagged pairs (default: flag only)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.jsonl_file.exists():
|
||||
print(f"Error: file not found: {args.jsonl_file}")
|
||||
return 1
|
||||
|
||||
pairs = []
|
||||
with open(args.jsonl_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
pairs.append(json.loads(line))
|
||||
|
||||
config = {}
|
||||
if args.strict:
|
||||
config["flagged_pair_action"] = "drop"
|
||||
else:
|
||||
config["flagged_pair_action"] = "flag"
|
||||
|
||||
# Use parent dir of input file as output_dir for history scanning
|
||||
output_dir = args.jsonl_file.parent
|
||||
validator = DPOQualityValidator(config=config, output_dir=output_dir)
|
||||
filtered, report = validator.validate(pairs)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(report.to_dict(), indent=2))
|
||||
else:
|
||||
print("=" * 60)
|
||||
print(" DPO PAIR QUALITY VALIDATION REPORT")
|
||||
print("=" * 60)
|
||||
print(report.summary())
|
||||
print("-" * 60)
|
||||
for pr in report.pair_reports:
|
||||
status = "✓" if pr.passed else "✗"
|
||||
print(f" [{status}] Pair {pr.index}: ", end="")
|
||||
if pr.passed:
|
||||
print("OK")
|
||||
else:
|
||||
print(", ".join(pr.warnings))
|
||||
print("=" * 60)
|
||||
print(f"\nFiltered output: {len(filtered)} pairs "
|
||||
f"({'strict/drop' if args.strict else 'flag'} mode)")
|
||||
|
||||
return 0 if report.passed_pairs > 0 else 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
Reference in New Issue
Block a user