Compare commits
1 Commits
kimi/issue
...
kimi/issue
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9578330c87 |
@@ -54,11 +54,25 @@ providers:
|
|||||||
context_window: 2048
|
context_window: 2048
|
||||||
capabilities: [text, vision, streaming]
|
capabilities: [text, vision, streaming]
|
||||||
|
|
||||||
# Secondary: OpenAI (if API key available)
|
# Secondary: Local AirLLM (if installed)
|
||||||
|
- name: airllm-local
|
||||||
|
type: airllm
|
||||||
|
enabled: false # Enable if pip install airllm
|
||||||
|
priority: 2
|
||||||
|
models:
|
||||||
|
- name: 70b
|
||||||
|
default: true
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
|
- name: 8b
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
|
- name: 405b
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
|
|
||||||
|
# Tertiary: OpenAI (if API key available)
|
||||||
- name: openai-backup
|
- name: openai-backup
|
||||||
type: openai
|
type: openai
|
||||||
enabled: false # Enable by setting OPENAI_API_KEY
|
enabled: false # Enable by setting OPENAI_API_KEY
|
||||||
priority: 2
|
priority: 3
|
||||||
api_key: "${OPENAI_API_KEY}" # Loaded from environment
|
api_key: "${OPENAI_API_KEY}" # Loaded from environment
|
||||||
base_url: null # Use default OpenAI endpoint
|
base_url: null # Use default OpenAI endpoint
|
||||||
models:
|
models:
|
||||||
@@ -70,11 +84,11 @@ providers:
|
|||||||
context_window: 128000
|
context_window: 128000
|
||||||
capabilities: [text, vision, tools, json, streaming]
|
capabilities: [text, vision, tools, json, streaming]
|
||||||
|
|
||||||
# Tertiary: Anthropic (if API key available)
|
# Quaternary: Anthropic (if API key available)
|
||||||
- name: anthropic-backup
|
- name: anthropic-backup
|
||||||
type: anthropic
|
type: anthropic
|
||||||
enabled: false # Enable by setting ANTHROPIC_API_KEY
|
enabled: false # Enable by setting ANTHROPIC_API_KEY
|
||||||
priority: 3
|
priority: 4
|
||||||
api_key: "${ANTHROPIC_API_KEY}"
|
api_key: "${ANTHROPIC_API_KEY}"
|
||||||
models:
|
models:
|
||||||
- name: claude-3-haiku-20240307
|
- name: claude-3-haiku-20240307
|
||||||
|
|||||||
@@ -64,10 +64,17 @@ class Settings(BaseSettings):
|
|||||||
# Seconds to wait for user confirmation before auto-rejecting.
|
# Seconds to wait for user confirmation before auto-rejecting.
|
||||||
discord_confirm_timeout: int = 120
|
discord_confirm_timeout: int = 120
|
||||||
|
|
||||||
# ── Backend selection ────────────────────────────────────────────────────
|
# ── AirLLM / backend selection ───────────────────────────────────────────
|
||||||
# "ollama" — always use Ollama (default, safe everywhere)
|
# "ollama" — always use Ollama (default, safe everywhere)
|
||||||
# "auto" — pick best available local backend, fall back to Ollama
|
# "airllm" — always use AirLLM (requires pip install ".[bigbrain]")
|
||||||
timmy_model_backend: Literal["ollama", "grok", "claude", "auto"] = "ollama"
|
# "auto" — use AirLLM on Apple Silicon if airllm is installed,
|
||||||
|
# fall back to Ollama otherwise
|
||||||
|
timmy_model_backend: Literal["ollama", "airllm", "grok", "claude", "auto"] = "ollama"
|
||||||
|
|
||||||
|
# AirLLM model size when backend is airllm or auto.
|
||||||
|
# Larger = smarter, but needs more RAM / disk.
|
||||||
|
# 8b ~16 GB | 70b ~140 GB | 405b ~810 GB
|
||||||
|
airllm_model_size: Literal["8b", "70b", "405b"] = "70b"
|
||||||
|
|
||||||
# ── Grok (xAI) — opt-in premium cloud backend ────────────────────────
|
# ── Grok (xAI) — opt-in premium cloud backend ────────────────────────
|
||||||
# Grok is a premium augmentation layer — local-first ethos preserved.
|
# Grok is a premium augmentation layer — local-first ethos preserved.
|
||||||
@@ -469,19 +476,8 @@ def validate_startup(*, force: bool = False) -> None:
|
|||||||
", ".join(_missing),
|
", ".join(_missing),
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if "*" in settings.cors_origins:
|
|
||||||
_startup_logger.error(
|
|
||||||
"PRODUCTION SECURITY ERROR: CORS wildcard '*' is not allowed "
|
|
||||||
"in production. Set CORS_ORIGINS to explicit origins."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
_startup_logger.info("Production mode: security secrets validated ✓")
|
_startup_logger.info("Production mode: security secrets validated ✓")
|
||||||
else:
|
else:
|
||||||
if "*" in settings.cors_origins:
|
|
||||||
_startup_logger.warning(
|
|
||||||
"SEC: CORS_ORIGINS contains wildcard '*' — "
|
|
||||||
"restrict to explicit origins before deploying to production."
|
|
||||||
)
|
|
||||||
if not settings.l402_hmac_secret:
|
if not settings.l402_hmac_secret:
|
||||||
_startup_logger.warning(
|
_startup_logger.warning(
|
||||||
"SEC: L402_HMAC_SECRET is not set — "
|
"SEC: L402_HMAC_SECRET is not set — "
|
||||||
|
|||||||
@@ -486,12 +486,17 @@ app = FastAPI(
|
|||||||
def _get_cors_origins() -> list[str]:
|
def _get_cors_origins() -> list[str]:
|
||||||
"""Get CORS origins from settings, rejecting wildcards in production."""
|
"""Get CORS origins from settings, rejecting wildcards in production."""
|
||||||
origins = settings.cors_origins
|
origins = settings.cors_origins
|
||||||
if "*" in origins and not settings.debug:
|
if not settings.debug and "*" in origins:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Wildcard '*' in CORS_ORIGINS stripped in production — "
|
"Wildcard '*' in CORS_ORIGINS ignored in production — "
|
||||||
"set explicit origins via CORS_ORIGINS env var"
|
"set explicit origins via CORS_ORIGINS env var"
|
||||||
)
|
)
|
||||||
origins = [o for o in origins if o != "*"]
|
origins = [o for o in origins if o != "*"]
|
||||||
|
if not origins:
|
||||||
|
origins = [
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:8000",
|
||||||
|
]
|
||||||
return origins
|
return origins
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -183,22 +183,6 @@ async def run_health_check(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/reload")
|
|
||||||
async def reload_config(
|
|
||||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Hot-reload providers.yaml without restart.
|
|
||||||
|
|
||||||
Preserves circuit breaker state and metrics for existing providers.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = cascade.reload_config()
|
|
||||||
return {"status": "ok", **result}
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("Config reload failed: %s", exc)
|
|
||||||
raise HTTPException(status_code=500, detail=f"Reload failed: {exc}") from exc
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config")
|
@router.get("/config")
|
||||||
async def get_config(
|
async def get_config(
|
||||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class Provider:
|
|||||||
"""LLM provider configuration and state."""
|
"""LLM provider configuration and state."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
type: str # ollama, openai, anthropic
|
type: str # ollama, openai, anthropic, airllm
|
||||||
enabled: bool
|
enabled: bool
|
||||||
priority: int
|
priority: int
|
||||||
url: str | None = None
|
url: str | None = None
|
||||||
@@ -308,6 +308,15 @@ class CascadeRouter:
|
|||||||
logger.debug("Ollama provider check error: %s", exc)
|
logger.debug("Ollama provider check error: %s", exc)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
elif provider.type == "airllm":
|
||||||
|
# Check if airllm is installed
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
return importlib.util.find_spec("airllm") is not None
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
return False
|
||||||
|
|
||||||
elif provider.type in ("openai", "anthropic", "grok"):
|
elif provider.type in ("openai", "anthropic", "grok"):
|
||||||
# Check if API key is set
|
# Check if API key is set
|
||||||
return provider.api_key is not None and provider.api_key != ""
|
return provider.api_key is not None and provider.api_key != ""
|
||||||
@@ -806,66 +815,6 @@ class CascadeRouter:
|
|||||||
provider.status = ProviderStatus.HEALTHY
|
provider.status = ProviderStatus.HEALTHY
|
||||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||||
|
|
||||||
def reload_config(self) -> dict:
|
|
||||||
"""Hot-reload providers.yaml, preserving runtime state.
|
|
||||||
|
|
||||||
Re-reads the config file, rebuilds the provider list, and
|
|
||||||
preserves circuit breaker state and metrics for providers
|
|
||||||
that still exist after reload.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Summary dict with added/removed/preserved counts.
|
|
||||||
"""
|
|
||||||
# Snapshot current runtime state keyed by provider name
|
|
||||||
old_state: dict[
|
|
||||||
str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]
|
|
||||||
] = {}
|
|
||||||
for p in self.providers:
|
|
||||||
old_state[p.name] = (
|
|
||||||
p.metrics,
|
|
||||||
p.circuit_state,
|
|
||||||
p.circuit_opened_at,
|
|
||||||
p.half_open_calls,
|
|
||||||
p.status,
|
|
||||||
)
|
|
||||||
|
|
||||||
old_names = set(old_state.keys())
|
|
||||||
|
|
||||||
# Reload from disk
|
|
||||||
self.providers = []
|
|
||||||
self._load_config()
|
|
||||||
|
|
||||||
# Restore preserved state
|
|
||||||
new_names = {p.name for p in self.providers}
|
|
||||||
preserved = 0
|
|
||||||
for p in self.providers:
|
|
||||||
if p.name in old_state:
|
|
||||||
metrics, circuit, opened_at, half_open, status = old_state[p.name]
|
|
||||||
p.metrics = metrics
|
|
||||||
p.circuit_state = circuit
|
|
||||||
p.circuit_opened_at = opened_at
|
|
||||||
p.half_open_calls = half_open
|
|
||||||
p.status = status
|
|
||||||
preserved += 1
|
|
||||||
|
|
||||||
added = new_names - old_names
|
|
||||||
removed = old_names - new_names
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Config reloaded: %d providers (%d preserved, %d added, %d removed)",
|
|
||||||
len(self.providers),
|
|
||||||
preserved,
|
|
||||||
len(added),
|
|
||||||
len(removed),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_providers": len(self.providers),
|
|
||||||
"preserved": preserved,
|
|
||||||
"added": sorted(added),
|
|
||||||
"removed": sorted(removed),
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_metrics(self) -> dict:
|
def get_metrics(self) -> dict:
|
||||||
"""Get metrics for all providers."""
|
"""Get metrics for all providers."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ def create_timmy(
|
|||||||
print_response(message, stream).
|
print_response(message, stream).
|
||||||
"""
|
"""
|
||||||
resolved = _resolve_backend(backend)
|
resolved = _resolve_backend(backend)
|
||||||
size = model_size or "70b"
|
size = model_size or settings.airllm_model_size
|
||||||
|
|
||||||
if resolved == "claude":
|
if resolved == "claude":
|
||||||
from timmy.backends import ClaudeBackend
|
from timmy.backends import ClaudeBackend
|
||||||
|
|||||||
@@ -75,8 +75,6 @@ def create_timmy_serve_app() -> FastAPI:
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
logger.info("Timmy Serve starting")
|
logger.info("Timmy Serve starting")
|
||||||
app.state.timmy = create_timmy()
|
|
||||||
logger.info("Timmy agent cached in app state")
|
|
||||||
yield
|
yield
|
||||||
logger.info("Timmy Serve shutting down")
|
logger.info("Timmy Serve shutting down")
|
||||||
|
|
||||||
@@ -103,7 +101,7 @@ def create_timmy_serve_app() -> FastAPI:
|
|||||||
async def serve_chat(request: Request, body: ChatRequest):
|
async def serve_chat(request: Request, body: ChatRequest):
|
||||||
"""Process a chat request."""
|
"""Process a chat request."""
|
||||||
try:
|
try:
|
||||||
timmy = request.app.state.timmy
|
timmy = create_timmy()
|
||||||
result = timmy.run(body.message, stream=False)
|
result = timmy.run(body.message, stream=False)
|
||||||
response_text = result.content if hasattr(result, "content") else str(result)
|
response_text = result.content if hasattr(result, "content") else str(result)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
@@ -489,182 +489,30 @@ class TestProviderAvailabilityCheck:
|
|||||||
|
|
||||||
assert router._check_provider_available(provider) is False
|
assert router._check_provider_available(provider) is False
|
||||||
|
|
||||||
|
def test_check_airllm_installed(self):
|
||||||
|
"""Test AirLLM when installed."""
|
||||||
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
|
|
||||||
class TestCascadeRouterReload:
|
provider = Provider(
|
||||||
"""Test hot-reload of providers.yaml."""
|
name="airllm",
|
||||||
|
type="airllm",
|
||||||
def test_reload_preserves_metrics(self, tmp_path):
|
enabled=True,
|
||||||
"""Test that reload preserves metrics for existing providers."""
|
priority=1,
|
||||||
config = {
|
|
||||||
"providers": [
|
|
||||||
{
|
|
||||||
"name": "test-openai",
|
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 1,
|
|
||||||
"api_key": "sk-test",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
config_path = tmp_path / "providers.yaml"
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router = CascadeRouter(config_path=config_path)
|
|
||||||
assert len(router.providers) == 1
|
|
||||||
|
|
||||||
# Simulate some traffic
|
|
||||||
router._record_success(router.providers[0], 150.0)
|
|
||||||
router._record_success(router.providers[0], 250.0)
|
|
||||||
assert router.providers[0].metrics.total_requests == 2
|
|
||||||
|
|
||||||
# Reload
|
|
||||||
result = router.reload_config()
|
|
||||||
|
|
||||||
assert result["total_providers"] == 1
|
|
||||||
assert result["preserved"] == 1
|
|
||||||
assert result["added"] == []
|
|
||||||
assert result["removed"] == []
|
|
||||||
# Metrics survived
|
|
||||||
assert router.providers[0].metrics.total_requests == 2
|
|
||||||
assert router.providers[0].metrics.total_latency_ms == 400.0
|
|
||||||
|
|
||||||
def test_reload_preserves_circuit_breaker(self, tmp_path):
|
|
||||||
"""Test that reload preserves circuit breaker state."""
|
|
||||||
config = {
|
|
||||||
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
|
|
||||||
"providers": [
|
|
||||||
{
|
|
||||||
"name": "test-openai",
|
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 1,
|
|
||||||
"api_key": "sk-test",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
config_path = tmp_path / "providers.yaml"
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router = CascadeRouter(config_path=config_path)
|
|
||||||
|
|
||||||
# Open circuit breaker
|
|
||||||
for _ in range(2):
|
|
||||||
router._record_failure(router.providers[0])
|
|
||||||
assert router.providers[0].circuit_state == CircuitState.OPEN
|
|
||||||
|
|
||||||
# Reload
|
|
||||||
router.reload_config()
|
|
||||||
|
|
||||||
# Circuit breaker state preserved
|
|
||||||
assert router.providers[0].circuit_state == CircuitState.OPEN
|
|
||||||
assert router.providers[0].status == ProviderStatus.UNHEALTHY
|
|
||||||
|
|
||||||
def test_reload_detects_added_provider(self, tmp_path):
|
|
||||||
"""Test that reload detects newly added providers."""
|
|
||||||
config = {
|
|
||||||
"providers": [
|
|
||||||
{
|
|
||||||
"name": "openai-1",
|
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 1,
|
|
||||||
"api_key": "sk-test",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
config_path = tmp_path / "providers.yaml"
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router = CascadeRouter(config_path=config_path)
|
|
||||||
assert len(router.providers) == 1
|
|
||||||
|
|
||||||
# Add a second provider to config
|
|
||||||
config["providers"].append(
|
|
||||||
{
|
|
||||||
"name": "anthropic-1",
|
|
||||||
"type": "anthropic",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 2,
|
|
||||||
"api_key": "sk-ant-test",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
result = router.reload_config()
|
with patch("importlib.util.find_spec", return_value=MagicMock()):
|
||||||
|
assert router._check_provider_available(provider) is True
|
||||||
|
|
||||||
assert result["total_providers"] == 2
|
def test_check_airllm_not_installed(self):
|
||||||
assert result["preserved"] == 1
|
"""Test AirLLM when not installed."""
|
||||||
assert result["added"] == ["anthropic-1"]
|
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||||
assert result["removed"] == []
|
|
||||||
|
|
||||||
def test_reload_detects_removed_provider(self, tmp_path):
|
provider = Provider(
|
||||||
"""Test that reload detects removed providers."""
|
name="airllm",
|
||||||
config = {
|
type="airllm",
|
||||||
"providers": [
|
enabled=True,
|
||||||
{
|
priority=1,
|
||||||
"name": "openai-1",
|
)
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 1,
|
|
||||||
"api_key": "sk-test",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "anthropic-1",
|
|
||||||
"type": "anthropic",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 2,
|
|
||||||
"api_key": "sk-ant-test",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
config_path = tmp_path / "providers.yaml"
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router = CascadeRouter(config_path=config_path)
|
with patch("importlib.util.find_spec", return_value=None):
|
||||||
assert len(router.providers) == 2
|
assert router._check_provider_available(provider) is False
|
||||||
|
|
||||||
# Remove anthropic
|
|
||||||
config["providers"] = [config["providers"][0]]
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
result = router.reload_config()
|
|
||||||
|
|
||||||
assert result["total_providers"] == 1
|
|
||||||
assert result["preserved"] == 1
|
|
||||||
assert result["removed"] == ["anthropic-1"]
|
|
||||||
|
|
||||||
def test_reload_re_sorts_by_priority(self, tmp_path):
|
|
||||||
"""Test that providers are re-sorted by priority after reload."""
|
|
||||||
config = {
|
|
||||||
"providers": [
|
|
||||||
{
|
|
||||||
"name": "low-priority",
|
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 10,
|
|
||||||
"api_key": "sk-test",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "high-priority",
|
|
||||||
"type": "openai",
|
|
||||||
"enabled": True,
|
|
||||||
"priority": 1,
|
|
||||||
"api_key": "sk-test2",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
config_path = tmp_path / "providers.yaml"
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router = CascadeRouter(config_path=config_path)
|
|
||||||
assert router.providers[0].name == "high-priority"
|
|
||||||
|
|
||||||
# Swap priorities
|
|
||||||
config["providers"][0]["priority"] = 1
|
|
||||||
config["providers"][1]["priority"] = 10
|
|
||||||
config_path.write_text(yaml.dump(config))
|
|
||||||
|
|
||||||
router.reload_config()
|
|
||||||
|
|
||||||
assert router.providers[0].name == "low-priority"
|
|
||||||
assert router.providers[1].name == "high-priority"
|
|
||||||
|
|||||||
@@ -49,34 +49,6 @@ class TestConfigLazyValidation:
|
|||||||
# Should not raise
|
# Should not raise
|
||||||
validate_startup(force=True)
|
validate_startup(force=True)
|
||||||
|
|
||||||
def test_validate_startup_exits_on_cors_wildcard_in_production(self):
|
|
||||||
"""validate_startup() should exit in production when CORS has wildcard."""
|
|
||||||
from config import settings, validate_startup
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(settings, "timmy_env", "production"),
|
|
||||||
patch.object(settings, "l402_hmac_secret", "test-secret-hex-value-32"),
|
|
||||||
patch.object(settings, "l402_macaroon_secret", "test-macaroon-hex-value-32"),
|
|
||||||
patch.object(settings, "cors_origins", ["*"]),
|
|
||||||
pytest.raises(SystemExit),
|
|
||||||
):
|
|
||||||
validate_startup(force=True)
|
|
||||||
|
|
||||||
def test_validate_startup_warns_cors_wildcard_in_dev(self):
|
|
||||||
"""validate_startup() should warn in dev when CORS has wildcard."""
|
|
||||||
from config import settings, validate_startup
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(settings, "timmy_env", "development"),
|
|
||||||
patch.object(settings, "cors_origins", ["*"]),
|
|
||||||
patch("config._startup_logger") as mock_logger,
|
|
||||||
):
|
|
||||||
validate_startup(force=True)
|
|
||||||
mock_logger.warning.assert_any_call(
|
|
||||||
"SEC: CORS_ORIGINS contains wildcard '*' — "
|
|
||||||
"restrict to explicit origins before deploying to production."
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_validate_startup_skips_in_test_mode(self):
|
def test_validate_startup_skips_in_test_mode(self):
|
||||||
"""validate_startup() should be a no-op in test mode."""
|
"""validate_startup() should be a no-op in test mode."""
|
||||||
from config import validate_startup
|
from config import validate_startup
|
||||||
|
|||||||
@@ -8,14 +8,11 @@ from fastapi.testclient import TestClient
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def serve_client():
|
def serve_client():
|
||||||
"""Create a TestClient for the timmy-serve app with mocked Timmy agent."""
|
"""Create a TestClient for the timmy-serve app."""
|
||||||
with patch("timmy_serve.app.create_timmy") as mock_create:
|
from timmy_serve.app import create_timmy_serve_app
|
||||||
mock_create.return_value = MagicMock()
|
|
||||||
from timmy_serve.app import create_timmy_serve_app
|
|
||||||
|
|
||||||
app = create_timmy_serve_app()
|
app = create_timmy_serve_app()
|
||||||
with TestClient(app) as client:
|
return TestClient(app)
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
class TestHealthEndpoint:
|
class TestHealthEndpoint:
|
||||||
@@ -37,40 +34,18 @@ class TestServeStatus:
|
|||||||
|
|
||||||
class TestServeChatEndpoint:
|
class TestServeChatEndpoint:
|
||||||
@patch("timmy_serve.app.create_timmy")
|
@patch("timmy_serve.app.create_timmy")
|
||||||
def test_chat_returns_response(self, mock_create):
|
def test_chat_returns_response(self, mock_create, serve_client):
|
||||||
mock_agent = MagicMock()
|
mock_agent = MagicMock()
|
||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.content = "I am Timmy."
|
mock_result.content = "I am Timmy."
|
||||||
mock_agent.run.return_value = mock_result
|
mock_agent.run.return_value = mock_result
|
||||||
mock_create.return_value = mock_agent
|
mock_create.return_value = mock_agent
|
||||||
|
|
||||||
from timmy_serve.app import create_timmy_serve_app
|
resp = serve_client.post(
|
||||||
|
"/serve/chat",
|
||||||
app = create_timmy_serve_app()
|
json={"message": "Who are you?"},
|
||||||
with TestClient(app) as client:
|
)
|
||||||
resp = client.post(
|
|
||||||
"/serve/chat",
|
|
||||||
json={"message": "Who are you?"},
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["response"] == "I am Timmy."
|
assert data["response"] == "I am Timmy."
|
||||||
mock_agent.run.assert_called_once_with("Who are you?", stream=False)
|
mock_agent.run.assert_called_once_with("Who are you?", stream=False)
|
||||||
|
|
||||||
@patch("timmy_serve.app.create_timmy")
|
|
||||||
def test_agent_cached_at_startup(self, mock_create):
|
|
||||||
"""Verify create_timmy is called once at startup, not per request."""
|
|
||||||
mock_agent = MagicMock()
|
|
||||||
mock_result = MagicMock()
|
|
||||||
mock_result.content = "reply"
|
|
||||||
mock_agent.run.return_value = mock_result
|
|
||||||
mock_create.return_value = mock_agent
|
|
||||||
|
|
||||||
from timmy_serve.app import create_timmy_serve_app
|
|
||||||
|
|
||||||
app = create_timmy_serve_app()
|
|
||||||
with TestClient(app) as client:
|
|
||||||
# Two requests — create_timmy should only be called once (at startup)
|
|
||||||
client.post("/serve/chat", json={"message": "hello"})
|
|
||||||
client.post("/serve/chat", json={"message": "world"})
|
|
||||||
mock_create.assert_called_once()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user