Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
c22cdcaa8e fix: add _validate_gateway_config tests and API_SERVER_KEY network binding warning
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 23s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 27s
Tests / e2e (pull_request) Successful in 1m51s
Tests / test (pull_request) Failing after 37m0s
Refs #892 - Gateway config debt: missing keys and broken fallbacks

Changes:
- Add `_is_network_accessible()` helper to gateway/config.py (avoids circular
  import with gateway.platforms.base which imports from gateway.config)
- Add API_SERVER_KEY warning in `_validate_gateway_config`: when the API server
  is enabled on a network-accessible address (0.0.0.0, public IP, hostname) but
  no key is configured, log a warning at config-load time so operators see the
  issue before any adapter initialisation runs
- Add `TestValidateGatewayConfig` in tests/gateway/test_config.py covering:
  - idle_minutes <= 0 and None are corrected to 1440 (default)
  - at_hour outside 0-23 is corrected to 4 (default)
  - Boundary hours 0 and 23 are accepted unchanged
  - Empty platform token triggers a warning log
  - Disabled platform with empty token produces no warning
  - API server on 0.0.0.0 without key logs a warning
  - API server on 127.0.0.1 without key is silent (loopback is allowed)
  - API server with a key set logs no warning regardless of bind address

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-17 02:18:02 -04:00
4 changed files with 186 additions and 232 deletions

View File

@@ -1,156 +0,0 @@
"""Tool fixation detection — break repetitive tool calling loops.
Detects when the agent latches onto one tool and calls it repeatedly
without making progress. Injects a nudge prompt to break the loop.
Usage:
from agent.tool_fixation_detector import ToolFixationDetector
detector = ToolFixationDetector()
nudge = detector.record("execute_code")
if nudge:
# Inject nudge into conversation
messages.append({"role": "system", "content": nudge})
"""
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
# Default thresholds
_DEFAULT_THRESHOLD = int(os.getenv("TOOL_FIXATION_THRESHOLD", "5"))
_DEFAULT_WINDOW = int(os.getenv("TOOL_FIXATION_WINDOW", "10"))
@dataclass
class FixationEvent:
"""Record of a fixation detection."""
tool_name: str
streak_length: int
threshold: int
nudge_sent: bool = False
class ToolFixationDetector:
"""Detects and breaks tool fixation loops.
Tracks the sequence of tool calls and detects when the same tool
is called N times consecutively. When detected, returns a nudge
prompt to inject into the conversation.
"""
def __init__(self, threshold: int = 0, window: int = 0):
self.threshold = threshold or _DEFAULT_THRESHOLD
self.window = window or _DEFAULT_WINDOW
self._history: List[str] = []
self._current_streak: str = ""
self._streak_count: int = 0
self._nudges_sent: int = 0
self._events: List[FixationEvent] = []
@property
def nudges_sent(self) -> int:
return self._nudges_sent
@property
def events(self) -> List[FixationEvent]:
return list(self._events)
def record(self, tool_name: str) -> Optional[str]:
"""Record a tool call and return nudge prompt if fixation detected.
Args:
tool_name: Name of the tool that was called.
Returns:
Nudge prompt string if fixation detected, None otherwise.
"""
self._history.append(tool_name)
# Trim history to window
if len(self._history) > self.window:
self._history = self._history[-self.window:]
# Update streak
if tool_name == self._current_streak:
self._streak_count += 1
else:
self._current_streak = tool_name
self._streak_count = 1
# Check for fixation
if self._streak_count >= self.threshold:
event = FixationEvent(
tool_name=tool_name,
streak_length=self._streak_count,
threshold=self.threshold,
nudge_sent=True,
)
self._events.append(event)
self._nudges_sent += 1
return self._build_nudge(tool_name, self._streak_count)
return None
def _build_nudge(self, tool_name: str, count: int) -> str:
"""Build a nudge prompt to break the fixation loop."""
return (
f"[SYSTEM: You have called `{tool_name}` {count} times in a row "
f"without switching tools. This suggests a fixation loop. "
f"Consider:\n"
f"1. Is the tool returning an error? Read the error carefully.\n"
f"2. Is there a different tool that could help?\n"
f"3. Should you ask the user for clarification?\n"
f"4. Is the task actually complete?\n"
f"Break the loop by trying a different approach.]"
)
def reset(self) -> None:
"""Reset the detector state."""
self._history.clear()
self._current_streak = ""
self._streak_count = 0
def get_streak_info(self) -> dict:
"""Get current streak information."""
return {
"current_tool": self._current_streak,
"streak_count": self._streak_count,
"threshold": self.threshold,
"at_threshold": self._streak_count >= self.threshold,
"nudges_sent": self._nudges_sent,
}
def format_report(self) -> str:
"""Format fixation events as a report."""
if not self._events:
return "No tool fixation detected."
lines = [
f"Tool Fixation Report ({len(self._events)} events)",
"=" * 40,
]
for e in self._events:
lines.append(f" {e.tool_name}: {e.streak_length} consecutive calls (threshold: {e.threshold})")
return "\n".join(lines)
# Singleton
_detector: Optional[ToolFixationDetector] = None
def get_fixation_detector() -> ToolFixationDetector:
"""Get or create the singleton detector."""
global _detector
if _detector is None:
_detector = ToolFixationDetector()
return _detector
def reset_fixation_detector() -> None:
"""Reset the singleton."""
global _detector
_detector = None

View File

@@ -8,6 +8,7 @@ Handles loading and validating configuration for:
- Delivery preferences
"""
import ipaddress
import logging
import os
import json
@@ -679,6 +680,26 @@ def load_gateway_config() -> GatewayConfig:
return config
def _is_network_accessible(host: str) -> bool:
"""Return True if *host* would expose a server beyond the loopback interface.
Duplicates the logic in ``gateway.platforms.base.is_network_accessible``
without creating a circular import (base.py imports from this module).
"""
try:
addr = ipaddress.ip_address(host)
if addr.is_loopback:
return False
# ::ffff:127.x.x.x — Python's is_loopback returns False for
# IPv4-mapped loopback; unwrap and check the underlying IPv4.
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
return False
return True
except ValueError:
# Hostname: assume it could be network-accessible.
return True
def _validate_gateway_config(config: "GatewayConfig") -> None:
"""Validate and sanitize a loaded GatewayConfig in place.
@@ -747,6 +768,22 @@ def _validate_gateway_config(config: "GatewayConfig") -> None:
)
pconfig.enabled = False
# Warn when the API server is enabled on a network-accessible address
# without an auth key. The adapter will refuse to start anyway, but
# surfacing this at config-load time lets operators see the problem in
# the startup log before any platform adapter initialisation runs.
api_cfg = config.platforms.get(Platform.API_SERVER)
if api_cfg and api_cfg.enabled:
key = api_cfg.extra.get("key", "")
host = api_cfg.extra.get("host", "127.0.0.1")
if not key and _is_network_accessible(host):
logger.warning(
"API Server is enabled on %s but API_SERVER_KEY is not set. "
"The adapter will refuse to start on a network-accessible address. "
"Set API_SERVER_KEY or bind to 127.0.0.1 for local-only access.",
host,
)
def _apply_env_overrides(config: GatewayConfig) -> None:
"""Apply environment variable overrides to config."""

View File

@@ -10,6 +10,7 @@ from gateway.config import (
PlatformConfig,
SessionResetPolicy,
_apply_env_overrides,
_validate_gateway_config,
load_gateway_config,
)
@@ -294,3 +295,151 @@ class TestHomeChannelEnvOverrides:
home = config.platforms[platform].home_channel
assert home is not None, f"{platform.value}: home_channel should not be None"
assert (home.chat_id, home.name) == expected, platform.value
class TestValidateGatewayConfig:
"""Tests for _validate_gateway_config — in-place sanitisation of loaded config."""
# -- idle_minutes validation --
def test_idle_minutes_zero_is_corrected_to_default(self):
config = GatewayConfig()
config.default_reset_policy.idle_minutes = 0
_validate_gateway_config(config)
assert config.default_reset_policy.idle_minutes == 1440
def test_idle_minutes_negative_is_corrected_to_default(self):
config = GatewayConfig()
config.default_reset_policy.idle_minutes = -60
_validate_gateway_config(config)
assert config.default_reset_policy.idle_minutes == 1440
def test_idle_minutes_none_is_corrected_to_default(self):
config = GatewayConfig()
config.default_reset_policy.idle_minutes = None # type: ignore[assignment]
_validate_gateway_config(config)
assert config.default_reset_policy.idle_minutes == 1440
def test_valid_idle_minutes_is_unchanged(self):
config = GatewayConfig()
config.default_reset_policy.idle_minutes = 90
_validate_gateway_config(config)
assert config.default_reset_policy.idle_minutes == 90
# -- at_hour validation --
def test_at_hour_too_high_is_corrected_to_default(self):
config = GatewayConfig()
config.default_reset_policy.at_hour = 24
_validate_gateway_config(config)
assert config.default_reset_policy.at_hour == 4
def test_at_hour_negative_is_corrected_to_default(self):
config = GatewayConfig()
config.default_reset_policy.at_hour = -1
_validate_gateway_config(config)
assert config.default_reset_policy.at_hour == 4
def test_valid_at_hour_is_unchanged(self):
config = GatewayConfig()
config.default_reset_policy.at_hour = 3
_validate_gateway_config(config)
assert config.default_reset_policy.at_hour == 3
def test_at_hour_boundary_values_are_valid(self):
for valid_hour in (0, 23):
config = GatewayConfig()
config.default_reset_policy.at_hour = valid_hour
_validate_gateway_config(config)
assert config.default_reset_policy.at_hour == valid_hour
# -- empty-token warning (enabled platforms) --
def test_empty_string_token_logs_warning(self, caplog):
import logging
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token=""),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert any(
"TELEGRAM_BOT_TOKEN" in r.message and "empty" in r.message
for r in caplog.records
)
def test_disabled_platform_with_empty_token_no_warning(self, caplog):
import logging
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=False, token=""),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert not any("TELEGRAM_BOT_TOKEN" in r.message for r in caplog.records)
# -- API Server key / binding warnings --
def test_api_server_network_binding_without_key_logs_warning(self, caplog):
import logging
config = GatewayConfig(
platforms={
Platform.API_SERVER: PlatformConfig(
enabled=True,
extra={"host": "0.0.0.0"},
),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert any(
"API_SERVER_KEY" in r.message for r in caplog.records
)
def test_api_server_loopback_without_key_no_warning(self, caplog):
import logging
config = GatewayConfig(
platforms={
Platform.API_SERVER: PlatformConfig(
enabled=True,
extra={"host": "127.0.0.1"},
),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert not any(
"API_SERVER_KEY" in r.message for r in caplog.records
)
def test_api_server_network_binding_with_key_no_warning(self, caplog):
import logging
config = GatewayConfig(
platforms={
Platform.API_SERVER: PlatformConfig(
enabled=True,
extra={"host": "0.0.0.0", "key": "sk-real-key-here"},
),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert not any(
"API_SERVER_KEY" in r.message for r in caplog.records
)
def test_api_server_default_loopback_without_key_no_warning(self, caplog):
"""API server with no explicit host defaults to 127.0.0.1 — no warning."""
import logging
config = GatewayConfig(
platforms={
Platform.API_SERVER: PlatformConfig(enabled=True),
}
)
with caplog.at_level(logging.WARNING, logger="gateway.config"):
_validate_gateway_config(config)
assert not any(
"API_SERVER_KEY" in r.message for r in caplog.records
)

View File

@@ -1,76 +0,0 @@
"""Tests for tool fixation detection."""
import pytest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.tool_fixation_detector import ToolFixationDetector, get_fixation_detector
class TestFixationDetection:
def test_no_fixation_below_threshold(self):
d = ToolFixationDetector(threshold=5)
for i in range(4):
assert d.record("execute_code") is None
def test_fixation_at_threshold(self):
d = ToolFixationDetector(threshold=3)
d.record("execute_code")
d.record("execute_code")
nudge = d.record("execute_code")
assert nudge is not None
assert "execute_code" in nudge
assert "3 times" in nudge
def test_fixation_above_threshold(self):
d = ToolFixationDetector(threshold=3)
d.record("execute_code")
d.record("execute_code")
d.record("execute_code") # threshold hit
nudge = d.record("execute_code") # still nudging
assert nudge is not None
def test_streak_resets_on_different_tool(self):
d = ToolFixationDetector(threshold=3)
d.record("execute_code")
d.record("execute_code")
d.record("terminal") # breaks streak
assert d._streak_count == 1
assert d._current_streak == "terminal"
def test_nudges_sent_counter(self):
d = ToolFixationDetector(threshold=2)
d.record("a")
d.record("a") # nudge 1
d.record("a") # nudge 2
assert d.nudges_sent == 2
def test_events_recorded(self):
d = ToolFixationDetector(threshold=2)
d.record("x")
d.record("x")
assert len(d.events) == 1
assert d.events[0].tool_name == "x"
assert d.events[0].streak_length == 2
def test_report(self):
d = ToolFixationDetector(threshold=2)
d.record("x")
d.record("x")
report = d.format_report()
assert "x" in report
def test_reset(self):
d = ToolFixationDetector(threshold=2)
d.record("x")
d.record("x")
d.reset()
assert d._streak_count == 0
assert d._current_streak == ""
def test_singleton(self):
d1 = get_fixation_detector()
d2 = get_fixation_detector()
assert d1 is d2