Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cef18fdcb | ||
|
|
706024e11e |
@@ -6,8 +6,9 @@ Stands between a broken man and a machine that would tell him to die.
|
||||
|
||||
from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urgency_emoji
|
||||
from .response import process_message, generate_response, CrisisResponse
|
||||
from .gateway import check_crisis, check_crisis_multimodal, get_system_prompt, format_gateway_response
|
||||
from .gateway import check_crisis, get_system_prompt, format_gateway_response
|
||||
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
|
||||
from .ab_testing import ABTestCrisisDetector, VariantRecord
|
||||
|
||||
__all__ = [
|
||||
"detect_crisis",
|
||||
@@ -16,7 +17,6 @@ __all__ = [
|
||||
"generate_response",
|
||||
"CrisisResponse",
|
||||
"check_crisis",
|
||||
"check_crisis_multimodal",
|
||||
"get_system_prompt",
|
||||
"format_result",
|
||||
"format_gateway_response",
|
||||
@@ -24,4 +24,6 @@ __all__ = [
|
||||
"CrisisSessionTracker",
|
||||
"SessionState",
|
||||
"check_crisis_with_session",
|
||||
"ABTestCrisisDetector",
|
||||
"VariantRecord",
|
||||
]
|
||||
|
||||
112
crisis/ab_testing.py
Normal file
112
crisis/ab_testing.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""A/B test framework for crisis detection in the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from .detect import CrisisDetectionResult
|
||||
|
||||
|
||||
def _get_variant_override() -> Optional[str]:
|
||||
"""Return env override for deterministic testing/debugging."""
|
||||
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
|
||||
if value in {"A", "B"}:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariantRecord:
|
||||
"""Single crisis detection event record with no user text or PII."""
|
||||
|
||||
variant: str
|
||||
level: str
|
||||
latency_ms: float
|
||||
indicator_count: int
|
||||
false_positive: Optional[bool] = None
|
||||
|
||||
|
||||
class ABTestCrisisDetector:
|
||||
"""Route crisis detection between two variants and collect comparison stats."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant_a: Callable[[str], CrisisDetectionResult],
|
||||
variant_b: Callable[[str], CrisisDetectionResult],
|
||||
split: float = 0.5,
|
||||
):
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.split = max(0.0, min(1.0, float(split)))
|
||||
self.records: List[VariantRecord] = []
|
||||
|
||||
def _select_variant(self) -> str:
|
||||
override = _get_variant_override()
|
||||
if override:
|
||||
return override
|
||||
return "A" if random.random() < self.split else "B"
|
||||
|
||||
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
|
||||
variant = self._select_variant()
|
||||
detector = self.variant_a if variant == "A" else self.variant_b
|
||||
|
||||
start = time.perf_counter()
|
||||
result = detector(text)
|
||||
latency_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
record = VariantRecord(
|
||||
variant=variant,
|
||||
level=result.level,
|
||||
latency_ms=latency_ms,
|
||||
indicator_count=len(result.indicators),
|
||||
)
|
||||
self.records.append(record)
|
||||
return result, variant, len(self.records) - 1
|
||||
|
||||
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
|
||||
if record_id < 0 or record_id >= len(self.records):
|
||||
raise IndexError(f"Unknown record id: {record_id}")
|
||||
self.records[record_id].false_positive = bool(false_positive)
|
||||
|
||||
def get_stats(self) -> Dict[str, dict]:
|
||||
stats: Dict[str, dict] = {}
|
||||
for variant in ("A", "B"):
|
||||
records = [record for record in self.records if record.variant == variant]
|
||||
if not records:
|
||||
stats[variant] = {
|
||||
"count": 0,
|
||||
"reviewed_count": 0,
|
||||
"false_positive_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
levels: Dict[str, int] = {}
|
||||
for record in records:
|
||||
levels[record.level] = levels.get(record.level, 0) + 1
|
||||
|
||||
reviewed = [record for record in records if record.false_positive is not None]
|
||||
false_positive_rate = None
|
||||
if reviewed:
|
||||
false_positive_rate = round(
|
||||
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
|
||||
4,
|
||||
)
|
||||
|
||||
stats[variant] = {
|
||||
"count": len(records),
|
||||
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
|
||||
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
|
||||
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
|
||||
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
|
||||
"levels": levels,
|
||||
"reviewed_count": len(reviewed),
|
||||
"false_positive_rate": false_positive_rate,
|
||||
}
|
||||
return stats
|
||||
|
||||
def reset(self) -> None:
|
||||
self.records.clear()
|
||||
@@ -2,21 +2,18 @@
|
||||
Crisis Gateway Module for the-door.
|
||||
|
||||
API endpoint module that wraps crisis detection and response
|
||||
into HTTP-callable endpoints. Integrates detect.py, unified_scorer.py, and response.py.
|
||||
into HTTP-callable endpoints. Integrates detect.py and response.py.
|
||||
|
||||
Usage:
|
||||
from crisis.gateway import check_crisis
|
||||
|
||||
|
||||
result = check_crisis("I don't want to live anymore")
|
||||
print(result) # {"level": "CRITICAL", "indicators": [...], "response": {...}}
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from unified_scorer import UnifiedCrisisScorer, UnifiedScoreAuditLog, behavioral_score_from_session
|
||||
|
||||
from .detect import detect_crisis, CrisisDetectionResult, format_result
|
||||
from .compassion_router import router
|
||||
from .response import (
|
||||
@@ -53,74 +50,6 @@ def check_crisis(text: str) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def check_crisis_multimodal(
|
||||
text: str,
|
||||
*,
|
||||
tracker: Optional[CrisisSessionTracker] = None,
|
||||
voice_score: Optional[float] = None,
|
||||
image_score: Optional[float] = None,
|
||||
behavioral_score: Optional[float] = None,
|
||||
audit_log_path: Optional[Path] = None,
|
||||
weights: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Combine text, voice, image, and behavioral signals into one crisis assessment."""
|
||||
detection = detect_crisis(text)
|
||||
session_state = tracker.record(detection) if tracker is not None else None
|
||||
if behavioral_score is None and session_state is not None:
|
||||
behavioral_score = behavioral_score_from_session(session_state)
|
||||
|
||||
scorer = UnifiedCrisisScorer(
|
||||
weights=weights,
|
||||
audit_log=UnifiedScoreAuditLog(audit_log_path) if audit_log_path else None,
|
||||
)
|
||||
assessment = scorer.score(
|
||||
text_score=detection.score,
|
||||
voice_score=voice_score,
|
||||
image_score=image_score,
|
||||
behavioral_score=behavioral_score,
|
||||
source_text=text,
|
||||
)
|
||||
|
||||
unified_detection = CrisisDetectionResult(
|
||||
level=assessment.level.value,
|
||||
indicators=detection.indicators,
|
||||
recommended_action=detection.recommended_action,
|
||||
score=assessment.combined_score,
|
||||
matches=detection.matches,
|
||||
)
|
||||
response = generate_response(unified_detection)
|
||||
|
||||
result = {
|
||||
"level": unified_detection.level,
|
||||
"score": unified_detection.score,
|
||||
"indicators": detection.indicators,
|
||||
"recommended_action": unified_detection.recommended_action,
|
||||
"timmy_message": response.timmy_message,
|
||||
"ui": {
|
||||
"show_crisis_panel": response.show_crisis_panel,
|
||||
"show_overlay": response.show_overlay,
|
||||
"provide_988": response.provide_988,
|
||||
},
|
||||
"escalate": response.escalate,
|
||||
"unified": {
|
||||
"level": assessment.level.value,
|
||||
"combined_score": assessment.combined_score,
|
||||
"weights": assessment.weights,
|
||||
"modalities": assessment.modalities,
|
||||
"present_modalities": assessment.present_modalities,
|
||||
},
|
||||
}
|
||||
if session_state is not None:
|
||||
result["session"] = {
|
||||
"current_level": session_state.current_level,
|
||||
"peak_level": session_state.peak_level,
|
||||
"message_count": session_state.message_count,
|
||||
"is_escalating": session_state.is_escalating,
|
||||
"is_deescalating": session_state.is_deescalating,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def get_system_prompt(base_prompt: str, text: str = "") -> str:
|
||||
"""
|
||||
Sovereign Heart System Prompt Override.
|
||||
|
||||
138
tests/test_ab_testing.py
Normal file
138
tests/test_ab_testing.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crisis.ab_testing import ABTestCrisisDetector
|
||||
from crisis.detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_variant_override():
|
||||
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["CRISIS_AB_VARIANT"] = old
|
||||
else:
|
||||
os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
|
||||
|
||||
def _make_variant(level: str, indicators=None):
|
||||
indicators = indicators or [f"mock_{level.lower()}"]
|
||||
|
||||
def fn(text: str) -> CrisisDetectionResult:
|
||||
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def test_detect_returns_result_variant_and_logged_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", return_value="A"):
|
||||
result, variant, record_id = detector.detect("test message")
|
||||
|
||||
assert isinstance(result, CrisisDetectionResult)
|
||||
assert variant == "A"
|
||||
assert record_id == 0
|
||||
assert len(detector.records) == 1
|
||||
assert detector.records[0].variant == "A"
|
||||
assert detector.records[0].level == "LOW"
|
||||
|
||||
|
||||
def test_env_override_forces_variant_b():
|
||||
os.environ["CRISIS_AB_VARIANT"] = "b"
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
result, variant, _ = detector.detect("test")
|
||||
|
||||
assert variant == "B"
|
||||
assert result.level == "HIGH"
|
||||
|
||||
|
||||
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("CRITICAL"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
detector.detect("first")
|
||||
detector.detect("second")
|
||||
detector.detect("third")
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 2
|
||||
assert stats["B"]["count"] == 1
|
||||
assert stats["A"]["levels"]["LOW"] == 2
|
||||
assert stats["B"]["levels"]["CRITICAL"] == 1
|
||||
assert "avg_latency_ms" in stats["A"]
|
||||
assert "avg_indicator_count" in stats["B"]
|
||||
|
||||
|
||||
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
_, _, a0 = detector.detect("first")
|
||||
_, _, a1 = detector.detect("second")
|
||||
_, _, b0 = detector.detect("third")
|
||||
|
||||
detector.record_outcome(a0, false_positive=True)
|
||||
detector.record_outcome(a1, false_positive=False)
|
||||
detector.record_outcome(b0, false_positive=False)
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["reviewed_count"] == 2
|
||||
assert stats["A"]["false_positive_rate"] == 0.5
|
||||
assert stats["B"]["false_positive_rate"] == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_rejects_unknown_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
detector.record_outcome(99, false_positive=True)
|
||||
|
||||
|
||||
def test_reset_clears_records_and_stats():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
detector.detect("test")
|
||||
detector.reset()
|
||||
|
||||
assert detector.records == []
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 0
|
||||
assert stats["B"]["count"] == 0
|
||||
|
||||
|
||||
def test_with_real_detector_integration():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=detect_crisis,
|
||||
variant_b=detect_crisis,
|
||||
)
|
||||
|
||||
result, variant, record_id = detector.detect("I want to kill myself")
|
||||
|
||||
assert result.level == "CRITICAL"
|
||||
assert variant in ("A", "B")
|
||||
assert record_id == 0
|
||||
@@ -1,19 +0,0 @@
|
||||
from crisis.gateway import check_crisis_multimodal
|
||||
from crisis.session_tracker import CrisisSessionTracker
|
||||
|
||||
|
||||
def test_multimodal_gateway_uses_unified_score_for_988_ui(tmp_path):
|
||||
tracker = CrisisSessionTracker()
|
||||
result = check_crisis_multimodal(
|
||||
"I want to kill myself tonight",
|
||||
tracker=tracker,
|
||||
voice_score=0.92,
|
||||
image_score=0.6,
|
||||
audit_log_path=tmp_path / "audit.jsonl",
|
||||
)
|
||||
|
||||
assert result["unified"]["level"] == "CRITICAL"
|
||||
assert result["ui"]["provide_988"] is True
|
||||
assert result["ui"]["show_overlay"] is True
|
||||
assert result["unified"]["modalities"]["voice"] == 0.92
|
||||
assert result["unified"]["modalities"]["behavioral"] >= 0.0
|
||||
@@ -1,51 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from unified_scorer import (
|
||||
CrisisLevel,
|
||||
UnifiedCrisisScorer,
|
||||
UnifiedScoreAuditLog,
|
||||
behavioral_score_from_session,
|
||||
)
|
||||
from crisis.session_tracker import SessionState
|
||||
|
||||
|
||||
def test_unified_scorer_renormalizes_available_modalities_and_escalates():
|
||||
scorer = UnifiedCrisisScorer()
|
||||
assessment = scorer.score(
|
||||
text_score=1.0,
|
||||
voice_score=0.8,
|
||||
image_score=None,
|
||||
behavioral_score=0.7,
|
||||
)
|
||||
|
||||
assert assessment.level is CrisisLevel.CRITICAL
|
||||
assert assessment.combined_score > 0.8
|
||||
assert assessment.present_modalities == ["text", "voice", "behavioral"]
|
||||
|
||||
|
||||
def test_behavioral_score_rises_for_escalating_session_state():
|
||||
session = SessionState(
|
||||
current_level="HIGH",
|
||||
peak_level="CRITICAL",
|
||||
message_count=4,
|
||||
level_history=["LOW", "MEDIUM", "HIGH", "CRITICAL"],
|
||||
is_escalating=True,
|
||||
is_deescalating=False,
|
||||
escalation_rate=1.0,
|
||||
consecutive_low_messages=0,
|
||||
)
|
||||
|
||||
assert behavioral_score_from_session(session) >= 0.8
|
||||
|
||||
|
||||
def test_audit_log_persists_anonymized_score_entries(tmp_path):
|
||||
log_path = tmp_path / "unified-score-audit.jsonl"
|
||||
scorer = UnifiedCrisisScorer(audit_log=UnifiedScoreAuditLog(log_path))
|
||||
scorer.score(text_score=0.75, voice_score=0.2, image_score=0.1, behavioral_score=0.6, source_text="I feel trapped and hopeless")
|
||||
|
||||
lines = log_path.read_text().strip().splitlines()
|
||||
assert len(lines) == 1
|
||||
entry = lines[0]
|
||||
assert "trapped and hopeless" not in entry
|
||||
assert '"text_fingerprint"' in entry
|
||||
assert '"combined_score"' in entry
|
||||
@@ -1,126 +0,0 @@
|
||||
"""Unified multimodal crisis scoring for the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crisis.session_tracker import SessionState
|
||||
|
||||
|
||||
SCORE_BY_LEVEL = {"NONE": 0.0, "LOW": 0.25, "MEDIUM": 0.5, "HIGH": 0.75, "CRITICAL": 1.0}
|
||||
LEVEL_RANK = {"NONE": 0, "LOW": 1, "MEDIUM": 2, "HIGH": 3, "CRITICAL": 4}
|
||||
|
||||
|
||||
class CrisisLevel(Enum):
|
||||
NONE = "NONE"
|
||||
LOW = "LOW"
|
||||
MEDIUM = "MEDIUM"
|
||||
HIGH = "HIGH"
|
||||
CRITICAL = "CRITICAL"
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS: Dict[str, float] = {
|
||||
"text": 0.4,
|
||||
"voice": 0.25,
|
||||
"behavioral": 0.2,
|
||||
"image": 0.15,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedAssessment:
|
||||
level: CrisisLevel
|
||||
combined_score: float
|
||||
weights: Dict[str, float]
|
||||
modalities: Dict[str, Optional[float]]
|
||||
present_modalities: List[str]
|
||||
|
||||
|
||||
class UnifiedScoreAuditLog:
|
||||
def __init__(self, path: Path | str):
|
||||
self.path = Path(path)
|
||||
|
||||
def record(self, assessment: UnifiedAssessment, source_text: str = "") -> None:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fingerprint = hashlib.sha256(source_text.encode("utf-8")).hexdigest()[:12] if source_text else None
|
||||
payload = {
|
||||
"level": assessment.level.value,
|
||||
"combined_score": round(assessment.combined_score, 4),
|
||||
"weights": assessment.weights,
|
||||
"modalities": assessment.modalities,
|
||||
"present_modalities": assessment.present_modalities,
|
||||
"text_fingerprint": fingerprint,
|
||||
}
|
||||
with self.path.open("a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(payload, sort_keys=True) + "\n")
|
||||
|
||||
|
||||
class UnifiedCrisisScorer:
|
||||
def __init__(self, weights: Optional[Dict[str, float]] = None, audit_log: Optional[UnifiedScoreAuditLog] = None):
|
||||
self.weights = dict(DEFAULT_WEIGHTS)
|
||||
if weights:
|
||||
self.weights.update(weights)
|
||||
self.audit_log = audit_log
|
||||
|
||||
def _normalize(self, modalities: Dict[str, Optional[float]]) -> Dict[str, float]:
|
||||
present = [name for name, score in modalities.items() if score is not None]
|
||||
if not present:
|
||||
return {}
|
||||
total = sum(self.weights[name] for name in present)
|
||||
return {name: self.weights[name] / total for name in present}
|
||||
|
||||
def _level_for_score(self, score: float) -> CrisisLevel:
|
||||
if score > 0.8:
|
||||
return CrisisLevel.CRITICAL
|
||||
if score > 0.6:
|
||||
return CrisisLevel.HIGH
|
||||
if score > 0.4:
|
||||
return CrisisLevel.MEDIUM
|
||||
if score > 0.0:
|
||||
return CrisisLevel.LOW
|
||||
return CrisisLevel.NONE
|
||||
|
||||
def score(
|
||||
self,
|
||||
*,
|
||||
text_score: Optional[float],
|
||||
voice_score: Optional[float] = None,
|
||||
image_score: Optional[float] = None,
|
||||
behavioral_score: Optional[float] = None,
|
||||
source_text: str = "",
|
||||
) -> UnifiedAssessment:
|
||||
modalities = {
|
||||
"text": text_score,
|
||||
"voice": voice_score,
|
||||
"behavioral": behavioral_score,
|
||||
"image": image_score,
|
||||
}
|
||||
normalized = self._normalize(modalities)
|
||||
combined = 0.0
|
||||
for name, weight in normalized.items():
|
||||
combined += float(modalities[name]) * weight
|
||||
assessment = UnifiedAssessment(
|
||||
level=self._level_for_score(combined),
|
||||
combined_score=combined,
|
||||
weights=normalized,
|
||||
modalities=modalities,
|
||||
present_modalities=[name for name, score in modalities.items() if score is not None],
|
||||
)
|
||||
if self.audit_log:
|
||||
self.audit_log.record(assessment, source_text=source_text)
|
||||
return assessment
|
||||
|
||||
|
||||
def behavioral_score_from_session(session: 'SessionState') -> float:
|
||||
current = SCORE_BY_LEVEL.get(session.current_level, 0.0)
|
||||
peak_bonus = 0.1 if LEVEL_RANK.get(session.peak_level, 0) >= LEVEL_RANK["HIGH"] else 0.0
|
||||
escalation_bonus = 0.15 if session.is_escalating else 0.0
|
||||
rate_bonus = min(max(session.escalation_rate, 0.0), 1.0) * 0.1
|
||||
deescalation_penalty = 0.15 if session.is_deescalating else 0.0
|
||||
return max(0.0, min(1.0, current + peak_bonus + escalation_bonus + rate_bonus - deescalation_penalty))
|
||||
Reference in New Issue
Block a user