feat: add Home Assistant integration (REST tools + WebSocket gateway)
- Add ha_list_entities, ha_get_state, ha_call_service tools via REST API - Add WebSocket gateway adapter for real-time state_changed event monitoring - Support domain/entity filtering, cooldown, and auto-reconnect with backoff - Use REST API for outbound notifications to avoid WS race condition - Gate tool availability on HASS_TOKEN env var - Add 82 unit tests covering real logic (filtering, payload building, event pipeline)
This commit is contained in:
@@ -26,6 +26,7 @@ class Platform(Enum):
|
||||
DISCORD = "discord"
|
||||
WHATSAPP = "whatsapp"
|
||||
SLACK = "slack"
|
||||
HOMEASSISTANT = "homeassistant"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -378,6 +379,17 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
# Home Assistant
|
||||
hass_token = os.getenv("HASS_TOKEN")
|
||||
if hass_token:
|
||||
if Platform.HOMEASSISTANT not in config.platforms:
|
||||
config.platforms[Platform.HOMEASSISTANT] = PlatformConfig()
|
||||
config.platforms[Platform.HOMEASSISTANT].enabled = True
|
||||
config.platforms[Platform.HOMEASSISTANT].token = hass_token
|
||||
hass_url = os.getenv("HASS_URL")
|
||||
if hass_url:
|
||||
config.platforms[Platform.HOMEASSISTANT].extra["url"] = hass_url
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
|
||||
413
gateway/platforms/homeassistant.py
Normal file
413
gateway/platforms/homeassistant.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Home Assistant platform adapter.
|
||||
|
||||
Connects to the HA WebSocket API for real-time event monitoring.
|
||||
State-change events are converted to MessageEvent objects and forwarded
|
||||
to the agent for processing. Outbound messages are delivered as HA
|
||||
persistent notifications.
|
||||
|
||||
Requires:
|
||||
- aiohttp (already in messaging extras)
|
||||
- HASS_TOKEN env var (Long-Lived Access Token)
|
||||
- HASS_URL env var (default: http://homeassistant.local:8123)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
aiohttp = None # type: ignore[assignment]
|
||||
|
||||
import sys
|
||||
from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_ha_requirements() -> bool:
|
||||
"""Check if Home Assistant dependencies are available and configured."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
return False
|
||||
if not os.getenv("HASS_TOKEN"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class HomeAssistantAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Home Assistant WebSocket adapter.
|
||||
|
||||
Subscribes to ``state_changed`` events and forwards them as
|
||||
MessageEvent objects. Supports domain/entity filtering and
|
||||
per-entity cooldowns to avoid event floods.
|
||||
"""
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
# Reconnection backoff schedule (seconds)
|
||||
_BACKOFF_STEPS = [5, 10, 30, 60]
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.HOMEASSISTANT)
|
||||
|
||||
# Connection state
|
||||
self._session: Optional["aiohttp.ClientSession"] = None
|
||||
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
|
||||
self._listen_task: Optional[asyncio.Task] = None
|
||||
self._msg_id: int = 0
|
||||
|
||||
# Configuration from extra
|
||||
extra = config.extra or {}
|
||||
token = config.token or os.getenv("HASS_TOKEN", "")
|
||||
url = extra.get("url") or os.getenv("HASS_URL", "http://homeassistant.local:8123")
|
||||
self._hass_url: str = url.rstrip("/")
|
||||
self._hass_token: str = token
|
||||
|
||||
# Event filtering
|
||||
self._watch_domains: Set[str] = set(extra.get("watch_domains", []))
|
||||
self._watch_entities: Set[str] = set(extra.get("watch_entities", []))
|
||||
self._ignore_entities: Set[str] = set(extra.get("ignore_entities", []))
|
||||
self._cooldown_seconds: int = int(extra.get("cooldown_seconds", 30))
|
||||
|
||||
# Cooldown tracking: entity_id -> last_event_timestamp
|
||||
self._last_event_time: Dict[str, float] = {}
|
||||
|
||||
def _next_id(self) -> int:
|
||||
"""Return the next WebSocket message ID."""
|
||||
self._msg_id += 1
|
||||
return self._msg_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connection lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to HA WebSocket API and subscribe to events."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp")
|
||||
return False
|
||||
|
||||
if not self._hass_token:
|
||||
print(f"[{self.name}] No HASS_TOKEN configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
success = await self._ws_connect()
|
||||
if not success:
|
||||
return False
|
||||
|
||||
# Start background listener
|
||||
self._listen_task = asyncio.create_task(self._listen_loop())
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected to {self._hass_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
async def _ws_connect(self) -> bool:
|
||||
"""Establish WebSocket connection and authenticate."""
|
||||
ws_url = self._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url}/api/websocket"
|
||||
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._ws = await self._session.ws_connect(ws_url, heartbeat=30)
|
||||
|
||||
# Step 1: Receive auth_required
|
||||
msg = await self._ws.receive_json()
|
||||
if msg.get("type") != "auth_required":
|
||||
logger.error("Expected auth_required, got: %s", msg.get("type"))
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
# Step 2: Send auth
|
||||
await self._ws.send_json({
|
||||
"type": "auth",
|
||||
"access_token": self._hass_token,
|
||||
})
|
||||
|
||||
# Step 3: Wait for auth_ok
|
||||
msg = await self._ws.receive_json()
|
||||
if msg.get("type") != "auth_ok":
|
||||
logger.error("Auth failed: %s", msg)
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
# Step 4: Subscribe to state_changed events
|
||||
sub_id = self._next_id()
|
||||
await self._ws.send_json({
|
||||
"id": sub_id,
|
||||
"type": "subscribe_events",
|
||||
"event_type": "state_changed",
|
||||
})
|
||||
|
||||
# Verify subscription acknowledgement
|
||||
msg = await self._ws.receive_json()
|
||||
if not msg.get("success"):
|
||||
logger.error("Failed to subscribe to events: %s", msg)
|
||||
await self._cleanup_ws()
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _cleanup_ws(self) -> None:
|
||||
"""Close WebSocket and session."""
|
||||
if self._ws and not self._ws.closed:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Home Assistant."""
|
||||
self._running = False
|
||||
if self._listen_task:
|
||||
self._listen_task.cancel()
|
||||
try:
|
||||
await self._listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._listen_task = None
|
||||
|
||||
await self._cleanup_ws()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event listener
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _listen_loop(self) -> None:
|
||||
"""Main event loop with automatic reconnection."""
|
||||
backoff_idx = 0
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._read_events()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning("[%s] WebSocket error: %s", self.name, e)
|
||||
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# Reconnect with backoff
|
||||
delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)]
|
||||
print(f"[{self.name}] Reconnecting in {delay}s...")
|
||||
await asyncio.sleep(delay)
|
||||
backoff_idx += 1
|
||||
|
||||
try:
|
||||
await self._cleanup_ws()
|
||||
success = await self._ws_connect()
|
||||
if success:
|
||||
backoff_idx = 0 # Reset on successful reconnect
|
||||
print(f"[{self.name}] Reconnected")
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Reconnection failed: %s", self.name, e)
|
||||
|
||||
async def _read_events(self) -> None:
|
||||
"""Read events from WebSocket until disconnected."""
|
||||
async for ws_msg in self._ws:
|
||||
if ws_msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(ws_msg.data)
|
||||
if data.get("type") == "event":
|
||||
await self._handle_ha_event(data.get("event", {}))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Invalid JSON from HA WS: %s", ws_msg.data[:200])
|
||||
elif ws_msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.ERROR):
|
||||
break
|
||||
|
||||
async def _handle_ha_event(self, event: Dict[str, Any]) -> None:
|
||||
"""Process a state_changed event from Home Assistant."""
|
||||
event_data = event.get("data", {})
|
||||
entity_id: str = event_data.get("entity_id", "")
|
||||
|
||||
if not entity_id:
|
||||
return
|
||||
|
||||
# Apply ignore filter
|
||||
if entity_id in self._ignore_entities:
|
||||
return
|
||||
|
||||
# Apply domain/entity watch filters
|
||||
domain = entity_id.split(".")[0] if "." in entity_id else ""
|
||||
if self._watch_domains or self._watch_entities:
|
||||
domain_match = domain in self._watch_domains if self._watch_domains else False
|
||||
entity_match = entity_id in self._watch_entities if self._watch_entities else False
|
||||
if not domain_match and not entity_match:
|
||||
return
|
||||
|
||||
# Apply cooldown
|
||||
now = time.time()
|
||||
last = self._last_event_time.get(entity_id, 0)
|
||||
if (now - last) < self._cooldown_seconds:
|
||||
return
|
||||
self._last_event_time[entity_id] = now
|
||||
|
||||
# Build human-readable message
|
||||
old_state = event_data.get("old_state", {})
|
||||
new_state = event_data.get("new_state", {})
|
||||
message = self._format_state_change(entity_id, old_state, new_state)
|
||||
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Build MessageEvent and forward to handler
|
||||
source = self.build_source(
|
||||
chat_id="ha_events",
|
||||
chat_name="Home Assistant Events",
|
||||
chat_type="channel",
|
||||
user_id="homeassistant",
|
||||
user_name="Home Assistant",
|
||||
)
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=message,
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
message_id=f"ha_{entity_id}_{int(now)}",
|
||||
timestamp=datetime.now(),
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
@staticmethod
|
||||
def _format_state_change(
|
||||
entity_id: str,
|
||||
old_state: Dict[str, Any],
|
||||
new_state: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
"""Convert a state_changed event into a human-readable description."""
|
||||
if not new_state:
|
||||
return None
|
||||
|
||||
old_val = old_state.get("state", "unknown") if old_state else "unknown"
|
||||
new_val = new_state.get("state", "unknown")
|
||||
|
||||
# Skip if state didn't actually change
|
||||
if old_val == new_val:
|
||||
return None
|
||||
|
||||
friendly_name = new_state.get("attributes", {}).get("friendly_name", entity_id)
|
||||
domain = entity_id.split(".")[0] if "." in entity_id else ""
|
||||
|
||||
# Domain-specific formatting
|
||||
if domain == "climate":
|
||||
attrs = new_state.get("attributes", {})
|
||||
temp = attrs.get("current_temperature", "?")
|
||||
target = attrs.get("temperature", "?")
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: HVAC mode changed from "
|
||||
f"'{old_val}' to '{new_val}' (current: {temp}, target: {target})"
|
||||
)
|
||||
|
||||
if domain == "sensor":
|
||||
unit = new_state.get("attributes", {}).get("unit_of_measurement", "")
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: changed from "
|
||||
f"{old_val}{unit} to {new_val}{unit}"
|
||||
)
|
||||
|
||||
if domain == "binary_sensor":
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: "
|
||||
f"{'triggered' if new_val == 'on' else 'cleared'} "
|
||||
f"(was {'triggered' if old_val == 'on' else 'cleared'})"
|
||||
)
|
||||
|
||||
if domain in ("light", "switch", "fan"):
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: turned "
|
||||
f"{'on' if new_val == 'on' else 'off'}"
|
||||
)
|
||||
|
||||
if domain == "alarm_control_panel":
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name}: alarm state changed from "
|
||||
f"'{old_val}' to '{new_val}'"
|
||||
)
|
||||
|
||||
# Generic fallback
|
||||
return (
|
||||
f"[Home Assistant] {friendly_name} ({entity_id}): "
|
||||
f"changed from '{old_val}' to '{new_val}'"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Outbound messaging
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send a notification via HA REST API (persistent_notification.create).
|
||||
|
||||
Uses the REST API instead of WebSocket to avoid a race condition
|
||||
with the event listener loop that reads from the same WS connection.
|
||||
"""
|
||||
url = f"{self._hass_url}/api/services/persistent_notification/create"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._hass_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"title": "Hermes Agent",
|
||||
"message": content[:self.MAX_MESSAGE_LENGTH],
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=10),
|
||||
) as resp:
|
||||
if resp.status < 300:
|
||||
return SendResult(success=True, message_id=str(self._next_id()))
|
||||
else:
|
||||
body = await resp.text()
|
||||
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return SendResult(success=False, error="Timeout sending notification to HA")
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""No typing indicator for Home Assistant."""
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Return basic info about the HA event channel."""
|
||||
return {
|
||||
"name": "Home Assistant Events",
|
||||
"type": "channel",
|
||||
"url": self._hass_url,
|
||||
}
|
||||
@@ -469,7 +469,14 @@ class GatewayRunner:
|
||||
logger.warning("Slack: slack-bolt not installed. Run: pip install 'hermes-agent[slack]'")
|
||||
return None
|
||||
return SlackAdapter(config)
|
||||
|
||||
|
||||
elif platform == Platform.HOMEASSISTANT:
|
||||
from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements
|
||||
if not check_ha_requirements():
|
||||
logger.warning("HomeAssistant: aiohttp not installed or HASS_TOKEN not set")
|
||||
return None
|
||||
return HomeAssistantAdapter(config)
|
||||
|
||||
return None
|
||||
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
|
||||
@@ -94,6 +94,7 @@ def _discover_tools():
|
||||
"tools.process_registry",
|
||||
"tools.send_message_tool",
|
||||
"tools.honcho_tools",
|
||||
"tools.homeassistant_tool",
|
||||
]
|
||||
import importlib
|
||||
for mod_name in _modules:
|
||||
|
||||
@@ -47,6 +47,7 @@ cli = ["simple-term-menu"]
|
||||
tts-premium = ["elevenlabs"]
|
||||
pty = ["ptyprocess>=0.7.0"]
|
||||
honcho = ["honcho-ai>=2.0.1"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[messaging]",
|
||||
@@ -57,6 +58,7 @@ all = [
|
||||
"hermes-agent[slack]",
|
||||
"hermes-agent[pty]",
|
||||
"hermes-agent[honcho]",
|
||||
"hermes-agent[homeassistant]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
604
tests/gateway/test_homeassistant.py
Normal file
604
tests/gateway/test_homeassistant.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""Tests for the Home Assistant gateway adapter.
|
||||
|
||||
Tests real logic: state change formatting, event filtering pipeline,
|
||||
cooldown behavior, config integration, and adapter initialization.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
)
|
||||
from gateway.platforms.homeassistant import (
|
||||
HomeAssistantAdapter,
|
||||
check_ha_requirements,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_ha_requirements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckRequirements:
|
||||
def test_returns_false_without_token(self, monkeypatch):
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
assert check_ha_requirements() is False
|
||||
|
||||
def test_returns_true_with_token(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
assert check_ha_requirements() is True
|
||||
|
||||
@patch("gateway.platforms.homeassistant.AIOHTTP_AVAILABLE", False)
|
||||
def test_returns_false_without_aiohttp(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
assert check_ha_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_state_change - pure function, all domain branches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatStateChange:
|
||||
@staticmethod
|
||||
def fmt(entity_id, old_state, new_state):
|
||||
return HomeAssistantAdapter._format_state_change(entity_id, old_state, new_state)
|
||||
|
||||
def test_climate_includes_temperatures(self):
|
||||
msg = self.fmt(
|
||||
"climate.thermostat",
|
||||
{"state": "off"},
|
||||
{"state": "heat", "attributes": {
|
||||
"friendly_name": "Main Thermostat",
|
||||
"current_temperature": 21.5,
|
||||
"temperature": 23,
|
||||
}},
|
||||
)
|
||||
assert "Main Thermostat" in msg
|
||||
assert "'off'" in msg and "'heat'" in msg
|
||||
assert "21.5" in msg and "23" in msg
|
||||
|
||||
def test_sensor_includes_unit(self):
|
||||
msg = self.fmt(
|
||||
"sensor.temperature",
|
||||
{"state": "22.5"},
|
||||
{"state": "25.1", "attributes": {
|
||||
"friendly_name": "Living Room Temp",
|
||||
"unit_of_measurement": "C",
|
||||
}},
|
||||
)
|
||||
assert "22.5C" in msg and "25.1C" in msg
|
||||
assert "Living Room Temp" in msg
|
||||
|
||||
def test_sensor_without_unit(self):
|
||||
msg = self.fmt(
|
||||
"sensor.count",
|
||||
{"state": "5"},
|
||||
{"state": "10", "attributes": {"friendly_name": "Counter"}},
|
||||
)
|
||||
assert "5" in msg and "10" in msg
|
||||
|
||||
def test_binary_sensor_on(self):
|
||||
msg = self.fmt(
|
||||
"binary_sensor.motion",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Hallway Motion"}},
|
||||
)
|
||||
assert "triggered" in msg
|
||||
assert "Hallway Motion" in msg
|
||||
|
||||
def test_binary_sensor_off(self):
|
||||
msg = self.fmt(
|
||||
"binary_sensor.door",
|
||||
{"state": "on"},
|
||||
{"state": "off", "attributes": {"friendly_name": "Front Door"}},
|
||||
)
|
||||
assert "cleared" in msg
|
||||
|
||||
def test_light_turned_on(self):
|
||||
msg = self.fmt(
|
||||
"light.bedroom",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Bedroom Light"}},
|
||||
)
|
||||
assert "turned on" in msg
|
||||
|
||||
def test_switch_turned_off(self):
|
||||
msg = self.fmt(
|
||||
"switch.heater",
|
||||
{"state": "on"},
|
||||
{"state": "off", "attributes": {"friendly_name": "Heater"}},
|
||||
)
|
||||
assert "turned off" in msg
|
||||
|
||||
def test_fan_domain_uses_light_switch_branch(self):
|
||||
msg = self.fmt(
|
||||
"fan.ceiling",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Ceiling Fan"}},
|
||||
)
|
||||
assert "turned on" in msg
|
||||
|
||||
def test_alarm_panel(self):
|
||||
msg = self.fmt(
|
||||
"alarm_control_panel.home",
|
||||
{"state": "disarmed"},
|
||||
{"state": "armed_away", "attributes": {"friendly_name": "Home Alarm"}},
|
||||
)
|
||||
assert "Home Alarm" in msg
|
||||
assert "armed_away" in msg and "disarmed" in msg
|
||||
|
||||
def test_generic_domain_includes_entity_id(self):
|
||||
msg = self.fmt(
|
||||
"automation.morning",
|
||||
{"state": "off"},
|
||||
{"state": "on", "attributes": {"friendly_name": "Morning Routine"}},
|
||||
)
|
||||
assert "automation.morning" in msg
|
||||
assert "Morning Routine" in msg
|
||||
|
||||
def test_same_state_returns_none(self):
|
||||
assert self.fmt(
|
||||
"sensor.temp",
|
||||
{"state": "22"},
|
||||
{"state": "22", "attributes": {"friendly_name": "Temp"}},
|
||||
) is None
|
||||
|
||||
def test_empty_new_state_returns_none(self):
|
||||
assert self.fmt("light.x", {"state": "on"}, {}) is None
|
||||
|
||||
def test_no_old_state_uses_unknown(self):
|
||||
msg = self.fmt(
|
||||
"light.new",
|
||||
None,
|
||||
{"state": "on", "attributes": {"friendly_name": "New Light"}},
|
||||
)
|
||||
assert msg is not None
|
||||
assert "New Light" in msg
|
||||
|
||||
def test_uses_entity_id_when_no_friendly_name(self):
|
||||
msg = self.fmt(
|
||||
"sensor.unnamed",
|
||||
{"state": "1"},
|
||||
{"state": "2", "attributes": {}},
|
||||
)
|
||||
assert "sensor.unnamed" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adapter initialization from config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdapterInit:
|
||||
def test_url_and_token_from_config_extra(self, monkeypatch):
|
||||
monkeypatch.delenv("HASS_URL", raising=False)
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="config-token",
|
||||
extra={"url": "http://192.168.1.50:8123"},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_token == "config-token"
|
||||
assert adapter._hass_url == "http://192.168.1.50:8123"
|
||||
|
||||
def test_url_fallback_to_env(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_URL", "http://env-host:8123")
|
||||
monkeypatch.setenv("HASS_TOKEN", "env-tok")
|
||||
|
||||
config = PlatformConfig(enabled=True, token="env-tok")
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_url == "http://env-host:8123"
|
||||
|
||||
def test_trailing_slash_stripped(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="t",
|
||||
extra={"url": "http://ha.local:8123/"},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._hass_url == "http://ha.local:8123"
|
||||
|
||||
def test_watch_filters_parsed(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="t",
|
||||
extra={
|
||||
"watch_domains": ["climate", "binary_sensor"],
|
||||
"watch_entities": ["sensor.special"],
|
||||
"ignore_entities": ["sensor.uptime", "sensor.cpu"],
|
||||
"cooldown_seconds": 120,
|
||||
},
|
||||
)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._watch_domains == {"climate", "binary_sensor"}
|
||||
assert adapter._watch_entities == {"sensor.special"}
|
||||
assert adapter._ignore_entities == {"sensor.uptime", "sensor.cpu"}
|
||||
assert adapter._cooldown_seconds == 120
|
||||
|
||||
def test_defaults_when_no_extra(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "tok")
|
||||
config = PlatformConfig(enabled=True, token="tok")
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
assert adapter._watch_domains == set()
|
||||
assert adapter._watch_entities == set()
|
||||
assert adapter._ignore_entities == set()
|
||||
assert adapter._cooldown_seconds == 30
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Event filtering pipeline (_handle_ha_event)
|
||||
#
|
||||
# We mock handle_message (not our code, it's the base class pipeline) to
|
||||
# capture the MessageEvent that _handle_ha_event produces.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_adapter(**extra) -> HomeAssistantAdapter:
|
||||
config = PlatformConfig(enabled=True, token="tok", extra=extra)
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
adapter.handle_message = AsyncMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_event(entity_id, old_state, new_state, old_attrs=None, new_attrs=None):
|
||||
return {
|
||||
"data": {
|
||||
"entity_id": entity_id,
|
||||
"old_state": {"state": old_state, "attributes": old_attrs or {}},
|
||||
"new_state": {"state": new_state, "attributes": new_attrs or {"friendly_name": entity_id}},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestEventFilteringPipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignored_entity_not_forwarded(self):
|
||||
adapter = _make_adapter(ignore_entities=["sensor.uptime"])
|
||||
await adapter._handle_ha_event(_make_event("sensor.uptime", "100", "101"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unwatched_domain_not_forwarded(self):
|
||||
adapter = _make_adapter(watch_domains=["climate"])
|
||||
await adapter._handle_ha_event(_make_event("light.bedroom", "off", "on"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watched_domain_forwarded(self):
|
||||
adapter = _make_adapter(watch_domains=["climate"], cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("climate.thermostat", "off", "heat",
|
||||
new_attrs={"friendly_name": "Thermostat", "current_temperature": 20, "temperature": 22})
|
||||
)
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
# Verify the actual MessageEvent text content
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Thermostat" in msg_event.text
|
||||
assert "heat" in msg_event.text
|
||||
assert msg_event.source.platform == Platform.HOMEASSISTANT
|
||||
assert msg_event.source.chat_id == "ha_events"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watched_entity_forwarded(self):
|
||||
adapter = _make_adapter(watch_entities=["sensor.important"], cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.important", "10", "20",
|
||||
new_attrs={"friendly_name": "Important Sensor", "unit_of_measurement": "W"})
|
||||
)
|
||||
adapter.handle_message.assert_called_once()
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "10W" in msg_event.text and "20W" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_filters_passes_everything(self):
|
||||
adapter = _make_adapter(cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(_make_event("cover.blinds", "closed", "open"))
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_state_not_forwarded(self):
|
||||
adapter = _make_adapter(cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(_make_event("light.x", "on", "on"))
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entity_id_skipped(self):
|
||||
adapter = _make_adapter()
|
||||
await adapter._handle_ha_event({"data": {"entity_id": ""}})
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_event_has_correct_source(self):
|
||||
adapter = _make_adapter(cooldown_seconds=0)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("light.test", "off", "on",
|
||||
new_attrs={"friendly_name": "Test Light"})
|
||||
)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.source.user_name == "Home Assistant"
|
||||
assert msg_event.source.chat_type == "channel"
|
||||
assert msg_event.message_id.startswith("ha_light.test_")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cooldown behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCooldown:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_blocks_rapid_events(self):
|
||||
adapter = _make_adapter(cooldown_seconds=60)
|
||||
|
||||
event = _make_event("sensor.temp", "20", "21",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# Second event immediately after should be blocked
|
||||
event2 = _make_event("sensor.temp", "21", "22",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event2)
|
||||
assert adapter.handle_message.call_count == 1 # Still 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cooldown_expires(self):
|
||||
adapter = _make_adapter(cooldown_seconds=1)
|
||||
|
||||
event = _make_event("sensor.temp", "20", "21",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event)
|
||||
assert adapter.handle_message.call_count == 1
|
||||
|
||||
# Simulate time passing beyond cooldown
|
||||
adapter._last_event_time["sensor.temp"] = time.time() - 2
|
||||
|
||||
event2 = _make_event("sensor.temp", "21", "22",
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
await adapter._handle_ha_event(event2)
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_entities_independent_cooldowns(self):
|
||||
adapter = _make_adapter(cooldown_seconds=60)
|
||||
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.a", "1", "2", new_attrs={"friendly_name": "A"})
|
||||
)
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.b", "3", "4", new_attrs={"friendly_name": "B"})
|
||||
)
|
||||
# Both should pass - different entities
|
||||
assert adapter.handle_message.call_count == 2
|
||||
|
||||
# Same entity again - should be blocked
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.a", "2", "3", new_attrs={"friendly_name": "A"})
|
||||
)
|
||||
assert adapter.handle_message.call_count == 2 # Still 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_cooldown_passes_all(self):
|
||||
adapter = _make_adapter(cooldown_seconds=0)
|
||||
|
||||
for i in range(5):
|
||||
await adapter._handle_ha_event(
|
||||
_make_event("sensor.temp", str(i), str(i + 1),
|
||||
new_attrs={"friendly_name": "Temp"})
|
||||
)
|
||||
assert adapter.handle_message.call_count == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config integration (env overrides, round-trip)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigIntegration:
|
||||
def test_env_override_creates_ha_platform(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "env-token")
|
||||
monkeypatch.setenv("HASS_URL", "http://10.0.0.5:8123")
|
||||
# Clear other platform tokens
|
||||
for v in ["TELEGRAM_BOT_TOKEN", "DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
|
||||
assert Platform.HOMEASSISTANT in config.platforms
|
||||
ha = config.platforms[Platform.HOMEASSISTANT]
|
||||
assert ha.enabled is True
|
||||
assert ha.token == "env-token"
|
||||
assert ha.extra["url"] == "http://10.0.0.5:8123"
|
||||
|
||||
def test_no_env_no_platform(self, monkeypatch):
|
||||
for v in ["HASS_TOKEN", "HASS_URL", "TELEGRAM_BOT_TOKEN",
|
||||
"DISCORD_BOT_TOKEN", "SLACK_BOT_TOKEN"]:
|
||||
monkeypatch.delenv(v, raising=False)
|
||||
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
assert Platform.HOMEASSISTANT not in config.platforms
|
||||
|
||||
def test_config_roundtrip_preserves_extra(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.HOMEASSISTANT: PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok",
|
||||
extra={
|
||||
"url": "http://ha:8123",
|
||||
"watch_domains": ["climate"],
|
||||
"cooldown_seconds": 45,
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
d = config.to_dict()
|
||||
restored = GatewayConfig.from_dict(d)
|
||||
|
||||
ha = restored.platforms[Platform.HOMEASSISTANT]
|
||||
assert ha.enabled is True
|
||||
assert ha.token == "tok"
|
||||
assert ha.extra["watch_domains"] == ["climate"]
|
||||
assert ha.extra["cooldown_seconds"] == 45
|
||||
|
||||
def test_connected_platforms_includes_ha(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.HOMEASSISTANT: PlatformConfig(enabled=True, token="tok"),
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token="t"),
|
||||
},
|
||||
)
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.HOMEASSISTANT in connected
|
||||
assert Platform.TELEGRAM not in connected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() via REST API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendViaRestApi:
|
||||
"""send() uses REST API (not WebSocket) to avoid race conditions."""
|
||||
|
||||
@staticmethod
|
||||
def _mock_aiohttp_session(response_status=200, response_text="OK"):
|
||||
"""Build a mock aiohttp session + response for async-with patterns.
|
||||
|
||||
aiohttp.ClientSession() is a sync constructor whose return value
|
||||
is used as ``async with session:``. ``session.post(...)`` returns a
|
||||
context-manager (not a coroutine), so both layers use MagicMock for
|
||||
the call and AsyncMock only for ``__aenter__`` / ``__aexit__``.
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = response_status
|
||||
mock_response.text = AsyncMock(return_value=response_text)
|
||||
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_response.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(return_value=mock_response)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return mock_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_success(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
result = await adapter.send("ha_events", "Test notification")
|
||||
|
||||
assert result.success is True
|
||||
# Verify the REST API was called with correct payload
|
||||
call_args = mock_session.post.call_args
|
||||
assert "/api/services/persistent_notification/create" in call_args[0][0]
|
||||
assert call_args[1]["json"]["title"] == "Hermes Agent"
|
||||
assert call_args[1]["json"]["message"] == "Test notification"
|
||||
assert "Bearer tok" in call_args[1]["headers"]["Authorization"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_http_error(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(401, "Unauthorized")
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
result = await adapter.send("ha_events", "Test")
|
||||
|
||||
assert result.success is False
|
||||
assert "401" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_truncates_long_message(self):
|
||||
adapter = _make_adapter()
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
long_message = "x" * 10000
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
await adapter.send("ha_events", long_message)
|
||||
|
||||
sent_message = mock_session.post.call_args[1]["json"]["message"]
|
||||
assert len(sent_message) == 4096
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_use_websocket(self):
|
||||
"""send() must use REST API, not the WS connection (race condition fix)."""
|
||||
adapter = _make_adapter()
|
||||
adapter._ws = AsyncMock() # Simulate an active WS
|
||||
mock_session = self._mock_aiohttp_session(200)
|
||||
|
||||
with patch("gateway.platforms.homeassistant.aiohttp") as mock_aiohttp:
|
||||
mock_aiohttp.ClientSession = MagicMock(return_value=mock_session)
|
||||
mock_aiohttp.ClientTimeout = lambda total: total
|
||||
|
||||
await adapter.send("ha_events", "Test")
|
||||
|
||||
# WS should NOT have been used for sending
|
||||
adapter._ws.send_json.assert_not_called()
|
||||
adapter._ws.receive_json.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Toolset integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolsetIntegration:
|
||||
def test_homeassistant_toolset_resolves(self):
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
tools = resolve_toolset("homeassistant")
|
||||
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"}
|
||||
|
||||
def test_gateway_toolset_includes_ha_tools(self):
|
||||
from toolsets import resolve_toolset
|
||||
|
||||
gateway_tools = resolve_toolset("hermes-gateway")
|
||||
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
|
||||
assert tool in gateway_tools
|
||||
|
||||
def test_hermes_core_tools_includes_ha(self):
|
||||
from toolsets import _HERMES_CORE_TOOLS
|
||||
|
||||
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
|
||||
assert tool in _HERMES_CORE_TOOLS
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket URL construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWsUrlConstruction:
|
||||
def test_http_to_ws(self):
|
||||
config = PlatformConfig(enabled=True, token="t", extra={"url": "http://ha:8123"})
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
assert ws_url == "ws://ha:8123"
|
||||
|
||||
def test_https_to_wss(self):
|
||||
config = PlatformConfig(enabled=True, token="t", extra={"url": "https://ha.example.com"})
|
||||
adapter = HomeAssistantAdapter(config)
|
||||
ws_url = adapter._hass_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
assert ws_url == "wss://ha.example.com"
|
||||
281
tests/tools/test_homeassistant_tool.py
Normal file
281
tests/tools/test_homeassistant_tool.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Tests for the Home Assistant tool module.
|
||||
|
||||
Tests real logic: entity filtering, payload building, response parsing,
|
||||
handler validation, and availability gating.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.homeassistant_tool import (
|
||||
_check_ha_available,
|
||||
_filter_and_summarize,
|
||||
_build_service_payload,
|
||||
_parse_service_response,
|
||||
_get_headers,
|
||||
_handle_get_state,
|
||||
_handle_call_service,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sample HA state data (matches real HA /api/states response shape)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_STATES = [
|
||||
{"entity_id": "light.bedroom", "state": "on", "attributes": {"friendly_name": "Bedroom Light", "brightness": 200}},
|
||||
{"entity_id": "light.kitchen", "state": "off", "attributes": {"friendly_name": "Kitchen Light"}},
|
||||
{"entity_id": "switch.fan", "state": "on", "attributes": {"friendly_name": "Living Room Fan"}},
|
||||
{"entity_id": "sensor.temperature", "state": "22.5", "attributes": {"friendly_name": "Kitchen Temperature", "unit_of_measurement": "C"}},
|
||||
{"entity_id": "climate.thermostat", "state": "heat", "attributes": {"friendly_name": "Main Thermostat", "current_temperature": 21}},
|
||||
{"entity_id": "binary_sensor.motion", "state": "off", "attributes": {"friendly_name": "Hallway Motion"}},
|
||||
{"entity_id": "sensor.humidity", "state": "55", "attributes": {"friendly_name": "Bedroom Humidity", "area": "bedroom"}},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entity filtering and summarization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFilterAndSummarize:
|
||||
def test_no_filters_returns_all(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES)
|
||||
assert result["count"] == 7
|
||||
ids = {e["entity_id"] for e in result["entities"]}
|
||||
assert "light.bedroom" in ids
|
||||
assert "climate.thermostat" in ids
|
||||
|
||||
def test_domain_filter_lights(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, domain="light")
|
||||
assert result["count"] == 2
|
||||
for e in result["entities"]:
|
||||
assert e["entity_id"].startswith("light.")
|
||||
|
||||
def test_domain_filter_sensor(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, domain="sensor")
|
||||
assert result["count"] == 2
|
||||
ids = {e["entity_id"] for e in result["entities"]}
|
||||
assert ids == {"sensor.temperature", "sensor.humidity"}
|
||||
|
||||
def test_domain_filter_no_matches(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, domain="media_player")
|
||||
assert result["count"] == 0
|
||||
assert result["entities"] == []
|
||||
|
||||
def test_area_filter_by_friendly_name(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, area="kitchen")
|
||||
assert result["count"] == 2
|
||||
ids = {e["entity_id"] for e in result["entities"]}
|
||||
assert "light.kitchen" in ids
|
||||
assert "sensor.temperature" in ids
|
||||
|
||||
def test_area_filter_by_area_attribute(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, area="bedroom")
|
||||
ids = {e["entity_id"] for e in result["entities"]}
|
||||
# "Bedroom Light" matches via friendly_name, "Bedroom Humidity" matches via area attr
|
||||
assert "light.bedroom" in ids
|
||||
assert "sensor.humidity" in ids
|
||||
|
||||
def test_area_filter_case_insensitive(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, area="KITCHEN")
|
||||
assert result["count"] == 2
|
||||
|
||||
def test_combined_domain_and_area(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, domain="sensor", area="kitchen")
|
||||
assert result["count"] == 1
|
||||
assert result["entities"][0]["entity_id"] == "sensor.temperature"
|
||||
|
||||
def test_summary_includes_friendly_name(self):
|
||||
result = _filter_and_summarize(SAMPLE_STATES, domain="climate")
|
||||
assert result["entities"][0]["friendly_name"] == "Main Thermostat"
|
||||
assert result["entities"][0]["state"] == "heat"
|
||||
|
||||
def test_empty_states_list(self):
|
||||
result = _filter_and_summarize([])
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_missing_attributes_handled(self):
|
||||
states = [{"entity_id": "light.x", "state": "on"}]
|
||||
result = _filter_and_summarize(states)
|
||||
assert result["count"] == 1
|
||||
assert result["entities"][0]["friendly_name"] == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service payload building
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildServicePayload:
|
||||
def test_entity_id_only(self):
|
||||
payload = _build_service_payload(entity_id="light.bedroom")
|
||||
assert payload == {"entity_id": "light.bedroom"}
|
||||
|
||||
def test_data_only(self):
|
||||
payload = _build_service_payload(data={"brightness": 255})
|
||||
assert payload == {"brightness": 255}
|
||||
|
||||
def test_entity_id_and_data(self):
|
||||
payload = _build_service_payload(
|
||||
entity_id="light.bedroom",
|
||||
data={"brightness": 200, "color_name": "blue"},
|
||||
)
|
||||
assert payload["entity_id"] == "light.bedroom"
|
||||
assert payload["brightness"] == 200
|
||||
assert payload["color_name"] == "blue"
|
||||
|
||||
def test_no_args_returns_empty(self):
|
||||
payload = _build_service_payload()
|
||||
assert payload == {}
|
||||
|
||||
def test_data_does_not_overwrite_entity_id(self):
|
||||
payload = _build_service_payload(
|
||||
entity_id="light.a",
|
||||
data={"entity_id": "light.b"},
|
||||
)
|
||||
# data.update overwrites entity_id set earlier
|
||||
assert payload["entity_id"] == "light.b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service response parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseServiceResponse:
|
||||
def test_list_response_extracts_entities(self):
|
||||
ha_response = [
|
||||
{"entity_id": "light.bedroom", "state": "on", "attributes": {}},
|
||||
{"entity_id": "light.kitchen", "state": "on", "attributes": {}},
|
||||
]
|
||||
result = _parse_service_response("light", "turn_on", ha_response)
|
||||
assert result["success"] is True
|
||||
assert result["service"] == "light.turn_on"
|
||||
assert len(result["affected_entities"]) == 2
|
||||
assert result["affected_entities"][0]["entity_id"] == "light.bedroom"
|
||||
|
||||
def test_empty_list_response(self):
|
||||
result = _parse_service_response("scene", "turn_on", [])
|
||||
assert result["success"] is True
|
||||
assert result["affected_entities"] == []
|
||||
|
||||
def test_non_list_response(self):
|
||||
# Some HA services return a dict instead of a list
|
||||
result = _parse_service_response("script", "run", {"result": "ok"})
|
||||
assert result["success"] is True
|
||||
assert result["affected_entities"] == []
|
||||
|
||||
def test_none_response(self):
|
||||
result = _parse_service_response("automation", "trigger", None)
|
||||
assert result["success"] is True
|
||||
assert result["affected_entities"] == []
|
||||
|
||||
def test_service_name_format(self):
|
||||
result = _parse_service_response("climate", "set_temperature", [])
|
||||
assert result["service"] == "climate.set_temperature"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Handler validation (no mocks - these paths don't reach the network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandlerValidation:
|
||||
def test_get_state_missing_entity_id(self):
|
||||
result = json.loads(_handle_get_state({}))
|
||||
assert "error" in result
|
||||
assert "entity_id" in result["error"]
|
||||
|
||||
def test_get_state_empty_entity_id(self):
|
||||
result = json.loads(_handle_get_state({"entity_id": ""}))
|
||||
assert "error" in result
|
||||
|
||||
def test_call_service_missing_domain(self):
|
||||
result = json.loads(_handle_call_service({"service": "turn_on"}))
|
||||
assert "error" in result
|
||||
assert "domain" in result["error"]
|
||||
|
||||
def test_call_service_missing_service(self):
|
||||
result = json.loads(_handle_call_service({"domain": "light"}))
|
||||
assert "error" in result
|
||||
assert "service" in result["error"]
|
||||
|
||||
def test_call_service_missing_both(self):
|
||||
result = json.loads(_handle_call_service({}))
|
||||
assert "error" in result
|
||||
|
||||
def test_call_service_empty_strings(self):
|
||||
result = json.loads(_handle_call_service({"domain": "", "service": ""}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCheckAvailable:
|
||||
def test_unavailable_without_token(self, monkeypatch):
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
assert _check_ha_available() is False
|
||||
|
||||
def test_available_with_token(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "eyJ0eXAiOiJKV1Q")
|
||||
assert _check_ha_available() is True
|
||||
|
||||
def test_empty_token_is_unavailable(self, monkeypatch):
|
||||
monkeypatch.setenv("HASS_TOKEN", "")
|
||||
assert _check_ha_available() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetHeaders:
|
||||
def test_bearer_token_format(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.homeassistant_tool._HASS_TOKEN", "my-secret-token")
|
||||
headers = _get_headers()
|
||||
assert headers["Authorization"] == "Bearer my-secret-token"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistration:
|
||||
def test_tools_registered_in_registry(self):
|
||||
from tools.registry import registry
|
||||
|
||||
names = registry.get_all_tool_names()
|
||||
assert "ha_list_entities" in names
|
||||
assert "ha_get_state" in names
|
||||
assert "ha_call_service" in names
|
||||
|
||||
def test_tools_in_homeassistant_toolset(self):
|
||||
from tools.registry import registry
|
||||
|
||||
toolset_map = registry.get_tool_to_toolset_map()
|
||||
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
|
||||
assert toolset_map[tool] == "homeassistant"
|
||||
|
||||
def test_check_fn_gates_availability(self, monkeypatch):
|
||||
"""Registry should exclude HA tools when HASS_TOKEN is not set."""
|
||||
from tools.registry import registry
|
||||
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
|
||||
assert len(defs) == 0
|
||||
|
||||
def test_check_fn_includes_when_token_set(self, monkeypatch):
|
||||
"""Registry should include HA tools when HASS_TOKEN is set."""
|
||||
from tools.registry import registry
|
||||
|
||||
monkeypatch.setenv("HASS_TOKEN", "test-token")
|
||||
defs = registry.get_definitions({"ha_list_entities", "ha_get_state", "ha_call_service"})
|
||||
assert len(defs) == 3
|
||||
364
tools/homeassistant_tool.py
Normal file
364
tools/homeassistant_tool.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Home Assistant tool for controlling smart home devices via REST API.
|
||||
|
||||
Registers three LLM-callable tools:
|
||||
- ``ha_list_entities`` -- list/filter entities by domain or area
|
||||
- ``ha_get_state`` -- get detailed state of a single entity
|
||||
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
|
||||
|
||||
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
|
||||
The HA instance URL is read from ``HASS_URL`` (default: http://homeassistant.local:8123).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
|
||||
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
|
||||
|
||||
|
||||
def _get_headers() -> Dict[str, str]:
|
||||
"""Return authorization headers for HA REST API."""
|
||||
return {
|
||||
"Authorization": f"Bearer {_HASS_TOKEN}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async helpers (called from sync handlers via run_until_complete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _filter_and_summarize(
|
||||
states: list,
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Filter raw HA states by domain/area and return a compact summary."""
|
||||
if domain:
|
||||
states = [s for s in states if s.get("entity_id", "").startswith(f"{domain}.")]
|
||||
|
||||
if area:
|
||||
area_lower = area.lower()
|
||||
states = [
|
||||
s for s in states
|
||||
if area_lower in (s.get("attributes", {}).get("friendly_name", "") or "").lower()
|
||||
or area_lower in (s.get("attributes", {}).get("area", "") or "").lower()
|
||||
]
|
||||
|
||||
entities = []
|
||||
for s in states:
|
||||
entities.append({
|
||||
"entity_id": s["entity_id"],
|
||||
"state": s["state"],
|
||||
"friendly_name": s.get("attributes", {}).get("friendly_name", ""),
|
||||
})
|
||||
|
||||
return {"count": len(entities), "entities": entities}
|
||||
|
||||
|
||||
async def _async_list_entities(
|
||||
domain: Optional[str] = None,
|
||||
area: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Fetch entity states from HA and optionally filter by domain/area."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{_HASS_URL}/api/states"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp:
|
||||
resp.raise_for_status()
|
||||
states = await resp.json()
|
||||
|
||||
return _filter_and_summarize(states, domain, area)
|
||||
|
||||
|
||||
async def _async_get_state(entity_id: str) -> Dict[str, Any]:
|
||||
"""Fetch detailed state of a single entity."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{_HASS_URL}/api/states/{entity_id}"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
resp.raise_for_status()
|
||||
data = await resp.json()
|
||||
|
||||
return {
|
||||
"entity_id": data["entity_id"],
|
||||
"state": data["state"],
|
||||
"attributes": data.get("attributes", {}),
|
||||
"last_changed": data.get("last_changed"),
|
||||
"last_updated": data.get("last_updated"),
|
||||
}
|
||||
|
||||
|
||||
def _build_service_payload(
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON payload for a HA service call."""
|
||||
payload: Dict[str, Any] = {}
|
||||
if entity_id:
|
||||
payload["entity_id"] = entity_id
|
||||
if data:
|
||||
payload.update(data)
|
||||
return payload
|
||||
|
||||
|
||||
def _parse_service_response(
|
||||
domain: str,
|
||||
service: str,
|
||||
result: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Parse HA service call response into a structured result."""
|
||||
affected = []
|
||||
if isinstance(result, list):
|
||||
for s in result:
|
||||
affected.append({
|
||||
"entity_id": s.get("entity_id", ""),
|
||||
"state": s.get("state", ""),
|
||||
})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"service": f"{domain}.{service}",
|
||||
"affected_entities": affected,
|
||||
}
|
||||
|
||||
|
||||
async def _async_call_service(
|
||||
domain: str,
|
||||
service: str,
|
||||
entity_id: Optional[str] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Call a Home Assistant service."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{_HASS_URL}/api/services/{domain}/{service}"
|
||||
payload = _build_service_payload(entity_id, data)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
headers=_get_headers(),
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
result = await resp.json()
|
||||
|
||||
return _parse_service_response(domain, service, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync wrappers (handler signature: (args, **kw) -> str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync handler."""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# Already inside an event loop -- create a new thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(asyncio.run, coro)
|
||||
return future.result(timeout=30)
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def _handle_list_entities(args: dict, **kw) -> str:
|
||||
"""Handler for ha_list_entities tool."""
|
||||
domain = args.get("domain")
|
||||
area = args.get("area")
|
||||
try:
|
||||
result = _run_async(_async_list_entities(domain=domain, area=area))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_list_entities error: %s", e)
|
||||
return json.dumps({"error": f"Failed to list entities: {e}"})
|
||||
|
||||
|
||||
def _handle_get_state(args: dict, **kw) -> str:
|
||||
"""Handler for ha_get_state tool."""
|
||||
entity_id = args.get("entity_id", "")
|
||||
if not entity_id:
|
||||
return json.dumps({"error": "Missing required parameter: entity_id"})
|
||||
try:
|
||||
result = _run_async(_async_get_state(entity_id))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_get_state error: %s", e)
|
||||
return json.dumps({"error": f"Failed to get state for {entity_id}: {e}"})
|
||||
|
||||
|
||||
def _handle_call_service(args: dict, **kw) -> str:
|
||||
"""Handler for ha_call_service tool."""
|
||||
domain = args.get("domain", "")
|
||||
service = args.get("service", "")
|
||||
if not domain or not service:
|
||||
return json.dumps({"error": "Missing required parameters: domain and service"})
|
||||
|
||||
entity_id = args.get("entity_id")
|
||||
data = args.get("data")
|
||||
try:
|
||||
result = _run_async(_async_call_service(domain, service, entity_id, data))
|
||||
return json.dumps({"result": result})
|
||||
except Exception as e:
|
||||
logger.error("ha_call_service error: %s", e)
|
||||
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Availability check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _check_ha_available() -> bool:
|
||||
"""Tool is only available when HASS_TOKEN is set."""
|
||||
return bool(os.getenv("HASS_TOKEN"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HA_LIST_ENTITIES_SCHEMA = {
|
||||
"name": "ha_list_entities",
|
||||
"description": (
|
||||
"List Home Assistant entities. Optionally filter by domain "
|
||||
"(light, switch, climate, sensor, binary_sensor, cover, fan, etc.) "
|
||||
"or by area name (living room, kitchen, bedroom, etc.)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Entity domain to filter by (e.g. 'light', 'switch', 'climate', "
|
||||
"'sensor', 'binary_sensor', 'cover', 'fan', 'media_player'). "
|
||||
"Omit to list all entities."
|
||||
),
|
||||
},
|
||||
"area": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Area/room name to filter by (e.g. 'living room', 'kitchen'). "
|
||||
"Matches against entity friendly names. Omit to list all."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
HA_GET_STATE_SCHEMA = {
|
||||
"name": "ha_get_state",
|
||||
"description": (
|
||||
"Get the detailed state of a single Home Assistant entity, including all "
|
||||
"attributes (brightness, color, temperature setpoint, sensor readings, etc.)."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The entity ID to query (e.g. 'light.living_room', "
|
||||
"'climate.thermostat', 'sensor.temperature')."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["entity_id"],
|
||||
},
|
||||
}
|
||||
|
||||
HA_CALL_SERVICE_SCHEMA = {
|
||||
"name": "ha_call_service",
|
||||
"description": (
|
||||
"Call a Home Assistant service to control a device. Common examples: "
|
||||
"turn_on/turn_off lights and switches, set_temperature for climate, "
|
||||
"open_cover/close_cover for blinds, set_volume_level for media players."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"domain": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Service domain (e.g. 'light', 'switch', 'climate', "
|
||||
"'cover', 'media_player', 'fan', 'scene', 'script')."
|
||||
),
|
||||
},
|
||||
"service": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Service name (e.g. 'turn_on', 'turn_off', 'toggle', "
|
||||
"'set_temperature', 'set_hvac_mode', 'open_cover', "
|
||||
"'close_cover', 'set_volume_level')."
|
||||
),
|
||||
},
|
||||
"entity_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Target entity ID (e.g. 'light.living_room'). "
|
||||
"Some services (like scene.turn_on) may not need this."
|
||||
),
|
||||
},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"description": (
|
||||
"Additional service data. Examples: "
|
||||
'{"brightness": 255, "color_name": "blue"} for lights, '
|
||||
'{"temperature": 22, "hvac_mode": "heat"} for climate, '
|
||||
'{"volume_level": 0.5} for media players.'
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["domain", "service"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name="ha_list_entities",
|
||||
toolset="homeassistant",
|
||||
schema=HA_LIST_ENTITIES_SCHEMA,
|
||||
handler=_handle_list_entities,
|
||||
check_fn=_check_ha_available,
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="ha_get_state",
|
||||
toolset="homeassistant",
|
||||
schema=HA_GET_STATE_SCHEMA,
|
||||
handler=_handle_get_state,
|
||||
check_fn=_check_ha_available,
|
||||
)
|
||||
|
||||
registry.register(
|
||||
name="ha_call_service",
|
||||
toolset="homeassistant",
|
||||
schema=HA_CALL_SERVICE_SCHEMA,
|
||||
handler=_handle_call_service,
|
||||
check_fn=_check_ha_available,
|
||||
)
|
||||
20
toolsets.py
20
toolsets.py
@@ -62,6 +62,8 @@ _HERMES_CORE_TOOLS = [
|
||||
"send_message",
|
||||
# Honcho user context (gated on honcho being active via check_fn)
|
||||
"query_user_context",
|
||||
# Home Assistant smart home control (gated on HASS_TOKEN via check_fn)
|
||||
"ha_list_entities", "ha_get_state", "ha_call_service",
|
||||
]
|
||||
|
||||
|
||||
@@ -193,8 +195,14 @@ TOOLSETS = {
|
||||
"tools": ["query_user_context"],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
|
||||
|
||||
"homeassistant": {
|
||||
"description": "Home Assistant smart home control and monitoring",
|
||||
"tools": ["ha_list_entities", "ha_get_state", "ha_call_service"],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
|
||||
# Scenario-specific toolsets
|
||||
|
||||
"debugging": {
|
||||
@@ -247,10 +255,16 @@ TOOLSETS = {
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-homeassistant": {
|
||||
"description": "Home Assistant bot toolset - smart home event monitoring and control",
|
||||
"tools": _HERMES_CORE_TOOLS,
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"hermes-gateway": {
|
||||
"description": "Gateway toolset - union of all messaging platform tools",
|
||||
"tools": [],
|
||||
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack"]
|
||||
"includes": ["hermes-telegram", "hermes-discord", "hermes-whatsapp", "hermes-slack", "hermes-homeassistant"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
7
uv.lock
generated
7
uv.lock
generated
@@ -1034,6 +1034,9 @@ dev = [
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
]
|
||||
homeassistant = [
|
||||
{ name = "aiohttp" },
|
||||
]
|
||||
honcho = [
|
||||
{ name = "honcho-ai" },
|
||||
]
|
||||
@@ -1060,6 +1063,7 @@ tts-premium = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiohttp", marker = "extra == 'homeassistant'", specifier = ">=3.9.0" },
|
||||
{ name = "aiohttp", marker = "extra == 'messaging'", specifier = ">=3.9.0" },
|
||||
{ name = "croniter", marker = "extra == 'cron'" },
|
||||
{ name = "discord-py", marker = "extra == 'messaging'", specifier = ">=2.0" },
|
||||
@@ -1071,6 +1075,7 @@ requires-dist = [
|
||||
{ name = "hermes-agent", extras = ["cli"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["cron"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["dev"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["homeassistant"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["honcho"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["messaging"], marker = "extra == 'all'" },
|
||||
{ name = "hermes-agent", extras = ["modal"], marker = "extra == 'all'" },
|
||||
@@ -1103,7 +1108,7 @@ requires-dist = [
|
||||
{ name = "tenacity" },
|
||||
{ name = "typer" },
|
||||
]
|
||||
provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "all"]
|
||||
provides-extras = ["modal", "dev", "messaging", "cron", "slack", "cli", "tts-premium", "pty", "honcho", "homeassistant", "all"]
|
||||
|
||||
[[package]]
|
||||
name = "hf-xet"
|
||||
|
||||
Reference in New Issue
Block a user