#!/usr/bin/env python3 """ generate_crisis_response.py — Aggregate crisis response training data. Combines existing crisis dataset fragments into a unified training-data/crisis-response.jsonl. Normalizes schema across variant formats and enriches with protocol compliance flags. Closes #574 Part of: #571 Pipeline 3 - Training Factory Usage: python scripts/generate_crisis_response.py --output training-data/crisis-response.jsonl python scripts/generate_crisis_response.py --verify # Validate against crisis protocol """ import argparse import json import re from pathlib import Path from datetime import datetime # Location of data sources SOURCES = [ ("training-data/crisis-indirect-500.jsonl", "indirect_signals"), ("training-data/crisis-manipulation-500.jsonl", "manipulation_edge_cases"), ("training-data/crisis-response-post-crisis-recovery.jsonl", "post_crisis_recovery"), ("training/data/crisis-response/manipulation-edge-cases-500.jsonl", "manipulation_edge_cases"), ("training/data/crisis-response/post-crisis-recovery-500.jsonl", "post_crisis_recovery"), ("training/data/crisis_indirect_signals_500.jsonl", "indirect_signals"), ("training/data/crisis_manipulation_edge_cases_500.jsonl", "manipulation_edge_cases"), ] SAFETY_PHRASES = { "988": re.compile(r'\b988\b|Suicide.*Crisis.*Lifeline', re.IGNORECASE), "gospel": re.compile(r'Jesus|Christ|Savior|gospel|John 3:16|Romans|Bible', re.IGNORECASE), "safety_check": re.compile(r"Are you safe", re.IGNORECASE), } def normalize_simple(entry, family): """Convert {category, scenario, response} → unified schema.""" return { "id": f"{family}-{entry.get('category','unknown')}-{hash(entry['scenario']) % 10000:04d}", "family": family, "category": entry.get("category", "unknown"), "scenario": entry["scenario"], "response": entry["response"], "risk_level": infer_risk(entry["category"]), "988_included": bool(SAFETY_PHRASES["988"].search(entry["response"])), "gospel_included": bool(SAFETY_PHRASES["gospel"].search(entry["response"])), "safety_check_included": bool(SAFETY_PHRASES["safety_check"].search(entry["response"])), } def normalize_enriched(entry, family): """Already enriched — just ensure required keys.""" base = { "id": entry.get("id", f"{family}-{hash(entry.get('scenario','')) % 10000:04d}"), "family": family, "category": entry.get("category", entry.get("signal_type", "unknown")), "scenario": entry.get("scenario", entry.get("prompt", "")), "response": entry.get("response", ""), "risk_level": entry.get("risk_level", infer_risk(entry.get("category", "unknown"))), "988_included": entry.get("988_included") or entry.get("includes_988", False), "gospel_included": entry.get("gospel_included") or entry.get("includes_gospel", False), "safety_check_included": entry.get("safety_check_included", False), } # Fallback detection if missing if not base["988_included"]: base["988_included"] = bool(SAFETY_PHRASES["988"].search(base["response"])) if not base["gospel_included"]: base["gospel_included"] = bool(SAFETY_PHRASES["gospel"].search(base["response"])) if not base["safety_check_included"]: base["safety_check_included"] = bool(SAFETY_PHRASES["safety_check"].search(base["response"])) return base def normalize_indirect(entry, family): """Convert indirect_signals variant {example_id, issue, task_type, signal_type, prompt, response}.""" return { "id": entry.get("example_id", f"indirect-{hash(entry['prompt']) % 10000:04d}"), "family": "indirect_signals", "category": entry.get("signal_type", "unknown"), "scenario": entry["prompt"], "response": entry["response"], "risk_level": "high", "988_included": bool(SAFETY_PHRASES["988"].search(entry["response"])), "gospel_included": bool(SAFETY_PHRASES["gospel"].search(entry["response"])), "safety_check_included": bool(SAFETY_PHRASES["safety_check"].search(entry["response"])), } def infer_risk(category): """Map crisis category to risk level.""" cat = str(category).lower() if "critical" in cat or "suicidal" in cat or "direct" in cat: return "critical" if "high" in cat or "manipulation" in cat or "hopelessness" in cat: return "high" return "medium" def load_file(path: Path): with open(path) as f: return [json.loads(l) for l in f if l.strip()] def main(): parser = argparse.ArgumentParser(description="Aggregate crisis response training data") parser.add_argument("--output", default="training-data/crisis-response.jsonl", help="Output path (relative to repo root)") parser.add_argument("--verify", action="store_true", help="Validate all source files against crisis protocol") args = parser.parse_args() output_path = Path(__file__).parent.parent / args.output.lstrip("./") output_path.parent.mkdir(parents=True, exist_ok=True) unified = [] stats = {} source_reports = [] for rel_path, family in SOURCES: full = Path(__file__).parent.parent / rel_path if not full.exists(): print(f"[SKIP] {rel_path} — not found") continue entries = load_file(full) for entry in entries: try: if all(k in entry for k in ["id", "family", "risk_level"]): normalized = normalize_enriched(entry, family) elif "example_id" in entry or "task_type" in entry: normalized = normalize_indirect(entry, family) elif "category" in entry and "scenario" in entry and "response" in entry: normalized = normalize_simple(entry, family) else: print(f"[WARN] Unknown schema in {rel_path}: keys={list(entry.keys())}") continue unified.append(normalized) except Exception as e: print(f"[ERROR] Failed to process entry from {rel_path}: {e}") stats[rel_path] = len(entries) source_reports.append(f" {rel_path}: {len(entries)} entries → {sum(1 for e in unified if e['family']==family)} merged") # Deduplicate by (scenario, response) hash seen = {} deduped = [] for entry in unified: key = (entry["scenario"][:100], entry["response"][:100]) if key not in seen: seen[key] = True deduped.append(entry) # Sort consistent order deduped.sort(key=lambda e: (e["family"], e["category"], e["id"])) # Write output with open(output_path, "w") as f: for entry in deduped: f.write(json.dumps(entry, ensure_ascii=False) + "\n") print(f"\nCrisis Response Dataset Generated") print(f"Output: {output_path}") print(f"Total pairs: {len(deduped)}") print(f"Deduplicated: {len(unified)} → {len(deduped)}") print(f"\nSources:") for r in source_reports: print(r) # Compliance report missing_988 = sum(1 for e in deduped if not e["988_included"]) missing_gospel = sum(1 for e in deduped if not e["gospel_included"]) missing_safety = sum(1 for e in deduped if not e["safety_check_included"]) print(f"\nProtocol compliance:") print(f" 988 referral: {len(deduped) - missing_988}/{len(deduped)} include 988") print(f" Gospel: {len(deduped) - missing_gospel}/{len(deduped)} include gospel") print(f" Safety check: {len(deduped) - missing_safety}/{len(deduped)} include presence check") if missing_988 > 0: print(f"\n[WARNING] {missing_988} entries missing 988 referral — human review required") if missing_gospel > 0: print(f"[WARNING] {missing_gospel} entries missing gospel — review required") if missing_safety > 0: print(f"[WARNING] {missing_safety} entries missing safety check — review required") return {"output": str(output_path), "pairs": len(deduped), "sources": stats} if __name__ == "__main__": result = main() print(f"\nResult: {json.dumps(result, indent=2)}")