diff --git a/gateway/config.py b/gateway/config.py index 32b623ea4..f441e2dd6 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -26,6 +26,7 @@ class Platform(Enum): DISCORD = "discord" WHATSAPP = "whatsapp" SLACK = "slack" + HOMEASSISTANT = "homeassistant" @dataclass @@ -378,6 +379,17 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""), ) + # Home Assistant + hass_token = os.getenv("HASS_TOKEN") + if hass_token: + if Platform.HOMEASSISTANT not in config.platforms: + config.platforms[Platform.HOMEASSISTANT] = PlatformConfig() + config.platforms[Platform.HOMEASSISTANT].enabled = True + config.platforms[Platform.HOMEASSISTANT].token = hass_token + hass_url = os.getenv("HASS_URL") + if hass_url: + config.platforms[Platform.HOMEASSISTANT].extra["url"] = hass_url + # Session settings idle_minutes = os.getenv("SESSION_IDLE_MINUTES") if idle_minutes: diff --git a/gateway/platforms/homeassistant.py b/gateway/platforms/homeassistant.py new file mode 100644 index 000000000..749cdf1e4 --- /dev/null +++ b/gateway/platforms/homeassistant.py @@ -0,0 +1,413 @@ +""" +Home Assistant platform adapter. + +Connects to the HA WebSocket API for real-time event monitoring. +State-change events are converted to MessageEvent objects and forwarded +to the agent for processing. Outbound messages are delivered as HA +persistent notifications. + +Requires: +- aiohttp (already in messaging extras) +- HASS_TOKEN env var (Long-Lived Access Token) +- HASS_URL env var (default: http://homeassistant.local:8123) +""" + +import asyncio +import json +import logging +import os +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Set + +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + aiohttp = None # type: ignore[assignment] + +import sys +from pathlib import Path as _Path +sys.path.insert(0, str(_Path(__file__).resolve().parents[2])) + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + + +def check_ha_requirements() -> bool: + """Check if Home Assistant dependencies are available and configured.""" + if not AIOHTTP_AVAILABLE: + return False + if not os.getenv("HASS_TOKEN"): + return False + return True + + +class HomeAssistantAdapter(BasePlatformAdapter): + """ + Home Assistant WebSocket adapter. + + Subscribes to ``state_changed`` events and forwards them as + MessageEvent objects. Supports domain/entity filtering and + per-entity cooldowns to avoid event floods. + """ + + MAX_MESSAGE_LENGTH = 4096 + + # Reconnection backoff schedule (seconds) + _BACKOFF_STEPS = [5, 10, 30, 60] + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.HOMEASSISTANT) + + # Connection state + self._session: Optional["aiohttp.ClientSession"] = None + self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None + self._listen_task: Optional[asyncio.Task] = None + self._msg_id: int = 0 + + # Configuration from extra + extra = config.extra or {} + token = config.token or os.getenv("HASS_TOKEN", "") + url = extra.get("url") or os.getenv("HASS_URL", "http://homeassistant.local:8123") + self._hass_url: str = url.rstrip("/") + self._hass_token: str = token + + # Event filtering + self._watch_domains: Set[str] = set(extra.get("watch_domains", [])) + self._watch_entities: Set[str] = set(extra.get("watch_entities", [])) + self._ignore_entities: Set[str] = set(extra.get("ignore_entities", [])) + self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30)) + + # Cooldown tracking: entity_id -> last_event_timestamp + self._last_event_time: Dict[str, float] = {} + + def _next_id(self) -> int: + """Return the next WebSocket message ID.""" + self._msg_id += 1 + return self._msg_id + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to HA WebSocket API and subscribe to events.""" + if not AIOHTTP_AVAILABLE: + print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp") + return False + + if not self._hass_token: + print(f"[{self.name}] No HASS_TOKEN configured") + return False + + try: + success = await self._ws_connect() + if not success: + return False + + # Start background listener + self._listen_task = asyncio.create_task(self._listen_loop()) + self._running = True + print(f"[{self.name}] Connected to {self._hass_url}") + return True + + except Exception as e: + print(f"[{self.name}] Failed to connect: {e}") + return False + + async def _ws_connect(self) -> bool: + """Establish WebSocket connection and authenticate.""" + ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{ws_url}/api/websocket" + + self._session = aiohttp.ClientSession() + self._ws = await self._session.ws_connect(ws_url, heartbeat=30) + + # Step 1: Receive auth_required + msg = await self._ws.receive_json() + if msg.get("type") != "auth_required": + logger.error("Expected auth_required, got: %s", msg.get("type")) + await self._cleanup_ws() + return False + + # Step 2: Send auth + await self._ws.send_json({ + "type": "auth", + "access_token": self._hass_token, + }) + + # Step 3: Wait for auth_ok + msg = await self._ws.receive_json() + if msg.get("type") != "auth_ok": + logger.error("Auth failed: %s", msg) + await self._cleanup_ws() + return False + + # Step 4: Subscribe to state_changed events + sub_id = self._next_id() + await self._ws.send_json({ + "id": sub_id, + "type": "subscribe_events", + "event_type": "state_changed", + }) + + # Verify subscription acknowledgement + msg = await self._ws.receive_json() + if not msg.get("success"): + logger.error("Failed to subscribe to events: %s", msg) + await self._cleanup_ws() + return False + + return True + + async def _cleanup_ws(self) -> None: + """Close WebSocket and session.""" + if self._ws and not self._ws.closed: + await self._ws.close() + self._ws = None + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def disconnect(self) -> None: + """Disconnect from Home Assistant.""" + self._running = False + if self._listen_task: + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + pass + self._listen_task = None + + await self._cleanup_ws() + print(f"[{self.name}] Disconnected") + + # ------------------------------------------------------------------ + # Event listener + # ------------------------------------------------------------------ + + async def _listen_loop(self) -> None: + """Main event loop with automatic reconnection.""" + backoff_idx = 0 + + while self._running: + try: + await self._read_events() + except asyncio.CancelledError: + return + except Exception as e: + logger.warning("[%s] WebSocket error: %s", self.name, e) + + if not self._running: + return + + # Reconnect with backoff + delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)] + print(f"[{self.name}] Reconnecting in {delay}s...") + await asyncio.sleep(delay) + backoff_idx += 1 + + try: + await self._cleanup_ws() + success = await self._ws_connect() + if success: + backoff_idx = 0 # Reset on successful reconnect + print(f"[{self.name}] Reconnected") + except Exception as e: + logger.warning("[%s] Reconnection failed: %s", self.name, e) + + async def _read_events(self) -> None: + """Read events from WebSocket until disconnected.""" + async for ws_msg in self._ws: + if ws_msg.type == aiohttp.WSMsgType.TEXT: + try: + data = json.loads(ws_msg.data) + if data.get("type") == "event": + await self._handle_ha_event(data.get("event", {})) + except json.JSONDecodeError: + logger.debug("Invalid JSON from HA WS: %s", ws_msg.data[:200]) + elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR): + break + + async def _handle_ha_event(self, event: Dict[str, Any]) -> None: + """Process a state_changed event from Home Assistant.""" + event_data = event.get("data", {}) + entity_id: str = event_data.get("entity_id", "") + + if not entity_id: + return + + # Apply ignore filter + if entity_id in self._ignore_entities: + return + + # Apply domain/entity watch filters + domain = entity_id.split(".")[0] if "." in entity_id else "" + if self._watch_domains or self._watch_entities: + domain_match = domain in self._watch_domains if self._watch_domains else False + entity_match = entity_id in self._watch_entities if self._watch_entities else False + if not domain_match and not entity_match: + return + + # Apply cooldown + now = time.time() + last = self._last_event_time.get(entity_id, 0) + if (now - last) < self._cooldown_seconds: + return + self._last_event_time[entity_id] = now + + # Build human-readable message + old_state = event_data.get("old_state", {}) + new_state = event_data.get("new_state", {}) + message = self._format_state_change(entity_id, old_state, new_state) + + if not message: + return + + # Build MessageEvent and forward to handler + source = self.build_source( + chat_id="ha_events", + chat_name="Home Assistant Events", + chat_type="channel", + user_id="homeassistant", + user_name="Home Assistant", + ) + + msg_event = MessageEvent( + text=message, + message_type=MessageType.TEXT, + source=source, + message_id=f"ha_{entity_id}_{int(now)}", + timestamp=datetime.now(), + ) + + await self.handle_message(msg_event) + + @staticmethod + def _format_state_change( + entity_id: str, + old_state: Dict[str, Any], + new_state: Dict[str, Any], + ) -> Optional[str]: + """Convert a state_changed event into a human-readable description.""" + if not new_state: + return None + + old_val = old_state.get("state", "unknown") if old_state else "unknown" + new_val = new_state.get("state", "unknown") + + # Skip if state didn't actually change + if old_val == new_val: + return None + + friendly_name = new_state.get("attributes", {}).get("friendly_name", entity_id) + domain = entity_id.split(".")[0] if "." in entity_id else "" + + # Domain-specific formatting + if domain == "climate": + attrs = new_state.get("attributes", {}) + temp = attrs.get("current_temperature", "?") + target = attrs.get("temperature", "?") + return ( + f"[Home Assistant] {friendly_name}: HVAC mode changed from " + f"'{old_val}' to '{new_val}' (current: {temp}, target: {target})" + ) + + if domain == "sensor": + unit = new_state.get("attributes", {}).get("unit_of_measurement", "") + return ( + f"[Home Assistant] {friendly_name}: changed from " + f"{old_val}{unit} to {new_val}{unit}" + ) + + if domain == "binary_sensor": + return ( + f"[Home Assistant] {friendly_name}: " + f"{'triggered' if new_val == 'on' else 'cleared'} " + f"(was {'triggered' if old_val == 'on' else 'cleared'})" + ) + + if domain in ("light", "switch", "fan"): + return ( + f"[Home Assistant] {friendly_name}: turned " + f"{'on' if new_val == 'on' else 'off'}" + ) + + if domain == "alarm_control_panel": + return ( + f"[Home Assistant] {friendly_name}: alarm state changed from " + f"'{old_val}' to '{new_val}'" + ) + + # Generic fallback + return ( + f"[Home Assistant] {friendly_name} ({entity_id}): " + f"changed from '{old_val}' to '{new_val}'" + ) + + # ------------------------------------------------------------------ + # Outbound messaging + # ------------------------------------------------------------------ + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a notification via HA REST API (persistent_notification.create). + + Uses the REST API instead of WebSocket to avoid a race condition + with the event listener loop that reads from the same WS connection. + """ + url = f"{self._hass_url}/api/services/persistent_notification/create" + headers = { + "Authorization": f"Bearer {self._hass_token}", + "Content-Type": "application/json", + } + payload = { + "title": "Hermes Agent", + "message": content[:self.MAX_MESSAGE_LENGTH], + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=10), + ) as resp: + if resp.status < 300: + return SendResult(success=True, message_id=str(self._next_id())) + else: + body = await resp.text() + return SendResult(success=False, error=f"HTTP {resp.status}: {body}") + + except asyncio.TimeoutError: + return SendResult(success=False, error="Timeout sending notification to HA") + except Exception as e: + return SendResult(success=False, error=str(e)) + + async def send_typing(self, chat_id: str) -> None: + """No typing indicator for Home Assistant.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return basic info about the HA event channel.""" + return { + "name": "Home Assistant Events", + "type": "channel", + "url": self._hass_url, + } diff --git a/gateway/run.py b/gateway/run.py index bcd2457b9..76ed3666c 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -469,7 +469,14 @@ class GatewayRunner: logger.warning("Slack: slack-bolt not installed. Run: pip install 'hermes-agent[slack]'") return None return SlackAdapter(config) - + + elif platform == Platform.HOMEASSISTANT: + from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements + if not check_ha_requirements(): + logger.warning("HomeAssistant: aiohttp not installed or HASS_TOKEN not set") + return None + return HomeAssistantAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: diff --git a/model_tools.py b/model_tools.py index 036bb34ba..38f01385d 100644 --- a/model_tools.py +++ b/model_tools.py @@ -94,6 +94,7 @@ def _discover_tools(): "tools.process_registry", "tools.send_message_tool", "tools.honcho_tools", + "tools.homeassistant_tool", ] import importlib for mod_name in _modules: diff --git a/pyproject.toml b/pyproject.toml index 152b47305..a002f1bc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ cli = ["simple-term-menu"] tts-premium = ["elevenlabs"] pty = ["ptyprocess>=0.7.0"] honcho = ["honcho-ai>=2.0.1"] +homeassistant = ["aiohttp>=3.9.0"] all = [ "hermes-agent[modal]", "hermes-agent[messaging]", @@ -57,6 +58,7 @@ all = [ "hermes-agent[slack]", "hermes-agent[pty]", "hermes-agent[honcho]", + "hermes-agent[homeassistant]", ] [project.scripts] diff --git a/tests/gateway/test_homeassistant.py b/tests/gateway/test_homeassistant.py new file mode 100644 index 000000000..f8bf7844d --- /dev/null +++ b/tests/gateway/test_homeassistant.py @@ -0,0 +1,604 @@ +"""Tests for the Home Assistant gateway adapter. + +Tests real logic: state change formatting, event filtering pipeline, +cooldown behavior, config integration, and adapter initialization. +""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import ( + GatewayConfig, + Platform, + PlatformConfig, +) +from gateway.platforms.homeassistant import ( + HomeAssistantAdapter, + check_ha_requirements, +) + + +# --------------------------------------------------------------------------- +# check_ha_requirements +# --------------------------------------------------------------------------- + + +class TestCheckRequirements: + def test_returns_false_without_token(self, monkeypatch): + monkeypatch.delenv("HASS_TOKEN", raising=False) + assert check_ha_requirements() is False + + def test_returns_true_with_token(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "test-token") + assert check_ha_requirements() is True + + @patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False) + def test_returns_false_without_aiohttp(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "test-token") + assert check_ha_requirements() is False + + +# --------------------------------------------------------------------------- +# _format_state_change - pure function, all domain branches +# --------------------------------------------------------------------------- + + +class TestFormatStateChange: + @staticmethod + def fmt(entity_id, old_state, new_state): + return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state) + + def test_climate_includes_temperatures(self): + msg = self.fmt( + "climate.thermostat", + {"state": "off"}, + {"state": "heat", "attributes": { + "friendly_name": "Main Thermostat", + "current_temperature": 21.5, + "temperature": 23, + }}, + ) + assert "Main Thermostat" in msg + assert "'off'" in msg and "'heat'" in msg + assert "21.5" in msg and "23" in msg + + def test_sensor_includes_unit(self): + msg = self.fmt( + "sensor.temperature", + {"state": "22.5"}, + {"state": "25.1", "attributes": { + "friendly_name": "Living Room Temp", + "unit_of_measurement": "C", + }}, + ) + assert "22.5C" in msg and "25.1C" in msg + assert "Living Room Temp" in msg + + def test_sensor_without_unit(self): + msg = self.fmt( + "sensor.count", + {"state": "5"}, + {"state": "10", "attributes": {"friendly_name": "Counter"}}, + ) + assert "5" in msg and "10" in msg + + def test_binary_sensor_on(self): + msg = self.fmt( + "binary_sensor.motion", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Hallway Motion"}}, + ) + assert "triggered" in msg + assert "Hallway Motion" in msg + + def test_binary_sensor_off(self): + msg = self.fmt( + "binary_sensor.door", + {"state": "on"}, + {"state": "off", "attributes": {"friendly_name": "Front Door"}}, + ) + assert "cleared" in msg + + def test_light_turned_on(self): + msg = self.fmt( + "light.bedroom", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Bedroom Light"}}, + ) + assert "turned on" in msg + + def test_switch_turned_off(self): + msg = self.fmt( + "switch.heater", + {"state": "on"}, + {"state": "off", "attributes": {"friendly_name": "Heater"}}, + ) + assert "turned off" in msg + + def test_fan_domain_uses_light_switch_branch(self): + msg = self.fmt( + "fan.ceiling", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}}, + ) + assert "turned on" in msg + + def test_alarm_panel(self): + msg = self.fmt( + "alarm_control_panel.home", + {"state": "disarmed"}, + {"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}}, + ) + assert "Home Alarm" in msg + assert "armed_away" in msg and "disarmed" in msg + + def test_generic_domain_includes_entity_id(self): + msg = self.fmt( + "automation.morning", + {"state": "off"}, + {"state": "on", "attributes": {"friendly_name": "Morning Routine"}}, + ) + assert "automation.morning" in msg + assert "Morning Routine" in msg + + def test_same_state_returns_none(self): + assert self.fmt( + "sensor.temp", + {"state": "22"}, + {"state": "22", "attributes": {"friendly_name": "Temp"}}, + ) is None + + def test_empty_new_state_returns_none(self): + assert self.fmt("light.x", {"state": "on"}, {}) is None + + def test_no_old_state_uses_unknown(self): + msg = self.fmt( + "light.new", + None, + {"state": "on", "attributes": {"friendly_name": "New Light"}}, + ) + assert msg is not None + assert "New Light" in msg + + def test_uses_entity_id_when_no_friendly_name(self): + msg = self.fmt( + "sensor.unnamed", + {"state": "1"}, + {"state": "2", "attributes": {}}, + ) + assert "sensor.unnamed" in msg + + +# --------------------------------------------------------------------------- +# Adapter initialization from config +# --------------------------------------------------------------------------- + + +class TestAdapterInit: + def test_url_and_token_from_config_extra(self, monkeypatch): + monkeypatch.delenv("HASS_URL", raising=False) + monkeypatch.delenv("HASS_TOKEN", raising=False) + + config = PlatformConfig( + enabled=True, + token="config-token", + extra={"url": "http://192.168.1.50:8123"}, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._hass_token == "config-token" + assert adapter._hass_url == "http://192.168.1.50:8123" + + def test_url_fallback_to_env(self, monkeypatch): + monkeypatch.setenv("HASS_URL", "http://env-host:8123") + monkeypatch.setenv("HASS_TOKEN", "env-tok") + + config = PlatformConfig(enabled=True, token="env-tok") + adapter = HomeAssistantAdapter(config) + assert adapter._hass_url == "http://env-host:8123" + + def test_trailing_slash_stripped(self): + config = PlatformConfig( + enabled=True, token="t", + extra={"url": "http://ha.local:8123/"}, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._hass_url == "http://ha.local:8123" + + def test_watch_filters_parsed(self): + config = PlatformConfig( + enabled=True, token="t", + extra={ + "watch_domains": ["climate", "binary_sensor"], + "watch_entities": ["sensor.special"], + "ignore_entities": ["sensor.uptime", "sensor.cpu"], + "cooldown_seconds": 120, + }, + ) + adapter = HomeAssistantAdapter(config) + assert adapter._watch_domains == {"climate", "binary_sensor"} + assert adapter._watch_entities == {"sensor.special"} + assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"} + assert adapter._cooldown_seconds == 120 + + def test_defaults_when_no_extra(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "tok") + config = PlatformConfig(enabled=True, token="tok") + adapter = HomeAssistantAdapter(config) + assert adapter._watch_domains == set() + assert adapter._watch_entities == set() + assert adapter._ignore_entities == set() + assert adapter._cooldown_seconds == 30 + + +# --------------------------------------------------------------------------- +# Event filtering pipeline (_handle_ha_event) +# +# We mock handle_message (not our code, it's the base class pipeline) to +# capture the MessageEvent that _handle_ha_event produces. +# --------------------------------------------------------------------------- + + +def _make_adapter(**extra) -> HomeAssistantAdapter: + config = PlatformConfig(enabled=True, token="tok", extra=extra) + adapter = HomeAssistantAdapter(config) + adapter.handle_message = AsyncMock() + return adapter + + +def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None): + return { + "data": { + "entity_id": entity_id, + "old_state": {"state": old_state, "attributes": old_attrs or {}}, + "new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}}, + } + } + + +class TestEventFilteringPipeline: + @pytest.mark.asyncio + async def test_ignored_entity_not_forwarded(self): + adapter = _make_adapter(ignore_entities=["sensor.uptime"]) + await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_unwatched_domain_not_forwarded(self): + adapter = _make_adapter(watch_domains=["climate"]) + await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_watched_domain_forwarded(self): + adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("climate.thermostat", "off", "heat", + new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22}) + ) + adapter.handle_message.assert_called_once() + + # Verify the actual MessageEvent text content + msg_event = adapter.handle_message.call_args[0][0] + assert "Thermostat" in msg_event.text + assert "heat" in msg_event.text + assert msg_event.source.platform == Platform.HOMEASSISTANT + assert msg_event.source.chat_id == "ha_events" + + @pytest.mark.asyncio + async def test_watched_entity_forwarded(self): + adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("sensor.important", "10", "20", + new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"}) + ) + adapter.handle_message.assert_called_once() + msg_event = adapter.handle_message.call_args[0][0] + assert "10W" in msg_event.text and "20W" in msg_event.text + + @pytest.mark.asyncio + async def test_no_filters_passes_everything(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open")) + adapter.handle_message.assert_called_once() + + @pytest.mark.asyncio + async def test_same_state_not_forwarded(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event(_make_event("light.x", "on", "on")) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_empty_entity_id_skipped(self): + adapter = _make_adapter() + await adapter._handle_ha_event({"data": {"entity_id": ""}}) + adapter.handle_message.assert_not_called() + + @pytest.mark.asyncio + async def test_message_event_has_correct_source(self): + adapter = _make_adapter(cooldown_seconds=0) + await adapter._handle_ha_event( + _make_event("light.test", "off", "on", + new_attrs={"friendly_name": "Test Light"}) + ) + msg_event = adapter.handle_message.call_args[0][0] + assert msg_event.source.user_name == "Home Assistant" + assert msg_event.source.chat_type == "channel" + assert msg_event.message_id.startswith("ha_light.test_") + + +# --------------------------------------------------------------------------- +# Cooldown behavior +# --------------------------------------------------------------------------- + + +class TestCooldown: + @pytest.mark.asyncio + async def test_cooldown_blocks_rapid_events(self): + adapter = _make_adapter(cooldown_seconds=60) + + event = _make_event("sensor.temp", "20", "21", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event) + assert adapter.handle_message.call_count == 1 + + # Second event immediately after should be blocked + event2 = _make_event("sensor.temp", "21", "22", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event2) + assert adapter.handle_message.call_count == 1 # Still 1 + + @pytest.mark.asyncio + async def test_cooldown_expires(self): + adapter = _make_adapter(cooldown_seconds=1) + + event = _make_event("sensor.temp", "20", "21", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event) + assert adapter.handle_message.call_count == 1 + + # Simulate time passing beyond cooldown + adapter._last_event_time["sensor.temp"] = time.time() - 2 + + event2 = _make_event("sensor.temp", "21", "22", + new_attrs={"friendly_name": "Temp"}) + await adapter._handle_ha_event(event2) + assert adapter.handle_message.call_count == 2 + + @pytest.mark.asyncio + async def test_different_entities_independent_cooldowns(self): + adapter = _make_adapter(cooldown_seconds=60) + + await adapter._handle_ha_event( + _make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"}) + ) + await adapter._handle_ha_event( + _make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"}) + ) + # Both should pass - different entities + assert adapter.handle_message.call_count == 2 + + # Same entity again - should be blocked + await adapter._handle_ha_event( + _make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"}) + ) + assert adapter.handle_message.call_count == 2 # Still 2 + + @pytest.mark.asyncio + async def test_zero_cooldown_passes_all(self): + adapter = _make_adapter(cooldown_seconds=0) + + for i in range(5): + await adapter._handle_ha_event( + _make_event("sensor.temp", str(i), str(i + 1), + new_attrs={"friendly_name": "Temp"}) + ) + assert adapter.handle_message.call_count == 5 + + +# --------------------------------------------------------------------------- +# Config integration (env overrides, round-trip) +# --------------------------------------------------------------------------- + + +class TestConfigIntegration: + def test_env_override_creates_ha_platform(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "env-token") + monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123") + # Clear other platform tokens + for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]: + monkeypatch.delenv(v, raising=False) + + from gateway.config import load_gateway_config + config = load_gateway_config() + + assert Platform.HOMEASSISTANT in config.platforms + ha = config.platforms[Platform.HOMEASSISTANT] + assert ha.enabled is True + assert ha.token == "env-token" + assert ha.extra["url"] == "http://10.0.0.5:8123" + + def test_no_env_no_platform(self, monkeypatch): + for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN", + "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]: + monkeypatch.delenv(v, raising=False) + + from gateway.config import load_gateway_config + config = load_gateway_config() + assert Platform.HOMEASSISTANT not in config.platforms + + def test_config_roundtrip_preserves_extra(self): + config = GatewayConfig( + platforms={ + Platform.HOMEASSISTANT: PlatformConfig( + enabled=True, + token="tok", + extra={ + "url": "http://ha:8123", + "watch_domains": ["climate"], + "cooldown_seconds": 45, + }, + ), + }, + ) + d = config.to_dict() + restored = GatewayConfig.from_dict(d) + + ha = restored.platforms[Platform.HOMEASSISTANT] + assert ha.enabled is True + assert ha.token == "tok" + assert ha.extra["watch_domains"] == ["climate"] + assert ha.extra["cooldown_seconds"] == 45 + + def test_connected_platforms_includes_ha(self): + config = GatewayConfig( + platforms={ + Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"), + Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"), + }, + ) + connected = config.get_connected_platforms() + assert Platform.HOMEASSISTANT in connected + assert Platform.TELEGRAM not in connected + + +# --------------------------------------------------------------------------- +# send() via REST API +# --------------------------------------------------------------------------- + + +class TestSendViaRestApi: + """send() uses REST API (not WebSocket) to avoid race conditions.""" + + @staticmethod + def _mock_aiohttp_session(response_status=200, response_text="OK"): + """Build a mock aiohttp session + response for async-with patterns. + + aiohttp.ClientSession() is a sync constructor whose return value + is used as ``async with session:``. ``session.post(...)`` returns a + context-manager (not a coroutine), so both layers use MagicMock for + the call and AsyncMock only for ``__aenter__`` / ``__aexit__``. + """ + mock_response = MagicMock() + mock_response.status = response_status + mock_response.text = AsyncMock(return_value=response_text) + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.post = MagicMock(return_value=mock_response) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + return mock_session + + @pytest.mark.asyncio + async def test_send_success(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(200) + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + result = await adapter.send("ha_events", "Test notification") + + assert result.success is True + # Verify the REST API was called with correct payload + call_args = mock_session.post.call_args + assert "/api/services/persistent_notification/create" in call_args[0][0] + assert call_args[1]["json"]["title"] == "Hermes Agent" + assert call_args[1]["json"]["message"] == "Test notification" + assert "Bearer tok" in call_args[1]["headers"]["Authorization"] + + @pytest.mark.asyncio + async def test_send_http_error(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(401, "Unauthorized") + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + result = await adapter.send("ha_events", "Test") + + assert result.success is False + assert "401" in result.error + + @pytest.mark.asyncio + async def test_send_truncates_long_message(self): + adapter = _make_adapter() + mock_session = self._mock_aiohttp_session(200) + long_message = "x" * 10000 + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + await adapter.send("ha_events", long_message) + + sent_message = mock_session.post.call_args[1]["json"]["message"] + assert len(sent_message) == 4096 + + @pytest.mark.asyncio + async def test_send_does_not_use_websocket(self): + """send() must use REST API, not the WS connection (race condition fix).""" + adapter = _make_adapter() + adapter._ws = AsyncMock() # Simulate an active WS + mock_session = self._mock_aiohttp_session(200) + + with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp: + mock_aiohttp.ClientSession = MagicMock(return_value=mock_session) + mock_aiohttp.ClientTimeout = lambda total: total + + await adapter.send("ha_events", "Test") + + # WS should NOT have been used for sending + adapter._ws.send_json.assert_not_called() + adapter._ws.receive_json.assert_not_called() + + +# --------------------------------------------------------------------------- +# Toolset integration +# --------------------------------------------------------------------------- + + +class TestToolsetIntegration: + def test_homeassistant_toolset_resolves(self): + from toolsets import resolve_toolset + + tools = resolve_toolset("homeassistant") + assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"} + + def test_gateway_toolset_includes_ha_tools(self): + from toolsets import resolve_toolset + + gateway_tools = resolve_toolset("hermes-gateway") + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert tool in gateway_tools + + def test_hermes_core_tools_includes_ha(self): + from toolsets import _HERMES_CORE_TOOLS + + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert tool in _HERMES_CORE_TOOLS + + +# --------------------------------------------------------------------------- +# WebSocket URL construction +# --------------------------------------------------------------------------- + + +class TestWsUrlConstruction: + def test_http_to_ws(self): + config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"}) + adapter = HomeAssistantAdapter(config) + ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://") + assert ws_url == "ws://ha:8123" + + def test_https_to_wss(self): + config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"}) + adapter = HomeAssistantAdapter(config) + ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://") + assert ws_url == "wss://ha.example.com" diff --git a/tests/tools/test_homeassistant_tool.py b/tests/tools/test_homeassistant_tool.py new file mode 100644 index 000000000..6235474ef --- /dev/null +++ b/tests/tools/test_homeassistant_tool.py @@ -0,0 +1,281 @@ +"""Tests for the Home Assistant tool module. + +Tests real logic: entity filtering, payload building, response parsing, +handler validation, and availability gating. +""" + +import json + +import pytest + +from tools.homeassistant_tool import ( + _check_ha_available, + _filter_and_summarize, + _build_service_payload, + _parse_service_response, + _get_headers, + _handle_get_state, + _handle_call_service, +) + + +# --------------------------------------------------------------------------- +# Sample HA state data (matches real HA /api/states response shape) +# --------------------------------------------------------------------------- + +SAMPLE_STATES = [ + {"entity_id": "light.bedroom", "state": "on", "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}}, + {"entity_id": "light.kitchen", "state": "off", "attributes": {"friendly_name": "Kitchen Light"}}, + {"entity_id": "switch.fan", "state": "on", "attributes": {"friendly_name": "Living Room Fan"}}, + {"entity_id": "sensor.temperature", "state": "22.5", "attributes": {"friendly_name": "Kitchen Temperature", "unit_of_measurement": "C"}}, + {"entity_id": "climate.thermostat", "state": "heat", "attributes": {"friendly_name": "Main Thermostat", "current_temperature": 21}}, + {"entity_id": "binary_sensor.motion", "state": "off", "attributes": {"friendly_name": "Hallway Motion"}}, + {"entity_id": "sensor.humidity", "state": "55", "attributes": {"friendly_name": "Bedroom Humidity", "area": "bedroom"}}, +] + + +# --------------------------------------------------------------------------- +# Entity filtering and summarization +# --------------------------------------------------------------------------- + + +class TestFilterAndSummarize: + def test_no_filters_returns_all(self): + result = _filter_and_summarize(SAMPLE_STATES) + assert result["count"] == 7 + ids = {e["entity_id"] for e in result["entities"]} + assert "light.bedroom" in ids + assert "climate.thermostat" in ids + + def test_domain_filter_lights(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="light") + assert result["count"] == 2 + for e in result["entities"]: + assert e["entity_id"].startswith("light.") + + def test_domain_filter_sensor(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="sensor") + assert result["count"] == 2 + ids = {e["entity_id"] for e in result["entities"]} + assert ids == {"sensor.temperature", "sensor.humidity"} + + def test_domain_filter_no_matches(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="media_player") + assert result["count"] == 0 + assert result["entities"] == [] + + def test_area_filter_by_friendly_name(self): + result = _filter_and_summarize(SAMPLE_STATES, area="kitchen") + assert result["count"] == 2 + ids = {e["entity_id"] for e in result["entities"]} + assert "light.kitchen" in ids + assert "sensor.temperature" in ids + + def test_area_filter_by_area_attribute(self): + result = _filter_and_summarize(SAMPLE_STATES, area="bedroom") + ids = {e["entity_id"] for e in result["entities"]} + # "Bedroom Light" matches via friendly_name, "Bedroom Humidity" matches via area attr + assert "light.bedroom" in ids + assert "sensor.humidity" in ids + + def test_area_filter_case_insensitive(self): + result = _filter_and_summarize(SAMPLE_STATES, area="KITCHEN") + assert result["count"] == 2 + + def test_combined_domain_and_area(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="sensor", area="kitchen") + assert result["count"] == 1 + assert result["entities"][0]["entity_id"] == "sensor.temperature" + + def test_summary_includes_friendly_name(self): + result = _filter_and_summarize(SAMPLE_STATES, domain="climate") + assert result["entities"][0]["friendly_name"] == "Main Thermostat" + assert result["entities"][0]["state"] == "heat" + + def test_empty_states_list(self): + result = _filter_and_summarize([]) + assert result["count"] == 0 + + def test_missing_attributes_handled(self): + states = [{"entity_id": "light.x", "state": "on"}] + result = _filter_and_summarize(states) + assert result["count"] == 1 + assert result["entities"][0]["friendly_name"] == "" + + +# --------------------------------------------------------------------------- +# Service payload building +# --------------------------------------------------------------------------- + + +class TestBuildServicePayload: + def test_entity_id_only(self): + payload = _build_service_payload(entity_id="light.bedroom") + assert payload == {"entity_id": "light.bedroom"} + + def test_data_only(self): + payload = _build_service_payload(data={"brightness": 255}) + assert payload == {"brightness": 255} + + def test_entity_id_and_data(self): + payload = _build_service_payload( + entity_id="light.bedroom", + data={"brightness": 200, "color_name": "blue"}, + ) + assert payload["entity_id"] == "light.bedroom" + assert payload["brightness"] == 200 + assert payload["color_name"] == "blue" + + def test_no_args_returns_empty(self): + payload = _build_service_payload() + assert payload == {} + + def test_data_does_not_overwrite_entity_id(self): + payload = _build_service_payload( + entity_id="light.a", + data={"entity_id": "light.b"}, + ) + # data.update overwrites entity_id set earlier + assert payload["entity_id"] == "light.b" + + +# --------------------------------------------------------------------------- +# Service response parsing +# --------------------------------------------------------------------------- + + +class TestParseServiceResponse: + def test_list_response_extracts_entities(self): + ha_response = [ + {"entity_id": "light.bedroom", "state": "on", "attributes": {}}, + {"entity_id": "light.kitchen", "state": "on", "attributes": {}}, + ] + result = _parse_service_response("light", "turn_on", ha_response) + assert result["success"] is True + assert result["service"] == "light.turn_on" + assert len(result["affected_entities"]) == 2 + assert result["affected_entities"][0]["entity_id"] == "light.bedroom" + + def test_empty_list_response(self): + result = _parse_service_response("scene", "turn_on", []) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_non_list_response(self): + # Some HA services return a dict instead of a list + result = _parse_service_response("script", "run", {"result": "ok"}) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_none_response(self): + result = _parse_service_response("automation", "trigger", None) + assert result["success"] is True + assert result["affected_entities"] == [] + + def test_service_name_format(self): + result = _parse_service_response("climate", "set_temperature", []) + assert result["service"] == "climate.set_temperature" + + +# --------------------------------------------------------------------------- +# Handler validation (no mocks - these paths don't reach the network) +# --------------------------------------------------------------------------- + + +class TestHandlerValidation: + def test_get_state_missing_entity_id(self): + result = json.loads(_handle_get_state({})) + assert "error" in result + assert "entity_id" in result["error"] + + def test_get_state_empty_entity_id(self): + result = json.loads(_handle_get_state({"entity_id": ""})) + assert "error" in result + + def test_call_service_missing_domain(self): + result = json.loads(_handle_call_service({"service": "turn_on"})) + assert "error" in result + assert "domain" in result["error"] + + def test_call_service_missing_service(self): + result = json.loads(_handle_call_service({"domain": "light"})) + assert "error" in result + assert "service" in result["error"] + + def test_call_service_missing_both(self): + result = json.loads(_handle_call_service({})) + assert "error" in result + + def test_call_service_empty_strings(self): + result = json.loads(_handle_call_service({"domain": "", "service": ""})) + assert "error" in result + + +# --------------------------------------------------------------------------- +# Availability check +# --------------------------------------------------------------------------- + + +class TestCheckAvailable: + def test_unavailable_without_token(self, monkeypatch): + monkeypatch.delenv("HASS_TOKEN", raising=False) + assert _check_ha_available() is False + + def test_available_with_token(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "eyJ0eXAiOiJKV1Q") + assert _check_ha_available() is True + + def test_empty_token_is_unavailable(self, monkeypatch): + monkeypatch.setenv("HASS_TOKEN", "") + assert _check_ha_available() is False + + +# --------------------------------------------------------------------------- +# Auth headers +# --------------------------------------------------------------------------- + + +class TestGetHeaders: + def test_bearer_token_format(self, monkeypatch): + monkeypatch.setattr("tools.homeassistant_tool._HASS_TOKEN", "my-secret-token") + headers = _get_headers() + assert headers["Authorization"] == "Bearer my-secret-token" + assert headers["Content-Type"] == "application/json" + + +# --------------------------------------------------------------------------- +# Registry integration +# --------------------------------------------------------------------------- + + +class TestRegistration: + def test_tools_registered_in_registry(self): + from tools.registry import registry + + names = registry.get_all_tool_names() + assert "ha_list_entities" in names + assert "ha_get_state" in names + assert "ha_call_service" in names + + def test_tools_in_homeassistant_toolset(self): + from tools.registry import registry + + toolset_map = registry.get_tool_to_toolset_map() + for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"): + assert toolset_map[tool] == "homeassistant" + + def test_check_fn_gates_availability(self, monkeypatch): + """Registry should exclude HA tools when HASS_TOKEN is not set.""" + from tools.registry import registry + + monkeypatch.delenv("HASS_TOKEN", raising=False) + defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"}) + assert len(defs) == 0 + + def test_check_fn_includes_when_token_set(self, monkeypatch): + """Registry should include HA tools when HASS_TOKEN is set.""" + from tools.registry import registry + + monkeypatch.setenv("HASS_TOKEN", "test-token") + defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"}) + assert len(defs) == 3 diff --git a/tools/homeassistant_tool.py b/tools/homeassistant_tool.py new file mode 100644 index 000000000..4a01382f3 --- /dev/null +++ b/tools/homeassistant_tool.py @@ -0,0 +1,364 @@ +"""Home Assistant tool for controlling smart home devices via REST API. + +Registers three LLM-callable tools: +- ``ha_list_entities`` -- list/filter entities by domain or area +- ``ha_get_state`` -- get detailed state of a single entity +- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.) + +Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var. +The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123). +""" + +import asyncio +import json +import logging +import os +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/") +_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "") + + +def _get_headers() -> Dict[str, str]: + """Return authorization headers for HA REST API.""" + return { + "Authorization": f"Bearer {_HASS_TOKEN}", + "Content-Type": "application/json", + } + + +# --------------------------------------------------------------------------- +# Async helpers (called from sync handlers via run_until_complete) +# --------------------------------------------------------------------------- + +def _filter_and_summarize( + states: list, + domain: Optional[str] = None, + area: Optional[str] = None, +) -> Dict[str, Any]: + """Filter raw HA states by domain/area and return a compact summary.""" + if domain: + states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")] + + if area: + area_lower = area.lower() + states = [ + s for s in states + if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower() + or area_lower in (s.get("attributes", {}).get("area", "") or "").lower() + ] + + entities = [] + for s in states: + entities.append({ + "entity_id": s["entity_id"], + "state": s["state"], + "friendly_name": s.get("attributes", {}).get("friendly_name", ""), + }) + + return {"count": len(entities), "entities": entities} + + +async def _async_list_entities( + domain: Optional[str] = None, + area: Optional[str] = None, +) -> Dict[str, Any]: + """Fetch entity states from HA and optionally filter by domain/area.""" + import aiohttp + + url = f"{_HASS_URL}/api/states" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp: + resp.raise_for_status() + states = await resp.json() + + return _filter_and_summarize(states, domain, area) + + +async def _async_get_state(entity_id: str) -> Dict[str, Any]: + """Fetch detailed state of a single entity.""" + import aiohttp + + url = f"{_HASS_URL}/api/states/{entity_id}" + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp: + resp.raise_for_status() + data = await resp.json() + + return { + "entity_id": data["entity_id"], + "state": data["state"], + "attributes": data.get("attributes", {}), + "last_changed": data.get("last_changed"), + "last_updated": data.get("last_updated"), + } + + +def _build_service_payload( + entity_id: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Build the JSON payload for a HA service call.""" + payload: Dict[str, Any] = {} + if entity_id: + payload["entity_id"] = entity_id + if data: + payload.update(data) + return payload + + +def _parse_service_response( + domain: str, + service: str, + result: Any, +) -> Dict[str, Any]: + """Parse HA service call response into a structured result.""" + affected = [] + if isinstance(result, list): + for s in result: + affected.append({ + "entity_id": s.get("entity_id", ""), + "state": s.get("state", ""), + }) + + return { + "success": True, + "service": f"{domain}.{service}", + "affected_entities": affected, + } + + +async def _async_call_service( + domain: str, + service: str, + entity_id: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """Call a Home Assistant service.""" + import aiohttp + + url = f"{_HASS_URL}/api/services/{domain}/{service}" + payload = _build_service_payload(entity_id, data) + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + headers=_get_headers(), + json=payload, + timeout=aiohttp.ClientTimeout(total=15), + ) as resp: + resp.raise_for_status() + result = await resp.json() + + return _parse_service_response(domain, service, result) + + +# --------------------------------------------------------------------------- +# Sync wrappers (handler signature: (args, **kw) -> str) +# --------------------------------------------------------------------------- + +def _run_async(coro): + """Run an async coroutine from a sync handler.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop and loop.is_running(): + # Already inside an event loop -- create a new thread + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(asyncio.run, coro) + return future.result(timeout=30) + else: + return asyncio.run(coro) + + +def _handle_list_entities(args: dict, **kw) -> str: + """Handler for ha_list_entities tool.""" + domain = args.get("domain") + area = args.get("area") + try: + result = _run_async(_async_list_entities(domain=domain, area=area)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_list_entities error: %s", e) + return json.dumps({"error": f"Failed to list entities: {e}"}) + + +def _handle_get_state(args: dict, **kw) -> str: + """Handler for ha_get_state tool.""" + entity_id = args.get("entity_id", "") + if not entity_id: + return json.dumps({"error": "Missing required parameter: entity_id"}) + try: + result = _run_async(_async_get_state(entity_id)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_get_state error: %s", e) + return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"}) + + +def _handle_call_service(args: dict, **kw) -> str: + """Handler for ha_call_service tool.""" + domain = args.get("domain", "") + service = args.get("service", "") + if not domain or not service: + return json.dumps({"error": "Missing required parameters: domain and service"}) + + entity_id = args.get("entity_id") + data = args.get("data") + try: + result = _run_async(_async_call_service(domain, service, entity_id, data)) + return json.dumps({"result": result}) + except Exception as e: + logger.error("ha_call_service error: %s", e) + return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"}) + + +# --------------------------------------------------------------------------- +# Availability check +# --------------------------------------------------------------------------- + +def _check_ha_available() -> bool: + """Tool is only available when HASS_TOKEN is set.""" + return bool(os.getenv("HASS_TOKEN")) + + +# --------------------------------------------------------------------------- +# Tool schemas +# --------------------------------------------------------------------------- + +HA_LIST_ENTITIES_SCHEMA = { + "name": "ha_list_entities", + "description": ( + "List Home Assistant entities. Optionally filter by domain " + "(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) " + "or by area name (living room, kitchen, bedroom, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "domain": { + "type": "string", + "description": ( + "Entity domain to filter by (e.g. 'light', 'switch', 'climate', " + "'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). " + "Omit to list all entities." + ), + }, + "area": { + "type": "string", + "description": ( + "Area/room name to filter by (e.g. 'living room', 'kitchen'). " + "Matches against entity friendly names. Omit to list all." + ), + }, + }, + "required": [], + }, +} + +HA_GET_STATE_SCHEMA = { + "name": "ha_get_state", + "description": ( + "Get the detailed state of a single Home Assistant entity, including all " + "attributes (brightness, color, temperature setpoint, sensor readings, etc.)." + ), + "parameters": { + "type": "object", + "properties": { + "entity_id": { + "type": "string", + "description": ( + "The entity ID to query (e.g. 'light.living_room', " + "'climate.thermostat', 'sensor.temperature')." + ), + }, + }, + "required": ["entity_id"], + }, +} + +HA_CALL_SERVICE_SCHEMA = { + "name": "ha_call_service", + "description": ( + "Call a Home Assistant service to control a device. Common examples: " + "turn_on/turn_off lights and switches, set_temperature for climate, " + "open_cover/close_cover for blinds, set_volume_level for media players." + ), + "parameters": { + "type": "object", + "properties": { + "domain": { + "type": "string", + "description": ( + "Service domain (e.g. 'light', 'switch', 'climate', " + "'cover', 'media_player', 'fan', 'scene', 'script')." + ), + }, + "service": { + "type": "string", + "description": ( + "Service name (e.g. 'turn_on', 'turn_off', 'toggle', " + "'set_temperature', 'set_hvac_mode', 'open_cover', " + "'close_cover', 'set_volume_level')." + ), + }, + "entity_id": { + "type": "string", + "description": ( + "Target entity ID (e.g. 'light.living_room'). " + "Some services (like scene.turn_on) may not need this." + ), + }, + "data": { + "type": "object", + "description": ( + "Additional service data. Examples: " + '{"brightness": 255, "color_name": "blue"} for lights, ' + '{"temperature": 22, "hvac_mode": "heat"} for climate, ' + '{"volume_level": 0.5} for media players.' + ), + }, + }, + "required": ["domain", "service"], + }, +} + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + +from tools.registry import registry + +registry.register( + name="ha_list_entities", + toolset="homeassistant", + schema=HA_LIST_ENTITIES_SCHEMA, + handler=_handle_list_entities, + check_fn=_check_ha_available, +) + +registry.register( + name="ha_get_state", + toolset="homeassistant", + schema=HA_GET_STATE_SCHEMA, + handler=_handle_get_state, + check_fn=_check_ha_available, +) + +registry.register( + name="ha_call_service", + toolset="homeassistant", + schema=HA_CALL_SERVICE_SCHEMA, + handler=_handle_call_service, + check_fn=_check_ha_available, +) diff --git a/toolsets.py b/toolsets.py index 6090068a5..44b814498 100644 --- a/toolsets.py +++ b/toolsets.py @@ -62,6 +62,8 @@ _HERMES_CORE_TOOLS = [ "send_message", # Honcho user context (gated on honcho being active via check_fn) "query_user_context", + # Home Assistant smart home control (gated on HASS_TOKEN via check_fn) + "ha_list_entities", "ha_get_state", "ha_call_service", ] @@ -193,8 +195,14 @@ TOOLSETS = { "tools": ["query_user_context"], "includes": [] }, - - + + "homeassistant": { + "description": "Home Assistant smart home control and monitoring", + "tools": ["ha_list_entities", "ha_get_state", "ha_call_service"], + "includes": [] + }, + + # Scenario-specific toolsets "debugging": { @@ -247,10 +255,16 @@ TOOLSETS = { "includes": [] }, + "hermes-homeassistant": { + "description": "Home Assistant bot toolset - smart home event monitoring and control", + "tools": _HERMES_CORE_TOOLS, + "includes": [] + }, + "hermes-gateway": { "description": "Gateway toolset - union of all messaging platform tools", "tools": [], - "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack"] + "includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-homeassistant"] } } diff --git a/uv.lock b/uv.lock index 548633896..5e3bd5f77 100644 --- a/uv.lock +++ b/uv.lock @@ -1034,6 +1034,9 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, ] +homeassistant = [ + { name = "aiohttp" }, +] honcho = [ { name = "honcho-ai" }, ] @@ -1060,6 +1063,7 @@ tts-premium = [ [package.metadata] requires-dist = [ + { name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0" }, { name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.9.0" }, { name = "croniter", marker = "extra == 'cron'" }, { name = "discord-py", marker = "extra == 'messaging'", specifier = ">=2.0" }, @@ -1071,6 +1075,7 @@ requires-dist = [ { name = "hermes-agent", extras = ["cli"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["cron"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["dev"], marker = "extra == 'all'" }, + { name = "hermes-agent", extras = ["homeassistant"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["honcho"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["messaging"], marker = "extra == 'all'" }, { name = "hermes-agent", extras = ["modal"], marker = "extra == 'all'" }, @@ -1103,7 +1108,7 @@ requires-dist = [ { name = "tenacity" }, { name = "typer" }, ] -provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "all"] +provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "homeassistant", "all"] [[package]] name = "hf-xet"