Compare commits

..

1 Commits

Author SHA1 Message Date
kimi
9578330c87 fix: replace wildcard CORS default with explicit localhost origins
The cors_origins setting defaulted to ["*"], which passed through
unchanged in production (non-debug) mode. Now defaults to explicit
localhost origins, and _get_cors_origins() strips any wildcards in
production with a warning.

Fixes #462

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 14:58:10 -04:00
10 changed files with 78 additions and 337 deletions

View File

@@ -54,11 +54,25 @@ providers:
context_window: 2048
capabilities: [text, vision, streaming]
# Secondary: OpenAI (if API key available)
# Secondary: Local AirLLM (if installed)
- name: airllm-local
type: airllm
enabled: false # Enable if pip install airllm
priority: 2
models:
- name: 70b
default: true
capabilities: [text, tools, json, streaming]
- name: 8b
capabilities: [text, tools, json, streaming]
- name: 405b
capabilities: [text, tools, json, streaming]
# Tertiary: OpenAI (if API key available)
- name: openai-backup
type: openai
enabled: false # Enable by setting OPENAI_API_KEY
priority: 2
priority: 3
api_key: "${OPENAI_API_KEY}" # Loaded from environment
base_url: null # Use default OpenAI endpoint
models:
@@ -70,11 +84,11 @@ providers:
context_window: 128000
capabilities: [text, vision, tools, json, streaming]
# Tertiary: Anthropic (if API key available)
# Quaternary: Anthropic (if API key available)
- name: anthropic-backup
type: anthropic
enabled: false # Enable by setting ANTHROPIC_API_KEY
priority: 3
priority: 4
api_key: "${ANTHROPIC_API_KEY}"
models:
- name: claude-3-haiku-20240307

View File

@@ -64,10 +64,17 @@ class Settings(BaseSettings):
# Seconds to wait for user confirmation before auto-rejecting.
discord_confirm_timeout: int = 120
# ── Backend selection ────────────────────────────────────────────────────
# ── AirLLM / backend selection ───────────────────────────────────────────
# "ollama" — always use Ollama (default, safe everywhere)
# "auto" — pick best available local backend, fall back to Ollama
timmy_model_backend: Literal["ollama", "grok", "claude", "auto"] = "ollama"
# "airllm" — always use AirLLM (requires pip install ".[bigbrain]")
# "auto" — use AirLLM on Apple Silicon if airllm is installed,
# fall back to Ollama otherwise
timmy_model_backend: Literal["ollama", "airllm", "grok", "claude", "auto"] = "ollama"
# AirLLM model size when backend is airllm or auto.
# Larger = smarter, but needs more RAM / disk.
# 8b ~16 GB | 70b ~140 GB | 405b ~810 GB
airllm_model_size: Literal["8b", "70b", "405b"] = "70b"
# ── Grok (xAI) — opt-in premium cloud backend ────────────────────────
# Grok is a premium augmentation layer — local-first ethos preserved.
@@ -469,19 +476,8 @@ def validate_startup(*, force: bool = False) -> None:
", ".join(_missing),
)
sys.exit(1)
if "*" in settings.cors_origins:
_startup_logger.error(
"PRODUCTION SECURITY ERROR: CORS wildcard '*' is not allowed "
"in production. Set CORS_ORIGINS to explicit origins."
)
sys.exit(1)
_startup_logger.info("Production mode: security secrets validated ✓")
else:
if "*" in settings.cors_origins:
_startup_logger.warning(
"SEC: CORS_ORIGINS contains wildcard '*'"
"restrict to explicit origins before deploying to production."
)
if not settings.l402_hmac_secret:
_startup_logger.warning(
"SEC: L402_HMAC_SECRET is not set — "

View File

@@ -486,12 +486,17 @@ app = FastAPI(
def _get_cors_origins() -> list[str]:
"""Get CORS origins from settings, rejecting wildcards in production."""
origins = settings.cors_origins
if "*" in origins and not settings.debug:
if not settings.debug and "*" in origins:
logger.warning(
"Wildcard '*' in CORS_ORIGINS stripped in production — "
"Wildcard '*' in CORS_ORIGINS ignored in production — "
"set explicit origins via CORS_ORIGINS env var"
)
origins = [o for o in origins if o != "*"]
if not origins:
origins = [
"http://localhost:3000",
"http://localhost:8000",
]
return origins

View File

@@ -183,22 +183,6 @@ async def run_health_check(
}
@router.post("/reload")
async def reload_config(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
) -> dict[str, Any]:
"""Hot-reload providers.yaml without restart.
Preserves circuit breaker state and metrics for existing providers.
"""
try:
result = cascade.reload_config()
return {"status": "ok", **result}
except Exception as exc:
logger.error("Config reload failed: %s", exc)
raise HTTPException(status_code=500, detail=f"Reload failed: {exc}") from exc
@router.get("/config")
async def get_config(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],

View File

@@ -100,7 +100,7 @@ class Provider:
"""LLM provider configuration and state."""
name: str
type: str # ollama, openai, anthropic
type: str # ollama, openai, anthropic, airllm
enabled: bool
priority: int
url: str | None = None
@@ -308,6 +308,15 @@ class CascadeRouter:
logger.debug("Ollama provider check error: %s", exc)
return False
elif provider.type == "airllm":
# Check if airllm is installed
try:
import importlib.util
return importlib.util.find_spec("airllm") is not None
except (ImportError, ModuleNotFoundError):
return False
elif provider.type in ("openai", "anthropic", "grok"):
# Check if API key is set
return provider.api_key is not None and provider.api_key != ""
@@ -806,66 +815,6 @@ class CascadeRouter:
provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name)
def reload_config(self) -> dict:
"""Hot-reload providers.yaml, preserving runtime state.
Re-reads the config file, rebuilds the provider list, and
preserves circuit breaker state and metrics for providers
that still exist after reload.
Returns:
Summary dict with added/removed/preserved counts.
"""
# Snapshot current runtime state keyed by provider name
old_state: dict[
str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]
] = {}
for p in self.providers:
old_state[p.name] = (
p.metrics,
p.circuit_state,
p.circuit_opened_at,
p.half_open_calls,
p.status,
)
old_names = set(old_state.keys())
# Reload from disk
self.providers = []
self._load_config()
# Restore preserved state
new_names = {p.name for p in self.providers}
preserved = 0
for p in self.providers:
if p.name in old_state:
metrics, circuit, opened_at, half_open, status = old_state[p.name]
p.metrics = metrics
p.circuit_state = circuit
p.circuit_opened_at = opened_at
p.half_open_calls = half_open
p.status = status
preserved += 1
added = new_names - old_names
removed = old_names - new_names
logger.info(
"Config reloaded: %d providers (%d preserved, %d added, %d removed)",
len(self.providers),
preserved,
len(added),
len(removed),
)
return {
"total_providers": len(self.providers),
"preserved": preserved,
"added": sorted(added),
"removed": sorted(removed),
}
def get_metrics(self) -> dict:
"""Get metrics for all providers."""
return {

View File

@@ -220,7 +220,7 @@ def create_timmy(
print_response(message, stream).
"""
resolved = _resolve_backend(backend)
size = model_size or "70b"
size = model_size or settings.airllm_model_size
if resolved == "claude":
from timmy.backends import ClaudeBackend

View File

@@ -75,8 +75,6 @@ def create_timmy_serve_app() -> FastAPI:
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("Timmy Serve starting")
app.state.timmy = create_timmy()
logger.info("Timmy agent cached in app state")
yield
logger.info("Timmy Serve shutting down")
@@ -103,7 +101,7 @@ def create_timmy_serve_app() -> FastAPI:
async def serve_chat(request: Request, body: ChatRequest):
"""Process a chat request."""
try:
timmy = request.app.state.timmy
timmy = create_timmy()
result = timmy.run(body.message, stream=False)
response_text = result.content if hasattr(result, "content") else str(result)

View File

@@ -2,7 +2,7 @@
import time
from pathlib import Path
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
@@ -489,182 +489,30 @@ class TestProviderAvailabilityCheck:
assert router._check_provider_available(provider) is False
def test_check_airllm_installed(self):
"""Test AirLLM when installed."""
router = CascadeRouter(config_path=Path("/nonexistent"))
class TestCascadeRouterReload:
"""Test hot-reload of providers.yaml."""
def test_reload_preserves_metrics(self, tmp_path):
"""Test that reload preserves metrics for existing providers."""
config = {
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Simulate some traffic
router._record_success(router.providers[0], 150.0)
router._record_success(router.providers[0], 250.0)
assert router.providers[0].metrics.total_requests == 2
# Reload
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["added"] == []
assert result["removed"] == []
# Metrics survived
assert router.providers[0].metrics.total_requests == 2
assert router.providers[0].metrics.total_latency_ms == 400.0
def test_reload_preserves_circuit_breaker(self, tmp_path):
"""Test that reload preserves circuit breaker state."""
config = {
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
# Open circuit breaker
for _ in range(2):
router._record_failure(router.providers[0])
assert router.providers[0].circuit_state == CircuitState.OPEN
# Reload
router.reload_config()
# Circuit breaker state preserved
assert router.providers[0].circuit_state == CircuitState.OPEN
assert router.providers[0].status == ProviderStatus.UNHEALTHY
def test_reload_detects_added_provider(self, tmp_path):
"""Test that reload detects newly added providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Add a second provider to config
config["providers"].append(
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
}
provider = Provider(
name="airllm",
type="airllm",
enabled=True,
priority=1,
)
config_path.write_text(yaml.dump(config))
result = router.reload_config()
with patch("importlib.util.find_spec", return_value=MagicMock()):
assert router._check_provider_available(provider) is True
assert result["total_providers"] == 2
assert result["preserved"] == 1
assert result["added"] == ["anthropic-1"]
assert result["removed"] == []
def test_check_airllm_not_installed(self):
"""Test AirLLM when not installed."""
router = CascadeRouter(config_path=Path("/nonexistent"))
def test_reload_detects_removed_provider(self, tmp_path):
"""Test that reload detects removed providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
},
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
provider = Provider(
name="airllm",
type="airllm",
enabled=True,
priority=1,
)
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 2
# Remove anthropic
config["providers"] = [config["providers"][0]]
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["removed"] == ["anthropic-1"]
def test_reload_re_sorts_by_priority(self, tmp_path):
"""Test that providers are re-sorted by priority after reload."""
config = {
"providers": [
{
"name": "low-priority",
"type": "openai",
"enabled": True,
"priority": 10,
"api_key": "sk-test",
},
{
"name": "high-priority",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test2",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.providers[0].name == "high-priority"
# Swap priorities
config["providers"][0]["priority"] = 1
config["providers"][1]["priority"] = 10
config_path.write_text(yaml.dump(config))
router.reload_config()
assert router.providers[0].name == "low-priority"
assert router.providers[1].name == "high-priority"
with patch("importlib.util.find_spec", return_value=None):
assert router._check_provider_available(provider) is False

View File

@@ -49,34 +49,6 @@ class TestConfigLazyValidation:
# Should not raise
validate_startup(force=True)
def test_validate_startup_exits_on_cors_wildcard_in_production(self):
"""validate_startup() should exit in production when CORS has wildcard."""
from config import settings, validate_startup
with (
patch.object(settings, "timmy_env", "production"),
patch.object(settings, "l402_hmac_secret", "test-secret-hex-value-32"),
patch.object(settings, "l402_macaroon_secret", "test-macaroon-hex-value-32"),
patch.object(settings, "cors_origins", ["*"]),
pytest.raises(SystemExit),
):
validate_startup(force=True)
def test_validate_startup_warns_cors_wildcard_in_dev(self):
"""validate_startup() should warn in dev when CORS has wildcard."""
from config import settings, validate_startup
with (
patch.object(settings, "timmy_env", "development"),
patch.object(settings, "cors_origins", ["*"]),
patch("config._startup_logger") as mock_logger,
):
validate_startup(force=True)
mock_logger.warning.assert_any_call(
"SEC: CORS_ORIGINS contains wildcard '*'"
"restrict to explicit origins before deploying to production."
)
def test_validate_startup_skips_in_test_mode(self):
"""validate_startup() should be a no-op in test mode."""
from config import validate_startup

View File

@@ -8,14 +8,11 @@ from fastapi.testclient import TestClient
@pytest.fixture
def serve_client():
"""Create a TestClient for the timmy-serve app with mocked Timmy agent."""
with patch("timmy_serve.app.create_timmy") as mock_create:
mock_create.return_value = MagicMock()
from timmy_serve.app import create_timmy_serve_app
"""Create a TestClient for the timmy-serve app."""
from timmy_serve.app import create_timmy_serve_app
app = create_timmy_serve_app()
with TestClient(app) as client:
yield client
app = create_timmy_serve_app()
return TestClient(app)
class TestHealthEndpoint:
@@ -37,40 +34,18 @@ class TestServeStatus:
class TestServeChatEndpoint:
@patch("timmy_serve.app.create_timmy")
def test_chat_returns_response(self, mock_create):
def test_chat_returns_response(self, mock_create, serve_client):
mock_agent = MagicMock()
mock_result = MagicMock()
mock_result.content = "I am Timmy."
mock_agent.run.return_value = mock_result
mock_create.return_value = mock_agent
from timmy_serve.app import create_timmy_serve_app
app = create_timmy_serve_app()
with TestClient(app) as client:
resp = client.post(
"/serve/chat",
json={"message": "Who are you?"},
)
resp = serve_client.post(
"/serve/chat",
json={"message": "Who are you?"},
)
assert resp.status_code == 200
data = resp.json()
assert data["response"] == "I am Timmy."
mock_agent.run.assert_called_once_with("Who are you?", stream=False)
@patch("timmy_serve.app.create_timmy")
def test_agent_cached_at_startup(self, mock_create):
"""Verify create_timmy is called once at startup, not per request."""
mock_agent = MagicMock()
mock_result = MagicMock()
mock_result.content = "reply"
mock_agent.run.return_value = mock_result
mock_create.return_value = mock_agent
from timmy_serve.app import create_timmy_serve_app
app = create_timmy_serve_app()
with TestClient(app) as client:
# Two requests — create_timmy should only be called once (at startup)
client.post("/serve/chat", json={"message": "hello"})
client.post("/serve/chat", json={"message": "world"})
mock_create.assert_called_once()