Block dangerous HA service domains (shell_command, command_line, python_script, pyscript, hassio, rest_command) that allow arbitrary code execution or SSRF. Add regex validation for entity_id to prevent path traversal attacks. 17 new tests covering both security features.
393 lines
13 KiB
Python
393 lines
13 KiB
Python
"""Home Assistant tool for controlling smart home devices via REST API.
|
|
|
|
Registers three LLM-callable tools:
|
|
- ``ha_list_entities`` -- list/filter entities by domain or area
|
|
- ``ha_get_state`` -- get detailed state of a single entity
|
|
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
|
|
|
|
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
|
|
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
from typing import Any, Dict, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
|
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
|
|
|
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
|
|
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
|
|
|
|
# Service domains blocked for security -- these allow arbitrary code/command
|
|
# execution on the HA host or enable SSRF attacks on the local network.
|
|
# HA provides zero service-level access control; all safety must be in our layer.
|
|
_BLOCKED_DOMAINS = frozenset({
|
|
"shell_command", # arbitrary shell commands as root in HA container
|
|
"command_line", # sensors/switches that execute shell commands
|
|
"python_script", # sandboxed but can escalate via hass.services.call()
|
|
"pyscript", # scripting integration with broader access
|
|
"hassio", # addon control, host shutdown/reboot, stdin to containers
|
|
"rest_command", # HTTP requests from HA server (SSRF vector)
|
|
})
|
|
|
|
|
|
def _get_headers() -> Dict[str, str]:
|
|
"""Return authorization headers for HA REST API."""
|
|
return {
|
|
"Authorization": f"Bearer {_HASS_TOKEN}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Async helpers (called from sync handlers via run_until_complete)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _filter_and_summarize(
|
|
states: list,
|
|
domain: Optional[str] = None,
|
|
area: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Filter raw HA states by domain/area and return a compact summary."""
|
|
if domain:
|
|
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
|
|
|
if area:
|
|
area_lower = area.lower()
|
|
states = [
|
|
s for s in states
|
|
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
|
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
|
]
|
|
|
|
entities = []
|
|
for s in states:
|
|
entities.append({
|
|
"entity_id": s["entity_id"],
|
|
"state": s["state"],
|
|
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
|
})
|
|
|
|
return {"count": len(entities), "entities": entities}
|
|
|
|
|
|
async def _async_list_entities(
|
|
domain: Optional[str] = None,
|
|
area: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Fetch entity states from HA and optionally filter by domain/area."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/states"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
|
resp.raise_for_status()
|
|
states = await resp.json()
|
|
|
|
return _filter_and_summarize(states, domain, area)
|
|
|
|
|
|
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
|
"""Fetch detailed state of a single entity."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/states/{entity_id}"
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
|
|
return {
|
|
"entity_id": data["entity_id"],
|
|
"state": data["state"],
|
|
"attributes": data.get("attributes", {}),
|
|
"last_changed": data.get("last_changed"),
|
|
"last_updated": data.get("last_updated"),
|
|
}
|
|
|
|
|
|
def _build_service_payload(
|
|
entity_id: Optional[str] = None,
|
|
data: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Build the JSON payload for a HA service call."""
|
|
payload: Dict[str, Any] = {}
|
|
if data:
|
|
payload.update(data)
|
|
# entity_id parameter takes precedence over data["entity_id"]
|
|
if entity_id:
|
|
payload["entity_id"] = entity_id
|
|
return payload
|
|
|
|
|
|
def _parse_service_response(
|
|
domain: str,
|
|
service: str,
|
|
result: Any,
|
|
) -> Dict[str, Any]:
|
|
"""Parse HA service call response into a structured result."""
|
|
affected = []
|
|
if isinstance(result, list):
|
|
for s in result:
|
|
affected.append({
|
|
"entity_id": s.get("entity_id", ""),
|
|
"state": s.get("state", ""),
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"service": f"{domain}.{service}",
|
|
"affected_entities": affected,
|
|
}
|
|
|
|
|
|
async def _async_call_service(
|
|
domain: str,
|
|
service: str,
|
|
entity_id: Optional[str] = None,
|
|
data: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Call a Home Assistant service."""
|
|
import aiohttp
|
|
|
|
url = f"{_HASS_URL}/api/services/{domain}/{service}"
|
|
payload = _build_service_payload(entity_id, data)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url,
|
|
headers=_get_headers(),
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=15),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
result = await resp.json()
|
|
|
|
return _parse_service_response(domain, service, result)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sync wrappers (handler signature: (args, **kw) -> str)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine from a sync handler."""
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
loop = None
|
|
|
|
if loop and loop.is_running():
|
|
# Already inside an event loop -- create a new thread
|
|
import concurrent.futures
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
future = pool.submit(asyncio.run, coro)
|
|
return future.result(timeout=30)
|
|
else:
|
|
return asyncio.run(coro)
|
|
|
|
|
|
def _handle_list_entities(args: dict, **kw) -> str:
|
|
"""Handler for ha_list_entities tool."""
|
|
domain = args.get("domain")
|
|
area = args.get("area")
|
|
try:
|
|
result = _run_async(_async_list_entities(domain=domain, area=area))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_list_entities error: %s", e)
|
|
return json.dumps({"error": f"Failed to list entities: {e}"})
|
|
|
|
|
|
def _handle_get_state(args: dict, **kw) -> str:
|
|
"""Handler for ha_get_state tool."""
|
|
entity_id = args.get("entity_id", "")
|
|
if not entity_id:
|
|
return json.dumps({"error": "Missing required parameter: entity_id"})
|
|
if not _ENTITY_ID_RE.match(entity_id):
|
|
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
|
try:
|
|
result = _run_async(_async_get_state(entity_id))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_get_state error: %s", e)
|
|
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
|
|
|
|
|
|
def _handle_call_service(args: dict, **kw) -> str:
|
|
"""Handler for ha_call_service tool."""
|
|
domain = args.get("domain", "")
|
|
service = args.get("service", "")
|
|
if not domain or not service:
|
|
return json.dumps({"error": "Missing required parameters: domain and service"})
|
|
|
|
if domain in _BLOCKED_DOMAINS:
|
|
return json.dumps({
|
|
"error": f"Service domain '{domain}' is blocked for security. "
|
|
f"Blocked domains: {', '.join(sorted(_BLOCKED_DOMAINS))}"
|
|
})
|
|
|
|
entity_id = args.get("entity_id")
|
|
if entity_id and not _ENTITY_ID_RE.match(entity_id):
|
|
return json.dumps({"error": f"Invalid entity_id format: {entity_id}"})
|
|
|
|
data = args.get("data")
|
|
try:
|
|
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
|
return json.dumps({"result": result})
|
|
except Exception as e:
|
|
logger.error("ha_call_service error: %s", e)
|
|
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Availability check
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _check_ha_available() -> bool:
|
|
"""Tool is only available when HASS_TOKEN is set."""
|
|
return bool(os.getenv("HASS_TOKEN"))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
HA_LIST_ENTITIES_SCHEMA = {
|
|
"name": "ha_list_entities",
|
|
"description": (
|
|
"List Home Assistant entities. Optionally filter by domain "
|
|
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
|
|
"or by area name (living room, kitchen, bedroom, etc.)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"domain": {
|
|
"type": "string",
|
|
"description": (
|
|
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
|
|
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
|
|
"Omit to list all entities."
|
|
),
|
|
},
|
|
"area": {
|
|
"type": "string",
|
|
"description": (
|
|
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
|
|
"Matches against entity friendly names. Omit to list all."
|
|
),
|
|
},
|
|
},
|
|
"required": [],
|
|
},
|
|
}
|
|
|
|
HA_GET_STATE_SCHEMA = {
|
|
"name": "ha_get_state",
|
|
"description": (
|
|
"Get the detailed state of a single Home Assistant entity, including all "
|
|
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"entity_id": {
|
|
"type": "string",
|
|
"description": (
|
|
"The entity ID to query (e.g. 'light.living_room', "
|
|
"'climate.thermostat', 'sensor.temperature')."
|
|
),
|
|
},
|
|
},
|
|
"required": ["entity_id"],
|
|
},
|
|
}
|
|
|
|
HA_CALL_SERVICE_SCHEMA = {
|
|
"name": "ha_call_service",
|
|
"description": (
|
|
"Call a Home Assistant service to control a device. Common examples: "
|
|
"turn_on/turn_off lights and switches, set_temperature for climate, "
|
|
"open_cover/close_cover for blinds, set_volume_level for media players."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"domain": {
|
|
"type": "string",
|
|
"description": (
|
|
"Service domain (e.g. 'light', 'switch', 'climate', "
|
|
"'cover', 'media_player', 'fan', 'scene', 'script')."
|
|
),
|
|
},
|
|
"service": {
|
|
"type": "string",
|
|
"description": (
|
|
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
|
|
"'set_temperature', 'set_hvac_mode', 'open_cover', "
|
|
"'close_cover', 'set_volume_level')."
|
|
),
|
|
},
|
|
"entity_id": {
|
|
"type": "string",
|
|
"description": (
|
|
"Target entity ID (e.g. 'light.living_room'). "
|
|
"Some services (like scene.turn_on) may not need this."
|
|
),
|
|
},
|
|
"data": {
|
|
"type": "object",
|
|
"description": (
|
|
"Additional service data. Examples: "
|
|
'{"brightness": 255, "color_name": "blue"} for lights, '
|
|
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
|
|
'{"volume_level": 0.5} for media players.'
|
|
),
|
|
},
|
|
},
|
|
"required": ["domain", "service"],
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Registration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
from tools.registry import registry
|
|
|
|
registry.register(
|
|
name="ha_list_entities",
|
|
toolset="homeassistant",
|
|
schema=HA_LIST_ENTITIES_SCHEMA,
|
|
handler=_handle_list_entities,
|
|
check_fn=_check_ha_available,
|
|
)
|
|
|
|
registry.register(
|
|
name="ha_get_state",
|
|
toolset="homeassistant",
|
|
schema=HA_GET_STATE_SCHEMA,
|
|
handler=_handle_get_state,
|
|
check_fn=_check_ha_available,
|
|
)
|
|
|
|
registry.register(
|
|
name="ha_call_service",
|
|
toolset="homeassistant",
|
|
schema=HA_CALL_SERVICE_SCHEMA,
|
|
handler=_handle_call_service,
|
|
check_fn=_check_ha_available,
|
|
)
|