From 1697e55cdb48e3dffe7853151d67e0b4e12b259b Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Mon, 23 Mar 2026 02:14:42 +0000 Subject: [PATCH 1/3] [claude] Add content moderation pipeline (Llama Guard + game-context prompts) (#1056) (#1059) --- config/moderation.yaml | 107 +++++ src/config.py | 8 + src/infrastructure/guards/__init__.py | 7 + src/infrastructure/guards/moderation.py | 500 ++++++++++++++++++++++++ src/infrastructure/guards/profiles.py | 56 +++ tests/infrastructure/test_moderation.py | 335 ++++++++++++++++ 6 files changed, 1013 insertions(+) create mode 100644 config/moderation.yaml create mode 100644 src/infrastructure/guards/__init__.py create mode 100644 src/infrastructure/guards/moderation.py create mode 100644 src/infrastructure/guards/profiles.py create mode 100644 tests/infrastructure/test_moderation.py diff --git a/config/moderation.yaml b/config/moderation.yaml new file mode 100644 index 00000000..875e554d --- /dev/null +++ b/config/moderation.yaml @@ -0,0 +1,107 @@ +# Content Moderation Profiles +# Per-game moderation configuration for the AI narrator pipeline. +# +# Each profile defines: +# - vocabulary_whitelist: Game terms safe in context (won't trigger moderation) +# - context_prompt: System prompt framing for the narrator +# - threshold: Confidence threshold — flags below this pass through +# - fallbacks: Pre-generated safe narration by scene type +# +# Model options (from research): +# llama-guard3:1b — Speed (<30ms/sentence, INT4 quantized) +# shieldgemma:2b — Accuracy (+10.8% AU-PRC, ~50-100ms) +# +# Override guard model via MODERATION_GUARD_MODEL env var. + +# ── Guard model selection ──────────────────────────────────────────────────── +guard_model: "llama-guard3:1b" + +# ── Streaming disclosure notes ─────────────────────────────────────────────── +# YouTube: Use "Altered or synthetic content" toggle +# Twitch: Standard community guidelines (no specific AI disclosure req as of 2026-03) + +# ── Game Profiles ──────────────────────────────────────────────────────────── +profiles: + + morrowind: + display_name: "The Elder Scrolls III: Morrowind" + threshold: 0.85 + vocabulary_whitelist: + - Skooma + - Moon Sugar + - slave + - slavery + - Morag Tong + - Dark Brotherhood + - Telvanni + - Camonna Tong + - smuggler + - assassin + - Sixth House + - Corprus + - Dagoth Ur + - Nerevarine + - Balmora + - Vivec + - Almsivi + - Ordinators + - Ashlanders + - outlander + - N'wah + context_prompt: > + You are narrating gameplay of The Elder Scrolls III: Morrowind. + Morrowind contains mature themes including slavery, drug use + (Skooma/Moon Sugar), assassin guilds (Morag Tong, Dark Brotherhood), + and political intrigue. Treat these as game mechanics and historical + worldbuilding within the game's fictional universe. Never editorialize + on real-world parallels. Narrate events neutrally as a game + commentator would. + fallbacks: + combat: "The battle rages on in the ashlands of Vvardenfell." + dialogue: "The conversation continues between the characters." + exploration: "The Nerevarine presses onward through the landscape." + quest: "The quest unfolds as the hero navigates Morrowind's politics." + default: "The adventure continues in Morrowind." + + skyrim: + display_name: "The Elder Scrolls V: Skyrim" + threshold: 0.85 + vocabulary_whitelist: + - Skooma + - Dark Brotherhood + - Thieves Guild + - Stormcloak + - Imperial + - Dragonborn + - Dovahkiin + - Daedra + - Thalmor + - bandit + - assassin + - Forsworn + - necromancer + context_prompt: > + You are narrating gameplay of The Elder Scrolls V: Skyrim. + Skyrim features civil war, thieves guilds, assassin organizations, + and fantasy violence. Treat all content as in-game fiction. + Never draw real-world parallels. Narrate as a neutral game + commentator. + fallbacks: + combat: "Steel clashes as the battle continues in the wilds of Skyrim." + dialogue: "The conversation plays out in the cold northern land." + exploration: "The Dragonborn ventures further into the province." + default: "The adventure continues in Skyrim." + + default: + display_name: "Generic Game" + threshold: 0.80 + vocabulary_whitelist: [] + context_prompt: > + You are narrating gameplay. Describe in-game events as a neutral + game commentator. Never reference real-world violence, politics, + or controversial topics. Stay focused on game mechanics and story. + fallbacks: + combat: "The action continues on screen." + dialogue: "The conversation unfolds between characters." + exploration: "The player explores the game world." + default: "The gameplay continues." diff --git a/src/config.py b/src/config.py index 62b46809..c213e563 100644 --- a/src/config.py +++ b/src/config.py @@ -99,6 +99,14 @@ class Settings(BaseSettings): anthropic_api_key: str = "" claude_model: str = "haiku" + # ── Content Moderation ────────────────────────────────────────────── + # Three-layer moderation pipeline for AI narrator output. + # Uses Llama Guard via Ollama with regex fallback. + moderation_enabled: bool = True + moderation_guard_model: str = "llama-guard3:1b" + # Default confidence threshold — per-game profiles can override. + moderation_threshold: float = 0.8 + # ── Spark Intelligence ──────────────────────────────────────────────── # Enable/disable the Spark cognitive layer. # When enabled, Spark captures swarm events, runs EIDOS predictions, diff --git a/src/infrastructure/guards/__init__.py b/src/infrastructure/guards/__init__.py new file mode 100644 index 00000000..c0dfe23f --- /dev/null +++ b/src/infrastructure/guards/__init__.py @@ -0,0 +1,7 @@ +"""Content moderation pipeline for AI narrator output. + +Three-layer defense: +1. Game-context system prompts (vocabulary whitelists, theme framing) +2. Real-time output filter via Llama Guard (or fallback regex) +3. Per-game moderation profiles with configurable thresholds +""" diff --git a/src/infrastructure/guards/moderation.py b/src/infrastructure/guards/moderation.py new file mode 100644 index 00000000..7af53c24 --- /dev/null +++ b/src/infrastructure/guards/moderation.py @@ -0,0 +1,500 @@ +"""Content moderation pipeline for AI narrator output. + +Three-layer defense against harmful LLM output: + +Layer 1 — Game-context system prompts with per-game vocabulary whitelists. +Layer 2 — Real-time output filter (Llama Guard via Ollama, regex fallback). +Layer 3 — Per-game moderation profiles with configurable thresholds. + +Usage: + from infrastructure.guards.moderation import get_moderator + + moderator = get_moderator() + result = await moderator.check("Some narrator text", game="morrowind") + if result.blocked: + use_fallback_narration(result.fallback) +""" + +import logging +import re +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from config import settings + +logger = logging.getLogger(__name__) + + +class ModerationVerdict(Enum): + """Result of a moderation check.""" + + PASS = "pass" + FAIL = "fail" + ERROR = "error" + + +class ViolationCategory(Enum): + """Categories of content violations.""" + + HATE_SPEECH = "hate_speech" + VIOLENCE_GLORIFICATION = "violence_glorification" + REAL_WORLD_HARM = "real_world_harm" + SEXUAL_CONTENT = "sexual_content" + SELF_HARM = "self_harm" + NONE = "none" + + +@dataclass +class ModerationResult: + """Result from the moderation pipeline.""" + + verdict: ModerationVerdict + blocked: bool + category: ViolationCategory = ViolationCategory.NONE + confidence: float = 0.0 + latency_ms: float = 0.0 + layer: str = "" # Which layer caught it + fallback: str = "" # Contextual fallback narration + reason: str = "" + timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + + @property + def passed(self) -> bool: + return self.verdict == ModerationVerdict.PASS + + +@dataclass +class GameProfile: + """Per-game moderation profile. + + Defines vocabulary whitelists, theme framing rules, + and threshold tuning for a specific game. + """ + + game_id: str + display_name: str + # Terms that are safe in game context (e.g., "Skooma", "slave", "Morag Tong") + vocabulary_whitelist: list[str] = field(default_factory=list) + # System prompt additions for game-context framing + context_prompt: str = "" + # Confidence threshold — below this, content passes even if flagged + threshold: float = 0.8 + # Scene-type fallback narrations + fallbacks: dict[str, str] = field(default_factory=dict) + + +# ── Regex fallback patterns ───────────────────────────────────────────────── +# Used when Llama Guard is unavailable. Catches obvious violations only. +# Intentionally conservative — false negatives preferred over false positives. +_FALLBACK_PATTERNS: list[tuple[re.Pattern, ViolationCategory]] = [ + ( + re.compile( + r"\b(kill\s+(?:all|every)\s+(?:real|actual)\s+(?:people|humans|children))\b", + re.IGNORECASE, + ), + ViolationCategory.REAL_WORLD_HARM, + ), + ( + re.compile( + r"\b((?:you|we)\s+should\s+(?:actually|really)\s+(?:harm|hurt|kill))\b", + re.IGNORECASE, + ), + ViolationCategory.REAL_WORLD_HARM, + ), + ( + re.compile( + r"\b((?:in\s+real\s+life|irl)\s+(?:you|we)\s+(?:should|could|must)\s+" + r"(?:attack|hurt|kill|harm))\b", + re.IGNORECASE, + ), + ViolationCategory.REAL_WORLD_HARM, + ), +] + +# ── Default game profiles ─────────────────────────────────────────────────── + +_DEFAULT_PROFILES: dict[str, GameProfile] = { + "morrowind": GameProfile( + game_id="morrowind", + display_name="The Elder Scrolls III: Morrowind", + vocabulary_whitelist=[ + "Skooma", + "Moon Sugar", + "slave", + "slavery", + "Morag Tong", + "Dark Brotherhood", + "Telvanni", + "Camonna Tong", + "smuggler", + "assassin", + "Sixth House", + "Corprus", + "Dagoth Ur", + "Nerevarine", + ], + context_prompt=( + "You are narrating gameplay of The Elder Scrolls III: Morrowind. " + "Morrowind contains mature themes including slavery, drug use (Skooma/Moon Sugar), " + "assassin guilds (Morag Tong, Dark Brotherhood), and political intrigue. " + "Treat these as game mechanics and historical worldbuilding within the game's " + "fictional universe. Never editorialize on real-world parallels. " + "Narrate events neutrally as a game commentator would." + ), + threshold=0.85, + fallbacks={ + "combat": "The battle rages on in the ashlands of Vvardenfell.", + "dialogue": "The conversation continues between the characters.", + "exploration": "The Nerevarine presses onward through the landscape.", + "default": "The adventure continues in Morrowind.", + }, + ), + "default": GameProfile( + game_id="default", + display_name="Generic Game", + vocabulary_whitelist=[], + context_prompt=( + "You are narrating gameplay. Describe in-game events as a neutral " + "game commentator. Never reference real-world violence, politics, " + "or controversial topics. Stay focused on game mechanics and story." + ), + threshold=0.8, + fallbacks={ + "combat": "The action continues on screen.", + "dialogue": "The conversation unfolds between characters.", + "exploration": "The player explores the game world.", + "default": "The gameplay continues.", + }, + ), +} + + +class ContentModerator: + """Three-layer content moderation pipeline. + + Layer 1: Game-context system prompts with vocabulary whitelists. + Layer 2: LLM-based moderation (Llama Guard via Ollama, with regex fallback). + Layer 3: Per-game threshold tuning and profile-based filtering. + + Follows graceful degradation — if Llama Guard is unavailable, + falls back to regex patterns. Never crashes. + """ + + def __init__( + self, + profiles: dict[str, GameProfile] | None = None, + guard_model: str | None = None, + ) -> None: + self._profiles: dict[str, GameProfile] = profiles or dict(_DEFAULT_PROFILES) + self._guard_model = guard_model or settings.moderation_guard_model + self._guard_available: bool | None = None # Lazy-checked + self._metrics = _ModerationMetrics() + + def get_profile(self, game: str) -> GameProfile: + """Get the moderation profile for a game, falling back to default.""" + return self._profiles.get(game, self._profiles["default"]) + + def register_profile(self, profile: GameProfile) -> None: + """Register or update a game moderation profile.""" + self._profiles[profile.game_id] = profile + logger.info("Registered moderation profile: %s", profile.game_id) + + def get_context_prompt(self, game: str) -> str: + """Get the game-context system prompt (Layer 1). + + Returns the context prompt for the given game, which should be + prepended to the narrator's system prompt. + """ + profile = self.get_profile(game) + return profile.context_prompt + + async def check( + self, + text: str, + game: str = "default", + scene_type: str = "default", + ) -> ModerationResult: + """Run the full moderation pipeline on narrator output. + + Args: + text: The text to moderate (narrator output). + game: Game identifier for profile selection. + scene_type: Current scene type for fallback selection. + + Returns: + ModerationResult with verdict, confidence, and fallback. + """ + start = time.monotonic() + profile = self.get_profile(game) + + # Layer 1: Vocabulary whitelist pre-processing + cleaned_text = self._apply_whitelist(text, profile) + + # Layer 2: LLM guard or regex fallback + result = await self._run_guard(cleaned_text, profile) + + # Layer 3: Threshold tuning + if result.verdict == ModerationVerdict.FAIL and result.confidence < profile.threshold: + logger.info( + "Moderation flag below threshold (%.2f < %.2f) — allowing", + result.confidence, + profile.threshold, + ) + result = ModerationResult( + verdict=ModerationVerdict.PASS, + blocked=False, + confidence=result.confidence, + layer="threshold", + reason=f"Below threshold ({result.confidence:.2f} < {profile.threshold:.2f})", + ) + + # Attach fallback narration if blocked + if result.blocked: + result.fallback = profile.fallbacks.get( + scene_type, profile.fallbacks.get("default", "") + ) + + result.latency_ms = (time.monotonic() - start) * 1000 + self._metrics.record(result) + + if result.blocked: + logger.warning( + "Content blocked [%s/%s]: category=%s confidence=%.2f reason=%s", + game, + scene_type, + result.category.value, + result.confidence, + result.reason, + ) + + return result + + def _apply_whitelist(self, text: str, profile: GameProfile) -> str: + """Layer 1: Replace whitelisted game terms with placeholders. + + This prevents the guard model from flagging in-game terminology + (e.g., "Skooma" being flagged as drug reference). + """ + cleaned = text + for term in profile.vocabulary_whitelist: + # Case-insensitive replacement with a neutral placeholder + pattern = re.compile(re.escape(term), re.IGNORECASE) + cleaned = pattern.sub("[GAME_TERM]", cleaned) + return cleaned + + async def _run_guard( + self, text: str, profile: GameProfile + ) -> ModerationResult: + """Layer 2: Run LLM guard model or fall back to regex.""" + if not settings.moderation_enabled: + return ModerationResult( + verdict=ModerationVerdict.PASS, + blocked=False, + layer="disabled", + reason="Moderation disabled", + ) + + # Try Llama Guard via Ollama + if await self._is_guard_available(): + try: + return await self._check_with_guard(text) + except Exception as exc: + logger.warning("Guard model failed, using regex fallback: %s", exc) + self._guard_available = False + + # Regex fallback + return self._check_with_regex(text) + + async def _is_guard_available(self) -> bool: + """Check if the guard model is available via Ollama.""" + if self._guard_available is not None: + return self._guard_available + + try: + import aiohttp + + url = f"{settings.normalized_ollama_url}/api/tags" + timeout = aiohttp.ClientTimeout(total=5) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(url) as resp: + if resp.status != 200: + self._guard_available = False + return False + data = await resp.json() + models = [m.get("name", "") for m in data.get("models", [])] + self._guard_available = any( + self._guard_model in m or m.startswith(self._guard_model) + for m in models + ) + if not self._guard_available: + logger.info( + "Guard model '%s' not found in Ollama — using regex fallback", + self._guard_model, + ) + return self._guard_available + except Exception as exc: + logger.debug("Ollama guard check failed: %s", exc) + self._guard_available = False + return False + + async def _check_with_guard(self, text: str) -> ModerationResult: + """Run moderation check via Llama Guard.""" + import aiohttp + + url = f"{settings.normalized_ollama_url}/api/chat" + payload = { + "model": self._guard_model, + "messages": [ + { + "role": "user", + "content": text, + } + ], + "stream": False, + "options": {"temperature": 0.0}, + } + + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=payload) as resp: + if resp.status != 200: + raise RuntimeError(f"Guard API error: {resp.status}") + data = await resp.json() + + response_text = data.get("message", {}).get("content", "").strip().lower() + + # Llama Guard returns "safe" or "unsafe\n" + if response_text.startswith("safe"): + return ModerationResult( + verdict=ModerationVerdict.PASS, + blocked=False, + confidence=0.0, + layer="llama_guard", + reason="Content safe", + ) + + # Parse unsafe response + category = ViolationCategory.NONE + confidence = 0.95 # High confidence from LLM guard + lines = response_text.split("\n") + if len(lines) > 1: + cat_str = lines[1].strip() + category = _parse_guard_category(cat_str) + + return ModerationResult( + verdict=ModerationVerdict.FAIL, + blocked=True, + category=category, + confidence=confidence, + layer="llama_guard", + reason=f"Guard flagged: {response_text}", + ) + + def _check_with_regex(self, text: str) -> ModerationResult: + """Regex fallback when guard model is unavailable. + + Intentionally conservative — only catches obvious real-world harm. + """ + for pattern, category in _FALLBACK_PATTERNS: + match = pattern.search(text) + if match: + return ModerationResult( + verdict=ModerationVerdict.FAIL, + blocked=True, + category=category, + confidence=0.95, # Regex patterns are high-signal + layer="regex_fallback", + reason=f"Regex match: {match.group(0)[:50]}", + ) + + return ModerationResult( + verdict=ModerationVerdict.PASS, + blocked=False, + layer="regex_fallback", + reason="No regex matches", + ) + + def get_metrics(self) -> dict[str, Any]: + """Get moderation pipeline metrics.""" + return self._metrics.to_dict() + + def reset_guard_cache(self) -> None: + """Reset the guard availability cache (e.g., after pulling model).""" + self._guard_available = None + + +class _ModerationMetrics: + """Tracks moderation pipeline performance.""" + + def __init__(self) -> None: + self.total_checks: int = 0 + self.passed: int = 0 + self.blocked: int = 0 + self.errors: int = 0 + self.total_latency_ms: float = 0.0 + self.by_layer: dict[str, int] = {} + self.by_category: dict[str, int] = {} + + def record(self, result: ModerationResult) -> None: + self.total_checks += 1 + self.total_latency_ms += result.latency_ms + + if result.verdict == ModerationVerdict.PASS: + self.passed += 1 + elif result.verdict == ModerationVerdict.FAIL: + self.blocked += 1 + else: + self.errors += 1 + + layer = result.layer or "unknown" + self.by_layer[layer] = self.by_layer.get(layer, 0) + 1 + + if result.blocked: + cat = result.category.value + self.by_category[cat] = self.by_category.get(cat, 0) + 1 + + def to_dict(self) -> dict[str, Any]: + return { + "total_checks": self.total_checks, + "passed": self.passed, + "blocked": self.blocked, + "errors": self.errors, + "avg_latency_ms": ( + round(self.total_latency_ms / self.total_checks, 2) + if self.total_checks > 0 + else 0.0 + ), + "by_layer": dict(self.by_layer), + "by_category": dict(self.by_category), + } + + +def _parse_guard_category(cat_str: str) -> ViolationCategory: + """Parse Llama Guard category string to ViolationCategory.""" + cat_lower = cat_str.lower() + if "hate" in cat_lower: + return ViolationCategory.HATE_SPEECH + if "violence" in cat_lower: + return ViolationCategory.VIOLENCE_GLORIFICATION + if "sexual" in cat_lower: + return ViolationCategory.SEXUAL_CONTENT + if "self-harm" in cat_lower or "self_harm" in cat_lower or "suicide" in cat_lower: + return ViolationCategory.SELF_HARM + if "harm" in cat_lower or "dangerous" in cat_lower: + return ViolationCategory.REAL_WORLD_HARM + return ViolationCategory.NONE + + +# ── Module-level singleton ────────────────────────────────────────────────── +_moderator: ContentModerator | None = None + + +def get_moderator() -> ContentModerator: + """Get or create the content moderator singleton.""" + global _moderator + if _moderator is None: + _moderator = ContentModerator() + return _moderator diff --git a/src/infrastructure/guards/profiles.py b/src/infrastructure/guards/profiles.py new file mode 100644 index 00000000..c96ce2ca --- /dev/null +++ b/src/infrastructure/guards/profiles.py @@ -0,0 +1,56 @@ +"""Load game moderation profiles from config/moderation.yaml. + +Falls back to hardcoded defaults if the YAML file is missing or malformed. +""" + +import logging +from pathlib import Path + +from infrastructure.guards.moderation import GameProfile + +logger = logging.getLogger(__name__) + + +def load_profiles(config_path: Path | None = None) -> dict[str, GameProfile]: + """Load game moderation profiles from YAML config. + + Args: + config_path: Path to moderation.yaml. Defaults to config/moderation.yaml. + + Returns: + Dict mapping game_id to GameProfile. + """ + path = config_path or Path("config/moderation.yaml") + + if not path.exists(): + logger.info("Moderation config not found at %s — using defaults", path) + return {} + + try: + import yaml + except ImportError: + logger.warning("PyYAML not installed — using default moderation profiles") + return {} + + try: + data = yaml.safe_load(path.read_text()) + except Exception as exc: + logger.error("Failed to parse moderation config: %s", exc) + return {} + + profiles: dict[str, GameProfile] = {} + for game_id, profile_data in data.get("profiles", {}).items(): + try: + profiles[game_id] = GameProfile( + game_id=game_id, + display_name=profile_data.get("display_name", game_id), + vocabulary_whitelist=profile_data.get("vocabulary_whitelist", []), + context_prompt=profile_data.get("context_prompt", ""), + threshold=float(profile_data.get("threshold", 0.8)), + fallbacks=profile_data.get("fallbacks", {}), + ) + except Exception as exc: + logger.warning("Invalid profile '%s': %s", game_id, exc) + + logger.info("Loaded %d moderation profiles from %s", len(profiles), path) + return profiles diff --git a/tests/infrastructure/test_moderation.py b/tests/infrastructure/test_moderation.py new file mode 100644 index 00000000..add8c1b5 --- /dev/null +++ b/tests/infrastructure/test_moderation.py @@ -0,0 +1,335 @@ +"""Tests for the content moderation pipeline.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from infrastructure.guards.moderation import ( + ContentModerator, + GameProfile, + ModerationResult, + ModerationVerdict, + ViolationCategory, + _parse_guard_category, + get_moderator, +) + + +# ── Unit tests for data types ──────────────────────────────────────────────── + + +class TestModerationResult: + """Test ModerationResult dataclass.""" + + def test_passed_property_true(self): + result = ModerationResult(verdict=ModerationVerdict.PASS, blocked=False) + assert result.passed is True + + def test_passed_property_false(self): + result = ModerationResult(verdict=ModerationVerdict.FAIL, blocked=True) + assert result.passed is False + + def test_default_values(self): + result = ModerationResult(verdict=ModerationVerdict.PASS, blocked=False) + assert result.category == ViolationCategory.NONE + assert result.confidence == 0.0 + assert result.fallback == "" + assert result.reason == "" + + +class TestGameProfile: + """Test GameProfile dataclass.""" + + def test_default_values(self): + profile = GameProfile(game_id="test", display_name="Test Game") + assert profile.vocabulary_whitelist == [] + assert profile.threshold == 0.8 + assert profile.fallbacks == {} + + def test_morrowind_profile(self): + profile = GameProfile( + game_id="morrowind", + display_name="Morrowind", + vocabulary_whitelist=["Skooma", "slave"], + threshold=0.85, + ) + assert "Skooma" in profile.vocabulary_whitelist + assert profile.threshold == 0.85 + + +class TestParseGuardCategory: + """Test Llama Guard category parsing.""" + + def test_hate_speech(self): + assert _parse_guard_category("S1: Hate speech") == ViolationCategory.HATE_SPEECH + + def test_violence(self): + assert _parse_guard_category("S2: Violence") == ViolationCategory.VIOLENCE_GLORIFICATION + + def test_sexual_content(self): + assert _parse_guard_category("S3: Sexual content") == ViolationCategory.SEXUAL_CONTENT + + def test_self_harm(self): + assert _parse_guard_category("S4: Self-harm") == ViolationCategory.SELF_HARM + + def test_dangerous(self): + assert _parse_guard_category("S5: Dangerous activity") == ViolationCategory.REAL_WORLD_HARM + + def test_unknown_category(self): + assert _parse_guard_category("S99: Unknown") == ViolationCategory.NONE + + +# ── ContentModerator tests ─────────────────────────────────────────────────── + + +class TestContentModerator: + """Test the content moderation pipeline.""" + + def _make_moderator(self, **kwargs) -> ContentModerator: + """Create a moderator with test defaults.""" + profiles = { + "morrowind": GameProfile( + game_id="morrowind", + display_name="Morrowind", + vocabulary_whitelist=["Skooma", "Moon Sugar", "slave", "Morag Tong"], + context_prompt="Narrate Morrowind gameplay.", + threshold=0.85, + fallbacks={ + "combat": "The battle continues.", + "default": "The adventure continues.", + }, + ), + "default": GameProfile( + game_id="default", + display_name="Generic", + vocabulary_whitelist=[], + context_prompt="Narrate gameplay.", + threshold=0.8, + fallbacks={"default": "Gameplay continues."}, + ), + } + return ContentModerator(profiles=profiles, **kwargs) + + def test_get_profile_known_game(self): + mod = self._make_moderator() + profile = mod.get_profile("morrowind") + assert profile.game_id == "morrowind" + + def test_get_profile_unknown_game_falls_back(self): + mod = self._make_moderator() + profile = mod.get_profile("unknown_game") + assert profile.game_id == "default" + + def test_get_context_prompt(self): + mod = self._make_moderator() + prompt = mod.get_context_prompt("morrowind") + assert "Morrowind" in prompt + + def test_register_profile(self): + mod = self._make_moderator() + new_profile = GameProfile(game_id="skyrim", display_name="Skyrim") + mod.register_profile(new_profile) + assert mod.get_profile("skyrim").game_id == "skyrim" + + def test_whitelist_replaces_game_terms(self): + mod = self._make_moderator() + profile = mod.get_profile("morrowind") + cleaned = mod._apply_whitelist( + "The merchant sells Skooma and Moon Sugar in the slave market.", + profile, + ) + assert "Skooma" not in cleaned + assert "Moon Sugar" not in cleaned + assert "slave" not in cleaned + assert "[GAME_TERM]" in cleaned + + def test_whitelist_case_insensitive(self): + mod = self._make_moderator() + profile = mod.get_profile("morrowind") + cleaned = mod._apply_whitelist("skooma and SKOOMA", profile) + assert "skooma" not in cleaned + assert "SKOOMA" not in cleaned + + @pytest.mark.asyncio + async def test_check_safe_content_passes(self): + """Safe content should pass moderation.""" + mod = self._make_moderator() + with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False): + result = await mod.check("The player walks through the town.", game="morrowind") + assert result.passed + assert not result.blocked + + @pytest.mark.asyncio + async def test_check_blocked_content_has_fallback(self): + """Blocked content should include scene-appropriate fallback.""" + mod = self._make_moderator() + # Force a block via regex by using real-world harm language + text = "In real life you should attack and hurt people" + with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False): + result = await mod.check(text, game="morrowind", scene_type="combat") + assert result.blocked + assert result.fallback == "The battle continues." + + @pytest.mark.asyncio + async def test_check_with_moderation_disabled(self): + """When moderation is disabled, everything passes.""" + mod = self._make_moderator() + with patch("infrastructure.guards.moderation.settings") as mock_settings: + mock_settings.moderation_enabled = False + mock_settings.moderation_guard_model = "llama-guard3:1b" + mock_settings.normalized_ollama_url = "http://127.0.0.1:11434" + result = await mod.check("anything goes here") + assert result.passed + assert result.layer == "disabled" + + @pytest.mark.asyncio + async def test_threshold_below_allows_content(self): + """Content flagged below threshold should pass through (Layer 3).""" + mod = self._make_moderator() + # Mock the guard to return a low-confidence flag + low_conf_result = ModerationResult( + verdict=ModerationVerdict.FAIL, + blocked=True, + confidence=0.5, # Below morrowind threshold of 0.85 + layer="llama_guard", + category=ViolationCategory.VIOLENCE_GLORIFICATION, + ) + with patch.object( + mod, "_run_guard", new_callable=AsyncMock, return_value=low_conf_result + ): + result = await mod.check("sword fight scene", game="morrowind") + assert result.passed + assert not result.blocked + assert result.layer == "threshold" + + @pytest.mark.asyncio + async def test_threshold_above_blocks_content(self): + """Content flagged above threshold should remain blocked.""" + mod = self._make_moderator() + high_conf_result = ModerationResult( + verdict=ModerationVerdict.FAIL, + blocked=True, + confidence=0.95, # Above morrowind threshold of 0.85 + layer="llama_guard", + category=ViolationCategory.REAL_WORLD_HARM, + ) + with patch.object( + mod, "_run_guard", new_callable=AsyncMock, return_value=high_conf_result + ): + result = await mod.check("harmful content", game="morrowind") + assert result.blocked + + def test_regex_catches_real_world_harm(self): + """Regex fallback should catch obvious real-world harm patterns.""" + mod = self._make_moderator() + result = mod._check_with_regex("you should actually harm real people") + assert result.blocked + assert result.category == ViolationCategory.REAL_WORLD_HARM + assert result.layer == "regex_fallback" + + def test_regex_passes_game_violence(self): + """Regex should not flag in-game violence narration.""" + mod = self._make_moderator() + result = mod._check_with_regex( + "The warrior slays the dragon with a mighty blow." + ) + assert result.passed + + def test_regex_passes_normal_narration(self): + """Normal narration should pass regex checks.""" + mod = self._make_moderator() + result = mod._check_with_regex( + "The Nerevarine enters the city of Balmora and speaks with Caius Cosades." + ) + assert result.passed + + def test_metrics_tracking(self): + """Metrics should track checks accurately.""" + mod = self._make_moderator() + assert mod.get_metrics()["total_checks"] == 0 + + @pytest.mark.asyncio + async def test_metrics_increment_after_check(self): + """Metrics should increment after moderation checks.""" + mod = self._make_moderator() + with patch.object(mod, "_is_guard_available", new_callable=AsyncMock, return_value=False): + await mod.check("safe text", game="default") + metrics = mod.get_metrics() + assert metrics["total_checks"] == 1 + assert metrics["passed"] == 1 + + @pytest.mark.asyncio + async def test_guard_fallback_on_error(self): + """Should fall back to regex when guard model errors.""" + mod = self._make_moderator() + with patch.object( + mod, "_is_guard_available", new_callable=AsyncMock, return_value=True + ), patch.object( + mod, "_check_with_guard", new_callable=AsyncMock, side_effect=RuntimeError("timeout") + ): + result = await mod.check("safe text", game="default") + # Should fall back to regex and pass + assert result.passed + assert result.layer == "regex_fallback" + + +class TestGetModerator: + """Test the singleton accessor.""" + + def test_returns_same_instance(self): + """get_moderator should return the same instance.""" + # Reset the global to test fresh + import infrastructure.guards.moderation as mod_module + + mod_module._moderator = None + m1 = get_moderator() + m2 = get_moderator() + assert m1 is m2 + # Clean up + mod_module._moderator = None + + +# ── Profile loader tests ──────────────────────────────────────────────────── + + +class TestProfileLoader: + """Test YAML profile loading.""" + + def test_load_missing_file_returns_empty(self, tmp_path): + from infrastructure.guards.profiles import load_profiles + + result = load_profiles(tmp_path / "nonexistent.yaml") + assert result == {} + + def test_load_valid_config(self, tmp_path): + import yaml + + from infrastructure.guards.profiles import load_profiles + + config = { + "profiles": { + "testgame": { + "display_name": "Test Game", + "threshold": 0.9, + "vocabulary_whitelist": ["sword", "potion"], + "context_prompt": "Narrate test game.", + "fallbacks": {"default": "Game continues."}, + } + } + } + config_file = tmp_path / "moderation.yaml" + config_file.write_text(yaml.dump(config)) + + profiles = load_profiles(config_file) + assert "testgame" in profiles + assert profiles["testgame"].threshold == 0.9 + assert "sword" in profiles["testgame"].vocabulary_whitelist + + def test_load_malformed_yaml_returns_empty(self, tmp_path): + from infrastructure.guards.profiles import load_profiles + + config_file = tmp_path / "moderation.yaml" + config_file.write_text("{{{{invalid yaml") + + result = load_profiles(config_file) + assert result == {} From fc53a33361f8c93c866a9b0aa478b07f08235d18 Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Mon, 23 Mar 2026 02:19:26 +0000 Subject: [PATCH 2/3] [claude] Enforce coverage threshold in CI workflow (#935) (#1061) --- .github/workflows/tests.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 63c0acea..8006b7ca 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -50,6 +50,7 @@ jobs: run: pip install tox - name: Run tests (via tox) + id: tests run: tox -e ci # Posts a check annotation + PR comment showing pass/fail counts. @@ -63,6 +64,20 @@ jobs: comment_title: "Test Results" report_individual_runs: true + - name: Enforce coverage floor (60%) + if: always() && steps.tests.outcome == 'success' + run: | + python -c " + import xml.etree.ElementTree as ET, sys + tree = ET.parse('reports/coverage.xml') + rate = float(tree.getroot().attrib['line-rate']) * 100 + print(f'Coverage: {rate:.1f}%') + if rate < 60: + print(f'FAIL: Coverage {rate:.1f}% is below 60% floor') + sys.exit(1) + print('PASS: Coverage is above 60% floor') + " + # Coverage report available as a downloadable artifact in the Actions tab - name: Upload coverage report uses: actions/upload-artifact@v4 From 7f875398fc2cb876068fdc7091ceaece6dc6935e Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Mon, 23 Mar 2026 14:09:03 +0000 Subject: [PATCH 3/3] [claude] Add sovereignty metrics tracking + dashboard panel (#981) (#1083) --- src/config.py | 4 + src/dashboard/app.py | 2 + src/dashboard/routes/sovereignty_metrics.py | 74 +++++ src/dashboard/templates/mission_control.html | 7 + .../partials/sovereignty_metrics.html | 63 ++++ src/infrastructure/sovereignty_metrics.py | 307 ++++++++++++++++++ tests/conftest.py | 2 + tests/infrastructure/test_moderation.py | 1 - .../test_sovereignty_metrics.py | 177 ++++++++++ 9 files changed, 636 insertions(+), 1 deletion(-) create mode 100644 src/dashboard/routes/sovereignty_metrics.py create mode 100644 src/dashboard/templates/partials/sovereignty_metrics.html create mode 100644 src/infrastructure/sovereignty_metrics.py create mode 100644 tests/infrastructure/test_sovereignty_metrics.py diff --git a/src/config.py b/src/config.py index c213e563..192c44e7 100644 --- a/src/config.py +++ b/src/config.py @@ -152,6 +152,10 @@ class Settings(BaseSettings): # Default is False (telemetry disabled) to align with sovereign AI vision. telemetry_enabled: bool = False + # ── Sovereignty Metrics ────────────────────────────────────────────── + # Alert when API cost per research task exceeds this threshold (USD). + sovereignty_api_cost_alert_threshold: float = 1.00 + # CORS allowed origins for the web chat interface (Gitea Pages, etc.) # Set CORS_ORIGINS as a comma-separated list, e.g. "http://localhost:3000,https://example.com" cors_origins: list[str] = [ diff --git a/src/dashboard/app.py b/src/dashboard/app.py index 7e1ccba9..042b9965 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -45,6 +45,7 @@ from dashboard.routes.models import api_router as models_api_router from dashboard.routes.models import router as models_router from dashboard.routes.quests import router as quests_router from dashboard.routes.scorecards import router as scorecards_router +from dashboard.routes.sovereignty_metrics import router as sovereignty_metrics_router from dashboard.routes.spark import router as spark_router from dashboard.routes.system import router as system_router from dashboard.routes.tasks import router as tasks_router @@ -631,6 +632,7 @@ app.include_router(tower_router) app.include_router(daily_run_router) app.include_router(quests_router) app.include_router(scorecards_router) +app.include_router(sovereignty_metrics_router) @app.websocket("/ws") diff --git a/src/dashboard/routes/sovereignty_metrics.py b/src/dashboard/routes/sovereignty_metrics.py new file mode 100644 index 00000000..3bffe95f --- /dev/null +++ b/src/dashboard/routes/sovereignty_metrics.py @@ -0,0 +1,74 @@ +"""Sovereignty metrics dashboard routes. + +Provides API endpoints and HTMX partials for tracking research +sovereignty progress against graduation targets. + +Refs: #981 +""" + +import logging +from typing import Any + +from fastapi import APIRouter, Request +from fastapi.responses import HTMLResponse + +from config import settings +from dashboard.templating import templates +from infrastructure.sovereignty_metrics import ( + GRADUATION_TARGETS, + get_sovereignty_store, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/sovereignty", tags=["sovereignty"]) + + +@router.get("/metrics") +async def sovereignty_metrics_api() -> dict[str, Any]: + """JSON API: full sovereignty metrics summary with trends.""" + store = get_sovereignty_store() + summary = store.get_summary() + alerts = store.get_alerts(unacknowledged_only=True) + return { + "metrics": summary, + "alerts": alerts, + "targets": GRADUATION_TARGETS, + "cost_threshold": settings.sovereignty_api_cost_alert_threshold, + } + + +@router.get("/metrics/panel", response_class=HTMLResponse) +async def sovereignty_metrics_panel(request: Request) -> HTMLResponse: + """HTMX partial: sovereignty metrics progress panel.""" + store = get_sovereignty_store() + summary = store.get_summary() + alerts = store.get_alerts(unacknowledged_only=True) + + return templates.TemplateResponse( + request, + "partials/sovereignty_metrics.html", + { + "metrics": summary, + "alerts": alerts, + "targets": GRADUATION_TARGETS, + }, + ) + + +@router.get("/alerts") +async def sovereignty_alerts_api() -> dict[str, Any]: + """JSON API: sovereignty alerts.""" + store = get_sovereignty_store() + return { + "alerts": store.get_alerts(unacknowledged_only=False), + "unacknowledged": store.get_alerts(unacknowledged_only=True), + } + + +@router.post("/alerts/{alert_id}/acknowledge") +async def acknowledge_alert(alert_id: int) -> dict[str, bool]: + """Acknowledge a sovereignty alert.""" + store = get_sovereignty_store() + success = store.acknowledge_alert(alert_id) + return {"success": success} diff --git a/src/dashboard/templates/mission_control.html b/src/dashboard/templates/mission_control.html index 27acbd15..a090ff5b 100644 --- a/src/dashboard/templates/mission_control.html +++ b/src/dashboard/templates/mission_control.html @@ -179,6 +179,13 @@ + +{% call panel("SOVEREIGNTY METRICS", id="sovereignty-metrics-panel", + hx_get="/sovereignty/metrics/panel", + hx_trigger="load, every 30s") %} +

Loading sovereignty metrics...

+{% endcall %} +
diff --git a/src/dashboard/templates/partials/sovereignty_metrics.html b/src/dashboard/templates/partials/sovereignty_metrics.html new file mode 100644 index 00000000..3ef004fc --- /dev/null +++ b/src/dashboard/templates/partials/sovereignty_metrics.html @@ -0,0 +1,63 @@ +{# HTMX partial: Sovereignty Metrics Progress Panel + Loaded via hx-get="/sovereignty/metrics/panel" + Refs: #981 +#} +{% set phase_labels = {"pre-start": "Pre-start", "week1": "Week 1", "month1": "Month 1", "month3": "Month 3", "graduated": "Graduated"} %} +{% set phase_colors = {"pre-start": "var(--text-dim)", "week1": "var(--red)", "month1": "var(--amber)", "month3": "var(--green)", "graduated": "var(--purple)"} %} + +{% set metric_labels = { + "cache_hit_rate": "Cache Hit Rate", + "api_cost": "API Cost / Task", + "time_to_report": "Time to Report", + "human_involvement": "Human Involvement", + "local_artifacts": "Local Artifacts" +} %} + +{% set metric_units = { + "cache_hit_rate": "%", + "api_cost": "$", + "time_to_report": "min", + "human_involvement": "%", + "local_artifacts": "" +} %} + +{% if alerts %} +
+ {% for alert in alerts %} +
+ ! + {{ alert.message }} +
+ {% endfor %} +
+{% endif %} + +
+{% for key, data in metrics.items() %} + {% set label = metric_labels.get(key, key) %} + {% set unit = metric_units.get(key, "") %} + {% set phase = data.phase %} + {% set color = phase_colors.get(phase, "var(--text-dim)") %} +
+
+ {% if data.current is not none %} + {% if key == "cache_hit_rate" or key == "human_involvement" %} + {{ "%.0f"|format(data.current * 100) }}{{ unit }} + {% elif key == "api_cost" %} + {{ unit }}{{ "%.2f"|format(data.current) }} + {% elif key == "time_to_report" %} + {{ "%.1f"|format(data.current) }}{{ unit }} + {% else %} + {{ data.current|int }} + {% endif %} + {% else %} + -- + {% endif %} +
+
{{ label }}
+
+ {{ phase_labels.get(phase, phase) }} +
+
+{% endfor %} +
diff --git a/src/infrastructure/sovereignty_metrics.py b/src/infrastructure/sovereignty_metrics.py new file mode 100644 index 00000000..a305fa65 --- /dev/null +++ b/src/infrastructure/sovereignty_metrics.py @@ -0,0 +1,307 @@ +"""Sovereignty metrics collector and store. + +Tracks research sovereignty progress: cache hit rate, API cost, +time-to-report, and human involvement. Persists to SQLite for +trend analysis and dashboard display. + +Refs: #981 +""" + +import json +import logging +import sqlite3 +from contextlib import closing +from dataclasses import dataclass, field +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +from config import settings + +logger = logging.getLogger(__name__) + +DB_PATH = Path(settings.repo_root) / "data" / "sovereignty_metrics.db" + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS sovereignty_metrics ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + metric_type TEXT NOT NULL, + value REAL NOT NULL, + metadata TEXT DEFAULT '{}' +); +CREATE INDEX IF NOT EXISTS idx_sm_type ON sovereignty_metrics(metric_type); +CREATE INDEX IF NOT EXISTS idx_sm_ts ON sovereignty_metrics(timestamp); + +CREATE TABLE IF NOT EXISTS sovereignty_alerts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + alert_type TEXT NOT NULL, + message TEXT NOT NULL, + value REAL NOT NULL, + threshold REAL NOT NULL, + acknowledged INTEGER DEFAULT 0 +); +CREATE INDEX IF NOT EXISTS idx_sa_ts ON sovereignty_alerts(timestamp); +CREATE INDEX IF NOT EXISTS idx_sa_ack ON sovereignty_alerts(acknowledged); +""" + + +@dataclass +class SovereigntyMetric: + """A single sovereignty metric data point.""" + + metric_type: str # cache_hit_rate, api_cost, time_to_report, human_involvement + value: float + timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class SovereigntyAlert: + """An alert triggered when a metric exceeds a threshold.""" + + alert_type: str + message: str + value: float + threshold: float + timestamp: str = field(default_factory=lambda: datetime.now(UTC).isoformat()) + acknowledged: bool = False + + +# Graduation targets from issue #981 +GRADUATION_TARGETS = { + "cache_hit_rate": {"week1": 0.10, "month1": 0.40, "month3": 0.80, "graduation": 0.90}, + "api_cost": {"week1": 1.50, "month1": 0.50, "month3": 0.10, "graduation": 0.01}, + "time_to_report": {"week1": 180.0, "month1": 30.0, "month3": 5.0, "graduation": 1.0}, + "human_involvement": {"week1": 1.0, "month1": 0.5, "month3": 0.25, "graduation": 0.0}, + "local_artifacts": {"week1": 6, "month1": 30, "month3": 100, "graduation": 500}, +} + + +class SovereigntyMetricsStore: + """SQLite-backed sovereignty metrics store. + + Thread-safe: creates a new connection per operation. + """ + + def __init__(self, db_path: Path | None = None) -> None: + self._db_path = db_path or DB_PATH + self._init_db() + + def _init_db(self) -> None: + """Initialize the database schema.""" + try: + self._db_path.parent.mkdir(parents=True, exist_ok=True) + with closing(sqlite3.connect(str(self._db_path))) as conn: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}") + conn.executescript(_SCHEMA) + conn.commit() + except Exception as exc: + logger.warning("Failed to initialize sovereignty metrics DB: %s", exc) + + def _connect(self) -> sqlite3.Connection: + """Get a new connection.""" + conn = sqlite3.connect(str(self._db_path)) + conn.row_factory = sqlite3.Row + conn.execute(f"PRAGMA busy_timeout={settings.db_busy_timeout_ms}") + return conn + + def record(self, metric: SovereigntyMetric) -> None: + """Record a sovereignty metric data point.""" + try: + with closing(self._connect()) as conn: + conn.execute( + "INSERT INTO sovereignty_metrics (timestamp, metric_type, value, metadata) " + "VALUES (?, ?, ?, ?)", + ( + metric.timestamp, + metric.metric_type, + metric.value, + json.dumps(metric.metadata), + ), + ) + conn.commit() + except Exception as exc: + logger.warning("Failed to record sovereignty metric: %s", exc) + + # Check thresholds for alerts + self._check_alert(metric) + + def _check_alert(self, metric: SovereigntyMetric) -> None: + """Check if a metric triggers an alert.""" + threshold = settings.sovereignty_api_cost_alert_threshold + if metric.metric_type == "api_cost" and metric.value > threshold: + alert = SovereigntyAlert( + alert_type="api_cost_exceeded", + message=f"API cost ${metric.value:.2f} exceeds threshold ${threshold:.2f}", + value=metric.value, + threshold=threshold, + ) + self._record_alert(alert) + + def _record_alert(self, alert: SovereigntyAlert) -> None: + """Persist an alert.""" + try: + with closing(self._connect()) as conn: + conn.execute( + "INSERT INTO sovereignty_alerts " + "(timestamp, alert_type, message, value, threshold) " + "VALUES (?, ?, ?, ?, ?)", + ( + alert.timestamp, + alert.alert_type, + alert.message, + alert.value, + alert.threshold, + ), + ) + conn.commit() + logger.warning("Sovereignty alert: %s", alert.message) + except Exception as exc: + logger.warning("Failed to record sovereignty alert: %s", exc) + + def get_latest(self, metric_type: str, limit: int = 50) -> list[dict]: + """Get the most recent metric values for a given type.""" + try: + with closing(self._connect()) as conn: + rows = conn.execute( + "SELECT timestamp, value, metadata FROM sovereignty_metrics " + "WHERE metric_type = ? ORDER BY timestamp DESC LIMIT ?", + (metric_type, limit), + ).fetchall() + return [ + { + "timestamp": row["timestamp"], + "value": row["value"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else {}, + } + for row in rows + ] + except Exception as exc: + logger.warning("Failed to query sovereignty metrics: %s", exc) + return [] + + def get_summary(self) -> dict[str, Any]: + """Get a summary of current sovereignty metrics progress.""" + summary: dict[str, Any] = {} + for metric_type in GRADUATION_TARGETS: + latest = self.get_latest(metric_type, limit=1) + history = self.get_latest(metric_type, limit=30) + + current_value = latest[0]["value"] if latest else None + targets = GRADUATION_TARGETS[metric_type] + + # Determine current phase based on value + phase = "pre-start" + if current_value is not None: + if metric_type in ("api_cost", "time_to_report", "human_involvement"): + # Lower is better + if current_value <= targets["graduation"]: + phase = "graduated" + elif current_value <= targets["month3"]: + phase = "month3" + elif current_value <= targets["month1"]: + phase = "month1" + elif current_value <= targets["week1"]: + phase = "week1" + else: + phase = "pre-start" + else: + # Higher is better + if current_value >= targets["graduation"]: + phase = "graduated" + elif current_value >= targets["month3"]: + phase = "month3" + elif current_value >= targets["month1"]: + phase = "month1" + elif current_value >= targets["week1"]: + phase = "week1" + else: + phase = "pre-start" + + summary[metric_type] = { + "current": current_value, + "phase": phase, + "targets": targets, + "trend": [{"t": h["timestamp"], "v": h["value"]} for h in reversed(history)], + } + + return summary + + def get_alerts(self, unacknowledged_only: bool = True, limit: int = 20) -> list[dict]: + """Get sovereignty alerts.""" + try: + with closing(self._connect()) as conn: + if unacknowledged_only: + rows = conn.execute( + "SELECT * FROM sovereignty_alerts " + "WHERE acknowledged = 0 ORDER BY timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM sovereignty_alerts " + "ORDER BY timestamp DESC LIMIT ?", + (limit,), + ).fetchall() + return [dict(row) for row in rows] + except Exception as exc: + logger.warning("Failed to query sovereignty alerts: %s", exc) + return [] + + def acknowledge_alert(self, alert_id: int) -> bool: + """Acknowledge an alert.""" + try: + with closing(self._connect()) as conn: + conn.execute( + "UPDATE sovereignty_alerts SET acknowledged = 1 WHERE id = ?", + (alert_id,), + ) + conn.commit() + return True + except Exception as exc: + logger.warning("Failed to acknowledge alert: %s", exc) + return False + + +# ── Module-level singleton ───────────────────────────────────────────────── +_store: SovereigntyMetricsStore | None = None + + +def get_sovereignty_store() -> SovereigntyMetricsStore: + """Return the module-level store, creating it on first access.""" + global _store + if _store is None: + _store = SovereigntyMetricsStore() + return _store + + +async def emit_sovereignty_metric( + metric_type: str, + value: float, + metadata: dict[str, Any] | None = None, +) -> None: + """Convenience function to record a sovereignty metric and emit an event. + + Also publishes to the event bus for real-time subscribers. + """ + import asyncio + + from infrastructure.events.bus import emit + + metric = SovereigntyMetric( + metric_type=metric_type, + value=value, + metadata=metadata or {}, + ) + # Record to SQLite in thread to avoid blocking event loop + await asyncio.to_thread(get_sovereignty_store().record, metric) + + # Publish to event bus for real-time consumers + await emit( + f"sovereignty.metric.{metric_type}", + source="sovereignty_metrics", + data={"metric_type": metric_type, "value": value, **(metadata or {})}, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 3db5de56..bf684f69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,10 +147,12 @@ def clean_database(tmp_path): # IMPORTANT: swarm.task_queue.models also has a DB_PATH that writes to # tasks.db — it MUST be patched too, or error_capture.capture_error() # will write test data to the production database. + tmp_sovereignty_db = tmp_path / "sovereignty_metrics.db" for mod_name, tmp_db in [ ("dashboard.routes.tasks", tmp_tasks_db), ("dashboard.routes.work_orders", tmp_work_orders_db), ("swarm.task_queue.models", tmp_tasks_db), + ("infrastructure.sovereignty_metrics", tmp_sovereignty_db), ]: try: mod = __import__(mod_name, fromlist=["DB_PATH"]) diff --git a/tests/infrastructure/test_moderation.py b/tests/infrastructure/test_moderation.py index add8c1b5..9ac59129 100644 --- a/tests/infrastructure/test_moderation.py +++ b/tests/infrastructure/test_moderation.py @@ -14,7 +14,6 @@ from infrastructure.guards.moderation import ( get_moderator, ) - # ── Unit tests for data types ──────────────────────────────────────────────── diff --git a/tests/infrastructure/test_sovereignty_metrics.py b/tests/infrastructure/test_sovereignty_metrics.py new file mode 100644 index 00000000..8acb4a0a --- /dev/null +++ b/tests/infrastructure/test_sovereignty_metrics.py @@ -0,0 +1,177 @@ +"""Tests for the sovereignty metrics store and API routes. + +Refs: #981 +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from infrastructure.sovereignty_metrics import ( + GRADUATION_TARGETS, + SovereigntyMetric, + SovereigntyMetricsStore, + emit_sovereignty_metric, +) + + +@pytest.fixture +def store(tmp_path): + """Create a fresh sovereignty metrics store with a temp DB.""" + return SovereigntyMetricsStore(db_path=tmp_path / "test_sov.db") + + +class TestSovereigntyMetricsStore: + def test_record_and_get_latest(self, store): + metric = SovereigntyMetric(metric_type="cache_hit_rate", value=0.42) + store.record(metric) + + results = store.get_latest("cache_hit_rate", limit=10) + assert len(results) == 1 + assert results[0]["value"] == 0.42 + + def test_get_latest_returns_most_recent_first(self, store): + for val in [0.1, 0.2, 0.3]: + store.record(SovereigntyMetric(metric_type="cache_hit_rate", value=val)) + + results = store.get_latest("cache_hit_rate", limit=10) + assert len(results) == 3 + assert results[0]["value"] == 0.3 # most recent first + + def test_get_latest_respects_limit(self, store): + for i in range(10): + store.record(SovereigntyMetric(metric_type="api_cost", value=float(i))) + + results = store.get_latest("api_cost", limit=3) + assert len(results) == 3 + + def test_get_latest_filters_by_type(self, store): + store.record(SovereigntyMetric(metric_type="cache_hit_rate", value=0.5)) + store.record(SovereigntyMetric(metric_type="api_cost", value=1.20)) + + results = store.get_latest("cache_hit_rate") + assert len(results) == 1 + assert results[0]["value"] == 0.5 + + def test_get_summary_empty(self, store): + summary = store.get_summary() + assert "cache_hit_rate" in summary + assert summary["cache_hit_rate"]["current"] is None + assert summary["cache_hit_rate"]["phase"] == "pre-start" + + def test_get_summary_with_data(self, store): + store.record(SovereigntyMetric(metric_type="cache_hit_rate", value=0.85)) + store.record(SovereigntyMetric(metric_type="api_cost", value=0.08)) + + summary = store.get_summary() + assert summary["cache_hit_rate"]["current"] == 0.85 + assert summary["cache_hit_rate"]["phase"] == "month3" + assert summary["api_cost"]["current"] == 0.08 + assert summary["api_cost"]["phase"] == "month3" + + def test_get_summary_graduation(self, store): + store.record(SovereigntyMetric(metric_type="cache_hit_rate", value=0.95)) + summary = store.get_summary() + assert summary["cache_hit_rate"]["phase"] == "graduated" + + def test_alert_on_high_api_cost(self, store): + """API cost above threshold triggers an alert.""" + with patch("infrastructure.sovereignty_metrics.settings") as mock_settings: + mock_settings.sovereignty_api_cost_alert_threshold = 1.00 + mock_settings.db_busy_timeout_ms = 5000 + store.record(SovereigntyMetric(metric_type="api_cost", value=2.50)) + + alerts = store.get_alerts(unacknowledged_only=True) + assert len(alerts) == 1 + assert alerts[0]["alert_type"] == "api_cost_exceeded" + assert alerts[0]["value"] == 2.50 + + def test_no_alert_below_threshold(self, store): + """API cost below threshold does not trigger an alert.""" + with patch("infrastructure.sovereignty_metrics.settings") as mock_settings: + mock_settings.sovereignty_api_cost_alert_threshold = 1.00 + mock_settings.db_busy_timeout_ms = 5000 + store.record(SovereigntyMetric(metric_type="api_cost", value=0.50)) + + alerts = store.get_alerts(unacknowledged_only=True) + assert len(alerts) == 0 + + def test_acknowledge_alert(self, store): + with patch("infrastructure.sovereignty_metrics.settings") as mock_settings: + mock_settings.sovereignty_api_cost_alert_threshold = 0.50 + mock_settings.db_busy_timeout_ms = 5000 + store.record(SovereigntyMetric(metric_type="api_cost", value=1.00)) + + alerts = store.get_alerts(unacknowledged_only=True) + assert len(alerts) == 1 + + store.acknowledge_alert(alerts[0]["id"]) + assert len(store.get_alerts(unacknowledged_only=True)) == 0 + assert len(store.get_alerts(unacknowledged_only=False)) == 1 + + def test_metadata_preserved(self, store): + store.record( + SovereigntyMetric( + metric_type="cache_hit_rate", + value=0.5, + metadata={"source": "research_orchestrator"}, + ) + ) + results = store.get_latest("cache_hit_rate") + assert results[0]["metadata"]["source"] == "research_orchestrator" + + def test_summary_trend_data(self, store): + for v in [0.1, 0.2, 0.3]: + store.record(SovereigntyMetric(metric_type="cache_hit_rate", value=v)) + + summary = store.get_summary() + trend = summary["cache_hit_rate"]["trend"] + assert len(trend) == 3 + assert trend[0]["v"] == 0.1 # oldest first (reversed) + assert trend[-1]["v"] == 0.3 + + def test_graduation_targets_complete(self): + """All expected metric types have graduation targets.""" + expected = {"cache_hit_rate", "api_cost", "time_to_report", "human_involvement", "local_artifacts"} + assert set(GRADUATION_TARGETS.keys()) == expected + + +class TestEmitSovereigntyMetric: + @pytest.mark.asyncio + async def test_emit_records_and_publishes(self, tmp_path): + """emit_sovereignty_metric records to store and publishes event.""" + with ( + patch("infrastructure.sovereignty_metrics._store", None), + patch( + "infrastructure.sovereignty_metrics.DB_PATH", + tmp_path / "emit_test.db", + ), + patch("infrastructure.events.bus.emit", new_callable=AsyncMock) as mock_emit, + ): + await emit_sovereignty_metric("cache_hit_rate", 0.75, {"source": "test"}) + + mock_emit.assert_called_once() + call_args = mock_emit.call_args + assert call_args[0][0] == "sovereignty.metric.cache_hit_rate" + + +class TestSovereigntyMetricsRoutes: + def test_metrics_api_returns_200(self, client): + response = client.get("/sovereignty/metrics") + assert response.status_code == 200 + data = response.json() + assert "metrics" in data + assert "alerts" in data + assert "targets" in data + + def test_metrics_panel_returns_html(self, client): + response = client.get("/sovereignty/metrics/panel") + assert response.status_code == 200 + assert "text/html" in response.headers["content-type"] + + def test_alerts_api_returns_200(self, client): + response = client.get("/sovereignty/alerts") + assert response.status_code == 200 + data = response.json() + assert "alerts" in data + assert "unacknowledged" in data