498 lines
17 KiB
Python
498 lines
17 KiB
Python
"""Content moderation pipeline for AI narrator output.
|
|
|
|
Three-layer defense against harmful LLM output:
|
|
|
|
Layer 1 — Game-context system prompts with per-game vocabulary whitelists.
|
|
Layer 2 — Real-time output filter (Llama Guard via Ollama, regex fallback).
|
|
Layer 3 — Per-game moderation profiles with configurable thresholds.
|
|
|
|
Usage:
|
|
from infrastructure.guards.moderation import get_moderator
|
|
|
|
moderator = get_moderator()
|
|
result = await moderator.check("Some narrator text", game="morrowind")
|
|
if result.blocked:
|
|
use_fallback_narration(result.fallback)
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModerationVerdict(Enum):
|
|
"""Result of a moderation check."""
|
|
|
|
PASS = "pass" # noqa: S105
|
|
FAIL = "fail"
|
|
ERROR = "error"
|
|
|
|
|
|
class ViolationCategory(Enum):
|
|
"""Categories of content violations."""
|
|
|
|
HATE_SPEECH = "hate_speech"
|
|
VIOLENCE_GLORIFICATION = "violence_glorification"
|
|
REAL_WORLD_HARM = "real_world_harm"
|
|
SEXUAL_CONTENT = "sexual_content"
|
|
SELF_HARM = "self_harm"
|
|
NONE = "none"
|
|
|
|
|
|
@dataclass
|
|
class ModerationResult:
|
|
"""Result from the moderation pipeline."""
|
|
|
|
verdict: ModerationVerdict
|
|
blocked: bool
|
|
category: ViolationCategory = ViolationCategory.NONE
|
|
confidence: float = 0.0
|
|
latency_ms: float = 0.0
|
|
layer: str = "" # Which layer caught it
|
|
fallback: str = "" # Contextual fallback narration
|
|
reason: str = ""
|
|
timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat())
|
|
|
|
@property
|
|
def passed(self) -> bool:
|
|
return self.verdict == ModerationVerdict.PASS
|
|
|
|
|
|
@dataclass
|
|
class GameProfile:
|
|
"""Per-game moderation profile.
|
|
|
|
Defines vocabulary whitelists, theme framing rules,
|
|
and threshold tuning for a specific game.
|
|
"""
|
|
|
|
game_id: str
|
|
display_name: str
|
|
# Terms that are safe in game context (e.g., "Skooma", "slave", "Morag Tong")
|
|
vocabulary_whitelist: list[str] = field(default_factory=list)
|
|
# System prompt additions for game-context framing
|
|
context_prompt: str = ""
|
|
# Confidence threshold — below this, content passes even if flagged
|
|
threshold: float = 0.8
|
|
# Scene-type fallback narrations
|
|
fallbacks: dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
# ── Regex fallback patterns ─────────────────────────────────────────────────
|
|
# Used when Llama Guard is unavailable. Catches obvious violations only.
|
|
# Intentionally conservative — false negatives preferred over false positives.
|
|
_FALLBACK_PATTERNS: list[tuple[re.Pattern, ViolationCategory]] = [
|
|
(
|
|
re.compile(
|
|
r"\b(kill\s+(?:all|every)\s+(?:real|actual)\s+(?:people|humans|children))\b",
|
|
re.IGNORECASE,
|
|
),
|
|
ViolationCategory.REAL_WORLD_HARM,
|
|
),
|
|
(
|
|
re.compile(
|
|
r"\b((?:you|we)\s+should\s+(?:actually|really)\s+(?:harm|hurt|kill))\b",
|
|
re.IGNORECASE,
|
|
),
|
|
ViolationCategory.REAL_WORLD_HARM,
|
|
),
|
|
(
|
|
re.compile(
|
|
r"\b((?:in\s+real\s+life|irl)\s+(?:you|we)\s+(?:should|could|must)\s+"
|
|
r"(?:attack|hurt|kill|harm))\b",
|
|
re.IGNORECASE,
|
|
),
|
|
ViolationCategory.REAL_WORLD_HARM,
|
|
),
|
|
]
|
|
|
|
# ── Default game profiles ───────────────────────────────────────────────────
|
|
|
|
_DEFAULT_PROFILES: dict[str, GameProfile] = {
|
|
"morrowind": GameProfile(
|
|
game_id="morrowind",
|
|
display_name="The Elder Scrolls III: Morrowind",
|
|
vocabulary_whitelist=[
|
|
"Skooma",
|
|
"Moon Sugar",
|
|
"slave",
|
|
"slavery",
|
|
"Morag Tong",
|
|
"Dark Brotherhood",
|
|
"Telvanni",
|
|
"Camonna Tong",
|
|
"smuggler",
|
|
"assassin",
|
|
"Sixth House",
|
|
"Corprus",
|
|
"Dagoth Ur",
|
|
"Nerevarine",
|
|
],
|
|
context_prompt=(
|
|
"You are narrating gameplay of The Elder Scrolls III: Morrowind. "
|
|
"Morrowind contains mature themes including slavery, drug use (Skooma/Moon Sugar), "
|
|
"assassin guilds (Morag Tong, Dark Brotherhood), and political intrigue. "
|
|
"Treat these as game mechanics and historical worldbuilding within the game's "
|
|
"fictional universe. Never editorialize on real-world parallels. "
|
|
"Narrate events neutrally as a game commentator would."
|
|
),
|
|
threshold=0.85,
|
|
fallbacks={
|
|
"combat": "The battle rages on in the ashlands of Vvardenfell.",
|
|
"dialogue": "The conversation continues between the characters.",
|
|
"exploration": "The Nerevarine presses onward through the landscape.",
|
|
"default": "The adventure continues in Morrowind.",
|
|
},
|
|
),
|
|
"default": GameProfile(
|
|
game_id="default",
|
|
display_name="Generic Game",
|
|
vocabulary_whitelist=[],
|
|
context_prompt=(
|
|
"You are narrating gameplay. Describe in-game events as a neutral "
|
|
"game commentator. Never reference real-world violence, politics, "
|
|
"or controversial topics. Stay focused on game mechanics and story."
|
|
),
|
|
threshold=0.8,
|
|
fallbacks={
|
|
"combat": "The action continues on screen.",
|
|
"dialogue": "The conversation unfolds between characters.",
|
|
"exploration": "The player explores the game world.",
|
|
"default": "The gameplay continues.",
|
|
},
|
|
),
|
|
}
|
|
|
|
|
|
class ContentModerator:
|
|
"""Three-layer content moderation pipeline.
|
|
|
|
Layer 1: Game-context system prompts with vocabulary whitelists.
|
|
Layer 2: LLM-based moderation (Llama Guard via Ollama, with regex fallback).
|
|
Layer 3: Per-game threshold tuning and profile-based filtering.
|
|
|
|
Follows graceful degradation — if Llama Guard is unavailable,
|
|
falls back to regex patterns. Never crashes.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
profiles: dict[str, GameProfile] | None = None,
|
|
guard_model: str | None = None,
|
|
) -> None:
|
|
self._profiles: dict[str, GameProfile] = profiles or dict(_DEFAULT_PROFILES)
|
|
self._guard_model = guard_model or settings.moderation_guard_model
|
|
self._guard_available: bool | None = None # Lazy-checked
|
|
self._metrics = _ModerationMetrics()
|
|
|
|
def get_profile(self, game: str) -> GameProfile:
|
|
"""Get the moderation profile for a game, falling back to default."""
|
|
return self._profiles.get(game, self._profiles["default"])
|
|
|
|
def register_profile(self, profile: GameProfile) -> None:
|
|
"""Register or update a game moderation profile."""
|
|
self._profiles[profile.game_id] = profile
|
|
logger.info("Registered moderation profile: %s", profile.game_id)
|
|
|
|
def get_context_prompt(self, game: str) -> str:
|
|
"""Get the game-context system prompt (Layer 1).
|
|
|
|
Returns the context prompt for the given game, which should be
|
|
prepended to the narrator's system prompt.
|
|
"""
|
|
profile = self.get_profile(game)
|
|
return profile.context_prompt
|
|
|
|
async def check(
|
|
self,
|
|
text: str,
|
|
game: str = "default",
|
|
scene_type: str = "default",
|
|
) -> ModerationResult:
|
|
"""Run the full moderation pipeline on narrator output.
|
|
|
|
Args:
|
|
text: The text to moderate (narrator output).
|
|
game: Game identifier for profile selection.
|
|
scene_type: Current scene type for fallback selection.
|
|
|
|
Returns:
|
|
ModerationResult with verdict, confidence, and fallback.
|
|
"""
|
|
start = time.monotonic()
|
|
profile = self.get_profile(game)
|
|
|
|
# Layer 1: Vocabulary whitelist pre-processing
|
|
cleaned_text = self._apply_whitelist(text, profile)
|
|
|
|
# Layer 2: LLM guard or regex fallback
|
|
result = await self._run_guard(cleaned_text, profile)
|
|
|
|
# Layer 3: Threshold tuning
|
|
if result.verdict == ModerationVerdict.FAIL and result.confidence < profile.threshold:
|
|
logger.info(
|
|
"Moderation flag below threshold (%.2f < %.2f) — allowing",
|
|
result.confidence,
|
|
profile.threshold,
|
|
)
|
|
result = ModerationResult(
|
|
verdict=ModerationVerdict.PASS,
|
|
blocked=False,
|
|
confidence=result.confidence,
|
|
layer="threshold",
|
|
reason=f"Below threshold ({result.confidence:.2f} < {profile.threshold:.2f})",
|
|
)
|
|
|
|
# Attach fallback narration if blocked
|
|
if result.blocked:
|
|
result.fallback = profile.fallbacks.get(
|
|
scene_type, profile.fallbacks.get("default", "")
|
|
)
|
|
|
|
result.latency_ms = (time.monotonic() - start) * 1000
|
|
self._metrics.record(result)
|
|
|
|
if result.blocked:
|
|
logger.warning(
|
|
"Content blocked [%s/%s]: category=%s confidence=%.2f reason=%s",
|
|
game,
|
|
scene_type,
|
|
result.category.value,
|
|
result.confidence,
|
|
result.reason,
|
|
)
|
|
|
|
return result
|
|
|
|
def _apply_whitelist(self, text: str, profile: GameProfile) -> str:
|
|
"""Layer 1: Replace whitelisted game terms with placeholders.
|
|
|
|
This prevents the guard model from flagging in-game terminology
|
|
(e.g., "Skooma" being flagged as drug reference).
|
|
"""
|
|
cleaned = text
|
|
for term in profile.vocabulary_whitelist:
|
|
# Case-insensitive replacement with a neutral placeholder
|
|
pattern = re.compile(re.escape(term), re.IGNORECASE)
|
|
cleaned = pattern.sub("[GAME_TERM]", cleaned)
|
|
return cleaned
|
|
|
|
async def _run_guard(self, text: str, profile: GameProfile) -> ModerationResult:
|
|
"""Layer 2: Run LLM guard model or fall back to regex."""
|
|
if not settings.moderation_enabled:
|
|
return ModerationResult(
|
|
verdict=ModerationVerdict.PASS,
|
|
blocked=False,
|
|
layer="disabled",
|
|
reason="Moderation disabled",
|
|
)
|
|
|
|
# Try Llama Guard via Ollama
|
|
if await self._is_guard_available():
|
|
try:
|
|
return await self._check_with_guard(text)
|
|
except Exception as exc:
|
|
logger.warning("Guard model failed, using regex fallback: %s", exc)
|
|
self._guard_available = False
|
|
|
|
# Regex fallback
|
|
return self._check_with_regex(text)
|
|
|
|
async def _is_guard_available(self) -> bool:
|
|
"""Check if the guard model is available via Ollama."""
|
|
if self._guard_available is not None:
|
|
return self._guard_available
|
|
|
|
try:
|
|
import aiohttp
|
|
|
|
url = f"{settings.normalized_ollama_url}/api/tags"
|
|
timeout = aiohttp.ClientTimeout(total=5)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.get(url) as resp:
|
|
if resp.status != 200:
|
|
self._guard_available = False
|
|
return False
|
|
data = await resp.json()
|
|
models = [m.get("name", "") for m in data.get("models", [])]
|
|
self._guard_available = any(
|
|
self._guard_model in m or m.startswith(self._guard_model) for m in models
|
|
)
|
|
if not self._guard_available:
|
|
logger.info(
|
|
"Guard model '%s' not found in Ollama — using regex fallback",
|
|
self._guard_model,
|
|
)
|
|
return self._guard_available
|
|
except Exception as exc:
|
|
logger.debug("Ollama guard check failed: %s", exc)
|
|
self._guard_available = False
|
|
return False
|
|
|
|
async def _check_with_guard(self, text: str) -> ModerationResult:
|
|
"""Run moderation check via Llama Guard."""
|
|
import aiohttp
|
|
|
|
url = f"{settings.normalized_ollama_url}/api/chat"
|
|
payload = {
|
|
"model": self._guard_model,
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": text,
|
|
}
|
|
],
|
|
"stream": False,
|
|
"options": {"temperature": 0.0},
|
|
}
|
|
|
|
timeout = aiohttp.ClientTimeout(total=10)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(url, json=payload) as resp:
|
|
if resp.status != 200:
|
|
raise RuntimeError(f"Guard API error: {resp.status}")
|
|
data = await resp.json()
|
|
|
|
response_text = data.get("message", {}).get("content", "").strip().lower()
|
|
|
|
# Llama Guard returns "safe" or "unsafe\n<category>"
|
|
if response_text.startswith("safe"):
|
|
return ModerationResult(
|
|
verdict=ModerationVerdict.PASS,
|
|
blocked=False,
|
|
confidence=0.0,
|
|
layer="llama_guard",
|
|
reason="Content safe",
|
|
)
|
|
|
|
# Parse unsafe response
|
|
category = ViolationCategory.NONE
|
|
confidence = 0.95 # High confidence from LLM guard
|
|
lines = response_text.split("\n")
|
|
if len(lines) > 1:
|
|
cat_str = lines[1].strip()
|
|
category = _parse_guard_category(cat_str)
|
|
|
|
return ModerationResult(
|
|
verdict=ModerationVerdict.FAIL,
|
|
blocked=True,
|
|
category=category,
|
|
confidence=confidence,
|
|
layer="llama_guard",
|
|
reason=f"Guard flagged: {response_text}",
|
|
)
|
|
|
|
def _check_with_regex(self, text: str) -> ModerationResult:
|
|
"""Regex fallback when guard model is unavailable.
|
|
|
|
Intentionally conservative — only catches obvious real-world harm.
|
|
"""
|
|
for pattern, category in _FALLBACK_PATTERNS:
|
|
match = pattern.search(text)
|
|
if match:
|
|
return ModerationResult(
|
|
verdict=ModerationVerdict.FAIL,
|
|
blocked=True,
|
|
category=category,
|
|
confidence=0.95, # Regex patterns are high-signal
|
|
layer="regex_fallback",
|
|
reason=f"Regex match: {match.group(0)[:50]}",
|
|
)
|
|
|
|
return ModerationResult(
|
|
verdict=ModerationVerdict.PASS,
|
|
blocked=False,
|
|
layer="regex_fallback",
|
|
reason="No regex matches",
|
|
)
|
|
|
|
def get_metrics(self) -> dict[str, Any]:
|
|
"""Get moderation pipeline metrics."""
|
|
return self._metrics.to_dict()
|
|
|
|
def reset_guard_cache(self) -> None:
|
|
"""Reset the guard availability cache (e.g., after pulling model)."""
|
|
self._guard_available = None
|
|
|
|
|
|
class _ModerationMetrics:
|
|
"""Tracks moderation pipeline performance."""
|
|
|
|
def __init__(self) -> None:
|
|
self.total_checks: int = 0
|
|
self.passed: int = 0
|
|
self.blocked: int = 0
|
|
self.errors: int = 0
|
|
self.total_latency_ms: float = 0.0
|
|
self.by_layer: dict[str, int] = {}
|
|
self.by_category: dict[str, int] = {}
|
|
|
|
def record(self, result: ModerationResult) -> None:
|
|
self.total_checks += 1
|
|
self.total_latency_ms += result.latency_ms
|
|
|
|
if result.verdict == ModerationVerdict.PASS:
|
|
self.passed += 1
|
|
elif result.verdict == ModerationVerdict.FAIL:
|
|
self.blocked += 1
|
|
else:
|
|
self.errors += 1
|
|
|
|
layer = result.layer or "unknown"
|
|
self.by_layer[layer] = self.by_layer.get(layer, 0) + 1
|
|
|
|
if result.blocked:
|
|
cat = result.category.value
|
|
self.by_category[cat] = self.by_category.get(cat, 0) + 1
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"total_checks": self.total_checks,
|
|
"passed": self.passed,
|
|
"blocked": self.blocked,
|
|
"errors": self.errors,
|
|
"avg_latency_ms": (
|
|
round(self.total_latency_ms / self.total_checks, 2)
|
|
if self.total_checks > 0
|
|
else 0.0
|
|
),
|
|
"by_layer": dict(self.by_layer),
|
|
"by_category": dict(self.by_category),
|
|
}
|
|
|
|
|
|
def _parse_guard_category(cat_str: str) -> ViolationCategory:
|
|
"""Parse Llama Guard category string to ViolationCategory."""
|
|
cat_lower = cat_str.lower()
|
|
if "hate" in cat_lower:
|
|
return ViolationCategory.HATE_SPEECH
|
|
if "violence" in cat_lower:
|
|
return ViolationCategory.VIOLENCE_GLORIFICATION
|
|
if "sexual" in cat_lower:
|
|
return ViolationCategory.SEXUAL_CONTENT
|
|
if "self-harm" in cat_lower or "self_harm" in cat_lower or "suicide" in cat_lower:
|
|
return ViolationCategory.SELF_HARM
|
|
if "harm" in cat_lower or "dangerous" in cat_lower:
|
|
return ViolationCategory.REAL_WORLD_HARM
|
|
return ViolationCategory.NONE
|
|
|
|
|
|
# ── Module-level singleton ──────────────────────────────────────────────────
|
|
_moderator: ContentModerator | None = None
|
|
|
|
|
|
def get_moderator() -> ContentModerator:
|
|
"""Get or create the content moderator singleton."""
|
|
global _moderator
|
|
if _moderator is None:
|
|
_moderator = ContentModerator()
|
|
return _moderator
|