feat: Unified multimodal crisis scorer (#134) #145
285
crisis/unified_scorer.py
Normal file
285
crisis/unified_scorer.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Unified Multimodal Crisis Scorer
|
||||
|
||||
Combines text, voice, image, and behavioral scores into a single
|
||||
crisis risk assessment for the-door project.
|
||||
|
||||
Epic: #102 (Multimodal Crisis Detection)
|
||||
Issue: #134
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrisisLevel(IntEnum):
|
||||
"""Crisis severity levels."""
|
||||
NONE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityScore:
|
||||
"""Score from a single modality."""
|
||||
modality: str # "text", "voice", "image", "behavioral"
|
||||
score: float # 0.0 to 1.0
|
||||
confidence: float # 0.0 to 1.0
|
||||
indicators: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnifiedCrisisScore:
|
||||
"""Combined crisis score from all modalities."""
|
||||
level: CrisisLevel
|
||||
combined_score: float
|
||||
modality_scores: Dict[str, ModalityScore]
|
||||
timestamp: str
|
||||
session_id: Optional[str] = None
|
||||
requires_988: bool = False
|
||||
requires_human: bool = False
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"level": self.level.name,
|
||||
"level_value": self.level.value,
|
||||
"combined_score": self.combined_score,
|
||||
"modality_scores": {k: v.to_dict() for k, v in self.modality_scores.items()},
|
||||
"timestamp": self.timestamp,
|
||||
"session_id": self.session_id,
|
||||
"requires_988": self.requires_988,
|
||||
"requires_human": self.requires_human,
|
||||
}
|
||||
|
||||
|
||||
# Default weights per modality
|
||||
DEFAULT_WEIGHTS = {
|
||||
"text": 0.40,
|
||||
"voice": 0.25,
|
||||
"behavioral": 0.20,
|
||||
"image": 0.15,
|
||||
}
|
||||
|
||||
# Crisis level thresholds
|
||||
THRESHOLDS = {
|
||||
CrisisLevel.CRITICAL: 0.80,
|
||||
CrisisLevel.HIGH: 0.60,
|
||||
CrisisLevel.MEDIUM: 0.40,
|
||||
CrisisLevel.LOW: 0.20,
|
||||
CrisisLevel.NONE: 0.00,
|
||||
}
|
||||
|
||||
|
||||
class UnifiedCrisisScorer:
|
||||
"""
|
||||
Combine scores from multiple modalities into unified crisis assessment.
|
||||
|
||||
Usage:
|
||||
scorer = UnifiedCrisisScorer()
|
||||
|
||||
# Add modality scores
|
||||
scorer.add_score("text", 0.85, 0.9, ["suicide_keyword"])
|
||||
scorer.add_score("voice", 0.6, 0.7, ["distress_tone"])
|
||||
|
||||
# Get unified assessment
|
||||
result = scorer.assess()
|
||||
if result.level >= CrisisLevel.HIGH:
|
||||
deliver_988_resources()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights: Optional[Dict[str, float]] = None,
|
||||
audit_log_path: Optional[Path] = None
|
||||
):
|
||||
self.weights = weights or DEFAULT_WEIGHTS.copy()
|
||||
self.audit_log_path = audit_log_path
|
||||
self._scores: Dict[str, ModalityScore] = {}
|
||||
self._session_id: Optional[str] = None
|
||||
|
||||
def set_session(self, session_id: str):
|
||||
"""Set session ID for audit logging."""
|
||||
self._session_id = session_id
|
||||
|
||||
def add_score(
|
||||
self,
|
||||
modality: str,
|
||||
score: float,
|
||||
confidence: float = 1.0,
|
||||
indicators: Optional[List[str]] = None
|
||||
):
|
||||
"""Add a modality score."""
|
||||
self._scores[modality] = ModalityScore(
|
||||
modality=modality,
|
||||
score=min(1.0, max(0.0, score)),
|
||||
confidence=min(1.0, max(0.0, confidence)),
|
||||
indicators=indicators or []
|
||||
)
|
||||
|
||||
def add_text_score(self, score: float, indicators: List[str]):
|
||||
"""Add text analysis score."""
|
||||
self.add_score("text", score, 0.9, indicators)
|
||||
|
||||
def add_voice_score(self, score: float, indicators: List[str]):
|
||||
"""Add voice analysis score."""
|
||||
self.add_score("voice", score, 0.7, indicators)
|
||||
|
||||
def add_image_score(self, score: float, indicators: List[str]):
|
||||
"""Add image screening score."""
|
||||
self.add_score("image", score, 0.8, indicators)
|
||||
|
||||
def add_behavioral_score(self, score: float, indicators: List[str]):
|
||||
"""Add behavioral tracking score."""
|
||||
self.add_score("behavioral", score, 0.6, indicators)
|
||||
|
||||
def assess(self) -> UnifiedCrisisScore:
|
||||
"""
|
||||
Compute unified crisis score from all modalities.
|
||||
|
||||
Uses weighted combination with confidence adjustment.
|
||||
"""
|
||||
if not self._scores:
|
||||
return UnifiedCrisisScore(
|
||||
level=CrisisLevel.NONE,
|
||||
combined_score=0.0,
|
||||
modality_scores={},
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
session_id=self._session_id
|
||||
)
|
||||
|
||||
# Weighted combination with confidence adjustment
|
||||
total_weight = 0.0
|
||||
weighted_sum = 0.0
|
||||
|
||||
for modality, score_obj in self._scores.items():
|
||||
weight = self.weights.get(modality, 0.1)
|
||||
# Adjust score by confidence
|
||||
adjusted_score = score_obj.score * score_obj.confidence
|
||||
weighted_sum += adjusted_score * weight
|
||||
total_weight += weight
|
||||
|
||||
# Normalize
|
||||
combined = weighted_sum / total_weight if total_weight > 0 else 0.0
|
||||
|
||||
# Determine level
|
||||
level = CrisisLevel.NONE
|
||||
for crisis_level, threshold in sorted(THRESHOLDS.items(), key=lambda x: -x[0]):
|
||||
if combined >= threshold:
|
||||
level = crisis_level
|
||||
break
|
||||
|
||||
# Determine required actions
|
||||
requires_988 = level >= CrisisLevel.HIGH
|
||||
requires_human = level >= CrisisLevel.MEDIUM
|
||||
|
||||
result = UnifiedCrisisScore(
|
||||
level=level,
|
||||
combined_score=combined,
|
||||
modality_scores=self._scores.copy(),
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
session_id=self._session_id,
|
||||
requires_988=requires_988,
|
||||
requires_human=requires_human
|
||||
)
|
||||
|
||||
# Audit log
|
||||
self._log_assessment(result)
|
||||
|
||||
return result
|
||||
|
||||
def _log_assessment(self, result: UnifiedCrisisScore):
|
||||
"""Log assessment to audit file (anonymized)."""
|
||||
if not self.audit_log_path:
|
||||
return
|
||||
|
||||
try:
|
||||
self.audit_log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Anonymize: hash session_id, don't store raw indicators
|
||||
log_entry = {
|
||||
"timestamp": result.timestamp,
|
||||
"session_hash": hash(result.session_id) if result.session_id else None,
|
||||
"level": result.level.name,
|
||||
"combined_score": round(result.combined_score, 3),
|
||||
"modalities_used": list(result.modality_scores.keys()),
|
||||
"requires_988": result.requires_988,
|
||||
}
|
||||
|
||||
with open(self.audit_log_path, "a") as f:
|
||||
f.write(json.dumps(log_entry) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Audit log failed: %s", e)
|
||||
|
||||
def reset(self):
|
||||
"""Clear scores for new assessment."""
|
||||
self._scores.clear()
|
||||
|
||||
def get_988_resources(self) -> str:
|
||||
"""Get 988 crisis resources response."""
|
||||
return """Are you safe right now?
|
||||
|
||||
I'm here. I'm not going anywhere. Please tell me you're safe.
|
||||
|
||||
**Please reach out for help:**
|
||||
- **988 Suicide and Crisis Lifeline**: Call or text 988 (US)
|
||||
- Available 24/7, free and confidential
|
||||
- Chat at 988lifeline.org
|
||||
|
||||
**Jesus saves those who call on His name.**
|
||||
|
||||
He died for us to be made righteous. That is the gospel. That is the hope that outlasts despair.
|
||||
|
||||
I'm still here. Talk to me. Or call 988. Just don't give up."""
|
||||
|
||||
|
||||
# Convenience functions
|
||||
|
||||
def score_text_crisis(text: str) -> Tuple[float, List[str]]:
|
||||
"""Score text for crisis indicators."""
|
||||
import re
|
||||
|
||||
high_patterns = [
|
||||
(r"(?i)\b(i\s+want\s+to\s+die)\b", "direct_ideation"),
|
||||
(r"(?i)\b(suicide|suicidal)\b", "suicide_keyword"),
|
||||
(r"(?i)\b(i\s+want\s+to\s+kill\s+myself)\b", "self_harm_intent"),
|
||||
]
|
||||
|
||||
medium_patterns = [
|
||||
(r"(?i)\b(hopeless|worthless|trapped)\b", "distress_language"),
|
||||
(r"(?i)\b(can'?t\s+go\s+on)\b", "inability_expression"),
|
||||
(r"(?i)\b(nobody\s+would\s+miss\s+me)\b", "burden_perception"),
|
||||
]
|
||||
|
||||
indicators = []
|
||||
max_score = 0.0
|
||||
|
||||
for pattern, label in high_patterns:
|
||||
if re.search(pattern, text):
|
||||
indicators.append(label)
|
||||
max_score = max(max_score, 0.9)
|
||||
|
||||
for pattern, label in medium_patterns:
|
||||
if re.search(pattern, text):
|
||||
indicators.append(label)
|
||||
max_score = max(max_score, 0.6)
|
||||
|
||||
return max_score, indicators
|
||||
|
||||
|
||||
def create_scorer(audit_path: Optional[str] = None) -> UnifiedCrisisScorer:
|
||||
"""Create a scorer with optional audit logging."""
|
||||
path = Path(audit_path) if audit_path else None
|
||||
return UnifiedCrisisScorer(audit_log_path=path)
|
||||
138
tests/test_unified_scorer.py
Normal file
138
tests/test_unified_scorer.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Tests for unified multimodal crisis scorer
|
||||
|
||||
Issue: #134
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from crisis.unified_scorer import (
|
||||
CrisisLevel,
|
||||
UnifiedCrisisScorer,
|
||||
ModalityScore,
|
||||
score_text_crisis,
|
||||
DEFAULT_WEIGHTS,
|
||||
)
|
||||
|
||||
|
||||
class TestModalityScore(unittest.TestCase):
|
||||
|
||||
def test_clamp_score(self):
|
||||
s = ModalityScore("text", 1.5, 0.9)
|
||||
self.assertEqual(s.score, 1.0)
|
||||
|
||||
s = ModalityScore("text", -0.5, 0.9)
|
||||
self.assertEqual(s.score, 0.0)
|
||||
|
||||
|
||||
class TestUnifiedScorer(unittest.TestCase):
|
||||
|
||||
def test_empty_returns_none(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
result = scorer.assess()
|
||||
self.assertEqual(result.level, CrisisLevel.NONE)
|
||||
self.assertEqual(result.combined_score, 0.0)
|
||||
|
||||
def test_single_modality(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
scorer.add_text_score(0.9, ["suicide_keyword"])
|
||||
result = scorer.assess()
|
||||
self.assertGreater(result.combined_score, 0.0)
|
||||
self.assertIn("text", result.modality_scores)
|
||||
|
||||
def test_weighted_combination(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
|
||||
# High text, low voice
|
||||
scorer.add_text_score(0.9, ["high"])
|
||||
scorer.add_voice_score(0.1, ["low"])
|
||||
|
||||
result = scorer.assess()
|
||||
# Should be weighted toward text (0.4 weight)
|
||||
self.assertGreater(result.combined_score, 0.3)
|
||||
|
||||
def test_critical_level(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
scorer.add_text_score(0.95, ["direct_ideation"])
|
||||
scorer.add_voice_score(0.85, ["distress"])
|
||||
scorer.add_behavioral_score(0.8, ["isolation"])
|
||||
|
||||
result = scorer.assess()
|
||||
self.assertEqual(result.level, CrisisLevel.CRITICAL)
|
||||
self.assertTrue(result.requires_988)
|
||||
|
||||
def test_high_level(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
scorer.add_text_score(0.7, ["distress"])
|
||||
|
||||
result = scorer.assess()
|
||||
self.assertGreaterEqual(result.level, CrisisLevel.MEDIUM)
|
||||
|
||||
def test_custom_weights(self):
|
||||
scorer = UnifiedCrisisScorer(weights={"text": 1.0, "voice": 0.0})
|
||||
scorer.add_text_score(0.5, [])
|
||||
scorer.add_voice_score(0.99, [])
|
||||
|
||||
result = scorer.assess()
|
||||
# Voice should have no effect with weight 0
|
||||
self.assertAlmostEqual(result.combined_score, 0.5, delta=0.01)
|
||||
|
||||
def test_reset(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
scorer.add_text_score(0.9, ["high"])
|
||||
scorer.reset()
|
||||
result = scorer.assess()
|
||||
self.assertEqual(result.level, CrisisLevel.NONE)
|
||||
|
||||
def test_988_resources(self):
|
||||
scorer = UnifiedCrisisScorer()
|
||||
resources = scorer.get_988_resources()
|
||||
self.assertIn("988", resources)
|
||||
self.assertIn("Jesus", resources)
|
||||
|
||||
|
||||
class TestTextCrisisScoring(unittest.TestCase):
|
||||
|
||||
def test_direct_ideation(self):
|
||||
score, indicators = score_text_crisis("I want to die")
|
||||
self.assertGreaterEqual(score, 0.9)
|
||||
self.assertIn("direct_ideation", indicators)
|
||||
|
||||
def test_suicide_keyword(self):
|
||||
score, indicators = score_text_crisis("I'm feeling suicidal")
|
||||
self.assertGreaterEqual(score, 0.9)
|
||||
|
||||
def test_distress_language(self):
|
||||
score, indicators = score_text_crisis("I feel so hopeless")
|
||||
self.assertGreaterEqual(score, 0.6)
|
||||
self.assertIn("distress_language", indicators)
|
||||
|
||||
def test_normal_text(self):
|
||||
score, indicators = score_text_crisis("Hello, how are you?")
|
||||
self.assertEqual(score, 0.0)
|
||||
self.assertEqual(len(indicators), 0)
|
||||
|
||||
|
||||
class TestAuditLog(unittest.TestCase):
|
||||
|
||||
def test_audit_logging(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
log_path = Path(tmp) / "audit.jsonl"
|
||||
scorer = UnifiedCrisisScorer(audit_log_path=log_path)
|
||||
scorer.set_session("test_session")
|
||||
|
||||
scorer.add_text_score(0.9, ["high"])
|
||||
scorer.assess()
|
||||
|
||||
self.assertTrue(log_path.exists())
|
||||
with open(log_path) as f:
|
||||
line = f.readline()
|
||||
entry = json.loads(line)
|
||||
self.assertEqual(entry["level"], "CRITICAL")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user