Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d084654d8 |
@@ -8,7 +8,13 @@ from .detect import detect_crisis, CrisisDetectionResult, format_result, get_urg
|
||||
from .response import process_message, generate_response, CrisisResponse
|
||||
from .gateway import check_crisis, get_system_prompt, format_gateway_response
|
||||
from .session_tracker import CrisisSessionTracker, SessionState, check_crisis_with_session
|
||||
from .ab_testing import ABTestCrisisDetector, VariantRecord
|
||||
from .metrics import (
|
||||
build_metrics_event,
|
||||
append_metrics_event,
|
||||
load_metrics_events,
|
||||
build_weekly_summary,
|
||||
render_weekly_summary,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"detect_crisis",
|
||||
@@ -24,6 +30,9 @@ __all__ = [
|
||||
"CrisisSessionTracker",
|
||||
"SessionState",
|
||||
"check_crisis_with_session",
|
||||
"ABTestCrisisDetector",
|
||||
"VariantRecord",
|
||||
"build_metrics_event",
|
||||
"append_metrics_event",
|
||||
"load_metrics_events",
|
||||
"build_weekly_summary",
|
||||
"render_weekly_summary",
|
||||
]
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""A/B test framework for crisis detection in the-door."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from .detect import CrisisDetectionResult
|
||||
|
||||
|
||||
def _get_variant_override() -> Optional[str]:
|
||||
"""Return env override for deterministic testing/debugging."""
|
||||
value = os.environ.get("CRISIS_AB_VARIANT", "").strip().upper()
|
||||
if value in {"A", "B"}:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VariantRecord:
|
||||
"""Single crisis detection event record with no user text or PII."""
|
||||
|
||||
variant: str
|
||||
level: str
|
||||
latency_ms: float
|
||||
indicator_count: int
|
||||
false_positive: Optional[bool] = None
|
||||
|
||||
|
||||
class ABTestCrisisDetector:
|
||||
"""Route crisis detection between two variants and collect comparison stats."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
variant_a: Callable[[str], CrisisDetectionResult],
|
||||
variant_b: Callable[[str], CrisisDetectionResult],
|
||||
split: float = 0.5,
|
||||
):
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.split = max(0.0, min(1.0, float(split)))
|
||||
self.records: List[VariantRecord] = []
|
||||
|
||||
def _select_variant(self) -> str:
|
||||
override = _get_variant_override()
|
||||
if override:
|
||||
return override
|
||||
return "A" if random.random() < self.split else "B"
|
||||
|
||||
def detect(self, text: str) -> Tuple[CrisisDetectionResult, str, int]:
|
||||
variant = self._select_variant()
|
||||
detector = self.variant_a if variant == "A" else self.variant_b
|
||||
|
||||
start = time.perf_counter()
|
||||
result = detector(text)
|
||||
latency_ms = (time.perf_counter() - start) * 1000.0
|
||||
|
||||
record = VariantRecord(
|
||||
variant=variant,
|
||||
level=result.level,
|
||||
latency_ms=latency_ms,
|
||||
indicator_count=len(result.indicators),
|
||||
)
|
||||
self.records.append(record)
|
||||
return result, variant, len(self.records) - 1
|
||||
|
||||
def record_outcome(self, record_id: int, *, false_positive: bool) -> None:
|
||||
if record_id < 0 or record_id >= len(self.records):
|
||||
raise IndexError(f"Unknown record id: {record_id}")
|
||||
self.records[record_id].false_positive = bool(false_positive)
|
||||
|
||||
def get_stats(self) -> Dict[str, dict]:
|
||||
stats: Dict[str, dict] = {}
|
||||
for variant in ("A", "B"):
|
||||
records = [record for record in self.records if record.variant == variant]
|
||||
if not records:
|
||||
stats[variant] = {
|
||||
"count": 0,
|
||||
"reviewed_count": 0,
|
||||
"false_positive_rate": None,
|
||||
}
|
||||
continue
|
||||
|
||||
levels: Dict[str, int] = {}
|
||||
for record in records:
|
||||
levels[record.level] = levels.get(record.level, 0) + 1
|
||||
|
||||
reviewed = [record for record in records if record.false_positive is not None]
|
||||
false_positive_rate = None
|
||||
if reviewed:
|
||||
false_positive_rate = round(
|
||||
sum(1 for record in reviewed if record.false_positive) / len(reviewed),
|
||||
4,
|
||||
)
|
||||
|
||||
stats[variant] = {
|
||||
"count": len(records),
|
||||
"avg_latency_ms": round(sum(record.latency_ms for record in records) / len(records), 4),
|
||||
"max_latency_ms": round(max(record.latency_ms for record in records), 4),
|
||||
"min_latency_ms": round(min(record.latency_ms for record in records), 4),
|
||||
"avg_indicator_count": round(sum(record.indicator_count for record in records) / len(records), 4),
|
||||
"levels": levels,
|
||||
"reviewed_count": len(reviewed),
|
||||
"false_positive_rate": false_positive_rate,
|
||||
}
|
||||
return stats
|
||||
|
||||
def reset(self) -> None:
|
||||
self.records.clear()
|
||||
@@ -23,9 +23,17 @@ from .response import (
|
||||
CrisisResponse,
|
||||
)
|
||||
from .session_tracker import CrisisSessionTracker
|
||||
from .metrics import build_metrics_event, append_metrics_event
|
||||
|
||||
|
||||
def check_crisis(text: str) -> dict:
|
||||
def check_crisis(
|
||||
text: str,
|
||||
metrics_log_path: Optional[str] = None,
|
||||
*,
|
||||
continued_conversation: bool = False,
|
||||
false_positive: bool = False,
|
||||
now: Optional[float] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Full crisis check returning structured data.
|
||||
|
||||
@@ -35,7 +43,7 @@ def check_crisis(text: str) -> dict:
|
||||
detection = detect_crisis(text)
|
||||
response = generate_response(detection)
|
||||
|
||||
return {
|
||||
result = {
|
||||
"level": detection.level,
|
||||
"score": detection.score,
|
||||
"indicators": detection.indicators,
|
||||
@@ -49,6 +57,23 @@ def check_crisis(text: str) -> dict:
|
||||
"escalate": response.escalate,
|
||||
}
|
||||
|
||||
metrics_event = build_metrics_event(
|
||||
detection,
|
||||
continued_conversation=continued_conversation,
|
||||
false_positive=false_positive,
|
||||
now=now,
|
||||
)
|
||||
if metrics_log_path:
|
||||
metrics_event = append_metrics_event(
|
||||
metrics_log_path,
|
||||
detection,
|
||||
continued_conversation=continued_conversation,
|
||||
false_positive=false_positive,
|
||||
now=now,
|
||||
)
|
||||
result["metrics_event"] = metrics_event
|
||||
return result
|
||||
|
||||
|
||||
def get_system_prompt(base_prompt: str, text: str = "") -> str:
|
||||
"""
|
||||
|
||||
166
crisis/metrics.py
Normal file
166
crisis/metrics.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Privacy-preserving crisis analytics metrics for the-door.
|
||||
|
||||
Stores only timestamps, crisis levels, indicator categories, and operator
|
||||
feedback flags. No raw message text or PII is persisted.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from .detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
LEVELS = ("NONE", "LOW", "MEDIUM", "HIGH", "CRITICAL")
|
||||
|
||||
|
||||
def normalize_indicator(indicator: str) -> str:
|
||||
"""Return a stable privacy-safe keyword/category identifier."""
|
||||
return indicator
|
||||
|
||||
|
||||
def build_metrics_event(
|
||||
detection: CrisisDetectionResult,
|
||||
*,
|
||||
continued_conversation: bool = False,
|
||||
false_positive: bool = False,
|
||||
now: float | None = None,
|
||||
) -> dict:
|
||||
timestamp = float(time.time() if now is None else now)
|
||||
indicators = [normalize_indicator(indicator) for indicator in detection.indicators]
|
||||
return {
|
||||
"timestamp": timestamp,
|
||||
"level": detection.level,
|
||||
"indicator_count": len(indicators),
|
||||
"indicators": indicators,
|
||||
"continued_conversation": bool(continued_conversation),
|
||||
"false_positive": bool(false_positive),
|
||||
}
|
||||
|
||||
|
||||
def append_metrics_event(
|
||||
log_path: str | Path,
|
||||
detection: CrisisDetectionResult,
|
||||
*,
|
||||
continued_conversation: bool = False,
|
||||
false_positive: bool = False,
|
||||
now: float | None = None,
|
||||
) -> dict:
|
||||
event = build_metrics_event(
|
||||
detection,
|
||||
continued_conversation=continued_conversation,
|
||||
false_positive=false_positive,
|
||||
now=now,
|
||||
)
|
||||
path = Path(log_path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(event) + "\n")
|
||||
return event
|
||||
|
||||
|
||||
def load_metrics_events(log_path: str | Path) -> list[dict]:
|
||||
path = Path(log_path)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
events.append(json.loads(line))
|
||||
return events
|
||||
|
||||
|
||||
def build_weekly_summary(
|
||||
events: Iterable[dict],
|
||||
*,
|
||||
now: float | None = None,
|
||||
window_days: int = 7,
|
||||
) -> dict:
|
||||
current_time = float(time.time() if now is None else now)
|
||||
cutoff = current_time - (window_days * 86400)
|
||||
filtered = [event for event in events if float(event.get("timestamp", 0)) >= cutoff]
|
||||
|
||||
detections_per_level = {level: 0 for level in LEVELS}
|
||||
keyword_counts: Counter[str] = Counter()
|
||||
detections = []
|
||||
continued_after_intervention = 0
|
||||
|
||||
for event in filtered:
|
||||
level = event.get("level", "NONE")
|
||||
detections_per_level[level] = detections_per_level.get(level, 0) + 1
|
||||
keyword_counts.update(event.get("indicators", []))
|
||||
if level != "NONE":
|
||||
detections.append(event)
|
||||
if event.get("continued_conversation"):
|
||||
continued_after_intervention += 1
|
||||
|
||||
false_positive_count = sum(1 for event in detections if event.get("false_positive"))
|
||||
false_positive_estimate = (
|
||||
false_positive_count / len(detections) if detections else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"window_days": window_days,
|
||||
"total_events": len(filtered),
|
||||
"detections_per_level": detections_per_level,
|
||||
"most_common_keywords": [
|
||||
{"keyword": keyword, "count": count}
|
||||
for keyword, count in keyword_counts.most_common(10)
|
||||
],
|
||||
"false_positive_estimate": false_positive_estimate,
|
||||
"continued_after_intervention": continued_after_intervention,
|
||||
}
|
||||
|
||||
|
||||
def render_weekly_summary(summary: dict) -> str:
|
||||
return json.dumps(summary, indent=2)
|
||||
|
||||
|
||||
def write_weekly_summary(path: str | Path, summary: dict) -> Path:
|
||||
output_path = Path(path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(render_weekly_summary(summary) + "\n", encoding="utf-8")
|
||||
return output_path
|
||||
|
||||
|
||||
def record_text_event(
|
||||
text: str,
|
||||
log_path: str | Path,
|
||||
*,
|
||||
continued_conversation: bool = False,
|
||||
false_positive: bool = False,
|
||||
now: float | None = None,
|
||||
) -> dict:
|
||||
detection = detect_crisis(text)
|
||||
return append_metrics_event(
|
||||
log_path,
|
||||
detection,
|
||||
continued_conversation=continued_conversation,
|
||||
false_positive=false_positive,
|
||||
now=now,
|
||||
)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = argparse.ArgumentParser(description="Privacy-preserving crisis metrics summary")
|
||||
parser.add_argument("--log-path", required=True, help="JSONL event log path")
|
||||
parser.add_argument("--days", type=int, default=7, help="Summary window in days")
|
||||
parser.add_argument("--output", help="Optional file to write summary JSON")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
events = load_metrics_events(args.log_path)
|
||||
summary = build_weekly_summary(events, window_days=args.days)
|
||||
rendered = render_weekly_summary(summary)
|
||||
print(rendered)
|
||||
if args.output:
|
||||
write_weekly_summary(args.output, summary)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Tests for crisis.ab_testing — A/B test framework for crisis detection (#101)."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crisis.ab_testing import ABTestCrisisDetector
|
||||
from crisis.detect import CrisisDetectionResult, detect_crisis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_variant_override():
|
||||
old = os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is not None:
|
||||
os.environ["CRISIS_AB_VARIANT"] = old
|
||||
else:
|
||||
os.environ.pop("CRISIS_AB_VARIANT", None)
|
||||
|
||||
|
||||
def _make_variant(level: str, indicators=None):
|
||||
indicators = indicators or [f"mock_{level.lower()}"]
|
||||
|
||||
def fn(text: str) -> CrisisDetectionResult:
|
||||
return CrisisDetectionResult(level=level, indicators=list(indicators))
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def test_detect_returns_result_variant_and_logged_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", return_value="A"):
|
||||
result, variant, record_id = detector.detect("test message")
|
||||
|
||||
assert isinstance(result, CrisisDetectionResult)
|
||||
assert variant == "A"
|
||||
assert record_id == 0
|
||||
assert len(detector.records) == 1
|
||||
assert detector.records[0].variant == "A"
|
||||
assert detector.records[0].level == "LOW"
|
||||
|
||||
|
||||
def test_env_override_forces_variant_b():
|
||||
os.environ["CRISIS_AB_VARIANT"] = "b"
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
result, variant, _ = detector.detect("test")
|
||||
|
||||
assert variant == "B"
|
||||
assert result.level == "HIGH"
|
||||
|
||||
|
||||
def test_get_stats_reports_latency_counts_and_level_breakdown():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("CRITICAL"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
detector.detect("first")
|
||||
detector.detect("second")
|
||||
detector.detect("third")
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 2
|
||||
assert stats["B"]["count"] == 1
|
||||
assert stats["A"]["levels"]["LOW"] == 2
|
||||
assert stats["B"]["levels"]["CRITICAL"] == 1
|
||||
assert "avg_latency_ms" in stats["A"]
|
||||
assert "avg_indicator_count" in stats["B"]
|
||||
|
||||
|
||||
def test_false_positive_rate_is_computed_from_reviewed_outcomes():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with patch.object(detector, "_select_variant", side_effect=["A", "A", "B"]):
|
||||
_, _, a0 = detector.detect("first")
|
||||
_, _, a1 = detector.detect("second")
|
||||
_, _, b0 = detector.detect("third")
|
||||
|
||||
detector.record_outcome(a0, false_positive=True)
|
||||
detector.record_outcome(a1, false_positive=False)
|
||||
detector.record_outcome(b0, false_positive=False)
|
||||
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["reviewed_count"] == 2
|
||||
assert stats["A"]["false_positive_rate"] == 0.5
|
||||
assert stats["B"]["false_positive_rate"] == 0.0
|
||||
|
||||
|
||||
def test_record_outcome_rejects_unknown_record():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
|
||||
with pytest.raises(IndexError):
|
||||
detector.record_outcome(99, false_positive=True)
|
||||
|
||||
|
||||
def test_reset_clears_records_and_stats():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=_make_variant("LOW"),
|
||||
variant_b=_make_variant("HIGH"),
|
||||
)
|
||||
detector.detect("test")
|
||||
detector.reset()
|
||||
|
||||
assert detector.records == []
|
||||
stats = detector.get_stats()
|
||||
assert stats["A"]["count"] == 0
|
||||
assert stats["B"]["count"] == 0
|
||||
|
||||
|
||||
def test_with_real_detector_integration():
|
||||
detector = ABTestCrisisDetector(
|
||||
variant_a=detect_crisis,
|
||||
variant_b=detect_crisis,
|
||||
)
|
||||
|
||||
result, variant, record_id = detector.detect("I want to kill myself")
|
||||
|
||||
assert result.level == "CRITICAL"
|
||||
assert variant in ("A", "B")
|
||||
assert record_id == 0
|
||||
100
tests/test_crisis_metrics.py
Normal file
100
tests/test_crisis_metrics.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests for privacy-preserving crisis metrics aggregation (issue #37)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from crisis.detect import detect_crisis
|
||||
from crisis.gateway import check_crisis
|
||||
from crisis.metrics import (
|
||||
append_metrics_event,
|
||||
build_metrics_event,
|
||||
build_weekly_summary,
|
||||
load_metrics_events,
|
||||
render_weekly_summary,
|
||||
)
|
||||
|
||||
|
||||
class TestMetricsEvent(unittest.TestCase):
|
||||
def test_event_is_privacy_preserving(self):
|
||||
detection = detect_crisis("I want to kill myself")
|
||||
event = build_metrics_event(
|
||||
detection,
|
||||
continued_conversation=True,
|
||||
false_positive=False,
|
||||
now=1_700_000_000,
|
||||
)
|
||||
self.assertEqual(event["timestamp"], 1_700_000_000)
|
||||
self.assertEqual(event["level"], "CRITICAL")
|
||||
self.assertTrue(event["continued_conversation"])
|
||||
self.assertFalse(event["false_positive"])
|
||||
self.assertNotIn("text", event)
|
||||
self.assertNotIn("message", event)
|
||||
self.assertGreaterEqual(event["indicator_count"], 1)
|
||||
self.assertTrue(event["indicators"])
|
||||
|
||||
|
||||
class TestMetricsLogAndSummary(unittest.TestCase):
|
||||
def test_append_and_load_metrics_events(self):
|
||||
log_path = pathlib.Path(self._testMethodName).with_suffix(".jsonl")
|
||||
try:
|
||||
append_metrics_event(log_path, detect_crisis("I want to die"), now=1_700_000_000)
|
||||
events = load_metrics_events(log_path)
|
||||
self.assertEqual(len(events), 1)
|
||||
self.assertEqual(events[0]["level"], "CRITICAL")
|
||||
finally:
|
||||
if log_path.exists():
|
||||
log_path.unlink()
|
||||
|
||||
def test_weekly_summary_counts_levels_keywords_and_false_positives(self):
|
||||
events = [
|
||||
build_metrics_event(detect_crisis("I want to die"), continued_conversation=True, false_positive=False, now=1_700_000_000),
|
||||
build_metrics_event(detect_crisis("I'm having a rough day"), continued_conversation=False, false_positive=False, now=1_700_000_100),
|
||||
build_metrics_event(detect_crisis("I want to die"), continued_conversation=False, false_positive=True, now=1_700_000_200),
|
||||
build_metrics_event(detect_crisis("Hello there"), continued_conversation=False, false_positive=False, now=1_700_000_300),
|
||||
]
|
||||
summary = build_weekly_summary(events, now=1_700_000_400, window_days=7)
|
||||
|
||||
self.assertEqual(summary["detections_per_level"]["CRITICAL"], 2)
|
||||
self.assertEqual(summary["detections_per_level"]["LOW"], 1)
|
||||
self.assertEqual(summary["detections_per_level"]["NONE"], 1)
|
||||
self.assertEqual(summary["continued_after_intervention"], 1)
|
||||
self.assertAlmostEqual(summary["false_positive_estimate"], 1 / 3, places=4)
|
||||
self.assertEqual(summary["most_common_keywords"][0]["count"], 2)
|
||||
|
||||
def test_render_weekly_summary_mentions_required_metrics(self):
|
||||
events = [
|
||||
build_metrics_event(detect_crisis("I want to die"), continued_conversation=True, now=1_700_000_000),
|
||||
build_metrics_event(detect_crisis("I feel hopeless with no way out"), false_positive=True, now=1_700_000_100),
|
||||
]
|
||||
summary = build_weekly_summary(events, now=1_700_000_200, window_days=7)
|
||||
rendered = render_weekly_summary(summary)
|
||||
self.assertIn("detections_per_level", rendered)
|
||||
self.assertIn("most_common_keywords", rendered)
|
||||
self.assertIn("false_positive_estimate", rendered)
|
||||
self.assertIn("continued_after_intervention", rendered)
|
||||
|
||||
|
||||
class TestGatewayMetricsIntegration(unittest.TestCase):
|
||||
def test_check_crisis_can_emit_metrics_event(self):
|
||||
result = check_crisis(
|
||||
"I want to die",
|
||||
metrics_log_path=None,
|
||||
continued_conversation=True,
|
||||
false_positive=False,
|
||||
now=1_700_000_000,
|
||||
)
|
||||
self.assertEqual(result["level"], "CRITICAL")
|
||||
self.assertIn("metrics_event", result)
|
||||
self.assertEqual(result["metrics_event"]["timestamp"], 1_700_000_000)
|
||||
self.assertTrue(result["metrics_event"]["continued_conversation"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user