From ffec21236d21745f4530fa1866d9cf858f82529f Mon Sep 17 00:00:00 2001 From: teknium1 Date: Tue, 3 Mar 2026 05:16:53 -0800 Subject: [PATCH] feat: enhance Home Assistant integration with service discovery and setup Improvements to the HA integration merged from PR #184: - Add ha_list_services tool: discovers available services (actions) per domain with descriptions and parameter fields. Tells the model what it can do with each device type (e.g. light.turn_on accepts brightness, color_name, transition). Closes the gap where the model had to guess available actions. - Add HA to hermes tools config: users can enable/disable the homeassistant toolset and configure HASS_TOKEN + HASS_URL through 'hermes tools' setup flow instead of manually editing .env. - Fix should-fix items from code review: - Remove sys.path.insert hack from gateway adapter - Replace all print() calls with proper logger (info/warning/error) - Move env var reads from import-time to handler-time via _get_config() - Add dedicated REST session reuse in gateway send() - Update ha_call_service description to reference ha_list_services for action discovery. - Update tests for new ha_list_services tool in toolset resolution. --- gateway/platforms/homeassistant.py | 42 +++++++--- hermes_cli/tools_config.py | 3 + tests/gateway/test_homeassistant.py | 6 +- tools/homeassistant_tool.py | 122 ++++++++++++++++++++++++---- toolsets.py | 4 +- 5 files changed, 145 insertions(+), 32 deletions(-) diff --git a/gateway/platforms/homeassistant.py b/gateway/platforms/homeassistant.py index 08dfa0992..a900ef3b7 100644 --- a/gateway/platforms/homeassistant.py +++ b/gateway/platforms/homeassistant.py @@ -28,10 +28,6 @@ except ImportError: AIOHTTP_AVAILABLE = False aiohttp = None # type: ignore[assignment] -import sys -from pathlib import Path as _Path -sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) - from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( BasePlatformAdapter, @@ -72,6 +68,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): # Connection state self._session: Optional["aiohttp.ClientSession"] = None self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None + self._rest_session: Optional["aiohttp.ClientSession"] = None self._listen_task: Optional[asyncio.Task] = None self._msg_id: int = 0 @@ -103,11 +100,11 @@ class HomeAssistantAdapter(BasePlatformAdapter): async def connect(self) -> bool: """Connect to HA WebSocket API and subscribe to events.""" if not AIOHTTP_AVAILABLE: - print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp") + logger.warning("[%s] aiohttp not installed. Run: pip install aiohttp", self.name) return False if not self._hass_token: - print(f"[{self.name}] No HASS_TOKEN configured") + logger.warning("[%s] No HASS_TOKEN configured", self.name) return False try: @@ -115,14 +112,17 @@ class HomeAssistantAdapter(BasePlatformAdapter): if not success: return False + # Dedicated REST session for send() calls + self._rest_session = aiohttp.ClientSession() + # Start background listener self._listen_task = asyncio.create_task(self._listen_loop()) self._running = True - print(f"[{self.name}] Connected to {self._hass_url}") + logger.info("[%s] Connected to %s", self.name, self._hass_url) return True except Exception as e: - print(f"[{self.name}] Failed to connect: {e}") + logger.error("[%s] Failed to connect: %s", self.name, e) return False async def _ws_connect(self) -> bool: @@ -191,7 +191,10 @@ class HomeAssistantAdapter(BasePlatformAdapter): self._listen_task = None await self._cleanup_ws() - print(f"[{self.name}] Disconnected") + if self._rest_session and not self._rest_session.closed: + await self._rest_session.close() + self._rest_session = None + logger.info("[%s] Disconnected", self.name) # ------------------------------------------------------------------ # Event listener @@ -214,7 +217,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): # Reconnect with backoff delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)] - print(f"[{self.name}] Reconnecting in {delay}s...") + logger.info("[%s] Reconnecting in %ds...", self.name, delay) await asyncio.sleep(delay) backoff_idx += 1 @@ -223,7 +226,7 @@ class HomeAssistantAdapter(BasePlatformAdapter): success = await self._ws_connect() if success: backoff_idx = 0 # Reset on successful reconnect - print(f"[{self.name}] Reconnected") + logger.info("[%s] Reconnected", self.name) except Exception as e: logger.warning("[%s] Reconnection failed: %s", self.name, e) @@ -385,8 +388,8 @@ class HomeAssistantAdapter(BasePlatformAdapter): } try: - async with aiohttp.ClientSession() as session: - async with session.post( + if self._rest_session: + async with self._rest_session.post( url, headers=headers, json=payload, @@ -397,6 +400,19 @@ class HomeAssistantAdapter(BasePlatformAdapter): else: body = await resp.text() return SendResult(success=False, error=f"HTTP {resp.status}: {body}") + else: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status < 300: + return SendResult(success=True, message_id=uuid.uuid4().hex[:12]) + else: + body = await resp.text() + return SendResult(success=False, error=f"HTTP {resp.status}: {body}") except asyncio.TimeoutError: return SendResult(success=False, error="Timeout sending notification to HA") diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index 8462d6b8b..6cfe34923 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -36,6 +36,7 @@ CONFIGURABLE_TOOLSETS = [ ("delegation", "๐Ÿ‘ฅ Task Delegation", "delegate_task"), ("cronjob", "โฐ Cron Jobs", "schedule, list, remove"), ("rl", "๐Ÿงช RL Training", "Tinker-Atropos training tools"), + ("homeassistant", "๐Ÿ  Home Assistant", "smart home device control"), ] # Platform display config @@ -312,6 +313,8 @@ TOOLSET_ENV_REQUIREMENTS = { "tts": [], # Edge TTS is free, no key needed "rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"), ("WANDB_API_KEY", "https://wandb.ai/authorize")], + "homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"), + ("HASS_URL", None)], } diff --git a/tests/gateway/test_homeassistant.py b/tests/gateway/test_homeassistant.py index f8bf7844d..8701ef14a 100644 --- a/tests/gateway/test_homeassistant.py +++ b/tests/gateway/test_homeassistant.py @@ -569,19 +569,19 @@ class TestToolsetIntegration: from toolsets import resolve_toolset tools = resolve_toolset("homeassistant") - assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"} + assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"} def test_gateway_toolset_includes_ha_tools(self): from toolsets import resolve_toolset gateway_tools = resolve_toolset("hermes-gateway") - for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"): assert tool in gateway_tools def test_hermes_core_tools_includes_ha(self): from toolsets import _HERMES_CORE_TOOLS - for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"): assert tool in _HERMES_CORE_TOOLS diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py index 177296108..a9077cff3 100644 --- a/tools/homeassistant_tool.py +++ b/tools/homeassistant_tool.py @@ -1,8 +1,9 @@ """Home Assistant tool for controlling smart home devices via REST API. -Registers three LLM-callable tools: +Registers four LLM-callable tools: - ``ha_list_entities`` -- list/filter entities by domain or area - ``ha_get_state`` -- get detailed state of a single entity +- ``ha_list_services`` -- list available services (actions) per domain - ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.) Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var. @@ -22,8 +23,17 @@ logger = logging.getLogger(__name__) # Configuration # --------------------------------------------------------------------------- -_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/") -_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "") +# Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config(). +_HASS_URL: str = "" +_HASS_TOKEN: str = "" + + +def _get_config(): + """Return (hass_url, hass_token) from env vars at call time.""" + return ( + (_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"), + _HASS_TOKEN or os.getenv("HASS_TOKEN", ""), + ) # Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1") _ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$") @@ -41,10 +51,12 @@ _BLOCKED_DOMAINS = frozenset({ }) -def _get_headers() -> Dict[str, str]: +def _get_headers(token: str = "") -> Dict[str, str]: """Return authorization headers for HA REST API.""" + if not token: + _, token = _get_config() return { - "Authorization": f"Bearer {_HASS_TOKEN}", + "Authorization": f"Bearer {token}", "Content-Type": "application/json", } @@ -88,9 +100,10 @@ async def _async_list_entities( """Fetch entity states from HA and optionally filter by domain/area.""" import aiohttp - url = f"{_HASS_URL}/api/states" + hass_url, hass_token = _get_config() + url = f"{hass_url}/api/states" async with aiohttp.ClientSession() as session: - async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp: + async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp: resp.raise_for_status() states = await resp.json() @@ -101,9 +114,10 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]: """Fetch detailed state of a single entity.""" import aiohttp - url = f"{_HASS_URL}/api/states/{entity_id}" + hass_url, hass_token = _get_config() + url = f"{hass_url}/api/states/{entity_id}" async with aiohttp.ClientSession() as session: - async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp: + async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp: resp.raise_for_status() data = await resp.json() @@ -160,13 +174,14 @@ async def _async_call_service( """Call a Home Assistant service.""" import aiohttp - url = f"{_HASS_URL}/api/services/{domain}/{service}" + hass_url, hass_token = _get_config() + url = f"{hass_url}/api/services/{domain}/{service}" payload = _build_service_payload(entity_id, data) async with aiohttp.ClientSession() as session: async with session.post( url, - headers=_get_headers(), + headers=_get_headers(hass_token), json=payload, timeout=aiohttp.ClientTimeout(total=15), ) as resp: @@ -250,6 +265,55 @@ def _handle_call_service(args: dict, **kw) -> str: return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"}) +# --------------------------------------------------------------------------- +# List services +# --------------------------------------------------------------------------- + +async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]: + """Fetch available services from HA and optionally filter by domain.""" + import aiohttp + + hass_url, hass_token = _get_config() + url = f"{hass_url}/api/services" + headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"} + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp: + resp.raise_for_status() + services = await resp.json() + + if domain: + services = [s for s in services if s.get("domain") == domain] + + # Compact the output for context efficiency + result = [] + for svc_domain in services: + d = svc_domain.get("domain", "") + domain_services = {} + for svc_name, svc_info in svc_domain.get("services", {}).items(): + svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")} + fields = svc_info.get("fields", {}) + if fields: + svc_entry["fields"] = { + k: v.get("description", "") for k, v in fields.items() + if isinstance(v, dict) + } + domain_services[svc_name] = svc_entry + result.append({"domain": d, "services": domain_services}) + + return {"count": len(result), "domains": result} + + +def _handle_list_services(args: dict, **kw) -> str: + """Handler for ha_list_services tool.""" + domain = args.get("domain") + try: + result = _run_async(_async_list_services(domain=domain)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_list_services error: %s", e) + return json.dumps({"error": f"Failed to list services: {e}"}) + + # --------------------------------------------------------------------------- # Availability check # --------------------------------------------------------------------------- @@ -314,12 +378,34 @@ HA_GET_STATE_SCHEMA = { }, } +HA_LIST_SERVICES_SCHEMA = { + "name": "ha_list_services", + "description": ( + "List available Home Assistant services (actions) for device control. " + "Shows what actions can be performed on each device type and what " + "parameters they accept. Use this to discover how to control devices " + "found via ha_list_entities." + ), + "parameters": { + "type": "object", + "properties": { + "domain": { + "type": "string", + "description": ( + "Filter by domain (e.g. 'light', 'climate', 'switch'). " + "Omit to list services for all domains." + ), + }, + }, + "required": [], + }, +} + HA_CALL_SERVICE_SCHEMA = { "name": "ha_call_service", "description": ( - "Call a Home Assistant service to control a device. Common examples: " - "turn_on/turn_off lights and switches, set_temperature for climate, " - "open_cover/close_cover for blinds, set_volume_level for media players." + "Call a Home Assistant service to control a device. Use ha_list_services " + "to discover available services and their parameters for each domain." ), "parameters": { "type": "object", @@ -383,6 +469,14 @@ registry.register( check_fn=_check_ha_available, ) +registry.register( + name="ha_list_services", + toolset="homeassistant", + schema=HA_LIST_SERVICES_SCHEMA, + handler=_handle_list_services, + check_fn=_check_ha_available, +) + registry.register( name="ha_call_service", toolset="homeassistant", diff --git a/toolsets.py b/toolsets.py index 44b814498..8589a35ea 100644 --- a/toolsets.py +++ b/toolsets.py @@ -63,7 +63,7 @@ _HERMES_CORE_TOOLS = [ # Honcho user context (gated on honcho being active via check_fn) "query_user_context", # Home Assistant smart home control (gated on HASS_TOKEN via check_fn) - "ha_list_entities", "ha_get_state", "ha_call_service", + "ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service", ] @@ -198,7 +198,7 @@ TOOLSETS = { "homeassistant": { "description": "Home Assistant smart home control and monitoring", - "tools": ["ha_list_entities", "ha_get_state", "ha_call_service"], + "tools": ["ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service"], "includes": [] },