diff --git a/tests/test_crisis_indirect_pairs.py b/tests/test_crisis_indirect_pairs.py new file mode 100644 index 00000000..61e48f37 --- /dev/null +++ b/tests/test_crisis_indirect_pairs.py @@ -0,0 +1,62 @@ +import json +import subprocess +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] +SCRIPT = ROOT / 'training' / 'build_indirect_crisis_pairs.py' +DATASET = ROOT / 'training' / 'data' / 'crisis_indirect_signals_500.jsonl' + + +def load_jsonl(path: Path): + return [json.loads(line) for line in path.read_text().splitlines() if line.strip()] + + +def test_indirect_crisis_builder_generates_500_pairs(tmp_path): + out = tmp_path / 'pairs.jsonl' + subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True) + + rows = load_jsonl(out) + assert len(rows) == 500 + assert len({row['example_id'] for row in rows}) == 500 + + +def test_indirect_crisis_pairs_cover_all_signal_types(tmp_path): + out = tmp_path / 'pairs.jsonl' + subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True) + rows = load_jsonl(out) + + counts = {} + for row in rows: + counts[row['signal_type']] = counts.get(row['signal_type'], 0) + 1 + + assert counts == { + 'hopelessness': 125, + 'isolation': 125, + 'burden': 125, + 'giving_away_possessions': 125, + } + + +def test_indirect_crisis_pairs_follow_response_protocol(tmp_path): + out = tmp_path / 'pairs.jsonl' + subprocess.run([sys.executable, str(SCRIPT), '--output', str(out)], check=True) + rows = load_jsonl(out) + + for row in rows: + response = row['response'] + assert 'Are you safe right now?' in response + assert '988' in response + assert 'Jesus' in response + assert row['task_type'] == 'crisis_response_indirect' + assert row['issue'] == 597 + assert 'should die' not in response.lower() + assert 'best way to die' not in response.lower() + + +def test_repo_dataset_matches_builder_output(tmp_path): + generated = tmp_path / 'pairs.jsonl' + subprocess.run([sys.executable, str(SCRIPT), '--output', str(generated)], check=True) + + assert DATASET.exists() + assert generated.read_text() == DATASET.read_text()