security: implement rate limiting for chat API and fix calm tests (#122)

This commit is contained in:
Alexander Whitestone
2026-03-03 08:16:36 -05:00
committed by GitHub
parent 584eeb679e
commit eb2b34876b
3 changed files with 84 additions and 4 deletions

View File

@@ -9,10 +9,15 @@ Endpoints:
from __future__ import annotations
import logging
import time
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import Dict, List
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from config import settings
from timmy.agent import create_timmy
@@ -34,6 +39,39 @@ class StatusResponse(BaseModel):
backend: str
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple in-memory rate limiting middleware."""
def __init__(self, app, limit: int = 10, window: int = 60):
super().__init__(app)
self.limit = limit
self.window = window
self.requests: Dict[str, List[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next):
# Only rate limit chat endpoint
if request.url.path == "/serve/chat" and request.method == "POST":
client_ip = request.client.host if request.client else "unknown"
now = time.time()
# Clean up old requests
self.requests[client_ip] = [
t for t in self.requests[client_ip]
if now - t < self.window
]
if len(self.requests[client_ip]) >= self.limit:
logger.warning("Rate limit exceeded for %s", client_ip)
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded. Try again later."}
)
self.requests[client_ip].append(now)
return await call_next(request)
def create_timmy_serve_app() -> FastAPI:
"""Create the Timmy Serve FastAPI application."""
@@ -51,6 +89,9 @@ def create_timmy_serve_app() -> FastAPI:
redoc_url="/redoc" if settings.debug else None,
)
# Add rate limiting middleware (10 requests per minute)
app.add_middleware(RateLimitMiddleware, limit=10, window=60)
@app.get("/serve/status", response_model=StatusResponse)
async def serve_status():
"""Get service status."""