forked from Rockachopa/Timmy-time-dashboard
Compare commits
1 Commits
claude/iss
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab4b2f938d |
@@ -1,35 +0,0 @@
|
||||
# Research Report: Task #1341
|
||||
|
||||
**Date:** 2026-03-23
|
||||
**Issue:** [#1341](http://143.198.27.163:3000/Rockachopa/Timmy-time-dashboard/issues/1341)
|
||||
**Priority:** normal
|
||||
**Delegated by:** Timmy via Kimi delegation pipeline
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
This issue was submitted as a placeholder via the Kimi delegation pipeline with unfilled template fields:
|
||||
|
||||
- **Research Question:** `Q?` (template default — no actual question provided)
|
||||
- **Background / Context:** `ctx` (template default — no context provided)
|
||||
- **Task:** `Task` (template default — no task specified)
|
||||
|
||||
## Findings
|
||||
|
||||
No actionable research question was specified. The issue appears to be a test or
|
||||
accidental submission of an unfilled delegation template.
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Re-open with a real question** if there is a specific topic to research.
|
||||
2. **Review the delegation pipeline** to add validation that prevents empty/template-default
|
||||
submissions from reaching the backlog (e.g. reject issues where the body contains
|
||||
literal placeholder strings like `Q?` or `ctx`).
|
||||
3. **Add a pipeline guard** in the Kimi delegation script to require non-empty, non-default
|
||||
values for `Research Question` and `Background / Context` before creating an issue.
|
||||
|
||||
## Next Steps
|
||||
|
||||
- [ ] Add input validation to Kimi delegation pipeline
|
||||
- [ ] Re-file with a concrete research question if needed
|
||||
File diff suppressed because it is too large
Load Diff
123
src/infrastructure/router/config_loader.py
Normal file
123
src/infrastructure/router/config_loader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Config loading helpers for the Cascade LLM Router.
|
||||
|
||||
Parses providers.yaml, expands env vars, and checks provider availability.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from infrastructure.router.models import Provider, RouterConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None # type: ignore
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
requests = None # type: ignore
|
||||
|
||||
|
||||
def expand_env_vars(content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content.
|
||||
|
||||
Uses os.environ directly (not settings) because this is a generic
|
||||
YAML config loader that must expand arbitrary variable references.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
def replace_var(match: "re.Match[str]") -> str:
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||
|
||||
|
||||
def parse_router_config(data: dict) -> RouterConfig:
|
||||
"""Build a RouterConfig from parsed YAML data."""
|
||||
cascade = data.get("cascade", {})
|
||||
cb = cascade.get("circuit_breaker", {})
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
return RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cb.get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cb.get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cb.get("half_open_max_calls", 2),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=data.get("fallback_chains", {}),
|
||||
)
|
||||
|
||||
|
||||
def load_providers(data: dict) -> list[Provider]:
|
||||
"""Load and filter providers from parsed YAML data (unsorted)."""
|
||||
providers: list[Provider] = []
|
||||
for p_data in data.get("providers", []):
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
enabled=p_data.get("enabled", True),
|
||||
priority=p_data.get("priority", 99),
|
||||
tier=p_data.get("tier"),
|
||||
url=p_data.get("url"),
|
||||
api_key=p_data.get("api_key"),
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
if check_provider_available(provider):
|
||||
providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def check_provider_available(provider: Provider) -> bool:
|
||||
"""Check if a provider is actually available."""
|
||||
from config import settings
|
||||
|
||||
if provider.type == "ollama":
|
||||
# Check if Ollama is running
|
||||
if requests is None:
|
||||
# Can't check without requests, assume available
|
||||
return True
|
||||
try:
|
||||
url = provider.url or settings.ollama_url
|
||||
response = requests.get(f"{url}/api/tags", timeout=5)
|
||||
return response.status_code == 200
|
||||
except Exception as exc:
|
||||
logger.debug("Ollama provider check error: %s", exc)
|
||||
return False
|
||||
|
||||
elif provider.type == "vllm_mlx":
|
||||
# Check if local vllm-mlx server is running (OpenAI-compatible)
|
||||
if requests is None:
|
||||
return True
|
||||
try:
|
||||
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
||||
# Strip /v1 suffix — health endpoint is at the root
|
||||
server_root = base_url.rstrip("/")
|
||||
if server_root.endswith("/v1"):
|
||||
server_root = server_root[:-3]
|
||||
response = requests.get(f"{server_root}/health", timeout=5)
|
||||
return response.status_code == 200
|
||||
except Exception as exc:
|
||||
logger.debug("vllm-mlx provider check error: %s", exc)
|
||||
return False
|
||||
|
||||
elif provider.type in ("openai", "anthropic", "grok"):
|
||||
# Check if API key is set
|
||||
return provider.api_key is not None and provider.api_key != ""
|
||||
|
||||
return True
|
||||
129
src/infrastructure/router/content.py
Normal file
129
src/infrastructure/router/content.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Content-type detection and model selection for the Cascade LLM Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.router.models import ContentType, Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def detect_content_type(messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image_url":
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
return ContentType.VISION
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
|
||||
def get_fallback_model(
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType,
|
||||
fallback_chains: dict,
|
||||
) -> str | None:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
capability_map = {
|
||||
ContentType.VISION: "vision",
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def select_model(
|
||||
provider: Provider,
|
||||
model: str | None,
|
||||
content_type: ContentType,
|
||||
mm_manager: Any,
|
||||
fallback_chains: dict,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Select the best model for the request, with vision fallback.
|
||||
|
||||
Returns:
|
||||
Tuple of (selected_model, is_fallback_model).
|
||||
"""
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback = False
|
||||
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
if content_type == ContentType.VISION:
|
||||
supports = mm_manager.model_supports(selected_model, ModelCapability.VISION)
|
||||
if not supports:
|
||||
fallback = get_fallback_model(
|
||||
provider, selected_model, content_type, fallback_chains
|
||||
)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model,
|
||||
fallback,
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name,
|
||||
)
|
||||
|
||||
return selected_model, is_fallback
|
||||
79
src/infrastructure/router/health.py
Normal file
79
src/infrastructure/router/health.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Circuit-breaker and health tracking for the Cascade LLM Router.
|
||||
|
||||
Standalone functions that mutate Provider state in place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from infrastructure.router.models import CircuitState, Provider, ProviderStatus, RouterConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def record_success(provider: Provider, latency_ms: float, config: RouterConfig) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.successful_requests += 1
|
||||
provider.metrics.total_latency_ms += latency_ms
|
||||
provider.metrics.last_request_time = datetime.now(UTC).isoformat()
|
||||
provider.metrics.consecutive_failures = 0
|
||||
|
||||
# Close circuit breaker if half-open
|
||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||
provider.half_open_calls += 1
|
||||
if provider.half_open_calls >= config.circuit_breaker_half_open_max_calls:
|
||||
close_circuit(provider)
|
||||
|
||||
# Update status based on error rate
|
||||
if provider.metrics.error_rate < 0.1:
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
elif provider.metrics.error_rate < 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
|
||||
|
||||
def record_failure(provider: Provider, config: RouterConfig) -> None:
|
||||
"""Record a failed request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.failed_requests += 1
|
||||
provider.metrics.last_error_time = datetime.now(UTC).isoformat()
|
||||
provider.metrics.consecutive_failures += 1
|
||||
|
||||
# Check if we should open circuit breaker
|
||||
if provider.metrics.consecutive_failures >= config.circuit_breaker_failure_threshold:
|
||||
open_circuit(provider)
|
||||
|
||||
# Update status
|
||||
if provider.metrics.error_rate > 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
if provider.metrics.error_rate > 0.5:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
|
||||
def open_circuit(provider: Provider) -> None:
|
||||
"""Open the circuit breaker for a provider."""
|
||||
provider.circuit_state = CircuitState.OPEN
|
||||
provider.circuit_opened_at = time.time()
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||
|
||||
|
||||
def can_close_circuit(provider: Provider, config: RouterConfig) -> bool:
|
||||
"""Check if circuit breaker can transition to half-open."""
|
||||
if provider.circuit_opened_at is None:
|
||||
return False
|
||||
elapsed = time.time() - provider.circuit_opened_at
|
||||
return elapsed >= config.circuit_breaker_recovery_timeout
|
||||
|
||||
|
||||
def close_circuit(provider: Provider) -> None:
|
||||
"""Close the circuit breaker (provider healthy again)."""
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.half_open_calls = 0
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||
141
src/infrastructure/router/models.py
Normal file
141
src/infrastructure/router/models.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Data models for the Cascade LLM Router.
|
||||
|
||||
Enums, dataclasses, and provider configuration shared across
|
||||
router sub-modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderStatus(Enum):
|
||||
"""Health status of a provider."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker state."""
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
total_latency_ms: float = 0.0
|
||||
last_request_time: str | None = None
|
||||
last_error_time: str | None = None
|
||||
consecutive_failures: int = 0
|
||||
|
||||
@property
|
||||
def avg_latency_ms(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.total_latency_ms / self.total_requests
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.failed_requests / self.total_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json: bool = False
|
||||
supports_streaming: bool = True
|
||||
context_window: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
|
||||
name: str
|
||||
type: str # ollama, openai, anthropic
|
||||
enabled: bool
|
||||
priority: int
|
||||
tier: str | None = None # e.g., "local", "standard_cloud", "frontier"
|
||||
url: str | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
models: list[dict] = field(default_factory=list)
|
||||
|
||||
# Runtime state
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||
circuit_state: CircuitState = CircuitState.CLOSED
|
||||
circuit_opened_at: float | None = None
|
||||
half_open_calls: int = 0
|
||||
|
||||
def get_default_model(self) -> str | None:
|
||||
"""Get the default model for this provider."""
|
||||
for model in self.models:
|
||||
if model.get("default"):
|
||||
return model["name"]
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> str | None:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
capabilities = model.get("capabilities", [])
|
||||
if capability in capabilities:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
if model["name"] == model_name:
|
||||
capabilities = model.get("capabilities", [])
|
||||
return capability in capabilities
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Cascade router configuration."""
|
||||
|
||||
timeout_seconds: int = 30
|
||||
max_retries_per_provider: int = 2
|
||||
retry_delay_seconds: int = 1
|
||||
circuit_breaker_failure_threshold: int = 5
|
||||
circuit_breaker_recovery_timeout: int = 60
|
||||
circuit_breaker_half_open_max_calls: int = 2
|
||||
cost_tracking_enabled: bool = True
|
||||
budget_daily_usd: float = 10.0
|
||||
# Multi-modal settings
|
||||
auto_pull_models: bool = True
|
||||
fallback_chains: dict = field(default_factory=dict)
|
||||
1
src/infrastructure/router/providers/__init__.py
Normal file
1
src/infrastructure/router/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Provider implementations
|
||||
56
src/infrastructure/router/providers/anthropic.py
Normal file
56
src/infrastructure/router/providers/anthropic.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Anthropic provider implementation for the Cascade LLM Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from infrastructure.router.models import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def call_anthropic(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
timeout_seconds: int,
|
||||
) -> dict:
|
||||
"""Call Anthropic API."""
|
||||
import anthropic
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=provider.api_key,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
# Convert messages to Anthropic format
|
||||
system_msg = None
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
system_msg = msg["content"]
|
||||
else:
|
||||
conversation.append(
|
||||
{
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
}
|
||||
)
|
||||
|
||||
kwargs: dict = {
|
||||
"model": model,
|
||||
"messages": conversation,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens or 1024,
|
||||
}
|
||||
if system_msg:
|
||||
kwargs["system"] = system_msg
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.content[0].text,
|
||||
"model": response.model,
|
||||
}
|
||||
80
src/infrastructure/router/providers/dispatch.py
Normal file
80
src/infrastructure/router/providers/dispatch.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Provider dispatch — routes a single request to the correct provider module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from infrastructure.router.models import ContentType, Provider
|
||||
|
||||
|
||||
async def call_provider(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
timeout_seconds: int,
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Dispatch a request to the correct provider implementation.
|
||||
|
||||
Returns a result dict with ``content``, ``model``, and ``latency_ms`` keys.
|
||||
Raises ValueError for unknown provider types.
|
||||
"""
|
||||
from infrastructure.router.providers import ollama as _ollama
|
||||
from infrastructure.router.providers import openai_compat as _openai_compat
|
||||
from infrastructure.router.providers import anthropic as _anthropic
|
||||
from infrastructure.router.providers import grok as _grok
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if provider.type == "ollama":
|
||||
result = await _ollama.call_ollama(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
elif provider.type == "openai":
|
||||
result = await _openai_compat.call_openai(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
elif provider.type == "anthropic":
|
||||
result = await _anthropic.call_anthropic(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
elif provider.type == "grok":
|
||||
result = await _grok.call_grok(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
elif provider.type == "vllm_mlx":
|
||||
result = await _openai_compat.call_vllm_mlx(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
timeout_seconds=timeout_seconds,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||
|
||||
result["latency_ms"] = (time.time() - start_time) * 1000
|
||||
return result
|
||||
44
src/infrastructure/router/providers/grok.py
Normal file
44
src/infrastructure/router/providers/grok.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Grok (xAI) provider implementation for the Cascade LLM Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from infrastructure.router.models import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def call_grok(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
) -> dict:
|
||||
"""Call xAI Grok API via OpenAI-compatible SDK."""
|
||||
import httpx
|
||||
import openai
|
||||
|
||||
from config import settings
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url or settings.xai_base_url,
|
||||
timeout=httpx.Timeout(300.0),
|
||||
)
|
||||
|
||||
kwargs: dict = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
92
src/infrastructure/router/providers/ollama.py
Normal file
92
src/infrastructure/router/providers/ollama.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Ollama provider implementation for the Cascade LLM Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
|
||||
from infrastructure.router.models import ContentType, Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def call_ollama(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
content_type: ContentType,
|
||||
timeout_seconds: int,
|
||||
) -> dict:
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
from config import settings
|
||||
|
||||
url = f"{provider.url or settings.ollama_url}/api/chat"
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = transform_messages_for_ollama(messages)
|
||||
|
||||
options: dict = {"temperature": temperature}
|
||||
if max_tokens:
|
||||
options["num_predict"] = max_tokens
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": transformed_messages,
|
||||
"stream": False,
|
||||
"options": options,
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||
|
||||
data = await response.json()
|
||||
return {
|
||||
"content": data["message"]["content"],
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
||||
def transform_messages_for_ollama(messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
new_msg["images"] = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
if img.startswith("data:image/"):
|
||||
# Base64 encoded image
|
||||
new_msg["images"].append(img.split(",")[1])
|
||||
elif img.startswith("http://") or img.startswith("https://"):
|
||||
# URL - would need to download, skip for now
|
||||
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||
elif Path(img).exists():
|
||||
# Local file path - read and encode
|
||||
try:
|
||||
with open(img, "rb") as f:
|
||||
img_data = base64.b64encode(f.read()).decode()
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
return transformed
|
||||
88
src/infrastructure/router/providers/openai_compat.py
Normal file
88
src/infrastructure/router/providers/openai_compat.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""OpenAI-compatible provider implementations for the Cascade LLM Router.
|
||||
|
||||
Covers the ``openai`` and ``vllm_mlx`` provider types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from infrastructure.router.models import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def call_openai(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
timeout_seconds: int,
|
||||
) -> dict:
|
||||
"""Call OpenAI API."""
|
||||
import openai
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
kwargs: dict = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
async def call_vllm_mlx(
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int | None,
|
||||
timeout_seconds: int,
|
||||
) -> dict:
|
||||
"""Call vllm-mlx via its OpenAI-compatible API.
|
||||
|
||||
vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI,
|
||||
so we reuse the OpenAI client pointed at the local server.
|
||||
No API key is required for local deployments.
|
||||
"""
|
||||
import openai
|
||||
|
||||
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
||||
# Ensure the base_url ends with /v1 as expected by the OpenAI client
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key or "no-key-required",
|
||||
base_url=base_url,
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
kwargs: dict = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
89
src/infrastructure/router/reporting.py
Normal file
89
src/infrastructure/router/reporting.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Metrics, status, and config-reload helpers for the Cascade LLM Router."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from infrastructure.router.models import (
|
||||
CircuitState,
|
||||
Provider,
|
||||
ProviderMetrics,
|
||||
ProviderStatus,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_metrics(providers: list[Provider]) -> dict:
|
||||
"""Build a metrics summary dict for all providers."""
|
||||
return {
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"status": p.status.value,
|
||||
"circuit_state": p.circuit_state.value,
|
||||
"metrics": {
|
||||
"total_requests": p.metrics.total_requests,
|
||||
"successful": p.metrics.successful_requests,
|
||||
"failed": p.metrics.failed_requests,
|
||||
"error_rate": round(p.metrics.error_rate, 3),
|
||||
"avg_latency_ms": round(p.metrics.avg_latency_ms, 2),
|
||||
},
|
||||
}
|
||||
for p in providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_status(providers: list[Provider]) -> dict:
|
||||
"""Build a status summary dict for all providers."""
|
||||
healthy = sum(1 for p in providers if p.status == ProviderStatus.HEALTHY)
|
||||
return {
|
||||
"total_providers": len(providers),
|
||||
"healthy_providers": healthy,
|
||||
"degraded_providers": sum(1 for p in providers if p.status == ProviderStatus.DEGRADED),
|
||||
"unhealthy_providers": sum(1 for p in providers if p.status == ProviderStatus.UNHEALTHY),
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
"type": p.type,
|
||||
"status": p.status.value,
|
||||
"priority": p.priority,
|
||||
"default_model": p.get_default_model(),
|
||||
}
|
||||
for p in providers
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def snapshot_provider_state(
|
||||
providers: list[Provider],
|
||||
) -> dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]]:
|
||||
"""Capture current runtime state keyed by provider name."""
|
||||
return {
|
||||
p.name: (p.metrics, p.circuit_state, p.circuit_opened_at, p.half_open_calls, p.status)
|
||||
for p in providers
|
||||
}
|
||||
|
||||
|
||||
def restore_provider_state(
|
||||
providers: list[Provider],
|
||||
old_state: dict[str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]],
|
||||
) -> int:
|
||||
"""Restore saved runtime state to matching providers. Returns count of restored providers."""
|
||||
preserved = 0
|
||||
for p in providers:
|
||||
if p.name in old_state:
|
||||
metrics, circuit, opened_at, half_open, status = old_state[p.name]
|
||||
p.metrics = metrics
|
||||
p.circuit_state = circuit
|
||||
p.circuit_opened_at = opened_at
|
||||
p.half_open_calls = half_open
|
||||
p.status = status
|
||||
preserved += 1
|
||||
return preserved
|
||||
@@ -9,7 +9,7 @@ import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from timmy.memory.crud import recall_last_activity_time, recall_last_reflection, recall_personal_facts
|
||||
from timmy.memory.crud import recall_last_reflection, recall_personal_facts
|
||||
from timmy.memory.db import HOT_MEMORY_PATH, VAULT_PATH
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -89,41 +89,25 @@ class HotMemory:
|
||||
"""Read hot memory — computed view of top facts + last reflection from DB."""
|
||||
try:
|
||||
facts = recall_personal_facts()
|
||||
lines = ["# Timmy Hot Memory\n"]
|
||||
|
||||
if facts:
|
||||
lines.append("## Known Facts\n")
|
||||
for f in facts[:15]:
|
||||
lines.append(f"- {f}")
|
||||
|
||||
# Include the last reflection if available
|
||||
reflection = recall_last_reflection()
|
||||
if reflection:
|
||||
lines.append("\n## Last Reflection\n")
|
||||
lines.append(reflection)
|
||||
|
||||
if facts or reflection:
|
||||
last_ts = recall_last_activity_time()
|
||||
try:
|
||||
updated_date = datetime.fromisoformat(last_ts).strftime("%Y-%m-%d %H:%M UTC")
|
||||
except (TypeError, ValueError):
|
||||
updated_date = datetime.now(UTC).strftime("%Y-%m-%d %H:%M UTC")
|
||||
|
||||
lines = [
|
||||
"# Timmy Hot Memory",
|
||||
"",
|
||||
"> Working RAM — always loaded, ~300 lines max, pruned monthly",
|
||||
f"> Last updated: {updated_date}",
|
||||
"",
|
||||
]
|
||||
|
||||
if facts:
|
||||
lines.append("## Known Facts")
|
||||
lines.append("")
|
||||
for f in facts[:15]:
|
||||
lines.append(f"- {f}")
|
||||
|
||||
if reflection:
|
||||
lines.append("")
|
||||
lines.append("## Last Reflection")
|
||||
lines.append("")
|
||||
lines.append(reflection)
|
||||
|
||||
if len(lines) > 1:
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception:
|
||||
logger.debug("DB context read failed, falling back to file")
|
||||
|
||||
# Fallback to file if DB unavailable or empty
|
||||
# Fallback to file if DB unavailable
|
||||
if self.path.exists():
|
||||
return self.path.read_text()
|
||||
|
||||
|
||||
@@ -393,12 +393,3 @@ def recall_last_reflection() -> str | None:
|
||||
"ORDER BY created_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return row["content"] if row else None
|
||||
|
||||
|
||||
def recall_last_activity_time() -> str | None:
|
||||
"""Return the ISO timestamp of the most recently stored memory, or None."""
|
||||
with get_connection() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT created_at FROM memories ORDER BY created_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
return row["created_at"] if row else None
|
||||
|
||||
@@ -27,7 +27,6 @@ from timmy.memory.crud import ( # noqa: F401
|
||||
get_memory_context,
|
||||
get_memory_stats,
|
||||
prune_memories,
|
||||
recall_last_activity_time,
|
||||
recall_last_reflection,
|
||||
recall_personal_facts,
|
||||
recall_personal_facts_with_ids,
|
||||
|
||||
@@ -1,598 +0,0 @@
|
||||
"""Unit tests for models/budget.py — comprehensive coverage for budget management.
|
||||
|
||||
Tests budget allocation, tracking, limit enforcement, and edge cases including:
|
||||
- Zero budget scenarios
|
||||
- Over-budget handling
|
||||
- Budget reset behavior
|
||||
- In-memory fallback when DB is unavailable
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from datetime import UTC, date, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.budget import (
|
||||
BudgetTracker,
|
||||
SpendRecord,
|
||||
estimate_cost_usd,
|
||||
get_budget_tracker,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.unit
|
||||
|
||||
|
||||
# ── Test SpendRecord dataclass ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSpendRecord:
|
||||
"""Tests for the SpendRecord dataclass."""
|
||||
|
||||
def test_spend_record_creation(self):
|
||||
"""Test creating a SpendRecord with all fields."""
|
||||
ts = time.time()
|
||||
record = SpendRecord(
|
||||
ts=ts,
|
||||
provider="anthropic",
|
||||
model="claude-haiku-4-5",
|
||||
tokens_in=100,
|
||||
tokens_out=200,
|
||||
cost_usd=0.001,
|
||||
tier="cloud",
|
||||
)
|
||||
assert record.ts == ts
|
||||
assert record.provider == "anthropic"
|
||||
assert record.model == "claude-haiku-4-5"
|
||||
assert record.tokens_in == 100
|
||||
assert record.tokens_out == 200
|
||||
assert record.cost_usd == 0.001
|
||||
assert record.tier == "cloud"
|
||||
|
||||
def test_spend_record_with_zero_tokens(self):
|
||||
"""Test SpendRecord with zero tokens."""
|
||||
ts = time.time()
|
||||
record = SpendRecord(ts=ts, provider="openai", model="gpt-4o", tokens_in=0, tokens_out=0, cost_usd=0.0, tier="cloud")
|
||||
assert record.tokens_in == 0
|
||||
assert record.tokens_out == 0
|
||||
|
||||
|
||||
# ── Test estimate_cost_usd function ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
"""Tests for the estimate_cost_usd function."""
|
||||
|
||||
def test_haiku_cheaper_than_sonnet(self):
|
||||
"""Haiku should be cheaper than Sonnet for same tokens."""
|
||||
haiku_cost = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
|
||||
sonnet_cost = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
|
||||
assert haiku_cost < sonnet_cost
|
||||
|
||||
def test_zero_tokens_is_zero_cost(self):
|
||||
"""Zero tokens should result in zero cost."""
|
||||
assert estimate_cost_usd("gpt-4o", 0, 0) == 0.0
|
||||
|
||||
def test_only_input_tokens(self):
|
||||
"""Cost calculation with only input tokens."""
|
||||
cost = estimate_cost_usd("gpt-4o", 1000, 0)
|
||||
expected = (1000 * 0.0025) / 1000.0 # $0.0025 per 1K input tokens
|
||||
assert cost == pytest.approx(expected)
|
||||
|
||||
def test_only_output_tokens(self):
|
||||
"""Cost calculation with only output tokens."""
|
||||
cost = estimate_cost_usd("gpt-4o", 0, 1000)
|
||||
expected = (1000 * 0.01) / 1000.0 # $0.01 per 1K output tokens
|
||||
assert cost == pytest.approx(expected)
|
||||
|
||||
def test_unknown_model_uses_default(self):
|
||||
"""Unknown model should use conservative default cost."""
|
||||
cost = estimate_cost_usd("some-unknown-model-xyz", 1000, 1000)
|
||||
assert cost > 0 # Uses conservative default, not zero
|
||||
# Default is 0.003 input, 0.015 output per 1K
|
||||
expected = (1000 * 0.003 + 1000 * 0.015) / 1000.0
|
||||
assert cost == pytest.approx(expected)
|
||||
|
||||
def test_versioned_model_name_matches(self):
|
||||
"""Versioned model names should match base model rates."""
|
||||
cost1 = estimate_cost_usd("claude-haiku-4-5-20251001", 1000, 0)
|
||||
cost2 = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
|
||||
assert cost1 == cost2
|
||||
|
||||
def test_gpt4o_mini_cheaper_than_gpt4o(self):
|
||||
"""GPT-4o mini should be cheaper than GPT-4o."""
|
||||
mini = estimate_cost_usd("gpt-4o-mini", 1000, 1000)
|
||||
full = estimate_cost_usd("gpt-4o", 1000, 1000)
|
||||
assert mini < full
|
||||
|
||||
def test_opus_most_expensive_claude(self):
|
||||
"""Opus should be the most expensive Claude model."""
|
||||
opus = estimate_cost_usd("claude-opus-4-5", 1000, 1000)
|
||||
sonnet = estimate_cost_usd("claude-sonnet-4-5", 1000, 1000)
|
||||
haiku = estimate_cost_usd("claude-haiku-4-5", 1000, 1000)
|
||||
assert opus > sonnet > haiku
|
||||
|
||||
def test_grok_variants(self):
|
||||
"""Test Grok model cost estimation."""
|
||||
cost = estimate_cost_usd("grok-3", 1000, 1000)
|
||||
assert cost > 0
|
||||
cost_fast = estimate_cost_usd("grok-3-fast", 1000, 1000)
|
||||
assert cost_fast > 0
|
||||
|
||||
def test_case_insensitive_matching(self):
|
||||
"""Model name matching should be case insensitive."""
|
||||
cost_lower = estimate_cost_usd("claude-haiku-4-5", 1000, 0)
|
||||
cost_upper = estimate_cost_usd("CLAUDE-HAIKU-4-5", 1000, 0)
|
||||
cost_mixed = estimate_cost_usd("Claude-Haiku-4-5", 1000, 0)
|
||||
assert cost_lower == cost_upper == cost_mixed
|
||||
|
||||
def test_returns_float(self):
|
||||
"""Function should always return a float."""
|
||||
assert isinstance(estimate_cost_usd("haiku", 100, 200), float)
|
||||
assert isinstance(estimate_cost_usd("unknown-model", 100, 200), float)
|
||||
assert isinstance(estimate_cost_usd("haiku", 0, 0), float)
|
||||
|
||||
|
||||
# ── Test BudgetTracker initialization ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerInit:
|
||||
"""Tests for BudgetTracker initialization."""
|
||||
|
||||
def test_creates_with_memory_db(self):
|
||||
"""Tracker should initialize with in-memory database."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._db_ok is True
|
||||
|
||||
def test_in_memory_fallback_empty_on_creation(self):
|
||||
"""In-memory fallback should start empty."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker._in_memory == []
|
||||
|
||||
def test_custom_db_path(self, tmp_path):
|
||||
"""Tracker should use custom database path."""
|
||||
db_file = tmp_path / "custom_budget.db"
|
||||
tracker = BudgetTracker(db_path=str(db_file))
|
||||
assert tracker._db_ok is True
|
||||
assert tracker._db_path == str(db_file)
|
||||
assert db_file.exists()
|
||||
|
||||
def test_db_path_directory_creation(self, tmp_path):
|
||||
"""Tracker should create parent directories if needed."""
|
||||
db_file = tmp_path / "nested" / "dirs" / "budget.db"
|
||||
tracker = BudgetTracker(db_path=str(db_file))
|
||||
assert tracker._db_ok is True
|
||||
assert db_file.parent.exists()
|
||||
|
||||
def test_invalid_db_path_fallback(self):
|
||||
"""Tracker should fallback to in-memory on invalid path."""
|
||||
# Use a path that cannot be created (e.g., permission denied simulation)
|
||||
tracker = BudgetTracker.__new__(BudgetTracker)
|
||||
tracker._db_path = "/nonexistent/invalid/path/budget.db"
|
||||
tracker._lock = threading.Lock()
|
||||
tracker._in_memory = []
|
||||
tracker._db_ok = False
|
||||
# Should still work with in-memory fallback
|
||||
cost = tracker.record_spend("test", "model", cost_usd=0.01)
|
||||
assert cost == 0.01
|
||||
|
||||
|
||||
# ── Test BudgetTracker record_spend ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerRecordSpend:
|
||||
"""Tests for recording spend events."""
|
||||
|
||||
def test_record_spend_returns_cost(self):
|
||||
"""record_spend should return the calculated cost."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "claude-haiku-4-5", 100, 200)
|
||||
assert cost > 0
|
||||
|
||||
def test_record_spend_explicit_cost(self):
|
||||
"""record_spend should use explicit cost when provided."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "model", cost_usd=1.23)
|
||||
assert cost == pytest.approx(1.23)
|
||||
|
||||
def test_record_spend_accumulates(self):
|
||||
"""Multiple spend records should accumulate correctly."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.02)
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
|
||||
|
||||
def test_record_spend_with_tier_label(self):
|
||||
"""record_spend should accept custom tier labels."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("anthropic", "haiku", tier="cloud_api")
|
||||
assert cost >= 0
|
||||
|
||||
def test_record_spend_with_provider(self):
|
||||
"""record_spend should track provider correctly."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=0.01)
|
||||
tracker.record_spend("anthropic", "claude-haiku", cost_usd=0.02)
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.03, abs=1e-9)
|
||||
|
||||
def test_record_zero_cost(self):
|
||||
"""Recording zero cost should work correctly."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("test", "model", cost_usd=0.0)
|
||||
assert cost == 0.0
|
||||
assert tracker.get_daily_spend() == 0.0
|
||||
|
||||
def test_record_negative_cost(self):
|
||||
"""Recording negative cost (refund) should work."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("test", "model", cost_usd=-0.50)
|
||||
assert cost == -0.50
|
||||
assert tracker.get_daily_spend() == -0.50
|
||||
|
||||
|
||||
# ── Test BudgetTracker daily/monthly spend queries ────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerSpendQueries:
|
||||
"""Tests for daily and monthly spend queries."""
|
||||
|
||||
def test_monthly_spend_includes_daily(self):
|
||||
"""Monthly spend should be >= daily spend."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=5.00)
|
||||
assert tracker.get_monthly_spend() >= tracker.get_daily_spend()
|
||||
|
||||
def test_get_daily_spend_empty(self):
|
||||
"""Daily spend should be zero when no records."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker.get_daily_spend() == 0.0
|
||||
|
||||
def test_get_monthly_spend_empty(self):
|
||||
"""Monthly spend should be zero when no records."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker.get_monthly_spend() == 0.0
|
||||
|
||||
def test_daily_spend_isolation(self):
|
||||
"""Daily spend should only include today's records, not old ones."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
# Force use of in-memory fallback
|
||||
tracker._db_ok = False
|
||||
|
||||
# Add record for today
|
||||
today_ts = datetime.combine(date.today(), datetime.min.time(), tzinfo=UTC).timestamp()
|
||||
tracker._in_memory.append(
|
||||
SpendRecord(today_ts + 3600, "test", "model", 0, 0, 1.0, "cloud")
|
||||
)
|
||||
|
||||
# Add old record (2 days ago)
|
||||
old_ts = (datetime.now(UTC) - timedelta(days=2)).timestamp()
|
||||
tracker._in_memory.append(
|
||||
SpendRecord(old_ts, "test", "old_model", 0, 0, 2.0, "cloud")
|
||||
)
|
||||
|
||||
# Daily should only include today's 1.0
|
||||
assert tracker.get_daily_spend() == pytest.approx(1.0, abs=1e-9)
|
||||
# Monthly should include both (both are in current month)
|
||||
assert tracker.get_monthly_spend() == pytest.approx(3.0, abs=1e-9)
|
||||
|
||||
|
||||
# ── Test BudgetTracker cloud_allowed ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerCloudAllowed:
|
||||
"""Tests for cloud budget limit enforcement."""
|
||||
|
||||
def test_allowed_when_no_spend(self):
|
||||
"""Cloud should be allowed when no spend recorded."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_daily_limit_exceeded(self):
|
||||
"""Cloud should be blocked when daily limit exceeded."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
# With default daily limit of 5.0, 999 should block
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_allowed_when_daily_limit_zero(self):
|
||||
"""Cloud should be allowed when daily limit is 0 (disabled)."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0 # disabled
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_blocked_when_monthly_limit_exceeded(self):
|
||||
"""Cloud should be blocked when monthly limit exceeded."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("anthropic", "haiku", cost_usd=999.0)
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # daily disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10.0
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_allowed_at_exact_daily_limit(self):
|
||||
"""Cloud should be allowed when exactly at daily limit."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 5.0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0
|
||||
# Record exactly at limit
|
||||
tracker.record_spend("test", "model", cost_usd=5.0)
|
||||
# At exactly the limit, it should return False (blocked)
|
||||
# because spend >= limit
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_allowed_below_daily_limit(self):
|
||||
"""Cloud should be allowed when below daily limit."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 5.0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0
|
||||
tracker.record_spend("test", "model", cost_usd=4.99)
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_zero_budget_blocks_all(self):
|
||||
"""Zero budget should block all cloud usage."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0.01 # Very small budget
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0
|
||||
tracker.record_spend("test", "model", cost_usd=0.02)
|
||||
# Over the tiny budget, should be blocked
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_both_limits_checked(self):
|
||||
"""Both daily and monthly limits should be checked."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 100.0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10.0
|
||||
tracker.record_spend("test", "model", cost_usd=15.0)
|
||||
# Under daily but over monthly
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
|
||||
# ── Test BudgetTracker summary ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerSummary:
|
||||
"""Tests for budget summary functionality."""
|
||||
|
||||
def test_summary_keys_present(self):
|
||||
"""Summary should contain all expected keys."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert "daily_usd" in summary
|
||||
assert "monthly_usd" in summary
|
||||
assert "daily_limit_usd" in summary
|
||||
assert "monthly_limit_usd" in summary
|
||||
assert "daily_ok" in summary
|
||||
assert "monthly_ok" in summary
|
||||
|
||||
def test_summary_daily_ok_true_on_empty(self):
|
||||
"""daily_ok and monthly_ok should be True when empty."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is True
|
||||
assert summary["monthly_ok"] is True
|
||||
|
||||
def test_summary_daily_ok_false_when_exceeded(self):
|
||||
"""daily_ok should be False when daily limit exceeded."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=999.0)
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_ok"] is False
|
||||
|
||||
def test_summary_monthly_ok_false_when_exceeded(self):
|
||||
"""monthly_ok should be False when monthly limit exceeded."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10.0
|
||||
tracker.record_spend("openai", "gpt-4o", cost_usd=15.0)
|
||||
summary = tracker.get_summary()
|
||||
assert summary["monthly_ok"] is False
|
||||
|
||||
def test_summary_values_rounded(self):
|
||||
"""Summary values should be rounded appropriately."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("test", "model", cost_usd=1.123456789)
|
||||
summary = tracker.get_summary()
|
||||
# daily_usd should be rounded to 6 decimal places
|
||||
assert summary["daily_usd"] == 1.123457
|
||||
|
||||
def test_summary_with_disabled_limits(self):
|
||||
"""Summary should handle disabled limits (0)."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 0
|
||||
tracker.record_spend("test", "model", cost_usd=100.0)
|
||||
summary = tracker.get_summary()
|
||||
assert summary["daily_limit_usd"] == 0
|
||||
assert summary["monthly_limit_usd"] == 0
|
||||
assert summary["daily_ok"] is True
|
||||
assert summary["monthly_ok"] is True
|
||||
|
||||
|
||||
# ── Test BudgetTracker in-memory fallback ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerInMemoryFallback:
|
||||
"""Tests for in-memory fallback when DB is unavailable."""
|
||||
|
||||
def test_in_memory_records_persisted(self):
|
||||
"""Records should be stored in memory when DB is unavailable."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
# Force DB to appear unavailable
|
||||
tracker._db_ok = False
|
||||
tracker.record_spend("test", "model", cost_usd=0.01)
|
||||
assert len(tracker._in_memory) == 1
|
||||
assert tracker._in_memory[0].cost_usd == 0.01
|
||||
|
||||
def test_in_memory_query_spend(self):
|
||||
"""Query spend should work with in-memory fallback."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker._db_ok = False
|
||||
tracker.record_spend("test", "model", cost_usd=0.01)
|
||||
# Query should work from in-memory
|
||||
since_ts = (datetime.now(UTC) - timedelta(hours=1)).timestamp()
|
||||
result = tracker._query_spend(since_ts)
|
||||
assert result == 0.01
|
||||
|
||||
def test_in_memory_older_records_not_counted(self):
|
||||
"""In-memory records older than since_ts should not be counted."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker._db_ok = False
|
||||
old_ts = (datetime.now(UTC) - timedelta(days=2)).timestamp()
|
||||
tracker._in_memory.append(
|
||||
SpendRecord(old_ts, "test", "model", 0, 0, 1.0, "cloud")
|
||||
)
|
||||
# Query for records in last day
|
||||
since_ts = (datetime.now(UTC) - timedelta(days=1)).timestamp()
|
||||
result = tracker._query_spend(since_ts)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
# ── Test BudgetTracker thread safety ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerThreadSafety:
|
||||
"""Tests for thread-safe operations."""
|
||||
|
||||
def test_concurrent_record_spend(self):
|
||||
"""Multiple threads should safely record spend concurrently."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def record_spends():
|
||||
try:
|
||||
for _ in range(10):
|
||||
cost = tracker.record_spend("test", "model", cost_usd=0.01)
|
||||
results.append(cost)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=record_spends) for _ in range(5)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(results) == 50
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.50, abs=1e-9)
|
||||
|
||||
|
||||
# ── Test BudgetTracker edge cases ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerEdgeCases:
|
||||
"""Tests for edge cases and boundary conditions."""
|
||||
|
||||
def test_very_small_cost(self):
|
||||
"""Tracker should handle very small costs."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("test", "model", cost_usd=0.000001)
|
||||
assert tracker.get_daily_spend() == pytest.approx(0.000001, abs=1e-9)
|
||||
|
||||
def test_very_large_cost(self):
|
||||
"""Tracker should handle very large costs."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
tracker.record_spend("test", "model", cost_usd=1_000_000.0)
|
||||
assert tracker.get_daily_spend() == pytest.approx(1_000_000.0, abs=1e-9)
|
||||
|
||||
def test_many_records(self):
|
||||
"""Tracker should handle many records efficiently."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
for i in range(100):
|
||||
tracker.record_spend(f"provider_{i}", f"model_{i}", cost_usd=0.01)
|
||||
assert tracker.get_daily_spend() == pytest.approx(1.0, abs=1e-9)
|
||||
|
||||
def test_empty_provider_name(self):
|
||||
"""Tracker should handle empty provider name."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("", "model", cost_usd=0.01)
|
||||
assert cost == 0.01
|
||||
|
||||
def test_empty_model_name(self):
|
||||
"""Tracker should handle empty model name."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
cost = tracker.record_spend("provider", "", cost_usd=0.01)
|
||||
assert cost == 0.01
|
||||
|
||||
|
||||
# ── Test get_budget_tracker singleton ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetBudgetTrackerSingleton:
|
||||
"""Tests for the module-level BudgetTracker singleton."""
|
||||
|
||||
def test_returns_budget_tracker(self):
|
||||
"""Singleton should return a BudgetTracker instance."""
|
||||
import infrastructure.models.budget as bmod
|
||||
|
||||
bmod._budget_tracker = None
|
||||
tracker = get_budget_tracker()
|
||||
assert isinstance(tracker, BudgetTracker)
|
||||
|
||||
def test_returns_same_instance(self):
|
||||
"""Singleton should return the same instance."""
|
||||
import infrastructure.models.budget as bmod
|
||||
|
||||
bmod._budget_tracker = None
|
||||
t1 = get_budget_tracker()
|
||||
t2 = get_budget_tracker()
|
||||
assert t1 is t2
|
||||
|
||||
def test_singleton_persists_state(self):
|
||||
"""Singleton should persist state across calls."""
|
||||
import infrastructure.models.budget as bmod
|
||||
|
||||
bmod._budget_tracker = None
|
||||
tracker1 = get_budget_tracker()
|
||||
# Record spend
|
||||
tracker1.record_spend("test", "model", cost_usd=1.0)
|
||||
# Get singleton again
|
||||
tracker2 = get_budget_tracker()
|
||||
assert tracker1 is tracker2
|
||||
|
||||
|
||||
# ── Test BudgetTracker with mocked settings ───────────────────────────────────
|
||||
|
||||
|
||||
class TestBudgetTrackerWithMockedSettings:
|
||||
"""Tests using mocked settings for different scenarios."""
|
||||
|
||||
def test_high_daily_limit(self):
|
||||
"""Test with high daily limit."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 1000.0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 10000.0
|
||||
tracker.record_spend("test", "model", cost_usd=500.0)
|
||||
assert tracker.cloud_allowed() is True
|
||||
|
||||
def test_low_daily_limit(self):
|
||||
"""Test with low daily limit."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 1.0
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 100.0
|
||||
tracker.record_spend("test", "model", cost_usd=2.0)
|
||||
assert tracker.cloud_allowed() is False
|
||||
|
||||
def test_only_monthly_limit_enabled(self):
|
||||
"""Test with only monthly limit enabled."""
|
||||
tracker = BudgetTracker(db_path=":memory:")
|
||||
with patch("infrastructure.models.budget.settings") as mock_settings:
|
||||
mock_settings.tier_cloud_daily_budget_usd = 0 # Disabled
|
||||
mock_settings.tier_cloud_monthly_budget_usd = 50.0
|
||||
tracker.record_spend("test", "model", cost_usd=30.0)
|
||||
assert tracker.cloud_allowed() is True
|
||||
tracker.record_spend("test", "model", cost_usd=25.0)
|
||||
assert tracker.cloud_allowed() is False
|
||||
@@ -287,148 +287,6 @@ class TestJotNote:
|
||||
assert "body is empty" in jot_note("title", " ")
|
||||
|
||||
|
||||
class TestHotMemoryTimestamp:
|
||||
"""Tests for Working RAM auto-updating timestamp (issue #10)."""
|
||||
|
||||
def test_read_includes_last_updated_when_facts_exist(self, tmp_path):
|
||||
"""HotMemory.read() includes a 'Last updated' timestamp when DB has facts."""
|
||||
db_path = tmp_path / "memory.db"
|
||||
|
||||
with (
|
||||
patch("timmy.memory.db.DB_PATH", db_path),
|
||||
patch("timmy.memory.crud.get_connection") as mock_conn,
|
||||
):
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
real_conn = sqlite3.connect(str(db_path))
|
||||
real_conn.row_factory = sqlite3.Row
|
||||
real_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
memory_type TEXT NOT NULL DEFAULT 'fact',
|
||||
source TEXT NOT NULL DEFAULT 'agent',
|
||||
embedding TEXT, metadata TEXT, source_hash TEXT,
|
||||
agent_id TEXT, task_id TEXT, session_id TEXT,
|
||||
confidence REAL NOT NULL DEFAULT 0.8,
|
||||
tags TEXT NOT NULL DEFAULT '[]',
|
||||
created_at TEXT NOT NULL,
|
||||
last_accessed TEXT,
|
||||
access_count INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
real_conn.execute(
|
||||
"INSERT INTO memories (id, content, memory_type, source, created_at) "
|
||||
"VALUES ('1', 'User prefers dark mode', 'fact', 'system', '2026-03-20T10:00:00+00:00')"
|
||||
)
|
||||
real_conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_connection():
|
||||
yield real_conn
|
||||
|
||||
mock_conn.side_effect = fake_get_connection
|
||||
|
||||
hot = HotMemory()
|
||||
result = hot.read()
|
||||
|
||||
assert "> Last updated:" in result
|
||||
assert "2026-03-20" in result
|
||||
assert "User prefers dark mode" in result
|
||||
|
||||
def test_read_timestamp_reflects_most_recent_memory(self, tmp_path):
|
||||
"""The timestamp in HotMemory.read() matches the latest memory's created_at."""
|
||||
db_path = tmp_path / "memory.db"
|
||||
|
||||
with patch("timmy.memory.crud.get_connection") as mock_conn:
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
real_conn = sqlite3.connect(str(db_path))
|
||||
real_conn.row_factory = sqlite3.Row
|
||||
real_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
memory_type TEXT NOT NULL DEFAULT 'fact',
|
||||
source TEXT NOT NULL DEFAULT 'agent',
|
||||
embedding TEXT, metadata TEXT, source_hash TEXT,
|
||||
agent_id TEXT, task_id TEXT, session_id TEXT,
|
||||
confidence REAL NOT NULL DEFAULT 0.8,
|
||||
tags TEXT NOT NULL DEFAULT '[]',
|
||||
created_at TEXT NOT NULL,
|
||||
last_accessed TEXT,
|
||||
access_count INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
# Older fact
|
||||
real_conn.execute(
|
||||
"INSERT INTO memories (id, content, memory_type, source, created_at) "
|
||||
"VALUES ('1', 'old fact', 'fact', 'system', '2026-03-15T08:00:00+00:00')"
|
||||
)
|
||||
# Newer fact — this should be reflected in the timestamp
|
||||
real_conn.execute(
|
||||
"INSERT INTO memories (id, content, memory_type, source, created_at) "
|
||||
"VALUES ('2', 'new fact', 'fact', 'system', '2026-03-23T14:30:00+00:00')"
|
||||
)
|
||||
real_conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_connection():
|
||||
yield real_conn
|
||||
|
||||
mock_conn.side_effect = fake_get_connection
|
||||
|
||||
hot = HotMemory()
|
||||
result = hot.read()
|
||||
|
||||
assert "2026-03-23" in result
|
||||
assert "> Last updated:" in result
|
||||
|
||||
def test_read_falls_back_to_file_when_db_empty(self, tmp_path):
|
||||
"""HotMemory.read() falls back to MEMORY.md when DB has no facts or reflections."""
|
||||
mem_file = tmp_path / "MEMORY.md"
|
||||
mem_file.write_text("# Timmy Hot Memory\n\n## Current Status\n\nOperational\n")
|
||||
|
||||
with patch("timmy.memory.crud.get_connection") as mock_conn:
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
|
||||
db_path = tmp_path / "empty.db"
|
||||
real_conn = sqlite3.connect(str(db_path))
|
||||
real_conn.row_factory = sqlite3.Row
|
||||
real_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
memory_type TEXT NOT NULL DEFAULT 'fact',
|
||||
source TEXT NOT NULL DEFAULT 'agent',
|
||||
embedding TEXT, metadata TEXT, source_hash TEXT,
|
||||
agent_id TEXT, task_id TEXT, session_id TEXT,
|
||||
confidence REAL NOT NULL DEFAULT 0.8,
|
||||
tags TEXT NOT NULL DEFAULT '[]',
|
||||
created_at TEXT NOT NULL,
|
||||
last_accessed TEXT,
|
||||
access_count INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
""")
|
||||
real_conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_connection():
|
||||
yield real_conn
|
||||
|
||||
mock_conn.side_effect = fake_get_connection
|
||||
|
||||
hot = HotMemory()
|
||||
hot.path = mem_file
|
||||
result = hot.read()
|
||||
|
||||
assert "Operational" in result
|
||||
assert "> Last updated:" not in result
|
||||
|
||||
|
||||
class TestLogDecision:
|
||||
"""Tests for log_decision() artifact tool."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user