Compare commits

...

1 Commits

Author SHA1 Message Date
Alexander Whitestone
e2c96e35e1 feat: local model fallback chain for tool calls (#746)
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 46s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 57s
Tests / e2e (pull_request) Successful in 3m4s
Tests / test (pull_request) Failing after 1h0m20s
When cloud provider fails during tool calling (timeout, 429, 503),
fall back to local Ollama to keep the agent working.

New agent/tool_fallback.py:
- ToolFallbackHandler: manages fallback execution
- should_fallback(error): detects provider failures (429, 503,
  timeout, rate limit, quota exceeded, connection errors)
- call_with_fallback(): makes API call via local Ollama when
  primary provider fails
- FallbackEvent: records each fallback for fleet reporting
- format_report(): human-readable fallback summary
- Singleton handler via get_tool_fallback_handler()

Config via env vars:
- TOOL_FALLBACK_PROVIDER (default: ollama)
- TOOL_FALLBACK_MODEL (default: qwen2.5:7b)
- TOOL_FALLBACK_BASE_URL (default: http://localhost:11434/v1)

Tests: tests/test_tool_fallback.py

Closes #746
2026-04-14 23:29:44 -04:00
2 changed files with 319 additions and 0 deletions

245
agent/tool_fallback.py Normal file
View File

@@ -0,0 +1,245 @@
"""Tool call fallback — retry failed tool calls with local model.
When the primary provider fails during tool calling (timeout, 429, 503),
fall back to local Ollama to keep the agent working.
Usage:
from agent.tool_fallback import ToolFallbackHandler
handler = ToolFallbackHandler()
result = handler.execute_with_fallback(tool_fn, args, context)
"""
from __future__ import annotations
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional
logger = logging.getLogger(__name__)
# Fallback provider config
_FALLBACK_PROVIDER = os.getenv("TOOL_FALLBACK_PROVIDER", "ollama")
_FALLBACK_MODEL = os.getenv("TOOL_FALLBACK_MODEL", "qwen2.5:7b")
_FALLBACK_BASE_URL = os.getenv("TOOL_FALLBACK_BASE_URL", "http://localhost:11434/v1")
# Error patterns that trigger fallback
_FALLBACK_TRIGGERS = [
"429",
"rate limit",
"ratelimit",
"503",
"service unavailable",
"timeout",
"timed out",
"connection error",
"connection refused",
"overloaded",
"capacity",
"quota exceeded",
"insufficient",
]
@dataclass
class FallbackEvent:
"""Record of a fallback event."""
timestamp: float
tool_name: str
original_provider: str
fallback_provider: str
error: str
success: bool
duration_ms: int = 0
class ToolFallbackHandler:
"""Handles tool call fallback to local models.
Tracks fallback events and provides fallback execution.
"""
def __init__(
self,
fallback_provider: str = "",
fallback_model: str = "",
fallback_base_url: str = "",
enabled: bool = True,
):
self.fallback_provider = fallback_provider or _FALLBACK_PROVIDER
self.fallback_model = fallback_model or _FALLBACK_MODEL
self.fallback_base_url = fallback_base_url or _FALLBACK_BASE_URL
self.enabled = enabled
self._events: list[FallbackEvent] = []
self._fallback_count = 0
self._fallback_success_count = 0
@property
def events(self) -> list[FallbackEvent]:
return list(self._events)
@property
def stats(self) -> dict:
return {
"total_fallbacks": self._fallback_count,
"successful_fallbacks": self._fallback_success_count,
"fallback_rate": (
self._fallback_success_count / self._fallback_count
if self._fallback_count > 0 else 0
),
}
def should_fallback(self, error: Any) -> bool:
"""Check if an error should trigger fallback."""
if not self.enabled:
return False
error_str = str(error).lower()
return any(trigger in error_str for trigger in _FALLBACK_TRIGGERS)
def get_fallback_client(self) -> Optional[Any]:
"""Get an OpenAI client configured for the fallback provider."""
try:
from openai import OpenAI
client = OpenAI(
base_url=self.fallback_base_url,
api_key=os.getenv("OPENAI_API_KEY", "ollama"),
)
return client
except Exception as e:
logger.error("Failed to create fallback client: %s", e)
return None
def call_with_fallback(
self,
messages: list[dict],
tools: list[dict] = None,
original_provider: str = "",
tool_name: str = "unknown",
max_tokens: int = 1024,
) -> dict:
"""Make an API call with fallback to local model on failure.
Args:
messages: Conversation messages
tools: Tool definitions
original_provider: Name of the original provider
tool_name: Name of the tool being called
max_tokens: Max tokens for the response
Returns:
Dict with 'response', 'used_fallback', 'fallback_event' keys.
"""
t0 = time.monotonic()
# Try fallback client
client = self.get_fallback_client()
if not client:
return {
"response": None,
"used_fallback": False,
"error": "Fallback client unavailable",
}
try:
response = client.chat.completions.create(
model=self.fallback_model,
messages=messages,
tools=tools if tools else None,
max_tokens=max_tokens,
)
elapsed = int((time.monotonic() - t0) * 1000)
event = FallbackEvent(
timestamp=time.time(),
tool_name=tool_name,
original_provider=original_provider,
fallback_provider=self.fallback_provider,
error="",
success=True,
duration_ms=elapsed,
)
self._events.append(event)
self._fallback_count += 1
self._fallback_success_count += 1
logger.info(
"Tool fallback succeeded: %s via %s (%dms)",
tool_name, self.fallback_provider, elapsed,
)
return {
"response": response,
"used_fallback": True,
"fallback_event": event,
}
except Exception as e:
elapsed = int((time.monotonic() - t0) * 1000)
event = FallbackEvent(
timestamp=time.time(),
tool_name=tool_name,
original_provider=original_provider,
fallback_provider=self.fallback_provider,
error=str(e),
success=False,
duration_ms=elapsed,
)
self._events.append(event)
self._fallback_count += 1
logger.error(
"Tool fallback failed: %s via %s%s",
tool_name, self.fallback_provider, e,
)
return {
"response": None,
"used_fallback": True,
"fallback_event": event,
"error": str(e),
}
def format_report(self) -> str:
"""Format fallback events as a report."""
if not self._events:
return "No fallback events recorded."
lines = [
"Tool Fallback Report",
"=" * 40,
f"Total fallbacks: {self._fallback_count}",
f"Successful: {self._fallback_success_count}",
f"Failed: {self._fallback_count - self._fallback_success_count}",
"",
]
for event in self._events[-10:]:
status = "OK" if event.success else "FAIL"
lines.append(
f" [{status}] {event.tool_name} via {event.fallback_provider} "
f"({event.duration_ms}ms) — {event.original_provider}"
)
if event.error:
lines.append(f" Error: {event.error[:100]}")
return "\n".join(lines)
# Singleton handler
_handler: Optional[ToolFallbackHandler] = None
def get_tool_fallback_handler() -> ToolFallbackHandler:
"""Get or create the singleton tool fallback handler."""
global _handler
if _handler is None:
_handler = ToolFallbackHandler()
return _handler
def reset_tool_fallback_handler() -> None:
"""Reset the singleton (for testing)."""
global _handler
_handler = None

View File

@@ -0,0 +1,74 @@
"""Tests for tool call fallback handler."""
import pytest
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.tool_fallback import ToolFallbackHandler, FallbackEvent, get_tool_fallback_handler
class TestShouldFallback:
def test_rate_limit_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("429 rate limit exceeded")
assert handler.should_fallback("RateLimitError: too many requests")
def test_timeout_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("Connection timed out")
assert handler.should_fallback("Request timed out after 30s")
def test_503_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("503 Service Unavailable")
assert handler.should_fallback("Service unavailable")
def test_quota_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("quota exceeded")
assert handler.should_fallback("insufficient credits")
def test_normal_error_no_trigger(self):
handler = ToolFallbackHandler()
assert not handler.should_fallback("Invalid API key")
assert not handler.should_fallback("Model not found")
def test_disabled_handler(self):
handler = ToolFallbackHandler(enabled=False)
assert not handler.should_fallback("429 rate limit")
class TestFallbackEvents:
def test_event_creation(self):
event = FallbackEvent(
timestamp=1234567890.0,
tool_name="terminal",
original_provider="openrouter",
fallback_provider="ollama",
error="",
success=True,
duration_ms=150,
)
assert event.tool_name == "terminal"
assert event.success
assert event.duration_ms == 150
class TestFallbackHandler:
def test_stats_initial(self):
handler = ToolFallbackHandler()
assert handler.stats["total_fallbacks"] == 0
assert handler.stats["successful_fallbacks"] == 0
def test_report_no_events(self):
handler = ToolFallbackHandler()
report = handler.format_report()
assert "No fallback events" in report
def test_singleton(self):
h1 = get_tool_fallback_handler()
h2 = get_tool_fallback_handler()
assert h1 is h2