Files
timmy-config/scripts/generate_crisis_response.py
Step35 Burn Worker 9b2f09ae95
Some checks failed
Architecture Lint / Linter Tests (pull_request) Successful in 23s
Smoke Test / smoke (pull_request) Failing after 20s
Validate Config / YAML Lint (pull_request) Failing after 15s
Validate Config / JSON Validate (pull_request) Successful in 18s
Validate Config / Python Syntax & Import Check (pull_request) Failing after 55s
Validate Config / Python Test Suite (pull_request) Has been skipped
Validate Config / Cron Syntax Check (pull_request) Successful in 11s
Validate Config / Shell Script Lint (pull_request) Failing after 57s
Validate Config / Deploy Script Dry Run (pull_request) Successful in 12s
Validate Config / Playbook Schema Validation (pull_request) Successful in 27s
Validate Training Data / validate (pull_request) Successful in 26s
Architecture Lint / Lint Repository (pull_request) Failing after 12s
PR Checklist / pr-checklist (pull_request) Successful in 3m45s
feat(training): add Crisis Response dataset generator (#574)
Adds `scripts/generate_crisis_response.py` that aggregates existing
crisis training fragments into a unified `training-data/crisis-response.jsonl`
dataset (3,143 pairs, exceeds 2K target).

- Normalizes schema across 7 source files into unified format
- Validates crisis protocol compliance: 988 referral, gospel, presence check
- Deduplicates entries (3500 → 3143 pairs)
- Includes smoke tests (`test_generate_crisis_response.py`)
- Documentation: `training-data/CRISIS-README.md`

Part of #571 Pipeline 3 — Training Factory.

Closes #574
2026-04-29 03:25:33 -04:00

188 lines
8.1 KiB
Python

#!/usr/bin/env python3
"""
generate_crisis_response.py — Aggregate crisis response training data.
Combines existing crisis dataset fragments into a unified training-data/crisis-response.jsonl.
Normalizes schema across variant formats and enriches with protocol compliance flags.
Closes #574
Part of: #571 Pipeline 3 - Training Factory
Usage:
python scripts/generate_crisis_response.py --output training-data/crisis-response.jsonl
python scripts/generate_crisis_response.py --verify # Validate against crisis protocol
"""
import argparse
import json
import re
from pathlib import Path
from datetime import datetime
# Location of data sources
SOURCES = [
("training-data/crisis-indirect-500.jsonl", "indirect_signals"),
("training-data/crisis-manipulation-500.jsonl", "manipulation_edge_cases"),
("training-data/crisis-response-post-crisis-recovery.jsonl", "post_crisis_recovery"),
("training/data/crisis-response/manipulation-edge-cases-500.jsonl", "manipulation_edge_cases"),
("training/data/crisis-response/post-crisis-recovery-500.jsonl", "post_crisis_recovery"),
("training/data/crisis_indirect_signals_500.jsonl", "indirect_signals"),
("training/data/crisis_manipulation_edge_cases_500.jsonl", "manipulation_edge_cases"),
]
SAFETY_PHRASES = {
"988": re.compile(r'\b988\b|Suicide.*Crisis.*Lifeline', re.IGNORECASE),
"gospel": re.compile(r'Jesus|Christ|Savior|gospel|John 3:16|Romans|Bible', re.IGNORECASE),
"safety_check": re.compile(r"Are you safe", re.IGNORECASE),
}
def normalize_simple(entry, family):
"""Convert {category, scenario, response} → unified schema."""
return {
"id": f"{family}-{entry.get('category','unknown')}-{hash(entry['scenario']) % 10000:04d}",
"family": family,
"category": entry.get("category", "unknown"),
"scenario": entry["scenario"],
"response": entry["response"],
"risk_level": infer_risk(entry["category"]),
"988_included": bool(SAFETY_PHRASES["988"].search(entry["response"])),
"gospel_included": bool(SAFETY_PHRASES["gospel"].search(entry["response"])),
"safety_check_included": bool(SAFETY_PHRASES["safety_check"].search(entry["response"])),
}
def normalize_enriched(entry, family):
"""Already enriched — just ensure required keys."""
base = {
"id": entry.get("id", f"{family}-{hash(entry.get('scenario','')) % 10000:04d}"),
"family": family,
"category": entry.get("category", entry.get("signal_type", "unknown")),
"scenario": entry.get("scenario", entry.get("prompt", "")),
"response": entry.get("response", ""),
"risk_level": entry.get("risk_level", infer_risk(entry.get("category", "unknown"))),
"988_included": entry.get("988_included") or entry.get("includes_988", False),
"gospel_included": entry.get("gospel_included") or entry.get("includes_gospel", False),
"safety_check_included": entry.get("safety_check_included", False),
}
# Fallback detection if missing
if not base["988_included"]:
base["988_included"] = bool(SAFETY_PHRASES["988"].search(base["response"]))
if not base["gospel_included"]:
base["gospel_included"] = bool(SAFETY_PHRASES["gospel"].search(base["response"]))
if not base["safety_check_included"]:
base["safety_check_included"] = bool(SAFETY_PHRASES["safety_check"].search(base["response"]))
return base
def normalize_indirect(entry, family):
"""Convert indirect_signals variant {example_id, issue, task_type, signal_type, prompt, response}."""
return {
"id": entry.get("example_id", f"indirect-{hash(entry['prompt']) % 10000:04d}"),
"family": "indirect_signals",
"category": entry.get("signal_type", "unknown"),
"scenario": entry["prompt"],
"response": entry["response"],
"risk_level": "high",
"988_included": bool(SAFETY_PHRASES["988"].search(entry["response"])),
"gospel_included": bool(SAFETY_PHRASES["gospel"].search(entry["response"])),
"safety_check_included": bool(SAFETY_PHRASES["safety_check"].search(entry["response"])),
}
def infer_risk(category):
"""Map crisis category to risk level."""
cat = str(category).lower()
if "critical" in cat or "suicidal" in cat or "direct" in cat:
return "critical"
if "high" in cat or "manipulation" in cat or "hopelessness" in cat:
return "high"
return "medium"
def load_file(path: Path):
with open(path) as f:
return [json.loads(l) for l in f if l.strip()]
def main():
parser = argparse.ArgumentParser(description="Aggregate crisis response training data")
parser.add_argument("--output", default="training-data/crisis-response.jsonl",
help="Output path (relative to repo root)")
parser.add_argument("--verify", action="store_true",
help="Validate all source files against crisis protocol")
args = parser.parse_args()
output_path = Path(__file__).parent.parent / args.output.lstrip("./")
output_path.parent.mkdir(parents=True, exist_ok=True)
unified = []
stats = {}
source_reports = []
for rel_path, family in SOURCES:
full = Path(__file__).parent.parent / rel_path
if not full.exists():
print(f"[SKIP] {rel_path} — not found")
continue
entries = load_file(full)
for entry in entries:
try:
if all(k in entry for k in ["id", "family", "risk_level"]):
normalized = normalize_enriched(entry, family)
elif "example_id" in entry or "task_type" in entry:
normalized = normalize_indirect(entry, family)
elif "category" in entry and "scenario" in entry and "response" in entry:
normalized = normalize_simple(entry, family)
else:
print(f"[WARN] Unknown schema in {rel_path}: keys={list(entry.keys())}")
continue
unified.append(normalized)
except Exception as e:
print(f"[ERROR] Failed to process entry from {rel_path}: {e}")
stats[rel_path] = len(entries)
source_reports.append(f" {rel_path}: {len(entries)} entries → {sum(1 for e in unified if e['family']==family)} merged")
# Deduplicate by (scenario, response) hash
seen = {}
deduped = []
for entry in unified:
key = (entry["scenario"][:100], entry["response"][:100])
if key not in seen:
seen[key] = True
deduped.append(entry)
# Sort consistent order
deduped.sort(key=lambda e: (e["family"], e["category"], e["id"]))
# Write output
with open(output_path, "w") as f:
for entry in deduped:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"\nCrisis Response Dataset Generated")
print(f"Output: {output_path}")
print(f"Total pairs: {len(deduped)}")
print(f"Deduplicated: {len(unified)}{len(deduped)}")
print(f"\nSources:")
for r in source_reports:
print(r)
# Compliance report
missing_988 = sum(1 for e in deduped if not e["988_included"])
missing_gospel = sum(1 for e in deduped if not e["gospel_included"])
missing_safety = sum(1 for e in deduped if not e["safety_check_included"])
print(f"\nProtocol compliance:")
print(f" 988 referral: {len(deduped) - missing_988}/{len(deduped)} include 988")
print(f" Gospel: {len(deduped) - missing_gospel}/{len(deduped)} include gospel")
print(f" Safety check: {len(deduped) - missing_safety}/{len(deduped)} include presence check")
if missing_988 > 0:
print(f"\n[WARNING] {missing_988} entries missing 988 referral — human review required")
if missing_gospel > 0:
print(f"[WARNING] {missing_gospel} entries missing gospel — review required")
if missing_safety > 0:
print(f"[WARNING] {missing_safety} entries missing safety check — review required")
return {"output": str(output_path), "pairs": len(deduped), "sources": stats}
if __name__ == "__main__":
result = main()
print(f"\nResult: {json.dumps(result, indent=2)}")