forked from Rockachopa/Timmy-time-dashboard
Compare commits
7 Commits
chore/migr
...
fix/csrf-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e789e2da | ||
| d2a5866650 | |||
| 2381d0b6d0 | |||
| 03ad2027a4 | |||
| 2bfc44ea1b | |||
| fe1fa78ef1 | |||
| 3c46a1b202 |
@@ -84,6 +84,7 @@ class Settings(BaseSettings):
|
||||
# Only used when explicitly enabled and query complexity warrants it.
|
||||
grok_enabled: bool = False
|
||||
xai_api_key: str = ""
|
||||
xai_base_url: str = "https://api.x.ai/v1"
|
||||
grok_default_model: str = "grok-3-fast"
|
||||
grok_max_sats_per_query: int = 200
|
||||
grok_free: bool = False # Skip Lightning invoice when user has own API key
|
||||
|
||||
@@ -377,73 +377,78 @@ def _startup_background_tasks() -> list[asyncio.Task]:
|
||||
]
|
||||
|
||||
|
||||
def _try_prune(label: str, prune_fn, days: int) -> None:
|
||||
"""Run a prune function, log results, swallow errors."""
|
||||
try:
|
||||
pruned = prune_fn()
|
||||
if pruned:
|
||||
logger.info(
|
||||
"%s auto-prune: removed %d entries older than %d days",
|
||||
label,
|
||||
pruned,
|
||||
days,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("%s auto-prune skipped: %s", label, exc)
|
||||
|
||||
|
||||
def _check_vault_size() -> None:
|
||||
"""Warn if the memory vault exceeds the configured size limit."""
|
||||
try:
|
||||
vault_path = Path(settings.repo_root) / "memory" / "notes"
|
||||
if vault_path.exists():
|
||||
total_bytes = sum(f.stat().st_size for f in vault_path.rglob("*") if f.is_file())
|
||||
total_mb = total_bytes / (1024 * 1024)
|
||||
if total_mb > settings.memory_vault_max_mb:
|
||||
logger.warning(
|
||||
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
|
||||
total_mb,
|
||||
settings.memory_vault_max_mb,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Vault size check skipped: %s", exc)
|
||||
|
||||
|
||||
def _startup_pruning() -> None:
|
||||
"""Auto-prune old memories, thoughts, and events on startup."""
|
||||
if settings.memory_prune_days > 0:
|
||||
try:
|
||||
from timmy.memory_system import prune_memories
|
||||
from timmy.memory_system import prune_memories
|
||||
|
||||
pruned = prune_memories(
|
||||
_try_prune(
|
||||
"Memory",
|
||||
lambda: prune_memories(
|
||||
older_than_days=settings.memory_prune_days,
|
||||
keep_facts=settings.memory_prune_keep_facts,
|
||||
)
|
||||
if pruned:
|
||||
logger.info(
|
||||
"Memory auto-prune: removed %d entries older than %d days",
|
||||
pruned,
|
||||
settings.memory_prune_days,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Memory auto-prune skipped: %s", exc)
|
||||
),
|
||||
settings.memory_prune_days,
|
||||
)
|
||||
|
||||
if settings.thoughts_prune_days > 0:
|
||||
try:
|
||||
from timmy.thinking import thinking_engine
|
||||
from timmy.thinking import thinking_engine
|
||||
|
||||
pruned = thinking_engine.prune_old_thoughts(
|
||||
_try_prune(
|
||||
"Thought",
|
||||
lambda: thinking_engine.prune_old_thoughts(
|
||||
keep_days=settings.thoughts_prune_days,
|
||||
keep_min=settings.thoughts_prune_keep_min,
|
||||
)
|
||||
if pruned:
|
||||
logger.info(
|
||||
"Thought auto-prune: removed %d entries older than %d days",
|
||||
pruned,
|
||||
settings.thoughts_prune_days,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Thought auto-prune skipped: %s", exc)
|
||||
),
|
||||
settings.thoughts_prune_days,
|
||||
)
|
||||
|
||||
if settings.events_prune_days > 0:
|
||||
try:
|
||||
from swarm.event_log import prune_old_events
|
||||
from swarm.event_log import prune_old_events
|
||||
|
||||
pruned = prune_old_events(
|
||||
_try_prune(
|
||||
"Event",
|
||||
lambda: prune_old_events(
|
||||
keep_days=settings.events_prune_days,
|
||||
keep_min=settings.events_prune_keep_min,
|
||||
)
|
||||
if pruned:
|
||||
logger.info(
|
||||
"Event auto-prune: removed %d entries older than %d days",
|
||||
pruned,
|
||||
settings.events_prune_days,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Event auto-prune skipped: %s", exc)
|
||||
),
|
||||
settings.events_prune_days,
|
||||
)
|
||||
|
||||
if settings.memory_vault_max_mb > 0:
|
||||
try:
|
||||
vault_path = Path(settings.repo_root) / "memory" / "notes"
|
||||
if vault_path.exists():
|
||||
total_bytes = sum(f.stat().st_size for f in vault_path.rglob("*") if f.is_file())
|
||||
total_mb = total_bytes / (1024 * 1024)
|
||||
if total_mb > settings.memory_vault_max_mb:
|
||||
logger.warning(
|
||||
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
|
||||
total_mb,
|
||||
settings.memory_vault_max_mb,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Vault size check skipped: %s", exc)
|
||||
_check_vault_size()
|
||||
|
||||
|
||||
async def _shutdown_cleanup(
|
||||
|
||||
@@ -175,18 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Token validation failed and path is not exempt
|
||||
# We still need to call the app to check if the endpoint is decorated
|
||||
# with @csrf_exempt, so we'll let it through and check after routing
|
||||
response = await call_next(request)
|
||||
|
||||
# After routing, check if the endpoint is marked as exempt
|
||||
endpoint = request.scope.get("endpoint")
|
||||
# Resolve the endpoint from routes BEFORE executing to avoid side effects
|
||||
endpoint = self._resolve_endpoint(request)
|
||||
if endpoint and is_csrf_exempt(endpoint):
|
||||
# Endpoint is marked as exempt, allow the response
|
||||
return response
|
||||
return await call_next(request)
|
||||
|
||||
# Endpoint is not exempt and token validation failed
|
||||
# Return 403 error
|
||||
# Endpoint is not exempt and token validation failed — reject without executing
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
@@ -196,6 +190,42 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
def _resolve_endpoint(self, request: Request) -> Callable | None:
|
||||
"""Resolve the endpoint for a request without executing it.
|
||||
|
||||
Walks the app chain to find routes, then matches against the request
|
||||
scope. This allows checking @csrf_exempt before the handler runs
|
||||
(avoiding side effects on CSRF rejection).
|
||||
|
||||
Returns:
|
||||
The endpoint callable if found, None otherwise.
|
||||
"""
|
||||
try:
|
||||
from starlette.routing import Match
|
||||
|
||||
# Walk the middleware/app chain to find something with routes
|
||||
routes = None
|
||||
current = self.app
|
||||
for _ in range(10): # Safety limit
|
||||
routes = getattr(current, "routes", None)
|
||||
if routes:
|
||||
break
|
||||
current = getattr(current, "app", None)
|
||||
if current is None:
|
||||
break
|
||||
|
||||
if not routes:
|
||||
return None
|
||||
|
||||
scope = dict(request.scope)
|
||||
for route in routes:
|
||||
match, child_scope = route.matches(scope)
|
||||
if match == Match.FULL:
|
||||
return child_scope.get("endpoint")
|
||||
except Exception:
|
||||
logger.debug("Failed to resolve endpoint for CSRF check")
|
||||
return None
|
||||
|
||||
def _is_likely_exempt(self, path: str) -> bool:
|
||||
"""Check if a path is likely to be CSRF exempt.
|
||||
|
||||
|
||||
@@ -149,6 +149,52 @@ def _log_error_event(
|
||||
logger.debug("Failed to log error event: %s", log_exc)
|
||||
|
||||
|
||||
def _build_report_description(
|
||||
exc: Exception,
|
||||
source: str,
|
||||
context: dict | None,
|
||||
error_hash: str,
|
||||
tb_str: str,
|
||||
affected_file: str,
|
||||
affected_line: int,
|
||||
git_ctx: dict,
|
||||
) -> str:
|
||||
"""Build the markdown description for a bug report task."""
|
||||
parts = [
|
||||
f"**Error:** {type(exc).__name__}: {str(exc)}",
|
||||
f"**Source:** {source}",
|
||||
f"**File:** {affected_file}:{affected_line}",
|
||||
f"**Git:** {git_ctx.get('branch', '?')} @ {git_ctx.get('commit', '?')}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Hash:** {error_hash}",
|
||||
]
|
||||
|
||||
if context:
|
||||
ctx_str = ", ".join(f"{k}={v}" for k, v in context.items())
|
||||
parts.append(f"**Context:** {ctx_str}")
|
||||
|
||||
parts.append(f"\n**Stack Trace:**\n```\n{tb_str[:2000]}\n```")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _log_bug_report_created(source: str, task_id: str, error_hash: str, title: str) -> None:
|
||||
"""Log a BUG_REPORT_CREATED event (best-effort)."""
|
||||
try:
|
||||
from swarm.event_log import EventType, log_event
|
||||
|
||||
log_event(
|
||||
EventType.BUG_REPORT_CREATED,
|
||||
source=source,
|
||||
task_id=task_id,
|
||||
data={
|
||||
"error_hash": error_hash,
|
||||
"title": title[:100],
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Bug report event log error: %s", exc)
|
||||
|
||||
|
||||
def _create_bug_report(
|
||||
exc: Exception,
|
||||
source: str,
|
||||
@@ -164,25 +210,20 @@ def _create_bug_report(
|
||||
from swarm.task_queue.models import create_task
|
||||
|
||||
title = f"[BUG] {type(exc).__name__}: {str(exc)[:80]}"
|
||||
|
||||
description_parts = [
|
||||
f"**Error:** {type(exc).__name__}: {str(exc)}",
|
||||
f"**Source:** {source}",
|
||||
f"**File:** {affected_file}:{affected_line}",
|
||||
f"**Git:** {git_ctx.get('branch', '?')} @ {git_ctx.get('commit', '?')}",
|
||||
f"**Time:** {datetime.now(UTC).isoformat()}",
|
||||
f"**Hash:** {error_hash}",
|
||||
]
|
||||
|
||||
if context:
|
||||
ctx_str = ", ".join(f"{k}={v}" for k, v in context.items())
|
||||
description_parts.append(f"**Context:** {ctx_str}")
|
||||
|
||||
description_parts.append(f"\n**Stack Trace:**\n```\n{tb_str[:2000]}\n```")
|
||||
description = _build_report_description(
|
||||
exc,
|
||||
source,
|
||||
context,
|
||||
error_hash,
|
||||
tb_str,
|
||||
affected_file,
|
||||
affected_line,
|
||||
git_ctx,
|
||||
)
|
||||
|
||||
task = create_task(
|
||||
title=title,
|
||||
description="\n".join(description_parts),
|
||||
description=description,
|
||||
assigned_to="default",
|
||||
created_by="system",
|
||||
priority="normal",
|
||||
@@ -190,24 +231,9 @@ def _create_bug_report(
|
||||
auto_approve=True,
|
||||
task_type="bug_report",
|
||||
)
|
||||
task_id = task.id
|
||||
|
||||
try:
|
||||
from swarm.event_log import EventType, log_event
|
||||
|
||||
log_event(
|
||||
EventType.BUG_REPORT_CREATED,
|
||||
source=source,
|
||||
task_id=task_id,
|
||||
data={
|
||||
"error_hash": error_hash,
|
||||
"title": title[:100],
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Bug report screenshot error: %s", exc)
|
||||
|
||||
return task_id
|
||||
_log_bug_report_created(source, task.id, error_hash, title)
|
||||
return task.id
|
||||
|
||||
except Exception as task_exc:
|
||||
logger.debug("Failed to create bug report task: %s", task_exc)
|
||||
|
||||
@@ -64,7 +64,7 @@ class EventBus:
|
||||
|
||||
@bus.subscribe("agent.task.*")
|
||||
async def handle_task(event: Event):
|
||||
logger.debug(f"Task event: {event.data}")
|
||||
logger.debug("Task event: %s", event.data)
|
||||
|
||||
await bus.publish(Event(
|
||||
type="agent.task.assigned",
|
||||
|
||||
@@ -221,65 +221,56 @@ class CascadeRouter:
|
||||
raise RuntimeError("PyYAML not installed")
|
||||
|
||||
content = self.config_path.read_text()
|
||||
# Expand environment variables
|
||||
content = self._expand_env_vars(content)
|
||||
data = yaml.safe_load(content)
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
|
||||
"failure_threshold", 5
|
||||
),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
|
||||
"recovery_timeout", 60
|
||||
),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
|
||||
"half_open_max_calls", 2
|
||||
),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
# Load providers
|
||||
for p_data in data.get("providers", []):
|
||||
# Skip disabled providers
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
enabled=p_data.get("enabled", True),
|
||||
priority=p_data.get("priority", 99),
|
||||
url=p_data.get("url"),
|
||||
api_key=p_data.get("api_key"),
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
# Check if provider is actually available
|
||||
if self._check_provider_available(provider):
|
||||
self.providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
# Sort by priority
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
self.config = self._parse_router_config(data)
|
||||
self._load_providers(data)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load config: %s", exc)
|
||||
|
||||
def _parse_router_config(self, data: dict) -> RouterConfig:
|
||||
"""Build a RouterConfig from parsed YAML data."""
|
||||
cascade = data.get("cascade", {})
|
||||
cb = cascade.get("circuit_breaker", {})
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
return RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cb.get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=data.get("fallback_chains", {}),
|
||||
)
|
||||
|
||||
def _load_providers(self, data: dict) -> None:
|
||||
"""Load, filter, and sort providers from parsed YAML data."""
|
||||
for p_data in data.get("providers", []):
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
enabled=p_data.get("enabled", True),
|
||||
priority=p_data.get("priority", 99),
|
||||
url=p_data.get("url"),
|
||||
api_key=p_data.get("api_key"),
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
if self._check_provider_available(provider):
|
||||
self.providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
|
||||
def _expand_env_vars(self, content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content.
|
||||
|
||||
@@ -768,7 +759,7 @@ class CascadeRouter:
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url or "https://api.x.ai/v1",
|
||||
base_url=provider.base_url or settings.xai_base_url,
|
||||
timeout=httpx.Timeout(300.0),
|
||||
)
|
||||
|
||||
|
||||
@@ -99,23 +99,27 @@ class GrokBackend:
|
||||
|
||||
def _get_client(self):
|
||||
"""Create OpenAI client configured for xAI endpoint."""
|
||||
from config import settings
|
||||
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
|
||||
return OpenAI(
|
||||
api_key=self._api_key,
|
||||
base_url="https://api.x.ai/v1",
|
||||
base_url=settings.xai_base_url,
|
||||
timeout=httpx.Timeout(300.0),
|
||||
)
|
||||
|
||||
async def _get_async_client(self):
|
||||
"""Create async OpenAI client configured for xAI endpoint."""
|
||||
from config import settings
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
return AsyncOpenAI(
|
||||
api_key=self._api_key,
|
||||
base_url="https://api.x.ai/v1",
|
||||
base_url=settings.xai_base_url,
|
||||
timeout=httpx.Timeout(300.0),
|
||||
)
|
||||
|
||||
|
||||
@@ -46,6 +46,64 @@ DB_PATH = PROJECT_ROOT / "data" / "memory.db"
|
||||
# ───────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
_DEFAULT_HOT_MEMORY_TEMPLATE = """\
|
||||
# Timmy Hot Memory
|
||||
|
||||
> Working RAM — always loaded, ~300 lines max, pruned monthly
|
||||
> Last updated: {date}
|
||||
|
||||
---
|
||||
|
||||
## Current Status
|
||||
|
||||
**Agent State:** Operational
|
||||
**Mode:** Development
|
||||
**Active Tasks:** 0
|
||||
**Pending Decisions:** None
|
||||
|
||||
---
|
||||
|
||||
## Standing Rules
|
||||
|
||||
1. **Sovereignty First** — No cloud dependencies
|
||||
2. **Local-Only Inference** — Ollama on localhost
|
||||
3. **Privacy by Design** — Telemetry disabled
|
||||
4. **Tool Minimalism** — Use tools only when necessary
|
||||
5. **Memory Discipline** — Write handoffs at session end
|
||||
|
||||
---
|
||||
|
||||
## Agent Roster
|
||||
|
||||
| Agent | Role | Status |
|
||||
|-------|------|--------|
|
||||
| Timmy | Core | Active |
|
||||
|
||||
---
|
||||
|
||||
## User Profile
|
||||
|
||||
**Name:** (not set)
|
||||
**Interests:** (to be learned)
|
||||
|
||||
---
|
||||
|
||||
## Key Decisions
|
||||
|
||||
(none yet)
|
||||
|
||||
---
|
||||
|
||||
## Pending Actions
|
||||
|
||||
- [ ] Learn user's name
|
||||
|
||||
---
|
||||
|
||||
*Prune date: {prune_date}*
|
||||
"""
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_connection() -> Generator[sqlite3.Connection, None, None]:
|
||||
"""Get database connection to unified memory database."""
|
||||
@@ -732,66 +790,12 @@ class HotMemory:
|
||||
logger.debug(
|
||||
"HotMemory._create_default() - creating default MEMORY.md for backward compatibility"
|
||||
)
|
||||
default_content = """# Timmy Hot Memory
|
||||
|
||||
> Working RAM — always loaded, ~300 lines max, pruned monthly
|
||||
> Last updated: {date}
|
||||
|
||||
---
|
||||
|
||||
## Current Status
|
||||
|
||||
**Agent State:** Operational
|
||||
**Mode:** Development
|
||||
**Active Tasks:** 0
|
||||
**Pending Decisions:** None
|
||||
|
||||
---
|
||||
|
||||
## Standing Rules
|
||||
|
||||
1. **Sovereignty First** — No cloud dependencies
|
||||
2. **Local-Only Inference** — Ollama on localhost
|
||||
3. **Privacy by Design** — Telemetry disabled
|
||||
4. **Tool Minimalism** — Use tools only when necessary
|
||||
5. **Memory Discipline** — Write handoffs at session end
|
||||
|
||||
---
|
||||
|
||||
## Agent Roster
|
||||
|
||||
| Agent | Role | Status |
|
||||
|-------|------|--------|
|
||||
| Timmy | Core | Active |
|
||||
|
||||
---
|
||||
|
||||
## User Profile
|
||||
|
||||
**Name:** (not set)
|
||||
**Interests:** (to be learned)
|
||||
|
||||
---
|
||||
|
||||
## Key Decisions
|
||||
|
||||
(none yet)
|
||||
|
||||
---
|
||||
|
||||
## Pending Actions
|
||||
|
||||
- [ ] Learn user's name
|
||||
|
||||
---
|
||||
|
||||
*Prune date: {prune_date}*
|
||||
""".format(
|
||||
date=datetime.now(UTC).strftime("%Y-%m-%d"),
|
||||
prune_date=(datetime.now(UTC).replace(day=25)).strftime("%Y-%m-%d"),
|
||||
now = datetime.now(UTC)
|
||||
content = _DEFAULT_HOT_MEMORY_TEMPLATE.format(
|
||||
date=now.strftime("%Y-%m-%d"),
|
||||
prune_date=now.replace(day=25).strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
self.path.write_text(default_content)
|
||||
self.path.write_text(content)
|
||||
logger.info("HotMemory: Created default MEMORY.md")
|
||||
|
||||
|
||||
|
||||
100
tests/dashboard/middleware/test_csrf_no_side_effects.py
Normal file
100
tests/dashboard/middleware/test_csrf_no_side_effects.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Tests that CSRF rejection does NOT execute the endpoint handler.
|
||||
|
||||
Regression test for #626: the middleware was calling call_next() before
|
||||
checking @csrf_exempt, causing side effects even on CSRF-rejected requests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from dashboard.middleware.csrf import CSRFMiddleware, csrf_exempt
|
||||
|
||||
|
||||
class TestCSRFNoSideEffects:
|
||||
"""Verify endpoints are NOT executed when CSRF validation fails."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_csrf(self):
|
||||
"""Re-enable CSRF for these tests."""
|
||||
from config import settings
|
||||
|
||||
original = settings.timmy_disable_csrf
|
||||
settings.timmy_disable_csrf = False
|
||||
yield
|
||||
settings.timmy_disable_csrf = original
|
||||
|
||||
def test_protected_endpoint_not_executed_on_csrf_failure(self):
|
||||
"""A protected endpoint must NOT run when CSRF token is missing.
|
||||
|
||||
Before the fix, the middleware called call_next() to resolve the
|
||||
endpoint, executing its side effects before returning 403.
|
||||
"""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
side_effect_log = []
|
||||
|
||||
@app.post("/transfer")
|
||||
def transfer_money():
|
||||
side_effect_log.append("money_transferred")
|
||||
return {"message": "transferred"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/transfer")
|
||||
|
||||
assert response.status_code == 403
|
||||
assert side_effect_log == [], (
|
||||
"Endpoint was executed despite CSRF failure — side effects occurred!"
|
||||
)
|
||||
|
||||
def test_csrf_exempt_endpoint_still_executes(self):
|
||||
"""A @csrf_exempt endpoint should still execute without a CSRF token."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
side_effect_log = []
|
||||
|
||||
@app.post("/webhook-handler")
|
||||
@csrf_exempt
|
||||
def webhook_handler():
|
||||
side_effect_log.append("webhook_processed")
|
||||
return {"message": "processed"}
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.post("/webhook-handler")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert side_effect_log == ["webhook_processed"]
|
||||
|
||||
def test_exempt_and_protected_no_cross_contamination(self):
|
||||
"""Mixed exempt/protected: only exempt endpoints execute without tokens."""
|
||||
app = FastAPI()
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
execution_log = []
|
||||
|
||||
@app.post("/safe-webhook")
|
||||
@csrf_exempt
|
||||
def safe_webhook():
|
||||
execution_log.append("safe")
|
||||
return {"message": "safe"}
|
||||
|
||||
@app.post("/dangerous-action")
|
||||
def dangerous_action():
|
||||
execution_log.append("dangerous")
|
||||
return {"message": "danger"}
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Exempt endpoint runs
|
||||
resp1 = client.post("/safe-webhook")
|
||||
assert resp1.status_code == 200
|
||||
|
||||
# Protected endpoint blocked WITHOUT executing
|
||||
resp2 = client.post("/dangerous-action")
|
||||
assert resp2.status_code == 403
|
||||
|
||||
assert execution_log == ["safe"], (
|
||||
f"Expected only 'safe' execution, got: {execution_log}"
|
||||
)
|
||||
@@ -5,11 +5,13 @@ from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
from infrastructure.error_capture import (
|
||||
_build_report_description,
|
||||
_create_bug_report,
|
||||
_dedup_cache,
|
||||
_extract_traceback_info,
|
||||
_get_git_context,
|
||||
_is_duplicate,
|
||||
_log_bug_report_created,
|
||||
_log_error_event,
|
||||
_notify_bug_report,
|
||||
_record_to_session,
|
||||
@@ -231,6 +233,68 @@ class TestLogErrorEvent:
|
||||
_log_error_event(e, "test", "abc123", "file.py", 42, {"branch": "main"})
|
||||
|
||||
|
||||
class TestBuildReportDescription:
|
||||
"""Test _build_report_description helper."""
|
||||
|
||||
def test_includes_error_info(self):
|
||||
try:
|
||||
raise RuntimeError("desc test")
|
||||
except RuntimeError as e:
|
||||
desc = _build_report_description(
|
||||
e,
|
||||
"test_src",
|
||||
None,
|
||||
"hash1",
|
||||
"tb...",
|
||||
"file.py",
|
||||
10,
|
||||
{"branch": "main"},
|
||||
)
|
||||
assert "RuntimeError" in desc
|
||||
assert "test_src" in desc
|
||||
assert "file.py:10" in desc
|
||||
assert "hash1" in desc
|
||||
|
||||
def test_includes_context_when_provided(self):
|
||||
try:
|
||||
raise RuntimeError("ctx desc")
|
||||
except RuntimeError as e:
|
||||
desc = _build_report_description(
|
||||
e,
|
||||
"src",
|
||||
{"path": "/api"},
|
||||
"h",
|
||||
"tb",
|
||||
"f.py",
|
||||
1,
|
||||
{},
|
||||
)
|
||||
assert "path=/api" in desc
|
||||
|
||||
def test_omits_context_when_none(self):
|
||||
try:
|
||||
raise RuntimeError("no ctx")
|
||||
except RuntimeError as e:
|
||||
desc = _build_report_description(
|
||||
e,
|
||||
"src",
|
||||
None,
|
||||
"h",
|
||||
"tb",
|
||||
"f.py",
|
||||
1,
|
||||
{},
|
||||
)
|
||||
assert "**Context:**" not in desc
|
||||
|
||||
|
||||
class TestLogBugReportCreated:
|
||||
"""Test _log_bug_report_created helper."""
|
||||
|
||||
def test_does_not_crash_on_missing_deps(self):
|
||||
_log_bug_report_created("test", "task-1", "hash1", "title")
|
||||
|
||||
|
||||
class TestCreateBugReport:
|
||||
"""Test _create_bug_report helper."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user