Compare commits

...

28 Commits

Author SHA1 Message Date
20159cd286 Merge branch 'main' into kimi/issue-627
Some checks failed
Tests / lint (pull_request) Has been cancelled
Tests / test (pull_request) Has been cancelled
2026-03-21 18:01:19 +00:00
dc9f0c04eb [kimi] Add rate limiting middleware for Matrix API endpoints (#683) (#746)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 16:23:16 +00:00
815933953c [kimi] Add WebSocket authentication for Matrix connections (#682) (#744)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 16:14:05 +00:00
d54493a87b [kimi] Add /api/matrix/health endpoint (#685) (#745)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 15:51:29 +00:00
f7404f67ec [kimi] Add system_status message producer (#681) (#743)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 15:13:01 +00:00
5f4580f98d [kimi] Add matrix config loader utility (#680) (#742)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 15:05:06 +00:00
695d1401fd [kimi] Add CORS config for Matrix frontend origin (#679) (#741)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:56:43 +00:00
ddadc95e55 [kimi] Add /api/matrix/memory/search endpoint (#678) (#740)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:52:31 +00:00
8fc8e0fc3d [kimi] Add /api/matrix/thoughts endpoint for recent thought stream (#677) (#739)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:44:46 +00:00
ada0774ca6 [kimi] Add Pip familiar state to agent_state messages (#676) (#738)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:37:39 +00:00
2a7b6d5708 [kimi] Add /api/matrix/bark endpoint — HTTP fallback for bark messages (#675) (#737)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:32:04 +00:00
9d4ac8e7cc [kimi] Add /api/matrix/config endpoint for world configuration (#674) (#736)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:25:19 +00:00
c9601ba32c [kimi] Add /api/matrix/agents endpoint for Matrix visualization (#673) (#735)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:18:46 +00:00
646eaefa3e [kimi] Add produce_thought() to stream thinking to Matrix (#672) (#734)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:09:19 +00:00
2fa5b23c0c [kimi] Add bark message producer for Matrix bark messages (#671) (#732)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 14:01:42 +00:00
9b57774282 [kimi] feat: pre-cycle state validation for stale cycle_result.json (#661) (#666)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-21 13:53:11 +00:00
6fc9ac09d2 Merge branch 'main' into kimi/issue-627
Some checks failed
Tests / lint (pull_request) Has been cancelled
Tests / test (pull_request) Has been cancelled
2026-03-21 13:52:04 +00:00
62bde03f9e [kimi] feat: add agent_state message producer (#669) (#698)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 13:46:10 +00:00
3474eeb4eb [kimi] refactor: extract presence state serializer from workshop heartbeat (#668) (#697)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 13:41:42 +00:00
e92e151dc3 [kimi] refactor: extract WebSocket message types into shared protocol module (#667) (#696)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 13:37:28 +00:00
1f1bc222e4 [kimi] test: add comprehensive tests for spark modules (#659) (#695)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 13:32:53 +00:00
cc30bdb391 [kimi] test: add comprehensive tests for multimodal.py (#658) (#694)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 04:00:53 +00:00
6f0863b587 [kimi] test: add comprehensive tests for config.py (#648) (#693)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 03:54:54 +00:00
e3d425483d [kimi] fix: add logging to silent except Exception handlers (#646) (#692)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 03:50:26 +00:00
c9445e3056 [kimi] refactor: extract helpers from CSRFMiddleware.dispatch (#628) (#691)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 03:41:09 +00:00
11cd2e3372 [kimi] refactor: extract helpers from chat() (#627) (#686)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-21 03:33:16 +00:00
kimi
917bf3e404 refactor: extract _read_message_input and _resolve_session_id from chat()
Some checks failed
Tests / lint (pull_request) Has been cancelled
Tests / test (pull_request) Has been cancelled
Break up the 72-line chat() function into a thin orchestrator by
extracting stdin reading logic into _read_message_input() and session
resolution into _resolve_session_id().

Fixes #627

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 23:33:14 -04:00
9d0f5c778e [loop-cycle-2] fix: resolve endpoint before execution in CSRF middleware (#626) (#656)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled
2026-03-20 23:05:09 +00:00
34 changed files with 7038 additions and 90 deletions

33
config/matrix.yaml Normal file
View File

@@ -0,0 +1,33 @@
# Matrix World Configuration
# Serves lighting, environment, and feature settings to the Matrix frontend.
lighting:
ambient_color: "#FFAA55" # Warm amber (Workshop warmth)
ambient_intensity: 0.5
point_lights:
- color: "#FFAA55" # Warm amber (Workshop center light)
intensity: 1.2
position: { x: 0, y: 5, z: 0 }
- color: "#3B82F6" # Cool blue (Matrix accent)
intensity: 0.8
position: { x: -5, y: 3, z: -5 }
- color: "#A855F7" # Purple accent
intensity: 0.6
position: { x: 5, y: 3, z: 5 }
environment:
rain_enabled: false
starfield_enabled: true # Cool blue starfield (Matrix feel)
fog_color: "#0f0f23"
fog_density: 0.02
features:
chat_enabled: true
visitor_avatars: true
pip_familiar: true
workshop_portal: true
agents:
default_count: 5
max_count: 20
agents: []

View File

@@ -27,11 +27,15 @@ from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
QUEUE_FILE = REPO_ROOT / ".loop" / "queue.json"
IDLE_STATE_FILE = REPO_ROOT / ".loop" / "idle_state.json"
CYCLE_RESULT_FILE = REPO_ROOT / ".loop" / "cycle_result.json"
TOKEN_FILE = Path.home() / ".hermes" / "gitea_token"
GITEA_API = os.environ.get("GITEA_API", "http://localhost:3000/api/v1")
REPO_SLUG = os.environ.get("REPO_SLUG", "rockachopa/Timmy-time-dashboard")
# Default cycle duration in seconds (5 min); stale threshold = 2× this
CYCLE_DURATION = int(os.environ.get("CYCLE_DURATION", "300"))
# Backoff sequence: 60s, 120s, 240s, 600s max
BACKOFF_BASE = 60
BACKOFF_MAX = 600
@@ -77,6 +81,89 @@ def _fetch_open_issue_numbers() -> set[int] | None:
return None
def _load_cycle_result() -> dict:
"""Read cycle_result.json, handling markdown-fenced JSON."""
if not CYCLE_RESULT_FILE.exists():
return {}
try:
raw = CYCLE_RESULT_FILE.read_text().strip()
if raw.startswith("```"):
lines = raw.splitlines()
lines = [ln for ln in lines if not ln.startswith("```")]
raw = "\n".join(lines)
return json.loads(raw)
except (json.JSONDecodeError, OSError):
return {}
def _is_issue_open(issue_number: int) -> bool | None:
"""Check if a single issue is open. Returns None on API failure."""
token = _get_token()
if not token:
return None
try:
url = f"{GITEA_API}/repos/{REPO_SLUG}/issues/{issue_number}"
req = urllib.request.Request(
url,
headers={
"Authorization": f"token {token}",
"Accept": "application/json",
},
)
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return data.get("state") == "open"
except Exception:
return None
def validate_cycle_result() -> bool:
"""Pre-cycle validation: remove stale or invalid cycle_result.json.
Checks:
1. Age — if older than 2× CYCLE_DURATION, delete it.
2. Issue — if the referenced issue is closed, delete it.
Returns True if the file was removed, False otherwise.
"""
if not CYCLE_RESULT_FILE.exists():
return False
# Age check
try:
age = time.time() - CYCLE_RESULT_FILE.stat().st_mtime
except OSError:
return False
stale_threshold = CYCLE_DURATION * 2
if age > stale_threshold:
print(
f"[loop-guard] cycle_result.json is {int(age)}s old "
f"(threshold {stale_threshold}s) — removing stale file"
)
CYCLE_RESULT_FILE.unlink(missing_ok=True)
return True
# Issue check
cr = _load_cycle_result()
issue_num = cr.get("issue")
if issue_num is not None:
try:
issue_num = int(issue_num)
except (ValueError, TypeError):
return False
is_open = _is_issue_open(issue_num)
if is_open is False:
print(
f"[loop-guard] cycle_result.json references closed "
f"issue #{issue_num} — removing"
)
CYCLE_RESULT_FILE.unlink(missing_ok=True)
return True
# is_open is None (API failure) or True — keep file
return False
def load_queue() -> list[dict]:
"""Load queue.json and return ready items, filtering out closed issues."""
if not QUEUE_FILE.exists():
@@ -150,6 +237,9 @@ def main() -> int:
}, indent=2))
return 0
# Pre-cycle validation: remove stale cycle_result.json
validate_cycle_result()
ready = load_queue()
if ready:

View File

@@ -149,6 +149,18 @@ class Settings(BaseSettings):
"http://127.0.0.1:8000",
]
# ── Matrix Frontend Integration ────────────────────────────────────────
# URL of the Matrix frontend (Replit/Tailscale) for CORS.
# When set, this origin is added to CORS allowed_origins.
# Example: "http://100.124.176.28:8080" or "https://alexanderwhitestone.com"
matrix_frontend_url: str = "" # Empty = disabled
# WebSocket authentication token for Matrix connections.
# When set, clients must provide this token via ?token= query param
# or in the first message as {"type": "auth", "token": "..."}.
# Empty/unset = auth disabled (dev mode).
matrix_ws_token: str = ""
# Trusted hosts for the Host header check (TrustedHostMiddleware).
# Set TRUSTED_HOSTS as a comma-separated list. Wildcards supported (e.g. "*.ts.net").
# Defaults include localhost + Tailscale MagicDNS. Add your Tailscale IP if needed.

View File

@@ -10,6 +10,7 @@ Key improvements:
import asyncio
import json
import logging
import re
from contextlib import asynccontextmanager
from pathlib import Path
@@ -23,6 +24,7 @@ from config import settings
# Import dedicated middleware
from dashboard.middleware.csrf import CSRFMiddleware
from dashboard.middleware.rate_limit import RateLimitMiddleware
from dashboard.middleware.request_logging import RequestLoggingMiddleware
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
from dashboard.routes.agents import router as agents_router
@@ -49,6 +51,7 @@ from dashboard.routes.tools import router as tools_router
from dashboard.routes.tower import router as tower_router
from dashboard.routes.voice import router as voice_router
from dashboard.routes.work_orders import router as work_orders_router
from dashboard.routes.world import matrix_router
from dashboard.routes.world import router as world_router
from timmy.workshop_state import PRESENCE_FILE
@@ -519,25 +522,55 @@ app = FastAPI(
def _get_cors_origins() -> list[str]:
"""Get CORS origins from settings, rejecting wildcards in production."""
origins = settings.cors_origins
"""Get CORS origins from settings, rejecting wildcards in production.
Adds matrix_frontend_url when configured. Always allows Tailscale IPs
(100.x.x.x range) for development convenience.
"""
origins = list(settings.cors_origins)
# Strip wildcards in production (security)
if "*" in origins and not settings.debug:
logger.warning(
"Wildcard '*' in CORS_ORIGINS stripped in production — "
"set explicit origins via CORS_ORIGINS env var"
)
origins = [o for o in origins if o != "*"]
# Add Matrix frontend URL if configured
if settings.matrix_frontend_url:
url = settings.matrix_frontend_url.strip()
if url and url not in origins:
origins.append(url)
logger.debug("Added Matrix frontend to CORS: %s", url)
return origins
# Pattern to match Tailscale IPs (100.x.x.x) for CORS origin regex
_TAILSCALE_IP_PATTERN = re.compile(r"^https?://100\.\d{1,3}\.\d{1,3}\.\d{1,3}(?::\d+)?$")
def _is_tailscale_origin(origin: str) -> bool:
"""Check if origin is a Tailscale IP (100.x.x.x range)."""
return bool(_TAILSCALE_IP_PATTERN.match(origin))
# Add dedicated middleware in correct order
# 1. Logging (outermost to capture everything)
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
# 2. Security Headers
# 2. Rate Limiting (before security to prevent abuse early)
app.add_middleware(
RateLimitMiddleware,
path_prefixes=["/api/matrix/"],
requests_per_minute=30,
)
# 3. Security Headers
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
# 3. CSRF Protection
# 4. CSRF Protection
app.add_middleware(CSRFMiddleware)
# 4. Standard FastAPI middleware
@@ -551,6 +584,7 @@ app.add_middleware(
app.add_middleware(
CORSMiddleware,
allow_origins=_get_cors_origins(),
allow_origin_regex=r"https?://100\.\d{1,3}\.\d{1,3}\.\d{1,3}(:\d+)?",
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization"],
@@ -589,6 +623,7 @@ app.include_router(system_router)
app.include_router(experiments_router)
app.include_router(db_explorer_router)
app.include_router(world_router)
app.include_router(matrix_router)
app.include_router(tower_router)

View File

@@ -1,6 +1,7 @@
"""Dashboard middleware package."""
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
from .rate_limit import RateLimiter, RateLimitMiddleware
from .request_logging import RequestLoggingMiddleware
from .security_headers import SecurityHeadersMiddleware
@@ -9,6 +10,8 @@ __all__ = [
"csrf_exempt",
"generate_csrf_token",
"validate_csrf_token",
"RateLimiter",
"RateLimitMiddleware",
"SecurityHeadersMiddleware",
"RequestLoggingMiddleware",
]

View File

@@ -131,7 +131,6 @@ class CSRFMiddleware(BaseHTTPMiddleware):
For safe methods: Set a CSRF token cookie if not present.
For unsafe methods: Validate the CSRF token or check if exempt.
"""
# Bypass CSRF if explicitly disabled (e.g. in tests)
from config import settings
if settings.timmy_disable_csrf:
@@ -141,52 +140,55 @@ class CSRFMiddleware(BaseHTTPMiddleware):
if request.headers.get("upgrade", "").lower() == "websocket":
return await call_next(request)
# Get existing CSRF token from cookie
csrf_cookie = request.cookies.get(self.cookie_name)
# For safe methods, just ensure a token exists
if request.method in self.SAFE_METHODS:
response = await call_next(request)
return await self._handle_safe_method(request, call_next, csrf_cookie)
# Set CSRF token cookie if not present
if not csrf_cookie:
new_token = generate_csrf_token()
response.set_cookie(
key=self.cookie_name,
value=new_token,
httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure,
samesite="Lax",
max_age=86400, # 24 hours
)
return await self._handle_unsafe_method(request, call_next, csrf_cookie)
return response
async def _handle_safe_method(
self, request: Request, call_next, csrf_cookie: str | None
) -> Response:
"""Handle safe HTTP methods (GET, HEAD, OPTIONS, TRACE).
# For unsafe methods, we need to validate or check if exempt
# First, try to validate the CSRF token
if await self._validate_request(request, csrf_cookie):
# Token is valid, allow the request
return await call_next(request)
Forwards the request and sets a CSRF token cookie if not present.
"""
from config import settings
# Token validation failed, check if the path is exempt
path = request.url.path
if self._is_likely_exempt(path):
# Path is exempt, allow the request
return await call_next(request)
# Token validation failed and path is not exempt
# We still need to call the app to check if the endpoint is decorated
# with @csrf_exempt, so we'll let it through and check after routing
response = await call_next(request)
# After routing, check if the endpoint is marked as exempt
endpoint = request.scope.get("endpoint")
if endpoint and is_csrf_exempt(endpoint):
# Endpoint is marked as exempt, allow the response
return response
if not csrf_cookie:
new_token = generate_csrf_token()
response.set_cookie(
key=self.cookie_name,
value=new_token,
httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure,
samesite="Lax",
max_age=86400, # 24 hours
)
return response
async def _handle_unsafe_method(
self, request: Request, call_next, csrf_cookie: str | None
) -> Response:
"""Handle unsafe HTTP methods (POST, PUT, DELETE, PATCH).
Validates the CSRF token, checks path and endpoint exemptions,
or returns a 403 error.
"""
if await self._validate_request(request, csrf_cookie):
return await call_next(request)
if self._is_likely_exempt(request.url.path):
return await call_next(request)
endpoint = self._resolve_endpoint(request)
if endpoint and is_csrf_exempt(endpoint):
return await call_next(request)
# Endpoint is not exempt and token validation failed
# Return 403 error
return JSONResponse(
status_code=403,
content={
@@ -196,6 +198,41 @@ class CSRFMiddleware(BaseHTTPMiddleware):
},
)
def _resolve_endpoint(self, request: Request) -> Callable | None:
"""Resolve the route endpoint without executing it.
Walks the Starlette/FastAPI router to find which endpoint function
handles this request, so we can check @csrf_exempt before any
side effects occur.
Returns:
The endpoint callable, or None if no route matched.
"""
# If routing already happened (endpoint in scope), use it
endpoint = request.scope.get("endpoint")
if endpoint:
return endpoint
# Walk the middleware/app chain to find something with routes
from starlette.routing import Match
app = self.app
while app is not None:
if hasattr(app, "routes"):
for route in app.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
# Try .router (FastAPI stores routes on app.router)
if hasattr(app, "router") and hasattr(app.router, "routes"):
for route in app.router.routes:
match, _ = route.matches(request.scope)
if match == Match.FULL:
return getattr(route, "endpoint", None)
app = getattr(app, "app", None)
return None
def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt.

View File

@@ -0,0 +1,209 @@
"""Rate limiting middleware for FastAPI.
Simple in-memory rate limiter for API endpoints. Tracks requests per IP
with configurable limits and automatic cleanup of stale entries.
"""
import logging
import time
from collections import deque
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
class RateLimiter:
"""In-memory rate limiter for tracking requests per IP.
Stores request timestamps in a dict keyed by client IP.
Automatically cleans up stale entries every 60 seconds.
Attributes:
requests_per_minute: Maximum requests allowed per minute per IP.
cleanup_interval_seconds: How often to clean stale entries.
"""
def __init__(
self,
requests_per_minute: int = 30,
cleanup_interval_seconds: int = 60,
):
self.requests_per_minute = requests_per_minute
self.cleanup_interval_seconds = cleanup_interval_seconds
self._storage: dict[str, deque[float]] = {}
self._last_cleanup: float = time.time()
self._window_seconds: float = 60.0 # 1 minute window
def _get_client_ip(self, request: Request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For header.
Args:
request: The incoming request.
Returns:
Client IP address string.
"""
# Check for forwarded IP (when behind proxy/load balancer)
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
# Take the first IP in the chain
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip
# Fall back to direct connection
if request.client:
return request.client.host
return "unknown"
def _cleanup_if_needed(self) -> None:
"""Remove stale entries older than the cleanup interval."""
now = time.time()
if now - self._last_cleanup < self.cleanup_interval_seconds:
return
cutoff = now - self._window_seconds
stale_ips: list[str] = []
for ip, timestamps in self._storage.items():
# Remove timestamps older than the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Mark IP for removal if no recent requests
if not timestamps:
stale_ips.append(ip)
# Remove stale IP entries
for ip in stale_ips:
del self._storage[ip]
self._last_cleanup = now
if stale_ips:
logger.debug("Rate limiter cleanup: removed %d stale IPs", len(stale_ips))
def is_allowed(self, client_ip: str) -> tuple[bool, float]:
"""Check if a request from the given IP is allowed.
Args:
client_ip: The client's IP address.
Returns:
Tuple of (allowed: bool, retry_after: float).
retry_after is seconds until next allowed request, 0 if allowed now.
"""
now = time.time()
cutoff = now - self._window_seconds
# Get or create timestamp deque for this IP
if client_ip not in self._storage:
self._storage[client_ip] = deque()
timestamps = self._storage[client_ip]
# Remove timestamps outside the window
while timestamps and timestamps[0] < cutoff:
timestamps.popleft()
# Check if limit exceeded
if len(timestamps) >= self.requests_per_minute:
# Calculate retry after time
oldest = timestamps[0]
retry_after = self._window_seconds - (now - oldest)
return False, max(0.0, retry_after)
# Record this request
timestamps.append(now)
return True, 0.0
def check_request(self, request: Request) -> tuple[bool, float]:
"""Check if the request is allowed under rate limits.
Args:
request: The incoming request.
Returns:
Tuple of (allowed: bool, retry_after: float).
"""
self._cleanup_if_needed()
client_ip = self._get_client_ip(request)
return self.is_allowed(client_ip)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware to apply rate limiting to specific routes.
Usage:
# Apply to all routes (not recommended for public static files)
app.add_middleware(RateLimitMiddleware)
# Apply only to specific paths
app.add_middleware(
RateLimitMiddleware,
path_prefixes=["/api/matrix/"],
requests_per_minute=30,
)
Attributes:
path_prefixes: List of URL path prefixes to rate limit.
If empty, applies to all paths.
requests_per_minute: Maximum requests per minute per IP.
"""
def __init__(
self,
app,
path_prefixes: list[str] | None = None,
requests_per_minute: int = 30,
):
super().__init__(app)
self.path_prefixes = path_prefixes or []
self.limiter = RateLimiter(requests_per_minute=requests_per_minute)
def _should_rate_limit(self, path: str) -> bool:
"""Check if the given path should be rate limited.
Args:
path: The request URL path.
Returns:
True if path matches any configured prefix.
"""
if not self.path_prefixes:
return True
return any(path.startswith(prefix) for prefix in self.path_prefixes)
async def dispatch(self, request: Request, call_next) -> Response:
"""Apply rate limiting to configured paths.
Args:
request: The incoming request.
call_next: Callable to get the response from downstream.
Returns:
Response from downstream, or 429 if rate limited.
"""
# Skip if path doesn't match configured prefixes
if not self._should_rate_limit(request.url.path):
return await call_next(request)
# Check rate limit
allowed, retry_after = self.limiter.check_request(request)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded. Try again later.",
"retry_after": int(retry_after) + 1,
},
headers={"Retry-After": str(int(retry_after) + 1)},
)
# Process the request
return await call_next(request)

View File

@@ -116,7 +116,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
},
)
except Exception as exc:
logger.debug("Escalation logging error: %s", exc)
logger.warning("Escalation logging error: %s", exc)
pass # never let escalation break the request
# Re-raise the exception

View File

@@ -75,6 +75,7 @@ def _query_database(db_path: str) -> dict:
"truncated": count > MAX_ROWS,
}
except Exception as exc:
logger.exception("Failed to query table %s", table_name)
result["tables"][table_name] = {
"error": str(exc),
"columns": [],
@@ -83,6 +84,7 @@ def _query_database(db_path: str) -> dict:
"truncated": False,
}
except Exception as exc:
logger.exception("Failed to query database %s", db_path)
result["error"] = str(exc)
return result

View File

@@ -135,6 +135,7 @@ def _run_grok_query(message: str) -> dict:
result = backend.run(message)
return {"response": f"**[Grok]{invoice_note}:** {result.content}", "error": None}
except Exception as exc:
logger.exception("Grok query failed")
return {"response": None, "error": f"Grok error: {exc}"}
@@ -193,6 +194,7 @@ async def grok_stats():
"model": settings.grok_default_model,
}
except Exception as exc:
logger.exception("Failed to load Grok stats")
return {"error": str(exc)}

View File

@@ -148,6 +148,7 @@ def _check_sqlite() -> DependencyStatus:
details={"path": str(db_path)},
)
except Exception as exc:
logger.exception("SQLite health check failed")
return DependencyStatus(
name="SQLite Database",
status="unavailable",

View File

@@ -59,6 +59,7 @@ async def tts_speak(text: str = Form(...)):
voice_tts.speak(text)
return {"spoken": True, "text": text}
except Exception as exc:
logger.exception("TTS speak failed")
return {"spoken": False, "reason": str(exc)}

View File

@@ -17,16 +17,221 @@ or missing.
import asyncio
import json
import logging
import math
import re
import time
from collections import deque
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from fastapi import APIRouter, WebSocket
import yaml
from fastapi import APIRouter, Request, WebSocket
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from config import settings
from infrastructure.presence import produce_bark, serialize_presence
from timmy.memory_system import search_memories
from timmy.workshop_state import PRESENCE_FILE
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/world", tags=["world"])
matrix_router = APIRouter(prefix="/api/matrix", tags=["matrix"])
# ---------------------------------------------------------------------------
# Matrix Bark Endpoint — HTTP fallback for bark messages
# ---------------------------------------------------------------------------
# Rate limiting: 1 request per 3 seconds per visitor_id
_BARK_RATE_LIMIT_SECONDS = 3
_bark_last_request: dict[str, float] = {}
class BarkRequest(BaseModel):
"""Request body for POST /api/matrix/bark."""
text: str
visitor_id: str
@matrix_router.post("/bark")
async def post_matrix_bark(request: BarkRequest) -> JSONResponse:
"""Generate a bark response for a visitor message.
HTTP fallback for when WebSocket isn't available. The Matrix frontend
can POST a message and get Timmy's bark response back as JSON.
Rate limited to 1 request per 3 seconds per visitor_id.
Request body:
- text: The visitor's message text
- visitor_id: Unique identifier for the visitor (used for rate limiting)
Returns:
- 200: Bark message in produce_bark() format
- 429: Rate limit exceeded (try again later)
- 422: Invalid request (missing/invalid fields)
"""
# Validate inputs
text = request.text.strip() if request.text else ""
visitor_id = request.visitor_id.strip() if request.visitor_id else ""
if not text:
return JSONResponse(
status_code=422,
content={"error": "text is required"},
)
if not visitor_id:
return JSONResponse(
status_code=422,
content={"error": "visitor_id is required"},
)
# Rate limiting check
now = time.time()
last_request = _bark_last_request.get(visitor_id, 0)
time_since_last = now - last_request
if time_since_last < _BARK_RATE_LIMIT_SECONDS:
retry_after = _BARK_RATE_LIMIT_SECONDS - time_since_last
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded. Try again later."},
headers={"Retry-After": str(int(retry_after) + 1)},
)
# Record this request
_bark_last_request[visitor_id] = now
# Generate bark response
try:
reply = await _generate_bark(text)
except Exception as exc:
logger.warning("Bark generation failed: %s", exc)
reply = "Hmm, my thoughts are a bit tangled right now."
# Build bark response using produce_bark format
bark = produce_bark(agent_id="timmy", text=reply, style="speech")
return JSONResponse(
content=bark,
headers={"Cache-Control": "no-cache, no-store"},
)
# ---------------------------------------------------------------------------
# Matrix Agent Registry — serves agents to the Matrix visualization
# ---------------------------------------------------------------------------
# Agent color mapping — consistent with Matrix visual identity
_AGENT_COLORS: dict[str, str] = {
"timmy": "#FFD700", # Gold
"orchestrator": "#FFD700", # Gold
"perplexity": "#3B82F6", # Blue
"replit": "#F97316", # Orange
"kimi": "#06B6D4", # Cyan
"claude": "#A855F7", # Purple
"researcher": "#10B981", # Emerald
"coder": "#EF4444", # Red
"writer": "#EC4899", # Pink
"memory": "#8B5CF6", # Violet
"experimenter": "#14B8A6", # Teal
"forge": "#EF4444", # Red (coder alias)
"seer": "#10B981", # Emerald (researcher alias)
"quill": "#EC4899", # Pink (writer alias)
"echo": "#8B5CF6", # Violet (memory alias)
"lab": "#14B8A6", # Teal (experimenter alias)
}
# Agent shape mapping for 3D visualization
_AGENT_SHAPES: dict[str, str] = {
"timmy": "sphere",
"orchestrator": "sphere",
"perplexity": "cube",
"replit": "cylinder",
"kimi": "dodecahedron",
"claude": "octahedron",
"researcher": "icosahedron",
"coder": "cube",
"writer": "cone",
"memory": "torus",
"experimenter": "tetrahedron",
"forge": "cube",
"seer": "icosahedron",
"quill": "cone",
"echo": "torus",
"lab": "tetrahedron",
}
# Default fallback values
_DEFAULT_COLOR = "#9CA3AF" # Gray
_DEFAULT_SHAPE = "sphere"
_DEFAULT_STATUS = "available"
def _get_agent_color(agent_id: str) -> str:
"""Get the Matrix color for an agent."""
return _AGENT_COLORS.get(agent_id.lower(), _DEFAULT_COLOR)
def _get_agent_shape(agent_id: str) -> str:
"""Get the Matrix shape for an agent."""
return _AGENT_SHAPES.get(agent_id.lower(), _DEFAULT_SHAPE)
def _compute_circular_positions(count: int, radius: float = 3.0) -> list[dict[str, float]]:
"""Compute circular positions for agents in the Matrix.
Agents are arranged in a circle on the XZ plane at y=0.
"""
positions = []
for i in range(count):
angle = (2 * math.pi * i) / count
x = radius * math.cos(angle)
z = radius * math.sin(angle)
positions.append({"x": round(x, 2), "y": 0.0, "z": round(z, 2)})
return positions
def _build_matrix_agents_response() -> list[dict[str, Any]]:
"""Build the Matrix agent registry response.
Reads from agents.yaml and returns agents with Matrix-compatible
formatting including colors, shapes, and positions.
"""
try:
from timmy.agents.loader import list_agents
agents = list_agents()
if not agents:
return []
positions = _compute_circular_positions(len(agents))
result = []
for i, agent in enumerate(agents):
agent_id = agent.get("id", "")
result.append(
{
"id": agent_id,
"display_name": agent.get("name", agent_id.title()),
"role": agent.get("role", "general"),
"color": _get_agent_color(agent_id),
"position": positions[i],
"shape": _get_agent_shape(agent_id),
"status": agent.get("status", _DEFAULT_STATUS),
}
)
return result
except Exception as exc:
logger.warning("Failed to load agents for Matrix: %s", exc)
return []
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/world", tags=["world"])
@@ -149,21 +354,7 @@ def _read_presence_file() -> dict | None:
def _build_world_state(presence: dict) -> dict:
"""Transform presence dict into the world/state API response."""
return {
"timmyState": {
"mood": presence.get("mood", "calm"),
"activity": presence.get("current_focus", "idle"),
"energy": presence.get("energy", 0.5),
"confidence": presence.get("confidence", 0.7),
},
"familiar": presence.get("familiar"),
"activeThreads": presence.get("active_threads", []),
"recentEvents": presence.get("recent_events", []),
"concerns": presence.get("concerns", []),
"visitorPresent": False,
"updatedAt": presence.get("liveness", datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")),
"version": presence.get("version", 1),
}
return serialize_presence(presence)
def _get_current_state() -> dict:
@@ -224,6 +415,50 @@ async def _heartbeat(websocket: WebSocket) -> None:
logger.debug("Heartbeat stopped — connection gone")
async def _authenticate_ws(websocket: WebSocket) -> bool:
"""Authenticate WebSocket connection using matrix_ws_token.
Checks for token in query param ?token= first. If no query param,
accepts the connection and waits for first message with
{"type": "auth", "token": "..."}.
Returns True if authenticated (or if auth is disabled).
Returns False and closes connection with code 4001 if invalid.
"""
token_setting = settings.matrix_ws_token
# Auth disabled in dev mode (empty/unset token)
if not token_setting:
return True
# Check query param first (can validate before accept)
query_token = websocket.query_params.get("token", "")
if query_token:
if query_token == token_setting:
return True
# Invalid token in query param - we need to accept to close properly
await websocket.accept()
await websocket.close(code=4001, reason="Invalid token")
return False
# No query token - accept and wait for auth message
await websocket.accept()
# Wait for auth message as first message
try:
raw = await websocket.receive_text()
data = json.loads(raw)
if data.get("type") == "auth" and data.get("token") == token_setting:
return True
# Invalid auth message
await websocket.close(code=4001, reason="Invalid token")
return False
except (json.JSONDecodeError, TypeError):
# Non-JSON first message without valid token
await websocket.close(code=4001, reason="Authentication required")
return False
@router.websocket("/ws")
async def world_ws(websocket: WebSocket) -> None:
"""Accept a Workshop client and keep it alive for state broadcasts.
@@ -232,8 +467,28 @@ async def world_ws(websocket: WebSocket) -> None:
client never starts from a blank slate. Incoming frames are parsed
as JSON — ``visitor_message`` triggers a bark response. A background
heartbeat ping runs every 15 s to detect dead connections early.
Authentication:
- If matrix_ws_token is configured, clients must provide it via
?token= query param or in the first message as
{"type": "auth", "token": "..."}.
- Invalid token results in close code 4001.
- Valid token receives a connection_ack message.
"""
await websocket.accept()
# Authenticate (may accept connection internally)
is_authed = await _authenticate_ws(websocket)
if not is_authed:
logger.info("World WS connection rejected — invalid token")
return
# Auth passed - accept if not already accepted
if websocket.client_state.name != "CONNECTED":
await websocket.accept()
# Send connection_ack if auth was required
if settings.matrix_ws_token:
await websocket.send_text(json.dumps({"type": "connection_ack"}))
_ws_clients.append(websocket)
logger.info("World WS connected — %d clients", len(_ws_clients))
@@ -383,3 +638,428 @@ async def _generate_bark(visitor_text: str) -> str:
except Exception as exc:
logger.warning("Bark generation failed: %s", exc)
return "Hmm, my thoughts are a bit tangled right now."
# ---------------------------------------------------------------------------
# Matrix Configuration Endpoint
# ---------------------------------------------------------------------------
# Default Matrix configuration (fallback when matrix.yaml is missing/corrupt)
_DEFAULT_MATRIX_CONFIG: dict[str, Any] = {
"lighting": {
"ambient_color": "#1a1a2e",
"ambient_intensity": 0.4,
"point_lights": [
{"color": "#FFD700", "intensity": 1.2, "position": {"x": 0, "y": 5, "z": 0}},
{"color": "#3B82F6", "intensity": 0.8, "position": {"x": -5, "y": 3, "z": -5}},
{"color": "#A855F7", "intensity": 0.6, "position": {"x": 5, "y": 3, "z": 5}},
],
},
"environment": {
"rain_enabled": False,
"starfield_enabled": True,
"fog_color": "#0f0f23",
"fog_density": 0.02,
},
"features": {
"chat_enabled": True,
"visitor_avatars": True,
"pip_familiar": True,
"workshop_portal": True,
},
}
def _load_matrix_config() -> dict[str, Any]:
"""Load Matrix world configuration from matrix.yaml with fallback to defaults.
Returns a dict with sections: lighting, environment, features.
If the config file is missing or invalid, returns sensible defaults.
"""
try:
config_path = Path(settings.repo_root) / "config" / "matrix.yaml"
if not config_path.exists():
logger.debug("matrix.yaml not found, using default config")
return _DEFAULT_MATRIX_CONFIG.copy()
raw = config_path.read_text()
config = yaml.safe_load(raw)
if not isinstance(config, dict):
logger.warning("matrix.yaml invalid format, using defaults")
return _DEFAULT_MATRIX_CONFIG.copy()
# Merge with defaults to ensure all required fields exist
result: dict[str, Any] = {
"lighting": {
**_DEFAULT_MATRIX_CONFIG["lighting"],
**config.get("lighting", {}),
},
"environment": {
**_DEFAULT_MATRIX_CONFIG["environment"],
**config.get("environment", {}),
},
"features": {
**_DEFAULT_MATRIX_CONFIG["features"],
**config.get("features", {}),
},
}
# Ensure point_lights is a list
if "point_lights" in config.get("lighting", {}):
result["lighting"]["point_lights"] = config["lighting"]["point_lights"]
else:
result["lighting"]["point_lights"] = _DEFAULT_MATRIX_CONFIG["lighting"]["point_lights"]
return result
except Exception as exc:
logger.warning("Failed to load matrix config: %s, using defaults", exc)
return _DEFAULT_MATRIX_CONFIG.copy()
@matrix_router.get("/config")
async def get_matrix_config() -> JSONResponse:
"""Return Matrix world configuration.
Serves lighting presets, environment settings, and feature flags
to the Matrix frontend so it can be config-driven rather than
hardcoded. Reads from config/matrix.yaml with sensible defaults.
Response structure:
- lighting: ambient_color, ambient_intensity, point_lights[]
- environment: rain_enabled, starfield_enabled, fog_color, fog_density
- features: chat_enabled, visitor_avatars, pip_familiar, workshop_portal
"""
config = _load_matrix_config()
return JSONResponse(
content=config,
headers={"Cache-Control": "no-cache, no-store"},
)
# ---------------------------------------------------------------------------
# Matrix Agent Registry Endpoint
# ---------------------------------------------------------------------------
@matrix_router.get("/agents")
async def get_matrix_agents() -> JSONResponse:
"""Return the agent registry for Matrix visualization.
Serves agents from agents.yaml with Matrix-compatible formatting:
- id: agent identifier
- display_name: human-readable name
- role: functional role
- color: hex color code for visualization
- position: {x, y, z} coordinates in 3D space
- shape: 3D shape type
- status: availability status
Agents are arranged in a circular layout by default.
Returns 200 with empty list if no agents configured.
"""
agents = _build_matrix_agents_response()
return JSONResponse(
content=agents,
headers={"Cache-Control": "no-cache, no-store"},
)
# ---------------------------------------------------------------------------
# Matrix Thoughts Endpoint — Timmy's recent thought stream for Matrix display
# ---------------------------------------------------------------------------
_MAX_THOUGHT_LIMIT = 50 # Maximum thoughts allowed per request
_DEFAULT_THOUGHT_LIMIT = 10 # Default number of thoughts to return
_MAX_THOUGHT_TEXT_LEN = 500 # Max characters for thought text
def _build_matrix_thoughts_response(limit: int = _DEFAULT_THOUGHT_LIMIT) -> list[dict[str, Any]]:
"""Build the Matrix thoughts response from the thinking engine.
Returns recent thoughts formatted for Matrix display:
- id: thought UUID
- text: thought content (truncated to 500 chars)
- created_at: ISO-8601 timestamp
- chain_id: parent thought ID (or null if root thought)
Returns empty list if thinking engine is disabled or fails.
"""
try:
from timmy.thinking import thinking_engine
thoughts = thinking_engine.get_recent_thoughts(limit=limit)
return [
{
"id": t.id,
"text": t.content[:_MAX_THOUGHT_TEXT_LEN],
"created_at": t.created_at,
"chain_id": t.parent_id,
}
for t in thoughts
]
except Exception as exc:
logger.warning("Failed to load thoughts for Matrix: %s", exc)
return []
@matrix_router.get("/thoughts")
async def get_matrix_thoughts(limit: int = _DEFAULT_THOUGHT_LIMIT) -> JSONResponse:
"""Return Timmy's recent thoughts formatted for Matrix display.
This is the REST companion to the thought WebSocket messages,
allowing the Matrix frontend to display what Timmy is actually
thinking about rather than canned contextual lines.
Query params:
- limit: Number of thoughts to return (default 10, max 50)
Response: JSON array of thought objects:
- id: thought UUID
- text: thought content (truncated to 500 chars)
- created_at: ISO-8601 timestamp
- chain_id: parent thought ID (null if root thought)
Returns empty array if thinking engine is disabled or fails.
"""
# Clamp limit to valid range
if limit < 1:
limit = 1
elif limit > _MAX_THOUGHT_LIMIT:
limit = _MAX_THOUGHT_LIMIT
thoughts = _build_matrix_thoughts_response(limit=limit)
return JSONResponse(
content=thoughts,
headers={"Cache-Control": "no-cache, no-store"},
)
# ---------------------------------------------------------------------------
# Matrix Health Endpoint — backend capability discovery
# ---------------------------------------------------------------------------
# Health check cache (5-second TTL for capability checks)
_health_cache: dict | None = None
_health_cache_ts: float = 0.0
_HEALTH_CACHE_TTL = 5.0
def _check_capability_thinking() -> bool:
"""Check if thinking engine is available."""
try:
from timmy.thinking import thinking_engine
# Check if the engine has been initialized (has a db path)
return hasattr(thinking_engine, "_db") and thinking_engine._db is not None
except Exception:
return False
def _check_capability_memory() -> bool:
"""Check if memory system is available."""
try:
from timmy.memory_system import HOT_MEMORY_PATH
return HOT_MEMORY_PATH.exists()
except Exception:
return False
def _check_capability_bark() -> bool:
"""Check if bark production is available."""
try:
from infrastructure.presence import produce_bark
return callable(produce_bark)
except Exception:
return False
def _check_capability_familiar() -> bool:
"""Check if familiar (Pip) is available."""
try:
from timmy.familiar import pip_familiar
return pip_familiar is not None
except Exception:
return False
def _check_capability_lightning() -> bool:
"""Check if Lightning payments are available."""
# Lightning is currently disabled per health.py
# Returns False until properly re-implemented
return False
def _build_matrix_health_response() -> dict[str, Any]:
"""Build the Matrix health response with capability checks.
Performs lightweight checks (<100ms total) to determine which features
are available. Returns 200 even if some capabilities are degraded.
"""
capabilities = {
"thinking": _check_capability_thinking(),
"memory": _check_capability_memory(),
"bark": _check_capability_bark(),
"familiar": _check_capability_familiar(),
"lightning": _check_capability_lightning(),
}
# Status is ok if core capabilities (thinking, memory, bark) are available
core_caps = ["thinking", "memory", "bark"]
core_available = all(capabilities[c] for c in core_caps)
status = "ok" if core_available else "degraded"
return {
"status": status,
"version": "1.0.0",
"capabilities": capabilities,
}
@matrix_router.get("/health")
async def get_matrix_health() -> JSONResponse:
"""Return health status and capability availability for Matrix frontend.
This endpoint allows the Matrix frontend to discover what backend
capabilities are available so it can show/hide UI elements:
- thinking: Show thought bubbles if enabled
- memory: Show crystal ball memory search if available
- bark: Enable visitor chat responses
- familiar: Show Pip the familiar
- lightning: Enable payment features
Response time is <100ms (no heavy checks). Returns 200 even if
some capabilities are degraded.
Response:
- status: "ok" or "degraded"
- version: API version string
- capabilities: dict of feature:bool
"""
response = _build_matrix_health_response()
status_code = 200 # Always 200, even if degraded
return JSONResponse(
content=response,
status_code=status_code,
headers={"Cache-Control": "no-cache, no-store"},
)
# ---------------------------------------------------------------------------
# Matrix Memory Search Endpoint — visitors query Timmy's memory
# ---------------------------------------------------------------------------
# Rate limiting: 1 search per 5 seconds per IP
_MEMORY_SEARCH_RATE_LIMIT_SECONDS = 5
_memory_search_last_request: dict[str, float] = {}
_MAX_MEMORY_RESULTS = 5
_MAX_MEMORY_TEXT_LENGTH = 200
def _get_client_ip(request) -> str:
"""Extract client IP from request, respecting X-Forwarded-For header."""
# Check for forwarded IP (when behind proxy)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# Take the first IP in the chain
return forwarded.split(",")[0].strip()
# Fall back to direct client IP
if request.client:
return request.client.host
return "unknown"
def _build_matrix_memory_response(
memories: list,
) -> list[dict[str, Any]]:
"""Build the Matrix memory search response.
Formats memory entries for Matrix display:
- text: truncated to 200 characters
- relevance: 0-1 score from relevance_score
- created_at: ISO-8601 timestamp
- context_type: the memory type
Results are capped at _MAX_MEMORY_RESULTS.
"""
results = []
for mem in memories[:_MAX_MEMORY_RESULTS]:
text = mem.content
if len(text) > _MAX_MEMORY_TEXT_LENGTH:
text = text[:_MAX_MEMORY_TEXT_LENGTH] + "..."
results.append(
{
"text": text,
"relevance": round(mem.relevance_score or 0.0, 4),
"created_at": mem.timestamp,
"context_type": mem.context_type,
}
)
return results
@matrix_router.get("/memory/search")
async def get_matrix_memory_search(request: Request, q: str | None = None) -> JSONResponse:
"""Search Timmy's memory for relevant snippets.
Allows Matrix visitors to query Timmy's memory ("what do you remember
about sovereignty?"). Results appear as floating crystal-ball text
in the Workshop room.
Query params:
- q: Search query text (required)
Response: JSON array of memory objects:
- text: Memory content (truncated to 200 chars)
- relevance: Similarity score 0-1
- created_at: ISO-8601 timestamp
- context_type: Memory type (conversation, fact, etc.)
Rate limited to 1 search per 5 seconds per IP.
Returns:
- 200: JSON array of memory results (max 5)
- 400: Missing or empty query parameter
- 429: Rate limit exceeded
"""
# Validate query parameter
query = q.strip() if q else ""
if not query:
return JSONResponse(
status_code=400,
content={"error": "Query parameter 'q' is required"},
)
# Rate limiting check by IP
client_ip = _get_client_ip(request)
now = time.time()
last_request = _memory_search_last_request.get(client_ip, 0)
time_since_last = now - last_request
if time_since_last < _MEMORY_SEARCH_RATE_LIMIT_SECONDS:
retry_after = _MEMORY_SEARCH_RATE_LIMIT_SECONDS - time_since_last
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded. Try again later."},
headers={"Retry-After": str(int(retry_after) + 1)},
)
# Record this request
_memory_search_last_request[client_ip] = now
# Search memories
try:
memories = search_memories(query, limit=_MAX_MEMORY_RESULTS)
results = _build_matrix_memory_response(memories)
except Exception as exc:
logger.warning("Memory search failed: %s", exc)
results = []
return JSONResponse(
content=results,
headers={"Cache-Control": "no-cache, no-store"},
)

View File

@@ -0,0 +1,266 @@
"""Matrix configuration loader utility.
Provides a typed dataclass for Matrix world configuration and a loader
that fetches settings from YAML with sensible defaults.
"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import yaml
logger = logging.getLogger(__name__)
@dataclass
class PointLight:
"""A single point light in the Matrix world."""
color: str = "#FFFFFF"
intensity: float = 1.0
position: dict[str, float] = field(default_factory=lambda: {"x": 0, "y": 0, "z": 0})
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "PointLight":
"""Create a PointLight from a dictionary with defaults."""
return cls(
color=data.get("color", "#FFFFFF"),
intensity=data.get("intensity", 1.0),
position=data.get("position", {"x": 0, "y": 0, "z": 0}),
)
def _default_point_lights_factory() -> list[PointLight]:
"""Factory function for default point lights."""
return [
PointLight(
color="#FFAA55", # Warm amber (Workshop)
intensity=1.2,
position={"x": 0, "y": 5, "z": 0},
),
PointLight(
color="#3B82F6", # Cool blue (Matrix)
intensity=0.8,
position={"x": -5, "y": 3, "z": -5},
),
PointLight(
color="#A855F7", # Purple accent
intensity=0.6,
position={"x": 5, "y": 3, "z": 5},
),
]
@dataclass
class LightingConfig:
"""Lighting configuration for the Matrix world."""
ambient_color: str = "#FFAA55" # Warm amber (Workshop warmth)
ambient_intensity: float = 0.5
point_lights: list[PointLight] = field(default_factory=_default_point_lights_factory)
@classmethod
def from_dict(cls, data: dict[str, Any] | None) -> "LightingConfig":
"""Create a LightingConfig from a dictionary with defaults."""
if data is None:
data = {}
point_lights_data = data.get("point_lights", [])
point_lights = (
[PointLight.from_dict(pl) for pl in point_lights_data]
if point_lights_data
else _default_point_lights_factory()
)
return cls(
ambient_color=data.get("ambient_color", "#FFAA55"),
ambient_intensity=data.get("ambient_intensity", 0.5),
point_lights=point_lights,
)
@dataclass
class EnvironmentConfig:
"""Environment settings for the Matrix world."""
rain_enabled: bool = False
starfield_enabled: bool = True
fog_color: str = "#0f0f23"
fog_density: float = 0.02
@classmethod
def from_dict(cls, data: dict[str, Any] | None) -> "EnvironmentConfig":
"""Create an EnvironmentConfig from a dictionary with defaults."""
if data is None:
data = {}
return cls(
rain_enabled=data.get("rain_enabled", False),
starfield_enabled=data.get("starfield_enabled", True),
fog_color=data.get("fog_color", "#0f0f23"),
fog_density=data.get("fog_density", 0.02),
)
@dataclass
class FeaturesConfig:
"""Feature toggles for the Matrix world."""
chat_enabled: bool = True
visitor_avatars: bool = True
pip_familiar: bool = True
workshop_portal: bool = True
@classmethod
def from_dict(cls, data: dict[str, Any] | None) -> "FeaturesConfig":
"""Create a FeaturesConfig from a dictionary with defaults."""
if data is None:
data = {}
return cls(
chat_enabled=data.get("chat_enabled", True),
visitor_avatars=data.get("visitor_avatars", True),
pip_familiar=data.get("pip_familiar", True),
workshop_portal=data.get("workshop_portal", True),
)
@dataclass
class AgentConfig:
"""Configuration for a single Matrix agent."""
name: str = ""
role: str = ""
enabled: bool = True
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AgentConfig":
"""Create an AgentConfig from a dictionary with defaults."""
return cls(
name=data.get("name", ""),
role=data.get("role", ""),
enabled=data.get("enabled", True),
)
@dataclass
class AgentsConfig:
"""Agent registry configuration."""
default_count: int = 5
max_count: int = 20
agents: list[AgentConfig] = field(default_factory=list)
@classmethod
def from_dict(cls, data: dict[str, Any] | None) -> "AgentsConfig":
"""Create an AgentsConfig from a dictionary with defaults."""
if data is None:
data = {}
agents_data = data.get("agents", [])
agents = [AgentConfig.from_dict(a) for a in agents_data] if agents_data else []
return cls(
default_count=data.get("default_count", 5),
max_count=data.get("max_count", 20),
agents=agents,
)
@dataclass
class MatrixConfig:
"""Complete Matrix world configuration.
Combines lighting, environment, features, and agent settings
into a single configuration object.
"""
lighting: LightingConfig = field(default_factory=LightingConfig)
environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
features: FeaturesConfig = field(default_factory=FeaturesConfig)
agents: AgentsConfig = field(default_factory=AgentsConfig)
@classmethod
def from_dict(cls, data: dict[str, Any] | None) -> "MatrixConfig":
"""Create a MatrixConfig from a dictionary with defaults for missing sections."""
if data is None:
data = {}
return cls(
lighting=LightingConfig.from_dict(data.get("lighting")),
environment=EnvironmentConfig.from_dict(data.get("environment")),
features=FeaturesConfig.from_dict(data.get("features")),
agents=AgentsConfig.from_dict(data.get("agents")),
)
def to_dict(self) -> dict[str, Any]:
"""Convert the configuration to a plain dictionary."""
return {
"lighting": {
"ambient_color": self.lighting.ambient_color,
"ambient_intensity": self.lighting.ambient_intensity,
"point_lights": [
{
"color": pl.color,
"intensity": pl.intensity,
"position": pl.position,
}
for pl in self.lighting.point_lights
],
},
"environment": {
"rain_enabled": self.environment.rain_enabled,
"starfield_enabled": self.environment.starfield_enabled,
"fog_color": self.environment.fog_color,
"fog_density": self.environment.fog_density,
},
"features": {
"chat_enabled": self.features.chat_enabled,
"visitor_avatars": self.features.visitor_avatars,
"pip_familiar": self.features.pip_familiar,
"workshop_portal": self.features.workshop_portal,
},
"agents": {
"default_count": self.agents.default_count,
"max_count": self.agents.max_count,
"agents": [
{"name": a.name, "role": a.role, "enabled": a.enabled}
for a in self.agents.agents
],
},
}
def load_from_yaml(path: str | Path) -> MatrixConfig:
"""Load Matrix configuration from a YAML file.
Missing keys are filled with sensible defaults. If the file
cannot be read or parsed, returns a fully default configuration.
Args:
path: Path to the YAML configuration file.
Returns:
A MatrixConfig instance with loaded or default values.
"""
path = Path(path)
if not path.exists():
logger.warning("Matrix config file not found: %s, using defaults", path)
return MatrixConfig()
try:
with open(path, encoding="utf-8") as f:
raw_data = yaml.safe_load(f)
if not isinstance(raw_data, dict):
logger.warning("Matrix config invalid format, using defaults")
return MatrixConfig()
return MatrixConfig.from_dict(raw_data)
except yaml.YAMLError as exc:
logger.warning("Matrix config YAML parse error: %s, using defaults", exc)
return MatrixConfig()
except OSError as exc:
logger.warning("Matrix config read error: %s, using defaults", exc)
return MatrixConfig()

View File

@@ -0,0 +1,333 @@
"""Presence state serializer — transforms ADR-023 presence dicts for consumers.
Converts the raw presence schema (version, liveness, mood, energy, etc.)
into the camelCase world-state payload consumed by the Workshop 3D renderer
and WebSocket gateway.
"""
import logging
import time
from datetime import UTC, datetime
logger = logging.getLogger(__name__)
# Default Pip familiar state (used when familiar module unavailable)
DEFAULT_PIP_STATE = {
"name": "Pip",
"mood": "sleepy",
"energy": 0.5,
"color": "0x00b450", # emerald green
"trail_color": "0xdaa520", # gold
}
def _get_familiar_state() -> dict:
"""Get Pip familiar state from familiar module, with graceful fallback.
Returns a dict with name, mood, energy, color, and trail_color.
Falls back to default state if familiar module unavailable or raises.
"""
try:
from timmy.familiar import pip_familiar
snapshot = pip_familiar.snapshot()
# Map PipSnapshot fields to the expected agent_state format
return {
"name": snapshot.name,
"mood": snapshot.state,
"energy": DEFAULT_PIP_STATE["energy"], # Pip doesn't track energy yet
"color": DEFAULT_PIP_STATE["color"],
"trail_color": DEFAULT_PIP_STATE["trail_color"],
}
except Exception as exc:
logger.warning("Familiar state unavailable, using default: %s", exc)
return DEFAULT_PIP_STATE.copy()
# Valid bark styles for Matrix protocol
BARK_STYLES = {"speech", "thought", "whisper", "shout"}
def produce_bark(agent_id: str, text: str, reply_to: str = None, style: str = "speech") -> dict:
"""Format a chat response as a Matrix bark message.
Barks appear as floating text above agents in the Matrix 3D world with
typing animation. This function formats the text for the Matrix protocol.
Parameters
----------
agent_id:
Unique identifier for the agent (e.g. ``"timmy"``).
text:
The chat response text to display as a bark.
reply_to:
Optional message ID or reference this bark is replying to.
style:
Visual style of the bark. One of: "speech" (default), "thought",
"whisper", "shout". Invalid styles fall back to "speech".
Returns
-------
dict
Bark message with keys ``type``, ``agent_id``, ``data`` (containing
``text``, ``reply_to``, ``style``), and ``ts``.
Examples
--------
>>> produce_bark("timmy", "Hello world!")
{
"type": "bark",
"agent_id": "timmy",
"data": {"text": "Hello world!", "reply_to": None, "style": "speech"},
"ts": 1742529600,
}
"""
# Validate and normalize style
if style not in BARK_STYLES:
style = "speech"
# Truncate text to 280 characters (bark, not essay)
truncated_text = text[:280] if text else ""
return {
"type": "bark",
"agent_id": agent_id,
"data": {
"text": truncated_text,
"reply_to": reply_to,
"style": style,
},
"ts": int(time.time()),
}
def produce_thought(
agent_id: str, thought_text: str, thought_id: int, chain_id: str = None
) -> dict:
"""Format a thinking engine thought as a Matrix thought message.
Thoughts appear as subtle floating text in the 3D world, streaming from
Timmy's thinking engine (/thinking/api). This function wraps thoughts in
Matrix protocol format.
Parameters
----------
agent_id:
Unique identifier for the agent (e.g. ``"timmy"``).
thought_text:
The thought text to display. Truncated to 500 characters.
thought_id:
Unique identifier for this thought (sequence number).
chain_id:
Optional chain identifier grouping related thoughts.
Returns
-------
dict
Thought message with keys ``type``, ``agent_id``, ``data`` (containing
``text``, ``thought_id``, ``chain_id``), and ``ts``.
Examples
--------
>>> produce_thought("timmy", "Considering the options...", 42, "chain-123")
{
"type": "thought",
"agent_id": "timmy",
"data": {"text": "Considering the options...", "thought_id": 42, "chain_id": "chain-123"},
"ts": 1742529600,
}
"""
# Truncate text to 500 characters (thoughts can be longer than barks)
truncated_text = thought_text[:500] if thought_text else ""
return {
"type": "thought",
"agent_id": agent_id,
"data": {
"text": truncated_text,
"thought_id": thought_id,
"chain_id": chain_id,
},
"ts": int(time.time()),
}
def serialize_presence(presence: dict) -> dict:
"""Transform an ADR-023 presence dict into the world-state API shape.
Parameters
----------
presence:
Raw presence dict as written by
:func:`~timmy.workshop_state.get_state_dict` or read from
``~/.timmy/presence.json``.
Returns
-------
dict
CamelCase world-state payload with ``timmyState``, ``familiar``,
``activeThreads``, ``recentEvents``, ``concerns``, ``visitorPresent``,
``updatedAt``, and ``version`` keys.
"""
return {
"timmyState": {
"mood": presence.get("mood", "calm"),
"activity": presence.get("current_focus", "idle"),
"energy": presence.get("energy", 0.5),
"confidence": presence.get("confidence", 0.7),
},
"familiar": presence.get("familiar"),
"activeThreads": presence.get("active_threads", []),
"recentEvents": presence.get("recent_events", []),
"concerns": presence.get("concerns", []),
"visitorPresent": False,
"updatedAt": presence.get("liveness", datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")),
"version": presence.get("version", 1),
}
# ---------------------------------------------------------------------------
# Status mapping: ADR-023 current_focus → Matrix agent status
# ---------------------------------------------------------------------------
_STATUS_KEYWORDS: dict[str, str] = {
"thinking": "thinking",
"speaking": "speaking",
"talking": "speaking",
"idle": "idle",
}
def _derive_status(current_focus: str) -> str:
"""Map a free-text current_focus value to a Matrix status enum.
Returns one of: online, idle, thinking, speaking.
"""
focus_lower = current_focus.lower()
for keyword, status in _STATUS_KEYWORDS.items():
if keyword in focus_lower:
return status
if current_focus and current_focus != "idle":
return "online"
return "idle"
def produce_agent_state(agent_id: str, presence: dict) -> dict:
"""Build a Matrix-compatible ``agent_state`` message from presence data.
Parameters
----------
agent_id:
Unique identifier for the agent (e.g. ``"timmy"``).
presence:
Raw ADR-023 presence dict.
Returns
-------
dict
Message with keys ``type``, ``agent_id``, ``data``, and ``ts``.
"""
return {
"type": "agent_state",
"agent_id": agent_id,
"data": {
"display_name": presence.get("display_name", agent_id.title()),
"role": presence.get("role", "assistant"),
"status": _derive_status(presence.get("current_focus", "idle")),
"mood": presence.get("mood", "calm"),
"energy": presence.get("energy", 0.5),
"bark": presence.get("bark", ""),
"familiar": _get_familiar_state(),
},
"ts": int(time.time()),
}
def produce_system_status() -> dict:
"""Generate a system_status message for the Matrix.
Returns a dict with system health metrics including agent count,
visitor count, uptime, thinking engine status, and memory count.
Returns
-------
dict
Message with keys ``type``, ``data`` (containing ``agents_online``,
``visitors``, ``uptime_seconds``, ``thinking_active``, ``memory_count``),
and ``ts``.
Examples
--------
>>> produce_system_status()
{
"type": "system_status",
"data": {
"agents_online": 5,
"visitors": 2,
"uptime_seconds": 3600,
"thinking_active": True,
"memory_count": 150,
},
"ts": 1742529600,
}
"""
# Count agents with status != offline
agents_online = 0
try:
from timmy.agents.loader import list_agents
agents = list_agents()
agents_online = sum(1 for a in agents if a.get("status", "") not in ("offline", ""))
except Exception as exc:
logger.debug("Failed to count agents: %s", exc)
# Count visitors from WebSocket clients
visitors = 0
try:
from dashboard.routes.world import _ws_clients
visitors = len(_ws_clients)
except Exception as exc:
logger.debug("Failed to count visitors: %s", exc)
# Calculate uptime
uptime_seconds = 0
try:
from datetime import UTC
from config import APP_START_TIME
uptime_seconds = int((datetime.now(UTC) - APP_START_TIME).total_seconds())
except Exception as exc:
logger.debug("Failed to calculate uptime: %s", exc)
# Check thinking engine status
thinking_active = False
try:
from config import settings
from timmy.thinking import thinking_engine
thinking_active = settings.thinking_enabled and thinking_engine is not None
except Exception as exc:
logger.debug("Failed to check thinking status: %s", exc)
# Count memories in vector store
memory_count = 0
try:
from timmy.memory_system import get_memory_stats
stats = get_memory_stats()
memory_count = stats.get("total_entries", 0)
except Exception as exc:
logger.debug("Failed to count memories: %s", exc)
return {
"type": "system_status",
"data": {
"agents_online": agents_online,
"visitors": visitors,
"uptime_seconds": uptime_seconds,
"thinking_active": thinking_active,
"memory_count": memory_count,
},
"ts": int(time.time()),
}

View File

@@ -0,0 +1,261 @@
"""Shared WebSocket message protocol for the Matrix frontend.
Defines all WebSocket message types as an enum and typed dataclasses
with ``to_json()`` / ``from_json()`` helpers so every producer and the
gateway speak the same language.
Message wire format
-------------------
.. code-block:: json
{"type": "agent_state", "agent_id": "timmy", "data": {...}, "ts": 1234567890}
"""
import json
import logging
import time
from dataclasses import asdict, dataclass, field
from enum import StrEnum
from typing import Any
logger = logging.getLogger(__name__)
class MessageType(StrEnum):
"""All WebSocket message types defined by the Matrix PROTOCOL.md."""
AGENT_STATE = "agent_state"
VISITOR_STATE = "visitor_state"
BARK = "bark"
THOUGHT = "thought"
SYSTEM_STATUS = "system_status"
CONNECTION_ACK = "connection_ack"
ERROR = "error"
TASK_UPDATE = "task_update"
MEMORY_FLASH = "memory_flash"
# ---------------------------------------------------------------------------
# Base message
# ---------------------------------------------------------------------------
@dataclass
class WSMessage:
"""Base WebSocket message with common envelope fields."""
type: str
ts: float = field(default_factory=time.time)
def to_json(self) -> str:
"""Serialise the message to a JSON string."""
return json.dumps(asdict(self))
@classmethod
def from_json(cls, raw: str) -> "WSMessage":
"""Deserialise a JSON string into the correct message subclass.
Falls back to the base ``WSMessage`` when the ``type`` field is
unrecognised.
"""
data = json.loads(raw)
msg_type = data.get("type")
sub = _REGISTRY.get(msg_type)
if sub is not None:
return sub.from_json(raw)
return cls(**data)
# ---------------------------------------------------------------------------
# Concrete message types
# ---------------------------------------------------------------------------
@dataclass
class AgentStateMessage(WSMessage):
"""State update for a single agent."""
type: str = field(default=MessageType.AGENT_STATE)
agent_id: str = ""
data: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_json(cls, raw: str) -> "AgentStateMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.AGENT_STATE),
ts=payload.get("ts", time.time()),
agent_id=payload.get("agent_id", ""),
data=payload.get("data", {}),
)
@dataclass
class VisitorStateMessage(WSMessage):
"""State update for a visitor / user session."""
type: str = field(default=MessageType.VISITOR_STATE)
visitor_id: str = ""
data: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_json(cls, raw: str) -> "VisitorStateMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.VISITOR_STATE),
ts=payload.get("ts", time.time()),
visitor_id=payload.get("visitor_id", ""),
data=payload.get("data", {}),
)
@dataclass
class BarkMessage(WSMessage):
"""A bark (chat-like utterance) from an agent."""
type: str = field(default=MessageType.BARK)
agent_id: str = ""
content: str = ""
@classmethod
def from_json(cls, raw: str) -> "BarkMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.BARK),
ts=payload.get("ts", time.time()),
agent_id=payload.get("agent_id", ""),
content=payload.get("content", ""),
)
@dataclass
class ThoughtMessage(WSMessage):
"""An inner thought from an agent."""
type: str = field(default=MessageType.THOUGHT)
agent_id: str = ""
content: str = ""
@classmethod
def from_json(cls, raw: str) -> "ThoughtMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.THOUGHT),
ts=payload.get("ts", time.time()),
agent_id=payload.get("agent_id", ""),
content=payload.get("content", ""),
)
@dataclass
class SystemStatusMessage(WSMessage):
"""System-wide status broadcast."""
type: str = field(default=MessageType.SYSTEM_STATUS)
status: str = ""
data: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_json(cls, raw: str) -> "SystemStatusMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.SYSTEM_STATUS),
ts=payload.get("ts", time.time()),
status=payload.get("status", ""),
data=payload.get("data", {}),
)
@dataclass
class ConnectionAckMessage(WSMessage):
"""Acknowledgement sent when a client connects."""
type: str = field(default=MessageType.CONNECTION_ACK)
client_id: str = ""
@classmethod
def from_json(cls, raw: str) -> "ConnectionAckMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.CONNECTION_ACK),
ts=payload.get("ts", time.time()),
client_id=payload.get("client_id", ""),
)
@dataclass
class ErrorMessage(WSMessage):
"""Error message sent to a client."""
type: str = field(default=MessageType.ERROR)
code: str = ""
message: str = ""
@classmethod
def from_json(cls, raw: str) -> "ErrorMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.ERROR),
ts=payload.get("ts", time.time()),
code=payload.get("code", ""),
message=payload.get("message", ""),
)
@dataclass
class TaskUpdateMessage(WSMessage):
"""Update about a task (created, assigned, completed, etc.)."""
type: str = field(default=MessageType.TASK_UPDATE)
task_id: str = ""
status: str = ""
data: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_json(cls, raw: str) -> "TaskUpdateMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.TASK_UPDATE),
ts=payload.get("ts", time.time()),
task_id=payload.get("task_id", ""),
status=payload.get("status", ""),
data=payload.get("data", {}),
)
@dataclass
class MemoryFlashMessage(WSMessage):
"""A flash of memory — a recalled or stored memory event."""
type: str = field(default=MessageType.MEMORY_FLASH)
agent_id: str = ""
memory_key: str = ""
content: str = ""
@classmethod
def from_json(cls, raw: str) -> "MemoryFlashMessage":
payload = json.loads(raw)
return cls(
type=payload.get("type", MessageType.MEMORY_FLASH),
ts=payload.get("ts", time.time()),
agent_id=payload.get("agent_id", ""),
memory_key=payload.get("memory_key", ""),
content=payload.get("content", ""),
)
# ---------------------------------------------------------------------------
# Registry for from_json dispatch
# ---------------------------------------------------------------------------
_REGISTRY: dict[str, type[WSMessage]] = {
MessageType.AGENT_STATE: AgentStateMessage,
MessageType.VISITOR_STATE: VisitorStateMessage,
MessageType.BARK: BarkMessage,
MessageType.THOUGHT: ThoughtMessage,
MessageType.SYSTEM_STATUS: SystemStatusMessage,
MessageType.CONNECTION_ACK: ConnectionAckMessage,
MessageType.ERROR: ErrorMessage,
MessageType.TASK_UPDATE: TaskUpdateMessage,
MessageType.MEMORY_FLASH: MemoryFlashMessage,
}

View File

@@ -99,11 +99,11 @@ class GrokBackend:
def _get_client(self):
"""Create OpenAI client configured for xAI endpoint."""
from config import settings
import httpx
from openai import OpenAI
from config import settings
return OpenAI(
api_key=self._api_key,
base_url=settings.xai_base_url,
@@ -112,11 +112,11 @@ class GrokBackend:
async def _get_async_client(self):
"""Create async OpenAI client configured for xAI endpoint."""
from config import settings
import httpx
from openai import AsyncOpenAI
from config import settings
return AsyncOpenAI(
api_key=self._api_key,
base_url=settings.xai_base_url,
@@ -264,6 +264,7 @@ class GrokBackend:
},
}
except Exception as exc:
logger.exception("Grok health check failed")
return {
"ok": False,
"error": str(exc),
@@ -430,6 +431,7 @@ class ClaudeBackend:
)
return {"ok": True, "error": None, "backend": "claude", "model": self._model}
except Exception as exc:
logger.exception("Claude health check failed")
return {"ok": False, "error": str(exc), "backend": "claude", "model": self._model}
# ── Private helpers ───────────────────────────────────────────────────

View File

@@ -37,6 +37,39 @@ def _is_interactive() -> bool:
return hasattr(sys.stdin, "isatty") and sys.stdin.isatty()
def _read_message_input(message: list[str]) -> str:
"""Join CLI args into a message, reading from stdin when requested.
Returns the final message string. Raises ``typer.Exit(1)`` when
stdin is explicitly requested (``-``) but empty.
"""
message_str = " ".join(message)
if message_str == "-" or not _is_interactive():
try:
stdin_content = sys.stdin.read().strip()
except (KeyboardInterrupt, EOFError):
stdin_content = ""
if stdin_content:
message_str = stdin_content
elif message_str == "-":
typer.echo("No input provided via stdin.", err=True)
raise typer.Exit(1)
return message_str
def _resolve_session_id(session_id: str | None, new_session: bool) -> str:
"""Return the effective session ID for a chat invocation."""
import uuid
if session_id is not None:
return session_id
if new_session:
return str(uuid.uuid4())
return _CLI_SESSION_ID
def _prompt_interactive(req, tool_name: str, tool_args: dict) -> None:
"""Display tool details and prompt the human for approval."""
description = format_action_description(tool_name, tool_args)
@@ -143,6 +176,35 @@ def think(
timmy.print_response(f"Think carefully about: {topic}", stream=True, session_id=_CLI_SESSION_ID)
def _read_message_input(message: list[str]) -> str:
"""Join CLI arguments and read from stdin when appropriate."""
message_str = " ".join(message)
if message_str == "-" or not _is_interactive():
try:
stdin_content = sys.stdin.read().strip()
except (KeyboardInterrupt, EOFError):
stdin_content = ""
if stdin_content:
message_str = stdin_content
elif message_str == "-":
typer.echo("No input provided via stdin.", err=True)
raise typer.Exit(1)
return message_str
def _resolve_session_id(session_id: str | None, new_session: bool) -> str:
"""Return the effective session ID based on CLI flags."""
import uuid
if session_id is not None:
return session_id
if new_session:
return str(uuid.uuid4())
return _CLI_SESSION_ID
@app.command()
def chat(
message: list[str] = typer.Argument(
@@ -179,38 +241,13 @@ def chat(
Read from stdin by passing "-" as the message or piping input.
"""
import uuid
# Join multiple arguments into a single message string
message_str = " ".join(message)
# Handle stdin input if "-" is passed or stdin is not a tty
if message_str == "-" or not _is_interactive():
try:
stdin_content = sys.stdin.read().strip()
except (KeyboardInterrupt, EOFError):
stdin_content = ""
if stdin_content:
message_str = stdin_content
elif message_str == "-":
typer.echo("No input provided via stdin.", err=True)
raise typer.Exit(1)
if session_id is not None:
pass # use the provided value
elif new_session:
session_id = str(uuid.uuid4())
else:
session_id = _CLI_SESSION_ID
message_str = _read_message_input(message)
session_id = _resolve_session_id(session_id, new_session)
timmy = create_timmy(backend=backend, session_id=session_id)
# Use agent.run() so we can intercept paused runs for tool confirmation.
run_output = timmy.run(message_str, stream=False, session_id=session_id)
# Handle paused runs — dangerous tools need user approval
run_output = _handle_tool_confirmation(timmy, run_output, session_id, autonomous=autonomous)
# Print the final response
content = run_output.content if hasattr(run_output, "content") else str(run_output)
if content:
from timmy.session import _clean_response

View File

@@ -97,6 +97,7 @@ async def probe_tool_use() -> dict:
"error_type": "empty_result",
}
except Exception as exc:
logger.exception("Tool use probe failed")
return {
"success": False,
"capability": cap,
@@ -129,6 +130,7 @@ async def probe_multistep_planning() -> dict:
"error_type": "verification_failed",
}
except Exception as exc:
logger.exception("Multistep planning probe failed")
return {
"success": False,
"capability": cap,
@@ -151,6 +153,7 @@ async def probe_memory_write() -> dict:
"error_type": None,
}
except Exception as exc:
logger.exception("Memory write probe failed")
return {
"success": False,
"capability": cap,
@@ -179,6 +182,7 @@ async def probe_memory_read() -> dict:
"error_type": "empty_result",
}
except Exception as exc:
logger.exception("Memory read probe failed")
return {
"success": False,
"capability": cap,
@@ -214,6 +218,7 @@ async def probe_self_coding() -> dict:
"error_type": "verification_failed",
}
except Exception as exc:
logger.exception("Self-coding probe failed")
return {
"success": False,
"capability": cap,
@@ -325,6 +330,7 @@ class LoopQAOrchestrator:
result = await probe_fn()
except Exception as exc:
# Probe itself crashed — record failure and report
logger.exception("Loop QA probe %s crashed", cap.value)
capture_error(exc, source="loop_qa", context={"capability": cap.value})
result = {
"success": False,

View File

@@ -139,6 +139,7 @@ def _run_kimi(cmd: list[str], workdir: str) -> dict[str, Any]:
"error": "Kimi timed out after 300s. Task may be too broad — try breaking it into smaller pieces.",
}
except Exception as exc:
logger.exception("Failed to run Kimi subprocess")
return {
"success": False,
"error": f"Failed to run Kimi: {exc}",

View File

@@ -122,6 +122,7 @@ def check_ollama_health() -> dict[str, Any]:
models = response.json().get("models", [])
result["available_models"] = [m.get("name", "") for m in models]
except Exception as e:
logger.exception("Ollama health check failed")
result["error"] = str(e)
return result
@@ -289,6 +290,7 @@ def get_live_system_status() -> dict[str, Any]:
try:
result["system"] = get_system_info()
except Exception as exc:
logger.exception("Failed to get system info")
result["system"] = {"error": str(exc)}
# Task queue
@@ -301,6 +303,7 @@ def get_live_system_status() -> dict[str, Any]:
try:
result["memory"] = get_memory_status()
except Exception as exc:
logger.exception("Failed to get memory status")
result["memory"] = {"error": str(exc)}
# Uptime
@@ -406,4 +409,5 @@ def run_self_tests(scope: str = "fast", _repo_root: str | None = None) -> dict[s
except subprocess.TimeoutExpired:
return {"success": False, "error": "Test run timed out (120s limit)"}
except Exception as exc:
logger.exception("Self-test run failed")
return {"success": False, "error": str(exc)}

View File

@@ -31,6 +31,8 @@ for _mod in [
"pyzbar.pyzbar",
"pyttsx3",
"sentence_transformers",
"swarm",
"swarm.event_log",
]:
sys.modules.setdefault(_mod, MagicMock())

View File

@@ -120,3 +120,50 @@ class TestCSRFDecoratorSupport:
# Protected endpoint should be 403
response2 = client.post("/protected")
assert response2.status_code == 403
def test_csrf_exempt_endpoint_not_executed_before_check(self):
"""Regression test for #626: endpoint must NOT execute before CSRF check.
Previously the middleware called call_next() first, executing the endpoint
and its side effects, then checked @csrf_exempt afterward. This meant
non-exempt endpoints would execute even when CSRF validation failed.
"""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
side_effect_log: list[str] = []
@app.post("/protected-with-side-effects")
def protected_with_side_effects():
side_effect_log.append("executed")
return {"message": "should not run"}
client = TestClient(app)
# POST without CSRF token — should be blocked with 403
response = client.post("/protected-with-side-effects")
assert response.status_code == 403
# The critical assertion: the endpoint must NOT have executed
assert side_effect_log == [], (
"Endpoint executed before CSRF validation! Side effects occurred "
"despite CSRF failure (see issue #626)."
)
def test_csrf_exempt_endpoint_does_execute(self):
"""Ensure @csrf_exempt endpoints still execute normally."""
app = FastAPI()
app.add_middleware(CSRFMiddleware)
side_effect_log: list[str] = []
@app.post("/exempt-webhook")
@csrf_exempt
def exempt_webhook():
side_effect_log.append("executed")
return {"message": "webhook ok"}
client = TestClient(app)
response = client.post("/exempt-webhook")
assert response.status_code == 200
assert side_effect_log == ["executed"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,509 @@
"""Tests for infrastructure.models.multimodal — multi-modal model management."""
import json
from unittest.mock import MagicMock, patch
from infrastructure.models.multimodal import (
DEFAULT_FALLBACK_CHAINS,
KNOWN_MODEL_CAPABILITIES,
ModelCapability,
ModelInfo,
MultiModalManager,
get_model_for_capability,
model_supports_tools,
model_supports_vision,
pull_model_with_fallback,
)
# ---------------------------------------------------------------------------
# ModelCapability enum
# ---------------------------------------------------------------------------
class TestModelCapability:
def test_members_exist(self):
assert ModelCapability.TEXT
assert ModelCapability.VISION
assert ModelCapability.AUDIO
assert ModelCapability.TOOLS
assert ModelCapability.JSON
assert ModelCapability.STREAMING
def test_all_members_unique(self):
values = [m.value for m in ModelCapability]
assert len(values) == len(set(values))
# ---------------------------------------------------------------------------
# ModelInfo dataclass
# ---------------------------------------------------------------------------
class TestModelInfo:
def test_defaults(self):
info = ModelInfo(name="test-model")
assert info.name == "test-model"
assert info.capabilities == set()
assert info.is_available is False
assert info.is_pulled is False
assert info.size_mb is None
assert info.description == ""
def test_supports_true(self):
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT, ModelCapability.VISION})
assert info.supports(ModelCapability.TEXT) is True
assert info.supports(ModelCapability.VISION) is True
def test_supports_false(self):
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT})
assert info.supports(ModelCapability.VISION) is False
# ---------------------------------------------------------------------------
# Known model capabilities lookup table
# ---------------------------------------------------------------------------
class TestKnownModelCapabilities:
def test_vision_models_have_vision(self):
vision_names = [
"llama3.2-vision",
"llava",
"moondream",
"qwen2.5-vl",
]
for name in vision_names:
assert ModelCapability.VISION in KNOWN_MODEL_CAPABILITIES[name], name
def test_text_models_lack_vision(self):
text_only = ["deepseek-r1", "gemma2", "phi3"]
for name in text_only:
assert ModelCapability.VISION not in KNOWN_MODEL_CAPABILITIES[name], name
def test_all_models_have_text(self):
for name, caps in KNOWN_MODEL_CAPABILITIES.items():
assert ModelCapability.TEXT in caps, f"{name} should have TEXT"
# ---------------------------------------------------------------------------
# Default fallback chains
# ---------------------------------------------------------------------------
class TestDefaultFallbackChains:
def test_vision_chain_non_empty(self):
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.VISION]) > 0
def test_tools_chain_non_empty(self):
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.TOOLS]) > 0
def test_audio_chain_empty(self):
assert DEFAULT_FALLBACK_CHAINS[ModelCapability.AUDIO] == []
# ---------------------------------------------------------------------------
# Helpers to build a manager without hitting the network
# ---------------------------------------------------------------------------
def _fake_ollama_tags(*model_names: str) -> bytes:
"""Build a JSON response mimicking Ollama /api/tags."""
models = []
for name in model_names:
models.append({"name": name, "size": 4 * 1024 * 1024 * 1024, "details": {"family": "test"}})
return json.dumps({"models": models}).encode()
def _make_manager(model_names: list[str] | None = None) -> MultiModalManager:
"""Create a MultiModalManager with mocked Ollama responses."""
if model_names is None:
# No models available — Ollama unreachable
with patch("urllib.request.urlopen", side_effect=ConnectionError("no ollama")):
return MultiModalManager(ollama_url="http://localhost:11434")
resp = MagicMock()
resp.__enter__ = MagicMock(return_value=resp)
resp.__exit__ = MagicMock(return_value=False)
resp.read.return_value = _fake_ollama_tags(*model_names)
resp.status = 200
with patch("urllib.request.urlopen", return_value=resp):
return MultiModalManager(ollama_url="http://localhost:11434")
# ---------------------------------------------------------------------------
# MultiModalManager — init & refresh
# ---------------------------------------------------------------------------
class TestMultiModalManagerInit:
def test_init_no_ollama(self):
mgr = _make_manager(None)
assert mgr.list_available_models() == []
def test_init_with_models(self):
mgr = _make_manager(["llama3.1:8b", "llava:7b"])
names = {m.name for m in mgr.list_available_models()}
assert names == {"llama3.1:8b", "llava:7b"}
def test_refresh_updates_models(self):
mgr = _make_manager([])
assert mgr.list_available_models() == []
resp = MagicMock()
resp.__enter__ = MagicMock(return_value=resp)
resp.__exit__ = MagicMock(return_value=False)
resp.read.return_value = _fake_ollama_tags("gemma2:9b")
resp.status = 200
with patch("urllib.request.urlopen", return_value=resp):
mgr.refresh()
names = {m.name for m in mgr.list_available_models()}
assert "gemma2:9b" in names
# ---------------------------------------------------------------------------
# _detect_capabilities
# ---------------------------------------------------------------------------
class TestDetectCapabilities:
def test_exact_match(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("llava:7b")
assert ModelCapability.VISION in caps
def test_base_name_match(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("llava:99b")
# "llava:99b" not in table, but "llava" is
assert ModelCapability.VISION in caps
def test_unknown_model_defaults_to_text(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("totally-unknown-model:1b")
assert caps == {ModelCapability.TEXT, ModelCapability.STREAMING}
# ---------------------------------------------------------------------------
# get_model_capabilities / model_supports
# ---------------------------------------------------------------------------
class TestGetModelCapabilities:
def test_available_model(self):
mgr = _make_manager(["llava:7b"])
caps = mgr.get_model_capabilities("llava:7b")
assert ModelCapability.VISION in caps
def test_unavailable_model_uses_detection(self):
mgr = _make_manager([])
caps = mgr.get_model_capabilities("llava:7b")
assert ModelCapability.VISION in caps
class TestModelSupports:
def test_supports_true(self):
mgr = _make_manager(["llava:7b"])
assert mgr.model_supports("llava:7b", ModelCapability.VISION) is True
def test_supports_false(self):
mgr = _make_manager(["deepseek-r1:7b"])
assert mgr.model_supports("deepseek-r1:7b", ModelCapability.VISION) is False
# ---------------------------------------------------------------------------
# get_models_with_capability
# ---------------------------------------------------------------------------
class TestGetModelsWithCapability:
def test_returns_vision_models(self):
mgr = _make_manager(["llava:7b", "deepseek-r1:7b"])
vision = mgr.get_models_with_capability(ModelCapability.VISION)
names = {m.name for m in vision}
assert "llava:7b" in names
assert "deepseek-r1:7b" not in names
def test_empty_when_none_available(self):
mgr = _make_manager(["deepseek-r1:7b"])
vision = mgr.get_models_with_capability(ModelCapability.VISION)
assert vision == []
# ---------------------------------------------------------------------------
# get_best_model_for
# ---------------------------------------------------------------------------
class TestGetBestModelFor:
def test_preferred_model_with_capability(self):
mgr = _make_manager(["llava:7b", "llama3.1:8b"])
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:7b")
assert result == "llava:7b"
def test_preferred_model_without_capability_uses_fallback(self):
mgr = _make_manager(["deepseek-r1:7b", "llava:7b"])
# preferred doesn't have VISION, fallback chain has llava:7b
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="deepseek-r1:7b")
assert result == "llava:7b"
def test_fallback_chain_order(self):
# First in chain: llama3.2:3b
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result == "llama3.2:3b"
def test_any_capable_model_when_no_fallback(self):
mgr = _make_manager(["moondream:1.8b"])
mgr._fallback_chains[ModelCapability.VISION] = [] # clear chain
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result == "moondream:1.8b"
def test_none_when_no_capable_model(self):
mgr = _make_manager(["deepseek-r1:7b"])
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result is None
def test_preferred_model_not_available_skipped(self):
mgr = _make_manager(["llava:7b"])
# preferred_model "llava:13b" is not in available_models
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:13b")
assert result == "llava:7b"
# ---------------------------------------------------------------------------
# pull_model_with_fallback (manager method)
# ---------------------------------------------------------------------------
class TestPullModelWithFallback:
def test_already_available(self):
mgr = _make_manager(["llama3.1:8b"])
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fallback is False
def test_pull_succeeds(self):
mgr = _make_manager([])
pull_resp = MagicMock()
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
pull_resp.__exit__ = MagicMock(return_value=False)
pull_resp.status = 200
# After pull, refresh returns the model
refresh_resp = MagicMock()
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
refresh_resp.__exit__ = MagicMock(return_value=False)
refresh_resp.read.return_value = _fake_ollama_tags("llama3.1:8b")
refresh_resp.status = 200
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fallback is False
def test_pull_fails_uses_capability_fallback(self):
mgr = _make_manager(["llava:7b"])
with patch("urllib.request.urlopen", side_effect=ConnectionError("fail")):
model, is_fallback = mgr.pull_model_with_fallback(
"nonexistent-vision:1b",
capability=ModelCapability.VISION,
)
assert model == "llava:7b"
assert is_fallback is True
def test_pull_fails_uses_default_model(self):
mgr = _make_manager([settings_ollama_model := "llama3.1:8b"])
with (
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
patch("infrastructure.models.multimodal.settings") as mock_settings,
):
mock_settings.ollama_model = settings_ollama_model
mock_settings.ollama_url = "http://localhost:11434"
model, is_fallback = mgr.pull_model_with_fallback("missing-model:99b")
assert model == "llama3.1:8b"
assert is_fallback is True
def test_auto_pull_false_skips_pull(self):
mgr = _make_manager([])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "default"
model, is_fallback = mgr.pull_model_with_fallback("missing:1b", auto_pull=False)
# Falls through to absolute last resort
assert model == "missing:1b"
assert is_fallback is False
def test_absolute_last_resort(self):
mgr = _make_manager([])
with (
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
patch("infrastructure.models.multimodal.settings") as mock_settings,
):
mock_settings.ollama_model = "not-available"
model, is_fallback = mgr.pull_model_with_fallback("primary:1b")
assert model == "primary:1b"
assert is_fallback is False
# ---------------------------------------------------------------------------
# _pull_model
# ---------------------------------------------------------------------------
class TestPullModel:
def test_pull_success(self):
mgr = _make_manager([])
pull_resp = MagicMock()
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
pull_resp.__exit__ = MagicMock(return_value=False)
pull_resp.status = 200
refresh_resp = MagicMock()
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
refresh_resp.__exit__ = MagicMock(return_value=False)
refresh_resp.read.return_value = _fake_ollama_tags("new-model:1b")
refresh_resp.status = 200
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
assert mgr._pull_model("new-model:1b") is True
def test_pull_network_error(self):
mgr = _make_manager([])
with patch("urllib.request.urlopen", side_effect=ConnectionError("offline")):
assert mgr._pull_model("any-model:1b") is False
# ---------------------------------------------------------------------------
# configure_fallback_chain / get_fallback_chain
# ---------------------------------------------------------------------------
class TestFallbackChainConfig:
def test_configure_and_get(self):
mgr = _make_manager(None)
mgr.configure_fallback_chain(ModelCapability.VISION, ["model-a", "model-b"])
assert mgr.get_fallback_chain(ModelCapability.VISION) == ["model-a", "model-b"]
def test_get_returns_copy(self):
mgr = _make_manager(None)
chain = mgr.get_fallback_chain(ModelCapability.VISION)
chain.append("mutated")
assert "mutated" not in mgr.get_fallback_chain(ModelCapability.VISION)
def test_get_empty_for_unknown(self):
mgr = _make_manager(None)
# AUDIO has an empty chain by default
assert mgr.get_fallback_chain(ModelCapability.AUDIO) == []
# ---------------------------------------------------------------------------
# get_model_for_content
# ---------------------------------------------------------------------------
class TestGetModelForContent:
def test_image_content(self):
mgr = _make_manager(["llava:7b"])
model, is_fb = mgr.get_model_for_content("image")
assert model == "llava:7b"
def test_vision_content(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("vision")
assert model == "llava:7b"
def test_multimodal_content(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("multimodal")
assert model == "llava:7b"
def test_audio_content(self):
mgr = _make_manager(["llama3.1:8b"])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "llama3.1:8b"
mock_settings.ollama_url = "http://localhost:11434"
model, _ = mgr.get_model_for_content("audio")
assert model == "llama3.1:8b"
def test_text_content(self):
mgr = _make_manager(["llama3.1:8b"])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "llama3.1:8b"
mock_settings.ollama_url = "http://localhost:11434"
model, _ = mgr.get_model_for_content("text")
assert model == "llama3.1:8b"
def test_preferred_model_respected(self):
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
model, _ = mgr.get_model_for_content("image", preferred_model="llama3.2:3b")
assert model == "llama3.2:3b"
def test_case_insensitive(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("IMAGE")
assert model == "llava:7b"
# ---------------------------------------------------------------------------
# Module-level convenience functions
# ---------------------------------------------------------------------------
class TestConvenienceFunctions:
def _patch_manager(self, mgr):
return patch(
"infrastructure.models.multimodal._multimodal_manager",
mgr,
)
def test_get_model_for_capability(self):
mgr = _make_manager(["llava:7b"])
with self._patch_manager(mgr):
result = get_model_for_capability(ModelCapability.VISION)
assert result == "llava:7b"
def test_pull_model_with_fallback_convenience(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
model, is_fb = pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fb is False
def test_model_supports_vision_true(self):
mgr = _make_manager(["llava:7b"])
with self._patch_manager(mgr):
assert model_supports_vision("llava:7b") is True
def test_model_supports_vision_false(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
assert model_supports_vision("llama3.1:8b") is False
def test_model_supports_tools_true(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
assert model_supports_tools("llama3.1:8b") is True
def test_model_supports_tools_false(self):
mgr = _make_manager(["deepseek-r1:7b"])
with self._patch_manager(mgr):
assert model_supports_tools("deepseek-r1:7b") is False
# ---------------------------------------------------------------------------
# ModelInfo in available_models — size_mb and description populated
# ---------------------------------------------------------------------------
class TestModelInfoPopulation:
def test_size_and_description(self):
mgr = _make_manager(["llama3.1:8b"])
info = mgr._available_models["llama3.1:8b"]
assert info.is_available is True
assert info.is_pulled is True
assert info.size_mb == 4 * 1024 # 4 GiB in MiB
assert info.description == "test"

View File

@@ -0,0 +1,163 @@
"""Tests for cycle_result.json validation in loop_guard.
Covers validate_cycle_result(), _load_cycle_result(), and _is_issue_open().
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from unittest.mock import patch
import pytest
import scripts.loop_guard as lg
@pytest.fixture(autouse=True)
def _isolate(tmp_path, monkeypatch):
"""Redirect loop_guard paths to tmp_path for isolation."""
monkeypatch.setattr(lg, "CYCLE_RESULT_FILE", tmp_path / "cycle_result.json")
monkeypatch.setattr(lg, "CYCLE_DURATION", 300)
monkeypatch.setattr(lg, "GITEA_API", "http://test:3000/api/v1")
monkeypatch.setattr(lg, "REPO_SLUG", "owner/repo")
def _write_cr(tmp_path, data: dict, age_seconds: float = 0) -> Path:
"""Write a cycle_result.json and optionally backdate it."""
p = tmp_path / "cycle_result.json"
p.write_text(json.dumps(data))
if age_seconds:
mtime = time.time() - age_seconds
import os
os.utime(p, (mtime, mtime))
return p
# --- _load_cycle_result ---
def test_load_cycle_result_missing(tmp_path):
assert lg._load_cycle_result() == {}
def test_load_cycle_result_valid(tmp_path):
_write_cr(tmp_path, {"issue": 42, "type": "fix"})
assert lg._load_cycle_result() == {"issue": 42, "type": "fix"}
def test_load_cycle_result_markdown_fenced(tmp_path):
p = tmp_path / "cycle_result.json"
p.write_text('```json\n{"issue": 99}\n```')
assert lg._load_cycle_result() == {"issue": 99}
def test_load_cycle_result_malformed(tmp_path):
p = tmp_path / "cycle_result.json"
p.write_text("not json at all")
assert lg._load_cycle_result() == {}
# --- _is_issue_open ---
def test_is_issue_open_true(monkeypatch):
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
resp_data = json.dumps({"state": "open"}).encode()
class FakeResp:
def read(self):
return resp_data
def __enter__(self):
return self
def __exit__(self, *a):
pass
with patch("urllib.request.urlopen", return_value=FakeResp()):
assert lg._is_issue_open(42) is True
def test_is_issue_open_closed(monkeypatch):
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
resp_data = json.dumps({"state": "closed"}).encode()
class FakeResp:
def read(self):
return resp_data
def __enter__(self):
return self
def __exit__(self, *a):
pass
with patch("urllib.request.urlopen", return_value=FakeResp()):
assert lg._is_issue_open(42) is False
def test_is_issue_open_no_token(monkeypatch):
monkeypatch.setattr(lg, "_get_token", lambda: "")
assert lg._is_issue_open(42) is None
def test_is_issue_open_api_error(monkeypatch):
monkeypatch.setattr(lg, "_get_token", lambda: "tok")
with patch("urllib.request.urlopen", side_effect=OSError("timeout")):
assert lg._is_issue_open(42) is None
# --- validate_cycle_result ---
def test_validate_no_file(tmp_path):
"""No file → returns False, no crash."""
assert lg.validate_cycle_result() is False
def test_validate_fresh_file_open_issue(tmp_path, monkeypatch):
"""Fresh file with open issue → kept."""
_write_cr(tmp_path, {"issue": 10})
monkeypatch.setattr(lg, "_is_issue_open", lambda n: True)
assert lg.validate_cycle_result() is False
assert (tmp_path / "cycle_result.json").exists()
def test_validate_stale_file_removed(tmp_path):
"""File older than 2× CYCLE_DURATION → removed."""
_write_cr(tmp_path, {"issue": 10}, age_seconds=700)
assert lg.validate_cycle_result() is True
assert not (tmp_path / "cycle_result.json").exists()
def test_validate_fresh_file_closed_issue(tmp_path, monkeypatch):
"""Fresh file referencing closed issue → removed."""
_write_cr(tmp_path, {"issue": 10})
monkeypatch.setattr(lg, "_is_issue_open", lambda n: False)
assert lg.validate_cycle_result() is True
assert not (tmp_path / "cycle_result.json").exists()
def test_validate_api_failure_keeps_file(tmp_path, monkeypatch):
"""API failure → file kept (graceful degradation)."""
_write_cr(tmp_path, {"issue": 10})
monkeypatch.setattr(lg, "_is_issue_open", lambda n: None)
assert lg.validate_cycle_result() is False
assert (tmp_path / "cycle_result.json").exists()
def test_validate_no_issue_field(tmp_path):
"""File without issue field → kept (only age check applies)."""
_write_cr(tmp_path, {"type": "fix"})
assert lg.validate_cycle_result() is False
assert (tmp_path / "cycle_result.json").exists()
def test_validate_stale_threshold_boundary(tmp_path, monkeypatch):
"""File just under threshold → kept (not stale yet)."""
_write_cr(tmp_path, {"issue": 10}, age_seconds=599)
monkeypatch.setattr(lg, "_is_issue_open", lambda n: True)
assert lg.validate_cycle_result() is False
assert (tmp_path / "cycle_result.json").exists()

327
tests/spark/test_advisor.py Normal file
View File

@@ -0,0 +1,327 @@
"""Comprehensive tests for spark.advisor module.
Covers all advisory-generation helpers:
- _check_failure_patterns (grouped agent failures)
- _check_agent_performance (top / struggling agents)
- _check_bid_patterns (spread + high average)
- _check_prediction_accuracy (low / high accuracy)
- _check_system_activity (idle / tasks-posted-but-no-completions)
- generate_advisories (integration, sorting, min-events guard)
"""
import json
from spark.advisor import (
_MIN_EVENTS,
Advisory,
_check_agent_performance,
_check_bid_patterns,
_check_failure_patterns,
_check_prediction_accuracy,
_check_system_activity,
generate_advisories,
)
from spark.memory import record_event
# ── Advisory dataclass ─────────────────────────────────────────────────────
class TestAdvisoryDataclass:
def test_defaults(self):
a = Advisory(
category="test",
priority=0.5,
title="T",
detail="D",
suggested_action="A",
)
assert a.subject is None
assert a.evidence_count == 0
def test_all_fields(self):
a = Advisory(
category="c",
priority=0.9,
title="T",
detail="D",
suggested_action="A",
subject="agent-1",
evidence_count=7,
)
assert a.subject == "agent-1"
assert a.evidence_count == 7
# ── _check_failure_patterns ────────────────────────────────────────────────
class TestCheckFailurePatterns:
def test_no_failures_returns_empty(self):
assert _check_failure_patterns() == []
def test_single_failure_not_enough(self):
record_event("task_failed", "once", agent_id="a1", task_id="t1")
assert _check_failure_patterns() == []
def test_two_failures_triggers_advisory(self):
for i in range(2):
record_event("task_failed", f"fail {i}", agent_id="agent-abc", task_id=f"t{i}")
results = _check_failure_patterns()
assert len(results) == 1
assert results[0].category == "failure_prevention"
assert results[0].subject == "agent-abc"
assert results[0].evidence_count == 2
def test_priority_scales_with_count(self):
for i in range(5):
record_event("task_failed", f"fail {i}", agent_id="agent-x", task_id=f"f{i}")
results = _check_failure_patterns()
assert len(results) == 1
assert results[0].priority > 0.5
def test_priority_capped_at_one(self):
for i in range(20):
record_event("task_failed", f"fail {i}", agent_id="agent-y", task_id=f"ff{i}")
results = _check_failure_patterns()
assert results[0].priority <= 1.0
def test_multiple_agents_separate_advisories(self):
for i in range(3):
record_event("task_failed", f"a fail {i}", agent_id="agent-a", task_id=f"a{i}")
record_event("task_failed", f"b fail {i}", agent_id="agent-b", task_id=f"b{i}")
results = _check_failure_patterns()
assert len(results) == 2
subjects = {r.subject for r in results}
assert subjects == {"agent-a", "agent-b"}
def test_events_without_agent_id_skipped(self):
for i in range(3):
record_event("task_failed", f"no-agent {i}", task_id=f"na{i}")
assert _check_failure_patterns() == []
# ── _check_agent_performance ───────────────────────────────────────────────
class TestCheckAgentPerformance:
def test_no_events_returns_empty(self):
assert _check_agent_performance() == []
def test_too_few_tasks_skipped(self):
record_event("task_completed", "done", agent_id="agent-1", task_id="t1")
assert _check_agent_performance() == []
def test_high_performer_detected(self):
for i in range(4):
record_event("task_completed", f"done {i}", agent_id="agent-star", task_id=f"s{i}")
results = _check_agent_performance()
perf = [r for r in results if r.category == "agent_performance"]
assert len(perf) == 1
assert "excels" in perf[0].title
assert perf[0].subject == "agent-star"
def test_struggling_agent_detected(self):
# 1 success, 4 failures = 20% rate
record_event("task_completed", "ok", agent_id="agent-bad", task_id="ok1")
for i in range(4):
record_event("task_failed", f"nope {i}", agent_id="agent-bad", task_id=f"bad{i}")
results = _check_agent_performance()
struggling = [r for r in results if "struggling" in r.title]
assert len(struggling) == 1
assert struggling[0].priority > 0.5
def test_middling_agent_no_advisory(self):
# 50% success rate — neither excelling nor struggling
for i in range(3):
record_event("task_completed", f"ok {i}", agent_id="agent-mid", task_id=f"m{i}")
for i in range(3):
record_event("task_failed", f"nope {i}", agent_id="agent-mid", task_id=f"mf{i}")
results = _check_agent_performance()
mid_advisories = [r for r in results if r.subject == "agent-mid"]
assert mid_advisories == []
def test_events_without_agent_id_skipped(self):
for i in range(5):
record_event("task_completed", f"done {i}", task_id=f"no-agent-{i}")
assert _check_agent_performance() == []
# ── _check_bid_patterns ────────────────────────────────────────────────────
class TestCheckBidPatterns:
def _record_bids(self, amounts):
for i, sats in enumerate(amounts):
record_event(
"bid_submitted",
f"bid {i}",
agent_id=f"a{i}",
task_id=f"bt{i}",
data=json.dumps({"bid_sats": sats}),
)
def test_too_few_bids_returns_empty(self):
self._record_bids([10, 20, 30])
assert _check_bid_patterns() == []
def test_wide_spread_detected(self):
# avg=50, spread=90 > 50*1.5=75
self._record_bids([5, 10, 50, 90, 95])
results = _check_bid_patterns()
spread_advisories = [r for r in results if "spread" in r.title.lower()]
assert len(spread_advisories) == 1
def test_high_average_detected(self):
self._record_bids([80, 85, 90, 95, 100])
results = _check_bid_patterns()
high_avg = [r for r in results if "High average" in r.title]
assert len(high_avg) == 1
def test_normal_bids_no_advisory(self):
# Tight spread, low average
self._record_bids([30, 32, 28, 31, 29])
results = _check_bid_patterns()
assert results == []
def test_invalid_json_data_skipped(self):
for i in range(6):
record_event(
"bid_submitted",
f"bid {i}",
agent_id=f"a{i}",
task_id=f"inv{i}",
data="not-json",
)
results = _check_bid_patterns()
assert results == []
def test_zero_bid_sats_skipped(self):
for i in range(6):
record_event(
"bid_submitted",
f"bid {i}",
data=json.dumps({"bid_sats": 0}),
)
assert _check_bid_patterns() == []
def test_both_spread_and_high_avg(self):
# Wide spread AND high average: avg=82, spread=150 > 82*1.5=123
self._record_bids([5, 80, 90, 100, 155])
results = _check_bid_patterns()
assert len(results) == 2
# ── _check_prediction_accuracy ─────────────────────────────────────────────
class TestCheckPredictionAccuracy:
def test_too_few_evaluations(self):
assert _check_prediction_accuracy() == []
def test_low_accuracy_advisory(self):
from spark.eidos import evaluate_prediction, predict_task_outcome
for i in range(4):
predict_task_outcome(f"pa-{i}", "task", ["agent-a"])
evaluate_prediction(f"pa-{i}", "agent-wrong", task_succeeded=False, winning_bid=999)
results = _check_prediction_accuracy()
low = [r for r in results if "Low prediction" in r.title]
assert len(low) == 1
assert low[0].priority > 0.5
def test_high_accuracy_advisory(self):
from spark.eidos import evaluate_prediction, predict_task_outcome
for i in range(4):
predict_task_outcome(f"ph-{i}", "task", ["agent-a"])
evaluate_prediction(f"ph-{i}", "agent-a", task_succeeded=True, winning_bid=30)
results = _check_prediction_accuracy()
high = [r for r in results if "Strong prediction" in r.title]
assert len(high) == 1
def test_middling_accuracy_no_advisory(self):
from spark.eidos import evaluate_prediction, predict_task_outcome
# Mix of correct and incorrect to get ~0.5 accuracy
for i in range(3):
predict_task_outcome(f"pm-{i}", "task", ["agent-a"])
evaluate_prediction(f"pm-{i}", "agent-a", task_succeeded=True, winning_bid=30)
for i in range(3):
predict_task_outcome(f"pmx-{i}", "task", ["agent-a"])
evaluate_prediction(f"pmx-{i}", "agent-wrong", task_succeeded=False, winning_bid=999)
results = _check_prediction_accuracy()
# avg should be middling — neither low nor high advisory
low = [r for r in results if "Low" in r.title]
high = [r for r in results if "Strong" in r.title]
# At least one side should be empty (depends on exact accuracy)
assert not (low and high)
# ── _check_system_activity ─────────────────────────────────────────────────
class TestCheckSystemActivity:
def test_no_events_idle_advisory(self):
results = _check_system_activity()
assert len(results) == 1
assert "No swarm activity" in results[0].title
def test_has_events_no_idle_advisory(self):
record_event("task_completed", "done", task_id="t1")
results = _check_system_activity()
idle = [r for r in results if "No swarm activity" in r.title]
assert idle == []
def test_tasks_posted_but_none_completing(self):
for i in range(5):
record_event("task_posted", f"posted {i}", task_id=f"tp{i}")
results = _check_system_activity()
stalled = [r for r in results if "none completing" in r.title.lower()]
assert len(stalled) == 1
assert stalled[0].evidence_count >= 4
def test_posts_with_completions_no_stalled_advisory(self):
for i in range(5):
record_event("task_posted", f"posted {i}", task_id=f"tpx{i}")
record_event("task_completed", "done", task_id="tpx0")
results = _check_system_activity()
stalled = [r for r in results if "none completing" in r.title.lower()]
assert stalled == []
# ── generate_advisories (integration) ──────────────────────────────────────
class TestGenerateAdvisories:
def test_below_min_events_returns_insufficient(self):
advisories = generate_advisories()
assert len(advisories) >= 1
assert advisories[0].title == "Insufficient data"
assert advisories[0].evidence_count == 0
def test_exactly_at_min_events_proceeds(self):
for i in range(_MIN_EVENTS):
record_event("task_posted", f"ev {i}", task_id=f"min{i}")
advisories = generate_advisories()
insufficient = [a for a in advisories if a.title == "Insufficient data"]
assert insufficient == []
def test_results_sorted_by_priority_descending(self):
for i in range(5):
record_event("task_posted", f"posted {i}", task_id=f"sp{i}")
for i in range(3):
record_event("task_failed", f"fail {i}", agent_id="agent-fail", task_id=f"sf{i}")
advisories = generate_advisories()
if len(advisories) >= 2:
for i in range(len(advisories) - 1):
assert advisories[i].priority >= advisories[i + 1].priority
def test_multiple_categories_produced(self):
# Create failures + posted-no-completions
for i in range(5):
record_event("task_failed", f"fail {i}", agent_id="agent-bad", task_id=f"mf{i}")
for i in range(5):
record_event("task_posted", f"posted {i}", task_id=f"mp{i}")
advisories = generate_advisories()
categories = {a.category for a in advisories}
assert len(categories) >= 2

299
tests/spark/test_eidos.py Normal file
View File

@@ -0,0 +1,299 @@
"""Comprehensive tests for spark.eidos module.
Covers:
- _get_conn (schema creation, WAL, busy timeout)
- predict_task_outcome (baseline, with history, edge cases)
- evaluate_prediction (correct, wrong, missing, double-eval)
- _compute_accuracy (all components, edge cases)
- get_predictions (filters: task_id, evaluated_only, limit)
- get_accuracy_stats (empty, after evaluations)
"""
import pytest
from spark.eidos import (
Prediction,
_compute_accuracy,
evaluate_prediction,
get_accuracy_stats,
get_predictions,
predict_task_outcome,
)
# ── Prediction dataclass ──────────────────────────────────────────────────
class TestPredictionDataclass:
def test_defaults(self):
p = Prediction(
id="1",
task_id="t1",
prediction_type="outcome",
predicted_value="{}",
actual_value=None,
accuracy=None,
created_at="2026-01-01",
evaluated_at=None,
)
assert p.actual_value is None
assert p.accuracy is None
# ── predict_task_outcome ──────────────────────────────────────────────────
class TestPredictTaskOutcome:
def test_baseline_no_history(self):
result = predict_task_outcome("t-base", "Do stuff", ["a1", "a2"])
assert result["likely_winner"] == "a1"
assert result["success_probability"] == 0.7
assert result["estimated_bid_range"] == [20, 80]
assert "baseline" in result["reasoning"]
assert "prediction_id" in result
def test_empty_candidates(self):
result = predict_task_outcome("t-empty", "Nothing", [])
assert result["likely_winner"] is None
def test_history_selects_best_agent(self):
history = {
"a1": {"success_rate": 0.3, "avg_winning_bid": 40},
"a2": {"success_rate": 0.95, "avg_winning_bid": 50},
}
result = predict_task_outcome("t-hist", "Task", ["a1", "a2"], agent_history=history)
assert result["likely_winner"] == "a2"
assert result["success_probability"] > 0.7
def test_history_agent_not_in_candidates_ignored(self):
history = {
"a-outside": {"success_rate": 0.99, "avg_winning_bid": 10},
}
result = predict_task_outcome("t-out", "Task", ["a1"], agent_history=history)
# a-outside not in candidates, so falls back to baseline
assert result["likely_winner"] == "a1"
def test_history_adjusts_bid_range(self):
history = {
"a1": {"success_rate": 0.5, "avg_winning_bid": 100},
"a2": {"success_rate": 0.8, "avg_winning_bid": 200},
}
result = predict_task_outcome("t-bid", "Task", ["a1", "a2"], agent_history=history)
low, high = result["estimated_bid_range"]
assert low == max(1, int(100 * 0.8))
assert high == int(200 * 1.2)
def test_history_with_zero_avg_bid_skipped(self):
history = {
"a1": {"success_rate": 0.8, "avg_winning_bid": 0},
}
result = predict_task_outcome("t-zero-bid", "Task", ["a1"], agent_history=history)
# Zero avg_winning_bid should be skipped, keep default range
assert result["estimated_bid_range"] == [20, 80]
def test_prediction_stored_in_db(self):
result = predict_task_outcome("t-db", "Store me", ["a1"])
preds = get_predictions(task_id="t-db")
assert len(preds) == 1
assert preds[0].id == result["prediction_id"]
assert preds[0].prediction_type == "outcome"
def test_success_probability_clamped(self):
history = {
"a1": {"success_rate": 1.5, "avg_winning_bid": 50},
}
result = predict_task_outcome("t-clamp", "Task", ["a1"], agent_history=history)
assert result["success_probability"] <= 1.0
# ── evaluate_prediction ───────────────────────────────────────────────────
class TestEvaluatePrediction:
def test_correct_prediction(self):
predict_task_outcome("t-eval-ok", "Task", ["a1"])
result = evaluate_prediction("t-eval-ok", "a1", task_succeeded=True, winning_bid=30)
assert result is not None
assert 0.0 <= result["accuracy"] <= 1.0
assert result["actual"]["winner"] == "a1"
assert result["actual"]["succeeded"] is True
def test_wrong_prediction(self):
predict_task_outcome("t-eval-wrong", "Task", ["a1"])
result = evaluate_prediction("t-eval-wrong", "a2", task_succeeded=False)
assert result is not None
assert result["accuracy"] < 1.0
def test_no_prediction_returns_none(self):
result = evaluate_prediction("nonexistent", "a1", task_succeeded=True)
assert result is None
def test_double_evaluation_returns_none(self):
predict_task_outcome("t-double", "Task", ["a1"])
evaluate_prediction("t-double", "a1", task_succeeded=True)
result = evaluate_prediction("t-double", "a1", task_succeeded=True)
assert result is None
def test_evaluation_updates_db(self):
predict_task_outcome("t-upd", "Task", ["a1"])
evaluate_prediction("t-upd", "a1", task_succeeded=True, winning_bid=50)
preds = get_predictions(task_id="t-upd", evaluated_only=True)
assert len(preds) == 1
assert preds[0].accuracy is not None
assert preds[0].actual_value is not None
assert preds[0].evaluated_at is not None
def test_winning_bid_none(self):
predict_task_outcome("t-nobid", "Task", ["a1"])
result = evaluate_prediction("t-nobid", "a1", task_succeeded=True)
assert result is not None
assert result["actual"]["winning_bid"] is None
# ── _compute_accuracy ─────────────────────────────────────────────────────
class TestComputeAccuracy:
def test_perfect_match(self):
predicted = {
"likely_winner": "a1",
"success_probability": 1.0,
"estimated_bid_range": [20, 40],
}
actual = {"winner": "a1", "succeeded": True, "winning_bid": 30}
assert _compute_accuracy(predicted, actual) == pytest.approx(1.0, abs=0.01)
def test_all_wrong(self):
predicted = {
"likely_winner": "a1",
"success_probability": 1.0,
"estimated_bid_range": [10, 20],
}
actual = {"winner": "a2", "succeeded": False, "winning_bid": 100}
assert _compute_accuracy(predicted, actual) < 0.3
def test_no_winner_in_predicted(self):
predicted = {"success_probability": 0.5, "estimated_bid_range": [20, 40]}
actual = {"winner": "a1", "succeeded": True, "winning_bid": 30}
acc = _compute_accuracy(predicted, actual)
# Winner component skipped, success + bid counted
assert 0.0 <= acc <= 1.0
def test_no_winner_in_actual(self):
predicted = {"likely_winner": "a1", "success_probability": 0.5}
actual = {"succeeded": True}
acc = _compute_accuracy(predicted, actual)
assert 0.0 <= acc <= 1.0
def test_bid_outside_range_partial_credit(self):
predicted = {
"likely_winner": "a1",
"success_probability": 1.0,
"estimated_bid_range": [20, 40],
}
# Bid just outside range
actual = {"winner": "a1", "succeeded": True, "winning_bid": 45}
acc = _compute_accuracy(predicted, actual)
assert 0.5 < acc < 1.0
def test_bid_far_outside_range(self):
predicted = {
"likely_winner": "a1",
"success_probability": 1.0,
"estimated_bid_range": [20, 40],
}
actual = {"winner": "a1", "succeeded": True, "winning_bid": 500}
acc = _compute_accuracy(predicted, actual)
assert acc < 1.0
def test_no_actual_bid(self):
predicted = {
"likely_winner": "a1",
"success_probability": 0.7,
"estimated_bid_range": [20, 40],
}
actual = {"winner": "a1", "succeeded": True, "winning_bid": None}
acc = _compute_accuracy(predicted, actual)
# Bid component skipped — only winner + success
assert 0.0 <= acc <= 1.0
def test_failed_prediction_low_probability(self):
predicted = {"success_probability": 0.1}
actual = {"succeeded": False}
acc = _compute_accuracy(predicted, actual)
# Predicted low success and task failed → high accuracy
assert acc > 0.8
# ── get_predictions ───────────────────────────────────────────────────────
class TestGetPredictions:
def test_empty_db(self):
assert get_predictions() == []
def test_filter_by_task_id(self):
predict_task_outcome("t-filter1", "A", ["a1"])
predict_task_outcome("t-filter2", "B", ["a2"])
preds = get_predictions(task_id="t-filter1")
assert len(preds) == 1
assert preds[0].task_id == "t-filter1"
def test_evaluated_only(self):
predict_task_outcome("t-eo1", "A", ["a1"])
predict_task_outcome("t-eo2", "B", ["a1"])
evaluate_prediction("t-eo1", "a1", task_succeeded=True)
preds = get_predictions(evaluated_only=True)
assert len(preds) == 1
assert preds[0].task_id == "t-eo1"
def test_limit(self):
for i in range(10):
predict_task_outcome(f"t-lim{i}", "X", ["a1"])
preds = get_predictions(limit=3)
assert len(preds) == 3
def test_combined_filters(self):
predict_task_outcome("t-combo", "A", ["a1"])
evaluate_prediction("t-combo", "a1", task_succeeded=True)
predict_task_outcome("t-combo2", "B", ["a1"])
preds = get_predictions(task_id="t-combo", evaluated_only=True)
assert len(preds) == 1
def test_order_by_created_desc(self):
for i in range(3):
predict_task_outcome(f"t-ord{i}", f"Task {i}", ["a1"])
preds = get_predictions()
# Most recent first
assert preds[0].task_id == "t-ord2"
# ── get_accuracy_stats ────────────────────────────────────────────────────
class TestGetAccuracyStats:
def test_empty(self):
stats = get_accuracy_stats()
assert stats["total_predictions"] == 0
assert stats["evaluated"] == 0
assert stats["pending"] == 0
assert stats["avg_accuracy"] == 0.0
assert stats["min_accuracy"] == 0.0
assert stats["max_accuracy"] == 0.0
def test_with_unevaluated(self):
predict_task_outcome("t-uneval", "X", ["a1"])
stats = get_accuracy_stats()
assert stats["total_predictions"] == 1
assert stats["evaluated"] == 0
assert stats["pending"] == 1
def test_with_evaluations(self):
for i in range(3):
predict_task_outcome(f"t-stats{i}", "X", ["a1"])
evaluate_prediction(f"t-stats{i}", "a1", task_succeeded=True, winning_bid=30)
stats = get_accuracy_stats()
assert stats["total_predictions"] == 3
assert stats["evaluated"] == 3
assert stats["pending"] == 0
assert stats["avg_accuracy"] > 0.0
assert stats["min_accuracy"] <= stats["avg_accuracy"] <= stats["max_accuracy"]

389
tests/spark/test_memory.py Normal file
View File

@@ -0,0 +1,389 @@
"""Comprehensive tests for spark.memory module.
Covers:
- SparkEvent / SparkMemory dataclasses
- _get_conn (schema creation, WAL, busy timeout, idempotent indexes)
- score_importance (all event types, boosts, edge cases)
- record_event (auto-importance, explicit importance, invalid JSON, swarm bridge)
- get_events (all filters, ordering, limit)
- count_events (total, by type)
- store_memory (with/without expiry)
- get_memories (all filters)
- count_memories (total, by type)
"""
import json
import pytest
from spark.memory import (
IMPORTANCE_HIGH,
IMPORTANCE_LOW,
IMPORTANCE_MEDIUM,
SparkEvent,
SparkMemory,
_get_conn,
count_events,
count_memories,
get_events,
get_memories,
record_event,
score_importance,
store_memory,
)
# ── Constants ─────────────────────────────────────────────────────────────
class TestConstants:
def test_importance_ordering(self):
assert IMPORTANCE_LOW < IMPORTANCE_MEDIUM < IMPORTANCE_HIGH
# ── Dataclasses ───────────────────────────────────────────────────────────
class TestSparkEventDataclass:
def test_all_fields(self):
ev = SparkEvent(
id="1",
event_type="task_posted",
agent_id="a1",
task_id="t1",
description="Test",
data="{}",
importance=0.5,
created_at="2026-01-01",
)
assert ev.event_type == "task_posted"
assert ev.agent_id == "a1"
def test_nullable_fields(self):
ev = SparkEvent(
id="2",
event_type="task_posted",
agent_id=None,
task_id=None,
description="",
data="{}",
importance=0.5,
created_at="2026-01-01",
)
assert ev.agent_id is None
assert ev.task_id is None
class TestSparkMemoryDataclass:
def test_all_fields(self):
mem = SparkMemory(
id="1",
memory_type="pattern",
subject="system",
content="Test insight",
confidence=0.8,
source_events=5,
created_at="2026-01-01",
expires_at="2026-12-31",
)
assert mem.memory_type == "pattern"
assert mem.expires_at == "2026-12-31"
def test_nullable_expires(self):
mem = SparkMemory(
id="2",
memory_type="anomaly",
subject="agent-1",
content="Odd behavior",
confidence=0.6,
source_events=3,
created_at="2026-01-01",
expires_at=None,
)
assert mem.expires_at is None
# ── _get_conn ─────────────────────────────────────────────────────────────
class TestGetConn:
def test_creates_tables(self):
with _get_conn() as conn:
tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()
names = {r["name"] for r in tables}
assert "spark_events" in names
assert "spark_memories" in names
def test_wal_mode(self):
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal"
def test_busy_timeout(self):
with _get_conn() as conn:
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
assert timeout == 5000
def test_idempotent(self):
# Calling _get_conn twice should not raise
with _get_conn():
pass
with _get_conn():
pass
# ── score_importance ──────────────────────────────────────────────────────
class TestScoreImportance:
@pytest.mark.parametrize(
"event_type,expected_min,expected_max",
[
("task_posted", 0.3, 0.5),
("bid_submitted", 0.1, 0.3),
("task_assigned", 0.4, 0.6),
("task_completed", 0.5, 0.7),
("task_failed", 0.9, 1.0),
("agent_joined", 0.4, 0.6),
("prediction_result", 0.6, 0.8),
],
)
def test_base_scores(self, event_type, expected_min, expected_max):
score = score_importance(event_type, {})
assert expected_min <= score <= expected_max
def test_unknown_event_default(self):
assert score_importance("never_heard_of_this", {}) == 0.5
def test_failure_boost(self):
score = score_importance("task_failed", {})
assert score == 1.0
def test_high_bid_boost(self):
low = score_importance("bid_submitted", {"bid_sats": 10})
high = score_importance("bid_submitted", {"bid_sats": 100})
assert high > low
assert high <= 1.0
def test_high_bid_on_failure(self):
score = score_importance("task_failed", {"bid_sats": 100})
assert score == 1.0 # capped at 1.0
def test_score_always_rounded(self):
score = score_importance("bid_submitted", {"bid_sats": 100})
assert score == round(score, 2)
# ── record_event ──────────────────────────────────────────────────────────
class TestRecordEvent:
def test_basic_record(self):
eid = record_event("task_posted", "New task", task_id="t1")
assert isinstance(eid, str)
assert len(eid) > 0
def test_auto_importance(self):
record_event("task_failed", "Failed", task_id="t-auto")
events = get_events(task_id="t-auto")
assert events[0].importance >= 0.9
def test_explicit_importance(self):
record_event("task_posted", "Custom", task_id="t-expl", importance=0.1)
events = get_events(task_id="t-expl")
assert events[0].importance == 0.1
def test_with_agent_and_data(self):
data = json.dumps({"bid_sats": 42})
record_event("bid_submitted", "Bid", agent_id="a1", task_id="t-data", data=data)
events = get_events(task_id="t-data")
assert events[0].agent_id == "a1"
parsed = json.loads(events[0].data)
assert parsed["bid_sats"] == 42
def test_invalid_json_data_uses_default_importance(self):
record_event("task_posted", "Bad data", task_id="t-bad", data="not-json")
events = get_events(task_id="t-bad")
assert events[0].importance == 0.4 # base for task_posted
def test_returns_unique_ids(self):
id1 = record_event("task_posted", "A")
id2 = record_event("task_posted", "B")
assert id1 != id2
# ── get_events ────────────────────────────────────────────────────────────
class TestGetEvents:
def test_empty_db(self):
assert get_events() == []
def test_filter_by_type(self):
record_event("task_posted", "A")
record_event("task_completed", "B")
events = get_events(event_type="task_posted")
assert len(events) == 1
assert events[0].event_type == "task_posted"
def test_filter_by_agent(self):
record_event("task_posted", "A", agent_id="a1")
record_event("task_posted", "B", agent_id="a2")
events = get_events(agent_id="a1")
assert len(events) == 1
assert events[0].agent_id == "a1"
def test_filter_by_task(self):
record_event("task_posted", "A", task_id="t1")
record_event("task_posted", "B", task_id="t2")
events = get_events(task_id="t1")
assert len(events) == 1
def test_filter_by_min_importance(self):
record_event("task_posted", "Low", importance=0.1)
record_event("task_failed", "High", importance=0.9)
events = get_events(min_importance=0.5)
assert len(events) == 1
assert events[0].importance >= 0.5
def test_limit(self):
for i in range(10):
record_event("task_posted", f"ev{i}")
events = get_events(limit=3)
assert len(events) == 3
def test_order_by_created_desc(self):
record_event("task_posted", "first", task_id="ord1")
record_event("task_posted", "second", task_id="ord2")
events = get_events()
# Most recent first
assert events[0].task_id == "ord2"
def test_combined_filters(self):
record_event("task_failed", "A", agent_id="a1", task_id="t1", importance=0.9)
record_event("task_posted", "B", agent_id="a1", task_id="t2", importance=0.4)
record_event("task_failed", "C", agent_id="a2", task_id="t3", importance=0.9)
events = get_events(event_type="task_failed", agent_id="a1", min_importance=0.5)
assert len(events) == 1
assert events[0].task_id == "t1"
# ── count_events ──────────────────────────────────────────────────────────
class TestCountEvents:
def test_empty(self):
assert count_events() == 0
def test_total(self):
record_event("task_posted", "A")
record_event("task_failed", "B")
assert count_events() == 2
def test_by_type(self):
record_event("task_posted", "A")
record_event("task_posted", "B")
record_event("task_failed", "C")
assert count_events("task_posted") == 2
assert count_events("task_failed") == 1
assert count_events("task_completed") == 0
# ── store_memory ──────────────────────────────────────────────────────────
class TestStoreMemory:
def test_basic_store(self):
mid = store_memory("pattern", "system", "Test insight")
assert isinstance(mid, str)
assert len(mid) > 0
def test_returns_unique_ids(self):
id1 = store_memory("pattern", "a", "X")
id2 = store_memory("pattern", "b", "Y")
assert id1 != id2
def test_with_all_params(self):
store_memory(
"anomaly",
"agent-1",
"Odd pattern",
confidence=0.9,
source_events=10,
expires_at="2026-12-31",
)
mems = get_memories(subject="agent-1")
assert len(mems) == 1
assert mems[0].confidence == 0.9
assert mems[0].source_events == 10
assert mems[0].expires_at == "2026-12-31"
def test_default_values(self):
store_memory("insight", "sys", "Default test")
mems = get_memories(subject="sys")
assert mems[0].confidence == 0.5
assert mems[0].source_events == 0
assert mems[0].expires_at is None
# ── get_memories ──────────────────────────────────────────────────────────
class TestGetMemories:
def test_empty(self):
assert get_memories() == []
def test_filter_by_type(self):
store_memory("pattern", "a", "X")
store_memory("anomaly", "a", "Y")
mems = get_memories(memory_type="pattern")
assert len(mems) == 1
assert mems[0].memory_type == "pattern"
def test_filter_by_subject(self):
store_memory("pattern", "a", "X")
store_memory("pattern", "b", "Y")
mems = get_memories(subject="a")
assert len(mems) == 1
def test_filter_by_min_confidence(self):
store_memory("pattern", "a", "Low", confidence=0.2)
store_memory("pattern", "b", "High", confidence=0.9)
mems = get_memories(min_confidence=0.5)
assert len(mems) == 1
assert mems[0].content == "High"
def test_limit(self):
for i in range(10):
store_memory("pattern", "a", f"M{i}")
mems = get_memories(limit=3)
assert len(mems) == 3
def test_combined_filters(self):
store_memory("pattern", "a", "Target", confidence=0.9)
store_memory("anomaly", "a", "Wrong type", confidence=0.9)
store_memory("pattern", "b", "Wrong subject", confidence=0.9)
store_memory("pattern", "a", "Low conf", confidence=0.1)
mems = get_memories(memory_type="pattern", subject="a", min_confidence=0.5)
assert len(mems) == 1
assert mems[0].content == "Target"
# ── count_memories ────────────────────────────────────────────────────────
class TestCountMemories:
def test_empty(self):
assert count_memories() == 0
def test_total(self):
store_memory("pattern", "a", "X")
store_memory("anomaly", "b", "Y")
assert count_memories() == 2
def test_by_type(self):
store_memory("pattern", "a", "X")
store_memory("pattern", "b", "Y")
store_memory("anomaly", "c", "Z")
assert count_memories("pattern") == 2
assert count_memories("anomaly") == 1
assert count_memories("insight") == 0

470
tests/test_config_module.py Normal file
View File

@@ -0,0 +1,470 @@
"""Tests for src/config.py — Settings, validation, and helper functions."""
import os
from unittest.mock import patch
import pytest
class TestNormalizeOllamaUrl:
"""normalize_ollama_url replaces localhost with 127.0.0.1."""
def test_replaces_localhost(self):
from config import normalize_ollama_url
assert normalize_ollama_url("http://localhost:11434") == "http://127.0.0.1:11434"
def test_preserves_ip(self):
from config import normalize_ollama_url
assert normalize_ollama_url("http://192.168.1.5:11434") == "http://192.168.1.5:11434"
def test_preserves_non_localhost_hostname(self):
from config import normalize_ollama_url
assert normalize_ollama_url("http://ollama.local:11434") == "http://ollama.local:11434"
def test_replaces_multiple_occurrences(self):
from config import normalize_ollama_url
result = normalize_ollama_url("http://localhost:11434/localhost")
assert result == "http://127.0.0.1:11434/127.0.0.1"
class TestSettingsDefaults:
"""Settings instantiation produces correct defaults."""
def _make_settings(self, **env_overrides):
"""Create a fresh Settings instance with given env overrides."""
from config import Settings
clean_env = {
k: v
for k, v in os.environ.items()
if not k.startswith(("OLLAMA_", "TIMMY_", "AGENT_", "DEBUG"))
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
return Settings()
def test_default_agent_name(self):
s = self._make_settings()
assert s.agent_name == "Agent"
def test_default_ollama_url(self):
s = self._make_settings()
assert s.ollama_url == "http://localhost:11434"
def test_default_ollama_model(self):
s = self._make_settings()
assert s.ollama_model == "qwen3:30b"
def test_default_ollama_num_ctx(self):
s = self._make_settings()
assert s.ollama_num_ctx == 4096
def test_default_debug_false(self):
s = self._make_settings()
assert s.debug is False
def test_default_timmy_env(self):
s = self._make_settings()
assert s.timmy_env == "development"
def test_default_timmy_test_mode(self):
s = self._make_settings()
assert s.timmy_test_mode is False
def test_default_spark_enabled(self):
s = self._make_settings()
assert s.spark_enabled is True
def test_default_lightning_backend(self):
s = self._make_settings()
assert s.lightning_backend == "mock"
def test_default_max_agent_steps(self):
s = self._make_settings()
assert s.max_agent_steps == 10
def test_default_memory_prune_days(self):
s = self._make_settings()
assert s.memory_prune_days == 90
def test_default_fallback_models_is_list(self):
s = self._make_settings()
assert isinstance(s.fallback_models, list)
assert len(s.fallback_models) > 0
def test_default_cors_origins_is_list(self):
s = self._make_settings()
assert isinstance(s.cors_origins, list)
def test_default_trusted_hosts_is_list(self):
s = self._make_settings()
assert isinstance(s.trusted_hosts, list)
assert "localhost" in s.trusted_hosts
def test_normalized_ollama_url_property(self):
s = self._make_settings()
assert "127.0.0.1" in s.normalized_ollama_url
assert "localhost" not in s.normalized_ollama_url
class TestSettingsEnvOverrides:
"""Environment variables override default values."""
def _make_settings(self, **env_overrides):
from config import Settings
clean_env = {
k: v
for k, v in os.environ.items()
if not k.startswith(("OLLAMA_", "TIMMY_", "AGENT_", "DEBUG"))
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
return Settings()
def test_agent_name_override(self):
s = self._make_settings(AGENT_NAME="Timmy")
assert s.agent_name == "Timmy"
def test_ollama_url_override(self):
s = self._make_settings(OLLAMA_URL="http://10.0.0.1:11434")
assert s.ollama_url == "http://10.0.0.1:11434"
def test_ollama_model_override(self):
s = self._make_settings(OLLAMA_MODEL="llama3.1")
assert s.ollama_model == "llama3.1"
def test_debug_true_from_string(self):
s = self._make_settings(DEBUG="true")
assert s.debug is True
def test_debug_false_from_string(self):
s = self._make_settings(DEBUG="false")
assert s.debug is False
def test_numeric_override(self):
s = self._make_settings(OLLAMA_NUM_CTX="8192")
assert s.ollama_num_ctx == 8192
def test_max_agent_steps_override(self):
s = self._make_settings(MAX_AGENT_STEPS="25")
assert s.max_agent_steps == 25
def test_timmy_env_production(self):
s = self._make_settings(TIMMY_ENV="production")
assert s.timmy_env == "production"
def test_timmy_test_mode_true(self):
s = self._make_settings(TIMMY_TEST_MODE="true")
assert s.timmy_test_mode is True
def test_grok_enabled_override(self):
s = self._make_settings(GROK_ENABLED="true")
assert s.grok_enabled is True
def test_spark_enabled_override(self):
s = self._make_settings(SPARK_ENABLED="false")
assert s.spark_enabled is False
def test_memory_prune_days_override(self):
s = self._make_settings(MEMORY_PRUNE_DAYS="30")
assert s.memory_prune_days == 30
class TestSettingsTypeValidation:
"""Pydantic correctly parses and validates types from string env vars."""
def _make_settings(self, **env_overrides):
from config import Settings
clean_env = {
k: v
for k, v in os.environ.items()
if not k.startswith(("OLLAMA_", "TIMMY_", "AGENT_", "DEBUG"))
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
return Settings()
def test_bool_from_1(self):
s = self._make_settings(DEBUG="1")
assert s.debug is True
def test_bool_from_0(self):
s = self._make_settings(DEBUG="0")
assert s.debug is False
def test_int_field_rejects_non_numeric(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
self._make_settings(OLLAMA_NUM_CTX="not_a_number")
def test_literal_field_rejects_invalid(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
self._make_settings(TIMMY_ENV="staging")
def test_literal_backend_rejects_invalid(self):
from pydantic import ValidationError
with pytest.raises(ValidationError):
self._make_settings(TIMMY_MODEL_BACKEND="openai")
def test_literal_backend_accepts_valid(self):
for backend in ("ollama", "grok", "claude", "auto"):
s = self._make_settings(TIMMY_MODEL_BACKEND=backend)
assert s.timmy_model_backend == backend
def test_extra_fields_ignored(self):
# model_config has extra="ignore"
s = self._make_settings(TOTALLY_UNKNOWN_FIELD="hello")
assert not hasattr(s, "totally_unknown_field")
class TestSettingsEdgeCases:
"""Edge cases: empty strings, missing vars, boundary values."""
def _make_settings(self, **env_overrides):
from config import Settings
clean_env = {
k: v
for k, v in os.environ.items()
if not k.startswith(("OLLAMA_", "TIMMY_", "AGENT_", "DEBUG"))
}
clean_env.update(env_overrides)
with patch.dict(os.environ, clean_env, clear=True):
return Settings()
def test_empty_string_tokens_stay_empty(self):
s = self._make_settings(TELEGRAM_TOKEN="", DISCORD_TOKEN="")
assert s.telegram_token == ""
assert s.discord_token == ""
def test_zero_int_fields(self):
s = self._make_settings(OLLAMA_NUM_CTX="0", MEMORY_PRUNE_DAYS="0")
assert s.ollama_num_ctx == 0
assert s.memory_prune_days == 0
def test_large_int_value(self):
s = self._make_settings(CHAT_API_MAX_BODY_BYTES="104857600")
assert s.chat_api_max_body_bytes == 104857600
def test_negative_int_accepted(self):
# Pydantic doesn't constrain these to positive
s = self._make_settings(MAX_AGENT_STEPS="-1")
assert s.max_agent_steps == -1
class TestComputeRepoRoot:
"""_compute_repo_root auto-detects .git directory."""
def test_returns_string(self):
from config import Settings
s = Settings()
result = s._compute_repo_root()
assert isinstance(result, str)
assert len(result) > 0
def test_explicit_repo_root_used(self):
from config import Settings
with patch.dict(os.environ, {"REPO_ROOT": "/tmp/myrepo"}, clear=False):
s = Settings()
s.repo_root = "/tmp/myrepo"
assert s._compute_repo_root() == "/tmp/myrepo"
class TestModelPostInit:
"""model_post_init resolves gitea_token from file fallback."""
def test_gitea_token_from_env(self):
from config import Settings
with patch.dict(os.environ, {"GITEA_TOKEN": "test-token-123"}, clear=False):
s = Settings()
assert s.gitea_token == "test-token-123"
def test_gitea_token_stays_empty_when_no_file(self):
from config import Settings
env = {k: v for k, v in os.environ.items() if k != "GITEA_TOKEN"}
with patch.dict(os.environ, env, clear=True):
with patch("os.path.isfile", return_value=False):
s = Settings()
assert s.gitea_token == ""
class TestCheckOllamaModelAvailable:
"""check_ollama_model_available handles network responses and errors."""
def test_returns_false_on_network_error(self):
from config import check_ollama_model_available
with patch("urllib.request.urlopen", side_effect=OSError("Connection refused")):
assert check_ollama_model_available("llama3.1") is False
def test_returns_true_when_model_found(self):
import json
from unittest.mock import MagicMock
from config import check_ollama_model_available
response_data = json.dumps({"models": [{"name": "llama3.1:8b-instruct"}]}).encode()
mock_response = MagicMock()
mock_response.read.return_value = response_data
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_response):
assert check_ollama_model_available("llama3.1") is True
def test_returns_false_when_model_not_found(self):
import json
from unittest.mock import MagicMock
from config import check_ollama_model_available
response_data = json.dumps({"models": [{"name": "qwen2.5:7b"}]}).encode()
mock_response = MagicMock()
mock_response.read.return_value = response_data
mock_response.__enter__ = lambda s: s
mock_response.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_response):
assert check_ollama_model_available("llama3.1") is False
class TestGetEffectiveOllamaModel:
"""get_effective_ollama_model walks fallback chain."""
def test_returns_primary_when_available(self):
from config import get_effective_ollama_model
with patch("config.check_ollama_model_available", return_value=True):
result = get_effective_ollama_model()
assert result == "qwen3:30b"
def test_falls_back_when_primary_unavailable(self):
from config import get_effective_ollama_model
def side_effect(model):
return model == "llama3.1:8b-instruct"
with patch("config.check_ollama_model_available", side_effect=side_effect):
result = get_effective_ollama_model()
assert result == "llama3.1:8b-instruct"
def test_returns_user_model_when_nothing_available(self):
from config import get_effective_ollama_model
with patch("config.check_ollama_model_available", return_value=False):
result = get_effective_ollama_model()
assert result == "qwen3:30b"
class TestValidateStartup:
"""validate_startup enforces security in production, warns in dev."""
def setup_method(self):
import config
config._startup_validated = False
def test_skips_in_test_mode(self):
import config
with patch.dict(os.environ, {"TIMMY_TEST_MODE": "1"}):
config.validate_startup()
assert config._startup_validated is True
def test_dev_mode_warns_but_does_not_exit(self, caplog):
import logging
import config
config._startup_validated = False
env = {k: v for k, v in os.environ.items() if k != "TIMMY_TEST_MODE"}
env["TIMMY_ENV"] = "development"
with patch.dict(os.environ, env, clear=True):
with caplog.at_level(logging.WARNING, logger="config"):
config.validate_startup()
assert config._startup_validated is True
def test_production_exits_without_secrets(self):
import config
config._startup_validated = False
env = {k: v for k, v in os.environ.items() if k != "TIMMY_TEST_MODE"}
env["TIMMY_ENV"] = "production"
env.pop("L402_HMAC_SECRET", None)
env.pop("L402_MACAROON_SECRET", None)
with patch.dict(os.environ, env, clear=True):
with patch.object(config.settings, "timmy_env", "production"):
with patch.object(config.settings, "l402_hmac_secret", ""):
with patch.object(config.settings, "l402_macaroon_secret", ""):
with pytest.raises(SystemExit):
config.validate_startup(force=True)
def test_production_exits_with_cors_wildcard(self):
import config
config._startup_validated = False
env = {k: v for k, v in os.environ.items() if k != "TIMMY_TEST_MODE"}
env["TIMMY_ENV"] = "production"
with patch.dict(os.environ, env, clear=True):
with patch.object(config.settings, "timmy_env", "production"):
with patch.object(config.settings, "l402_hmac_secret", "secret1"):
with patch.object(config.settings, "l402_macaroon_secret", "secret2"):
with patch.object(config.settings, "cors_origins", ["*"]):
with pytest.raises(SystemExit):
config.validate_startup(force=True)
def test_production_passes_with_all_secrets(self):
import config
config._startup_validated = False
env = {k: v for k, v in os.environ.items() if k != "TIMMY_TEST_MODE"}
env["TIMMY_ENV"] = "production"
with patch.dict(os.environ, env, clear=True):
with patch.object(config.settings, "timmy_env", "production"):
with patch.object(config.settings, "l402_hmac_secret", "secret1"):
with patch.object(config.settings, "l402_macaroon_secret", "secret2"):
with patch.object(
config.settings,
"cors_origins",
["http://localhost:3000"],
):
config.validate_startup(force=True)
assert config._startup_validated is True
def test_idempotent_without_force(self):
import config
config._startup_validated = True
# Should return immediately without doing anything
config.validate_startup()
assert config._startup_validated is True
class TestAppStartTime:
"""APP_START_TIME is set at module load."""
def test_app_start_time_is_datetime(self):
from datetime import datetime
from config import APP_START_TIME
assert isinstance(APP_START_TIME, datetime)
def test_app_start_time_has_timezone(self):
from config import APP_START_TIME
assert APP_START_TIME.tzinfo is not None

View File

@@ -0,0 +1,331 @@
"""Tests for the matrix configuration loader utility."""
from pathlib import Path
import pytest
import yaml
from infrastructure.matrix_config import (
AgentConfig,
AgentsConfig,
EnvironmentConfig,
FeaturesConfig,
LightingConfig,
MatrixConfig,
PointLight,
load_from_yaml,
)
class TestPointLight:
"""Tests for PointLight dataclass."""
def test_default_values(self):
"""PointLight has correct defaults."""
pl = PointLight()
assert pl.color == "#FFFFFF"
assert pl.intensity == 1.0
assert pl.position == {"x": 0, "y": 0, "z": 0}
def test_from_dict_full(self):
"""PointLight.from_dict loads all fields."""
data = {
"color": "#FF0000",
"intensity": 2.5,
"position": {"x": 1, "y": 2, "z": 3},
}
pl = PointLight.from_dict(data)
assert pl.color == "#FF0000"
assert pl.intensity == 2.5
assert pl.position == {"x": 1, "y": 2, "z": 3}
def test_from_dict_partial(self):
"""PointLight.from_dict fills missing fields with defaults."""
data = {"color": "#00FF00"}
pl = PointLight.from_dict(data)
assert pl.color == "#00FF00"
assert pl.intensity == 1.0
assert pl.position == {"x": 0, "y": 0, "z": 0}
class TestLightingConfig:
"""Tests for LightingConfig dataclass."""
def test_default_values(self):
"""LightingConfig has correct Workshop+Matrix blend defaults."""
cfg = LightingConfig()
assert cfg.ambient_color == "#FFAA55" # Warm amber (Workshop)
assert cfg.ambient_intensity == 0.5
assert len(cfg.point_lights) == 3
# First light is warm amber center
assert cfg.point_lights[0].color == "#FFAA55"
# Second light is cool blue (Matrix)
assert cfg.point_lights[1].color == "#3B82F6"
def test_from_dict_full(self):
"""LightingConfig.from_dict loads all fields."""
data = {
"ambient_color": "#123456",
"ambient_intensity": 0.8,
"point_lights": [
{"color": "#ABCDEF", "intensity": 1.5, "position": {"x": 1, "y": 1, "z": 1}}
],
}
cfg = LightingConfig.from_dict(data)
assert cfg.ambient_color == "#123456"
assert cfg.ambient_intensity == 0.8
assert len(cfg.point_lights) == 1
assert cfg.point_lights[0].color == "#ABCDEF"
def test_from_dict_empty_list_uses_defaults(self):
"""Empty point_lights list triggers default lights."""
data = {"ambient_color": "#000000", "point_lights": []}
cfg = LightingConfig.from_dict(data)
assert cfg.ambient_color == "#000000"
assert len(cfg.point_lights) == 3 # Default lights
def test_from_dict_none(self):
"""LightingConfig.from_dict handles None."""
cfg = LightingConfig.from_dict(None)
assert cfg.ambient_color == "#FFAA55"
assert len(cfg.point_lights) == 3
class TestEnvironmentConfig:
"""Tests for EnvironmentConfig dataclass."""
def test_default_values(self):
"""EnvironmentConfig has correct defaults."""
cfg = EnvironmentConfig()
assert cfg.rain_enabled is False
assert cfg.starfield_enabled is True # Matrix starfield
assert cfg.fog_color == "#0f0f23"
assert cfg.fog_density == 0.02
def test_from_dict_full(self):
"""EnvironmentConfig.from_dict loads all fields."""
data = {
"rain_enabled": True,
"starfield_enabled": False,
"fog_color": "#FFFFFF",
"fog_density": 0.1,
}
cfg = EnvironmentConfig.from_dict(data)
assert cfg.rain_enabled is True
assert cfg.starfield_enabled is False
assert cfg.fog_color == "#FFFFFF"
assert cfg.fog_density == 0.1
def test_from_dict_partial(self):
"""EnvironmentConfig.from_dict fills missing fields."""
data = {"rain_enabled": True}
cfg = EnvironmentConfig.from_dict(data)
assert cfg.rain_enabled is True
assert cfg.starfield_enabled is True # Default
assert cfg.fog_color == "#0f0f23"
class TestFeaturesConfig:
"""Tests for FeaturesConfig dataclass."""
def test_default_values_all_enabled(self):
"""FeaturesConfig defaults to all features enabled."""
cfg = FeaturesConfig()
assert cfg.chat_enabled is True
assert cfg.visitor_avatars is True
assert cfg.pip_familiar is True
assert cfg.workshop_portal is True
def test_from_dict_full(self):
"""FeaturesConfig.from_dict loads all fields."""
data = {
"chat_enabled": False,
"visitor_avatars": False,
"pip_familiar": False,
"workshop_portal": False,
}
cfg = FeaturesConfig.from_dict(data)
assert cfg.chat_enabled is False
assert cfg.visitor_avatars is False
assert cfg.pip_familiar is False
assert cfg.workshop_portal is False
def test_from_dict_partial(self):
"""FeaturesConfig.from_dict fills missing fields."""
data = {"chat_enabled": False}
cfg = FeaturesConfig.from_dict(data)
assert cfg.chat_enabled is False
assert cfg.visitor_avatars is True # Default
assert cfg.pip_familiar is True
assert cfg.workshop_portal is True
class TestAgentConfig:
"""Tests for AgentConfig dataclass."""
def test_default_values(self):
"""AgentConfig has correct defaults."""
cfg = AgentConfig()
assert cfg.name == ""
assert cfg.role == ""
assert cfg.enabled is True
def test_from_dict_full(self):
"""AgentConfig.from_dict loads all fields."""
data = {"name": "Timmy", "role": "guide", "enabled": False}
cfg = AgentConfig.from_dict(data)
assert cfg.name == "Timmy"
assert cfg.role == "guide"
assert cfg.enabled is False
class TestAgentsConfig:
"""Tests for AgentsConfig dataclass."""
def test_default_values(self):
"""AgentsConfig has correct defaults."""
cfg = AgentsConfig()
assert cfg.default_count == 5
assert cfg.max_count == 20
assert cfg.agents == []
def test_from_dict_with_agents(self):
"""AgentsConfig.from_dict loads agent list."""
data = {
"default_count": 10,
"max_count": 50,
"agents": [
{"name": "Timmy", "role": "guide", "enabled": True},
{"name": "Helper", "role": "assistant"},
],
}
cfg = AgentsConfig.from_dict(data)
assert cfg.default_count == 10
assert cfg.max_count == 50
assert len(cfg.agents) == 2
assert cfg.agents[0].name == "Timmy"
assert cfg.agents[1].enabled is True # Default
class TestMatrixConfig:
"""Tests for MatrixConfig dataclass."""
def test_default_values(self):
"""MatrixConfig has correct composite defaults."""
cfg = MatrixConfig()
assert isinstance(cfg.lighting, LightingConfig)
assert isinstance(cfg.environment, EnvironmentConfig)
assert isinstance(cfg.features, FeaturesConfig)
assert isinstance(cfg.agents, AgentsConfig)
# Check the blend
assert cfg.lighting.ambient_color == "#FFAA55"
assert cfg.environment.starfield_enabled is True
assert cfg.features.chat_enabled is True
def test_from_dict_full(self):
"""MatrixConfig.from_dict loads all sections."""
data = {
"lighting": {"ambient_color": "#000000"},
"environment": {"rain_enabled": True},
"features": {"chat_enabled": False},
"agents": {"default_count": 3},
}
cfg = MatrixConfig.from_dict(data)
assert cfg.lighting.ambient_color == "#000000"
assert cfg.environment.rain_enabled is True
assert cfg.features.chat_enabled is False
assert cfg.agents.default_count == 3
def test_from_dict_partial(self):
"""MatrixConfig.from_dict fills missing sections with defaults."""
data = {"lighting": {"ambient_color": "#111111"}}
cfg = MatrixConfig.from_dict(data)
assert cfg.lighting.ambient_color == "#111111"
assert cfg.environment.starfield_enabled is True # Default
assert cfg.features.pip_familiar is True # Default
def test_from_dict_none(self):
"""MatrixConfig.from_dict handles None."""
cfg = MatrixConfig.from_dict(None)
assert cfg.lighting.ambient_color == "#FFAA55"
assert cfg.features.chat_enabled is True
def test_to_dict_roundtrip(self):
"""MatrixConfig.to_dict produces serializable output."""
cfg = MatrixConfig()
data = cfg.to_dict()
assert isinstance(data, dict)
assert "lighting" in data
assert "environment" in data
assert "features" in data
assert "agents" in data
# Verify point lights are included
assert len(data["lighting"]["point_lights"]) == 3
class TestLoadFromYaml:
"""Tests for load_from_yaml function."""
def test_loads_valid_yaml(self, tmp_path: Path):
"""load_from_yaml reads a valid YAML file."""
config_path = tmp_path / "matrix.yaml"
data = {
"lighting": {"ambient_color": "#TEST11"},
"features": {"chat_enabled": False},
}
config_path.write_text(yaml.safe_dump(data))
cfg = load_from_yaml(config_path)
assert cfg.lighting.ambient_color == "#TEST11"
assert cfg.features.chat_enabled is False
def test_missing_file_returns_defaults(self, tmp_path: Path):
"""load_from_yaml returns defaults when file doesn't exist."""
config_path = tmp_path / "nonexistent.yaml"
cfg = load_from_yaml(config_path)
assert cfg.lighting.ambient_color == "#FFAA55"
assert cfg.features.chat_enabled is True
def test_empty_file_returns_defaults(self, tmp_path: Path):
"""load_from_yaml returns defaults for empty file."""
config_path = tmp_path / "empty.yaml"
config_path.write_text("")
cfg = load_from_yaml(config_path)
assert cfg.lighting.ambient_color == "#FFAA55"
def test_invalid_yaml_returns_defaults(self, tmp_path: Path):
"""load_from_yaml returns defaults for invalid YAML."""
config_path = tmp_path / "invalid.yaml"
config_path.write_text("not: valid: yaml: [")
cfg = load_from_yaml(config_path)
assert cfg.lighting.ambient_color == "#FFAA55"
assert cfg.features.chat_enabled is True
def test_non_dict_yaml_returns_defaults(self, tmp_path: Path):
"""load_from_yaml returns defaults when YAML is not a dict."""
config_path = tmp_path / "list.yaml"
config_path.write_text("- item1\n- item2")
cfg = load_from_yaml(config_path)
assert cfg.lighting.ambient_color == "#FFAA55"
def test_loads_actual_config_file(self):
"""load_from_yaml can load the project's config/matrix.yaml."""
repo_root = Path(__file__).parent.parent.parent
config_path = repo_root / "config" / "matrix.yaml"
if not config_path.exists():
pytest.skip("config/matrix.yaml not found")
cfg = load_from_yaml(config_path)
# Verify it loaded with expected values
assert cfg.lighting.ambient_color == "#FFAA55"
assert len(cfg.lighting.point_lights) == 3
assert cfg.environment.starfield_enabled is True
assert cfg.features.workshop_portal is True
def test_str_path_accepted(self, tmp_path: Path):
"""load_from_yaml accepts string path."""
config_path = tmp_path / "matrix.yaml"
config_path.write_text(yaml.safe_dump({"lighting": {"ambient_intensity": 0.9}}))
cfg = load_from_yaml(str(config_path))
assert cfg.lighting.ambient_intensity == 0.9

502
tests/unit/test_presence.py Normal file
View File

@@ -0,0 +1,502 @@
"""Tests for infrastructure.presence — presence state serializer."""
from unittest.mock import patch
import pytest
from infrastructure.presence import (
DEFAULT_PIP_STATE,
_get_familiar_state,
produce_agent_state,
produce_bark,
produce_system_status,
produce_thought,
serialize_presence,
)
class TestSerializePresence:
"""Round-trip and edge-case tests for serialize_presence()."""
@pytest.fixture()
def full_presence(self):
"""A complete ADR-023 presence dict."""
return {
"version": 1,
"liveness": "2026-03-21T12:00:00Z",
"current_focus": "writing tests",
"mood": "focused",
"energy": 0.9,
"confidence": 0.85,
"active_threads": [
{"type": "thinking", "ref": "refactor presence", "status": "active"}
],
"recent_events": ["committed code"],
"concerns": ["test coverage"],
"familiar": {"name": "Pip", "state": "alert"},
}
def test_full_round_trip(self, full_presence):
"""All ADR-023 fields map to the expected camelCase keys."""
result = serialize_presence(full_presence)
assert result["timmyState"]["mood"] == "focused"
assert result["timmyState"]["activity"] == "writing tests"
assert result["timmyState"]["energy"] == 0.9
assert result["timmyState"]["confidence"] == 0.85
assert result["familiar"] == {"name": "Pip", "state": "alert"}
assert result["activeThreads"] == full_presence["active_threads"]
assert result["recentEvents"] == ["committed code"]
assert result["concerns"] == ["test coverage"]
assert result["visitorPresent"] is False
assert result["updatedAt"] == "2026-03-21T12:00:00Z"
assert result["version"] == 1
def test_defaults_on_empty_dict(self):
"""Missing fields fall back to safe defaults."""
result = serialize_presence({})
assert result["timmyState"]["mood"] == "calm"
assert result["timmyState"]["activity"] == "idle"
assert result["timmyState"]["energy"] == 0.5
assert result["timmyState"]["confidence"] == 0.7
assert result["familiar"] is None
assert result["activeThreads"] == []
assert result["recentEvents"] == []
assert result["concerns"] == []
assert result["visitorPresent"] is False
assert result["version"] == 1
# updatedAt should be an ISO timestamp string
assert "T" in result["updatedAt"]
def test_partial_presence(self):
"""Only some fields provided — others get defaults."""
result = serialize_presence({"mood": "excited", "energy": 0.3})
assert result["timmyState"]["mood"] == "excited"
assert result["timmyState"]["energy"] == 0.3
assert result["timmyState"]["confidence"] == 0.7 # default
assert result["activeThreads"] == [] # default
def test_return_type_is_dict(self, full_presence):
"""serialize_presence always returns a plain dict."""
result = serialize_presence(full_presence)
assert isinstance(result, dict)
assert isinstance(result["timmyState"], dict)
def test_visitor_present_always_false(self, full_presence):
"""visitorPresent is always False — set by the WS layer, not here."""
assert serialize_presence(full_presence)["visitorPresent"] is False
assert serialize_presence({})["visitorPresent"] is False
class TestProduceAgentState:
"""Tests for produce_agent_state() — Matrix agent_state message producer."""
@pytest.fixture()
def full_presence(self):
"""A presence dict with all agent_state-relevant fields."""
return {
"display_name": "Timmy",
"role": "companion",
"current_focus": "thinking about tests",
"mood": "focused",
"energy": 0.9,
"bark": "Running test suite...",
}
@patch("infrastructure.presence.time")
def test_full_message_structure(self, mock_time, full_presence):
"""Returns dict with type, agent_id, data, and ts keys."""
mock_time.time.return_value = 1742529600
result = produce_agent_state("timmy", full_presence)
assert result["type"] == "agent_state"
assert result["agent_id"] == "timmy"
assert result["ts"] == 1742529600
assert isinstance(result["data"], dict)
def test_data_fields(self, full_presence):
"""data dict contains all required presence fields."""
data = produce_agent_state("timmy", full_presence)["data"]
assert data["display_name"] == "Timmy"
assert data["role"] == "companion"
assert data["status"] == "thinking"
assert data["mood"] == "focused"
assert data["energy"] == 0.9
assert data["bark"] == "Running test suite..."
def test_defaults_on_empty_presence(self):
"""Missing fields get sensible defaults."""
result = produce_agent_state("timmy", {})
data = result["data"]
assert data["display_name"] == "Timmy" # agent_id.title()
assert data["role"] == "assistant"
assert data["status"] == "idle"
assert data["mood"] == "calm"
assert data["energy"] == 0.5
assert data["bark"] == ""
def test_ts_is_unix_timestamp(self):
"""ts should be an integer Unix timestamp."""
result = produce_agent_state("timmy", {})
assert isinstance(result["ts"], int)
assert result["ts"] > 0
@pytest.mark.parametrize(
("focus", "expected_status"),
[
("thinking about code", "thinking"),
("speaking to user", "speaking"),
("talking with agent", "speaking"),
("idle", "idle"),
("", "idle"),
("writing tests", "online"),
("reviewing PR", "online"),
],
)
def test_status_derivation(self, focus, expected_status):
"""current_focus maps to the correct Matrix status."""
data = produce_agent_state("t", {"current_focus": focus})["data"]
assert data["status"] == expected_status
def test_agent_id_passed_through(self):
"""agent_id appears in the top-level message."""
result = produce_agent_state("spark", {})
assert result["agent_id"] == "spark"
def test_display_name_from_agent_id(self):
"""When display_name is missing, it's derived from agent_id.title()."""
data = produce_agent_state("spark", {})["data"]
assert data["display_name"] == "Spark"
def test_familiar_in_data(self):
"""agent_state.data includes familiar field with required keys."""
data = produce_agent_state("timmy", {})["data"]
assert "familiar" in data
familiar = data["familiar"]
assert familiar["name"] == "Pip"
assert "mood" in familiar
assert "energy" in familiar
assert familiar["color"] == "0x00b450"
assert familiar["trail_color"] == "0xdaa520"
def test_familiar_has_all_required_fields(self):
"""familiar dict contains all required fields per acceptance criteria."""
data = produce_agent_state("timmy", {})["data"]
familiar = data["familiar"]
required_fields = {"name", "mood", "energy", "color", "trail_color"}
assert set(familiar.keys()) >= required_fields
class TestFamiliarState:
"""Tests for _get_familiar_state() — Pip familiar state retrieval."""
def test_get_familiar_state_returns_dict(self):
"""_get_familiar_state returns a dict."""
result = _get_familiar_state()
assert isinstance(result, dict)
def test_get_familiar_state_has_required_fields(self):
"""Result contains name, mood, energy, color, trail_color."""
result = _get_familiar_state()
assert result["name"] == "Pip"
assert "mood" in result
assert isinstance(result["energy"], (int, float))
assert result["color"] == "0x00b450"
assert result["trail_color"] == "0xdaa520"
def test_default_pip_state_constant(self):
"""DEFAULT_PIP_STATE has expected values."""
assert DEFAULT_PIP_STATE["name"] == "Pip"
assert DEFAULT_PIP_STATE["mood"] == "sleepy"
assert DEFAULT_PIP_STATE["energy"] == 0.5
assert DEFAULT_PIP_STATE["color"] == "0x00b450"
assert DEFAULT_PIP_STATE["trail_color"] == "0xdaa520"
@patch("infrastructure.presence.logger")
def test_get_familiar_state_fallback_on_exception(self, mock_logger):
"""When familiar module raises, falls back to default and logs warning."""
# Patch inside the function where pip_familiar is imported
with patch("timmy.familiar.pip_familiar.snapshot") as mock_snapshot:
mock_snapshot.side_effect = RuntimeError("Pip is napping")
result = _get_familiar_state()
assert result["name"] == "Pip"
assert result["mood"] == "sleepy"
mock_logger.warning.assert_called_once()
assert "Pip is napping" in str(mock_logger.warning.call_args)
class TestProduceBark:
"""Tests for produce_bark() — Matrix bark message producer."""
@patch("infrastructure.presence.time")
def test_full_message_structure(self, mock_time):
"""Returns dict with type, agent_id, data, and ts keys."""
mock_time.time.return_value = 1742529600
result = produce_bark("timmy", "Hello world!")
assert result["type"] == "bark"
assert result["agent_id"] == "timmy"
assert result["ts"] == 1742529600
assert isinstance(result["data"], dict)
def test_data_fields(self):
"""data dict contains text, reply_to, and style."""
result = produce_bark("timmy", "Hello world!", reply_to="msg-123", style="shout")
data = result["data"]
assert data["text"] == "Hello world!"
assert data["reply_to"] == "msg-123"
assert data["style"] == "shout"
def test_default_style_is_speech(self):
"""When style is not provided, defaults to 'speech'."""
result = produce_bark("timmy", "Hello!")
assert result["data"]["style"] == "speech"
def test_default_reply_to_is_none(self):
"""When reply_to is not provided, defaults to None."""
result = produce_bark("timmy", "Hello!")
assert result["data"]["reply_to"] is None
def test_text_truncated_to_280_chars(self):
"""Text longer than 280 chars is truncated."""
long_text = "A" * 500
result = produce_bark("timmy", long_text)
assert len(result["data"]["text"]) == 280
assert result["data"]["text"] == "A" * 280
def test_text_exactly_280_chars_not_truncated(self):
"""Text exactly 280 chars is not truncated."""
text = "B" * 280
result = produce_bark("timmy", text)
assert result["data"]["text"] == text
def test_text_shorter_than_280_not_padded(self):
"""Text shorter than 280 chars is not padded."""
result = produce_bark("timmy", "Short")
assert result["data"]["text"] == "Short"
@pytest.mark.parametrize(
("style", "expected_style"),
[
("speech", "speech"),
("thought", "thought"),
("whisper", "whisper"),
("shout", "shout"),
],
)
def test_valid_styles_preserved(self, style, expected_style):
"""Valid style values are preserved."""
result = produce_bark("timmy", "Hello!", style=style)
assert result["data"]["style"] == expected_style
@pytest.mark.parametrize(
"invalid_style",
["yell", "scream", "", "SPEECH", "Speech", None, 123],
)
def test_invalid_style_defaults_to_speech(self, invalid_style):
"""Invalid style values fall back to 'speech'."""
result = produce_bark("timmy", "Hello!", style=invalid_style)
assert result["data"]["style"] == "speech"
def test_empty_text_handled(self):
"""Empty text is handled gracefully."""
result = produce_bark("timmy", "")
assert result["data"]["text"] == ""
def test_ts_is_unix_timestamp(self):
"""ts should be an integer Unix timestamp."""
result = produce_bark("timmy", "Hello!")
assert isinstance(result["ts"], int)
assert result["ts"] > 0
def test_agent_id_passed_through(self):
"""agent_id appears in the top-level message."""
result = produce_bark("spark", "Hello!")
assert result["agent_id"] == "spark"
def test_with_all_parameters(self):
"""Full parameter set produces expected output."""
result = produce_bark(
agent_id="timmy",
text="Running test suite...",
reply_to="parent-msg-456",
style="thought",
)
assert result["type"] == "bark"
assert result["agent_id"] == "timmy"
assert result["data"]["text"] == "Running test suite..."
assert result["data"]["reply_to"] == "parent-msg-456"
assert result["data"]["style"] == "thought"
class TestProduceThought:
"""Tests for produce_thought() — Matrix thought message producer."""
@patch("infrastructure.presence.time")
def test_full_message_structure(self, mock_time):
"""Returns dict with type, agent_id, data, and ts keys."""
mock_time.time.return_value = 1742529600
result = produce_thought("timmy", "Considering the options...", 42)
assert result["type"] == "thought"
assert result["agent_id"] == "timmy"
assert result["ts"] == 1742529600
assert isinstance(result["data"], dict)
def test_data_fields(self):
"""data dict contains text, thought_id, and chain_id."""
result = produce_thought("timmy", "Considering...", 42, chain_id="chain-123")
data = result["data"]
assert data["text"] == "Considering..."
assert data["thought_id"] == 42
assert data["chain_id"] == "chain-123"
def test_default_chain_id_is_none(self):
"""When chain_id is not provided, defaults to None."""
result = produce_thought("timmy", "Thinking...", 1)
assert result["data"]["chain_id"] is None
def test_text_truncated_to_500_chars(self):
"""Text longer than 500 chars is truncated."""
long_text = "A" * 600
result = produce_thought("timmy", long_text, 1)
assert len(result["data"]["text"]) == 500
assert result["data"]["text"] == "A" * 500
def test_text_exactly_500_chars_not_truncated(self):
"""Text exactly 500 chars is not truncated."""
text = "B" * 500
result = produce_thought("timmy", text, 1)
assert result["data"]["text"] == text
def test_text_shorter_than_500_not_padded(self):
"""Text shorter than 500 chars is not padded."""
result = produce_thought("timmy", "Short thought", 1)
assert result["data"]["text"] == "Short thought"
def test_empty_text_handled(self):
"""Empty text is handled gracefully."""
result = produce_thought("timmy", "", 1)
assert result["data"]["text"] == ""
def test_ts_is_unix_timestamp(self):
"""ts should be an integer Unix timestamp."""
result = produce_thought("timmy", "Hello!", 1)
assert isinstance(result["ts"], int)
assert result["ts"] > 0
def test_agent_id_passed_through(self):
"""agent_id appears in the top-level message."""
result = produce_thought("spark", "Hello!", 1)
assert result["agent_id"] == "spark"
def test_thought_id_passed_through(self):
"""thought_id appears in the data."""
result = produce_thought("timmy", "Hello!", 999)
assert result["data"]["thought_id"] == 999
def test_with_all_parameters(self):
"""Full parameter set produces expected output."""
result = produce_thought(
agent_id="timmy",
thought_text="Analyzing the situation...",
thought_id=42,
chain_id="chain-abc",
)
assert result["type"] == "thought"
assert result["agent_id"] == "timmy"
assert result["data"]["text"] == "Analyzing the situation..."
assert result["data"]["thought_id"] == 42
assert result["data"]["chain_id"] == "chain-abc"
class TestProduceSystemStatus:
"""Tests for produce_system_status() — Matrix system_status message producer."""
@patch("infrastructure.presence.time")
def test_full_message_structure(self, mock_time):
"""Returns dict with type, data, and ts keys."""
mock_time.time.return_value = 1742529600
result = produce_system_status()
assert result["type"] == "system_status"
assert result["ts"] == 1742529600
assert isinstance(result["data"], dict)
def test_data_has_required_fields(self):
"""data dict contains all required system status fields."""
result = produce_system_status()
data = result["data"]
assert "agents_online" in data
assert "visitors" in data
assert "uptime_seconds" in data
assert "thinking_active" in data
assert "memory_count" in data
def test_data_field_types(self):
"""All data fields have correct types."""
result = produce_system_status()
data = result["data"]
assert isinstance(data["agents_online"], int)
assert isinstance(data["visitors"], int)
assert isinstance(data["uptime_seconds"], int)
assert isinstance(data["thinking_active"], bool)
assert isinstance(data["memory_count"], int)
def test_agents_online_is_non_negative(self):
"""agents_online is never negative."""
result = produce_system_status()
assert result["data"]["agents_online"] >= 0
def test_visitors_is_non_negative(self):
"""visitors is never negative."""
result = produce_system_status()
assert result["data"]["visitors"] >= 0
def test_uptime_seconds_is_non_negative(self):
"""uptime_seconds is never negative."""
result = produce_system_status()
assert result["data"]["uptime_seconds"] >= 0
def test_memory_count_is_non_negative(self):
"""memory_count is never negative."""
result = produce_system_status()
assert result["data"]["memory_count"] >= 0
@patch("infrastructure.presence.time")
def test_ts_is_unix_timestamp(self, mock_time):
"""ts should be an integer Unix timestamp."""
mock_time.time.return_value = 1742529600
result = produce_system_status()
assert isinstance(result["ts"], int)
assert result["ts"] == 1742529600
@patch("infrastructure.presence.logger")
def test_graceful_degradation_on_import_errors(self, mock_logger):
"""Function returns valid dict even when imports fail."""
# This test verifies the function handles failures gracefully
# by checking it always returns the expected structure
result = produce_system_status()
assert result["type"] == "system_status"
assert isinstance(result["data"], dict)
assert isinstance(result["ts"], int)
def test_returns_dict(self):
"""produce_system_status always returns a plain dict."""
result = produce_system_status()
assert isinstance(result, dict)

173
tests/unit/test_protocol.py Normal file
View File

@@ -0,0 +1,173 @@
"""Tests for infrastructure.protocol — WebSocket message types."""
import json
import pytest
from infrastructure.protocol import (
AgentStateMessage,
BarkMessage,
ConnectionAckMessage,
ErrorMessage,
MemoryFlashMessage,
MessageType,
SystemStatusMessage,
TaskUpdateMessage,
ThoughtMessage,
VisitorStateMessage,
WSMessage,
)
# ---------------------------------------------------------------------------
# MessageType enum
# ---------------------------------------------------------------------------
class TestMessageType:
"""MessageType enum covers all 9 Matrix PROTOCOL.md types."""
def test_has_all_nine_types(self):
assert len(MessageType) == 9
@pytest.mark.parametrize(
"member,value",
[
(MessageType.AGENT_STATE, "agent_state"),
(MessageType.VISITOR_STATE, "visitor_state"),
(MessageType.BARK, "bark"),
(MessageType.THOUGHT, "thought"),
(MessageType.SYSTEM_STATUS, "system_status"),
(MessageType.CONNECTION_ACK, "connection_ack"),
(MessageType.ERROR, "error"),
(MessageType.TASK_UPDATE, "task_update"),
(MessageType.MEMORY_FLASH, "memory_flash"),
],
)
def test_enum_values(self, member, value):
assert member.value == value
def test_str_comparison(self):
"""MessageType is a str enum so it can be compared to plain strings."""
assert MessageType.BARK == "bark"
# ---------------------------------------------------------------------------
# to_json / from_json round-trip
# ---------------------------------------------------------------------------
class TestAgentStateMessage:
def test_defaults(self):
msg = AgentStateMessage()
assert msg.type == "agent_state"
assert msg.agent_id == ""
assert msg.data == {}
def test_round_trip(self):
msg = AgentStateMessage(agent_id="timmy", data={"mood": "happy"}, ts=1000.0)
raw = msg.to_json()
restored = AgentStateMessage.from_json(raw)
assert restored.agent_id == "timmy"
assert restored.data == {"mood": "happy"}
assert restored.ts == 1000.0
def test_to_json_structure(self):
msg = AgentStateMessage(agent_id="timmy", data={"x": 1}, ts=123.0)
parsed = json.loads(msg.to_json())
assert parsed["type"] == "agent_state"
assert parsed["agent_id"] == "timmy"
assert parsed["data"] == {"x": 1}
assert parsed["ts"] == 123.0
class TestVisitorStateMessage:
def test_round_trip(self):
msg = VisitorStateMessage(visitor_id="v1", data={"page": "/"}, ts=1.0)
restored = VisitorStateMessage.from_json(msg.to_json())
assert restored.visitor_id == "v1"
assert restored.data == {"page": "/"}
class TestBarkMessage:
def test_round_trip(self):
msg = BarkMessage(agent_id="timmy", content="woof!", ts=1.0)
restored = BarkMessage.from_json(msg.to_json())
assert restored.agent_id == "timmy"
assert restored.content == "woof!"
class TestThoughtMessage:
def test_round_trip(self):
msg = ThoughtMessage(agent_id="timmy", content="hmm...", ts=1.0)
restored = ThoughtMessage.from_json(msg.to_json())
assert restored.content == "hmm..."
class TestSystemStatusMessage:
def test_round_trip(self):
msg = SystemStatusMessage(status="healthy", data={"uptime": 3600}, ts=1.0)
restored = SystemStatusMessage.from_json(msg.to_json())
assert restored.status == "healthy"
assert restored.data == {"uptime": 3600}
class TestConnectionAckMessage:
def test_round_trip(self):
msg = ConnectionAckMessage(client_id="abc-123", ts=1.0)
restored = ConnectionAckMessage.from_json(msg.to_json())
assert restored.client_id == "abc-123"
class TestErrorMessage:
def test_round_trip(self):
msg = ErrorMessage(code="INVALID", message="bad request", ts=1.0)
restored = ErrorMessage.from_json(msg.to_json())
assert restored.code == "INVALID"
assert restored.message == "bad request"
class TestTaskUpdateMessage:
def test_round_trip(self):
msg = TaskUpdateMessage(task_id="t1", status="completed", data={"result": "ok"}, ts=1.0)
restored = TaskUpdateMessage.from_json(msg.to_json())
assert restored.task_id == "t1"
assert restored.status == "completed"
assert restored.data == {"result": "ok"}
class TestMemoryFlashMessage:
def test_round_trip(self):
msg = MemoryFlashMessage(agent_id="timmy", memory_key="fav_food", content="kibble", ts=1.0)
restored = MemoryFlashMessage.from_json(msg.to_json())
assert restored.memory_key == "fav_food"
assert restored.content == "kibble"
# ---------------------------------------------------------------------------
# WSMessage.from_json dispatch
# ---------------------------------------------------------------------------
class TestWSMessageDispatch:
"""WSMessage.from_json dispatches to the correct subclass."""
def test_dispatch_to_bark(self):
raw = json.dumps({"type": "bark", "agent_id": "t", "content": "woof", "ts": 1.0})
msg = WSMessage.from_json(raw)
assert isinstance(msg, BarkMessage)
assert msg.content == "woof"
def test_dispatch_to_error(self):
raw = json.dumps({"type": "error", "code": "E1", "message": "oops", "ts": 1.0})
msg = WSMessage.from_json(raw)
assert isinstance(msg, ErrorMessage)
def test_unknown_type_returns_base(self):
raw = json.dumps({"type": "unknown_future_type", "ts": 1.0})
msg = WSMessage.from_json(raw)
assert type(msg) is WSMessage
assert msg.type == "unknown_future_type"
def test_invalid_json_raises(self):
with pytest.raises(json.JSONDecodeError):
WSMessage.from_json("not json")

View File

@@ -0,0 +1,446 @@
"""Tests for rate limiting middleware.
Tests the RateLimiter class and RateLimitMiddleware for correct
rate limiting behavior, cleanup, and edge cases.
"""
import time
from unittest.mock import Mock
import pytest
from starlette.requests import Request
from starlette.responses import JSONResponse
from dashboard.middleware.rate_limit import RateLimiter, RateLimitMiddleware
class TestRateLimiter:
"""Tests for the RateLimiter class."""
def test_init_defaults(self):
"""RateLimiter initializes with default values."""
limiter = RateLimiter()
assert limiter.requests_per_minute == 30
assert limiter.cleanup_interval_seconds == 60
assert limiter._storage == {}
def test_init_custom_values(self):
"""RateLimiter accepts custom configuration."""
limiter = RateLimiter(requests_per_minute=60, cleanup_interval_seconds=120)
assert limiter.requests_per_minute == 60
assert limiter.cleanup_interval_seconds == 120
def test_is_allowed_first_request(self):
"""First request from an IP is always allowed."""
limiter = RateLimiter(requests_per_minute=5)
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert retry_after == 0.0
assert "192.168.1.1" in limiter._storage
assert len(limiter._storage["192.168.1.1"]) == 1
def test_is_allowed_under_limit(self):
"""Requests under the limit are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make 4 requests (under limit of 5)
for _ in range(4):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 4
def test_is_allowed_at_limit(self):
"""Request at the limit is allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Make exactly 5 requests
for _ in range(5):
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
assert len(limiter._storage["192.168.1.1"]) == 5
def test_is_allowed_over_limit(self):
"""Request over the limit is denied."""
limiter = RateLimiter(requests_per_minute=5)
# Make 5 requests to hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# 6th request should be denied
allowed, retry_after = limiter.is_allowed("192.168.1.1")
assert allowed is False
assert retry_after > 0
def test_is_allowed_different_ips(self):
"""Rate limiting is per-IP, not global."""
limiter = RateLimiter(requests_per_minute=5)
# Hit limit for IP 1
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# IP 1 is now rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# IP 2 should still be allowed
allowed, _ = limiter.is_allowed("192.168.1.2")
assert allowed is True
def test_window_expiration_allows_new_requests(self):
"""After window expires, new requests are allowed."""
limiter = RateLimiter(requests_per_minute=5)
# Hit the limit
for _ in range(5):
limiter.is_allowed("192.168.1.1")
# Should be rate limited
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is False
# Simulate time passing by clearing timestamps manually
# (we can't wait 60 seconds in a test)
limiter._storage["192.168.1.1"].clear()
# Should now be allowed again
allowed, _ = limiter.is_allowed("192.168.1.1")
assert allowed is True
def test_cleanup_removes_stale_entries(self):
"""Cleanup removes IPs with no recent requests."""
limiter = RateLimiter(
requests_per_minute=5,
cleanup_interval_seconds=1, # Short interval for testing
)
# Add some requests
limiter.is_allowed("192.168.1.1")
limiter.is_allowed("192.168.1.2")
# Both IPs should be in storage
assert "192.168.1.1" in limiter._storage
assert "192.168.1.2" in limiter._storage
# Manually clear timestamps to simulate stale data
limiter._storage["192.168.1.1"].clear()
limiter._last_cleanup = time.time() - 2 # Force cleanup
# Trigger cleanup via check_request with a mock
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.3"
mock_request.url.path = "/api/matrix/test"
limiter.check_request(mock_request)
# Stale IP should be removed
assert "192.168.1.1" not in limiter._storage
# IP with no requests (cleared) is also stale
assert "192.168.1.2" in limiter._storage
def test_get_client_ip_direct(self):
"""Extract client IP from direct connection."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "192.168.1.100"
def test_get_client_ip_x_forwarded_for(self):
"""Extract client IP from X-Forwarded-For header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-forwarded-for": "10.0.0.1, 192.168.1.1"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.1"
def test_get_client_ip_x_real_ip(self):
"""Extract client IP from X-Real-IP header."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {"x-real-ip": "10.0.0.5"}
mock_request.client = Mock()
mock_request.client.host = "192.168.1.100"
ip = limiter._get_client_ip(mock_request)
assert ip == "10.0.0.5"
def test_get_client_ip_no_client(self):
"""Return 'unknown' when no client info available."""
limiter = RateLimiter()
mock_request = Mock()
mock_request.headers = {}
mock_request.client = None
ip = limiter._get_client_ip(mock_request)
assert ip == "unknown"
class TestRateLimitMiddleware:
"""Tests for the RateLimitMiddleware class."""
@pytest.fixture
def mock_app(self):
"""Create a mock ASGI app."""
async def app(scope, receive, send):
response = JSONResponse({"status": "ok"})
await response(scope, receive, send)
return app
@pytest.fixture
def mock_request(self):
"""Create a mock Request object."""
request = Mock(spec=Request)
request.url.path = "/api/matrix/test"
request.headers = {}
request.client = Mock()
request.client.host = "192.168.1.1"
return request
def test_init_defaults(self, mock_app):
"""Middleware initializes with default values."""
middleware = RateLimitMiddleware(mock_app)
assert middleware.path_prefixes == []
assert middleware.limiter.requests_per_minute == 30
def test_init_custom_values(self, mock_app):
"""Middleware accepts custom configuration."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=60,
)
assert middleware.path_prefixes == ["/api/matrix/"]
assert middleware.limiter.requests_per_minute == 60
def test_should_rate_limit_no_prefixes(self, mock_app):
"""With no prefixes, all paths are rate limited."""
middleware = RateLimitMiddleware(mock_app)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/other/test") is True
assert middleware._should_rate_limit("/health") is True
def test_should_rate_limit_with_prefixes(self, mock_app):
"""With prefixes, only matching paths are rate limited."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/", "/api/public/"],
)
assert middleware._should_rate_limit("/api/matrix/test") is True
assert middleware._should_rate_limit("/api/matrix/") is True
assert middleware._should_rate_limit("/api/public/data") is True
assert middleware._should_rate_limit("/api/other/test") is False
assert middleware._should_rate_limit("/health") is False
@pytest.mark.asyncio
async def test_dispatch_allows_matching_path_under_limit(self, mock_app):
"""Request to matching path under limit is allowed."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
# Create a proper ASGI scope
scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
# Check for 200 status in the response start message
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_skips_non_matching_path(self, mock_app):
"""Request to non-matching path bypasses rate limiting."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=5,
)
scope = {
"type": "http",
"method": "GET",
"path": "/api/other/test", # Doesn't match /api/matrix/
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
response_body = []
async def send(message):
response_body.append(message)
await middleware(scope, receive, send)
# Should have sent response messages
assert len(response_body) > 0
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message is not None
assert start_message["status"] == 200
@pytest.mark.asyncio
async def test_dispatch_returns_429_when_rate_limited(self, mock_app):
"""Request over limit returns 429 status."""
middleware = RateLimitMiddleware(
mock_app,
path_prefixes=["/api/matrix/"],
requests_per_minute=2, # Low limit for testing
)
# First request - allowed
test_scope = {
"type": "http",
"method": "GET",
"path": "/api/matrix/test",
"headers": [],
}
async def receive():
return {"type": "http.request", "body": b""}
# Helper to capture response
def make_send(captured):
async def send(message):
captured.append(message)
return send
# Make requests to hit the limit
for _ in range(2):
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"),
None,
)
assert start_message["status"] == 200
# 3rd request should be rate limited
response_body = []
await middleware(test_scope, receive, make_send(response_body))
start_message = next(
(m for m in response_body if m.get("type") == "http.response.start"), None
)
assert start_message["status"] == 429
# Check for Retry-After header
headers = dict(start_message.get("headers", []))
assert b"retry-after" in headers or b"Retry-After" in headers
class TestRateLimiterIntegration:
"""Integration-style tests for rate limiter behavior."""
def test_multiple_ips_independent_limits(self):
"""Each IP has its own independent rate limit."""
limiter = RateLimiter(requests_per_minute=3)
# Use up limit for IP 1
for _ in range(3):
limiter.is_allowed("10.0.0.1")
# Use up limit for IP 2
for _ in range(3):
limiter.is_allowed("10.0.0.2")
# Both should now be rate limited
assert limiter.is_allowed("10.0.0.1")[0] is False
assert limiter.is_allowed("10.0.0.2")[0] is False
# IP 3 should still be allowed
assert limiter.is_allowed("10.0.0.3")[0] is True
def test_timestamp_window_sliding(self):
"""Rate limit window slides correctly as time passes."""
from collections import deque
limiter = RateLimiter(requests_per_minute=3)
# Add 3 timestamps manually (simulating old requests)
now = time.time()
limiter._storage["test-ip"] = deque(
[
now - 100, # 100 seconds ago (outside 60s window)
now - 50, # 50 seconds ago (inside window)
now - 10, # 10 seconds ago (inside window)
]
)
# Currently have 2 requests in window, so 1 more allowed
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is True
# Now 3 in window, should be rate limited
allowed, _ = limiter.is_allowed("test-ip")
assert allowed is False
def test_cleanup_preserves_active_ips(self):
"""Cleanup only removes IPs with no recent requests."""
from collections import deque
limiter = RateLimiter(
requests_per_minute=3,
cleanup_interval_seconds=1,
)
now = time.time()
# IP 1: active recently
limiter._storage["active-ip"] = deque([now - 10])
# IP 2: no timestamps (stale)
limiter._storage["stale-ip"] = deque()
# IP 3: old timestamps only
limiter._storage["old-ip"] = deque([now - 100])
limiter._last_cleanup = now - 2 # Force cleanup
# Run cleanup
limiter._cleanup_if_needed()
# Active IP should remain
assert "active-ip" in limiter._storage
# Stale IPs should be removed
assert "stale-ip" not in limiter._storage
assert "old-ip" not in limiter._storage