diff --git a/tests/tools/test_homeassistant_tool.py b/tests/tools/test_homeassistant_tool.py index b57df069d..b136b5653 100644 --- a/tests/tools/test_homeassistant_tool.py +++ b/tests/tools/test_homeassistant_tool.py @@ -16,6 +16,8 @@ from tools.homeassistant_tool import ( _get_headers, _handle_get_state, _handle_call_service, + _BLOCKED_DOMAINS, + _ENTITY_ID_RE, ) @@ -211,6 +213,96 @@ class TestHandlerValidation: assert "error" in result +# --------------------------------------------------------------------------- +# Security: domain blocklist +# --------------------------------------------------------------------------- + + +class TestDomainBlocklist: + """Verify dangerous HA service domains are blocked.""" + + @pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS)) + def test_blocked_domain_rejected(self, domain): + result = json.loads(_handle_call_service({ + "domain": domain, "service": "any_service" + })) + assert "error" in result + assert "blocked" in result["error"].lower() + + def test_safe_domain_not_blocked(self): + """Safe domains like 'light' should not be blocked (will fail on network, not blocklist).""" + # This will try to make a real HTTP call and fail, but the important thing + # is it does NOT return a "blocked" error + result = json.loads(_handle_call_service({ + "domain": "light", "service": "turn_on", "entity_id": "light.test" + })) + # Should fail with a network/connection error, not a "blocked" error + if "error" in result: + assert "blocked" not in result["error"].lower() + + def test_blocked_domains_include_shell_command(self): + assert "shell_command" in _BLOCKED_DOMAINS + + def test_blocked_domains_include_hassio(self): + assert "hassio" in _BLOCKED_DOMAINS + + def test_blocked_domains_include_rest_command(self): + assert "rest_command" in _BLOCKED_DOMAINS + + +# --------------------------------------------------------------------------- +# Security: entity_id validation +# --------------------------------------------------------------------------- + + +class TestEntityIdValidation: + """Verify entity_id format validation prevents path traversal.""" + + def test_valid_entity_id_accepted(self): + assert _ENTITY_ID_RE.match("light.bedroom") + assert _ENTITY_ID_RE.match("sensor.temperature_1") + assert _ENTITY_ID_RE.match("binary_sensor.motion") + assert _ENTITY_ID_RE.match("climate.main_thermostat") + + def test_path_traversal_rejected(self): + assert _ENTITY_ID_RE.match("../../config") is None + assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None + assert _ENTITY_ID_RE.match("../api/config") is None + + def test_special_chars_rejected(self): + assert _ENTITY_ID_RE.match("light.bed room") is None # space + assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None # semicolon + assert _ENTITY_ID_RE.match("light.bed/room") is None # slash + assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None # uppercase + + def test_missing_domain_rejected(self): + assert _ENTITY_ID_RE.match(".bedroom") is None + assert _ENTITY_ID_RE.match("bedroom") is None + + def test_get_state_rejects_invalid_entity_id(self): + result = json.loads(_handle_get_state({"entity_id": "../../config"})) + assert "error" in result + assert "Invalid entity_id" in result["error"] + + def test_call_service_rejects_invalid_entity_id(self): + result = json.loads(_handle_call_service({ + "domain": "light", + "service": "turn_on", + "entity_id": "../../../etc/passwd", + })) + assert "error" in result + assert "Invalid entity_id" in result["error"] + + def test_call_service_allows_no_entity_id(self): + """Some services (like scene.turn_on) don't need entity_id.""" + # Will fail on network, but should NOT fail on entity_id validation + result = json.loads(_handle_call_service({ + "domain": "scene", "service": "turn_on" + })) + if "error" in result: + assert "Invalid entity_id" not in result["error"] + + # --------------------------------------------------------------------------- # Availability check # --------------------------------------------------------------------------- diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py index b351cfecf..177296108 100644 --- a/tools/homeassistant_tool.py +++ b/tools/homeassistant_tool.py @@ -13,6 +13,7 @@ import asyncio import json import logging import os +import re from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -24,6 +25,21 @@ logger = logging.getLogger(__name__) _HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/") _HASS_TOKEN: str = os.getenv("HASS_TOKEN", "") +# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1") +_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$") + +# Service domains blocked for security -- these allow arbitrary code/command +# execution on the HA host or enable SSRF attacks on the local network. +# HA provides zero service-level access control; all safety must be in our layer. +_BLOCKED_DOMAINS = frozenset({ + "shell_command", # arbitrary shell commands as root in HA container + "command_line", # sensors/switches that execute shell commands + "python_script", # sandboxed but can escalate via hass.services.call() + "pyscript", # scripting integration with broader access + "hassio", # addon control, host shutdown/reboot, stdin to containers + "rest_command", # HTTP requests from HA server (SSRF vector) +}) + def _get_headers() -> Dict[str, str]: """Return authorization headers for HA REST API.""" @@ -198,6 +214,8 @@ def _handle_get_state(args: dict, **kw) -> str: entity_id = args.get("entity_id", "") if not entity_id: return json.dumps({"error": "Missing required parameter: entity_id"}) + if not _ENTITY_ID_RE.match(entity_id): + return json.dumps({"error": f"Invalid entity_id format: {entity_id}"}) try: result = _run_async(_async_get_state(entity_id)) return json.dumps({"result": result}) @@ -213,7 +231,16 @@ def _handle_call_service(args: dict, **kw) -> str: if not domain or not service: return json.dumps({"error": "Missing required parameters: domain and service"}) + if domain in _BLOCKED_DOMAINS: + return json.dumps({ + "error": f"Service domain '{domain}' is blocked for security. " + f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}" + }) + entity_id = args.get("entity_id") + if entity_id and not _ENTITY_ID_RE.match(entity_id): + return json.dumps({"error": f"Invalid entity_id format: {entity_id}"}) + data = args.get("data") try: result = _run_async(_async_call_service(domain, service, entity_id, data))