Merge branch 'main' into kimi/issue-627
This commit is contained in:
33
config/matrix.yaml
Normal file
33
config/matrix.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# Matrix World Configuration
|
||||
# Serves lighting, environment, and feature settings to the Matrix frontend.
|
||||
|
||||
lighting:
|
||||
ambient_color: "#FFAA55" # Warm amber (Workshop warmth)
|
||||
ambient_intensity: 0.5
|
||||
point_lights:
|
||||
- color: "#FFAA55" # Warm amber (Workshop center light)
|
||||
intensity: 1.2
|
||||
position: { x: 0, y: 5, z: 0 }
|
||||
- color: "#3B82F6" # Cool blue (Matrix accent)
|
||||
intensity: 0.8
|
||||
position: { x: -5, y: 3, z: -5 }
|
||||
- color: "#A855F7" # Purple accent
|
||||
intensity: 0.6
|
||||
position: { x: 5, y: 3, z: 5 }
|
||||
|
||||
environment:
|
||||
rain_enabled: false
|
||||
starfield_enabled: true # Cool blue starfield (Matrix feel)
|
||||
fog_color: "#0f0f23"
|
||||
fog_density: 0.02
|
||||
|
||||
features:
|
||||
chat_enabled: true
|
||||
visitor_avatars: true
|
||||
pip_familiar: true
|
||||
workshop_portal: true
|
||||
|
||||
agents:
|
||||
default_count: 5
|
||||
max_count: 20
|
||||
agents: []
|
||||
@@ -27,11 +27,15 @@ from pathlib import Path
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
QUEUE_FILE = REPO_ROOT / ".loop" / "queue.json"
|
||||
IDLE_STATE_FILE = REPO_ROOT / ".loop" / "idle_state.json"
|
||||
CYCLE_RESULT_FILE = REPO_ROOT / ".loop" / "cycle_result.json"
|
||||
TOKEN_FILE = Path.home() / ".hermes" / "gitea_token"
|
||||
|
||||
GITEA_API = os.environ.get("GITEA_API", "http://localhost:3000/api/v1")
|
||||
REPO_SLUG = os.environ.get("REPO_SLUG", "rockachopa/Timmy-time-dashboard")
|
||||
|
||||
# Default cycle duration in seconds (5 min); stale threshold = 2× this
|
||||
CYCLE_DURATION = int(os.environ.get("CYCLE_DURATION", "300"))
|
||||
|
||||
# Backoff sequence: 60s, 120s, 240s, 600s max
|
||||
BACKOFF_BASE = 60
|
||||
BACKOFF_MAX = 600
|
||||
@@ -77,6 +81,89 @@ def _fetch_open_issue_numbers() -> set[int] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _load_cycle_result() -> dict:
|
||||
"""Read cycle_result.json, handling markdown-fenced JSON."""
|
||||
if not CYCLE_RESULT_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
raw = CYCLE_RESULT_FILE.read_text().strip()
|
||||
if raw.startswith("```"):
|
||||
lines = raw.splitlines()
|
||||
lines = [ln for ln in lines if not ln.startswith("```")]
|
||||
raw = "\n".join(lines)
|
||||
return json.loads(raw)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return {}
|
||||
|
||||
|
||||
def _is_issue_open(issue_number: int) -> bool | None:
|
||||
"""Check if a single issue is open. Returns None on API failure."""
|
||||
token = _get_token()
|
||||
if not token:
|
||||
return None
|
||||
try:
|
||||
url = f"{GITEA_API}/repos/{REPO_SLUG}/issues/{issue_number}"
|
||||
req = urllib.request.Request(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
return data.get("state") == "open"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_cycle_result() -> bool:
|
||||
"""Pre-cycle validation: remove stale or invalid cycle_result.json.
|
||||
|
||||
Checks:
|
||||
1. Age — if older than 2× CYCLE_DURATION, delete it.
|
||||
2. Issue — if the referenced issue is closed, delete it.
|
||||
|
||||
Returns True if the file was removed, False otherwise.
|
||||
"""
|
||||
if not CYCLE_RESULT_FILE.exists():
|
||||
return False
|
||||
|
||||
# Age check
|
||||
try:
|
||||
age = time.time() - CYCLE_RESULT_FILE.stat().st_mtime
|
||||
except OSError:
|
||||
return False
|
||||
stale_threshold = CYCLE_DURATION * 2
|
||||
if age > stale_threshold:
|
||||
print(
|
||||
f"[loop-guard] cycle_result.json is {int(age)}s old "
|
||||
f"(threshold {stale_threshold}s) — removing stale file"
|
||||
)
|
||||
CYCLE_RESULT_FILE.unlink(missing_ok=True)
|
||||
return True
|
||||
|
||||
# Issue check
|
||||
cr = _load_cycle_result()
|
||||
issue_num = cr.get("issue")
|
||||
if issue_num is not None:
|
||||
try:
|
||||
issue_num = int(issue_num)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
is_open = _is_issue_open(issue_num)
|
||||
if is_open is False:
|
||||
print(
|
||||
f"[loop-guard] cycle_result.json references closed "
|
||||
f"issue #{issue_num} — removing"
|
||||
)
|
||||
CYCLE_RESULT_FILE.unlink(missing_ok=True)
|
||||
return True
|
||||
# is_open is None (API failure) or True — keep file
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def load_queue() -> list[dict]:
|
||||
"""Load queue.json and return ready items, filtering out closed issues."""
|
||||
if not QUEUE_FILE.exists():
|
||||
@@ -150,6 +237,9 @@ def main() -> int:
|
||||
}, indent=2))
|
||||
return 0
|
||||
|
||||
# Pre-cycle validation: remove stale cycle_result.json
|
||||
validate_cycle_result()
|
||||
|
||||
ready = load_queue()
|
||||
|
||||
if ready:
|
||||
|
||||
@@ -149,6 +149,18 @@ class Settings(BaseSettings):
|
||||
"http://127.0.0.1:8000",
|
||||
]
|
||||
|
||||
# ── Matrix Frontend Integration ────────────────────────────────────────
|
||||
# URL of the Matrix frontend (Replit/Tailscale) for CORS.
|
||||
# When set, this origin is added to CORS allowed_origins.
|
||||
# Example: "http://100.124.176.28:8080" or "https://alexanderwhitestone.com"
|
||||
matrix_frontend_url: str = "" # Empty = disabled
|
||||
|
||||
# WebSocket authentication token for Matrix connections.
|
||||
# When set, clients must provide this token via ?token= query param
|
||||
# or in the first message as {"type": "auth", "token": "..."}.
|
||||
# Empty/unset = auth disabled (dev mode).
|
||||
matrix_ws_token: str = ""
|
||||
|
||||
# Trusted hosts for the Host header check (TrustedHostMiddleware).
|
||||
# Set TRUSTED_HOSTS as a comma-separated list. Wildcards supported (e.g. "*.ts.net").
|
||||
# Defaults include localhost + Tailscale MagicDNS. Add your Tailscale IP if needed.
|
||||
|
||||
@@ -10,6 +10,7 @@ Key improvements:
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
@@ -23,6 +24,7 @@ from config import settings
|
||||
|
||||
# Import dedicated middleware
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
from dashboard.middleware.rate_limit import RateLimitMiddleware
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from dashboard.routes.agents import router as agents_router
|
||||
@@ -49,6 +51,7 @@ from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.tower import router as tower_router
|
||||
from dashboard.routes.voice import router as voice_router
|
||||
from dashboard.routes.work_orders import router as work_orders_router
|
||||
from dashboard.routes.world import matrix_router
|
||||
from dashboard.routes.world import router as world_router
|
||||
from timmy.workshop_state import PRESENCE_FILE
|
||||
|
||||
@@ -519,25 +522,55 @@ app = FastAPI(
|
||||
|
||||
|
||||
def _get_cors_origins() -> list[str]:
|
||||
"""Get CORS origins from settings, rejecting wildcards in production."""
|
||||
origins = settings.cors_origins
|
||||
"""Get CORS origins from settings, rejecting wildcards in production.
|
||||
|
||||
Adds matrix_frontend_url when configured. Always allows Tailscale IPs
|
||||
(100.x.x.x range) for development convenience.
|
||||
"""
|
||||
origins = list(settings.cors_origins)
|
||||
|
||||
# Strip wildcards in production (security)
|
||||
if "*" in origins and not settings.debug:
|
||||
logger.warning(
|
||||
"Wildcard '*' in CORS_ORIGINS stripped in production — "
|
||||
"set explicit origins via CORS_ORIGINS env var"
|
||||
)
|
||||
origins = [o for o in origins if o != "*"]
|
||||
|
||||
# Add Matrix frontend URL if configured
|
||||
if settings.matrix_frontend_url:
|
||||
url = settings.matrix_frontend_url.strip()
|
||||
if url and url not in origins:
|
||||
origins.append(url)
|
||||
logger.debug("Added Matrix frontend to CORS: %s", url)
|
||||
|
||||
return origins
|
||||
|
||||
|
||||
# Pattern to match Tailscale IPs (100.x.x.x) for CORS origin regex
|
||||
_TAILSCALE_IP_PATTERN = re.compile(r"^https?://100\.\d{1,3}\.\d{1,3}\.\d{1,3}(?::\d+)?$")
|
||||
|
||||
|
||||
def _is_tailscale_origin(origin: str) -> bool:
|
||||
"""Check if origin is a Tailscale IP (100.x.x.x range)."""
|
||||
return bool(_TAILSCALE_IP_PATTERN.match(origin))
|
||||
|
||||
|
||||
# Add dedicated middleware in correct order
|
||||
# 1. Logging (outermost to capture everything)
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||
|
||||
# 2. Security Headers
|
||||
# 2. Rate Limiting (before security to prevent abuse early)
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=30,
|
||||
)
|
||||
|
||||
# 3. Security Headers
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
|
||||
|
||||
# 3. CSRF Protection
|
||||
# 4. CSRF Protection
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# 4. Standard FastAPI middleware
|
||||
@@ -551,6 +584,7 @@ app.add_middleware(
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_get_cors_origins(),
|
||||
allow_origin_regex=r"https?://100\.\d{1,3}\.\d{1,3}\.\d{1,3}(:\d+)?",
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
@@ -589,6 +623,7 @@ app.include_router(system_router)
|
||||
app.include_router(experiments_router)
|
||||
app.include_router(db_explorer_router)
|
||||
app.include_router(world_router)
|
||||
app.include_router(matrix_router)
|
||||
app.include_router(tower_router)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Dashboard middleware package."""
|
||||
|
||||
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
|
||||
from .rate_limit import RateLimiter, RateLimitMiddleware
|
||||
from .request_logging import RequestLoggingMiddleware
|
||||
from .security_headers import SecurityHeadersMiddleware
|
||||
|
||||
@@ -9,6 +10,8 @@ __all__ = [
|
||||
"csrf_exempt",
|
||||
"generate_csrf_token",
|
||||
"validate_csrf_token",
|
||||
"RateLimiter",
|
||||
"RateLimitMiddleware",
|
||||
"SecurityHeadersMiddleware",
|
||||
"RequestLoggingMiddleware",
|
||||
]
|
||||
|
||||
209
src/dashboard/middleware/rate_limit.py
Normal file
209
src/dashboard/middleware/rate_limit.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Rate limiting middleware for FastAPI.
|
||||
|
||||
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
|
||||
with configurable limits and automatic cleanup of stale entries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""In-memory rate limiter for tracking requests per IP.
|
||||
|
||||
Stores request timestamps in a dict keyed by client IP.
|
||||
Automatically cleans up stale entries every 60 seconds.
|
||||
|
||||
Attributes:
|
||||
requests_per_minute: Maximum requests allowed per minute per IP.
|
||||
cleanup_interval_seconds: How often to clean stale entries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 30,
|
||||
cleanup_interval_seconds: int = 60,
|
||||
):
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.cleanup_interval_seconds = cleanup_interval_seconds
|
||||
self._storage: dict[str, deque[float]] = {}
|
||||
self._last_cleanup: float = time.time()
|
||||
self._window_seconds: float = 60.0 # 1 minute window
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract client IP from request, respecting X-Forwarded-For header.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
Returns:
|
||||
Client IP address string.
|
||||
"""
|
||||
# Check for forwarded IP (when behind proxy/load balancer)
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
# Take the first IP in the chain
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fall back to direct connection
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_if_needed(self) -> None:
|
||||
"""Remove stale entries older than the cleanup interval."""
|
||||
now = time.time()
|
||||
if now - self._last_cleanup < self.cleanup_interval_seconds:
|
||||
return
|
||||
|
||||
cutoff = now - self._window_seconds
|
||||
stale_ips: list[str] = []
|
||||
|
||||
for ip, timestamps in self._storage.items():
|
||||
# Remove timestamps older than the window
|
||||
while timestamps and timestamps[0] < cutoff:
|
||||
timestamps.popleft()
|
||||
# Mark IP for removal if no recent requests
|
||||
if not timestamps:
|
||||
stale_ips.append(ip)
|
||||
|
||||
# Remove stale IP entries
|
||||
for ip in stale_ips:
|
||||
del self._storage[ip]
|
||||
|
||||
self._last_cleanup = now
|
||||
if stale_ips:
|
||||
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
|
||||
|
||||
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
|
||||
"""Check if a request from the given IP is allowed.
|
||||
|
||||
Args:
|
||||
client_ip: The client's IP address.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, retry_after: float).
|
||||
retry_after is seconds until next allowed request, 0 if allowed now.
|
||||
"""
|
||||
now = time.time()
|
||||
cutoff = now - self._window_seconds
|
||||
|
||||
# Get or create timestamp deque for this IP
|
||||
if client_ip not in self._storage:
|
||||
self._storage[client_ip] = deque()
|
||||
|
||||
timestamps = self._storage[client_ip]
|
||||
|
||||
# Remove timestamps outside the window
|
||||
while timestamps and timestamps[0] < cutoff:
|
||||
timestamps.popleft()
|
||||
|
||||
# Check if limit exceeded
|
||||
if len(timestamps) >= self.requests_per_minute:
|
||||
# Calculate retry after time
|
||||
oldest = timestamps[0]
|
||||
retry_after = self._window_seconds - (now - oldest)
|
||||
return False, max(0.0, retry_after)
|
||||
|
||||
# Record this request
|
||||
timestamps.append(now)
|
||||
return True, 0.0
|
||||
|
||||
def check_request(self, request: Request) -> tuple[bool, float]:
|
||||
"""Check if the request is allowed under rate limits.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed: bool, retry_after: float).
|
||||
"""
|
||||
self._cleanup_if_needed()
|
||||
client_ip = self._get_client_ip(request)
|
||||
return self.is_allowed(client_ip)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to apply rate limiting to specific routes.
|
||||
|
||||
Usage:
|
||||
# Apply to all routes (not recommended for public static files)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# Apply only to specific paths
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=30,
|
||||
)
|
||||
|
||||
Attributes:
|
||||
path_prefixes: List of URL path prefixes to rate limit.
|
||||
If empty, applies to all paths.
|
||||
requests_per_minute: Maximum requests per minute per IP.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
path_prefixes: list[str] | None = None,
|
||||
requests_per_minute: int = 30,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.path_prefixes = path_prefixes or []
|
||||
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
|
||||
|
||||
def _should_rate_limit(self, path: str) -> bool:
|
||||
"""Check if the given path should be rate limited.
|
||||
|
||||
Args:
|
||||
path: The request URL path.
|
||||
|
||||
Returns:
|
||||
True if path matches any configured prefix.
|
||||
"""
|
||||
if not self.path_prefixes:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in self.path_prefixes)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Apply rate limiting to configured paths.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
call_next: Callable to get the response from downstream.
|
||||
|
||||
Returns:
|
||||
Response from downstream, or 429 if rate limited.
|
||||
"""
|
||||
# Skip if path doesn't match configured prefixes
|
||||
if not self._should_rate_limit(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Check rate limit
|
||||
allowed, retry_after = self.limiter.check_request(request)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded. Try again later.",
|
||||
"retry_after": int(retry_after) + 1,
|
||||
},
|
||||
headers={"Retry-After": str(int(retry_after) + 1)},
|
||||
)
|
||||
|
||||
# Process the request
|
||||
return await call_next(request)
|
||||
@@ -17,17 +17,221 @@ or missing.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, WebSocket
|
||||
import yaml
|
||||
from fastapi import APIRouter, Request, WebSocket
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from infrastructure.presence import serialize_presence
|
||||
from config import settings
|
||||
from infrastructure.presence import produce_bark, serialize_presence
|
||||
from timmy.memory_system import search_memories
|
||||
from timmy.workshop_state import PRESENCE_FILE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/world", tags=["world"])
|
||||
matrix_router = APIRouter(prefix="/api/matrix", tags=["matrix"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Bark Endpoint — HTTP fallback for bark messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Rate limiting: 1 request per 3 seconds per visitor_id
|
||||
_BARK_RATE_LIMIT_SECONDS = 3
|
||||
_bark_last_request: dict[str, float] = {}
|
||||
|
||||
|
||||
class BarkRequest(BaseModel):
|
||||
"""Request body for POST /api/matrix/bark."""
|
||||
|
||||
text: str
|
||||
visitor_id: str
|
||||
|
||||
|
||||
@matrix_router.post("/bark")
|
||||
async def post_matrix_bark(request: BarkRequest) -> JSONResponse:
|
||||
"""Generate a bark response for a visitor message.
|
||||
|
||||
HTTP fallback for when WebSocket isn't available. The Matrix frontend
|
||||
can POST a message and get Timmy's bark response back as JSON.
|
||||
|
||||
Rate limited to 1 request per 3 seconds per visitor_id.
|
||||
|
||||
Request body:
|
||||
- text: The visitor's message text
|
||||
- visitor_id: Unique identifier for the visitor (used for rate limiting)
|
||||
|
||||
Returns:
|
||||
- 200: Bark message in produce_bark() format
|
||||
- 429: Rate limit exceeded (try again later)
|
||||
- 422: Invalid request (missing/invalid fields)
|
||||
"""
|
||||
# Validate inputs
|
||||
text = request.text.strip() if request.text else ""
|
||||
visitor_id = request.visitor_id.strip() if request.visitor_id else ""
|
||||
|
||||
if not text:
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={"error": "text is required"},
|
||||
)
|
||||
|
||||
if not visitor_id:
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={"error": "visitor_id is required"},
|
||||
)
|
||||
|
||||
# Rate limiting check
|
||||
now = time.time()
|
||||
last_request = _bark_last_request.get(visitor_id, 0)
|
||||
time_since_last = now - last_request
|
||||
|
||||
if time_since_last < _BARK_RATE_LIMIT_SECONDS:
|
||||
retry_after = _BARK_RATE_LIMIT_SECONDS - time_since_last
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "Rate limit exceeded. Try again later."},
|
||||
headers={"Retry-After": str(int(retry_after) + 1)},
|
||||
)
|
||||
|
||||
# Record this request
|
||||
_bark_last_request[visitor_id] = now
|
||||
|
||||
# Generate bark response
|
||||
try:
|
||||
reply = await _generate_bark(text)
|
||||
except Exception as exc:
|
||||
logger.warning("Bark generation failed: %s", exc)
|
||||
reply = "Hmm, my thoughts are a bit tangled right now."
|
||||
|
||||
# Build bark response using produce_bark format
|
||||
bark = produce_bark(agent_id="timmy", text=reply, style="speech")
|
||||
|
||||
return JSONResponse(
|
||||
content=bark,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Agent Registry — serves agents to the Matrix visualization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Agent color mapping — consistent with Matrix visual identity
|
||||
_AGENT_COLORS: dict[str, str] = {
|
||||
"timmy": "#FFD700", # Gold
|
||||
"orchestrator": "#FFD700", # Gold
|
||||
"perplexity": "#3B82F6", # Blue
|
||||
"replit": "#F97316", # Orange
|
||||
"kimi": "#06B6D4", # Cyan
|
||||
"claude": "#A855F7", # Purple
|
||||
"researcher": "#10B981", # Emerald
|
||||
"coder": "#EF4444", # Red
|
||||
"writer": "#EC4899", # Pink
|
||||
"memory": "#8B5CF6", # Violet
|
||||
"experimenter": "#14B8A6", # Teal
|
||||
"forge": "#EF4444", # Red (coder alias)
|
||||
"seer": "#10B981", # Emerald (researcher alias)
|
||||
"quill": "#EC4899", # Pink (writer alias)
|
||||
"echo": "#8B5CF6", # Violet (memory alias)
|
||||
"lab": "#14B8A6", # Teal (experimenter alias)
|
||||
}
|
||||
|
||||
# Agent shape mapping for 3D visualization
|
||||
_AGENT_SHAPES: dict[str, str] = {
|
||||
"timmy": "sphere",
|
||||
"orchestrator": "sphere",
|
||||
"perplexity": "cube",
|
||||
"replit": "cylinder",
|
||||
"kimi": "dodecahedron",
|
||||
"claude": "octahedron",
|
||||
"researcher": "icosahedron",
|
||||
"coder": "cube",
|
||||
"writer": "cone",
|
||||
"memory": "torus",
|
||||
"experimenter": "tetrahedron",
|
||||
"forge": "cube",
|
||||
"seer": "icosahedron",
|
||||
"quill": "cone",
|
||||
"echo": "torus",
|
||||
"lab": "tetrahedron",
|
||||
}
|
||||
|
||||
# Default fallback values
|
||||
_DEFAULT_COLOR = "#9CA3AF" # Gray
|
||||
_DEFAULT_SHAPE = "sphere"
|
||||
_DEFAULT_STATUS = "available"
|
||||
|
||||
|
||||
def _get_agent_color(agent_id: str) -> str:
|
||||
"""Get the Matrix color for an agent."""
|
||||
return _AGENT_COLORS.get(agent_id.lower(), _DEFAULT_COLOR)
|
||||
|
||||
|
||||
def _get_agent_shape(agent_id: str) -> str:
|
||||
"""Get the Matrix shape for an agent."""
|
||||
return _AGENT_SHAPES.get(agent_id.lower(), _DEFAULT_SHAPE)
|
||||
|
||||
|
||||
def _compute_circular_positions(count: int, radius: float = 3.0) -> list[dict[str, float]]:
|
||||
"""Compute circular positions for agents in the Matrix.
|
||||
|
||||
Agents are arranged in a circle on the XZ plane at y=0.
|
||||
"""
|
||||
positions = []
|
||||
for i in range(count):
|
||||
angle = (2 * math.pi * i) / count
|
||||
x = radius * math.cos(angle)
|
||||
z = radius * math.sin(angle)
|
||||
positions.append({"x": round(x, 2), "y": 0.0, "z": round(z, 2)})
|
||||
return positions
|
||||
|
||||
|
||||
def _build_matrix_agents_response() -> list[dict[str, Any]]:
|
||||
"""Build the Matrix agent registry response.
|
||||
|
||||
Reads from agents.yaml and returns agents with Matrix-compatible
|
||||
formatting including colors, shapes, and positions.
|
||||
"""
|
||||
try:
|
||||
from timmy.agents.loader import list_agents
|
||||
|
||||
agents = list_agents()
|
||||
if not agents:
|
||||
return []
|
||||
|
||||
positions = _compute_circular_positions(len(agents))
|
||||
|
||||
result = []
|
||||
for i, agent in enumerate(agents):
|
||||
agent_id = agent.get("id", "")
|
||||
result.append(
|
||||
{
|
||||
"id": agent_id,
|
||||
"display_name": agent.get("name", agent_id.title()),
|
||||
"role": agent.get("role", "general"),
|
||||
"color": _get_agent_color(agent_id),
|
||||
"position": positions[i],
|
||||
"shape": _get_agent_shape(agent_id),
|
||||
"status": agent.get("status", _DEFAULT_STATUS),
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load agents for Matrix: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/world", tags=["world"])
|
||||
@@ -211,6 +415,50 @@ async def _heartbeat(websocket: WebSocket) -> None:
|
||||
logger.debug("Heartbeat stopped — connection gone")
|
||||
|
||||
|
||||
async def _authenticate_ws(websocket: WebSocket) -> bool:
|
||||
"""Authenticate WebSocket connection using matrix_ws_token.
|
||||
|
||||
Checks for token in query param ?token= first. If no query param,
|
||||
accepts the connection and waits for first message with
|
||||
{"type": "auth", "token": "..."}.
|
||||
|
||||
Returns True if authenticated (or if auth is disabled).
|
||||
Returns False and closes connection with code 4001 if invalid.
|
||||
"""
|
||||
token_setting = settings.matrix_ws_token
|
||||
|
||||
# Auth disabled in dev mode (empty/unset token)
|
||||
if not token_setting:
|
||||
return True
|
||||
|
||||
# Check query param first (can validate before accept)
|
||||
query_token = websocket.query_params.get("token", "")
|
||||
if query_token:
|
||||
if query_token == token_setting:
|
||||
return True
|
||||
# Invalid token in query param - we need to accept to close properly
|
||||
await websocket.accept()
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return False
|
||||
|
||||
# No query token - accept and wait for auth message
|
||||
await websocket.accept()
|
||||
|
||||
# Wait for auth message as first message
|
||||
try:
|
||||
raw = await websocket.receive_text()
|
||||
data = json.loads(raw)
|
||||
if data.get("type") == "auth" and data.get("token") == token_setting:
|
||||
return True
|
||||
# Invalid auth message
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return False
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON first message without valid token
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
return False
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def world_ws(websocket: WebSocket) -> None:
|
||||
"""Accept a Workshop client and keep it alive for state broadcasts.
|
||||
@@ -219,8 +467,28 @@ async def world_ws(websocket: WebSocket) -> None:
|
||||
client never starts from a blank slate. Incoming frames are parsed
|
||||
as JSON — ``visitor_message`` triggers a bark response. A background
|
||||
heartbeat ping runs every 15 s to detect dead connections early.
|
||||
|
||||
Authentication:
|
||||
- If matrix_ws_token is configured, clients must provide it via
|
||||
?token= query param or in the first message as
|
||||
{"type": "auth", "token": "..."}.
|
||||
- Invalid token results in close code 4001.
|
||||
- Valid token receives a connection_ack message.
|
||||
"""
|
||||
await websocket.accept()
|
||||
# Authenticate (may accept connection internally)
|
||||
is_authed = await _authenticate_ws(websocket)
|
||||
if not is_authed:
|
||||
logger.info("World WS connection rejected — invalid token")
|
||||
return
|
||||
|
||||
# Auth passed - accept if not already accepted
|
||||
if websocket.client_state.name != "CONNECTED":
|
||||
await websocket.accept()
|
||||
|
||||
# Send connection_ack if auth was required
|
||||
if settings.matrix_ws_token:
|
||||
await websocket.send_text(json.dumps({"type": "connection_ack"}))
|
||||
|
||||
_ws_clients.append(websocket)
|
||||
logger.info("World WS connected — %d clients", len(_ws_clients))
|
||||
|
||||
@@ -370,3 +638,428 @@ async def _generate_bark(visitor_text: str) -> str:
|
||||
except Exception as exc:
|
||||
logger.warning("Bark generation failed: %s", exc)
|
||||
return "Hmm, my thoughts are a bit tangled right now."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Configuration Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Default Matrix configuration (fallback when matrix.yaml is missing/corrupt)
|
||||
_DEFAULT_MATRIX_CONFIG: dict[str, Any] = {
|
||||
"lighting": {
|
||||
"ambient_color": "#1a1a2e",
|
||||
"ambient_intensity": 0.4,
|
||||
"point_lights": [
|
||||
{"color": "#FFD700", "intensity": 1.2, "position": {"x": 0, "y": 5, "z": 0}},
|
||||
{"color": "#3B82F6", "intensity": 0.8, "position": {"x": -5, "y": 3, "z": -5}},
|
||||
{"color": "#A855F7", "intensity": 0.6, "position": {"x": 5, "y": 3, "z": 5}},
|
||||
],
|
||||
},
|
||||
"environment": {
|
||||
"rain_enabled": False,
|
||||
"starfield_enabled": True,
|
||||
"fog_color": "#0f0f23",
|
||||
"fog_density": 0.02,
|
||||
},
|
||||
"features": {
|
||||
"chat_enabled": True,
|
||||
"visitor_avatars": True,
|
||||
"pip_familiar": True,
|
||||
"workshop_portal": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _load_matrix_config() -> dict[str, Any]:
|
||||
"""Load Matrix world configuration from matrix.yaml with fallback to defaults.
|
||||
|
||||
Returns a dict with sections: lighting, environment, features.
|
||||
If the config file is missing or invalid, returns sensible defaults.
|
||||
"""
|
||||
try:
|
||||
config_path = Path(settings.repo_root) / "config" / "matrix.yaml"
|
||||
if not config_path.exists():
|
||||
logger.debug("matrix.yaml not found, using default config")
|
||||
return _DEFAULT_MATRIX_CONFIG.copy()
|
||||
|
||||
raw = config_path.read_text()
|
||||
config = yaml.safe_load(raw)
|
||||
if not isinstance(config, dict):
|
||||
logger.warning("matrix.yaml invalid format, using defaults")
|
||||
return _DEFAULT_MATRIX_CONFIG.copy()
|
||||
|
||||
# Merge with defaults to ensure all required fields exist
|
||||
result: dict[str, Any] = {
|
||||
"lighting": {
|
||||
**_DEFAULT_MATRIX_CONFIG["lighting"],
|
||||
**config.get("lighting", {}),
|
||||
},
|
||||
"environment": {
|
||||
**_DEFAULT_MATRIX_CONFIG["environment"],
|
||||
**config.get("environment", {}),
|
||||
},
|
||||
"features": {
|
||||
**_DEFAULT_MATRIX_CONFIG["features"],
|
||||
**config.get("features", {}),
|
||||
},
|
||||
}
|
||||
|
||||
# Ensure point_lights is a list
|
||||
if "point_lights" in config.get("lighting", {}):
|
||||
result["lighting"]["point_lights"] = config["lighting"]["point_lights"]
|
||||
else:
|
||||
result["lighting"]["point_lights"] = _DEFAULT_MATRIX_CONFIG["lighting"]["point_lights"]
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load matrix config: %s, using defaults", exc)
|
||||
return _DEFAULT_MATRIX_CONFIG.copy()
|
||||
|
||||
|
||||
@matrix_router.get("/config")
|
||||
async def get_matrix_config() -> JSONResponse:
|
||||
"""Return Matrix world configuration.
|
||||
|
||||
Serves lighting presets, environment settings, and feature flags
|
||||
to the Matrix frontend so it can be config-driven rather than
|
||||
hardcoded. Reads from config/matrix.yaml with sensible defaults.
|
||||
|
||||
Response structure:
|
||||
- lighting: ambient_color, ambient_intensity, point_lights[]
|
||||
- environment: rain_enabled, starfield_enabled, fog_color, fog_density
|
||||
- features: chat_enabled, visitor_avatars, pip_familiar, workshop_portal
|
||||
"""
|
||||
config = _load_matrix_config()
|
||||
return JSONResponse(
|
||||
content=config,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Agent Registry Endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@matrix_router.get("/agents")
|
||||
async def get_matrix_agents() -> JSONResponse:
|
||||
"""Return the agent registry for Matrix visualization.
|
||||
|
||||
Serves agents from agents.yaml with Matrix-compatible formatting:
|
||||
- id: agent identifier
|
||||
- display_name: human-readable name
|
||||
- role: functional role
|
||||
- color: hex color code for visualization
|
||||
- position: {x, y, z} coordinates in 3D space
|
||||
- shape: 3D shape type
|
||||
- status: availability status
|
||||
|
||||
Agents are arranged in a circular layout by default.
|
||||
Returns 200 with empty list if no agents configured.
|
||||
"""
|
||||
agents = _build_matrix_agents_response()
|
||||
return JSONResponse(
|
||||
content=agents,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Thoughts Endpoint — Timmy's recent thought stream for Matrix display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MAX_THOUGHT_LIMIT = 50 # Maximum thoughts allowed per request
|
||||
_DEFAULT_THOUGHT_LIMIT = 10 # Default number of thoughts to return
|
||||
_MAX_THOUGHT_TEXT_LEN = 500 # Max characters for thought text
|
||||
|
||||
|
||||
def _build_matrix_thoughts_response(limit: int = _DEFAULT_THOUGHT_LIMIT) -> list[dict[str, Any]]:
|
||||
"""Build the Matrix thoughts response from the thinking engine.
|
||||
|
||||
Returns recent thoughts formatted for Matrix display:
|
||||
- id: thought UUID
|
||||
- text: thought content (truncated to 500 chars)
|
||||
- created_at: ISO-8601 timestamp
|
||||
- chain_id: parent thought ID (or null if root thought)
|
||||
|
||||
Returns empty list if thinking engine is disabled or fails.
|
||||
"""
|
||||
try:
|
||||
from timmy.thinking import thinking_engine
|
||||
|
||||
thoughts = thinking_engine.get_recent_thoughts(limit=limit)
|
||||
return [
|
||||
{
|
||||
"id": t.id,
|
||||
"text": t.content[:_MAX_THOUGHT_TEXT_LEN],
|
||||
"created_at": t.created_at,
|
||||
"chain_id": t.parent_id,
|
||||
}
|
||||
for t in thoughts
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load thoughts for Matrix: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
@matrix_router.get("/thoughts")
|
||||
async def get_matrix_thoughts(limit: int = _DEFAULT_THOUGHT_LIMIT) -> JSONResponse:
|
||||
"""Return Timmy's recent thoughts formatted for Matrix display.
|
||||
|
||||
This is the REST companion to the thought WebSocket messages,
|
||||
allowing the Matrix frontend to display what Timmy is actually
|
||||
thinking about rather than canned contextual lines.
|
||||
|
||||
Query params:
|
||||
- limit: Number of thoughts to return (default 10, max 50)
|
||||
|
||||
Response: JSON array of thought objects:
|
||||
- id: thought UUID
|
||||
- text: thought content (truncated to 500 chars)
|
||||
- created_at: ISO-8601 timestamp
|
||||
- chain_id: parent thought ID (null if root thought)
|
||||
|
||||
Returns empty array if thinking engine is disabled or fails.
|
||||
"""
|
||||
# Clamp limit to valid range
|
||||
if limit < 1:
|
||||
limit = 1
|
||||
elif limit > _MAX_THOUGHT_LIMIT:
|
||||
limit = _MAX_THOUGHT_LIMIT
|
||||
|
||||
thoughts = _build_matrix_thoughts_response(limit=limit)
|
||||
return JSONResponse(
|
||||
content=thoughts,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Health Endpoint — backend capability discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Health check cache (5-second TTL for capability checks)
|
||||
_health_cache: dict | None = None
|
||||
_health_cache_ts: float = 0.0
|
||||
_HEALTH_CACHE_TTL = 5.0
|
||||
|
||||
|
||||
def _check_capability_thinking() -> bool:
|
||||
"""Check if thinking engine is available."""
|
||||
try:
|
||||
from timmy.thinking import thinking_engine
|
||||
|
||||
# Check if the engine has been initialized (has a db path)
|
||||
return hasattr(thinking_engine, "_db") and thinking_engine._db is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_capability_memory() -> bool:
|
||||
"""Check if memory system is available."""
|
||||
try:
|
||||
from timmy.memory_system import HOT_MEMORY_PATH
|
||||
|
||||
return HOT_MEMORY_PATH.exists()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_capability_bark() -> bool:
|
||||
"""Check if bark production is available."""
|
||||
try:
|
||||
from infrastructure.presence import produce_bark
|
||||
|
||||
return callable(produce_bark)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_capability_familiar() -> bool:
|
||||
"""Check if familiar (Pip) is available."""
|
||||
try:
|
||||
from timmy.familiar import pip_familiar
|
||||
|
||||
return pip_familiar is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_capability_lightning() -> bool:
|
||||
"""Check if Lightning payments are available."""
|
||||
# Lightning is currently disabled per health.py
|
||||
# Returns False until properly re-implemented
|
||||
return False
|
||||
|
||||
|
||||
def _build_matrix_health_response() -> dict[str, Any]:
|
||||
"""Build the Matrix health response with capability checks.
|
||||
|
||||
Performs lightweight checks (<100ms total) to determine which features
|
||||
are available. Returns 200 even if some capabilities are degraded.
|
||||
"""
|
||||
capabilities = {
|
||||
"thinking": _check_capability_thinking(),
|
||||
"memory": _check_capability_memory(),
|
||||
"bark": _check_capability_bark(),
|
||||
"familiar": _check_capability_familiar(),
|
||||
"lightning": _check_capability_lightning(),
|
||||
}
|
||||
|
||||
# Status is ok if core capabilities (thinking, memory, bark) are available
|
||||
core_caps = ["thinking", "memory", "bark"]
|
||||
core_available = all(capabilities[c] for c in core_caps)
|
||||
status = "ok" if core_available else "degraded"
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"version": "1.0.0",
|
||||
"capabilities": capabilities,
|
||||
}
|
||||
|
||||
|
||||
@matrix_router.get("/health")
|
||||
async def get_matrix_health() -> JSONResponse:
|
||||
"""Return health status and capability availability for Matrix frontend.
|
||||
|
||||
This endpoint allows the Matrix frontend to discover what backend
|
||||
capabilities are available so it can show/hide UI elements:
|
||||
- thinking: Show thought bubbles if enabled
|
||||
- memory: Show crystal ball memory search if available
|
||||
- bark: Enable visitor chat responses
|
||||
- familiar: Show Pip the familiar
|
||||
- lightning: Enable payment features
|
||||
|
||||
Response time is <100ms (no heavy checks). Returns 200 even if
|
||||
some capabilities are degraded.
|
||||
|
||||
Response:
|
||||
- status: "ok" or "degraded"
|
||||
- version: API version string
|
||||
- capabilities: dict of feature:bool
|
||||
"""
|
||||
response = _build_matrix_health_response()
|
||||
status_code = 200 # Always 200, even if degraded
|
||||
|
||||
return JSONResponse(
|
||||
content=response,
|
||||
status_code=status_code,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matrix Memory Search Endpoint — visitors query Timmy's memory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Rate limiting: 1 search per 5 seconds per IP
|
||||
_MEMORY_SEARCH_RATE_LIMIT_SECONDS = 5
|
||||
_memory_search_last_request: dict[str, float] = {}
|
||||
_MAX_MEMORY_RESULTS = 5
|
||||
_MAX_MEMORY_TEXT_LENGTH = 200
|
||||
|
||||
|
||||
def _get_client_ip(request) -> str:
|
||||
"""Extract client IP from request, respecting X-Forwarded-For header."""
|
||||
# Check for forwarded IP (when behind proxy)
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# Take the first IP in the chain
|
||||
return forwarded.split(",")[0].strip()
|
||||
# Fall back to direct client IP
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _build_matrix_memory_response(
|
||||
memories: list,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the Matrix memory search response.
|
||||
|
||||
Formats memory entries for Matrix display:
|
||||
- text: truncated to 200 characters
|
||||
- relevance: 0-1 score from relevance_score
|
||||
- created_at: ISO-8601 timestamp
|
||||
- context_type: the memory type
|
||||
|
||||
Results are capped at _MAX_MEMORY_RESULTS.
|
||||
"""
|
||||
results = []
|
||||
for mem in memories[:_MAX_MEMORY_RESULTS]:
|
||||
text = mem.content
|
||||
if len(text) > _MAX_MEMORY_TEXT_LENGTH:
|
||||
text = text[:_MAX_MEMORY_TEXT_LENGTH] + "..."
|
||||
|
||||
results.append(
|
||||
{
|
||||
"text": text,
|
||||
"relevance": round(mem.relevance_score or 0.0, 4),
|
||||
"created_at": mem.timestamp,
|
||||
"context_type": mem.context_type,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@matrix_router.get("/memory/search")
|
||||
async def get_matrix_memory_search(request: Request, q: str | None = None) -> JSONResponse:
|
||||
"""Search Timmy's memory for relevant snippets.
|
||||
|
||||
Allows Matrix visitors to query Timmy's memory ("what do you remember
|
||||
about sovereignty?"). Results appear as floating crystal-ball text
|
||||
in the Workshop room.
|
||||
|
||||
Query params:
|
||||
- q: Search query text (required)
|
||||
|
||||
Response: JSON array of memory objects:
|
||||
- text: Memory content (truncated to 200 chars)
|
||||
- relevance: Similarity score 0-1
|
||||
- created_at: ISO-8601 timestamp
|
||||
- context_type: Memory type (conversation, fact, etc.)
|
||||
|
||||
Rate limited to 1 search per 5 seconds per IP.
|
||||
|
||||
Returns:
|
||||
- 200: JSON array of memory results (max 5)
|
||||
- 400: Missing or empty query parameter
|
||||
- 429: Rate limit exceeded
|
||||
"""
|
||||
# Validate query parameter
|
||||
query = q.strip() if q else ""
|
||||
if not query:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"error": "Query parameter 'q' is required"},
|
||||
)
|
||||
|
||||
# Rate limiting check by IP
|
||||
client_ip = _get_client_ip(request)
|
||||
now = time.time()
|
||||
last_request = _memory_search_last_request.get(client_ip, 0)
|
||||
time_since_last = now - last_request
|
||||
|
||||
if time_since_last < _MEMORY_SEARCH_RATE_LIMIT_SECONDS:
|
||||
retry_after = _MEMORY_SEARCH_RATE_LIMIT_SECONDS - time_since_last
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "Rate limit exceeded. Try again later."},
|
||||
headers={"Retry-After": str(int(retry_after) + 1)},
|
||||
)
|
||||
|
||||
# Record this request
|
||||
_memory_search_last_request[client_ip] = now
|
||||
|
||||
# Search memories
|
||||
try:
|
||||
memories = search_memories(query, limit=_MAX_MEMORY_RESULTS)
|
||||
results = _build_matrix_memory_response(memories)
|
||||
except Exception as exc:
|
||||
logger.warning("Memory search failed: %s", exc)
|
||||
results = []
|
||||
|
||||
return JSONResponse(
|
||||
content=results,
|
||||
headers={"Cache-Control": "no-cache, no-store"},
|
||||
)
|
||||
|
||||
266
src/infrastructure/matrix_config.py
Normal file
266
src/infrastructure/matrix_config.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Matrix configuration loader utility.
|
||||
|
||||
Provides a typed dataclass for Matrix world configuration and a loader
|
||||
that fetches settings from YAML with sensible defaults.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointLight:
|
||||
"""A single point light in the Matrix world."""
|
||||
|
||||
color: str = "#FFFFFF"
|
||||
intensity: float = 1.0
|
||||
position: dict[str, float] = field(default_factory=lambda: {"x": 0, "y": 0, "z": 0})
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "PointLight":
|
||||
"""Create a PointLight from a dictionary with defaults."""
|
||||
return cls(
|
||||
color=data.get("color", "#FFFFFF"),
|
||||
intensity=data.get("intensity", 1.0),
|
||||
position=data.get("position", {"x": 0, "y": 0, "z": 0}),
|
||||
)
|
||||
|
||||
|
||||
def _default_point_lights_factory() -> list[PointLight]:
|
||||
"""Factory function for default point lights."""
|
||||
return [
|
||||
PointLight(
|
||||
color="#FFAA55", # Warm amber (Workshop)
|
||||
intensity=1.2,
|
||||
position={"x": 0, "y": 5, "z": 0},
|
||||
),
|
||||
PointLight(
|
||||
color="#3B82F6", # Cool blue (Matrix)
|
||||
intensity=0.8,
|
||||
position={"x": -5, "y": 3, "z": -5},
|
||||
),
|
||||
PointLight(
|
||||
color="#A855F7", # Purple accent
|
||||
intensity=0.6,
|
||||
position={"x": 5, "y": 3, "z": 5},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LightingConfig:
|
||||
"""Lighting configuration for the Matrix world."""
|
||||
|
||||
ambient_color: str = "#FFAA55" # Warm amber (Workshop warmth)
|
||||
ambient_intensity: float = 0.5
|
||||
point_lights: list[PointLight] = field(default_factory=_default_point_lights_factory)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any] | None) -> "LightingConfig":
|
||||
"""Create a LightingConfig from a dictionary with defaults."""
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
point_lights_data = data.get("point_lights", [])
|
||||
point_lights = (
|
||||
[PointLight.from_dict(pl) for pl in point_lights_data]
|
||||
if point_lights_data
|
||||
else _default_point_lights_factory()
|
||||
)
|
||||
|
||||
return cls(
|
||||
ambient_color=data.get("ambient_color", "#FFAA55"),
|
||||
ambient_intensity=data.get("ambient_intensity", 0.5),
|
||||
point_lights=point_lights,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvironmentConfig:
|
||||
"""Environment settings for the Matrix world."""
|
||||
|
||||
rain_enabled: bool = False
|
||||
starfield_enabled: bool = True
|
||||
fog_color: str = "#0f0f23"
|
||||
fog_density: float = 0.02
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any] | None) -> "EnvironmentConfig":
|
||||
"""Create an EnvironmentConfig from a dictionary with defaults."""
|
||||
if data is None:
|
||||
data = {}
|
||||
return cls(
|
||||
rain_enabled=data.get("rain_enabled", False),
|
||||
starfield_enabled=data.get("starfield_enabled", True),
|
||||
fog_color=data.get("fog_color", "#0f0f23"),
|
||||
fog_density=data.get("fog_density", 0.02),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeaturesConfig:
|
||||
"""Feature toggles for the Matrix world."""
|
||||
|
||||
chat_enabled: bool = True
|
||||
visitor_avatars: bool = True
|
||||
pip_familiar: bool = True
|
||||
workshop_portal: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any] | None) -> "FeaturesConfig":
|
||||
"""Create a FeaturesConfig from a dictionary with defaults."""
|
||||
if data is None:
|
||||
data = {}
|
||||
return cls(
|
||||
chat_enabled=data.get("chat_enabled", True),
|
||||
visitor_avatars=data.get("visitor_avatars", True),
|
||||
pip_familiar=data.get("pip_familiar", True),
|
||||
workshop_portal=data.get("workshop_portal", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for a single Matrix agent."""
|
||||
|
||||
name: str = ""
|
||||
role: str = ""
|
||||
enabled: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AgentConfig":
|
||||
"""Create an AgentConfig from a dictionary with defaults."""
|
||||
return cls(
|
||||
name=data.get("name", ""),
|
||||
role=data.get("role", ""),
|
||||
enabled=data.get("enabled", True),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentsConfig:
|
||||
"""Agent registry configuration."""
|
||||
|
||||
default_count: int = 5
|
||||
max_count: int = 20
|
||||
agents: list[AgentConfig] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any] | None) -> "AgentsConfig":
|
||||
"""Create an AgentsConfig from a dictionary with defaults."""
|
||||
if data is None:
|
||||
data = {}
|
||||
|
||||
agents_data = data.get("agents", [])
|
||||
agents = [AgentConfig.from_dict(a) for a in agents_data] if agents_data else []
|
||||
|
||||
return cls(
|
||||
default_count=data.get("default_count", 5),
|
||||
max_count=data.get("max_count", 20),
|
||||
agents=agents,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixConfig:
|
||||
"""Complete Matrix world configuration.
|
||||
|
||||
Combines lighting, environment, features, and agent settings
|
||||
into a single configuration object.
|
||||
"""
|
||||
|
||||
lighting: LightingConfig = field(default_factory=LightingConfig)
|
||||
environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
|
||||
features: FeaturesConfig = field(default_factory=FeaturesConfig)
|
||||
agents: AgentsConfig = field(default_factory=AgentsConfig)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any] | None) -> "MatrixConfig":
|
||||
"""Create a MatrixConfig from a dictionary with defaults for missing sections."""
|
||||
if data is None:
|
||||
data = {}
|
||||
return cls(
|
||||
lighting=LightingConfig.from_dict(data.get("lighting")),
|
||||
environment=EnvironmentConfig.from_dict(data.get("environment")),
|
||||
features=FeaturesConfig.from_dict(data.get("features")),
|
||||
agents=AgentsConfig.from_dict(data.get("agents")),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert the configuration to a plain dictionary."""
|
||||
return {
|
||||
"lighting": {
|
||||
"ambient_color": self.lighting.ambient_color,
|
||||
"ambient_intensity": self.lighting.ambient_intensity,
|
||||
"point_lights": [
|
||||
{
|
||||
"color": pl.color,
|
||||
"intensity": pl.intensity,
|
||||
"position": pl.position,
|
||||
}
|
||||
for pl in self.lighting.point_lights
|
||||
],
|
||||
},
|
||||
"environment": {
|
||||
"rain_enabled": self.environment.rain_enabled,
|
||||
"starfield_enabled": self.environment.starfield_enabled,
|
||||
"fog_color": self.environment.fog_color,
|
||||
"fog_density": self.environment.fog_density,
|
||||
},
|
||||
"features": {
|
||||
"chat_enabled": self.features.chat_enabled,
|
||||
"visitor_avatars": self.features.visitor_avatars,
|
||||
"pip_familiar": self.features.pip_familiar,
|
||||
"workshop_portal": self.features.workshop_portal,
|
||||
},
|
||||
"agents": {
|
||||
"default_count": self.agents.default_count,
|
||||
"max_count": self.agents.max_count,
|
||||
"agents": [
|
||||
{"name": a.name, "role": a.role, "enabled": a.enabled}
|
||||
for a in self.agents.agents
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_from_yaml(path: str | Path) -> MatrixConfig:
|
||||
"""Load Matrix configuration from a YAML file.
|
||||
|
||||
Missing keys are filled with sensible defaults. If the file
|
||||
cannot be read or parsed, returns a fully default configuration.
|
||||
|
||||
Args:
|
||||
path: Path to the YAML configuration file.
|
||||
|
||||
Returns:
|
||||
A MatrixConfig instance with loaded or default values.
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
logger.warning("Matrix config file not found: %s, using defaults", path)
|
||||
return MatrixConfig()
|
||||
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
raw_data = yaml.safe_load(f)
|
||||
|
||||
if not isinstance(raw_data, dict):
|
||||
logger.warning("Matrix config invalid format, using defaults")
|
||||
return MatrixConfig()
|
||||
|
||||
return MatrixConfig.from_dict(raw_data)
|
||||
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Matrix config YAML parse error: %s, using defaults", exc)
|
||||
return MatrixConfig()
|
||||
except OSError as exc:
|
||||
logger.warning("Matrix config read error: %s, using defaults", exc)
|
||||
return MatrixConfig()
|
||||
@@ -5,9 +5,152 @@ into the camelCase world-state payload consumed by the Workshop 3D renderer
|
||||
and WebSocket gateway.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default Pip familiar state (used when familiar module unavailable)
|
||||
DEFAULT_PIP_STATE = {
|
||||
"name": "Pip",
|
||||
"mood": "sleepy",
|
||||
"energy": 0.5,
|
||||
"color": "0x00b450", # emerald green
|
||||
"trail_color": "0xdaa520", # gold
|
||||
}
|
||||
|
||||
|
||||
def _get_familiar_state() -> dict:
|
||||
"""Get Pip familiar state from familiar module, with graceful fallback.
|
||||
|
||||
Returns a dict with name, mood, energy, color, and trail_color.
|
||||
Falls back to default state if familiar module unavailable or raises.
|
||||
"""
|
||||
try:
|
||||
from timmy.familiar import pip_familiar
|
||||
|
||||
snapshot = pip_familiar.snapshot()
|
||||
# Map PipSnapshot fields to the expected agent_state format
|
||||
return {
|
||||
"name": snapshot.name,
|
||||
"mood": snapshot.state,
|
||||
"energy": DEFAULT_PIP_STATE["energy"], # Pip doesn't track energy yet
|
||||
"color": DEFAULT_PIP_STATE["color"],
|
||||
"trail_color": DEFAULT_PIP_STATE["trail_color"],
|
||||
}
|
||||
except Exception as exc:
|
||||
logger.warning("Familiar state unavailable, using default: %s", exc)
|
||||
return DEFAULT_PIP_STATE.copy()
|
||||
|
||||
|
||||
# Valid bark styles for Matrix protocol
|
||||
BARK_STYLES = {"speech", "thought", "whisper", "shout"}
|
||||
|
||||
|
||||
def produce_bark(agent_id: str, text: str, reply_to: str = None, style: str = "speech") -> dict:
|
||||
"""Format a chat response as a Matrix bark message.
|
||||
|
||||
Barks appear as floating text above agents in the Matrix 3D world with
|
||||
typing animation. This function formats the text for the Matrix protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
agent_id:
|
||||
Unique identifier for the agent (e.g. ``"timmy"``).
|
||||
text:
|
||||
The chat response text to display as a bark.
|
||||
reply_to:
|
||||
Optional message ID or reference this bark is replying to.
|
||||
style:
|
||||
Visual style of the bark. One of: "speech" (default), "thought",
|
||||
"whisper", "shout". Invalid styles fall back to "speech".
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Bark message with keys ``type``, ``agent_id``, ``data`` (containing
|
||||
``text``, ``reply_to``, ``style``), and ``ts``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> produce_bark("timmy", "Hello world!")
|
||||
{
|
||||
"type": "bark",
|
||||
"agent_id": "timmy",
|
||||
"data": {"text": "Hello world!", "reply_to": None, "style": "speech"},
|
||||
"ts": 1742529600,
|
||||
}
|
||||
"""
|
||||
# Validate and normalize style
|
||||
if style not in BARK_STYLES:
|
||||
style = "speech"
|
||||
|
||||
# Truncate text to 280 characters (bark, not essay)
|
||||
truncated_text = text[:280] if text else ""
|
||||
|
||||
return {
|
||||
"type": "bark",
|
||||
"agent_id": agent_id,
|
||||
"data": {
|
||||
"text": truncated_text,
|
||||
"reply_to": reply_to,
|
||||
"style": style,
|
||||
},
|
||||
"ts": int(time.time()),
|
||||
}
|
||||
|
||||
|
||||
def produce_thought(
|
||||
agent_id: str, thought_text: str, thought_id: int, chain_id: str = None
|
||||
) -> dict:
|
||||
"""Format a thinking engine thought as a Matrix thought message.
|
||||
|
||||
Thoughts appear as subtle floating text in the 3D world, streaming from
|
||||
Timmy's thinking engine (/thinking/api). This function wraps thoughts in
|
||||
Matrix protocol format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
agent_id:
|
||||
Unique identifier for the agent (e.g. ``"timmy"``).
|
||||
thought_text:
|
||||
The thought text to display. Truncated to 500 characters.
|
||||
thought_id:
|
||||
Unique identifier for this thought (sequence number).
|
||||
chain_id:
|
||||
Optional chain identifier grouping related thoughts.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Thought message with keys ``type``, ``agent_id``, ``data`` (containing
|
||||
``text``, ``thought_id``, ``chain_id``), and ``ts``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> produce_thought("timmy", "Considering the options...", 42, "chain-123")
|
||||
{
|
||||
"type": "thought",
|
||||
"agent_id": "timmy",
|
||||
"data": {"text": "Considering the options...", "thought_id": 42, "chain_id": "chain-123"},
|
||||
"ts": 1742529600,
|
||||
}
|
||||
"""
|
||||
# Truncate text to 500 characters (thoughts can be longer than barks)
|
||||
truncated_text = thought_text[:500] if thought_text else ""
|
||||
|
||||
return {
|
||||
"type": "thought",
|
||||
"agent_id": agent_id,
|
||||
"data": {
|
||||
"text": truncated_text,
|
||||
"thought_id": thought_id,
|
||||
"chain_id": chain_id,
|
||||
},
|
||||
"ts": int(time.time()),
|
||||
}
|
||||
|
||||
|
||||
def serialize_presence(presence: dict) -> dict:
|
||||
"""Transform an ADR-023 presence dict into the world-state API shape.
|
||||
@@ -93,6 +236,98 @@ def produce_agent_state(agent_id: str, presence: dict) -> dict:
|
||||
"mood": presence.get("mood", "calm"),
|
||||
"energy": presence.get("energy", 0.5),
|
||||
"bark": presence.get("bark", ""),
|
||||
"familiar": _get_familiar_state(),
|
||||
},
|
||||
"ts": int(time.time()),
|
||||
}
|
||||
|
||||
|
||||
def produce_system_status() -> dict:
|
||||
"""Generate a system_status message for the Matrix.
|
||||
|
||||
Returns a dict with system health metrics including agent count,
|
||||
visitor count, uptime, thinking engine status, and memory count.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Message with keys ``type``, ``data`` (containing ``agents_online``,
|
||||
``visitors``, ``uptime_seconds``, ``thinking_active``, ``memory_count``),
|
||||
and ``ts``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> produce_system_status()
|
||||
{
|
||||
"type": "system_status",
|
||||
"data": {
|
||||
"agents_online": 5,
|
||||
"visitors": 2,
|
||||
"uptime_seconds": 3600,
|
||||
"thinking_active": True,
|
||||
"memory_count": 150,
|
||||
},
|
||||
"ts": 1742529600,
|
||||
}
|
||||
"""
|
||||
# Count agents with status != offline
|
||||
agents_online = 0
|
||||
try:
|
||||
from timmy.agents.loader import list_agents
|
||||
|
||||
agents = list_agents()
|
||||
agents_online = sum(1 for a in agents if a.get("status", "") not in ("offline", ""))
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to count agents: %s", exc)
|
||||
|
||||
# Count visitors from WebSocket clients
|
||||
visitors = 0
|
||||
try:
|
||||
from dashboard.routes.world import _ws_clients
|
||||
|
||||
visitors = len(_ws_clients)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to count visitors: %s", exc)
|
||||
|
||||
# Calculate uptime
|
||||
uptime_seconds = 0
|
||||
try:
|
||||
from datetime import UTC
|
||||
|
||||
from config import APP_START_TIME
|
||||
|
||||
uptime_seconds = int((datetime.now(UTC) - APP_START_TIME).total_seconds())
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to calculate uptime: %s", exc)
|
||||
|
||||
# Check thinking engine status
|
||||
thinking_active = False
|
||||
try:
|
||||
from config import settings
|
||||
from timmy.thinking import thinking_engine
|
||||
|
||||
thinking_active = settings.thinking_enabled and thinking_engine is not None
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to check thinking status: %s", exc)
|
||||
|
||||
# Count memories in vector store
|
||||
memory_count = 0
|
||||
try:
|
||||
from timmy.memory_system import get_memory_stats
|
||||
|
||||
stats = get_memory_stats()
|
||||
memory_count = stats.get("total_entries", 0)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to count memories: %s", exc)
|
||||
|
||||
return {
|
||||
"type": "system_status",
|
||||
"data": {
|
||||
"agents_online": agents_online,
|
||||
"visitors": visitors,
|
||||
"uptime_seconds": uptime_seconds,
|
||||
"thinking_active": thinking_active,
|
||||
"memory_count": memory_count,
|
||||
},
|
||||
"ts": int(time.time()),
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
163
tests/loop/test_loop_guard_validate.py
Normal file
163
tests/loop/test_loop_guard_validate.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for cycle_result.json validation in loop_guard.
|
||||
|
||||
Covers validate_cycle_result(), _load_cycle_result(), and _is_issue_open().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import scripts.loop_guard as lg
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate(tmp_path, monkeypatch):
|
||||
"""Redirect loop_guard paths to tmp_path for isolation."""
|
||||
monkeypatch.setattr(lg, "CYCLE_RESULT_FILE", tmp_path / "cycle_result.json")
|
||||
monkeypatch.setattr(lg, "CYCLE_DURATION", 300)
|
||||
monkeypatch.setattr(lg, "GITEA_API", "http://test:3000/api/v1")
|
||||
monkeypatch.setattr(lg, "REPO_SLUG", "owner/repo")
|
||||
|
||||
|
||||
def _write_cr(tmp_path, data: dict, age_seconds: float = 0) -> Path:
|
||||
"""Write a cycle_result.json and optionally backdate it."""
|
||||
p = tmp_path / "cycle_result.json"
|
||||
p.write_text(json.dumps(data))
|
||||
if age_seconds:
|
||||
mtime = time.time() - age_seconds
|
||||
import os
|
||||
|
||||
os.utime(p, (mtime, mtime))
|
||||
return p
|
||||
|
||||
|
||||
# --- _load_cycle_result ---
|
||||
|
||||
|
||||
def test_load_cycle_result_missing(tmp_path):
|
||||
assert lg._load_cycle_result() == {}
|
||||
|
||||
|
||||
def test_load_cycle_result_valid(tmp_path):
|
||||
_write_cr(tmp_path, {"issue": 42, "type": "fix"})
|
||||
assert lg._load_cycle_result() == {"issue": 42, "type": "fix"}
|
||||
|
||||
|
||||
def test_load_cycle_result_markdown_fenced(tmp_path):
|
||||
p = tmp_path / "cycle_result.json"
|
||||
p.write_text('```json\n{"issue": 99}\n```')
|
||||
assert lg._load_cycle_result() == {"issue": 99}
|
||||
|
||||
|
||||
def test_load_cycle_result_malformed(tmp_path):
|
||||
p = tmp_path / "cycle_result.json"
|
||||
p.write_text("not json at all")
|
||||
assert lg._load_cycle_result() == {}
|
||||
|
||||
|
||||
# --- _is_issue_open ---
|
||||
|
||||
|
||||
def test_is_issue_open_true(monkeypatch):
|
||||
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
|
||||
resp_data = json.dumps({"state": "open"}).encode()
|
||||
|
||||
class FakeResp:
|
||||
def read(self):
|
||||
return resp_data
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
pass
|
||||
|
||||
with patch("urllib.request.urlopen", return_value=FakeResp()):
|
||||
assert lg._is_issue_open(42) is True
|
||||
|
||||
|
||||
def test_is_issue_open_closed(monkeypatch):
|
||||
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
|
||||
resp_data = json.dumps({"state": "closed"}).encode()
|
||||
|
||||
class FakeResp:
|
||||
def read(self):
|
||||
return resp_data
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *a):
|
||||
pass
|
||||
|
||||
with patch("urllib.request.urlopen", return_value=FakeResp()):
|
||||
assert lg._is_issue_open(42) is False
|
||||
|
||||
|
||||
def test_is_issue_open_no_token(monkeypatch):
|
||||
monkeypatch.setattr(lg, "_get_token", lambda: "")
|
||||
assert lg._is_issue_open(42) is None
|
||||
|
||||
|
||||
def test_is_issue_open_api_error(monkeypatch):
|
||||
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
|
||||
with patch("urllib.request.urlopen", side_effect=OSError("timeout")):
|
||||
assert lg._is_issue_open(42) is None
|
||||
|
||||
|
||||
# --- validate_cycle_result ---
|
||||
|
||||
|
||||
def test_validate_no_file(tmp_path):
|
||||
"""No file → returns False, no crash."""
|
||||
assert lg.validate_cycle_result() is False
|
||||
|
||||
|
||||
def test_validate_fresh_file_open_issue(tmp_path, monkeypatch):
|
||||
"""Fresh file with open issue → kept."""
|
||||
_write_cr(tmp_path, {"issue": 10})
|
||||
monkeypatch.setattr(lg, "_is_issue_open", lambda n: True)
|
||||
assert lg.validate_cycle_result() is False
|
||||
assert (tmp_path / "cycle_result.json").exists()
|
||||
|
||||
|
||||
def test_validate_stale_file_removed(tmp_path):
|
||||
"""File older than 2× CYCLE_DURATION → removed."""
|
||||
_write_cr(tmp_path, {"issue": 10}, age_seconds=700)
|
||||
assert lg.validate_cycle_result() is True
|
||||
assert not (tmp_path / "cycle_result.json").exists()
|
||||
|
||||
|
||||
def test_validate_fresh_file_closed_issue(tmp_path, monkeypatch):
|
||||
"""Fresh file referencing closed issue → removed."""
|
||||
_write_cr(tmp_path, {"issue": 10})
|
||||
monkeypatch.setattr(lg, "_is_issue_open", lambda n: False)
|
||||
assert lg.validate_cycle_result() is True
|
||||
assert not (tmp_path / "cycle_result.json").exists()
|
||||
|
||||
|
||||
def test_validate_api_failure_keeps_file(tmp_path, monkeypatch):
|
||||
"""API failure → file kept (graceful degradation)."""
|
||||
_write_cr(tmp_path, {"issue": 10})
|
||||
monkeypatch.setattr(lg, "_is_issue_open", lambda n: None)
|
||||
assert lg.validate_cycle_result() is False
|
||||
assert (tmp_path / "cycle_result.json").exists()
|
||||
|
||||
|
||||
def test_validate_no_issue_field(tmp_path):
|
||||
"""File without issue field → kept (only age check applies)."""
|
||||
_write_cr(tmp_path, {"type": "fix"})
|
||||
assert lg.validate_cycle_result() is False
|
||||
assert (tmp_path / "cycle_result.json").exists()
|
||||
|
||||
|
||||
def test_validate_stale_threshold_boundary(tmp_path, monkeypatch):
|
||||
"""File just under threshold → kept (not stale yet)."""
|
||||
_write_cr(tmp_path, {"issue": 10}, age_seconds=599)
|
||||
monkeypatch.setattr(lg, "_is_issue_open", lambda n: True)
|
||||
assert lg.validate_cycle_result() is False
|
||||
assert (tmp_path / "cycle_result.json").exists()
|
||||
331
tests/unit/test_matrix_config.py
Normal file
331
tests/unit/test_matrix_config.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Tests for the matrix configuration loader utility."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from infrastructure.matrix_config import (
|
||||
AgentConfig,
|
||||
AgentsConfig,
|
||||
EnvironmentConfig,
|
||||
FeaturesConfig,
|
||||
LightingConfig,
|
||||
MatrixConfig,
|
||||
PointLight,
|
||||
load_from_yaml,
|
||||
)
|
||||
|
||||
|
||||
class TestPointLight:
|
||||
"""Tests for PointLight dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""PointLight has correct defaults."""
|
||||
pl = PointLight()
|
||||
assert pl.color == "#FFFFFF"
|
||||
assert pl.intensity == 1.0
|
||||
assert pl.position == {"x": 0, "y": 0, "z": 0}
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""PointLight.from_dict loads all fields."""
|
||||
data = {
|
||||
"color": "#FF0000",
|
||||
"intensity": 2.5,
|
||||
"position": {"x": 1, "y": 2, "z": 3},
|
||||
}
|
||||
pl = PointLight.from_dict(data)
|
||||
assert pl.color == "#FF0000"
|
||||
assert pl.intensity == 2.5
|
||||
assert pl.position == {"x": 1, "y": 2, "z": 3}
|
||||
|
||||
def test_from_dict_partial(self):
|
||||
"""PointLight.from_dict fills missing fields with defaults."""
|
||||
data = {"color": "#00FF00"}
|
||||
pl = PointLight.from_dict(data)
|
||||
assert pl.color == "#00FF00"
|
||||
assert pl.intensity == 1.0
|
||||
assert pl.position == {"x": 0, "y": 0, "z": 0}
|
||||
|
||||
|
||||
class TestLightingConfig:
|
||||
"""Tests for LightingConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""LightingConfig has correct Workshop+Matrix blend defaults."""
|
||||
cfg = LightingConfig()
|
||||
assert cfg.ambient_color == "#FFAA55" # Warm amber (Workshop)
|
||||
assert cfg.ambient_intensity == 0.5
|
||||
assert len(cfg.point_lights) == 3
|
||||
# First light is warm amber center
|
||||
assert cfg.point_lights[0].color == "#FFAA55"
|
||||
# Second light is cool blue (Matrix)
|
||||
assert cfg.point_lights[1].color == "#3B82F6"
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""LightingConfig.from_dict loads all fields."""
|
||||
data = {
|
||||
"ambient_color": "#123456",
|
||||
"ambient_intensity": 0.8,
|
||||
"point_lights": [
|
||||
{"color": "#ABCDEF", "intensity": 1.5, "position": {"x": 1, "y": 1, "z": 1}}
|
||||
],
|
||||
}
|
||||
cfg = LightingConfig.from_dict(data)
|
||||
assert cfg.ambient_color == "#123456"
|
||||
assert cfg.ambient_intensity == 0.8
|
||||
assert len(cfg.point_lights) == 1
|
||||
assert cfg.point_lights[0].color == "#ABCDEF"
|
||||
|
||||
def test_from_dict_empty_list_uses_defaults(self):
|
||||
"""Empty point_lights list triggers default lights."""
|
||||
data = {"ambient_color": "#000000", "point_lights": []}
|
||||
cfg = LightingConfig.from_dict(data)
|
||||
assert cfg.ambient_color == "#000000"
|
||||
assert len(cfg.point_lights) == 3 # Default lights
|
||||
|
||||
def test_from_dict_none(self):
|
||||
"""LightingConfig.from_dict handles None."""
|
||||
cfg = LightingConfig.from_dict(None)
|
||||
assert cfg.ambient_color == "#FFAA55"
|
||||
assert len(cfg.point_lights) == 3
|
||||
|
||||
|
||||
class TestEnvironmentConfig:
|
||||
"""Tests for EnvironmentConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""EnvironmentConfig has correct defaults."""
|
||||
cfg = EnvironmentConfig()
|
||||
assert cfg.rain_enabled is False
|
||||
assert cfg.starfield_enabled is True # Matrix starfield
|
||||
assert cfg.fog_color == "#0f0f23"
|
||||
assert cfg.fog_density == 0.02
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""EnvironmentConfig.from_dict loads all fields."""
|
||||
data = {
|
||||
"rain_enabled": True,
|
||||
"starfield_enabled": False,
|
||||
"fog_color": "#FFFFFF",
|
||||
"fog_density": 0.1,
|
||||
}
|
||||
cfg = EnvironmentConfig.from_dict(data)
|
||||
assert cfg.rain_enabled is True
|
||||
assert cfg.starfield_enabled is False
|
||||
assert cfg.fog_color == "#FFFFFF"
|
||||
assert cfg.fog_density == 0.1
|
||||
|
||||
def test_from_dict_partial(self):
|
||||
"""EnvironmentConfig.from_dict fills missing fields."""
|
||||
data = {"rain_enabled": True}
|
||||
cfg = EnvironmentConfig.from_dict(data)
|
||||
assert cfg.rain_enabled is True
|
||||
assert cfg.starfield_enabled is True # Default
|
||||
assert cfg.fog_color == "#0f0f23"
|
||||
|
||||
|
||||
class TestFeaturesConfig:
|
||||
"""Tests for FeaturesConfig dataclass."""
|
||||
|
||||
def test_default_values_all_enabled(self):
|
||||
"""FeaturesConfig defaults to all features enabled."""
|
||||
cfg = FeaturesConfig()
|
||||
assert cfg.chat_enabled is True
|
||||
assert cfg.visitor_avatars is True
|
||||
assert cfg.pip_familiar is True
|
||||
assert cfg.workshop_portal is True
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""FeaturesConfig.from_dict loads all fields."""
|
||||
data = {
|
||||
"chat_enabled": False,
|
||||
"visitor_avatars": False,
|
||||
"pip_familiar": False,
|
||||
"workshop_portal": False,
|
||||
}
|
||||
cfg = FeaturesConfig.from_dict(data)
|
||||
assert cfg.chat_enabled is False
|
||||
assert cfg.visitor_avatars is False
|
||||
assert cfg.pip_familiar is False
|
||||
assert cfg.workshop_portal is False
|
||||
|
||||
def test_from_dict_partial(self):
|
||||
"""FeaturesConfig.from_dict fills missing fields."""
|
||||
data = {"chat_enabled": False}
|
||||
cfg = FeaturesConfig.from_dict(data)
|
||||
assert cfg.chat_enabled is False
|
||||
assert cfg.visitor_avatars is True # Default
|
||||
assert cfg.pip_familiar is True
|
||||
assert cfg.workshop_portal is True
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
"""Tests for AgentConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""AgentConfig has correct defaults."""
|
||||
cfg = AgentConfig()
|
||||
assert cfg.name == ""
|
||||
assert cfg.role == ""
|
||||
assert cfg.enabled is True
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""AgentConfig.from_dict loads all fields."""
|
||||
data = {"name": "Timmy", "role": "guide", "enabled": False}
|
||||
cfg = AgentConfig.from_dict(data)
|
||||
assert cfg.name == "Timmy"
|
||||
assert cfg.role == "guide"
|
||||
assert cfg.enabled is False
|
||||
|
||||
|
||||
class TestAgentsConfig:
|
||||
"""Tests for AgentsConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""AgentsConfig has correct defaults."""
|
||||
cfg = AgentsConfig()
|
||||
assert cfg.default_count == 5
|
||||
assert cfg.max_count == 20
|
||||
assert cfg.agents == []
|
||||
|
||||
def test_from_dict_with_agents(self):
|
||||
"""AgentsConfig.from_dict loads agent list."""
|
||||
data = {
|
||||
"default_count": 10,
|
||||
"max_count": 50,
|
||||
"agents": [
|
||||
{"name": "Timmy", "role": "guide", "enabled": True},
|
||||
{"name": "Helper", "role": "assistant"},
|
||||
],
|
||||
}
|
||||
cfg = AgentsConfig.from_dict(data)
|
||||
assert cfg.default_count == 10
|
||||
assert cfg.max_count == 50
|
||||
assert len(cfg.agents) == 2
|
||||
assert cfg.agents[0].name == "Timmy"
|
||||
assert cfg.agents[1].enabled is True # Default
|
||||
|
||||
|
||||
class TestMatrixConfig:
|
||||
"""Tests for MatrixConfig dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""MatrixConfig has correct composite defaults."""
|
||||
cfg = MatrixConfig()
|
||||
assert isinstance(cfg.lighting, LightingConfig)
|
||||
assert isinstance(cfg.environment, EnvironmentConfig)
|
||||
assert isinstance(cfg.features, FeaturesConfig)
|
||||
assert isinstance(cfg.agents, AgentsConfig)
|
||||
# Check the blend
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
assert cfg.environment.starfield_enabled is True
|
||||
assert cfg.features.chat_enabled is True
|
||||
|
||||
def test_from_dict_full(self):
|
||||
"""MatrixConfig.from_dict loads all sections."""
|
||||
data = {
|
||||
"lighting": {"ambient_color": "#000000"},
|
||||
"environment": {"rain_enabled": True},
|
||||
"features": {"chat_enabled": False},
|
||||
"agents": {"default_count": 3},
|
||||
}
|
||||
cfg = MatrixConfig.from_dict(data)
|
||||
assert cfg.lighting.ambient_color == "#000000"
|
||||
assert cfg.environment.rain_enabled is True
|
||||
assert cfg.features.chat_enabled is False
|
||||
assert cfg.agents.default_count == 3
|
||||
|
||||
def test_from_dict_partial(self):
|
||||
"""MatrixConfig.from_dict fills missing sections with defaults."""
|
||||
data = {"lighting": {"ambient_color": "#111111"}}
|
||||
cfg = MatrixConfig.from_dict(data)
|
||||
assert cfg.lighting.ambient_color == "#111111"
|
||||
assert cfg.environment.starfield_enabled is True # Default
|
||||
assert cfg.features.pip_familiar is True # Default
|
||||
|
||||
def test_from_dict_none(self):
|
||||
"""MatrixConfig.from_dict handles None."""
|
||||
cfg = MatrixConfig.from_dict(None)
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
assert cfg.features.chat_enabled is True
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
"""MatrixConfig.to_dict produces serializable output."""
|
||||
cfg = MatrixConfig()
|
||||
data = cfg.to_dict()
|
||||
assert isinstance(data, dict)
|
||||
assert "lighting" in data
|
||||
assert "environment" in data
|
||||
assert "features" in data
|
||||
assert "agents" in data
|
||||
# Verify point lights are included
|
||||
assert len(data["lighting"]["point_lights"]) == 3
|
||||
|
||||
|
||||
class TestLoadFromYaml:
|
||||
"""Tests for load_from_yaml function."""
|
||||
|
||||
def test_loads_valid_yaml(self, tmp_path: Path):
|
||||
"""load_from_yaml reads a valid YAML file."""
|
||||
config_path = tmp_path / "matrix.yaml"
|
||||
data = {
|
||||
"lighting": {"ambient_color": "#TEST11"},
|
||||
"features": {"chat_enabled": False},
|
||||
}
|
||||
config_path.write_text(yaml.safe_dump(data))
|
||||
|
||||
cfg = load_from_yaml(config_path)
|
||||
assert cfg.lighting.ambient_color == "#TEST11"
|
||||
assert cfg.features.chat_enabled is False
|
||||
|
||||
def test_missing_file_returns_defaults(self, tmp_path: Path):
|
||||
"""load_from_yaml returns defaults when file doesn't exist."""
|
||||
config_path = tmp_path / "nonexistent.yaml"
|
||||
cfg = load_from_yaml(config_path)
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
assert cfg.features.chat_enabled is True
|
||||
|
||||
def test_empty_file_returns_defaults(self, tmp_path: Path):
|
||||
"""load_from_yaml returns defaults for empty file."""
|
||||
config_path = tmp_path / "empty.yaml"
|
||||
config_path.write_text("")
|
||||
cfg = load_from_yaml(config_path)
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
|
||||
def test_invalid_yaml_returns_defaults(self, tmp_path: Path):
|
||||
"""load_from_yaml returns defaults for invalid YAML."""
|
||||
config_path = tmp_path / "invalid.yaml"
|
||||
config_path.write_text("not: valid: yaml: [")
|
||||
cfg = load_from_yaml(config_path)
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
assert cfg.features.chat_enabled is True
|
||||
|
||||
def test_non_dict_yaml_returns_defaults(self, tmp_path: Path):
|
||||
"""load_from_yaml returns defaults when YAML is not a dict."""
|
||||
config_path = tmp_path / "list.yaml"
|
||||
config_path.write_text("- item1\n- item2")
|
||||
cfg = load_from_yaml(config_path)
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
|
||||
def test_loads_actual_config_file(self):
|
||||
"""load_from_yaml can load the project's config/matrix.yaml."""
|
||||
repo_root = Path(__file__).parent.parent.parent
|
||||
config_path = repo_root / "config" / "matrix.yaml"
|
||||
if not config_path.exists():
|
||||
pytest.skip("config/matrix.yaml not found")
|
||||
|
||||
cfg = load_from_yaml(config_path)
|
||||
# Verify it loaded with expected values
|
||||
assert cfg.lighting.ambient_color == "#FFAA55"
|
||||
assert len(cfg.lighting.point_lights) == 3
|
||||
assert cfg.environment.starfield_enabled is True
|
||||
assert cfg.features.workshop_portal is True
|
||||
|
||||
def test_str_path_accepted(self, tmp_path: Path):
|
||||
"""load_from_yaml accepts string path."""
|
||||
config_path = tmp_path / "matrix.yaml"
|
||||
config_path.write_text(yaml.safe_dump({"lighting": {"ambient_intensity": 0.9}}))
|
||||
|
||||
cfg = load_from_yaml(str(config_path))
|
||||
assert cfg.lighting.ambient_intensity == 0.9
|
||||
@@ -4,7 +4,15 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.presence import produce_agent_state, serialize_presence
|
||||
from infrastructure.presence import (
|
||||
DEFAULT_PIP_STATE,
|
||||
_get_familiar_state,
|
||||
produce_agent_state,
|
||||
produce_bark,
|
||||
produce_system_status,
|
||||
produce_thought,
|
||||
serialize_presence,
|
||||
)
|
||||
|
||||
|
||||
class TestSerializePresence:
|
||||
@@ -163,3 +171,332 @@ class TestProduceAgentState:
|
||||
"""When display_name is missing, it's derived from agent_id.title()."""
|
||||
data = produce_agent_state("spark", {})["data"]
|
||||
assert data["display_name"] == "Spark"
|
||||
|
||||
def test_familiar_in_data(self):
|
||||
"""agent_state.data includes familiar field with required keys."""
|
||||
data = produce_agent_state("timmy", {})["data"]
|
||||
|
||||
assert "familiar" in data
|
||||
familiar = data["familiar"]
|
||||
assert familiar["name"] == "Pip"
|
||||
assert "mood" in familiar
|
||||
assert "energy" in familiar
|
||||
assert familiar["color"] == "0x00b450"
|
||||
assert familiar["trail_color"] == "0xdaa520"
|
||||
|
||||
def test_familiar_has_all_required_fields(self):
|
||||
"""familiar dict contains all required fields per acceptance criteria."""
|
||||
data = produce_agent_state("timmy", {})["data"]
|
||||
familiar = data["familiar"]
|
||||
|
||||
required_fields = {"name", "mood", "energy", "color", "trail_color"}
|
||||
assert set(familiar.keys()) >= required_fields
|
||||
|
||||
|
||||
class TestFamiliarState:
|
||||
"""Tests for _get_familiar_state() — Pip familiar state retrieval."""
|
||||
|
||||
def test_get_familiar_state_returns_dict(self):
|
||||
"""_get_familiar_state returns a dict."""
|
||||
result = _get_familiar_state()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
def test_get_familiar_state_has_required_fields(self):
|
||||
"""Result contains name, mood, energy, color, trail_color."""
|
||||
result = _get_familiar_state()
|
||||
|
||||
assert result["name"] == "Pip"
|
||||
assert "mood" in result
|
||||
assert isinstance(result["energy"], (int, float))
|
||||
assert result["color"] == "0x00b450"
|
||||
assert result["trail_color"] == "0xdaa520"
|
||||
|
||||
def test_default_pip_state_constant(self):
|
||||
"""DEFAULT_PIP_STATE has expected values."""
|
||||
assert DEFAULT_PIP_STATE["name"] == "Pip"
|
||||
assert DEFAULT_PIP_STATE["mood"] == "sleepy"
|
||||
assert DEFAULT_PIP_STATE["energy"] == 0.5
|
||||
assert DEFAULT_PIP_STATE["color"] == "0x00b450"
|
||||
assert DEFAULT_PIP_STATE["trail_color"] == "0xdaa520"
|
||||
|
||||
@patch("infrastructure.presence.logger")
|
||||
def test_get_familiar_state_fallback_on_exception(self, mock_logger):
|
||||
"""When familiar module raises, falls back to default and logs warning."""
|
||||
# Patch inside the function where pip_familiar is imported
|
||||
with patch("timmy.familiar.pip_familiar.snapshot") as mock_snapshot:
|
||||
mock_snapshot.side_effect = RuntimeError("Pip is napping")
|
||||
result = _get_familiar_state()
|
||||
|
||||
assert result["name"] == "Pip"
|
||||
assert result["mood"] == "sleepy"
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "Pip is napping" in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestProduceBark:
|
||||
"""Tests for produce_bark() — Matrix bark message producer."""
|
||||
|
||||
@patch("infrastructure.presence.time")
|
||||
def test_full_message_structure(self, mock_time):
|
||||
"""Returns dict with type, agent_id, data, and ts keys."""
|
||||
mock_time.time.return_value = 1742529600
|
||||
result = produce_bark("timmy", "Hello world!")
|
||||
|
||||
assert result["type"] == "bark"
|
||||
assert result["agent_id"] == "timmy"
|
||||
assert result["ts"] == 1742529600
|
||||
assert isinstance(result["data"], dict)
|
||||
|
||||
def test_data_fields(self):
|
||||
"""data dict contains text, reply_to, and style."""
|
||||
result = produce_bark("timmy", "Hello world!", reply_to="msg-123", style="shout")
|
||||
data = result["data"]
|
||||
|
||||
assert data["text"] == "Hello world!"
|
||||
assert data["reply_to"] == "msg-123"
|
||||
assert data["style"] == "shout"
|
||||
|
||||
def test_default_style_is_speech(self):
|
||||
"""When style is not provided, defaults to 'speech'."""
|
||||
result = produce_bark("timmy", "Hello!")
|
||||
assert result["data"]["style"] == "speech"
|
||||
|
||||
def test_default_reply_to_is_none(self):
|
||||
"""When reply_to is not provided, defaults to None."""
|
||||
result = produce_bark("timmy", "Hello!")
|
||||
assert result["data"]["reply_to"] is None
|
||||
|
||||
def test_text_truncated_to_280_chars(self):
|
||||
"""Text longer than 280 chars is truncated."""
|
||||
long_text = "A" * 500
|
||||
result = produce_bark("timmy", long_text)
|
||||
assert len(result["data"]["text"]) == 280
|
||||
assert result["data"]["text"] == "A" * 280
|
||||
|
||||
def test_text_exactly_280_chars_not_truncated(self):
|
||||
"""Text exactly 280 chars is not truncated."""
|
||||
text = "B" * 280
|
||||
result = produce_bark("timmy", text)
|
||||
assert result["data"]["text"] == text
|
||||
|
||||
def test_text_shorter_than_280_not_padded(self):
|
||||
"""Text shorter than 280 chars is not padded."""
|
||||
result = produce_bark("timmy", "Short")
|
||||
assert result["data"]["text"] == "Short"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("style", "expected_style"),
|
||||
[
|
||||
("speech", "speech"),
|
||||
("thought", "thought"),
|
||||
("whisper", "whisper"),
|
||||
("shout", "shout"),
|
||||
],
|
||||
)
|
||||
def test_valid_styles_preserved(self, style, expected_style):
|
||||
"""Valid style values are preserved."""
|
||||
result = produce_bark("timmy", "Hello!", style=style)
|
||||
assert result["data"]["style"] == expected_style
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_style",
|
||||
["yell", "scream", "", "SPEECH", "Speech", None, 123],
|
||||
)
|
||||
def test_invalid_style_defaults_to_speech(self, invalid_style):
|
||||
"""Invalid style values fall back to 'speech'."""
|
||||
result = produce_bark("timmy", "Hello!", style=invalid_style)
|
||||
assert result["data"]["style"] == "speech"
|
||||
|
||||
def test_empty_text_handled(self):
|
||||
"""Empty text is handled gracefully."""
|
||||
result = produce_bark("timmy", "")
|
||||
assert result["data"]["text"] == ""
|
||||
|
||||
def test_ts_is_unix_timestamp(self):
|
||||
"""ts should be an integer Unix timestamp."""
|
||||
result = produce_bark("timmy", "Hello!")
|
||||
assert isinstance(result["ts"], int)
|
||||
assert result["ts"] > 0
|
||||
|
||||
def test_agent_id_passed_through(self):
|
||||
"""agent_id appears in the top-level message."""
|
||||
result = produce_bark("spark", "Hello!")
|
||||
assert result["agent_id"] == "spark"
|
||||
|
||||
def test_with_all_parameters(self):
|
||||
"""Full parameter set produces expected output."""
|
||||
result = produce_bark(
|
||||
agent_id="timmy",
|
||||
text="Running test suite...",
|
||||
reply_to="parent-msg-456",
|
||||
style="thought",
|
||||
)
|
||||
|
||||
assert result["type"] == "bark"
|
||||
assert result["agent_id"] == "timmy"
|
||||
assert result["data"]["text"] == "Running test suite..."
|
||||
assert result["data"]["reply_to"] == "parent-msg-456"
|
||||
assert result["data"]["style"] == "thought"
|
||||
|
||||
|
||||
class TestProduceThought:
|
||||
"""Tests for produce_thought() — Matrix thought message producer."""
|
||||
|
||||
@patch("infrastructure.presence.time")
|
||||
def test_full_message_structure(self, mock_time):
|
||||
"""Returns dict with type, agent_id, data, and ts keys."""
|
||||
mock_time.time.return_value = 1742529600
|
||||
result = produce_thought("timmy", "Considering the options...", 42)
|
||||
|
||||
assert result["type"] == "thought"
|
||||
assert result["agent_id"] == "timmy"
|
||||
assert result["ts"] == 1742529600
|
||||
assert isinstance(result["data"], dict)
|
||||
|
||||
def test_data_fields(self):
|
||||
"""data dict contains text, thought_id, and chain_id."""
|
||||
result = produce_thought("timmy", "Considering...", 42, chain_id="chain-123")
|
||||
data = result["data"]
|
||||
|
||||
assert data["text"] == "Considering..."
|
||||
assert data["thought_id"] == 42
|
||||
assert data["chain_id"] == "chain-123"
|
||||
|
||||
def test_default_chain_id_is_none(self):
|
||||
"""When chain_id is not provided, defaults to None."""
|
||||
result = produce_thought("timmy", "Thinking...", 1)
|
||||
assert result["data"]["chain_id"] is None
|
||||
|
||||
def test_text_truncated_to_500_chars(self):
|
||||
"""Text longer than 500 chars is truncated."""
|
||||
long_text = "A" * 600
|
||||
result = produce_thought("timmy", long_text, 1)
|
||||
assert len(result["data"]["text"]) == 500
|
||||
assert result["data"]["text"] == "A" * 500
|
||||
|
||||
def test_text_exactly_500_chars_not_truncated(self):
|
||||
"""Text exactly 500 chars is not truncated."""
|
||||
text = "B" * 500
|
||||
result = produce_thought("timmy", text, 1)
|
||||
assert result["data"]["text"] == text
|
||||
|
||||
def test_text_shorter_than_500_not_padded(self):
|
||||
"""Text shorter than 500 chars is not padded."""
|
||||
result = produce_thought("timmy", "Short thought", 1)
|
||||
assert result["data"]["text"] == "Short thought"
|
||||
|
||||
def test_empty_text_handled(self):
|
||||
"""Empty text is handled gracefully."""
|
||||
result = produce_thought("timmy", "", 1)
|
||||
assert result["data"]["text"] == ""
|
||||
|
||||
def test_ts_is_unix_timestamp(self):
|
||||
"""ts should be an integer Unix timestamp."""
|
||||
result = produce_thought("timmy", "Hello!", 1)
|
||||
assert isinstance(result["ts"], int)
|
||||
assert result["ts"] > 0
|
||||
|
||||
def test_agent_id_passed_through(self):
|
||||
"""agent_id appears in the top-level message."""
|
||||
result = produce_thought("spark", "Hello!", 1)
|
||||
assert result["agent_id"] == "spark"
|
||||
|
||||
def test_thought_id_passed_through(self):
|
||||
"""thought_id appears in the data."""
|
||||
result = produce_thought("timmy", "Hello!", 999)
|
||||
assert result["data"]["thought_id"] == 999
|
||||
|
||||
def test_with_all_parameters(self):
|
||||
"""Full parameter set produces expected output."""
|
||||
result = produce_thought(
|
||||
agent_id="timmy",
|
||||
thought_text="Analyzing the situation...",
|
||||
thought_id=42,
|
||||
chain_id="chain-abc",
|
||||
)
|
||||
|
||||
assert result["type"] == "thought"
|
||||
assert result["agent_id"] == "timmy"
|
||||
assert result["data"]["text"] == "Analyzing the situation..."
|
||||
assert result["data"]["thought_id"] == 42
|
||||
assert result["data"]["chain_id"] == "chain-abc"
|
||||
|
||||
|
||||
class TestProduceSystemStatus:
|
||||
"""Tests for produce_system_status() — Matrix system_status message producer."""
|
||||
|
||||
@patch("infrastructure.presence.time")
|
||||
def test_full_message_structure(self, mock_time):
|
||||
"""Returns dict with type, data, and ts keys."""
|
||||
mock_time.time.return_value = 1742529600
|
||||
result = produce_system_status()
|
||||
|
||||
assert result["type"] == "system_status"
|
||||
assert result["ts"] == 1742529600
|
||||
assert isinstance(result["data"], dict)
|
||||
|
||||
def test_data_has_required_fields(self):
|
||||
"""data dict contains all required system status fields."""
|
||||
result = produce_system_status()
|
||||
data = result["data"]
|
||||
|
||||
assert "agents_online" in data
|
||||
assert "visitors" in data
|
||||
assert "uptime_seconds" in data
|
||||
assert "thinking_active" in data
|
||||
assert "memory_count" in data
|
||||
|
||||
def test_data_field_types(self):
|
||||
"""All data fields have correct types."""
|
||||
result = produce_system_status()
|
||||
data = result["data"]
|
||||
|
||||
assert isinstance(data["agents_online"], int)
|
||||
assert isinstance(data["visitors"], int)
|
||||
assert isinstance(data["uptime_seconds"], int)
|
||||
assert isinstance(data["thinking_active"], bool)
|
||||
assert isinstance(data["memory_count"], int)
|
||||
|
||||
def test_agents_online_is_non_negative(self):
|
||||
"""agents_online is never negative."""
|
||||
result = produce_system_status()
|
||||
assert result["data"]["agents_online"] >= 0
|
||||
|
||||
def test_visitors_is_non_negative(self):
|
||||
"""visitors is never negative."""
|
||||
result = produce_system_status()
|
||||
assert result["data"]["visitors"] >= 0
|
||||
|
||||
def test_uptime_seconds_is_non_negative(self):
|
||||
"""uptime_seconds is never negative."""
|
||||
result = produce_system_status()
|
||||
assert result["data"]["uptime_seconds"] >= 0
|
||||
|
||||
def test_memory_count_is_non_negative(self):
|
||||
"""memory_count is never negative."""
|
||||
result = produce_system_status()
|
||||
assert result["data"]["memory_count"] >= 0
|
||||
|
||||
@patch("infrastructure.presence.time")
|
||||
def test_ts_is_unix_timestamp(self, mock_time):
|
||||
"""ts should be an integer Unix timestamp."""
|
||||
mock_time.time.return_value = 1742529600
|
||||
result = produce_system_status()
|
||||
assert isinstance(result["ts"], int)
|
||||
assert result["ts"] == 1742529600
|
||||
|
||||
@patch("infrastructure.presence.logger")
|
||||
def test_graceful_degradation_on_import_errors(self, mock_logger):
|
||||
"""Function returns valid dict even when imports fail."""
|
||||
# This test verifies the function handles failures gracefully
|
||||
# by checking it always returns the expected structure
|
||||
result = produce_system_status()
|
||||
|
||||
assert result["type"] == "system_status"
|
||||
assert isinstance(result["data"], dict)
|
||||
assert isinstance(result["ts"], int)
|
||||
|
||||
def test_returns_dict(self):
|
||||
"""produce_system_status always returns a plain dict."""
|
||||
result = produce_system_status()
|
||||
assert isinstance(result, dict)
|
||||
|
||||
446
tests/unit/test_rate_limit.py
Normal file
446
tests/unit/test_rate_limit.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""Tests for rate limiting middleware.
|
||||
|
||||
Tests the RateLimiter class and RateLimitMiddleware for correct
|
||||
rate limiting behavior, cleanup, and edge cases.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Tests for the RateLimiter class."""
|
||||
|
||||
def test_init_defaults(self):
|
||||
"""RateLimiter initializes with default values."""
|
||||
limiter = RateLimiter()
|
||||
assert limiter.requests_per_minute == 30
|
||||
assert limiter.cleanup_interval_seconds == 60
|
||||
assert limiter._storage == {}
|
||||
|
||||
def test_init_custom_values(self):
|
||||
"""RateLimiter accepts custom configuration."""
|
||||
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
|
||||
assert limiter.requests_per_minute == 60
|
||||
assert limiter.cleanup_interval_seconds == 120
|
||||
|
||||
def test_is_allowed_first_request(self):
|
||||
"""First request from an IP is always allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
assert retry_after == 0.0
|
||||
assert "192.168.1.1" in limiter._storage
|
||||
assert len(limiter._storage["192.168.1.1"]) == 1
|
||||
|
||||
def test_is_allowed_under_limit(self):
|
||||
"""Requests under the limit are allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make 4 requests (under limit of 5)
|
||||
for _ in range(4):
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
assert len(limiter._storage["192.168.1.1"]) == 4
|
||||
|
||||
def test_is_allowed_at_limit(self):
|
||||
"""Request at the limit is allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make exactly 5 requests
|
||||
for _ in range(5):
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
assert len(limiter._storage["192.168.1.1"]) == 5
|
||||
|
||||
def test_is_allowed_over_limit(self):
|
||||
"""Request over the limit is denied."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Make 5 requests to hit the limit
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# 6th request should be denied
|
||||
allowed, retry_after = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
def test_is_allowed_different_ips(self):
|
||||
"""Rate limiting is per-IP, not global."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Hit limit for IP 1
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# IP 1 is now rate limited
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
|
||||
# IP 2 should still be allowed
|
||||
allowed, _ = limiter.is_allowed("192.168.1.2")
|
||||
assert allowed is True
|
||||
|
||||
def test_window_expiration_allows_new_requests(self):
|
||||
"""After window expires, new requests are allowed."""
|
||||
limiter = RateLimiter(requests_per_minute=5)
|
||||
|
||||
# Hit the limit
|
||||
for _ in range(5):
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
|
||||
# Should be rate limited
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is False
|
||||
|
||||
# Simulate time passing by clearing timestamps manually
|
||||
# (we can't wait 60 seconds in a test)
|
||||
limiter._storage["192.168.1.1"].clear()
|
||||
|
||||
# Should now be allowed again
|
||||
allowed, _ = limiter.is_allowed("192.168.1.1")
|
||||
assert allowed is True
|
||||
|
||||
def test_cleanup_removes_stale_entries(self):
|
||||
"""Cleanup removes IPs with no recent requests."""
|
||||
limiter = RateLimiter(
|
||||
requests_per_minute=5,
|
||||
cleanup_interval_seconds=1, # Short interval for testing
|
||||
)
|
||||
|
||||
# Add some requests
|
||||
limiter.is_allowed("192.168.1.1")
|
||||
limiter.is_allowed("192.168.1.2")
|
||||
|
||||
# Both IPs should be in storage
|
||||
assert "192.168.1.1" in limiter._storage
|
||||
assert "192.168.1.2" in limiter._storage
|
||||
|
||||
# Manually clear timestamps to simulate stale data
|
||||
limiter._storage["192.168.1.1"].clear()
|
||||
limiter._last_cleanup = time.time() - 2 # Force cleanup
|
||||
|
||||
# Trigger cleanup via check_request with a mock
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.3"
|
||||
mock_request.url.path = "/api/matrix/test"
|
||||
|
||||
limiter.check_request(mock_request)
|
||||
|
||||
# Stale IP should be removed
|
||||
assert "192.168.1.1" not in limiter._storage
|
||||
# IP with no requests (cleared) is also stale
|
||||
assert "192.168.1.2" in limiter._storage
|
||||
|
||||
def test_get_client_ip_direct(self):
|
||||
"""Extract client IP from direct connection."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "192.168.1.100"
|
||||
|
||||
def test_get_client_ip_x_forwarded_for(self):
|
||||
"""Extract client IP from X-Forwarded-For header."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "10.0.0.1"
|
||||
|
||||
def test_get_client_ip_x_real_ip(self):
|
||||
"""Extract client IP from X-Real-IP header."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {"x-real-ip": "10.0.0.5"}
|
||||
mock_request.client = Mock()
|
||||
mock_request.client.host = "192.168.1.100"
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "10.0.0.5"
|
||||
|
||||
def test_get_client_ip_no_client(self):
|
||||
"""Return 'unknown' when no client info available."""
|
||||
limiter = RateLimiter()
|
||||
|
||||
mock_request = Mock()
|
||||
mock_request.headers = {}
|
||||
mock_request.client = None
|
||||
|
||||
ip = limiter._get_client_ip(mock_request)
|
||||
assert ip == "unknown"
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Tests for the RateLimitMiddleware class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app(self):
|
||||
"""Create a mock ASGI app."""
|
||||
|
||||
async def app(scope, receive, send):
|
||||
response = JSONResponse({"status": "ok"})
|
||||
await response(scope, receive, send)
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Create a mock Request object."""
|
||||
request = Mock(spec=Request)
|
||||
request.url.path = "/api/matrix/test"
|
||||
request.headers = {}
|
||||
request.client = Mock()
|
||||
request.client.host = "192.168.1.1"
|
||||
return request
|
||||
|
||||
def test_init_defaults(self, mock_app):
|
||||
"""Middleware initializes with default values."""
|
||||
middleware = RateLimitMiddleware(mock_app)
|
||||
assert middleware.path_prefixes == []
|
||||
assert middleware.limiter.requests_per_minute == 30
|
||||
|
||||
def test_init_custom_values(self, mock_app):
|
||||
"""Middleware accepts custom configuration."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=60,
|
||||
)
|
||||
assert middleware.path_prefixes == ["/api/matrix/"]
|
||||
assert middleware.limiter.requests_per_minute == 60
|
||||
|
||||
def test_should_rate_limit_no_prefixes(self, mock_app):
|
||||
"""With no prefixes, all paths are rate limited."""
|
||||
middleware = RateLimitMiddleware(mock_app)
|
||||
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||||
assert middleware._should_rate_limit("/api/other/test") is True
|
||||
assert middleware._should_rate_limit("/health") is True
|
||||
|
||||
def test_should_rate_limit_with_prefixes(self, mock_app):
|
||||
"""With prefixes, only matching paths are rate limited."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/", "/api/public/"],
|
||||
)
|
||||
assert middleware._should_rate_limit("/api/matrix/test") is True
|
||||
assert middleware._should_rate_limit("/api/matrix/") is True
|
||||
assert middleware._should_rate_limit("/api/public/data") is True
|
||||
assert middleware._should_rate_limit("/api/other/test") is False
|
||||
assert middleware._should_rate_limit("/health") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
|
||||
"""Request to matching path under limit is allowed."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=5,
|
||||
)
|
||||
|
||||
# Create a proper ASGI scope
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/matrix/test",
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
response_body = []
|
||||
|
||||
async def send(message):
|
||||
response_body.append(message)
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should have sent response messages
|
||||
assert len(response_body) > 0
|
||||
# Check for 200 status in the response start message
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message is not None
|
||||
assert start_message["status"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_skips_non_matching_path(self, mock_app):
|
||||
"""Request to non-matching path bypasses rate limiting."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=5,
|
||||
)
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/other/test", # Doesn't match /api/matrix/
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
response_body = []
|
||||
|
||||
async def send(message):
|
||||
response_body.append(message)
|
||||
|
||||
await middleware(scope, receive, send)
|
||||
|
||||
# Should have sent response messages
|
||||
assert len(response_body) > 0
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message is not None
|
||||
assert start_message["status"] == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
|
||||
"""Request over limit returns 429 status."""
|
||||
middleware = RateLimitMiddleware(
|
||||
mock_app,
|
||||
path_prefixes=["/api/matrix/"],
|
||||
requests_per_minute=2, # Low limit for testing
|
||||
)
|
||||
|
||||
# First request - allowed
|
||||
test_scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/api/matrix/test",
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b""}
|
||||
|
||||
# Helper to capture response
|
||||
def make_send(captured):
|
||||
async def send(message):
|
||||
captured.append(message)
|
||||
|
||||
return send
|
||||
|
||||
# Make requests to hit the limit
|
||||
for _ in range(2):
|
||||
response_body = []
|
||||
await middleware(test_scope, receive, make_send(response_body))
|
||||
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"),
|
||||
None,
|
||||
)
|
||||
assert start_message["status"] == 200
|
||||
|
||||
# 3rd request should be rate limited
|
||||
response_body = []
|
||||
await middleware(test_scope, receive, make_send(response_body))
|
||||
|
||||
start_message = next(
|
||||
(m for m in response_body if m.get("type") == "http.response.start"), None
|
||||
)
|
||||
assert start_message["status"] == 429
|
||||
|
||||
# Check for Retry-After header
|
||||
headers = dict(start_message.get("headers", []))
|
||||
assert b"retry-after" in headers or b"Retry-After" in headers
|
||||
|
||||
|
||||
class TestRateLimiterIntegration:
|
||||
"""Integration-style tests for rate limiter behavior."""
|
||||
|
||||
def test_multiple_ips_independent_limits(self):
|
||||
"""Each IP has its own independent rate limit."""
|
||||
limiter = RateLimiter(requests_per_minute=3)
|
||||
|
||||
# Use up limit for IP 1
|
||||
for _ in range(3):
|
||||
limiter.is_allowed("10.0.0.1")
|
||||
|
||||
# Use up limit for IP 2
|
||||
for _ in range(3):
|
||||
limiter.is_allowed("10.0.0.2")
|
||||
|
||||
# Both should now be rate limited
|
||||
assert limiter.is_allowed("10.0.0.1")[0] is False
|
||||
assert limiter.is_allowed("10.0.0.2")[0] is False
|
||||
|
||||
# IP 3 should still be allowed
|
||||
assert limiter.is_allowed("10.0.0.3")[0] is True
|
||||
|
||||
def test_timestamp_window_sliding(self):
|
||||
"""Rate limit window slides correctly as time passes."""
|
||||
from collections import deque
|
||||
|
||||
limiter = RateLimiter(requests_per_minute=3)
|
||||
|
||||
# Add 3 timestamps manually (simulating old requests)
|
||||
now = time.time()
|
||||
limiter._storage["test-ip"] = deque(
|
||||
[
|
||||
now - 100, # 100 seconds ago (outside 60s window)
|
||||
now - 50, # 50 seconds ago (inside window)
|
||||
now - 10, # 10 seconds ago (inside window)
|
||||
]
|
||||
)
|
||||
|
||||
# Currently have 2 requests in window, so 1 more allowed
|
||||
allowed, _ = limiter.is_allowed("test-ip")
|
||||
assert allowed is True
|
||||
|
||||
# Now 3 in window, should be rate limited
|
||||
allowed, _ = limiter.is_allowed("test-ip")
|
||||
assert allowed is False
|
||||
|
||||
def test_cleanup_preserves_active_ips(self):
|
||||
"""Cleanup only removes IPs with no recent requests."""
|
||||
from collections import deque
|
||||
|
||||
limiter = RateLimiter(
|
||||
requests_per_minute=3,
|
||||
cleanup_interval_seconds=1,
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
# IP 1: active recently
|
||||
limiter._storage["active-ip"] = deque([now - 10])
|
||||
# IP 2: no timestamps (stale)
|
||||
limiter._storage["stale-ip"] = deque()
|
||||
# IP 3: old timestamps only
|
||||
limiter._storage["old-ip"] = deque([now - 100])
|
||||
|
||||
limiter._last_cleanup = now - 2 # Force cleanup
|
||||
|
||||
# Run cleanup
|
||||
limiter._cleanup_if_needed()
|
||||
|
||||
# Active IP should remain
|
||||
assert "active-ip" in limiter._storage
|
||||
# Stale IPs should be removed
|
||||
assert "stale-ip" not in limiter._storage
|
||||
assert "old-ip" not in limiter._storage
|
||||
Reference in New Issue
Block a user