63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SCRIPT = ROOT / 'training' / 'build_indirect_crisis_pairs.py'
|
|
DATASET = ROOT / 'training' / 'data' / 'crisis_indirect_signals_500.jsonl'
|
|
|
|
|
|
def load_jsonl(path: Path):
|
|
return [json.loads(line) for line in path.read_text().splitlines() if line.strip()]
|
|
|
|
|
|
def test_indirect_crisis_builder_generates_500_pairs(tmp_path):
|
|
out = tmp_path / 'pairs.jsonl'
|
|
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
|
|
|
rows = load_jsonl(out)
|
|
assert len(rows) == 500
|
|
assert len({row['example_id'] for row in rows}) == 500
|
|
|
|
|
|
def test_indirect_crisis_pairs_cover_all_signal_types(tmp_path):
|
|
out = tmp_path / 'pairs.jsonl'
|
|
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
|
rows = load_jsonl(out)
|
|
|
|
counts = {}
|
|
for row in rows:
|
|
counts[row['signal_type']] = counts.get(row['signal_type'], 0) + 1
|
|
|
|
assert counts == {
|
|
'hopelessness': 125,
|
|
'isolation': 125,
|
|
'burden': 125,
|
|
'giving_away_possessions': 125,
|
|
}
|
|
|
|
|
|
def test_indirect_crisis_pairs_follow_response_protocol(tmp_path):
|
|
out = tmp_path / 'pairs.jsonl'
|
|
subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True)
|
|
rows = load_jsonl(out)
|
|
|
|
for row in rows:
|
|
response = row['response']
|
|
assert 'Are you safe right now?' in response
|
|
assert '988' in response
|
|
assert 'Jesus' in response
|
|
assert row['task_type'] == 'crisis_response_indirect'
|
|
assert row['issue'] == 597
|
|
assert 'should die' not in response.lower()
|
|
assert 'best way to die' not in response.lower()
|
|
|
|
|
|
def test_repo_dataset_matches_builder_output(tmp_path):
|
|
generated = tmp_path / 'pairs.jsonl'
|
|
subprocess.run([sys.executable, str(SCRIPT), '--output', str(generated)], check=True)
|
|
|
|
assert DATASET.exists()
|
|
assert generated.read_text() == DATASET.read_text()
|