fix: guard config.get() against YAML null values to prevent AttributeError (#3377)

dict.get(key, default) returns None — not the default — when the key IS
present but explicitly set to null/~ in YAML.  Calling .lower() on that
raises AttributeError.

Use (config.get(key) or fallback) so both missing keys and explicit nulls
coalesce to the intended default.

Files fixed:
- tools/tts_tool.py — _get_provider()
- tools/web_tools.py — _get_backend()
- tools/mcp_tool.py — MCPServerTask auth config
- trajectory_compressor.py — _detect_provider() and config loading

Co-authored-by: dieutx <dangtc94@gmail.com>
This commit is contained in:
Teknium
2026-03-27 04:03:00 -07:00
committed by GitHub
parent b8b1f24fd7
commit be416cdfa9
5 changed files with 116 additions and 5 deletions

View File

@@ -0,0 +1,111 @@
"""Tests for config.get() null-coalescing in tool configuration.
YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)``
return ``None`` instead of the default — calling ``.lower()`` on that raises
``AttributeError``. These tests verify the ``or`` coalescing guards.
"""
from unittest.mock import patch
import pytest
# ── TTS tool ──────────────────────────────────────────────────────────────
class TestTTSProviderNullGuard:
"""tools/tts_tool.py — _get_provider()"""
def test_explicit_null_provider_returns_default(self):
"""YAML ``tts: {provider: null}`` should fall back to default."""
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
result = _get_provider({"provider": None})
assert result == DEFAULT_PROVIDER.lower().strip()
def test_missing_provider_returns_default(self):
"""No ``provider`` key at all should also return default."""
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
result = _get_provider({})
assert result == DEFAULT_PROVIDER.lower().strip()
def test_valid_provider_passed_through(self):
from tools.tts_tool import _get_provider
result = _get_provider({"provider": "OPENAI"})
assert result == "openai"
# ── Web tools ─────────────────────────────────────────────────────────────
class TestWebBackendNullGuard:
"""tools/web_tools.py — _get_backend()"""
@patch("tools.web_tools._load_web_config", return_value={"backend": None})
def test_explicit_null_backend_does_not_crash(self, _cfg):
"""YAML ``web: {backend: null}`` should not raise AttributeError."""
from tools.web_tools import _get_backend
# Should not raise — the exact return depends on env key fallback
result = _get_backend()
assert isinstance(result, str)
@patch("tools.web_tools._load_web_config", return_value={})
def test_missing_backend_does_not_crash(self, _cfg):
from tools.web_tools import _get_backend
result = _get_backend()
assert isinstance(result, str)
# ── MCP tool ──────────────────────────────────────────────────────────────
class TestMCPAuthNullGuard:
"""tools/mcp_tool.py — MCPServerTask.__init__() auth config line"""
def test_explicit_null_auth_does_not_crash(self):
"""YAML ``auth: null`` in MCP server config should not raise."""
# Test the expression directly — MCPServerTask.__init__ has many deps
config = {"auth": None, "timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == ""
def test_missing_auth_defaults_to_empty(self):
config = {"timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == ""
def test_valid_auth_passed_through(self):
config = {"auth": "OAUTH", "timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == "oauth"
# ── Trajectory compressor ─────────────────────────────────────────────────
class TestTrajectoryCompressorNullGuard:
"""trajectory_compressor.py — _detect_provider() and config loading"""
def test_null_base_url_does_not_crash(self):
"""base_url=None should not crash _detect_provider()."""
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
config = CompressionConfig()
config.base_url = None
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
# Should not raise AttributeError; returns empty string (no match)
result = compressor._detect_provider()
assert result == ""
def test_config_loading_null_base_url_keeps_default(self):
"""YAML ``summarization: {base_url: null}`` should keep default."""
from trajectory_compressor import CompressionConfig
from hermes_constants import OPENROUTER_BASE_URL
config = CompressionConfig()
data = {"summarization": {"base_url": None}}
config.base_url = data["summarization"].get("base_url") or config.base_url
assert config.base_url == OPENROUTER_BASE_URL

View File

@@ -797,7 +797,7 @@ class MCPServerTask:
"""
self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
self._auth_type = config.get("auth", "").lower().strip()
self._auth_type = (config.get("auth") or "").lower().strip()
# Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {})

View File

@@ -102,7 +102,7 @@ def _load_tts_config() -> Dict[str, Any]:
def _get_provider(tts_config: Dict[str, Any]) -> str:
"""Get the configured TTS provider name."""
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip()
return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip()
# ===========================================================================

View File

@@ -73,7 +73,7 @@ def _get_backend() -> str:
Falls back to whichever API key is present for users who configured
keys manually without running setup.
"""
configured = _load_web_config().get("backend", "").lower().strip()
configured = (_load_web_config().get("backend") or "").lower().strip()
if configured in ("parallel", "firecrawl", "tavily"):
return configured

View File

@@ -123,7 +123,7 @@ class CompressionConfig:
# Summarization
if 'summarization' in data:
config.summarization_model = data['summarization'].get('model', config.summarization_model)
config.base_url = data['summarization'].get('base_url', config.base_url)
config.base_url = data['summarization'].get('base_url') or config.base_url
config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env)
config.temperature = data['summarization'].get('temperature', config.temperature)
config.max_retries = data['summarization'].get('max_retries', config.max_retries)
@@ -386,7 +386,7 @@ class TrajectoryCompressor:
def _detect_provider(self) -> str:
"""Detect the provider name from the configured base_url."""
url = self.config.base_url.lower()
url = (self.config.base_url or "").lower()
if "openrouter" in url:
return "openrouter"
if "nousresearch.com" in url: