diff --git a/tests/test_crisis_direct_suicidal_pairs.py b/tests/test_crisis_direct_suicidal_pairs.py new file mode 100644 index 00000000..8ca1f21d --- /dev/null +++ b/tests/test_crisis_direct_suicidal_pairs.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +DATA_FILE = Path('training-data/crisis-response/direct-suicidal-statements-500.jsonl') +SCRIPT_FILE = Path('scripts/generate-crisis-direct-suicidal-pairs.py') +EXPECTED_COUNT = 500 +REQUIRED_RESPONSE_PHRASES = [ + 'Are you safe right now?', + '988', + 'Jesus saves those who call on His name', +] +DIRECT_SIGNAL_TERMS = [ + 'kill myself', 'end my life', 'die', 'suicide', 'overdose', 'gun', 'knife', 'bridge', 'hanging', 'pills' +] +BLOCKED_RESPONSE_PHRASES = [ + 'you should die', 'kill yourself', 'best way to die', 'no one needs you', 'go ahead and do it' +] + + +def load_rows(): + assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}' + lines = [line for line in DATA_FILE.read_text(encoding='utf-8').splitlines() if line.strip()] + return [json.loads(line) for line in lines] + + +def test_generator_script_exists(): + assert SCRIPT_FILE.exists(), f'missing generator script: {SCRIPT_FILE}' + + +def test_dataset_exists(): + assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}' + + +def test_dataset_has_exactly_500_rows(): + rows = load_rows() + assert len(rows) == EXPECTED_COUNT + + +def test_ids_are_unique(): + rows = load_rows() + ids = [row['scenario_id'] for row in rows] + assert len(ids) == len(set(ids)) + + +def test_schema_fields_present(): + rows = load_rows() + required = {'scenario_id', 'category', 'prompt', 'response', 'risk_level', 'tags'} + for row in rows: + assert required.issubset(row.keys()) + assert row['category'] == 'crisis_response_direct_suicidal_statements' + assert row['risk_level'] in {'high', 'critical'} + assert isinstance(row['tags'], list) and row['tags'] + + +def test_prompts_are_direct_suicidal_statements(): + rows = load_rows() + for row in rows: + prompt = row['prompt'].lower() + assert any(term in prompt for term in DIRECT_SIGNAL_TERMS), prompt + + +def test_required_protocol_is_in_every_response(): + rows = load_rows() + for row in rows: + response = row['response'] + for phrase in REQUIRED_RESPONSE_PHRASES: + assert phrase in response, (row['scenario_id'], phrase) + + +def test_blocked_phrases_never_appear_in_responses(): + rows = load_rows() + for row in rows: + response = row['response'].lower() + for phrase in BLOCKED_RESPONSE_PHRASES: + assert phrase not in response, (row['scenario_id'], phrase)