fix: guard config.get() against YAML null values to prevent AttributeError (#3377)

dict.get(key, default) returns None — not the default — when the key IS
present but explicitly set to null/~ in YAML.  Calling .lower() on that
raises AttributeError.

Use (config.get(key) or fallback) so both missing keys and explicit nulls
coalesce to the intended default.

Files fixed:
- tools/tts_tool.py — _get_provider()
- tools/web_tools.py — _get_backend()
- tools/mcp_tool.py — MCPServerTask auth config
- trajectory_compressor.py — _detect_provider() and config loading

Co-authored-by: dieutx <dangtc94@gmail.com>
This commit is contained in:
Teknium
2026-03-27 04:03:00 -07:00
committed by GitHub
parent b8b1f24fd7
commit be416cdfa9
5 changed files with 116 additions and 5 deletions

View File

@@ -0,0 +1,111 @@
"""Tests for config.get() null-coalescing in tool configuration.
YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)``
return ``None`` instead of the default — calling ``.lower()`` on that raises
``AttributeError``. These tests verify the ``or`` coalescing guards.
"""
from unittest.mock import patch
import pytest
# ── TTS tool ──────────────────────────────────────────────────────────────
class TestTTSProviderNullGuard:
"""tools/tts_tool.py — _get_provider()"""
def test_explicit_null_provider_returns_default(self):
"""YAML ``tts: {provider: null}`` should fall back to default."""
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
result = _get_provider({"provider": None})
assert result == DEFAULT_PROVIDER.lower().strip()
def test_missing_provider_returns_default(self):
"""No ``provider`` key at all should also return default."""
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
result = _get_provider({})
assert result == DEFAULT_PROVIDER.lower().strip()
def test_valid_provider_passed_through(self):
from tools.tts_tool import _get_provider
result = _get_provider({"provider": "OPENAI"})
assert result == "openai"
# ── Web tools ─────────────────────────────────────────────────────────────
class TestWebBackendNullGuard:
"""tools/web_tools.py — _get_backend()"""
@patch("tools.web_tools._load_web_config", return_value={"backend": None})
def test_explicit_null_backend_does_not_crash(self, _cfg):
"""YAML ``web: {backend: null}`` should not raise AttributeError."""
from tools.web_tools import _get_backend
# Should not raise — the exact return depends on env key fallback
result = _get_backend()
assert isinstance(result, str)
@patch("tools.web_tools._load_web_config", return_value={})
def test_missing_backend_does_not_crash(self, _cfg):
from tools.web_tools import _get_backend
result = _get_backend()
assert isinstance(result, str)
# ── MCP tool ──────────────────────────────────────────────────────────────
class TestMCPAuthNullGuard:
"""tools/mcp_tool.py — MCPServerTask.__init__() auth config line"""
def test_explicit_null_auth_does_not_crash(self):
"""YAML ``auth: null`` in MCP server config should not raise."""
# Test the expression directly — MCPServerTask.__init__ has many deps
config = {"auth": None, "timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == ""
def test_missing_auth_defaults_to_empty(self):
config = {"timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == ""
def test_valid_auth_passed_through(self):
config = {"auth": "OAUTH", "timeout": 30}
auth_type = (config.get("auth") or "").lower().strip()
assert auth_type == "oauth"
# ── Trajectory compressor ─────────────────────────────────────────────────
class TestTrajectoryCompressorNullGuard:
"""trajectory_compressor.py — _detect_provider() and config loading"""
def test_null_base_url_does_not_crash(self):
"""base_url=None should not crash _detect_provider()."""
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
config = CompressionConfig()
config.base_url = None
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
compressor.config = config
# Should not raise AttributeError; returns empty string (no match)
result = compressor._detect_provider()
assert result == ""
def test_config_loading_null_base_url_keeps_default(self):
"""YAML ``summarization: {base_url: null}`` should keep default."""
from trajectory_compressor import CompressionConfig
from hermes_constants import OPENROUTER_BASE_URL
config = CompressionConfig()
data = {"summarization": {"base_url": None}}
config.base_url = data["summarization"].get("base_url") or config.base_url
assert config.base_url == OPENROUTER_BASE_URL

View File

@@ -797,7 +797,7 @@ class MCPServerTask:
""" """
self._config = config self._config = config
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
self._auth_type = config.get("auth", "").lower().strip() self._auth_type = (config.get("auth") or "").lower().strip()
# Set up sampling handler if enabled and SDK types are available # Set up sampling handler if enabled and SDK types are available
sampling_config = config.get("sampling", {}) sampling_config = config.get("sampling", {})

View File

@@ -102,7 +102,7 @@ def _load_tts_config() -> Dict[str, Any]:
def _get_provider(tts_config: Dict[str, Any]) -> str: def _get_provider(tts_config: Dict[str, Any]) -> str:
"""Get the configured TTS provider name.""" """Get the configured TTS provider name."""
return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip() return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip()
# =========================================================================== # ===========================================================================

View File

@@ -73,7 +73,7 @@ def _get_backend() -> str:
Falls back to whichever API key is present for users who configured Falls back to whichever API key is present for users who configured
keys manually without running setup. keys manually without running setup.
""" """
configured = _load_web_config().get("backend", "").lower().strip() configured = (_load_web_config().get("backend") or "").lower().strip()
if configured in ("parallel", "firecrawl", "tavily"): if configured in ("parallel", "firecrawl", "tavily"):
return configured return configured

View File

@@ -123,7 +123,7 @@ class CompressionConfig:
# Summarization # Summarization
if 'summarization' in data: if 'summarization' in data:
config.summarization_model = data['summarization'].get('model', config.summarization_model) config.summarization_model = data['summarization'].get('model', config.summarization_model)
config.base_url = data['summarization'].get('base_url', config.base_url) config.base_url = data['summarization'].get('base_url') or config.base_url
config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env)
config.temperature = data['summarization'].get('temperature', config.temperature) config.temperature = data['summarization'].get('temperature', config.temperature)
config.max_retries = data['summarization'].get('max_retries', config.max_retries) config.max_retries = data['summarization'].get('max_retries', config.max_retries)
@@ -386,7 +386,7 @@ class TrajectoryCompressor:
def _detect_provider(self) -> str: def _detect_provider(self) -> str:
"""Detect the provider name from the configured base_url.""" """Detect the provider name from the configured base_url."""
url = self.config.base_url.lower() url = (self.config.base_url or "").lower()
if "openrouter" in url: if "openrouter" in url:
return "openrouter" return "openrouter"
if "nousresearch.com" in url: if "nousresearch.com" in url: