Compare commits

..

2 Commits

Author SHA1 Message Date
Timmy
7cef18fdcb feat: add crisis ab testing for #101
All checks were successful
Sanity Checks / sanity-test (pull_request) Successful in 7s
Smoke Test / smoke (pull_request) Successful in 14s
2026-04-20 21:43:37 -04:00
Timmy
706024e11e test: define crisis ab testing for #101 2026-04-20 21:41:31 -04:00
5 changed files with 254 additions and 306 deletions

View File

@@ -8,6 +8,7 @@ from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urg
from .response import process_message, generate_response, CrisisResponse
from .gateway import check_crisis, get_system_prompt, format_gateway_response
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
from .ab_testing import ABTestCrisisDetector, VariantRecord
__all__ = [
"detect_crisis",
@@ -23,4 +24,6 @@ __all__ = [
"CrisisSessionTracker",
"SessionState",
"check_crisis_with_session",
"ABTestCrisisDetector",
"VariantRecord",
]

112
crisis/ab_testing.py Normal file
View File

@@ -0,0 +1,112 @@
"""A/B test framework for crisis detection in the-door."""
from __future__ import annotations
import os
import random
import time
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple
from .detect import CrisisDetectionResult
def _get_variant_override() -> Optional[str]:
"""Return env override for deterministic testing/debugging."""
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
if value in {"A", "B"}:
return value
return None
@dataclass
class VariantRecord:
"""Single crisis detection event record with no user text or PII."""
variant: str
level: str
latency_ms: float
indicator_count: int
false_positive: Optional[bool] = None
class ABTestCrisisDetector:
"""Route crisis detection between two variants and collect comparison stats."""
def __init__(
self,
variant_a: Callable[[str], CrisisDetectionResult],
variant_b: Callable[[str], CrisisDetectionResult],
split: float = 0.5,
):
self.variant_a = variant_a
self.variant_b = variant_b
self.split = max(0.0, min(1.0, float(split)))
self.records: List[VariantRecord] = []
def _select_variant(self) -> str:
override = _get_variant_override()
if override:
return override
return "A" if random.random() < self.split else "B"
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
variant = self._select_variant()
detector = self.variant_a if variant == "A" else self.variant_b
start = time.perf_counter()
result = detector(text)
latency_ms = (time.perf_counter() - start) * 1000.0
record = VariantRecord(
variant=variant,
level=result.level,
latency_ms=latency_ms,
indicator_count=len(result.indicators),
)
self.records.append(record)
return result, variant, len(self.records) - 1
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
if record_id < 0 or record_id >= len(self.records):
raise IndexError(f"Unknown record id: {record_id}")
self.records[record_id].false_positive = bool(false_positive)
def get_stats(self) -> Dict[str, dict]:
stats: Dict[str, dict] = {}
for variant in ("A", "B"):
records = [record for record in self.records if record.variant == variant]
if not records:
stats[variant] = {
"count": 0,
"reviewed_count": 0,
"false_positive_rate": None,
}
continue
levels: Dict[str, int] = {}
for record in records:
levels[record.level] = levels.get(record.level, 0) + 1
reviewed = [record for record in records if record.false_positive is not None]
false_positive_rate = None
if reviewed:
false_positive_rate = round(
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
4,
)
stats[variant] = {
"count": len(records),
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
"levels": levels,
"reviewed_count": len(reviewed),
"false_positive_rate": false_positive_rate,
}
return stats
def reset(self) -> None:
self.records.clear()

View File

@@ -1,195 +1 @@
"""Crisis synthesizer — learn from anonymized crisis interactions.
This is deliberately simple and privacy-preserving. It does not train a model or
modify detection rules automatically. It only logs metadata, summarizes patterns,
and suggests human-reviewed keyword weight adjustments.
"""
from __future__ import annotations
import argparse
import json
import time
from collections import Counter, defaultdict
from pathlib import Path
from typing import Iterable
DEFAULT_LOG_PATH = Path.home() / ".the-door" / "crisis-interactions.jsonl"
LEVELS = ("NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL")
def build_interaction_event(
level: str,
indicators: list[str],
response_given: str,
continued_conversation: bool,
false_positive: bool,
*,
now: float | None = None,
) -> dict:
return {
"timestamp": float(time.time() if now is None else now),
"level": level,
"indicators": list(indicators),
"indicator_count": len(indicators),
"response_given": response_given,
"continued_conversation": bool(continued_conversation),
"false_positive": bool(false_positive),
}
def append_interaction_event(
log_path: str | Path,
*,
level: str,
indicators: list[str],
response_given: str,
continued_conversation: bool,
false_positive: bool,
now: float | None = None,
) -> dict:
event = build_interaction_event(
level,
indicators,
response_given,
continued_conversation,
false_positive,
now=now,
)
path = Path(log_path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(event) + "\n")
return event
def load_interaction_events(log_path: str | Path) -> list[dict]:
path = Path(log_path)
if not path.exists():
return []
events = []
for line in path.read_text(encoding="utf-8").splitlines():
if not line.strip():
continue
events.append(json.loads(line))
return events
def summarize_keywords(events: Iterable[dict]) -> list[dict]:
counts: Counter[str] = Counter()
for event in events:
counts.update(event.get("indicators", []))
return [{"keyword": keyword, "count": count} for keyword, count in counts.most_common(10)]
def suggest_keyword_adjustments(events: Iterable[dict], *, min_observations: int = 5) -> list[dict]:
stats: dict[str, dict[str, int]] = defaultdict(lambda: {
"observations": 0,
"true_positive_count": 0,
"false_positive_count": 0,
"continued_conversation_count": 0,
})
for event in events:
for keyword in event.get("indicators", []):
bucket = stats[keyword]
bucket["observations"] += 1
if event.get("false_positive"):
bucket["false_positive_count"] += 1
else:
bucket["true_positive_count"] += 1
if event.get("continued_conversation"):
bucket["continued_conversation_count"] += 1
suggestions = []
for keyword, bucket in sorted(stats.items()):
if bucket["observations"] < min_observations:
continue
fp = bucket["false_positive_count"]
tp = bucket["true_positive_count"]
if fp >= min_observations and tp == 0:
adjustment = "lower_weight"
rationale = "Observed only false positives across the sample window."
elif tp >= min_observations and fp == 0:
adjustment = "raise_weight"
rationale = "Observed repeated genuine crises with no false positives."
else:
adjustment = "observe"
rationale = "Mixed evidence; keep monitoring before changing weights."
suggestions.append(
{
"keyword": keyword,
**bucket,
"suggested_adjustment": adjustment,
"rationale": rationale,
}
)
return suggestions
def build_weekly_report(
events: Iterable[dict],
*,
now: float | None = None,
window_days: int = 7,
min_observations: int = 3,
) -> dict:
current_time = float(time.time() if now is None else now)
cutoff = current_time - (window_days * 86400)
filtered = [event for event in events if float(event.get("timestamp", 0)) >= cutoff]
detections_per_level = {level: 0 for level in LEVELS}
detected_events = []
continued_after_intervention = 0
for event in filtered:
level = event.get("level", "NONE")
detections_per_level[level] = detections_per_level.get(level, 0) + 1
if level != "NONE":
detected_events.append(event)
if event.get("continued_conversation"):
continued_after_intervention += 1
false_positive_count = sum(1 for event in detected_events if event.get("false_positive"))
false_positive_estimate = false_positive_count / len(detected_events) if detected_events else 0.0
return {
"window_days": window_days,
"total_events": len(filtered),
"detections_per_level": detections_per_level,
"most_common_keywords": summarize_keywords(filtered),
"false_positive_estimate": false_positive_estimate,
"continued_after_intervention": continued_after_intervention,
"keyword_weight_suggestions": suggest_keyword_adjustments(filtered, min_observations=min_observations),
}
def render_weekly_report(summary: dict) -> str:
return json.dumps(summary, indent=2)
def write_weekly_report(output_path: str | Path, summary: dict) -> Path:
path = Path(output_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(render_weekly_report(summary) + "\n", encoding="utf-8")
return path
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Summarize anonymized crisis interactions")
parser.add_argument("--log-path", default=str(DEFAULT_LOG_PATH), help="JSONL crisis interaction log")
parser.add_argument("--days", type=int, default=7, help="Lookback window in days")
parser.add_argument("--min-observations", type=int, default=3, help="Minimum observations before suggesting keyword adjustments")
parser.add_argument("--output", help="Optional file to write the weekly report JSON")
args = parser.parse_args(argv)
events = load_interaction_events(args.log_path)
summary = build_weekly_report(events, window_days=args.days, min_observations=args.min_observations)
rendered = render_weekly_report(summary)
print(rendered)
if args.output:
write_weekly_report(args.output, summary)
return 0
if __name__ == "__main__":
raise SystemExit(main())
...

138
tests/test_ab_testing.py Normal file
View File

@@ -0,0 +1,138 @@
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
import os
from unittest.mock import patch
import pytest
from crisis.ab_testing import ABTestCrisisDetector
from crisis.detect import CrisisDetectionResult, detect_crisis
@pytest.fixture(autouse=True)
def clear_variant_override():
old = os.environ.pop("CRISIS_AB_VARIANT", None)
try:
yield
finally:
if old is not None:
os.environ["CRISIS_AB_VARIANT"] = old
else:
os.environ.pop("CRISIS_AB_VARIANT", None)
def _make_variant(level: str, indicators=None):
indicators = indicators or [f"mock_{level.lower()}"]
def fn(text: str) -> CrisisDetectionResult:
return CrisisDetectionResult(level=level, indicators=list(indicators))
return fn
def test_detect_returns_result_variant_and_logged_record():
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
with patch.object(detector, "_select_variant", return_value="A"):
result, variant, record_id = detector.detect("test message")
assert isinstance(result, CrisisDetectionResult)
assert variant == "A"
assert record_id == 0
assert len(detector.records) == 1
assert detector.records[0].variant == "A"
assert detector.records[0].level == "LOW"
def test_env_override_forces_variant_b():
os.environ["CRISIS_AB_VARIANT"] = "b"
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
result, variant, _ = detector.detect("test")
assert variant == "B"
assert result.level == "HIGH"
def test_get_stats_reports_latency_counts_and_level_breakdown():
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("CRITICAL"),
)
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
detector.detect("first")
detector.detect("second")
detector.detect("third")
stats = detector.get_stats()
assert stats["A"]["count"] == 2
assert stats["B"]["count"] == 1
assert stats["A"]["levels"]["LOW"] == 2
assert stats["B"]["levels"]["CRITICAL"] == 1
assert "avg_latency_ms" in stats["A"]
assert "avg_indicator_count" in stats["B"]
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
_, _, a0 = detector.detect("first")
_, _, a1 = detector.detect("second")
_, _, b0 = detector.detect("third")
detector.record_outcome(a0, false_positive=True)
detector.record_outcome(a1, false_positive=False)
detector.record_outcome(b0, false_positive=False)
stats = detector.get_stats()
assert stats["A"]["reviewed_count"] == 2
assert stats["A"]["false_positive_rate"] == 0.5
assert stats["B"]["false_positive_rate"] == 0.0
def test_record_outcome_rejects_unknown_record():
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
with pytest.raises(IndexError):
detector.record_outcome(99, false_positive=True)
def test_reset_clears_records_and_stats():
detector = ABTestCrisisDetector(
variant_a=_make_variant("LOW"),
variant_b=_make_variant("HIGH"),
)
detector.detect("test")
detector.reset()
assert detector.records == []
stats = detector.get_stats()
assert stats["A"]["count"] == 0
assert stats["B"]["count"] == 0
def test_with_real_detector_integration():
detector = ABTestCrisisDetector(
variant_a=detect_crisis,
variant_b=detect_crisis,
)
result, variant, record_id = detector.detect("I want to kill myself")
assert result.level == "CRITICAL"
assert variant in ("A", "B")
assert record_id == 0

View File

@@ -1,111 +0,0 @@
"""Tests for evolution/crisis_synthesizer.py (issue #36)."""
from __future__ import annotations
import importlib.util
import json
import pathlib
import sys
import tempfile
import unittest
ROOT = pathlib.Path(__file__).resolve().parents[1]
SCRIPT = ROOT / 'evolution' / 'crisis_synthesizer.py'
spec = importlib.util.spec_from_file_location('crisis_synthesizer', str(SCRIPT))
mod = importlib.util.module_from_spec(spec)
sys.modules['crisis_synthesizer'] = mod
spec.loader.exec_module(mod)
class TestCrisisSynthesizerEvent(unittest.TestCase):
def test_build_interaction_event_is_privacy_preserving(self):
event = mod.build_interaction_event(
level='CRITICAL',
indicators=['want_to_die', 'no_way_out'],
response_given='guardian',
continued_conversation=True,
false_positive=False,
now=1700000000,
)
self.assertEqual(event['timestamp'], 1700000000)
self.assertEqual(event['level'], 'CRITICAL')
self.assertEqual(event['response_given'], 'guardian')
self.assertTrue(event['continued_conversation'])
self.assertFalse(event['false_positive'])
self.assertEqual(event['indicators'], ['want_to_die', 'no_way_out'])
for forbidden in ['text', 'message', 'content', 'ip', 'session_id', 'user_id']:
self.assertNotIn(forbidden, event)
class TestCrisisSynthesizerStorage(unittest.TestCase):
def test_append_and_load_events_round_trip(self):
with tempfile.TemporaryDirectory() as tmp:
log_path = pathlib.Path(tmp) / 'crisis-events.jsonl'
mod.append_interaction_event(
log_path,
level='HIGH',
indicators=['hopeless'],
response_given='companion',
continued_conversation=False,
false_positive=True,
now=1700000100,
)
events = mod.load_interaction_events(log_path)
self.assertEqual(len(events), 1)
self.assertEqual(events[0]['level'], 'HIGH')
self.assertEqual(events[0]['indicators'], ['hopeless'])
class TestCrisisSynthesizerSummary(unittest.TestCase):
def test_weekly_report_contains_required_metrics(self):
events = [
mod.build_interaction_event('CRITICAL', ['want_to_die'], 'guardian', True, False, now=1700000000),
mod.build_interaction_event('HIGH', ['hopeless'], 'companion', False, True, now=1700000100),
mod.build_interaction_event('LOW', ['rough_day'], 'friend', False, False, now=1700000200),
mod.build_interaction_event('CRITICAL', ['want_to_die'], 'guardian', False, False, now=1700000300),
mod.build_interaction_event('NONE', [], 'friend', False, False, now=1700000400),
]
summary = mod.build_weekly_report(events, now=1700000500, window_days=7)
self.assertEqual(summary['detections_per_level']['CRITICAL'], 2)
self.assertEqual(summary['detections_per_level']['HIGH'], 1)
self.assertEqual(summary['detections_per_level']['LOW'], 1)
self.assertEqual(summary['detections_per_level']['NONE'], 1)
self.assertEqual(summary['continued_after_intervention'], 1)
self.assertAlmostEqual(summary['false_positive_estimate'], 0.25)
self.assertEqual(summary['most_common_keywords'][0]['keyword'], 'want_to_die')
self.assertEqual(summary['most_common_keywords'][0]['count'], 2)
class TestCrisisSynthesizerSuggestions(unittest.TestCase):
def test_suggests_weight_adjustments_from_interactions(self):
events = []
for ts in range(3):
events.append(mod.build_interaction_event('CRITICAL', ['want_to_die'], 'guardian', True, False, now=1700000000 + ts))
for ts in range(3):
events.append(mod.build_interaction_event('LOW', ['rough_day'], 'friend', False, True, now=1700000100 + ts))
suggestions = mod.suggest_keyword_adjustments(events, min_observations=3)
by_keyword = {s['keyword']: s for s in suggestions}
self.assertEqual(by_keyword['want_to_die']['suggested_adjustment'], 'raise_weight')
self.assertEqual(by_keyword['rough_day']['suggested_adjustment'], 'lower_weight')
class TestCrisisSynthesizerRendering(unittest.TestCase):
def test_render_weekly_report_outputs_json(self):
summary = {
'detections_per_level': {'NONE': 0, 'LOW': 1, 'MEDIUM': 0, 'HIGH': 0, 'CRITICAL': 0},
'most_common_keywords': [{'keyword': 'rough_day', 'count': 1}],
'false_positive_estimate': 0.0,
'continued_after_intervention': 0,
'keyword_weight_suggestions': [],
'window_days': 7,
'total_events': 1,
}
rendered = mod.render_weekly_report(summary)
parsed = json.loads(rendered)
self.assertEqual(parsed['window_days'], 7)
self.assertEqual(parsed['most_common_keywords'][0]['keyword'], 'rough_day')
if __name__ == '__main__':
unittest.main()