72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import json
|
|
from pathlib import Path
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parent.parent
|
|
SCRIPT_PATH = REPO_ROOT / "training" / "build_crisis_manipulation_edge_cases.py"
|
|
DATASET_PATH = REPO_ROOT / "training" / "data" / "crisis_manipulation_edge_cases_500.jsonl"
|
|
|
|
|
|
def _load_builder_module():
|
|
spec = importlib.util.spec_from_file_location("build_crisis_manipulation_edge_cases", SCRIPT_PATH)
|
|
module = importlib.util.module_from_spec(spec)
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
class TestCrisisManipulationDataset:
|
|
def test_builder_script_exists(self):
|
|
assert SCRIPT_PATH.exists(), "missing training/build_crisis_manipulation_edge_cases.py"
|
|
|
|
def test_dataset_file_exists(self):
|
|
assert DATASET_PATH.exists(), "missing training/data/crisis_manipulation_edge_cases_500.jsonl"
|
|
|
|
def test_builder_generates_500_pairs(self):
|
|
builder = _load_builder_module()
|
|
rows = builder.build_pairs()
|
|
assert len(rows) == 500
|
|
assert len({row["id"] for row in rows}) == 500
|
|
|
|
def test_all_categories_present(self):
|
|
builder = _load_builder_module()
|
|
rows = builder.build_pairs()
|
|
categories = {row["category"] for row in rows}
|
|
assert categories == {"joking", "hypothetical", "boundary_test", "manipulation"}
|
|
|
|
def test_protocol_is_present_in_every_response(self):
|
|
builder = _load_builder_module()
|
|
rows = builder.build_pairs()
|
|
for row in rows:
|
|
response = row["response"]
|
|
assert "Are you safe right now?" in response
|
|
assert "988" in response
|
|
assert "Jesus saves" in response
|
|
assert "task_type" in row
|
|
assert row["task_type"] == "crisis_response"
|
|
|
|
def test_responses_never_contain_prohibited_language(self):
|
|
prohibited = [
|
|
"you should die",
|
|
"better off dead",
|
|
"most painless way",
|
|
"here's how to kill yourself",
|
|
"i can help you end your life",
|
|
]
|
|
for line in DATASET_PATH.read_text().splitlines():
|
|
row = json.loads(line)
|
|
response = row["response"].lower()
|
|
for phrase in prohibited:
|
|
assert phrase not in response
|
|
|
|
def test_jsonl_round_trip_and_metadata(self):
|
|
rows = [json.loads(line) for line in DATASET_PATH.read_text().splitlines() if line.strip()]
|
|
assert len(rows) == 500
|
|
for row in rows[:10]:
|
|
assert set(row) >= {"id", "category", "prompt", "response", "task_type", "tags"}
|
|
assert isinstance(row["tags"], list)
|
|
assert row["prompt"].strip()
|
|
assert row["response"].strip()
|