From be416cdfa94e17009b56e5dc292d2a426e57e91c Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Fri, 27 Mar 2026 04:03:00 -0700 Subject: [PATCH] fix: guard config.get() against YAML null values to prevent AttributeError (#3377) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dict.get(key, default) returns None — not the default — when the key IS present but explicitly set to null/~ in YAML. Calling .lower() on that raises AttributeError. Use (config.get(key) or fallback) so both missing keys and explicit nulls coalesce to the intended default. Files fixed: - tools/tts_tool.py — _get_provider() - tools/web_tools.py — _get_backend() - tools/mcp_tool.py — MCPServerTask auth config - trajectory_compressor.py — _detect_provider() and config loading Co-authored-by: dieutx --- tests/tools/test_config_null_guard.py | 111 ++++++++++++++++++++++++++ tools/mcp_tool.py | 2 +- tools/tts_tool.py | 2 +- tools/web_tools.py | 2 +- trajectory_compressor.py | 4 +- 5 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 tests/tools/test_config_null_guard.py diff --git a/tests/tools/test_config_null_guard.py b/tests/tools/test_config_null_guard.py new file mode 100644 index 00000000..a6ab6400 --- /dev/null +++ b/tests/tools/test_config_null_guard.py @@ -0,0 +1,111 @@ +"""Tests for config.get() null-coalescing in tool configuration. + +YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)`` +return ``None`` instead of the default — calling ``.lower()`` on that raises +``AttributeError``. These tests verify the ``or`` coalescing guards. +""" + +from unittest.mock import patch +import pytest + + +# ── TTS tool ────────────────────────────────────────────────────────────── + +class TestTTSProviderNullGuard: + """tools/tts_tool.py — _get_provider()""" + + def test_explicit_null_provider_returns_default(self): + """YAML ``tts: {provider: null}`` should fall back to default.""" + from tools.tts_tool import _get_provider, DEFAULT_PROVIDER + + result = _get_provider({"provider": None}) + assert result == DEFAULT_PROVIDER.lower().strip() + + def test_missing_provider_returns_default(self): + """No ``provider`` key at all should also return default.""" + from tools.tts_tool import _get_provider, DEFAULT_PROVIDER + + result = _get_provider({}) + assert result == DEFAULT_PROVIDER.lower().strip() + + def test_valid_provider_passed_through(self): + from tools.tts_tool import _get_provider + + result = _get_provider({"provider": "OPENAI"}) + assert result == "openai" + + +# ── Web tools ───────────────────────────────────────────────────────────── + +class TestWebBackendNullGuard: + """tools/web_tools.py — _get_backend()""" + + @patch("tools.web_tools._load_web_config", return_value={"backend": None}) + def test_explicit_null_backend_does_not_crash(self, _cfg): + """YAML ``web: {backend: null}`` should not raise AttributeError.""" + from tools.web_tools import _get_backend + + # Should not raise — the exact return depends on env key fallback + result = _get_backend() + assert isinstance(result, str) + + @patch("tools.web_tools._load_web_config", return_value={}) + def test_missing_backend_does_not_crash(self, _cfg): + from tools.web_tools import _get_backend + + result = _get_backend() + assert isinstance(result, str) + + +# ── MCP tool ────────────────────────────────────────────────────────────── + +class TestMCPAuthNullGuard: + """tools/mcp_tool.py — MCPServerTask.__init__() auth config line""" + + def test_explicit_null_auth_does_not_crash(self): + """YAML ``auth: null`` in MCP server config should not raise.""" + # Test the expression directly — MCPServerTask.__init__ has many deps + config = {"auth": None, "timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "" + + def test_missing_auth_defaults_to_empty(self): + config = {"timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "" + + def test_valid_auth_passed_through(self): + config = {"auth": "OAUTH", "timeout": 30} + auth_type = (config.get("auth") or "").lower().strip() + assert auth_type == "oauth" + + +# ── Trajectory compressor ───────────────────────────────────────────────── + +class TestTrajectoryCompressorNullGuard: + """trajectory_compressor.py — _detect_provider() and config loading""" + + def test_null_base_url_does_not_crash(self): + """base_url=None should not crash _detect_provider().""" + from trajectory_compressor import CompressionConfig, TrajectoryCompressor + + config = CompressionConfig() + config.base_url = None + + compressor = TrajectoryCompressor.__new__(TrajectoryCompressor) + compressor.config = config + + # Should not raise AttributeError; returns empty string (no match) + result = compressor._detect_provider() + assert result == "" + + def test_config_loading_null_base_url_keeps_default(self): + """YAML ``summarization: {base_url: null}`` should keep default.""" + from trajectory_compressor import CompressionConfig + from hermes_constants import OPENROUTER_BASE_URL + + config = CompressionConfig() + data = {"summarization": {"base_url": None}} + + config.base_url = data["summarization"].get("base_url") or config.base_url + assert config.base_url == OPENROUTER_BASE_URL diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index f539586e..2b68ff4b 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -797,7 +797,7 @@ class MCPServerTask: """ self._config = config self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT) - self._auth_type = config.get("auth", "").lower().strip() + self._auth_type = (config.get("auth") or "").lower().strip() # Set up sampling handler if enabled and SDK types are available sampling_config = config.get("sampling", {}) diff --git a/tools/tts_tool.py b/tools/tts_tool.py index eed3961d..879634cf 100644 --- a/tools/tts_tool.py +++ b/tools/tts_tool.py @@ -102,7 +102,7 @@ def _load_tts_config() -> Dict[str, Any]: def _get_provider(tts_config: Dict[str, Any]) -> str: """Get the configured TTS provider name.""" - return tts_config.get("provider", DEFAULT_PROVIDER).lower().strip() + return (tts_config.get("provider") or DEFAULT_PROVIDER).lower().strip() # =========================================================================== diff --git a/tools/web_tools.py b/tools/web_tools.py index d4afc06a..38ad8cf0 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -73,7 +73,7 @@ def _get_backend() -> str: Falls back to whichever API key is present for users who configured keys manually without running setup. """ - configured = _load_web_config().get("backend", "").lower().strip() + configured = (_load_web_config().get("backend") or "").lower().strip() if configured in ("parallel", "firecrawl", "tavily"): return configured diff --git a/trajectory_compressor.py b/trajectory_compressor.py index 1bfed6bf..fd69cd18 100644 --- a/trajectory_compressor.py +++ b/trajectory_compressor.py @@ -123,7 +123,7 @@ class CompressionConfig: # Summarization if 'summarization' in data: config.summarization_model = data['summarization'].get('model', config.summarization_model) - config.base_url = data['summarization'].get('base_url', config.base_url) + config.base_url = data['summarization'].get('base_url') or config.base_url config.api_key_env = data['summarization'].get('api_key_env', config.api_key_env) config.temperature = data['summarization'].get('temperature', config.temperature) config.max_retries = data['summarization'].get('max_retries', config.max_retries) @@ -386,7 +386,7 @@ class TrajectoryCompressor: def _detect_provider(self) -> str: """Detect the provider name from the configured base_url.""" - url = self.config.base_url.lower() + url = (self.config.base_url or "").lower() if "openrouter" in url: return "openrouter" if "nousresearch.com" in url: