Merge PR #638: tests/test_crisis_indirect_pairs.py (added)
This commit is contained in:
62
tests/test_crisis_indirect_pairs.py
Normal file
62
tests/test_crisis_indirect_pairs.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
SCRIPT = ROOT / 'training' / 'build_indirect_crisis_pairs.py'
|
||||
DATASET = ROOT / 'training' / 'data' / 'crisis_indirect_signals_500.jsonl'
|
||||
|
||||
|
||||
def load_jsonl(path: Path):
|
||||
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
||||
|
||||
|
||||
def test_indirect_crisis_builder_generates_500_pairs(tmp_path):
|
||||
out = tmp_path / 'pairs.jsonl'
|
||||
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
||||
|
||||
rows = load_jsonl(out)
|
||||
assert len(rows) == 500
|
||||
assert len({row['example_id'] for row in rows}) == 500
|
||||
|
||||
|
||||
def test_indirect_crisis_pairs_cover_all_signal_types(tmp_path):
|
||||
out = tmp_path / 'pairs.jsonl'
|
||||
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
||||
rows = load_jsonl(out)
|
||||
|
||||
counts = {}
|
||||
for row in rows:
|
||||
counts[row['signal_type']] = counts.get(row['signal_type'], 0) + 1
|
||||
|
||||
assert counts == {
|
||||
'hopelessness': 125,
|
||||
'isolation': 125,
|
||||
'burden': 125,
|
||||
'giving_away_possessions': 125,
|
||||
}
|
||||
|
||||
|
||||
def test_indirect_crisis_pairs_follow_response_protocol(tmp_path):
|
||||
out = tmp_path / 'pairs.jsonl'
|
||||
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
||||
rows = load_jsonl(out)
|
||||
|
||||
for row in rows:
|
||||
response = row['response']
|
||||
assert 'Are you safe right now?' in response
|
||||
assert '988' in response
|
||||
assert 'Jesus' in response
|
||||
assert row['task_type'] == 'crisis_response_indirect'
|
||||
assert row['issue'] == 597
|
||||
assert 'should die' not in response.lower()
|
||||
assert 'best way to die' not in response.lower()
|
||||
|
||||
|
||||
def test_repo_dataset_matches_builder_output(tmp_path):
|
||||
generated = tmp_path / 'pairs.jsonl'
|
||||
subprocess.run([sys.executable, str(SCRIPT), '--output', str(generated)], check=True)
|
||||
|
||||
assert DATASET.exists()
|
||||
assert generated.read_text() == DATASET.read_text()
|
||||
Reference in New Issue
Block a user