Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 46s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 57s
Tests / e2e (pull_request) Successful in 3m4s
Tests / test (pull_request) Failing after 1h0m20s
When cloud provider fails during tool calling (timeout, 429, 503), fall back to local Ollama to keep the agent working. New agent/tool_fallback.py: - ToolFallbackHandler: manages fallback execution - should_fallback(error): detects provider failures (429, 503, timeout, rate limit, quota exceeded, connection errors) - call_with_fallback(): makes API call via local Ollama when primary provider fails - FallbackEvent: records each fallback for fleet reporting - format_report(): human-readable fallback summary - Singleton handler via get_tool_fallback_handler() Config via env vars: - TOOL_FALLBACK_PROVIDER (default: ollama) - TOOL_FALLBACK_MODEL (default: qwen2.5:7b) - TOOL_FALLBACK_BASE_URL (default: http://localhost:11434/v1) Tests: tests/test_tool_fallback.py Closes #746
75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
"""Tests for tool call fallback handler."""
|
|
|
|
import pytest
|
|
import sys
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
|
|
from agent.tool_fallback import ToolFallbackHandler, FallbackEvent, get_tool_fallback_handler
|
|
|
|
|
|
class TestShouldFallback:
|
|
def test_rate_limit_triggers(self):
|
|
handler = ToolFallbackHandler()
|
|
assert handler.should_fallback("429 rate limit exceeded")
|
|
assert handler.should_fallback("RateLimitError: too many requests")
|
|
|
|
def test_timeout_triggers(self):
|
|
handler = ToolFallbackHandler()
|
|
assert handler.should_fallback("Connection timed out")
|
|
assert handler.should_fallback("Request timed out after 30s")
|
|
|
|
def test_503_triggers(self):
|
|
handler = ToolFallbackHandler()
|
|
assert handler.should_fallback("503 Service Unavailable")
|
|
assert handler.should_fallback("Service unavailable")
|
|
|
|
def test_quota_triggers(self):
|
|
handler = ToolFallbackHandler()
|
|
assert handler.should_fallback("quota exceeded")
|
|
assert handler.should_fallback("insufficient credits")
|
|
|
|
def test_normal_error_no_trigger(self):
|
|
handler = ToolFallbackHandler()
|
|
assert not handler.should_fallback("Invalid API key")
|
|
assert not handler.should_fallback("Model not found")
|
|
|
|
def test_disabled_handler(self):
|
|
handler = ToolFallbackHandler(enabled=False)
|
|
assert not handler.should_fallback("429 rate limit")
|
|
|
|
|
|
class TestFallbackEvents:
|
|
def test_event_creation(self):
|
|
event = FallbackEvent(
|
|
timestamp=1234567890.0,
|
|
tool_name="terminal",
|
|
original_provider="openrouter",
|
|
fallback_provider="ollama",
|
|
error="",
|
|
success=True,
|
|
duration_ms=150,
|
|
)
|
|
assert event.tool_name == "terminal"
|
|
assert event.success
|
|
assert event.duration_ms == 150
|
|
|
|
|
|
class TestFallbackHandler:
|
|
def test_stats_initial(self):
|
|
handler = ToolFallbackHandler()
|
|
assert handler.stats["total_fallbacks"] == 0
|
|
assert handler.stats["successful_fallbacks"] == 0
|
|
|
|
def test_report_no_events(self):
|
|
handler = ToolFallbackHandler()
|
|
report = handler.format_report()
|
|
assert "No fallback events" in report
|
|
|
|
def test_singleton(self):
|
|
h1 = get_tool_fallback_handler()
|
|
h2 = get_tool_fallback_handler()
|
|
assert h1 is h2
|