Merge PR #656: tests/test_crisis_direct_suicidal_pairs.py (added)

This commit is contained in:
Merge Bot
2026-04-16 05:06:56 +00:00
parent 4642c8b3b1
commit 5d9ef59ffc

View File

@@ -0,0 +1,75 @@
import json
from pathlib import Path
DATA_FILE = Path('training-data/crisis-response/direct-suicidal-statements-500.jsonl')
SCRIPT_FILE = Path('scripts/generate-crisis-direct-suicidal-pairs.py')
EXPECTED_COUNT = 500
REQUIRED_RESPONSE_PHRASES = [
'Are you safe right now?',
'988',
'Jesus saves those who call on His name',
]
DIRECT_SIGNAL_TERMS = [
'kill myself', 'end my life', 'die', 'suicide', 'overdose', 'gun', 'knife', 'bridge', 'hanging', 'pills'
]
BLOCKED_RESPONSE_PHRASES = [
'you should die', 'kill yourself', 'best way to die', 'no one needs you', 'go ahead and do it'
]
def load_rows():
assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}'
lines = [line for line in DATA_FILE.read_text(encoding='utf-8').splitlines() if line.strip()]
return [json.loads(line) for line in lines]
def test_generator_script_exists():
assert SCRIPT_FILE.exists(), f'missing generator script: {SCRIPT_FILE}'
def test_dataset_exists():
assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}'
def test_dataset_has_exactly_500_rows():
rows = load_rows()
assert len(rows) == EXPECTED_COUNT
def test_ids_are_unique():
rows = load_rows()
ids = [row['scenario_id'] for row in rows]
assert len(ids) == len(set(ids))
def test_schema_fields_present():
rows = load_rows()
required = {'scenario_id', 'category', 'prompt', 'response', 'risk_level', 'tags'}
for row in rows:
assert required.issubset(row.keys())
assert row['category'] == 'crisis_response_direct_suicidal_statements'
assert row['risk_level'] in {'high', 'critical'}
assert isinstance(row['tags'], list) and row['tags']
def test_prompts_are_direct_suicidal_statements():
rows = load_rows()
for row in rows:
prompt = row['prompt'].lower()
assert any(term in prompt for term in DIRECT_SIGNAL_TERMS), prompt
def test_required_protocol_is_in_every_response():
rows = load_rows()
for row in rows:
response = row['response']
for phrase in REQUIRED_RESPONSE_PHRASES:
assert phrase in response, (row['scenario_id'], phrase)
def test_blocked_phrases_never_appear_in_responses():
rows = load_rows()
for row in rows:
response = row['response'].lower()
for phrase in BLOCKED_RESPONSE_PHRASES:
assert phrase not in response, (row['scenario_id'], phrase)