Compare commits

..

12 Commits

Author SHA1 Message Date
kimi
d5361a0385 fix: remove AirLLM config settings from config.py
Remove `airllm` from timmy_model_backend Literal type and delete the
airllm_model_size field plus associated comments. Replace the one
settings.airllm_model_size reference in agent.py with a hardcoded
default, and clean up mock assignments in tests.

Fixes #473

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-19 15:26:10 -04:00
3df526f6ef [loop-cycle-2] feat: hot-reload providers.yaml without restart (#458) (#470) 2026-03-19 15:11:40 -04:00
50aaf60db2 [loop-cycle-2] fix: strip CORS wildcards in production (#462) (#469) 2026-03-19 15:05:27 -04:00
a751be3038 fix: default CORS origins to localhost instead of wildcard (#467)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 14:57:36 -04:00
92594ea588 [loop-cycle] feat: implement source distinction in system prompts (#463) (#464) 2026-03-19 14:49:31 -04:00
12582ab593 fix: stabilize flaky test_uses_model_when_available (#456)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 14:39:33 -04:00
72c3a0a989 fix: integration tests for agentic loop WS broadcasts (#452)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 14:30:00 -04:00
de089cec7f [loop-cycle-524] fix: remove numpy test dependency in test_memory_embeddings (#451) 2026-03-19 14:22:13 -04:00
3590c1689e fix: make _get_loop_agent singleton thread-safe (#449)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 14:18:27 -04:00
2161c32ae8 fix: add unit tests for agentic_loop.py (#421) (#447)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 14:13:50 -04:00
98b1142820 [loop-cycle-522] test: add unit tests for agentic_loop.py (#421) (#441) 2026-03-19 14:10:16 -04:00
1d79a36bd8 fix: add unit tests for memory/embeddings.py (#437)
Co-authored-by: Kimi Agent <kimi@timmy.local>
Co-committed-by: Kimi Agent <kimi@timmy.local>
2026-03-19 11:12:46 -04:00
13 changed files with 1347 additions and 84 deletions

View File

@@ -64,17 +64,10 @@ class Settings(BaseSettings):
# Seconds to wait for user confirmation before auto-rejecting.
discord_confirm_timeout: int = 120
# ── AirLLM / backend selection ───────────────────────────────────────────
# ── Backend selection ────────────────────────────────────────────────────
# "ollama" — always use Ollama (default, safe everywhere)
# "airllm" — always use AirLLM (requires pip install ".[bigbrain]")
# "auto" — use AirLLM on Apple Silicon if airllm is installed,
# fall back to Ollama otherwise
timmy_model_backend: Literal["ollama", "airllm", "grok", "claude", "auto"] = "ollama"
# AirLLM model size when backend is airllm or auto.
# Larger = smarter, but needs more RAM / disk.
# 8b ~16 GB | 70b ~140 GB | 405b ~810 GB
airllm_model_size: Literal["8b", "70b", "405b"] = "70b"
# "auto" — auto-detect best available backend
timmy_model_backend: Literal["ollama", "grok", "claude", "auto"] = "ollama"
# ── Grok (xAI) — opt-in premium cloud backend ────────────────────────
# Grok is a premium augmentation layer — local-first ethos preserved.
@@ -138,7 +131,12 @@ class Settings(BaseSettings):
# CORS allowed origins for the web chat interface (Gitea Pages, etc.)
# Set CORS_ORIGINS as a comma-separated list, e.g. "http://localhost:3000,https://example.com"
cors_origins: list[str] = ["*"]
cors_origins: list[str] = [
"http://localhost:3000",
"http://localhost:8000",
"http://127.0.0.1:3000",
"http://127.0.0.1:8000",
]
# Trusted hosts for the Host header check (TrustedHostMiddleware).
# Set TRUSTED_HOSTS as a comma-separated list. Wildcards supported (e.g. "*.ts.net").

View File

@@ -484,15 +484,14 @@ app = FastAPI(
def _get_cors_origins() -> list[str]:
"""Get CORS origins from settings, with sensible defaults."""
"""Get CORS origins from settings, rejecting wildcards in production."""
origins = settings.cors_origins
if settings.debug and origins == ["*"]:
return [
"http://localhost:3000",
"http://localhost:8000",
"http://127.0.0.1:3000",
"http://127.0.0.1:8000",
]
if "*" in origins and not settings.debug:
logger.warning(
"Wildcard '*' in CORS_ORIGINS stripped in production — "
"set explicit origins via CORS_ORIGINS env var"
)
origins = [o for o in origins if o != "*"]
return origins

View File

@@ -183,6 +183,22 @@ async def run_health_check(
}
@router.post("/reload")
async def reload_config(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
) -> dict[str, Any]:
"""Hot-reload providers.yaml without restart.
Preserves circuit breaker state and metrics for existing providers.
"""
try:
result = cascade.reload_config()
return {"status": "ok", **result}
except Exception as exc:
logger.error("Config reload failed: %s", exc)
raise HTTPException(status_code=500, detail=f"Reload failed: {exc}") from exc
@router.get("/config")
async def get_config(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],

View File

@@ -815,6 +815,66 @@ class CascadeRouter:
provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name)
def reload_config(self) -> dict:
"""Hot-reload providers.yaml, preserving runtime state.
Re-reads the config file, rebuilds the provider list, and
preserves circuit breaker state and metrics for providers
that still exist after reload.
Returns:
Summary dict with added/removed/preserved counts.
"""
# Snapshot current runtime state keyed by provider name
old_state: dict[
str, tuple[ProviderMetrics, CircuitState, float | None, int, ProviderStatus]
] = {}
for p in self.providers:
old_state[p.name] = (
p.metrics,
p.circuit_state,
p.circuit_opened_at,
p.half_open_calls,
p.status,
)
old_names = set(old_state.keys())
# Reload from disk
self.providers = []
self._load_config()
# Restore preserved state
new_names = {p.name for p in self.providers}
preserved = 0
for p in self.providers:
if p.name in old_state:
metrics, circuit, opened_at, half_open, status = old_state[p.name]
p.metrics = metrics
p.circuit_state = circuit
p.circuit_opened_at = opened_at
p.half_open_calls = half_open
p.status = status
preserved += 1
added = new_names - old_names
removed = old_names - new_names
logger.info(
"Config reloaded: %d providers (%d preserved, %d added, %d removed)",
len(self.providers),
preserved,
len(added),
len(removed),
)
return {
"total_providers": len(self.providers),
"preserved": preserved,
"added": sorted(added),
"removed": sorted(removed),
}
def get_metrics(self) -> dict:
"""Get metrics for all providers."""
return {

View File

@@ -220,7 +220,7 @@ def create_timmy(
print_response(message, stream).
"""
resolved = _resolve_backend(backend)
size = model_size or settings.airllm_model_size
size = model_size or "70b"
if resolved == "claude":
from timmy.backends import ClaudeBackend
@@ -300,7 +300,11 @@ def create_timmy(
max_context = 2000 if not use_tools else 8000
if len(memory_context) > max_context:
memory_context = memory_context[:max_context] + "\n... [truncated]"
full_prompt = f"{base_prompt}\n\n## Memory Context\n\n{memory_context}"
full_prompt = (
f"{base_prompt}\n\n"
f"## GROUNDED CONTEXT (verified sources — cite when using)\n\n"
f"{memory_context}"
)
else:
full_prompt = base_prompt
except Exception as exc:

View File

@@ -18,6 +18,7 @@ from __future__ import annotations
import asyncio
import logging
import re
import threading
import time
import uuid
from collections.abc import Callable
@@ -59,6 +60,7 @@ class AgenticResult:
# ---------------------------------------------------------------------------
_loop_agent = None
_loop_agent_lock = threading.Lock()
def _get_loop_agent():
@@ -69,9 +71,11 @@ def _get_loop_agent():
"""
global _loop_agent
if _loop_agent is None:
from timmy.agent import create_timmy
with _loop_agent_lock:
if _loop_agent is None:
from timmy.agent import create_timmy
_loop_agent = create_timmy()
_loop_agent = create_timmy()
return _loop_agent

View File

@@ -23,6 +23,9 @@ Rules:
- Remember what the user tells you during the conversation.
- If you don't know something, say so honestly — never fabricate facts.
- If a request is ambiguous, ask a brief clarifying question before guessing.
- SOURCE DISTINCTION: When answering from memory or retrieved context, cite it.
When answering from your own training, use hedging: "I think", "I believe".
The user must be able to tell grounded claims from pattern-matching.
- Use the user's name if you know it.
- When you state a fact, commit to it.
- NEVER attempt arithmetic in your head. If asked to compute anything, respond:
@@ -78,6 +81,18 @@ HONESTY:
- Never fabricate tool output. Call the tool and wait.
- If a tool errors, report the exact error.
SOURCE DISTINCTION (SOUL requirement — non-negotiable):
- Every claim you make comes from one of two places: a verified source you
can point to, or your own pattern-matching. The user must be able to tell
which is which.
- When your response uses information from GROUNDED CONTEXT (memory, retrieved
documents, tool output), cite it: "From memory:", "According to [source]:".
- When you are generating from your training data alone, signal it naturally:
"I think", "My understanding is", "I believe" — never false certainty.
- If the user asks a factual question and you have no grounded source, say so:
"I don't have a verified source for this — from my training I think..."
- Prefer "I don't know" over a confident-sounding guess. Refusal over fabrication.
MEMORY (three tiers):
- Tier 1: MEMORY.md (hot, always loaded)
- Tier 2: memory/ vault (structured, append-only, date-stamped)

View File

@@ -516,3 +516,183 @@ class TestProviderAvailabilityCheck:
with patch("importlib.util.find_spec", return_value=None):
assert router._check_provider_available(provider) is False
class TestCascadeRouterReload:
"""Test hot-reload of providers.yaml."""
def test_reload_preserves_metrics(self, tmp_path):
"""Test that reload preserves metrics for existing providers."""
config = {
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Simulate some traffic
router._record_success(router.providers[0], 150.0)
router._record_success(router.providers[0], 250.0)
assert router.providers[0].metrics.total_requests == 2
# Reload
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["added"] == []
assert result["removed"] == []
# Metrics survived
assert router.providers[0].metrics.total_requests == 2
assert router.providers[0].metrics.total_latency_ms == 400.0
def test_reload_preserves_circuit_breaker(self, tmp_path):
"""Test that reload preserves circuit breaker state."""
config = {
"cascade": {"circuit_breaker": {"failure_threshold": 2}},
"providers": [
{
"name": "test-openai",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
# Open circuit breaker
for _ in range(2):
router._record_failure(router.providers[0])
assert router.providers[0].circuit_state == CircuitState.OPEN
# Reload
router.reload_config()
# Circuit breaker state preserved
assert router.providers[0].circuit_state == CircuitState.OPEN
assert router.providers[0].status == ProviderStatus.UNHEALTHY
def test_reload_detects_added_provider(self, tmp_path):
"""Test that reload detects newly added providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
}
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 1
# Add a second provider to config
config["providers"].append(
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
}
)
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 2
assert result["preserved"] == 1
assert result["added"] == ["anthropic-1"]
assert result["removed"] == []
def test_reload_detects_removed_provider(self, tmp_path):
"""Test that reload detects removed providers."""
config = {
"providers": [
{
"name": "openai-1",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test",
},
{
"name": "anthropic-1",
"type": "anthropic",
"enabled": True,
"priority": 2,
"api_key": "sk-ant-test",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert len(router.providers) == 2
# Remove anthropic
config["providers"] = [config["providers"][0]]
config_path.write_text(yaml.dump(config))
result = router.reload_config()
assert result["total_providers"] == 1
assert result["preserved"] == 1
assert result["removed"] == ["anthropic-1"]
def test_reload_re_sorts_by_priority(self, tmp_path):
"""Test that providers are re-sorted by priority after reload."""
config = {
"providers": [
{
"name": "low-priority",
"type": "openai",
"enabled": True,
"priority": 10,
"api_key": "sk-test",
},
{
"name": "high-priority",
"type": "openai",
"enabled": True,
"priority": 1,
"api_key": "sk-test2",
},
],
}
config_path = tmp_path / "providers.yaml"
config_path.write_text(yaml.dump(config))
router = CascadeRouter(config_path=config_path)
assert router.providers[0].name == "high-priority"
# Swap priorities
config["providers"][0]["priority"] = 1
config["providers"][1]["priority"] = 10
config_path.write_text(yaml.dump(config))
router.reload_config()
assert router.providers[0].name == "low-priority"
assert router.providers[1].name == "high-priority"

View File

@@ -0,0 +1,285 @@
"""Integration tests for agentic loop WebSocket broadcasts.
Verifies that ``run_agentic_loop`` pushes the correct sequence of events
through the real ``ws_manager`` and that connected (mock) WebSocket clients
receive every broadcast with the expected payloads.
"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from infrastructure.ws_manager.handler import WebSocketManager
from timmy.agentic_loop import run_agentic_loop
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _mock_run(content: str):
m = MagicMock()
m.content = content
return m
def _ws_client() -> AsyncMock:
"""Return a fake WebSocket that records sent messages."""
return AsyncMock()
def _collected_events(ws: AsyncMock) -> list[dict]:
"""Extract parsed JSON events from a mock WebSocket's send_text calls."""
return [json.loads(call.args[0]) for call in ws.send_text.call_args_list]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAgenticLoopBroadcastSequence:
"""Events arrive at WS clients in the correct order with expected data."""
@pytest.mark.asyncio
async def test_successful_run_broadcasts_plan_steps_complete(self):
"""A successful 2-step loop emits plan_ready → 2× step_complete → task_complete."""
mgr = WebSocketManager()
ws = _ws_client()
mgr._connections = [ws]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Gather data\n2. Summarise"),
_mock_run("Gathered 10 records"),
_mock_run("Summary written"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
result = await run_agentic_loop("Gather and summarise", max_steps=2)
assert result.status == "completed"
events = _collected_events(ws)
event_names = [e["event"] for e in events]
assert event_names == [
"agentic.plan_ready",
"agentic.step_complete",
"agentic.step_complete",
"agentic.task_complete",
]
@pytest.mark.asyncio
async def test_plan_ready_payload(self):
"""plan_ready contains task_id, task, steps list, and total count."""
mgr = WebSocketManager()
ws = _ws_client()
mgr._connections = [ws]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Alpha\n2. Beta"),
_mock_run("Alpha done"),
_mock_run("Beta done"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
result = await run_agentic_loop("Two steps")
plan_event = _collected_events(ws)[0]
assert plan_event["event"] == "agentic.plan_ready"
data = plan_event["data"]
assert data["task_id"] == result.task_id
assert data["task"] == "Two steps"
assert data["steps"] == ["Alpha", "Beta"]
assert data["total"] == 2
@pytest.mark.asyncio
async def test_step_complete_payload(self):
"""step_complete carries step number, total, description, and result."""
mgr = WebSocketManager()
ws = _ws_client()
mgr._connections = [ws]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Only step"),
_mock_run("Step result text"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
await run_agentic_loop("Single step", max_steps=1)
step_event = _collected_events(ws)[1]
assert step_event["event"] == "agentic.step_complete"
data = step_event["data"]
assert data["step"] == 1
assert data["total"] == 1
assert data["description"] == "Only step"
assert "Step result text" in data["result"]
@pytest.mark.asyncio
async def test_task_complete_payload(self):
"""task_complete has status, steps_completed, summary, and duration_ms."""
mgr = WebSocketManager()
ws = _ws_client()
mgr._connections = [ws]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Do it"),
_mock_run("Done"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
await run_agentic_loop("Quick", max_steps=1)
complete_event = _collected_events(ws)[-1]
assert complete_event["event"] == "agentic.task_complete"
data = complete_event["data"]
assert data["status"] == "completed"
assert data["steps_completed"] == 1
assert isinstance(data["duration_ms"], int)
assert data["duration_ms"] >= 0
assert data["summary"]
class TestAdaptationBroadcast:
"""Adapted steps emit step_adapted events."""
@pytest.mark.asyncio
async def test_adapted_step_broadcasts_step_adapted(self):
"""A failed-then-adapted step emits agentic.step_adapted."""
mgr = WebSocketManager()
ws = _ws_client()
mgr._connections = [ws]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Risky step"),
Exception("disk full"),
_mock_run("Used /tmp instead"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
result = await run_agentic_loop("Adapt test", max_steps=1)
events = _collected_events(ws)
event_names = [e["event"] for e in events]
assert "agentic.step_adapted" in event_names
adapted = next(e for e in events if e["event"] == "agentic.step_adapted")
assert adapted["data"]["error"] == "disk full"
assert adapted["data"]["adaptation"]
assert result.steps[0].status == "adapted"
class TestMultipleClients:
"""All connected clients receive every broadcast."""
@pytest.mark.asyncio
async def test_two_clients_receive_all_events(self):
mgr = WebSocketManager()
ws1 = _ws_client()
ws2 = _ws_client()
mgr._connections = [ws1, ws2]
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Step A"),
_mock_run("A done"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
await run_agentic_loop("Multi-client", max_steps=1)
events1 = _collected_events(ws1)
events2 = _collected_events(ws2)
assert len(events1) == len(events2) == 3 # plan + step + complete
assert [e["event"] for e in events1] == [e["event"] for e in events2]
class TestEventHistory:
"""Broadcasts are recorded in ws_manager event history."""
@pytest.mark.asyncio
async def test_events_appear_in_history(self):
mgr = WebSocketManager()
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Only"),
_mock_run("Done"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch("infrastructure.ws_manager.handler.ws_manager", mgr),
):
await run_agentic_loop("History test", max_steps=1)
history_events = [e.event for e in mgr.event_history]
assert "agentic.plan_ready" in history_events
assert "agentic.step_complete" in history_events
assert "agentic.task_complete" in history_events
class TestBroadcastGracefulDegradation:
"""Loop completes even when ws_manager is unavailable."""
@pytest.mark.asyncio
async def test_loop_succeeds_when_broadcast_fails(self):
"""ImportError from ws_manager doesn't crash the loop."""
mock_agent = MagicMock()
mock_agent.run = MagicMock(
side_effect=[
_mock_run("1. Do it"),
_mock_run("Done"),
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=mock_agent),
patch(
"infrastructure.ws_manager.handler.ws_manager",
new_callable=lambda: MagicMock,
) as broken_mgr,
):
broken_mgr.broadcast = AsyncMock(side_effect=RuntimeError("ws down"))
result = await run_agentic_loop("Resilient task", max_steps=1)
assert result.status == "completed"
assert len(result.steps) == 1

View File

@@ -81,7 +81,6 @@ def test_create_timmy_respects_custom_ollama_url():
mock_settings.ollama_url = custom_url
mock_settings.ollama_num_ctx = 4096
mock_settings.timmy_model_backend = "ollama"
mock_settings.airllm_model_size = "70b"
from timmy.agent import create_timmy
@@ -159,7 +158,6 @@ def test_resolve_backend_auto_uses_airllm_on_apple_silicon():
patch("timmy.agent.settings") as mock_settings,
):
mock_settings.timmy_model_backend = "auto"
mock_settings.airllm_model_size = "70b"
mock_settings.ollama_model = "llama3.2"
from timmy.agent import _resolve_backend
@@ -174,7 +172,6 @@ def test_resolve_backend_auto_falls_back_on_non_apple():
patch("timmy.agent.settings") as mock_settings,
):
mock_settings.timmy_model_backend = "auto"
mock_settings.airllm_model_size = "70b"
mock_settings.ollama_model = "llama3.2"
from timmy.agent import _resolve_backend
@@ -259,7 +256,6 @@ def test_create_timmy_includes_tools_for_large_model():
mock_settings.ollama_url = "http://localhost:11434"
mock_settings.ollama_num_ctx = 4096
mock_settings.timmy_model_backend = "ollama"
mock_settings.airllm_model_size = "70b"
mock_settings.telemetry_enabled = False
from timmy.agent import create_timmy

View File

@@ -0,0 +1,386 @@
"""Tests for timmy.agentic_loop — multi-step task execution engine."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from timmy.agentic_loop import (
AgenticResult,
AgenticStep,
_parse_steps,
)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
class TestAgenticStep:
"""Unit tests for the AgenticStep dataclass."""
def test_creation(self):
step = AgenticStep(
step_num=1,
description="Do thing",
result="Done",
status="completed",
duration_ms=42,
)
assert step.step_num == 1
assert step.description == "Do thing"
assert step.result == "Done"
assert step.status == "completed"
assert step.duration_ms == 42
def test_failed_status(self):
step = AgenticStep(
step_num=2, description="Bad step", result="Error", status="failed", duration_ms=10
)
assert step.status == "failed"
def test_adapted_status(self):
step = AgenticStep(
step_num=3, description="Retried", result="OK", status="adapted", duration_ms=100
)
assert step.status == "adapted"
class TestAgenticResult:
"""Unit tests for the AgenticResult dataclass."""
def test_defaults(self):
result = AgenticResult(task_id="abc", task="Test", summary="Done")
assert result.steps == []
assert result.status == "completed"
assert result.total_duration_ms == 0
def test_with_steps(self):
s = AgenticStep(step_num=1, description="A", result="B", status="completed", duration_ms=5)
result = AgenticResult(task_id="x", task="T", summary="S", steps=[s])
assert len(result.steps) == 1
# ---------------------------------------------------------------------------
# _parse_steps — pure function, highly testable
# ---------------------------------------------------------------------------
class TestParseSteps:
"""Unit tests for the plan parser."""
def test_numbered_with_dots(self):
text = "1. First step\n2. Second step\n3. Third step"
steps = _parse_steps(text)
assert steps == ["First step", "Second step", "Third step"]
def test_numbered_with_parens(self):
text = "1) Do this\n2) Do that"
steps = _parse_steps(text)
assert steps == ["Do this", "Do that"]
def test_mixed_numbering(self):
text = "1. Step one\n2) Step two\n3. Step three"
steps = _parse_steps(text)
assert len(steps) == 3
def test_indented_steps(self):
text = " 1. Indented step\n 2. Also indented"
steps = _parse_steps(text)
assert len(steps) == 2
assert steps[0] == "Indented step"
def test_no_numbered_steps_fallback(self):
text = "Do this first\nThen do that\nFinally wrap up"
steps = _parse_steps(text)
assert len(steps) == 3
assert steps[0] == "Do this first"
def test_empty_string(self):
steps = _parse_steps("")
assert steps == []
def test_blank_lines_ignored_in_fallback(self):
text = "Step A\n\n\nStep B\n"
steps = _parse_steps(text)
assert steps == ["Step A", "Step B"]
def test_strips_whitespace(self):
text = "1. Lots of space \n2. Also spaced "
steps = _parse_steps(text)
assert steps[0] == "Lots of space"
assert steps[1] == "Also spaced"
def test_preamble_ignored_when_numbered(self):
text = "Here is the plan:\n1. Step one\n2. Step two"
steps = _parse_steps(text)
assert steps == ["Step one", "Step two"]
# ---------------------------------------------------------------------------
# _get_loop_agent — singleton pattern
# ---------------------------------------------------------------------------
class TestGetLoopAgent:
"""Tests for the agent singleton."""
def test_creates_agent_once(self):
import timmy.agentic_loop as mod
mod._loop_agent = None
mock_agent = MagicMock()
with patch("timmy.agent.create_timmy", return_value=mock_agent) as mock_create:
agent = mod._get_loop_agent()
assert agent is mock_agent
mock_create.assert_called_once()
# Second call should reuse singleton
agent2 = mod._get_loop_agent()
assert agent2 is mock_agent
mock_create.assert_called_once()
mod._loop_agent = None # cleanup
def test_reuses_existing(self):
import timmy.agentic_loop as mod
sentinel = MagicMock()
mod._loop_agent = sentinel
assert mod._get_loop_agent() is sentinel
mod._loop_agent = None # cleanup
# ---------------------------------------------------------------------------
# _broadcast_progress — best-effort WebSocket broadcast
# ---------------------------------------------------------------------------
class TestBroadcastProgress:
"""Tests for the WebSocket broadcast helper."""
@pytest.mark.asyncio
async def test_successful_broadcast(self):
from timmy.agentic_loop import _broadcast_progress
mock_ws = MagicMock()
mock_ws.broadcast = AsyncMock()
mock_module = MagicMock()
mock_module.ws_manager = mock_ws
with patch.dict("sys.modules", {"infrastructure.ws_manager.handler": mock_module}):
await _broadcast_progress("test.event", {"key": "value"})
mock_ws.broadcast.assert_awaited_once_with("test.event", {"key": "value"})
@pytest.mark.asyncio
async def test_import_error_swallowed(self):
"""When ws_manager import fails, broadcast silently succeeds."""
import sys
from timmy.agentic_loop import _broadcast_progress
# Remove the module so import fails
saved = sys.modules.pop("infrastructure.ws_manager.handler", None)
try:
with patch.dict("sys.modules", {"infrastructure": None}):
# Should not raise — errors are swallowed
await _broadcast_progress("fail.event", {})
finally:
if saved is not None:
sys.modules["infrastructure.ws_manager.handler"] = saved
# ---------------------------------------------------------------------------
# run_agentic_loop — integration-style tests with mocked agent
# ---------------------------------------------------------------------------
class TestRunAgenticLoop:
"""Tests for the main agentic loop."""
@pytest.fixture(autouse=True)
def _reset_agent(self):
import timmy.agentic_loop as mod
mod._loop_agent = None
yield
mod._loop_agent = None
def _mock_agent(self, responses):
"""Create a mock agent that returns responses in sequence."""
agent = MagicMock()
run_results = []
for r in responses:
mock_result = MagicMock()
mock_result.content = r
run_results.append(mock_result)
agent.run = MagicMock(side_effect=run_results)
return agent
@pytest.mark.asyncio
async def test_successful_two_step_task(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent(
[
"1. Step one\n2. Step two", # planning
"Step one done", # execution step 1
"Step two done", # execution step 2
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
):
result = await run_agentic_loop("Test task", max_steps=5)
assert result.status == "completed"
assert len(result.steps) == 2
assert result.steps[0].status == "completed"
assert result.steps[1].status == "completed"
assert result.total_duration_ms >= 0
@pytest.mark.asyncio
async def test_planning_failure(self):
from timmy.agentic_loop import run_agentic_loop
agent = MagicMock()
agent.run = MagicMock(side_effect=RuntimeError("LLM down"))
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
):
result = await run_agentic_loop("Broken task", max_steps=3)
assert result.status == "failed"
assert "Planning failed" in result.summary
@pytest.mark.asyncio
async def test_empty_plan(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent([""]) # empty plan
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
):
result = await run_agentic_loop("Empty plan task", max_steps=3)
assert result.status == "failed"
assert "no steps" in result.summary.lower()
@pytest.mark.asyncio
async def test_step_failure_triggers_adaptation(self):
from timmy.agentic_loop import run_agentic_loop
agent = MagicMock()
call_count = 0
def mock_run(prompt, **kwargs):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
result.content = "1. Only step"
elif call_count == 2:
raise RuntimeError("Step failed")
else:
result.content = "Adapted successfully"
return result
agent.run = mock_run
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
):
result = await run_agentic_loop("Failing task", max_steps=5)
assert len(result.steps) == 1
assert result.steps[0].status == "adapted"
assert "[Adapted]" in result.steps[0].description
@pytest.mark.asyncio
async def test_max_steps_truncation(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent(
[
"1. A\n2. B\n3. C\n4. D\n5. E", # 5 steps planned
"Done A",
"Done B",
]
)
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
):
result = await run_agentic_loop("Big task", max_steps=2)
assert result.status == "partial" # was truncated
assert len(result.steps) == 2
@pytest.mark.asyncio
async def test_on_progress_callback(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent(
[
"1. Only step",
"Step done",
]
)
progress_calls = []
async def track_progress(desc, step_num, total):
progress_calls.append((desc, step_num, total))
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
):
await run_agentic_loop("Callback task", max_steps=5, on_progress=track_progress)
assert len(progress_calls) == 1
assert progress_calls[0][1] == 1 # step_num
@pytest.mark.asyncio
async def test_default_max_steps_from_settings(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent(["1. Step one", "Done"])
mock_settings = MagicMock()
mock_settings.max_agent_steps = 7
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
patch("config.settings", mock_settings),
):
result = await run_agentic_loop("Settings task")
assert result.status == "completed"
@pytest.mark.asyncio
async def test_task_id_generated(self):
from timmy.agentic_loop import run_agentic_loop
agent = self._mock_agent(["1. Step", "OK"])
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock),
patch("timmy.session._clean_response", side_effect=lambda x: x),
):
result = await run_agentic_loop("ID task", max_steps=5)
assert result.task_id # non-empty
assert len(result.task_id) == 8 # uuid[:8]

View File

@@ -0,0 +1,319 @@
"""Unit tests for timmy.agentic_loop — agentic loop data structures, parsing, and execution."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from timmy.agentic_loop import (
AgenticResult,
AgenticStep,
_broadcast_progress,
_parse_steps,
run_agentic_loop,
)
# ── Data structures ──────────────────────────────────────────────────────────
class TestAgenticStep:
def test_fields(self):
step = AgenticStep(
step_num=1,
description="Do something",
result="Done",
status="completed",
duration_ms=42,
)
assert step.step_num == 1
assert step.description == "Do something"
assert step.result == "Done"
assert step.status == "completed"
assert step.duration_ms == 42
class TestAgenticResult:
def test_defaults(self):
r = AgenticResult(task_id="abc", task="test task", summary="ok")
assert r.steps == []
assert r.status == "completed"
assert r.total_duration_ms == 0
def test_with_steps(self):
step = AgenticStep(1, "s", "r", "completed", 10)
r = AgenticResult(task_id="x", task="t", summary="s", steps=[step])
assert len(r.steps) == 1
# ── _parse_steps ─────────────────────────────────────────────────────────────
class TestParseSteps:
def test_numbered_dot(self):
text = "1. First step\n2. Second step\n3. Third step"
assert _parse_steps(text) == ["First step", "Second step", "Third step"]
def test_numbered_paren(self):
text = "1) Alpha\n2) Beta"
assert _parse_steps(text) == ["Alpha", "Beta"]
def test_mixed_whitespace(self):
text = " 1. Indented step\n 2. Another "
result = _parse_steps(text)
assert result == ["Indented step", "Another"]
def test_fallback_plain_lines(self):
text = "Do this\nDo that\nDo the other"
assert _parse_steps(text) == ["Do this", "Do that", "Do the other"]
def test_empty_string(self):
assert _parse_steps("") == []
def test_blank_lines_skipped_in_fallback(self):
text = "line one\n\nline two\n \nline three"
assert _parse_steps(text) == ["line one", "line two", "line three"]
# ── _get_loop_agent ──────────────────────────────────────────────────────────
class TestGetLoopAgent:
def test_creates_agent_once(self):
import timmy.agentic_loop as al
saved = al._loop_agent
try:
al._loop_agent = None
mock_agent = MagicMock()
with patch("timmy.agent.create_timmy", return_value=mock_agent):
result = al._get_loop_agent()
assert result is mock_agent
# Second call returns cached
result2 = al._get_loop_agent()
assert result2 is mock_agent
finally:
al._loop_agent = saved
def test_returns_cached(self):
import timmy.agentic_loop as al
saved = al._loop_agent
try:
sentinel = object()
al._loop_agent = sentinel
assert al._get_loop_agent() is sentinel
finally:
al._loop_agent = saved
# ── _broadcast_progress ──────────────────────────────────────────────────────
class TestBroadcastProgress:
@pytest.mark.asyncio
async def test_success(self):
mock_ws = AsyncMock()
with (
patch("timmy.agentic_loop.ws_manager", mock_ws, create=True),
patch.dict(
"sys.modules",
{"infrastructure.ws_manager.handler": MagicMock(ws_manager=mock_ws)},
),
):
await _broadcast_progress("test.event", {"key": "val"})
mock_ws.broadcast.assert_awaited_once_with("test.event", {"key": "val"})
@pytest.mark.asyncio
async def test_import_error_swallowed(self):
with patch.dict("sys.modules", {"infrastructure.ws_manager.handler": None}):
# Should not raise
await _broadcast_progress("test.event", {})
# ── run_agentic_loop ─────────────────────────────────────────────────────────
def _make_mock_agent(plan_text, step_responses=None):
"""Create a mock agent whose .run returns predictable content."""
call_count = 0
def run_side_effect(prompt, *, stream=False, session_id=""):
nonlocal call_count
call_count += 1
resp = MagicMock()
if call_count == 1:
# Planning call
resp.content = plan_text
else:
idx = call_count - 2 # step index (0-based)
if step_responses and idx < len(step_responses):
val = step_responses[idx]
if isinstance(val, Exception):
raise val
resp.content = val
else:
resp.content = f"Step result {call_count}"
return resp
agent = MagicMock()
agent.run = MagicMock(side_effect=run_side_effect)
return agent
@pytest.fixture
def _patch_broadcast():
with patch("timmy.agentic_loop._broadcast_progress", new_callable=AsyncMock):
yield
@pytest.fixture
def _patch_clean_response():
with patch("timmy.session._clean_response", side_effect=lambda x: x):
yield
class TestRunAgenticLoop:
@pytest.mark.asyncio
async def test_successful_execution(self, _patch_broadcast, _patch_clean_response):
agent = _make_mock_agent("1. Step A\n2. Step B", ["Result A", "Result B"])
mock_settings = MagicMock()
mock_settings.max_agent_steps = 10
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch("timmy.agentic_loop.settings", mock_settings, create=True),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=5)
assert result.status == "completed"
assert len(result.steps) == 2
assert result.steps[0].status == "completed"
assert result.steps[0].description == "Step A"
assert result.total_duration_ms >= 0
@pytest.mark.asyncio
async def test_planning_failure(self, _patch_broadcast):
agent = MagicMock()
agent.run = MagicMock(side_effect=RuntimeError("LLM down"))
mock_settings = MagicMock()
mock_settings.max_agent_steps = 5
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=3)
assert result.status == "failed"
assert "Planning failed" in result.summary
@pytest.mark.asyncio
async def test_empty_plan(self, _patch_broadcast):
agent = _make_mock_agent("")
mock_settings = MagicMock()
mock_settings.max_agent_steps = 5
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=3)
assert result.status == "failed"
assert "no steps" in result.summary.lower()
@pytest.mark.asyncio
async def test_step_failure_triggers_adaptation(self, _patch_broadcast, _patch_clean_response):
agent = _make_mock_agent(
"1. Do X\n2. Do Y",
[RuntimeError("oops"), "Adapted result", "Y done"],
)
mock_settings = MagicMock()
mock_settings.max_agent_steps = 10
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=5)
# Step 1 should be adapted, step 2 completed
statuses = [s.status for s in result.steps]
assert "adapted" in statuses
@pytest.mark.asyncio
async def test_truncation_marks_partial(self, _patch_broadcast, _patch_clean_response):
agent = _make_mock_agent(
"1. A\n2. B\n3. C\n4. D\n5. E",
["r1", "r2"],
)
mock_settings = MagicMock()
mock_settings.max_agent_steps = 10
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=2)
assert result.status == "partial"
@pytest.mark.asyncio
async def test_on_progress_callback(self, _patch_broadcast, _patch_clean_response):
agent = _make_mock_agent("1. Only step", ["done"])
mock_settings = MagicMock()
mock_settings.max_agent_steps = 10
callback = AsyncMock()
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=5, on_progress=callback)
callback.assert_awaited_once_with("Only step", 1, 1)
assert result.status == "completed"
@pytest.mark.asyncio
async def test_default_max_steps_from_settings(self, _patch_broadcast, _patch_clean_response):
agent = _make_mock_agent("1. S1", ["r1"])
mock_settings = MagicMock()
mock_settings.max_agent_steps = 3
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff") # max_steps=0 → from settings
assert result.status == "completed"
@pytest.mark.asyncio
async def test_failed_step_and_failed_adaptation(self, _patch_broadcast, _patch_clean_response):
"""When both step and adaptation fail, step is marked failed."""
call_count = 0
def run_side_effect(prompt, *, stream=False, session_id=""):
nonlocal call_count
call_count += 1
if call_count == 1:
resp = MagicMock()
resp.content = "1. Only step"
return resp
# Both step execution and adaptation fail
raise RuntimeError("everything broken")
agent = MagicMock()
agent.run = MagicMock(side_effect=run_side_effect)
mock_settings = MagicMock()
mock_settings.max_agent_steps = 10
with (
patch("timmy.agentic_loop._get_loop_agent", return_value=agent),
patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}),
):
result = await run_agentic_loop("do stuff", max_steps=5)
assert result.steps[0].status == "failed"
assert "Failed" in result.steps[0].result
assert result.status == "partial"

View File

@@ -13,44 +13,46 @@ from timmy.memory.embeddings import (
embed_text,
)
# ── _simple_hash_embedding ──────────────────────────────────────────────────
# ── _simple_hash_embedding ──────────────────────────────────────────────────
class TestSimpleHashEmbedding:
"""Tests for the deterministic hash-based fallback embedding."""
def test_returns_128_floats(self):
def test_returns_128_dim_vector(self):
vec = _simple_hash_embedding("hello world")
assert len(vec) == 128
assert all(isinstance(x, float) for x in vec)
def test_deterministic(self):
"""Same input always produces the same vector."""
assert _simple_hash_embedding("test") == _simple_hash_embedding("test")
def test_normalized(self):
"""Output vector has unit magnitude."""
vec = _simple_hash_embedding("some text for testing")
vec = _simple_hash_embedding("some text for embedding")
mag = math.sqrt(sum(x * x for x in vec))
assert mag == pytest.approx(1.0, abs=1e-6)
def test_deterministic(self):
a = _simple_hash_embedding("same input")
b = _simple_hash_embedding("same input")
assert a == b
def test_different_texts_differ(self):
a = _simple_hash_embedding("hello world")
b = _simple_hash_embedding("goodbye moon")
assert a != b
def test_empty_string(self):
"""Empty string doesn't crash — returns a zero-ish vector."""
vec = _simple_hash_embedding("")
assert len(vec) == 128
# All zeros normalised stays zero (mag fallback to 1.0)
assert all(x == 0.0 for x in vec)
def test_different_inputs_differ(self):
a = _simple_hash_embedding("alpha")
b = _simple_hash_embedding("beta")
assert a != b
def test_long_text_truncates_at_50_words(self):
"""Words beyond 50 should not change the result."""
short = " ".join(f"word{i}" for i in range(50))
long = short + " extra1 extra2 extra3"
assert _simple_hash_embedding(short) == _simple_hash_embedding(long)
# ── cosine_similarity ────────────────────────────────────────────────────────
class TestCosineSimilarity:
"""Tests for cosine similarity calculation."""
def test_identical_vectors(self):
v = [1.0, 2.0, 3.0]
assert cosine_similarity(v, v) == pytest.approx(1.0)
@@ -69,86 +71,85 @@ class TestCosineSimilarity:
assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0
assert cosine_similarity([1.0, 2.0], [0.0, 0.0]) == 0.0
def test_different_length_uses_zip(self):
"""zip(strict=False) truncates to shortest — verify no crash."""
result = cosine_similarity([1.0, 0.0], [1.0, 0.0, 9.9])
# mag_b includes the extra element, so result < 1.0
assert isinstance(result, float)
def test_both_zero_vectors(self):
assert cosine_similarity([0.0], [0.0]) == 0.0
# ── _keyword_overlap ─────────────────────────────────────────────────────────
class TestKeywordOverlap:
"""Tests for keyword overlap scoring."""
def test_full_overlap(self):
assert _keyword_overlap("hello world", "hello world extra") == pytest.approx(1.0)
assert _keyword_overlap("hello world", "hello world") == pytest.approx(1.0)
def test_partial_overlap(self):
assert _keyword_overlap("hello world", "hello there") == pytest.approx(0.5)
assert _keyword_overlap("hello world", "hello moon") == pytest.approx(0.5)
def test_no_overlap(self):
assert _keyword_overlap("alpha", "beta") == pytest.approx(0.0)
assert _keyword_overlap("hello", "goodbye") == pytest.approx(0.0)
def test_empty_query(self):
assert _keyword_overlap("", "some content") == 0.0
assert _keyword_overlap("", "anything") == 0.0
def test_case_insensitive(self):
assert _keyword_overlap("Hello", "hello world") == pytest.approx(1.0)
assert _keyword_overlap("Hello World", "hello world") == pytest.approx(1.0)
# ── embed_text ───────────────────────────────────────────────────────────────
class TestEmbedText:
"""Tests for the main embed_text entry point."""
def test_uses_fallback_when_model_false(self):
"""When _get_embedding_model returns False, use hash fallback."""
with patch.object(emb, "EMBEDDING_MODEL", False):
with patch.object(emb, "_get_embedding_model", return_value=False):
vec = embed_text("hello")
assert len(vec) == 128
def test_uses_fallback_when_model_disabled(self):
with patch.object(emb, "_get_embedding_model", return_value=False):
vec = embed_text("test")
assert len(vec) == 128 # hash fallback dimension
def test_uses_model_when_available(self):
"""When a real model is loaded, call model.encode()."""
mock_encoding = MagicMock()
mock_encoding.tolist.return_value = [0.1, 0.2, 0.3]
mock_model = MagicMock()
mock_model.encode.return_value = MagicMock(tolist=MagicMock(return_value=[0.1, 0.2]))
mock_model.encode.return_value = mock_encoding
with patch.object(emb, "_get_embedding_model", return_value=mock_model):
vec = embed_text("hello")
assert vec == [0.1, 0.2]
mock_model.encode.assert_called_once_with("hello")
result = embed_text("test")
assert result == pytest.approx([0.1, 0.2, 0.3])
mock_model.encode.assert_called_once_with("test")
# ── _get_embedding_model ─────────────────────────────────────────────────────
class TestGetEmbeddingModel:
"""Tests for lazy model loading."""
def setup_method(self):
"""Reset global state before each test."""
self._saved_model = emb.EMBEDDING_MODEL
emb.EMBEDDING_MODEL = None
def teardown_method(self):
emb.EMBEDDING_MODEL = self._saved_model
def test_skip_embeddings_setting(self):
"""When settings.timmy_skip_embeddings is True, model is set to False."""
mock_settings = MagicMock()
mock_settings.timmy_skip_embeddings = True
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
emb.EMBEDDING_MODEL = None
result = emb._get_embedding_model()
assert result is False
assert result is False
def test_sentence_transformers_import_error(self):
"""When sentence-transformers is missing, falls back to False."""
def test_fallback_when_transformers_missing(self):
mock_settings = MagicMock()
mock_settings.timmy_skip_embeddings = False
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
with patch.dict("sys.modules", {"sentence_transformers": None}):
emb.EMBEDDING_MODEL = None
result = emb._get_embedding_model()
assert result is False
with patch.dict(
"sys.modules",
{
"config": MagicMock(settings=mock_settings),
"sentence_transformers": None,
},
):
emb.EMBEDDING_MODEL = None
result = emb._get_embedding_model()
assert result is False
def teardown_method(self):
emb.EMBEDDING_MODEL = None
def test_returns_cached_model(self):
sentinel = object()
emb.EMBEDDING_MODEL = sentinel
assert emb._get_embedding_model() is sentinel