335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""Tests for the content moderation pipeline."""
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from infrastructure.guards.moderation import (
|
|
ContentModerator,
|
|
GameProfile,
|
|
ModerationResult,
|
|
ModerationVerdict,
|
|
ViolationCategory,
|
|
_parse_guard_category,
|
|
get_moderator,
|
|
)
|
|
|
|
# ── Unit tests for data types ────────────────────────────────────────────────
|
|
|
|
|
|
class TestModerationResult:
|
|
"""Test ModerationResult dataclass."""
|
|
|
|
def test_passed_property_true(self):
|
|
result = ModerationResult(verdict=ModerationVerdict.PASS, blocked=False)
|
|
assert result.passed is True
|
|
|
|
def test_passed_property_false(self):
|
|
result = ModerationResult(verdict=ModerationVerdict.FAIL, blocked=True)
|
|
assert result.passed is False
|
|
|
|
def test_default_values(self):
|
|
result = ModerationResult(verdict=ModerationVerdict.PASS, blocked=False)
|
|
assert result.category == ViolationCategory.NONE
|
|
assert result.confidence == 0.0
|
|
assert result.fallback == ""
|
|
assert result.reason == ""
|
|
|
|
|
|
class TestGameProfile:
|
|
"""Test GameProfile dataclass."""
|
|
|
|
def test_default_values(self):
|
|
profile = GameProfile(game_id="test", display_name="Test Game")
|
|
assert profile.vocabulary_whitelist == []
|
|
assert profile.threshold == 0.8
|
|
assert profile.fallbacks == {}
|
|
|
|
def test_morrowind_profile(self):
|
|
profile = GameProfile(
|
|
game_id="morrowind",
|
|
display_name="Morrowind",
|
|
vocabulary_whitelist=["Skooma", "slave"],
|
|
threshold=0.85,
|
|
)
|
|
assert "Skooma" in profile.vocabulary_whitelist
|
|
assert profile.threshold == 0.85
|
|
|
|
|
|
class TestParseGuardCategory:
|
|
"""Test Llama Guard category parsing."""
|
|
|
|
def test_hate_speech(self):
|
|
assert _parse_guard_category("S1: Hate speech") == ViolationCategory.HATE_SPEECH
|
|
|
|
def test_violence(self):
|
|
assert _parse_guard_category("S2: Violence") == ViolationCategory.VIOLENCE_GLORIFICATION
|
|
|
|
def test_sexual_content(self):
|
|
assert _parse_guard_category("S3: Sexual content") == ViolationCategory.SEXUAL_CONTENT
|
|
|
|
def test_self_harm(self):
|
|
assert _parse_guard_category("S4: Self-harm") == ViolationCategory.SELF_HARM
|
|
|
|
def test_dangerous(self):
|
|
assert _parse_guard_category("S5: Dangerous activity") == ViolationCategory.REAL_WORLD_HARM
|
|
|
|
def test_unknown_category(self):
|
|
assert _parse_guard_category("S99: Unknown") == ViolationCategory.NONE
|
|
|
|
|
|
# ── ContentModerator tests ───────────────────────────────────────────────────
|
|
|
|
|
|
class TestContentModerator:
|
|
"""Test the content moderation pipeline."""
|
|
|
|
def _make_moderator(self, **kwargs) -> ContentModerator:
|
|
"""Create a moderator with test defaults."""
|
|
profiles = {
|
|
"morrowind": GameProfile(
|
|
game_id="morrowind",
|
|
display_name="Morrowind",
|
|
vocabulary_whitelist=["Skooma", "Moon Sugar", "slave", "Morag Tong"],
|
|
context_prompt="Narrate Morrowind gameplay.",
|
|
threshold=0.85,
|
|
fallbacks={
|
|
"combat": "The battle continues.",
|
|
"default": "The adventure continues.",
|
|
},
|
|
),
|
|
"default": GameProfile(
|
|
game_id="default",
|
|
display_name="Generic",
|
|
vocabulary_whitelist=[],
|
|
context_prompt="Narrate gameplay.",
|
|
threshold=0.8,
|
|
fallbacks={"default": "Gameplay continues."},
|
|
),
|
|
}
|
|
return ContentModerator(profiles=profiles, **kwargs)
|
|
|
|
def test_get_profile_known_game(self):
|
|
mod = self._make_moderator()
|
|
profile = mod.get_profile("morrowind")
|
|
assert profile.game_id == "morrowind"
|
|
|
|
def test_get_profile_unknown_game_falls_back(self):
|
|
mod = self._make_moderator()
|
|
profile = mod.get_profile("unknown_game")
|
|
assert profile.game_id == "default"
|
|
|
|
def test_get_context_prompt(self):
|
|
mod = self._make_moderator()
|
|
prompt = mod.get_context_prompt("morrowind")
|
|
assert "Morrowind" in prompt
|
|
|
|
def test_register_profile(self):
|
|
mod = self._make_moderator()
|
|
new_profile = GameProfile(game_id="skyrim", display_name="Skyrim")
|
|
mod.register_profile(new_profile)
|
|
assert mod.get_profile("skyrim").game_id == "skyrim"
|
|
|
|
def test_whitelist_replaces_game_terms(self):
|
|
mod = self._make_moderator()
|
|
profile = mod.get_profile("morrowind")
|
|
cleaned = mod._apply_whitelist(
|
|
"The merchant sells Skooma and Moon Sugar in the slave market.",
|
|
profile,
|
|
)
|
|
assert "Skooma" not in cleaned
|
|
assert "Moon Sugar" not in cleaned
|
|
assert "slave" not in cleaned
|
|
assert "[GAME_TERM]" in cleaned
|
|
|
|
def test_whitelist_case_insensitive(self):
|
|
mod = self._make_moderator()
|
|
profile = mod.get_profile("morrowind")
|
|
cleaned = mod._apply_whitelist("skooma and SKOOMA", profile)
|
|
assert "skooma" not in cleaned
|
|
assert "SKOOMA" not in cleaned
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_safe_content_passes(self):
|
|
"""Safe content should pass moderation."""
|
|
mod = self._make_moderator()
|
|
with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False):
|
|
result = await mod.check("The player walks through the town.", game="morrowind")
|
|
assert result.passed
|
|
assert not result.blocked
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_blocked_content_has_fallback(self):
|
|
"""Blocked content should include scene-appropriate fallback."""
|
|
mod = self._make_moderator()
|
|
# Force a block via regex by using real-world harm language
|
|
text = "In real life you should attack and hurt people"
|
|
with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False):
|
|
result = await mod.check(text, game="morrowind", scene_type="combat")
|
|
assert result.blocked
|
|
assert result.fallback == "The battle continues."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_with_moderation_disabled(self):
|
|
"""When moderation is disabled, everything passes."""
|
|
mod = self._make_moderator()
|
|
with patch("infrastructure.guards.moderation.settings") as mock_settings:
|
|
mock_settings.moderation_enabled = False
|
|
mock_settings.moderation_guard_model = "llama-guard3:1b"
|
|
mock_settings.normalized_ollama_url = "http://127.0.0.1:11434"
|
|
result = await mod.check("anything goes here")
|
|
assert result.passed
|
|
assert result.layer == "disabled"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_threshold_below_allows_content(self):
|
|
"""Content flagged below threshold should pass through (Layer 3)."""
|
|
mod = self._make_moderator()
|
|
# Mock the guard to return a low-confidence flag
|
|
low_conf_result = ModerationResult(
|
|
verdict=ModerationVerdict.FAIL,
|
|
blocked=True,
|
|
confidence=0.5, # Below morrowind threshold of 0.85
|
|
layer="llama_guard",
|
|
category=ViolationCategory.VIOLENCE_GLORIFICATION,
|
|
)
|
|
with patch.object(
|
|
mod, "_run_guard", new_callable=AsyncMock, return_value=low_conf_result
|
|
):
|
|
result = await mod.check("sword fight scene", game="morrowind")
|
|
assert result.passed
|
|
assert not result.blocked
|
|
assert result.layer == "threshold"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_threshold_above_blocks_content(self):
|
|
"""Content flagged above threshold should remain blocked."""
|
|
mod = self._make_moderator()
|
|
high_conf_result = ModerationResult(
|
|
verdict=ModerationVerdict.FAIL,
|
|
blocked=True,
|
|
confidence=0.95, # Above morrowind threshold of 0.85
|
|
layer="llama_guard",
|
|
category=ViolationCategory.REAL_WORLD_HARM,
|
|
)
|
|
with patch.object(
|
|
mod, "_run_guard", new_callable=AsyncMock, return_value=high_conf_result
|
|
):
|
|
result = await mod.check("harmful content", game="morrowind")
|
|
assert result.blocked
|
|
|
|
def test_regex_catches_real_world_harm(self):
|
|
"""Regex fallback should catch obvious real-world harm patterns."""
|
|
mod = self._make_moderator()
|
|
result = mod._check_with_regex("you should actually harm real people")
|
|
assert result.blocked
|
|
assert result.category == ViolationCategory.REAL_WORLD_HARM
|
|
assert result.layer == "regex_fallback"
|
|
|
|
def test_regex_passes_game_violence(self):
|
|
"""Regex should not flag in-game violence narration."""
|
|
mod = self._make_moderator()
|
|
result = mod._check_with_regex(
|
|
"The warrior slays the dragon with a mighty blow."
|
|
)
|
|
assert result.passed
|
|
|
|
def test_regex_passes_normal_narration(self):
|
|
"""Normal narration should pass regex checks."""
|
|
mod = self._make_moderator()
|
|
result = mod._check_with_regex(
|
|
"The Nerevarine enters the city of Balmora and speaks with Caius Cosades."
|
|
)
|
|
assert result.passed
|
|
|
|
def test_metrics_tracking(self):
|
|
"""Metrics should track checks accurately."""
|
|
mod = self._make_moderator()
|
|
assert mod.get_metrics()["total_checks"] == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_metrics_increment_after_check(self):
|
|
"""Metrics should increment after moderation checks."""
|
|
mod = self._make_moderator()
|
|
with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False):
|
|
await mod.check("safe text", game="default")
|
|
metrics = mod.get_metrics()
|
|
assert metrics["total_checks"] == 1
|
|
assert metrics["passed"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_guard_fallback_on_error(self):
|
|
"""Should fall back to regex when guard model errors."""
|
|
mod = self._make_moderator()
|
|
with patch.object(
|
|
mod, "_is_guard_available", new_callable=AsyncMock, return_value=True
|
|
), patch.object(
|
|
mod, "_check_with_guard", new_callable=AsyncMock, side_effect=RuntimeError("timeout")
|
|
):
|
|
result = await mod.check("safe text", game="default")
|
|
# Should fall back to regex and pass
|
|
assert result.passed
|
|
assert result.layer == "regex_fallback"
|
|
|
|
|
|
class TestGetModerator:
|
|
"""Test the singleton accessor."""
|
|
|
|
def test_returns_same_instance(self):
|
|
"""get_moderator should return the same instance."""
|
|
# Reset the global to test fresh
|
|
import infrastructure.guards.moderation as mod_module
|
|
|
|
mod_module._moderator = None
|
|
m1 = get_moderator()
|
|
m2 = get_moderator()
|
|
assert m1 is m2
|
|
# Clean up
|
|
mod_module._moderator = None
|
|
|
|
|
|
# ── Profile loader tests ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestProfileLoader:
|
|
"""Test YAML profile loading."""
|
|
|
|
def test_load_missing_file_returns_empty(self, tmp_path):
|
|
from infrastructure.guards.profiles import load_profiles
|
|
|
|
result = load_profiles(tmp_path / "nonexistent.yaml")
|
|
assert result == {}
|
|
|
|
def test_load_valid_config(self, tmp_path):
|
|
import yaml
|
|
|
|
from infrastructure.guards.profiles import load_profiles
|
|
|
|
config = {
|
|
"profiles": {
|
|
"testgame": {
|
|
"display_name": "Test Game",
|
|
"threshold": 0.9,
|
|
"vocabulary_whitelist": ["sword", "potion"],
|
|
"context_prompt": "Narrate test game.",
|
|
"fallbacks": {"default": "Game continues."},
|
|
}
|
|
}
|
|
}
|
|
config_file = tmp_path / "moderation.yaml"
|
|
config_file.write_text(yaml.dump(config))
|
|
|
|
profiles = load_profiles(config_file)
|
|
assert "testgame" in profiles
|
|
assert profiles["testgame"].threshold == 0.9
|
|
assert "sword" in profiles["testgame"].vocabulary_whitelist
|
|
|
|
def test_load_malformed_yaml_returns_empty(self, tmp_path):
|
|
from infrastructure.guards.profiles import load_profiles
|
|
|
|
config_file = tmp_path / "moderation.yaml"
|
|
config_file.write_text("{{{{invalid yaml")
|
|
|
|
result = load_profiles(config_file)
|
|
assert result == {}
|