Merge PR #656: tests/test_crisis_direct_suicidal_pairs.py (added)
This commit is contained in:
75
tests/test_crisis_direct_suicidal_pairs.py
Normal file
75
tests/test_crisis_direct_suicidal_pairs.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
DATA_FILE = Path('training-data/crisis-response/direct-suicidal-statements-500.jsonl')
|
||||
SCRIPT_FILE = Path('scripts/generate-crisis-direct-suicidal-pairs.py')
|
||||
EXPECTED_COUNT = 500
|
||||
REQUIRED_RESPONSE_PHRASES = [
|
||||
'Are you safe right now?',
|
||||
'988',
|
||||
'Jesus saves those who call on His name',
|
||||
]
|
||||
DIRECT_SIGNAL_TERMS = [
|
||||
'kill myself', 'end my life', 'die', 'suicide', 'overdose', 'gun', 'knife', 'bridge', 'hanging', 'pills'
|
||||
]
|
||||
BLOCKED_RESPONSE_PHRASES = [
|
||||
'you should die', 'kill yourself', 'best way to die', 'no one needs you', 'go ahead and do it'
|
||||
]
|
||||
|
||||
|
||||
def load_rows():
|
||||
assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}'
|
||||
lines = [line for line in DATA_FILE.read_text(encoding='utf-8').splitlines() if line.strip()]
|
||||
return [json.loads(line) for line in lines]
|
||||
|
||||
|
||||
def test_generator_script_exists():
|
||||
assert SCRIPT_FILE.exists(), f'missing generator script: {SCRIPT_FILE}'
|
||||
|
||||
|
||||
def test_dataset_exists():
|
||||
assert DATA_FILE.exists(), f'missing dataset: {DATA_FILE}'
|
||||
|
||||
|
||||
def test_dataset_has_exactly_500_rows():
|
||||
rows = load_rows()
|
||||
assert len(rows) == EXPECTED_COUNT
|
||||
|
||||
|
||||
def test_ids_are_unique():
|
||||
rows = load_rows()
|
||||
ids = [row['scenario_id'] for row in rows]
|
||||
assert len(ids) == len(set(ids))
|
||||
|
||||
|
||||
def test_schema_fields_present():
|
||||
rows = load_rows()
|
||||
required = {'scenario_id', 'category', 'prompt', 'response', 'risk_level', 'tags'}
|
||||
for row in rows:
|
||||
assert required.issubset(row.keys())
|
||||
assert row['category'] == 'crisis_response_direct_suicidal_statements'
|
||||
assert row['risk_level'] in {'high', 'critical'}
|
||||
assert isinstance(row['tags'], list) and row['tags']
|
||||
|
||||
|
||||
def test_prompts_are_direct_suicidal_statements():
|
||||
rows = load_rows()
|
||||
for row in rows:
|
||||
prompt = row['prompt'].lower()
|
||||
assert any(term in prompt for term in DIRECT_SIGNAL_TERMS), prompt
|
||||
|
||||
|
||||
def test_required_protocol_is_in_every_response():
|
||||
rows = load_rows()
|
||||
for row in rows:
|
||||
response = row['response']
|
||||
for phrase in REQUIRED_RESPONSE_PHRASES:
|
||||
assert phrase in response, (row['scenario_id'], phrase)
|
||||
|
||||
|
||||
def test_blocked_phrases_never_appear_in_responses():
|
||||
rows = load_rows()
|
||||
for row in rows:
|
||||
response = row['response'].lower()
|
||||
for phrase in BLOCKED_RESPONSE_PHRASES:
|
||||
assert phrase not in response, (row['scenario_id'], phrase)
|
||||
Reference in New Issue
Block a user