fix: add service domain blocklist and entity_id validation to HA tools
Block dangerous HA service domains (shell_command, command_line, python_script, pyscript, hassio, rest_command) that allow arbitrary code execution or SSRF. Add regex validation for entity_id to prevent path traversal attacks. 17 new tests covering both security features.
This commit is contained in:
@@ -16,6 +16,8 @@ from tools.homeassistant_tool import (
|
||||
_get_headers,
|
||||
_handle_get_state,
|
||||
_handle_call_service,
|
||||
_BLOCKED_DOMAINS,
|
||||
_ENTITY_ID_RE,
|
||||
)
|
||||
|
||||
|
||||
@@ -211,6 +213,96 @@ class TestHandlerValidation:
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: domain blocklist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDomainBlocklist:
|
||||
"""Verify dangerous HA service domains are blocked."""
|
||||
|
||||
@pytest.mark.parametrize("domain", sorted(_BLOCKED_DOMAINS))
|
||||
def test_blocked_domain_rejected(self, domain):
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": domain, "service": "any_service"
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "blocked" in result["error"].lower()
|
||||
|
||||
def test_safe_domain_not_blocked(self):
|
||||
"""Safe domains like 'light' should not be blocked (will fail on network, not blocklist)."""
|
||||
# This will try to make a real HTTP call and fail, but the important thing
|
||||
# is it does NOT return a "blocked" error
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "light", "service": "turn_on", "entity_id": "light.test"
|
||||
}))
|
||||
# Should fail with a network/connection error, not a "blocked" error
|
||||
if "error" in result:
|
||||
assert "blocked" not in result["error"].lower()
|
||||
|
||||
def test_blocked_domains_include_shell_command(self):
|
||||
assert "shell_command" in _BLOCKED_DOMAINS
|
||||
|
||||
def test_blocked_domains_include_hassio(self):
|
||||
assert "hassio" in _BLOCKED_DOMAINS
|
||||
|
||||
def test_blocked_domains_include_rest_command(self):
|
||||
assert "rest_command" in _BLOCKED_DOMAINS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security: entity_id validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntityIdValidation:
|
||||
"""Verify entity_id format validation prevents path traversal."""
|
||||
|
||||
def test_valid_entity_id_accepted(self):
|
||||
assert _ENTITY_ID_RE.match("light.bedroom")
|
||||
assert _ENTITY_ID_RE.match("sensor.temperature_1")
|
||||
assert _ENTITY_ID_RE.match("binary_sensor.motion")
|
||||
assert _ENTITY_ID_RE.match("climate.main_thermostat")
|
||||
|
||||
def test_path_traversal_rejected(self):
|
||||
assert _ENTITY_ID_RE.match("../../config") is None
|
||||
assert _ENTITY_ID_RE.match("light/../../../etc/passwd") is None
|
||||
assert _ENTITY_ID_RE.match("../api/config") is None
|
||||
|
||||
def test_special_chars_rejected(self):
|
||||
assert _ENTITY_ID_RE.match("light.bed room") is None # space
|
||||
assert _ENTITY_ID_RE.match("light.bed;rm -rf") is None # semicolon
|
||||
assert _ENTITY_ID_RE.match("light.bed/room") is None # slash
|
||||
assert _ENTITY_ID_RE.match("LIGHT.BEDROOM") is None # uppercase
|
||||
|
||||
def test_missing_domain_rejected(self):
|
||||
assert _ENTITY_ID_RE.match(".bedroom") is None
|
||||
assert _ENTITY_ID_RE.match("bedroom") is None
|
||||
|
||||
def test_get_state_rejects_invalid_entity_id(self):
|
||||
result = json.loads(_handle_get_state({"entity_id": "../../config"}))
|
||||
assert "error" in result
|
||||
assert "Invalid entity_id" in result["error"]
|
||||
|
||||
def test_call_service_rejects_invalid_entity_id(self):
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "light",
|
||||
"service": "turn_on",
|
||||
"entity_id": "../../../etc/passwd",
|
||||
}))
|
||||
assert "error" in result
|
||||
assert "Invalid entity_id" in result["error"]
|
||||
|
||||
def test_call_service_allows_no_entity_id(self):
|
||||
"""Some services (like scene.turn_on) don't need entity_id."""
|
||||
# Will fail on network, but should NOT fail on entity_id validation
|
||||
result = json.loads(_handle_call_service({
|
||||
"domain": "scene", "service": "turn_on"
|
||||
}))
|
||||
if "error" in result:
|
||||
assert "Invalid entity_id" not in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -13,6 +13,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,6 +25,21 @@ logger = logging.getLogger(__name__)
|
||||
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
||||
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
||||
|
||||
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
||||
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
||||
|
||||
# Service domains blocked for security -- these allow arbitrary code/command
|
||||
# execution on the HA host or enable SSRF attacks on the local network.
|
||||
# HA provides zero service-level access control; all safety must be in our layer.
|
||||
_BLOCKED_DOMAINS = frozenset({
|
||||
"shell_command", # arbitrary shell commands as root in HA container
|
||||
"command_line", # sensors/switches that execute shell commands
|
||||
"python_script", # sandboxed but can escalate via hass.services.call()
|
||||
"pyscript", # scripting integration with broader access
|
||||
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
||||
"rest_command", # HTTP requests from HA server (SSRF vector)
|
||||
})
|
||||
|
||||
|
||||
def _get_headers() -> Dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
@@ -198,6 +214,8 @@ def _handle_get_state(args: dict, **kw) -> str:
|
||||
entity_id = args.get("entity_id", "")
|
||||
if not entity_id:
|
||||
return json.dumps({"error": "Missing required parameter: entity_id"})
|
||||
if not _ENTITY_ID_RE.match(entity_id):
|
||||
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||
try:
|
||||
result = _run_async(_async_get_state(entity_id))
|
||||
return json.dumps({"result": result})
|
||||
@@ -213,7 +231,16 @@ def _handle_call_service(args: dict, **kw) -> str:
|
||||
if not domain or not service:
|
||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
if domain in _BLOCKED_DOMAINS:
|
||||
return json.dumps({
|
||||
"error": f"Service domain '{domain}' is blocked for security. "
|
||||
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
||||
})
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
||||
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
||||
|
||||
data = args.get("data")
|
||||
try:
|
||||
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
||||
|
||||
Reference in New Issue
Block a user