diff --git a/image_screening.py b/image_screening.py new file mode 100644 index 0000000..f62f255 --- /dev/null +++ b/image_screening.py @@ -0,0 +1,170 @@ +""" +image_screening.py — local image crisis screening slice for epic #130. + +Grounded scope: +- screens OCR text, upstream object labels, and operator notes for crisis signals +- intentionally does NOT claim raw computer-vision understanding of pixels +- designed to plug into future multimodal scoring once a dedicated image model lands +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Iterable, List, Optional + +from crisis.detect import detect_crisis + + +DIRECT_SELF_HARM_LABELS = { + "blood", + "blade", + "razor", + "knife", + "scissors", + "noose", + "ligature", + "hanging", + "pills", + "pill bottle", + "overdose", + "gun", + "firearm", + "rope", + "cuts", + "self-harm", +} + +INJURY_LABELS = { + "wound", + "wounds", + "bruise", + "bruises", + "bandage", + "bandages", + "injury", + "injuries", + "scar", + "scars", + "burn", + "burns", + "bleeding", +} + +HIGH_RISK_SCENE_LABELS = { + "bridge edge", + "rooftop edge", + "train tracks", + "ledge", + "cliff edge", + "dark room", + "bathroom floor", + "hospital wristband", + "suicide note", + "goodbye letter", +} + + +@dataclass +class ImageScreeningResult: + ocr_text: str = "" + labels: List[str] = field(default_factory=list) + visual_flags: List[str] = field(default_factory=list) + distress_score: float = 0.0 + requires_human_review: bool = False + signals_detected: List[str] = field(default_factory=list) + grounded_scope: str = ( + "heuristic screening over OCR text, upstream labels, and operator notes; " + "raw vision-model inference is not implemented in this slice" + ) + + def to_dict(self) -> dict: + return asdict(self) + + +def _normalize_items(values: Optional[Iterable[str]]) -> List[str]: + if not values: + return [] + normalized = [] + for value in values: + text = str(value).strip().lower() + if text: + normalized.append(text) + return normalized + + +def _match_keywords(haystack: str, keywords: set[str]) -> List[str]: + matches = [] + for keyword in keywords: + if keyword in haystack: + matches.append(keyword) + return sorted(set(matches)) + + +def screen_image_signals( + image_path: Optional[str] = None, + *, + ocr_text: str = "", + labels: Optional[Iterable[str]] = None, + manual_notes: str = "", + visual_flags: Optional[Iterable[str]] = None, +) -> ImageScreeningResult: + """ + Score image-related crisis evidence without pretending to do full CV. + + Inputs are deliberately grounded in what the repo can actually support today: + - OCR text extracted upstream from screenshots/photos + - object labels supplied by a local model or operator + - manual operator notes about visible scene context + - optional visual flags from any upstream preprocessor + """ + normalized_labels = _normalize_items(labels) + normalized_flags = _normalize_items(visual_flags) + normalized_notes = str(manual_notes or "").strip().lower() + normalized_ocr = str(ocr_text or "").strip() + + combined_label_text = " ".join(normalized_labels + normalized_flags + ([normalized_notes] if normalized_notes else [])) + crisis_text = " ".join(part for part in [normalized_ocr, normalized_notes] if part).strip() + + direct_matches = _match_keywords(combined_label_text, DIRECT_SELF_HARM_LABELS) + injury_matches = _match_keywords(combined_label_text, INJURY_LABELS) + scene_matches = _match_keywords(combined_label_text, HIGH_RISK_SCENE_LABELS) + text_detection = detect_crisis(crisis_text) if crisis_text else None + + signals: List[str] = [] + score = 0.0 + + if direct_matches: + score = max(score, 0.85) + for match in direct_matches: + signals.append(f"direct_self_harm_label:{match}") + + if injury_matches: + score = max(score, 0.55) + for match in injury_matches: + signals.append(f"injury_indicator:{match}") + + if scene_matches: + score = max(score, 0.4) + for match in scene_matches: + signals.append(f"high_risk_scene:{match}") + + if text_detection and text_detection.level != "NONE": + score = max(score, min(1.0, text_detection.score)) + signals.append(f"ocr_crisis_level:{text_detection.level}") + for indicator in text_detection.indicators[:3]: + signals.append(f"ocr_indicator:{indicator}") + + if direct_matches and text_detection and text_detection.level in {"HIGH", "CRITICAL"}: + score = min(1.0, max(score, 0.95)) + signals.append("cross_modal_confirmation:text_plus_visual") + + requires_human_review = score >= 0.4 or bool(direct_matches) + + return ImageScreeningResult( + ocr_text=normalized_ocr, + labels=list(normalized_labels), + visual_flags=list(normalized_flags), + distress_score=round(score, 4), + requires_human_review=requires_human_review, + signals_detected=signals, + ) diff --git a/tests/test_image_screening.py b/tests/test_image_screening.py new file mode 100644 index 0000000..f9a4111 --- /dev/null +++ b/tests/test_image_screening.py @@ -0,0 +1,62 @@ +"""Tests for local image crisis screening slice under epic #130.""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from image_screening import ImageScreeningResult, screen_image_signals + + +class TestImageScreeningResult(unittest.TestCase): + def test_to_dict_preserves_core_fields(self): + result = ImageScreeningResult( + ocr_text="help me", + labels=["blood"], + visual_flags=["dark_scene"], + distress_score=0.8, + requires_human_review=True, + signals_detected=["direct_self_harm_label:blood"], + ) + + data = result.to_dict() + self.assertEqual(data["ocr_text"], "help me") + self.assertEqual(data["labels"], ["blood"]) + self.assertTrue(data["requires_human_review"]) + + +class TestScreenImageSignals(unittest.TestCase): + def test_direct_self_harm_labels_trigger_high_risk(self): + result = screen_image_signals( + labels=["razor blade", "blood droplets"], + manual_notes="photo of fresh cuts on forearm", + ) + + self.assertGreaterEqual(result.distress_score, 0.8) + self.assertTrue(result.requires_human_review) + self.assertTrue(any("self_harm" in signal for signal in result.signals_detected)) + + def test_ocr_text_uses_existing_crisis_detector(self): + result = screen_image_signals( + ocr_text="I want to kill myself tonight", + labels=["handwritten note"], + ) + + self.assertGreaterEqual(result.distress_score, 0.7) + self.assertTrue(result.requires_human_review) + self.assertTrue(any(signal.startswith("ocr_crisis_level:") for signal in result.signals_detected)) + + def test_neutral_image_stays_low_risk(self): + result = screen_image_signals( + labels=["dog", "park", "sunlight"], + manual_notes="family outing in daylight", + ) + + self.assertLess(result.distress_score, 0.2) + self.assertFalse(result.requires_human_review) + self.assertEqual(result.signals_detected, []) + + +if __name__ == "__main__": + unittest.main()