feat: enhance Home Assistant integration with service discovery and setup

Improvements to the HA integration merged from PR #184:

- Add ha_list_services tool: discovers available services (actions) per
  domain with descriptions and parameter fields. Tells the model what
  it can do with each device type (e.g. light.turn_on accepts brightness,
  color_name, transition). Closes the gap where the model had to guess
  available actions.

- Add HA to hermes tools config: users can enable/disable the homeassistant
  toolset and configure HASS_TOKEN + HASS_URL through 'hermes tools' setup
  flow instead of manually editing .env.

- Fix should-fix items from code review:
  - Remove sys.path.insert hack from gateway adapter
  - Replace all print() calls with proper logger (info/warning/error)
  - Move env var reads from import-time to handler-time via _get_config()
  - Add dedicated REST session reuse in gateway send()

- Update ha_call_service description to reference ha_list_services for
  action discovery.

- Update tests for new ha_list_services tool in toolset resolution.
This commit is contained in:
teknium1
2026-03-03 05:16:53 -08:00
parent db0521ce0e
commit ffec21236d
5 changed files with 145 additions and 32 deletions

View File

@@ -28,10 +28,6 @@ except ImportError:
AIOHTTP_AVAILABLE = False
aiohttp = None # type: ignore[assignment]
import sys
from pathlib import Path as _Path
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import (
BasePlatformAdapter,
@@ -72,6 +68,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
# Connection state
self._session: Optional["aiohttp.ClientSession"] = None
self._ws: Optional["aiohttp.ClientWebSocketResponse"] = None
self._rest_session: Optional["aiohttp.ClientSession"] = None
self._listen_task: Optional[asyncio.Task] = None
self._msg_id: int = 0
@@ -103,11 +100,11 @@ class HomeAssistantAdapter(BasePlatformAdapter):
async def connect(self) -> bool:
"""Connect to HA WebSocket API and subscribe to events."""
if not AIOHTTP_AVAILABLE:
print(f"[{self.name}] aiohttp not installed. Run: pip install aiohttp")
logger.warning("[%s] aiohttp not installed. Run: pip install aiohttp", self.name)
return False
if not self._hass_token:
print(f"[{self.name}] No HASS_TOKEN configured")
logger.warning("[%s] No HASS_TOKEN configured", self.name)
return False
try:
@@ -115,14 +112,17 @@ class HomeAssistantAdapter(BasePlatformAdapter):
if not success:
return False
# Dedicated REST session for send() calls
self._rest_session = aiohttp.ClientSession()
# Start background listener
self._listen_task = asyncio.create_task(self._listen_loop())
self._running = True
print(f"[{self.name}] Connected to {self._hass_url}")
logger.info("[%s] Connected to %s", self.name, self._hass_url)
return True
except Exception as e:
print(f"[{self.name}] Failed to connect: {e}")
logger.error("[%s] Failed to connect: %s", self.name, e)
return False
async def _ws_connect(self) -> bool:
@@ -191,7 +191,10 @@ class HomeAssistantAdapter(BasePlatformAdapter):
self._listen_task = None
await self._cleanup_ws()
print(f"[{self.name}] Disconnected")
if self._rest_session and not self._rest_session.closed:
await self._rest_session.close()
self._rest_session = None
logger.info("[%s] Disconnected", self.name)
# ------------------------------------------------------------------
# Event listener
@@ -214,7 +217,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
# Reconnect with backoff
delay = self._BACKOFF_STEPS[min(backoff_idx, len(self._BACKOFF_STEPS) - 1)]
print(f"[{self.name}] Reconnecting in {delay}s...")
logger.info("[%s] Reconnecting in %ds...", self.name, delay)
await asyncio.sleep(delay)
backoff_idx += 1
@@ -223,7 +226,7 @@ class HomeAssistantAdapter(BasePlatformAdapter):
success = await self._ws_connect()
if success:
backoff_idx = 0 # Reset on successful reconnect
print(f"[{self.name}] Reconnected")
logger.info("[%s] Reconnected", self.name)
except Exception as e:
logger.warning("[%s] Reconnection failed: %s", self.name, e)
@@ -385,8 +388,8 @@ class HomeAssistantAdapter(BasePlatformAdapter):
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
if self._rest_session:
async with self._rest_session.post(
url,
headers=headers,
json=payload,
@@ -397,6 +400,19 @@ class HomeAssistantAdapter(BasePlatformAdapter):
else:
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
else:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=10),
) as resp:
if resp.status < 300:
return SendResult(success=True, message_id=uuid.uuid4().hex[:12])
else:
body = await resp.text()
return SendResult(success=False, error=f"HTTP {resp.status}: {body}")
except asyncio.TimeoutError:
return SendResult(success=False, error="Timeout sending notification to HA")

View File

@@ -36,6 +36,7 @@ CONFIGURABLE_TOOLSETS = [
("delegation", "👥 Task Delegation", "delegate_task"),
("cronjob", "⏰ Cron Jobs", "schedule, list, remove"),
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
("homeassistant", "🏠 Home Assistant", "smart home device control"),
]
# Platform display config
@@ -312,6 +313,8 @@ TOOLSET_ENV_REQUIREMENTS = {
"tts": [], # Edge TTS is free, no key needed
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
("WANDB_API_KEY", "https://wandb.ai/authorize")],
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
("HASS_URL", None)],
}

View File

@@ -569,19 +569,19 @@ class TestToolsetIntegration:
from toolsets import resolve_toolset
tools = resolve_toolset("homeassistant")
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service"}
assert set(tools) == {"ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"}
def test_gateway_toolset_includes_ha_tools(self):
from toolsets import resolve_toolset
gateway_tools = resolve_toolset("hermes-gateway")
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
assert tool in gateway_tools
def test_hermes_core_tools_includes_ha(self):
from toolsets import _HERMES_CORE_TOOLS
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service"):
for tool in ("ha_list_entities", "ha_get_state", "ha_call_service", "ha_list_services"):
assert tool in _HERMES_CORE_TOOLS

View File

@@ -1,8 +1,9 @@
"""Home Assistant tool for controlling smart home devices via REST API.
Registers three LLM-callable tools:
Registers four LLM-callable tools:
- ``ha_list_entities`` -- list/filter entities by domain or area
- ``ha_get_state`` -- get detailed state of a single entity
- ``ha_list_services`` -- list available services (actions) per domain
- ``ha_call_service`` -- call a HA service (turn_on, turn_off, set_temperature, etc.)
Authentication uses a Long-Lived Access Token via ``HASS_TOKEN`` env var.
@@ -22,8 +23,17 @@ logger = logging.getLogger(__name__)
# Configuration
# ---------------------------------------------------------------------------
_HASS_URL: str = os.getenv("HASS_URL", "http://homeassistant.local:8123").rstrip("/")
_HASS_TOKEN: str = os.getenv("HASS_TOKEN", "")
# Kept for backward compatibility (e.g. test monkeypatching); prefer _get_config().
_HASS_URL: str = ""
_HASS_TOKEN: str = ""
def _get_config():
"""Return (hass_url, hass_token) from env vars at call time."""
return (
(_HASS_URL or os.getenv("HASS_URL", "http://homeassistant.local:8123")).rstrip("/"),
_HASS_TOKEN or os.getenv("HASS_TOKEN", ""),
)
# Regex for valid HA entity_id format (e.g. "light.living_room", "sensor.temperature_1")
_ENTITY_ID_RE = re.compile(r"^[a-z_][a-z0-9_]*\.[a-z0-9_]+$")
@@ -41,10 +51,12 @@ _BLOCKED_DOMAINS = frozenset({
})
def _get_headers() -> Dict[str, str]:
def _get_headers(token: str = "") -> Dict[str, str]:
"""Return authorization headers for HA REST API."""
if not token:
_, token = _get_config()
return {
"Authorization": f"Bearer {_HASS_TOKEN}",
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
@@ -88,9 +100,10 @@ async def _async_list_entities(
"""Fetch entity states from HA and optionally filter by domain/area."""
import aiohttp
url = f"{_HASS_URL}/api/states"
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/states"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=15)) as resp:
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=15)) as resp:
resp.raise_for_status()
states = await resp.json()
@@ -101,9 +114,10 @@ async def _async_get_state(entity_id: str) -> Dict[str, Any]:
"""Fetch detailed state of a single entity."""
import aiohttp
url = f"{_HASS_URL}/api/states/{entity_id}"
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/states/{entity_id}"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=_get_headers(), timeout=aiohttp.ClientTimeout(total=10)) as resp:
async with session.get(url, headers=_get_headers(hass_token), timeout=aiohttp.ClientTimeout(total=10)) as resp:
resp.raise_for_status()
data = await resp.json()
@@ -160,13 +174,14 @@ async def _async_call_service(
"""Call a Home Assistant service."""
import aiohttp
url = f"{_HASS_URL}/api/services/{domain}/{service}"
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/services/{domain}/{service}"
payload = _build_service_payload(entity_id, data)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
headers=_get_headers(),
headers=_get_headers(hass_token),
json=payload,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
@@ -250,6 +265,55 @@ def _handle_call_service(args: dict, **kw) -> str:
return json.dumps({"error": f"Failed to call {domain}.{service}: {e}"})
# ---------------------------------------------------------------------------
# List services
# ---------------------------------------------------------------------------
async def _async_list_services(domain: Optional[str] = None) -> Dict[str, Any]:
"""Fetch available services from HA and optionally filter by domain."""
import aiohttp
hass_url, hass_token = _get_config()
url = f"{hass_url}/api/services"
headers = {"Authorization": f"Bearer {hass_token}", "Content-Type": "application/json"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=15)) as resp:
resp.raise_for_status()
services = await resp.json()
if domain:
services = [s for s in services if s.get("domain") == domain]
# Compact the output for context efficiency
result = []
for svc_domain in services:
d = svc_domain.get("domain", "")
domain_services = {}
for svc_name, svc_info in svc_domain.get("services", {}).items():
svc_entry: Dict[str, Any] = {"description": svc_info.get("description", "")}
fields = svc_info.get("fields", {})
if fields:
svc_entry["fields"] = {
k: v.get("description", "") for k, v in fields.items()
if isinstance(v, dict)
}
domain_services[svc_name] = svc_entry
result.append({"domain": d, "services": domain_services})
return {"count": len(result), "domains": result}
def _handle_list_services(args: dict, **kw) -> str:
"""Handler for ha_list_services tool."""
domain = args.get("domain")
try:
result = _run_async(_async_list_services(domain=domain))
return json.dumps({"result": result})
except Exception as e:
logger.error("ha_list_services error: %s", e)
return json.dumps({"error": f"Failed to list services: {e}"})
# ---------------------------------------------------------------------------
# Availability check
# ---------------------------------------------------------------------------
@@ -314,12 +378,34 @@ HA_GET_STATE_SCHEMA = {
},
}
HA_LIST_SERVICES_SCHEMA = {
"name": "ha_list_services",
"description": (
"List available Home Assistant services (actions) for device control. "
"Shows what actions can be performed on each device type and what "
"parameters they accept. Use this to discover how to control devices "
"found via ha_list_entities."
),
"parameters": {
"type": "object",
"properties": {
"domain": {
"type": "string",
"description": (
"Filter by domain (e.g. 'light', 'climate', 'switch'). "
"Omit to list services for all domains."
),
},
},
"required": [],
},
}
HA_CALL_SERVICE_SCHEMA = {
"name": "ha_call_service",
"description": (
"Call a Home Assistant service to control a device. Common examples: "
"turn_on/turn_off lights and switches, set_temperature for climate, "
"open_cover/close_cover for blinds, set_volume_level for media players."
"Call a Home Assistant service to control a device. Use ha_list_services "
"to discover available services and their parameters for each domain."
),
"parameters": {
"type": "object",
@@ -383,6 +469,14 @@ registry.register(
check_fn=_check_ha_available,
)
registry.register(
name="ha_list_services",
toolset="homeassistant",
schema=HA_LIST_SERVICES_SCHEMA,
handler=_handle_list_services,
check_fn=_check_ha_available,
)
registry.register(
name="ha_call_service",
toolset="homeassistant",

View File

@@ -63,7 +63,7 @@ _HERMES_CORE_TOOLS = [
# Honcho user context (gated on honcho being active via check_fn)
"query_user_context",
# Home Assistant smart home control (gated on HASS_TOKEN via check_fn)
"ha_list_entities", "ha_get_state", "ha_call_service",
"ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service",
]
@@ -198,7 +198,7 @@ TOOLSETS = {
"homeassistant": {
"description": "Home Assistant smart home control and monitoring",
"tools": ["ha_list_entities", "ha_get_state", "ha_call_service"],
"tools": ["ha_list_entities", "ha_get_state", "ha_list_services", "ha_call_service"],
"includes": []
},