diff --git a/.env.example b/.env.example index 638610e4d..d273a6966 100644 --- a/.env.example +++ b/.env.example @@ -65,10 +65,15 @@ OPENCODE_GO_API_KEY= # TOOL API KEYS # ============================================================================= +# Parallel API Key - AI-native web search and extract +# Get at: https://parallel.ai +PARALLEL_API_KEY= + # Firecrawl API Key - Web search, extract, and crawl # Get at: https://firecrawl.dev/ FIRECRAWL_API_KEY= + # FAL.ai API Key - Image generation # Get at: https://fal.ai/ FAL_KEY= diff --git a/AGENTS.md b/AGENTS.md index c1fa098bf..13998fe1d 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -44,7 +44,7 @@ hermes-agent/ │ ├── terminal_tool.py # Terminal orchestration │ ├── process_registry.py # Background process management │ ├── file_tools.py # File read/write/search/patch -│ ├── web_tools.py # Firecrawl search/extract +│ ├── web_tools.py # Web search/extract (Parallel + Firecrawl) │ ├── browser_tool.py # Browserbase browser automation │ ├── code_execution_tool.py # execute_code sandbox │ ├── delegate_tool.py # Subagent delegation diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d866539ad..25cddde6e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -147,7 +147,7 @@ hermes-agent/ │ ├── approval.py # Dangerous command detection + per-session approval │ ├── terminal_tool.py # Terminal orchestration (sudo, env lifecycle, backends) │ ├── file_operations.py # read_file, write_file, search, patch, etc. -│ ├── web_tools.py # web_search, web_extract (Firecrawl + Gemini summarization) +│ ├── web_tools.py # web_search, web_extract (Parallel/Firecrawl + Gemini summarization) │ ├── vision_tools.py # Image analysis via multimodal models │ ├── delegate_tool.py # Subagent spawning and parallel task execution │ ├── code_execution_tool.py # Sandboxed Python with RPC tool access diff --git a/agent/anthropic_adapter.py b/agent/anthropic_adapter.py index 3e1bd85bb..30958f0f5 100644 --- a/agent/anthropic_adapter.py +++ b/agent/anthropic_adapter.py @@ -963,8 +963,12 @@ def convert_messages_to_anthropic( elif isinstance(prev_blocks, str) and isinstance(curr_blocks, str): fixed[-1]["content"] = prev_blocks + "\n" + curr_blocks else: - # Keep the later message - fixed[-1] = m + # Mixed types — normalize both to list and merge + if isinstance(prev_blocks, str): + prev_blocks = [{"type": "text", "text": prev_blocks}] + if isinstance(curr_blocks, str): + curr_blocks = [{"type": "text", "text": curr_blocks}] + fixed[-1]["content"] = prev_blocks + curr_blocks else: fixed.append(m) result = fixed @@ -1049,7 +1053,8 @@ def build_anthropic_kwargs( elif tool_choice == "required": kwargs["tool_choice"] = {"type": "any"} elif tool_choice == "none": - pass # Don't send tool_choice — Anthropic will use tools if needed + # Anthropic has no tool_choice "none" — omit tools entirely to prevent use + kwargs.pop("tools", None) elif isinstance(tool_choice, str): # Specific tool name kwargs["tool_choice"] = {"type": "tool", "name": tool_choice} diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index d008361b5..a0807d8ab 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -39,6 +39,7 @@ custom OpenAI-compatible endpoint without touching the main model settings. import json import logging import os +import threading from pathlib import Path from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple @@ -705,6 +706,8 @@ def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[st def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]: """Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None.""" + global auxiliary_is_nous + auxiliary_is_nous = False # Reset — _try_nous() will set True if it wins for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint, _try_codex, _resolve_api_key_provider): client, model = try_fn() @@ -1171,6 +1174,7 @@ def auxiliary_max_tokens_param(value: int) -> dict: # Client cache: (provider, async_mode, base_url, api_key) -> (client, default_model) _client_cache: Dict[tuple, tuple] = {} +_client_cache_lock = threading.Lock() def _get_cached_client( @@ -1182,9 +1186,11 @@ def _get_cached_client( ) -> Tuple[Optional[Any], Optional[str]]: """Get or create a cached client for the given provider.""" cache_key = (provider, async_mode, base_url or "", api_key or "") - if cache_key in _client_cache: - cached_client, cached_default = _client_cache[cache_key] - return cached_client, model or cached_default + with _client_cache_lock: + if cache_key in _client_cache: + cached_client, cached_default = _client_cache[cache_key] + return cached_client, model or cached_default + # Build outside the lock client, default_model = resolve_provider_client( provider, model, @@ -1193,7 +1199,11 @@ def _get_cached_client( explicit_api_key=api_key, ) if client is not None: - _client_cache[cache_key] = (client, default_model) + with _client_cache_lock: + if cache_key not in _client_cache: + _client_cache[cache_key] = (client, default_model) + else: + client, default_model = _client_cache[cache_key] return client, model or default_model diff --git a/agent/context_compressor.py b/agent/context_compressor.py index aa05a8daa..22ce32f34 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -313,7 +313,19 @@ Write only the summary body. Do not include any preamble or prefix; the system w if summary: last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user" - summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant" + first_tail_role = messages[compress_end].get("role", "user") if compress_end < n_messages else "user" + # Pick a role that avoids consecutive same-role with both neighbors. + # Priority: avoid colliding with head (already committed), then tail. + if last_head_role in ("assistant", "tool"): + summary_role = "user" + else: + summary_role = "assistant" + # If the chosen role collides with the tail AND flipping wouldn't + # collide with the head, flip it. + if summary_role == first_tail_role: + flipped = "assistant" if summary_role == "user" else "user" + if flipped != last_head_role: + summary_role = flipped compressed.append({"role": summary_role, "content": summary}) else: if not self.quiet_mode: diff --git a/agent/insights.py b/agent/insights.py index 8fc55e043..64a37f11b 100644 --- a/agent/insights.py +++ b/agent/insights.py @@ -22,14 +22,21 @@ from collections import Counter, defaultdict from datetime import datetime from typing import Any, Dict, List -from agent.usage_pricing import DEFAULT_PRICING, estimate_cost_usd, format_duration_compact, get_pricing, has_known_pricing +from agent.usage_pricing import ( + CanonicalUsage, + DEFAULT_PRICING, + estimate_usage_cost, + format_duration_compact, + get_pricing, + has_known_pricing, +) _DEFAULT_PRICING = DEFAULT_PRICING -def _has_known_pricing(model_name: str) -> bool: +def _has_known_pricing(model_name: str, provider: str = None, base_url: str = None) -> bool: """Check if a model has known pricing (vs unknown/custom endpoint).""" - return has_known_pricing(model_name) + return has_known_pricing(model_name, provider=provider, base_url=base_url) def _get_pricing(model_name: str) -> Dict[str, float]: @@ -41,9 +48,43 @@ def _get_pricing(model_name: str) -> Dict[str, float]: return get_pricing(model_name) -def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float: - """Estimate the USD cost for a given model and token counts.""" - return estimate_cost_usd(model, input_tokens, output_tokens) +def _estimate_cost( + session_or_model: Dict[str, Any] | str, + input_tokens: int = 0, + output_tokens: int = 0, + *, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + provider: str = None, + base_url: str = None, +) -> tuple[float, str]: + """Estimate the USD cost for a session row or a model/token tuple.""" + if isinstance(session_or_model, dict): + session = session_or_model + model = session.get("model") or "" + usage = CanonicalUsage( + input_tokens=session.get("input_tokens") or 0, + output_tokens=session.get("output_tokens") or 0, + cache_read_tokens=session.get("cache_read_tokens") or 0, + cache_write_tokens=session.get("cache_write_tokens") or 0, + ) + provider = session.get("billing_provider") + base_url = session.get("billing_base_url") + else: + model = session_or_model or "" + usage = CanonicalUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + ) + result = estimate_usage_cost( + model, + usage, + provider=provider, + base_url=base_url, + ) + return float(result.amount_usd or 0.0), result.status def _format_duration(seconds: float) -> str: @@ -135,7 +176,10 @@ class InsightsEngine: # Columns we actually need (skip system_prompt, model_config blobs) _SESSION_COLS = ("id, source, model, started_at, ended_at, " - "message_count, tool_call_count, input_tokens, output_tokens") + "message_count, tool_call_count, input_tokens, output_tokens, " + "cache_read_tokens, cache_write_tokens, billing_provider, " + "billing_base_url, billing_mode, estimated_cost_usd, " + "actual_cost_usd, cost_status, cost_source") def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]: """Fetch sessions within the time window.""" @@ -287,21 +331,30 @@ class InsightsEngine: """Compute high-level overview statistics.""" total_input = sum(s.get("input_tokens") or 0 for s in sessions) total_output = sum(s.get("output_tokens") or 0 for s in sessions) - total_tokens = total_input + total_output + total_cache_read = sum(s.get("cache_read_tokens") or 0 for s in sessions) + total_cache_write = sum(s.get("cache_write_tokens") or 0 for s in sessions) + total_tokens = total_input + total_output + total_cache_read + total_cache_write total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions) total_messages = sum(s.get("message_count") or 0 for s in sessions) # Cost estimation (weighted by model) total_cost = 0.0 + actual_cost = 0.0 models_with_pricing = set() models_without_pricing = set() + unknown_cost_sessions = 0 + included_cost_sessions = 0 for s in sessions: model = s.get("model") or "" - inp = s.get("input_tokens") or 0 - out = s.get("output_tokens") or 0 - total_cost += _estimate_cost(model, inp, out) + estimated, status = _estimate_cost(s) + total_cost += estimated + actual_cost += s.get("actual_cost_usd") or 0.0 display = model.split("/")[-1] if "/" in model else (model or "unknown") - if _has_known_pricing(model): + if status == "included": + included_cost_sessions += 1 + elif status == "unknown": + unknown_cost_sessions += 1 + if _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")): models_with_pricing.add(display) else: models_without_pricing.add(display) @@ -328,8 +381,11 @@ class InsightsEngine: "total_tool_calls": total_tool_calls, "total_input_tokens": total_input, "total_output_tokens": total_output, + "total_cache_read_tokens": total_cache_read, + "total_cache_write_tokens": total_cache_write, "total_tokens": total_tokens, "estimated_cost": total_cost, + "actual_cost": actual_cost, "total_hours": total_hours, "avg_session_duration": avg_duration, "avg_messages_per_session": total_messages / len(sessions) if sessions else 0, @@ -341,12 +397,15 @@ class InsightsEngine: "date_range_end": date_range_end, "models_with_pricing": sorted(models_with_pricing), "models_without_pricing": sorted(models_without_pricing), + "unknown_cost_sessions": unknown_cost_sessions, + "included_cost_sessions": included_cost_sessions, } def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]: """Break down usage by model.""" model_data = defaultdict(lambda: { "sessions": 0, "input_tokens": 0, "output_tokens": 0, + "cache_read_tokens": 0, "cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0, "cost": 0.0, }) @@ -358,12 +417,18 @@ class InsightsEngine: d["sessions"] += 1 inp = s.get("input_tokens") or 0 out = s.get("output_tokens") or 0 + cache_read = s.get("cache_read_tokens") or 0 + cache_write = s.get("cache_write_tokens") or 0 d["input_tokens"] += inp d["output_tokens"] += out - d["total_tokens"] += inp + out + d["cache_read_tokens"] += cache_read + d["cache_write_tokens"] += cache_write + d["total_tokens"] += inp + out + cache_read + cache_write d["tool_calls"] += s.get("tool_call_count") or 0 - d["cost"] += _estimate_cost(model, inp, out) - d["has_pricing"] = _has_known_pricing(model) + estimate, status = _estimate_cost(s) + d["cost"] += estimate + d["has_pricing"] = _has_known_pricing(model, s.get("billing_provider"), s.get("billing_base_url")) + d["cost_status"] = status result = [ {"model": model, **data} @@ -377,7 +442,8 @@ class InsightsEngine: """Break down usage by platform/source.""" platform_data = defaultdict(lambda: { "sessions": 0, "messages": 0, "input_tokens": 0, - "output_tokens": 0, "total_tokens": 0, "tool_calls": 0, + "output_tokens": 0, "cache_read_tokens": 0, + "cache_write_tokens": 0, "total_tokens": 0, "tool_calls": 0, }) for s in sessions: @@ -387,9 +453,13 @@ class InsightsEngine: d["messages"] += s.get("message_count") or 0 inp = s.get("input_tokens") or 0 out = s.get("output_tokens") or 0 + cache_read = s.get("cache_read_tokens") or 0 + cache_write = s.get("cache_write_tokens") or 0 d["input_tokens"] += inp d["output_tokens"] += out - d["total_tokens"] += inp + out + d["cache_read_tokens"] += cache_read + d["cache_write_tokens"] += cache_write + d["total_tokens"] += inp + out + cache_read + cache_write d["tool_calls"] += s.get("tool_call_count") or 0 result = [ diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 2f9ea666c..c578acf50 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -266,8 +266,10 @@ def get_model_context_length(model: str, base_url: str = "") -> int: if model in metadata: return metadata[model].get("context_length", 128000) - # 3. Hardcoded defaults (fuzzy match) - for default_model, length in DEFAULT_CONTEXT_LENGTHS.items(): + # 3. Hardcoded defaults (fuzzy match — longest key first for specificity) + for default_model, length in sorted( + DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True + ): if default_model in model or model in default_model: return length diff --git a/agent/prompt_builder.py b/agent/prompt_builder.py index 48e23eefb..4ce84473f 100644 --- a/agent/prompt_builder.py +++ b/agent/prompt_builder.py @@ -212,16 +212,15 @@ PLATFORM_HINTS = { "the scheduled destination, put it directly in your final response. Use " "send_message only for additional or different targets." ), - "sms": ( - "You are communicating via SMS text messaging. Keep responses concise " - "and plain text only -- no markdown, no formatting. SMS has a 1600 " - "character limit per message (10 segments). Longer replies are split " - "across multiple messages. Be brief and direct." - ), "cli": ( "You are a CLI AI Agent. Try not to use markdown but simple text " "renderable inside a terminal." ), + "sms": ( + "You are communicating via SMS. Keep responses concise and use plain text " + "only — no markdown, no formatting. SMS messages are limited to ~1600 " + "characters, so be brief and direct." + ), } CONTEXT_FILE_MAX_CHARS = 20_000 diff --git a/agent/usage_pricing.py b/agent/usage_pricing.py index 5bfba25d4..29e7df254 100644 --- a/agent/usage_pricing.py +++ b/agent/usage_pricing.py @@ -1,101 +1,593 @@ from __future__ import annotations +from dataclasses import dataclass +from datetime import datetime, timezone from decimal import Decimal -from typing import Dict +from typing import Any, Dict, Literal, Optional - -MODEL_PRICING = { - "gpt-4o": {"input": 2.50, "output": 10.00}, - "gpt-4o-mini": {"input": 0.15, "output": 0.60}, - "gpt-4.1": {"input": 2.00, "output": 8.00}, - "gpt-4.1-mini": {"input": 0.40, "output": 1.60}, - "gpt-4.1-nano": {"input": 0.10, "output": 0.40}, - "gpt-4.5-preview": {"input": 75.00, "output": 150.00}, - "gpt-5": {"input": 10.00, "output": 30.00}, - "gpt-5.4": {"input": 10.00, "output": 30.00}, - "o3": {"input": 10.00, "output": 40.00}, - "o3-mini": {"input": 1.10, "output": 4.40}, - "o4-mini": {"input": 1.10, "output": 4.40}, - "claude-opus-4-20250514": {"input": 15.00, "output": 75.00}, - "claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00}, - "claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00}, - "claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00}, - "claude-3-opus-20240229": {"input": 15.00, "output": 75.00}, - "claude-3-haiku-20240307": {"input": 0.25, "output": 1.25}, - "deepseek-chat": {"input": 0.14, "output": 0.28}, - "deepseek-reasoner": {"input": 0.55, "output": 2.19}, - "gemini-2.5-pro": {"input": 1.25, "output": 10.00}, - "gemini-2.5-flash": {"input": 0.15, "output": 0.60}, - "gemini-2.0-flash": {"input": 0.10, "output": 0.40}, - "llama-4-maverick": {"input": 0.50, "output": 0.70}, - "llama-4-scout": {"input": 0.20, "output": 0.30}, - "glm-5": {"input": 0.0, "output": 0.0}, - "glm-4.7": {"input": 0.0, "output": 0.0}, - "glm-4.5": {"input": 0.0, "output": 0.0}, - "glm-4.5-flash": {"input": 0.0, "output": 0.0}, - "kimi-k2.5": {"input": 0.0, "output": 0.0}, - "kimi-k2-thinking": {"input": 0.0, "output": 0.0}, - "kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0}, - "kimi-k2-0905-preview": {"input": 0.0, "output": 0.0}, - "MiniMax-M2.5": {"input": 0.0, "output": 0.0}, - "MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0}, - "MiniMax-M2.1": {"input": 0.0, "output": 0.0}, -} +from agent.model_metadata import fetch_model_metadata DEFAULT_PRICING = {"input": 0.0, "output": 0.0} +_ZERO = Decimal("0") +_ONE_MILLION = Decimal("1000000") -def get_pricing(model_name: str) -> Dict[str, float]: - if not model_name: - return DEFAULT_PRICING - - bare = model_name.split("/")[-1].lower() - if bare in MODEL_PRICING: - return MODEL_PRICING[bare] - - best_match = None - best_len = 0 - for key, price in MODEL_PRICING.items(): - if bare.startswith(key) and len(key) > best_len: - best_match = price - best_len = len(key) - if best_match: - return best_match - - if "opus" in bare: - return {"input": 15.00, "output": 75.00} - if "sonnet" in bare: - return {"input": 3.00, "output": 15.00} - if "haiku" in bare: - return {"input": 0.80, "output": 4.00} - if "gpt-4o-mini" in bare: - return {"input": 0.15, "output": 0.60} - if "gpt-4o" in bare: - return {"input": 2.50, "output": 10.00} - if "gpt-5" in bare: - return {"input": 10.00, "output": 30.00} - if "deepseek" in bare: - return {"input": 0.14, "output": 0.28} - if "gemini" in bare: - return {"input": 0.15, "output": 0.60} - - return DEFAULT_PRICING +CostStatus = Literal["actual", "estimated", "included", "unknown"] +CostSource = Literal[ + "provider_cost_api", + "provider_generation_api", + "provider_models_api", + "official_docs_snapshot", + "user_override", + "custom_contract", + "none", +] -def has_known_pricing(model_name: str) -> bool: - pricing = get_pricing(model_name) - return pricing is not DEFAULT_PRICING and any( - float(value) > 0 for value in pricing.values() +@dataclass(frozen=True) +class CanonicalUsage: + input_tokens: int = 0 + output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + reasoning_tokens: int = 0 + request_count: int = 1 + raw_usage: Optional[dict[str, Any]] = None + + @property + def prompt_tokens(self) -> int: + return self.input_tokens + self.cache_read_tokens + self.cache_write_tokens + + @property + def total_tokens(self) -> int: + return self.prompt_tokens + self.output_tokens + + +@dataclass(frozen=True) +class BillingRoute: + provider: str + model: str + base_url: str = "" + billing_mode: str = "unknown" + + +@dataclass(frozen=True) +class PricingEntry: + input_cost_per_million: Optional[Decimal] = None + output_cost_per_million: Optional[Decimal] = None + cache_read_cost_per_million: Optional[Decimal] = None + cache_write_cost_per_million: Optional[Decimal] = None + request_cost: Optional[Decimal] = None + source: CostSource = "none" + source_url: Optional[str] = None + pricing_version: Optional[str] = None + fetched_at: Optional[datetime] = None + + +@dataclass(frozen=True) +class CostResult: + amount_usd: Optional[Decimal] + status: CostStatus + source: CostSource + label: str + fetched_at: Optional[datetime] = None + pricing_version: Optional[str] = None + notes: tuple[str, ...] = () + + +_UTC_NOW = lambda: datetime.now(timezone.utc) + + +# Official docs snapshot entries. Models whose published pricing and cache +# semantics are stable enough to encode exactly. +_OFFICIAL_DOCS_PRICING: Dict[tuple[str, str], PricingEntry] = { + ( + "anthropic", + "claude-opus-4-20250514", + ): PricingEntry( + input_cost_per_million=Decimal("15.00"), + output_cost_per_million=Decimal("75.00"), + cache_read_cost_per_million=Decimal("1.50"), + cache_write_cost_per_million=Decimal("18.75"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-prompt-caching-2026-03-16", + ), + ( + "anthropic", + "claude-sonnet-4-20250514", + ): PricingEntry( + input_cost_per_million=Decimal("3.00"), + output_cost_per_million=Decimal("15.00"), + cache_read_cost_per_million=Decimal("0.30"), + cache_write_cost_per_million=Decimal("3.75"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-prompt-caching-2026-03-16", + ), + # OpenAI + ( + "openai", + "gpt-4o", + ): PricingEntry( + input_cost_per_million=Decimal("2.50"), + output_cost_per_million=Decimal("10.00"), + cache_read_cost_per_million=Decimal("1.25"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "gpt-4o-mini", + ): PricingEntry( + input_cost_per_million=Decimal("0.15"), + output_cost_per_million=Decimal("0.60"), + cache_read_cost_per_million=Decimal("0.075"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "gpt-4.1", + ): PricingEntry( + input_cost_per_million=Decimal("2.00"), + output_cost_per_million=Decimal("8.00"), + cache_read_cost_per_million=Decimal("0.50"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "gpt-4.1-mini", + ): PricingEntry( + input_cost_per_million=Decimal("0.40"), + output_cost_per_million=Decimal("1.60"), + cache_read_cost_per_million=Decimal("0.10"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "gpt-4.1-nano", + ): PricingEntry( + input_cost_per_million=Decimal("0.10"), + output_cost_per_million=Decimal("0.40"), + cache_read_cost_per_million=Decimal("0.025"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "o3", + ): PricingEntry( + input_cost_per_million=Decimal("10.00"), + output_cost_per_million=Decimal("40.00"), + cache_read_cost_per_million=Decimal("2.50"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + ( + "openai", + "o3-mini", + ): PricingEntry( + input_cost_per_million=Decimal("1.10"), + output_cost_per_million=Decimal("4.40"), + cache_read_cost_per_million=Decimal("0.55"), + source="official_docs_snapshot", + source_url="https://openai.com/api/pricing/", + pricing_version="openai-pricing-2026-03-16", + ), + # Anthropic older models (pre-4.6 generation) + ( + "anthropic", + "claude-3-5-sonnet-20241022", + ): PricingEntry( + input_cost_per_million=Decimal("3.00"), + output_cost_per_million=Decimal("15.00"), + cache_read_cost_per_million=Decimal("0.30"), + cache_write_cost_per_million=Decimal("3.75"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-pricing-2026-03-16", + ), + ( + "anthropic", + "claude-3-5-haiku-20241022", + ): PricingEntry( + input_cost_per_million=Decimal("0.80"), + output_cost_per_million=Decimal("4.00"), + cache_read_cost_per_million=Decimal("0.08"), + cache_write_cost_per_million=Decimal("1.00"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-pricing-2026-03-16", + ), + ( + "anthropic", + "claude-3-opus-20240229", + ): PricingEntry( + input_cost_per_million=Decimal("15.00"), + output_cost_per_million=Decimal("75.00"), + cache_read_cost_per_million=Decimal("1.50"), + cache_write_cost_per_million=Decimal("18.75"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-pricing-2026-03-16", + ), + ( + "anthropic", + "claude-3-haiku-20240307", + ): PricingEntry( + input_cost_per_million=Decimal("0.25"), + output_cost_per_million=Decimal("1.25"), + cache_read_cost_per_million=Decimal("0.03"), + cache_write_cost_per_million=Decimal("0.30"), + source="official_docs_snapshot", + source_url="https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching", + pricing_version="anthropic-pricing-2026-03-16", + ), + # DeepSeek + ( + "deepseek", + "deepseek-chat", + ): PricingEntry( + input_cost_per_million=Decimal("0.14"), + output_cost_per_million=Decimal("0.28"), + source="official_docs_snapshot", + source_url="https://api-docs.deepseek.com/quick_start/pricing", + pricing_version="deepseek-pricing-2026-03-16", + ), + ( + "deepseek", + "deepseek-reasoner", + ): PricingEntry( + input_cost_per_million=Decimal("0.55"), + output_cost_per_million=Decimal("2.19"), + source="official_docs_snapshot", + source_url="https://api-docs.deepseek.com/quick_start/pricing", + pricing_version="deepseek-pricing-2026-03-16", + ), + # Google Gemini + ( + "google", + "gemini-2.5-pro", + ): PricingEntry( + input_cost_per_million=Decimal("1.25"), + output_cost_per_million=Decimal("10.00"), + source="official_docs_snapshot", + source_url="https://ai.google.dev/pricing", + pricing_version="google-pricing-2026-03-16", + ), + ( + "google", + "gemini-2.5-flash", + ): PricingEntry( + input_cost_per_million=Decimal("0.15"), + output_cost_per_million=Decimal("0.60"), + source="official_docs_snapshot", + source_url="https://ai.google.dev/pricing", + pricing_version="google-pricing-2026-03-16", + ), + ( + "google", + "gemini-2.0-flash", + ): PricingEntry( + input_cost_per_million=Decimal("0.10"), + output_cost_per_million=Decimal("0.40"), + source="official_docs_snapshot", + source_url="https://ai.google.dev/pricing", + pricing_version="google-pricing-2026-03-16", + ), +} + + +def _to_decimal(value: Any) -> Optional[Decimal]: + if value is None: + return None + try: + return Decimal(str(value)) + except Exception: + return None + + +def _to_int(value: Any) -> int: + try: + return int(value or 0) + except Exception: + return 0 + + +def resolve_billing_route( + model_name: str, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> BillingRoute: + provider_name = (provider or "").strip().lower() + base = (base_url or "").strip().lower() + model = (model_name or "").strip() + if not provider_name and "/" in model: + inferred_provider, bare_model = model.split("/", 1) + if inferred_provider in {"anthropic", "openai", "google"}: + provider_name = inferred_provider + model = bare_model + + if provider_name == "openai-codex": + return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included") + if provider_name == "openrouter" or "openrouter.ai" in base: + return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api") + if provider_name == "anthropic": + return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") + if provider_name == "openai": + return BillingRoute(provider="openai", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") + if provider_name in {"custom", "local"} or (base and "localhost" in base): + return BillingRoute(provider=provider_name or "custom", model=model, base_url=base_url or "", billing_mode="unknown") + return BillingRoute(provider=provider_name or "unknown", model=model.split("/")[-1] if model else "", base_url=base_url or "", billing_mode="unknown") + + +def _lookup_official_docs_pricing(route: BillingRoute) -> Optional[PricingEntry]: + return _OFFICIAL_DOCS_PRICING.get((route.provider, route.model.lower())) + + +def _openrouter_pricing_entry(route: BillingRoute) -> Optional[PricingEntry]: + metadata = fetch_model_metadata() + model_id = route.model + if model_id not in metadata: + return None + pricing = metadata[model_id].get("pricing") or {} + prompt = _to_decimal(pricing.get("prompt")) + completion = _to_decimal(pricing.get("completion")) + request = _to_decimal(pricing.get("request")) + cache_read = _to_decimal( + pricing.get("cache_read") + or pricing.get("cached_prompt") + or pricing.get("input_cache_read") + ) + cache_write = _to_decimal( + pricing.get("cache_write") + or pricing.get("cache_creation") + or pricing.get("input_cache_write") + ) + if prompt is None and completion is None and request is None: + return None + def _per_token_to_per_million(value: Optional[Decimal]) -> Optional[Decimal]: + if value is None: + return None + return value * _ONE_MILLION + + return PricingEntry( + input_cost_per_million=_per_token_to_per_million(prompt), + output_cost_per_million=_per_token_to_per_million(completion), + cache_read_cost_per_million=_per_token_to_per_million(cache_read), + cache_write_cost_per_million=_per_token_to_per_million(cache_write), + request_cost=request, + source="provider_models_api", + source_url="https://openrouter.ai/docs/api/api-reference/models/get-models", + pricing_version="openrouter-models-api", + fetched_at=_UTC_NOW(), ) -def estimate_cost_usd(model: str, input_tokens: int, output_tokens: int) -> float: - pricing = get_pricing(model) - total = ( - Decimal(input_tokens) * Decimal(str(pricing["input"])) - + Decimal(output_tokens) * Decimal(str(pricing["output"])) - ) / Decimal("1000000") - return float(total) +def get_pricing_entry( + model_name: str, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> Optional[PricingEntry]: + route = resolve_billing_route(model_name, provider=provider, base_url=base_url) + if route.billing_mode == "subscription_included": + return PricingEntry( + input_cost_per_million=_ZERO, + output_cost_per_million=_ZERO, + cache_read_cost_per_million=_ZERO, + cache_write_cost_per_million=_ZERO, + source="none", + pricing_version="included-route", + ) + if route.provider == "openrouter": + return _openrouter_pricing_entry(route) + return _lookup_official_docs_pricing(route) + + +def normalize_usage( + response_usage: Any, + *, + provider: Optional[str] = None, + api_mode: Optional[str] = None, +) -> CanonicalUsage: + """Normalize raw API response usage into canonical token buckets. + + Handles three API shapes: + - Anthropic: input_tokens/output_tokens/cache_read_input_tokens/cache_creation_input_tokens + - Codex Responses: input_tokens includes cache tokens; input_tokens_details.cached_tokens separates them + - OpenAI Chat Completions: prompt_tokens includes cache tokens; prompt_tokens_details.cached_tokens separates them + + In both Codex and OpenAI modes, input_tokens is derived by subtracting cache + tokens from the total — the API contract is that input/prompt totals include + cached tokens and the details object breaks them out. + """ + if not response_usage: + return CanonicalUsage() + + provider_name = (provider or "").strip().lower() + mode = (api_mode or "").strip().lower() + + if mode == "anthropic_messages" or provider_name == "anthropic": + input_tokens = _to_int(getattr(response_usage, "input_tokens", 0)) + output_tokens = _to_int(getattr(response_usage, "output_tokens", 0)) + cache_read_tokens = _to_int(getattr(response_usage, "cache_read_input_tokens", 0)) + cache_write_tokens = _to_int(getattr(response_usage, "cache_creation_input_tokens", 0)) + elif mode == "codex_responses": + input_total = _to_int(getattr(response_usage, "input_tokens", 0)) + output_tokens = _to_int(getattr(response_usage, "output_tokens", 0)) + details = getattr(response_usage, "input_tokens_details", None) + cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0) + cache_write_tokens = _to_int( + getattr(details, "cache_creation_tokens", 0) if details else 0 + ) + input_tokens = max(0, input_total - cache_read_tokens - cache_write_tokens) + else: + prompt_total = _to_int(getattr(response_usage, "prompt_tokens", 0)) + output_tokens = _to_int(getattr(response_usage, "completion_tokens", 0)) + details = getattr(response_usage, "prompt_tokens_details", None) + cache_read_tokens = _to_int(getattr(details, "cached_tokens", 0) if details else 0) + cache_write_tokens = _to_int( + getattr(details, "cache_write_tokens", 0) if details else 0 + ) + input_tokens = max(0, prompt_total - cache_read_tokens - cache_write_tokens) + + reasoning_tokens = 0 + output_details = getattr(response_usage, "output_tokens_details", None) + if output_details: + reasoning_tokens = _to_int(getattr(output_details, "reasoning_tokens", 0)) + + return CanonicalUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + reasoning_tokens=reasoning_tokens, + ) + + +def estimate_usage_cost( + model_name: str, + usage: CanonicalUsage, + *, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> CostResult: + route = resolve_billing_route(model_name, provider=provider, base_url=base_url) + if route.billing_mode == "subscription_included": + return CostResult( + amount_usd=_ZERO, + status="included", + source="none", + label="included", + pricing_version="included-route", + ) + + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + if not entry: + return CostResult(amount_usd=None, status="unknown", source="none", label="n/a") + + notes: list[str] = [] + amount = _ZERO + + if usage.input_tokens and entry.input_cost_per_million is None: + return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a") + if usage.output_tokens and entry.output_cost_per_million is None: + return CostResult(amount_usd=None, status="unknown", source=entry.source, label="n/a") + if usage.cache_read_tokens: + if entry.cache_read_cost_per_million is None: + return CostResult( + amount_usd=None, + status="unknown", + source=entry.source, + label="n/a", + notes=("cache-read pricing unavailable for route",), + ) + if usage.cache_write_tokens: + if entry.cache_write_cost_per_million is None: + return CostResult( + amount_usd=None, + status="unknown", + source=entry.source, + label="n/a", + notes=("cache-write pricing unavailable for route",), + ) + + if entry.input_cost_per_million is not None: + amount += Decimal(usage.input_tokens) * entry.input_cost_per_million / _ONE_MILLION + if entry.output_cost_per_million is not None: + amount += Decimal(usage.output_tokens) * entry.output_cost_per_million / _ONE_MILLION + if entry.cache_read_cost_per_million is not None: + amount += Decimal(usage.cache_read_tokens) * entry.cache_read_cost_per_million / _ONE_MILLION + if entry.cache_write_cost_per_million is not None: + amount += Decimal(usage.cache_write_tokens) * entry.cache_write_cost_per_million / _ONE_MILLION + if entry.request_cost is not None and usage.request_count: + amount += Decimal(usage.request_count) * entry.request_cost + + status: CostStatus = "estimated" + label = f"~${amount:.2f}" + if entry.source == "none" and amount == _ZERO: + status = "included" + label = "included" + + if route.provider == "openrouter": + notes.append("OpenRouter cost is estimated from the models API until reconciled.") + + return CostResult( + amount_usd=amount, + status=status, + source=entry.source, + label=label, + fetched_at=entry.fetched_at, + pricing_version=entry.pricing_version, + notes=tuple(notes), + ) + + +def has_known_pricing( + model_name: str, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> bool: + """Check whether we have pricing data for this model+route. + + Uses direct lookup instead of routing through the full estimation + pipeline — avoids creating dummy usage objects just to check status. + """ + route = resolve_billing_route(model_name, provider=provider, base_url=base_url) + if route.billing_mode == "subscription_included": + return True + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + return entry is not None + + +def get_pricing( + model_name: str, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> Dict[str, float]: + """Backward-compatible thin wrapper for legacy callers. + + Returns only non-cache input/output fields when a pricing entry exists. + Unknown routes return zeroes. + """ + entry = get_pricing_entry(model_name, provider=provider, base_url=base_url) + if not entry: + return {"input": 0.0, "output": 0.0} + return { + "input": float(entry.input_cost_per_million or _ZERO), + "output": float(entry.output_cost_per_million or _ZERO), + } + + +def estimate_cost_usd( + model: str, + input_tokens: int, + output_tokens: int, + *, + provider: Optional[str] = None, + base_url: Optional[str] = None, +) -> float: + """Backward-compatible helper for legacy callers. + + This uses non-cached input/output only. New code should call + `estimate_usage_cost()` with canonical usage buckets. + """ + result = estimate_usage_cost( + model, + CanonicalUsage(input_tokens=input_tokens, output_tokens=output_tokens), + provider=provider, + base_url=base_url, + ) + return float(result.amount_usd or _ZERO) def format_duration_compact(seconds: float) -> str: diff --git a/cli.py b/cli.py index 8a51cd315..703b85e77 100755 --- a/cli.py +++ b/cli.py @@ -58,7 +58,12 @@ except (ImportError, AttributeError): import threading import queue -from agent.usage_pricing import estimate_cost_usd, format_duration_compact, format_token_count_compact, has_known_pricing +from agent.usage_pricing import ( + CanonicalUsage, + estimate_usage_cost, + format_duration_compact, + format_token_count_compact, +) from hermes_cli.banner import _format_context_length _COMMAND_SPINNER_FRAMES = ("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏") @@ -212,7 +217,7 @@ def load_cli_config() -> Dict[str, Any]: "resume_display": "full", "show_reasoning": False, "streaming": False, - "show_cost": False, + "skin": "default", "theme_mode": "auto", }, @@ -1034,8 +1039,7 @@ class HermesCLI: self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False) # show_reasoning: display model thinking/reasoning before the response self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False) - # show_cost: display $ cost in the status bar (off by default) - self.show_cost = CLI_CONFIG["display"].get("show_cost", False) + self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose") # streaming: stream tokens to the terminal as they arrive (display.streaming in config.yaml) @@ -1260,12 +1264,14 @@ class HermesCLI: "context_tokens": 0, "context_length": None, "context_percent": None, + "session_input_tokens": 0, + "session_output_tokens": 0, + "session_cache_read_tokens": 0, + "session_cache_write_tokens": 0, "session_prompt_tokens": 0, "session_completion_tokens": 0, "session_total_tokens": 0, "session_api_calls": 0, - "session_cost": 0.0, - "pricing_known": has_known_pricing(model_name), "compressions": 0, } @@ -1273,15 +1279,14 @@ class HermesCLI: if not agent: return snapshot + snapshot["session_input_tokens"] = getattr(agent, "session_input_tokens", 0) or 0 + snapshot["session_output_tokens"] = getattr(agent, "session_output_tokens", 0) or 0 + snapshot["session_cache_read_tokens"] = getattr(agent, "session_cache_read_tokens", 0) or 0 + snapshot["session_cache_write_tokens"] = getattr(agent, "session_cache_write_tokens", 0) or 0 snapshot["session_prompt_tokens"] = getattr(agent, "session_prompt_tokens", 0) or 0 snapshot["session_completion_tokens"] = getattr(agent, "session_completion_tokens", 0) or 0 snapshot["session_total_tokens"] = getattr(agent, "session_total_tokens", 0) or 0 snapshot["session_api_calls"] = getattr(agent, "session_api_calls", 0) or 0 - snapshot["session_cost"] = estimate_cost_usd( - model_name, - snapshot["session_prompt_tokens"], - snapshot["session_completion_tokens"], - ) compressor = getattr(agent, "context_compressor", None) if compressor: @@ -1302,19 +1307,11 @@ class HermesCLI: percent = snapshot["context_percent"] percent_label = f"{percent}%" if percent is not None else "--" duration_label = snapshot["duration"] - show_cost = getattr(self, "show_cost", False) - - if show_cost: - cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a" - else: - cost_label = None if width < 52: return f"⚕ {snapshot['model_short']} · {duration_label}" if width < 76: parts = [f"⚕ {snapshot['model_short']}", percent_label] - if cost_label: - parts.append(cost_label) parts.append(duration_label) return " · ".join(parts) @@ -1326,8 +1323,6 @@ class HermesCLI: context_label = "ctx --" parts = [f"⚕ {snapshot['model_short']}", context_label, percent_label] - if cost_label: - parts.append(cost_label) parts.append(duration_label) return " │ ".join(parts) except Exception: @@ -1338,12 +1333,6 @@ class HermesCLI: snapshot = self._get_status_bar_snapshot() width = shutil.get_terminal_size((80, 24)).columns duration_label = snapshot["duration"] - show_cost = getattr(self, "show_cost", False) - - if show_cost: - cost_label = f"${snapshot['session_cost']:.2f}" if snapshot["pricing_known"] else "cost n/a" - else: - cost_label = None if width < 52: return [ @@ -1363,11 +1352,6 @@ class HermesCLI: ("class:status-bar-dim", " · "), (self._status_bar_context_style(percent), percent_label), ] - if cost_label: - frags.extend([ - ("class:status-bar-dim", " · "), - ("class:status-bar-dim", cost_label), - ]) frags.extend([ ("class:status-bar-dim", " · "), ("class:status-bar-dim", duration_label), @@ -1393,11 +1377,6 @@ class HermesCLI: ("class:status-bar-dim", " "), (bar_style, percent_label), ] - if cost_label: - frags.extend([ - ("class:status-bar-dim", " │ "), - ("class:status-bar-dim", cost_label), - ]) frags.extend([ ("class:status-bar-dim", " │ "), ("class:status-bar-dim", duration_label), @@ -3653,8 +3632,17 @@ class HermesCLI: self.console.print(f"[bold red]Quick command error: {e}[/]") else: self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]") + elif qcmd.get("type") == "alias": + target = qcmd.get("target", "").strip() + if target: + target = target if target.startswith("/") else f"/{target}" + user_args = cmd_original[len(base_cmd):].strip() + aliased_command = f"{target} {user_args}".strip() + return self.process_command(aliased_command) + else: + self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]") else: - self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (only 'exec' is supported)[/]") + self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]") # Check for skill slash commands (/gif-search, /axolotl, etc.) elif base_cmd in _skill_commands: user_instruction = cmd_original[len(base_cmd):].strip() @@ -4242,6 +4230,10 @@ class HermesCLI: return agent = self.agent + input_tokens = getattr(agent, "session_input_tokens", 0) or 0 + output_tokens = getattr(agent, "session_output_tokens", 0) or 0 + cache_read_tokens = getattr(agent, "session_cache_read_tokens", 0) or 0 + cache_write_tokens = getattr(agent, "session_cache_write_tokens", 0) or 0 prompt = agent.session_prompt_tokens completion = agent.session_completion_tokens total = agent.session_total_tokens @@ -4259,33 +4251,45 @@ class HermesCLI: compressions = compressor.compression_count msg_count = len(self.conversation_history) - cost = estimate_cost_usd(agent.model, prompt, completion) - prompt_cost = estimate_cost_usd(agent.model, prompt, 0) - completion_cost = estimate_cost_usd(agent.model, 0, completion) - pricing_known = has_known_pricing(agent.model) + cost_result = estimate_usage_cost( + agent.model, + CanonicalUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + ), + provider=getattr(agent, "provider", None), + base_url=getattr(agent, "base_url", None), + ) elapsed = format_duration_compact((datetime.now() - self.session_start).total_seconds()) print(f" 📊 Session Token Usage") print(f" {'─' * 40}") print(f" Model: {agent.model}") - print(f" Prompt tokens (input): {prompt:>10,}") - print(f" Completion tokens (output): {completion:>9,}") + print(f" Input tokens: {input_tokens:>10,}") + print(f" Cache read tokens: {cache_read_tokens:>10,}") + print(f" Cache write tokens: {cache_write_tokens:>10,}") + print(f" Output tokens: {output_tokens:>10,}") + print(f" Prompt tokens (total): {prompt:>10,}") + print(f" Completion tokens: {completion:>10,}") print(f" Total tokens: {total:>10,}") print(f" API calls: {calls:>10,}") print(f" Session duration: {elapsed:>10}") - if pricing_known: - print(f" Input cost: ${prompt_cost:>10.4f}") - print(f" Output cost: ${completion_cost:>10.4f}") - print(f" Total cost: ${cost:>10.4f}") + print(f" Cost status: {cost_result.status:>10}") + print(f" Cost source: {cost_result.source:>10}") + if cost_result.amount_usd is not None: + prefix = "~" if cost_result.status == "estimated" else "" + print(f" Total cost: {prefix}${float(cost_result.amount_usd):>10.4f}") + elif cost_result.status == "included": + print(f" Total cost: {'included':>10}") else: - print(f" Input cost: {'n/a':>10}") - print(f" Output cost: {'n/a':>10}") print(f" Total cost: {'n/a':>10}") print(f" {'─' * 40}") print(f" Current context: {last_prompt:,} / {ctx_len:,} ({pct:.0f}%)") print(f" Messages: {msg_count}") print(f" Compressions: {compressions}") - if not pricing_known: + if cost_result.status == "unknown": print(f" Note: Pricing unknown for {agent.model}") if self.verbose: diff --git a/cron/jobs.py b/cron/jobs.py index b749c51f0..30d20f1e3 100644 --- a/cron/jobs.py +++ b/cron/jobs.py @@ -5,6 +5,7 @@ Jobs are stored in ~/.hermes/cron/jobs.json Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md """ +import copy import json import logging import tempfile @@ -167,6 +168,10 @@ def parse_schedule(schedule: str) -> Dict[str, Any]: try: # Parse and validate dt = datetime.fromisoformat(schedule.replace('Z', '+00:00')) + # Make naive timestamps timezone-aware at parse time so the stored + # value doesn't depend on the system timezone matching at check time. + if dt.tzinfo is None: + dt = dt.astimezone() # Interpret as local timezone return { "kind": "once", "run_at": dt.isoformat(), @@ -539,8 +544,8 @@ def get_due_jobs() -> List[Dict[str, Any]]: immediately. This prevents a burst of missed jobs on gateway restart. """ now = _hermes_now() - jobs = [_apply_skill_fields(j) for j in load_jobs()] - raw_jobs = load_jobs() # For saving updates + raw_jobs = load_jobs() + jobs = [_apply_skill_fields(j) for j in copy.deepcopy(raw_jobs)] due = [] needs_save = False diff --git a/docs/plans/2026-03-16-pricing-accuracy-architecture-design.md b/docs/plans/2026-03-16-pricing-accuracy-architecture-design.md new file mode 100644 index 000000000..a75f14ff5 --- /dev/null +++ b/docs/plans/2026-03-16-pricing-accuracy-architecture-design.md @@ -0,0 +1,608 @@ +# Pricing Accuracy Architecture + +Date: 2026-03-16 + +## Goal + +Hermes should only show dollar costs when they are backed by an official source for the user's actual billing path. + +This design replaces the current static, heuristic pricing flow in: + +- `run_agent.py` +- `agent/usage_pricing.py` +- `agent/insights.py` +- `cli.py` + +with a provider-aware pricing system that: + +- handles cache billing correctly +- distinguishes `actual` vs `estimated` vs `included` vs `unknown` +- reconciles post-hoc costs when providers expose authoritative billing data +- supports direct providers, OpenRouter, subscriptions, enterprise pricing, and custom endpoints + +## Problems In The Current Design + +Current Hermes behavior has four structural issues: + +1. It stores only `prompt_tokens` and `completion_tokens`, which is insufficient for providers that bill cache reads and cache writes separately. +2. It uses a static model price table and fuzzy heuristics, which can drift from current official pricing. +3. It assumes public API list pricing matches the user's real billing path. +4. It has no distinction between live estimates and reconciled billed cost. + +## Design Principles + +1. Normalize usage before pricing. +2. Never fold cached tokens into plain input cost. +3. Track certainty explicitly. +4. Treat the billing path as part of the model identity. +5. Prefer official machine-readable sources over scraped docs. +6. Use post-hoc provider cost APIs when available. +7. Show `n/a` rather than inventing precision. + +## High-Level Architecture + +The new system has four layers: + +1. `usage_normalization` + Converts raw provider usage into a canonical usage record. +2. `pricing_source_resolution` + Determines the billing path, source of truth, and applicable pricing source. +3. `cost_estimation_and_reconciliation` + Produces an immediate estimate when possible, then replaces or annotates it with actual billed cost later. +4. `presentation` + `/usage`, `/insights`, and the status bar display cost with certainty metadata. + +## Canonical Usage Record + +Add a canonical usage model that every provider path maps into before any pricing math happens. + +Suggested structure: + +```python +@dataclass +class CanonicalUsage: + provider: str + billing_provider: str + model: str + billing_route: str + + input_tokens: int = 0 + output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + reasoning_tokens: int = 0 + request_count: int = 1 + + raw_usage: dict[str, Any] | None = None + raw_usage_fields: dict[str, str] | None = None + computed_fields: set[str] | None = None + + provider_request_id: str | None = None + provider_generation_id: str | None = None + provider_response_id: str | None = None +``` + +Rules: + +- `input_tokens` means non-cached input only. +- `cache_read_tokens` and `cache_write_tokens` are never merged into `input_tokens`. +- `output_tokens` excludes cache metrics. +- `reasoning_tokens` is telemetry unless a provider officially bills it separately. + +This is the same normalization pattern used by `opencode`, extended with provenance and reconciliation ids. + +## Provider Normalization Rules + +### OpenAI Direct + +Source usage fields: + +- `prompt_tokens` +- `completion_tokens` +- `prompt_tokens_details.cached_tokens` + +Normalization: + +- `cache_read_tokens = cached_tokens` +- `input_tokens = prompt_tokens - cached_tokens` +- `cache_write_tokens = 0` unless OpenAI exposes it in the relevant route +- `output_tokens = completion_tokens` + +### Anthropic Direct + +Source usage fields: + +- `input_tokens` +- `output_tokens` +- `cache_read_input_tokens` +- `cache_creation_input_tokens` + +Normalization: + +- `input_tokens = input_tokens` +- `output_tokens = output_tokens` +- `cache_read_tokens = cache_read_input_tokens` +- `cache_write_tokens = cache_creation_input_tokens` + +### OpenRouter + +Estimate-time usage normalization should use the response usage payload with the same rules as the underlying provider when possible. + +Reconciliation-time records should also store: + +- OpenRouter generation id +- native token fields when available +- `total_cost` +- `cache_discount` +- `upstream_inference_cost` +- `is_byok` + +### Gemini / Vertex + +Use official Gemini or Vertex usage fields where available. + +If cached content tokens are exposed: + +- map them to `cache_read_tokens` + +If a route exposes no cache creation metric: + +- store `cache_write_tokens = 0` +- preserve the raw usage payload for later extension + +### DeepSeek And Other Direct Providers + +Normalize only the fields that are officially exposed. + +If a provider does not expose cache buckets: + +- do not infer them unless the provider explicitly documents how to derive them + +### Subscription / Included-Cost Routes + +These still use the canonical usage model. + +Tokens are tracked normally. Cost depends on billing mode, not on whether usage exists. + +## Billing Route Model + +Hermes must stop keying pricing solely by `model`. + +Introduce a billing route descriptor: + +```python +@dataclass +class BillingRoute: + provider: str + base_url: str | None + model: str + billing_mode: str + organization_hint: str | None = None +``` + +`billing_mode` values: + +- `official_cost_api` +- `official_generation_api` +- `official_models_api` +- `official_docs_snapshot` +- `subscription_included` +- `user_override` +- `custom_contract` +- `unknown` + +Examples: + +- OpenAI direct API with Costs API access: `official_cost_api` +- Anthropic direct API with Usage & Cost API access: `official_cost_api` +- OpenRouter request before reconciliation: `official_models_api` +- OpenRouter request after generation lookup: `official_generation_api` +- GitHub Copilot style subscription route: `subscription_included` +- local OpenAI-compatible server: `unknown` +- enterprise contract with configured rates: `custom_contract` + +## Cost Status Model + +Every displayed cost should have: + +```python +@dataclass +class CostResult: + amount_usd: Decimal | None + status: Literal["actual", "estimated", "included", "unknown"] + source: Literal[ + "provider_cost_api", + "provider_generation_api", + "provider_models_api", + "official_docs_snapshot", + "user_override", + "custom_contract", + "none", + ] + label: str + fetched_at: datetime | None + pricing_version: str | None + notes: list[str] +``` + +Presentation rules: + +- `actual`: show dollar amount as final +- `estimated`: show dollar amount with estimate labeling +- `included`: show `included` or `$0.00 (included)` depending on UX choice +- `unknown`: show `n/a` + +## Official Source Hierarchy + +Resolve cost using this order: + +1. Request-level or account-level official billed cost +2. Official machine-readable model pricing +3. Official docs snapshot +4. User override or custom contract +5. Unknown + +The system must never skip to a lower level if a higher-confidence source exists for the current billing route. + +## Provider-Specific Truth Rules + +### OpenAI Direct + +Preferred truth: + +1. Costs API for reconciled spend +2. Official pricing page for live estimate + +### Anthropic Direct + +Preferred truth: + +1. Usage & Cost API for reconciled spend +2. Official pricing docs for live estimate + +### OpenRouter + +Preferred truth: + +1. `GET /api/v1/generation` for reconciled `total_cost` +2. `GET /api/v1/models` pricing for live estimate + +Do not use underlying provider public pricing as the source of truth for OpenRouter billing. + +### Gemini / Vertex + +Preferred truth: + +1. official billing export or billing API for reconciled spend when available for the route +2. official pricing docs for estimate + +### DeepSeek + +Preferred truth: + +1. official machine-readable cost source if available in the future +2. official pricing docs snapshot today + +### Subscription-Included Routes + +Preferred truth: + +1. explicit route config marking the model as included in subscription + +These should display `included`, not an API list-price estimate. + +### Custom Endpoint / Local Model + +Preferred truth: + +1. user override +2. custom contract config +3. unknown + +These should default to `unknown`. + +## Pricing Catalog + +Replace the current `MODEL_PRICING` dict with a richer pricing catalog. + +Suggested record: + +```python +@dataclass +class PricingEntry: + provider: str + route_pattern: str + model_pattern: str + + input_cost_per_million: Decimal | None = None + output_cost_per_million: Decimal | None = None + cache_read_cost_per_million: Decimal | None = None + cache_write_cost_per_million: Decimal | None = None + request_cost: Decimal | None = None + image_cost: Decimal | None = None + + source: str = "official_docs_snapshot" + source_url: str | None = None + fetched_at: datetime | None = None + pricing_version: str | None = None +``` + +The catalog should be route-aware: + +- `openai:gpt-5` +- `anthropic:claude-opus-4-6` +- `openrouter:anthropic/claude-opus-4.6` +- `copilot:gpt-4o` + +This avoids conflating direct-provider billing with aggregator billing. + +## Pricing Sync Architecture + +Introduce a pricing sync subsystem instead of manually maintaining a single hardcoded table. + +Suggested modules: + +- `agent/pricing/catalog.py` +- `agent/pricing/sources.py` +- `agent/pricing/sync.py` +- `agent/pricing/reconcile.py` +- `agent/pricing/types.py` + +### Sync Sources + +- OpenRouter models API +- official provider docs snapshots where no API exists +- user overrides from config + +### Sync Output + +Cache pricing entries locally with: + +- source URL +- fetch timestamp +- version/hash +- confidence/source type + +### Sync Frequency + +- startup warm cache +- background refresh every 6 to 24 hours depending on source +- manual `hermes pricing sync` + +## Reconciliation Architecture + +Live requests may produce only an estimate initially. Hermes should reconcile them later when a provider exposes actual billed cost. + +Suggested flow: + +1. Agent call completes. +2. Hermes stores canonical usage plus reconciliation ids. +3. Hermes computes an immediate estimate if a pricing source exists. +4. A reconciliation worker fetches actual cost when supported. +5. Session and message records are updated with `actual` cost. + +This can run: + +- inline for cheap lookups +- asynchronously for delayed provider accounting + +## Persistence Changes + +Session storage should stop storing only aggregate prompt/completion totals. + +Add fields for both usage and cost certainty: + +- `input_tokens` +- `output_tokens` +- `cache_read_tokens` +- `cache_write_tokens` +- `reasoning_tokens` +- `estimated_cost_usd` +- `actual_cost_usd` +- `cost_status` +- `cost_source` +- `pricing_version` +- `billing_provider` +- `billing_mode` + +If schema expansion is too large for one PR, add a new pricing events table: + +```text +session_cost_events + id + session_id + request_id + provider + model + billing_mode + input_tokens + output_tokens + cache_read_tokens + cache_write_tokens + estimated_cost_usd + actual_cost_usd + cost_status + cost_source + pricing_version + created_at + updated_at +``` + +## Hermes Touchpoints + +### `run_agent.py` + +Current responsibility: + +- parse raw provider usage +- update session token counters + +New responsibility: + +- build `CanonicalUsage` +- update canonical counters +- store reconciliation ids +- emit usage event to pricing subsystem + +### `agent/usage_pricing.py` + +Current responsibility: + +- static lookup table +- direct cost arithmetic + +New responsibility: + +- move or replace with pricing catalog facade +- no fuzzy model-family heuristics +- no direct pricing without billing-route context + +### `cli.py` + +Current responsibility: + +- compute session cost directly from prompt/completion totals + +New responsibility: + +- display `CostResult` +- show status badges: + - `actual` + - `estimated` + - `included` + - `n/a` + +### `agent/insights.py` + +Current responsibility: + +- recompute historical estimates from static pricing + +New responsibility: + +- aggregate stored pricing events +- prefer actual cost over estimate +- surface estimates only when reconciliation is unavailable + +## UX Rules + +### Status Bar + +Show one of: + +- `$1.42` +- `~$1.42` +- `included` +- `cost n/a` + +Where: + +- `$1.42` means `actual` +- `~$1.42` means `estimated` +- `included` means subscription-backed or explicitly zero-cost route +- `cost n/a` means unknown + +### `/usage` + +Show: + +- token buckets +- estimated cost +- actual cost if available +- cost status +- pricing source + +### `/insights` + +Aggregate: + +- actual cost totals +- estimated-only totals +- unknown-cost sessions count +- included-cost sessions count + +## Config And Overrides + +Add user-configurable pricing overrides in config: + +```yaml +pricing: + mode: hybrid + sync_on_startup: true + sync_interval_hours: 12 + overrides: + - provider: openrouter + model: anthropic/claude-opus-4.6 + billing_mode: custom_contract + input_cost_per_million: 4.25 + output_cost_per_million: 22.0 + cache_read_cost_per_million: 0.5 + cache_write_cost_per_million: 6.0 + included_routes: + - provider: copilot + model: "*" + - provider: codex-subscription + model: "*" +``` + +Overrides must win over catalog defaults for the matching billing route. + +## Rollout Plan + +### Phase 1 + +- add canonical usage model +- split cache token buckets in `run_agent.py` +- stop pricing cache-inflated prompt totals +- preserve current UI with improved backend math + +### Phase 2 + +- add route-aware pricing catalog +- integrate OpenRouter models API sync +- add `estimated` vs `included` vs `unknown` + +### Phase 3 + +- add reconciliation for OpenRouter generation cost +- add actual cost persistence +- update `/insights` to prefer actual cost + +### Phase 4 + +- add direct OpenAI and Anthropic reconciliation paths +- add user overrides and contract pricing +- add pricing sync CLI command + +## Testing Strategy + +Add tests for: + +- OpenAI cached token subtraction +- Anthropic cache read/write separation +- OpenRouter estimated vs actual reconciliation +- subscription-backed models showing `included` +- custom endpoints showing `n/a` +- override precedence +- stale catalog fallback behavior + +Current tests that assume heuristic pricing should be replaced with route-aware expectations. + +## Non-Goals + +- exact enterprise billing reconstruction without an official source or user override +- backfilling perfect historical cost for old sessions that lack cache bucket data +- scraping arbitrary provider web pages at request time + +## Recommendation + +Do not expand the existing `MODEL_PRICING` dict. + +That path cannot satisfy the product requirement. Hermes should instead migrate to: + +- canonical usage normalization +- route-aware pricing sources +- estimate-then-reconcile cost lifecycle +- explicit certainty states in the UI + +This is the minimum architecture that makes the statement "Hermes pricing is backed by official sources where possible, and otherwise clearly labeled" defensible. diff --git a/gateway/config.py b/gateway/config.py index cf8fc1fae..e43af65aa 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -40,9 +40,12 @@ class Platform(Enum): WHATSAPP = "whatsapp" SLACK = "slack" SIGNAL = "signal" + MATTERMOST = "mattermost" + MATRIX = "matrix" HOMEASSISTANT = "homeassistant" EMAIL = "email" SMS = "sms" + DINGTALK = "dingtalk" @dataclass @@ -226,15 +229,15 @@ class GatewayConfig: # WhatsApp uses enabled flag only (bridge handles auth) elif platform == Platform.WHATSAPP: connected.append(platform) - # SMS uses api_key from env (checked via extra or env var) - elif platform == Platform.SMS and os.getenv("TELNYX_API_KEY"): - connected.append(platform) # Signal uses extra dict for config (http_url + account) elif platform == Platform.SIGNAL and config.extra.get("http_url"): connected.append(platform) # Email uses extra dict for config (address + imap_host + smtp_host) elif platform == Platform.EMAIL and config.extra.get("address"): connected.append(platform) + # SMS uses api_key (Twilio auth token) — SID checked via env + elif platform == Platform.SMS and os.getenv("TWILIO_ACCOUNT_SID"): + connected.append(platform) return connected def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]: @@ -441,6 +444,8 @@ def load_gateway_config() -> GatewayConfig: Platform.TELEGRAM: "TELEGRAM_BOT_TOKEN", Platform.DISCORD: "DISCORD_BOT_TOKEN", Platform.SLACK: "SLACK_BOT_TOKEN", + Platform.MATTERMOST: "MATTERMOST_TOKEN", + Platform.MATRIX: "MATRIX_ACCESS_TOKEN", } for platform, pconfig in config.platforms.items(): if not pconfig.enabled: @@ -534,6 +539,53 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"), ) + # Mattermost + mattermost_token = os.getenv("MATTERMOST_TOKEN") + if mattermost_token: + mattermost_url = os.getenv("MATTERMOST_URL", "") + if not mattermost_url: + logger.warning("MATTERMOST_TOKEN set but MATTERMOST_URL is missing") + if Platform.MATTERMOST not in config.platforms: + config.platforms[Platform.MATTERMOST] = PlatformConfig() + config.platforms[Platform.MATTERMOST].enabled = True + config.platforms[Platform.MATTERMOST].token = mattermost_token + config.platforms[Platform.MATTERMOST].extra["url"] = mattermost_url + mattermost_home = os.getenv("MATTERMOST_HOME_CHANNEL") + if mattermost_home: + config.platforms[Platform.MATTERMOST].home_channel = HomeChannel( + platform=Platform.MATTERMOST, + chat_id=mattermost_home, + name=os.getenv("MATTERMOST_HOME_CHANNEL_NAME", "Home"), + ) + + # Matrix + matrix_token = os.getenv("MATRIX_ACCESS_TOKEN") + matrix_homeserver = os.getenv("MATRIX_HOMESERVER", "") + if matrix_token or os.getenv("MATRIX_PASSWORD"): + if not matrix_homeserver: + logger.warning("MATRIX_ACCESS_TOKEN/MATRIX_PASSWORD set but MATRIX_HOMESERVER is missing") + if Platform.MATRIX not in config.platforms: + config.platforms[Platform.MATRIX] = PlatformConfig() + config.platforms[Platform.MATRIX].enabled = True + if matrix_token: + config.platforms[Platform.MATRIX].token = matrix_token + config.platforms[Platform.MATRIX].extra["homeserver"] = matrix_homeserver + matrix_user = os.getenv("MATRIX_USER_ID", "") + if matrix_user: + config.platforms[Platform.MATRIX].extra["user_id"] = matrix_user + matrix_password = os.getenv("MATRIX_PASSWORD", "") + if matrix_password: + config.platforms[Platform.MATRIX].extra["password"] = matrix_password + matrix_e2ee = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes") + config.platforms[Platform.MATRIX].extra["encryption"] = matrix_e2ee + matrix_home = os.getenv("MATRIX_HOME_ROOM") + if matrix_home: + config.platforms[Platform.MATRIX].home_channel = HomeChannel( + platform=Platform.MATRIX, + chat_id=matrix_home, + name=os.getenv("MATRIX_HOME_ROOM_NAME", "Home"), + ) + # Home Assistant hass_token = os.getenv("HASS_TOKEN") if hass_token: @@ -567,13 +619,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: name=os.getenv("EMAIL_HOME_ADDRESS_NAME", "Home"), ) - # SMS (Telnyx) - telnyx_key = os.getenv("TELNYX_API_KEY") - if telnyx_key: + # SMS (Twilio) + twilio_sid = os.getenv("TWILIO_ACCOUNT_SID") + if twilio_sid: if Platform.SMS not in config.platforms: config.platforms[Platform.SMS] = PlatformConfig() config.platforms[Platform.SMS].enabled = True - config.platforms[Platform.SMS].api_key = telnyx_key + config.platforms[Platform.SMS].api_key = os.getenv("TWILIO_AUTH_TOKEN", "") sms_home = os.getenv("SMS_HOME_CHANNEL") if sms_home: config.platforms[Platform.SMS].home_channel = HomeChannel( diff --git a/gateway/hooks.py b/gateway/hooks.py index 2274b5b91..657c2e449 100644 --- a/gateway/hooks.py +++ b/gateway/hooks.py @@ -8,8 +8,9 @@ Hooks are discovered from ~/.hermes/hooks/ directories, each containing: Events: - gateway:startup -- Gateway process starts - - session:start -- New session created - - session:reset -- User ran /new or /reset + - session:start -- New session created (first message of a new session) + - session:end -- Session ends (user ran /new or /reset) + - session:reset -- Session reset completed (new session entry created) - agent:start -- Agent begins processing a message - agent:step -- Each turn in the tool-calling loop - agent:end -- Agent finishes processing diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py new file mode 100644 index 000000000..8ed376962 --- /dev/null +++ b/gateway/platforms/dingtalk.py @@ -0,0 +1,340 @@ +""" +DingTalk platform adapter using Stream Mode. + +Uses dingtalk-stream SDK for real-time message reception without webhooks. +Responses are sent via DingTalk's session webhook (markdown format). + +Requires: + pip install dingtalk-stream httpx + DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET env vars + +Configuration in config.yaml: + platforms: + dingtalk: + enabled: true + extra: + client_id: "your-app-key" # or DINGTALK_CLIENT_ID env var + client_secret: "your-secret" # or DINGTALK_CLIENT_SECRET env var +""" + +import asyncio +import logging +import os +import time +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, Optional + +try: + import dingtalk_stream + from dingtalk_stream import ChatbotHandler, ChatbotMessage + DINGTALK_STREAM_AVAILABLE = True +except ImportError: + DINGTALK_STREAM_AVAILABLE = False + dingtalk_stream = None # type: ignore[assignment] + +try: + import httpx + HTTPX_AVAILABLE = True +except ImportError: + HTTPX_AVAILABLE = False + httpx = None # type: ignore[assignment] + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + +MAX_MESSAGE_LENGTH = 20000 +DEDUP_WINDOW_SECONDS = 300 +DEDUP_MAX_SIZE = 1000 +RECONNECT_BACKOFF = [2, 5, 10, 30, 60] + + +def check_dingtalk_requirements() -> bool: + """Check if DingTalk dependencies are available and configured.""" + if not DINGTALK_STREAM_AVAILABLE or not HTTPX_AVAILABLE: + return False + if not os.getenv("DINGTALK_CLIENT_ID") or not os.getenv("DINGTALK_CLIENT_SECRET"): + return False + return True + + +class DingTalkAdapter(BasePlatformAdapter): + """DingTalk chatbot adapter using Stream Mode. + + The dingtalk-stream SDK maintains a long-lived WebSocket connection. + Incoming messages arrive via a ChatbotHandler callback. Replies are + sent via the incoming message's session_webhook URL using httpx. + """ + + MAX_MESSAGE_LENGTH = MAX_MESSAGE_LENGTH + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.DINGTALK) + + extra = config.extra or {} + self._client_id: str = extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID", "") + self._client_secret: str = extra.get("client_secret") or os.getenv("DINGTALK_CLIENT_SECRET", "") + + self._stream_client: Any = None + self._stream_task: Optional[asyncio.Task] = None + self._http_client: Optional["httpx.AsyncClient"] = None + + # Message deduplication: msg_id -> timestamp + self._seen_messages: Dict[str, float] = {} + # Map chat_id -> session_webhook for reply routing + self._session_webhooks: Dict[str, str] = {} + + # -- Connection lifecycle ----------------------------------------------- + + async def connect(self) -> bool: + """Connect to DingTalk via Stream Mode.""" + if not DINGTALK_STREAM_AVAILABLE: + logger.warning("[%s] dingtalk-stream not installed. Run: pip install dingtalk-stream", self.name) + return False + if not HTTPX_AVAILABLE: + logger.warning("[%s] httpx not installed. Run: pip install httpx", self.name) + return False + if not self._client_id or not self._client_secret: + logger.warning("[%s] DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET required", self.name) + return False + + try: + self._http_client = httpx.AsyncClient(timeout=30.0) + + credential = dingtalk_stream.Credential(self._client_id, self._client_secret) + self._stream_client = dingtalk_stream.DingTalkStreamClient(credential) + + # Capture the current event loop for cross-thread dispatch + loop = asyncio.get_running_loop() + handler = _IncomingHandler(self, loop) + self._stream_client.register_callback_handler( + dingtalk_stream.ChatbotMessage.TOPIC, handler + ) + + self._stream_task = asyncio.create_task(self._run_stream()) + self._mark_connected() + logger.info("[%s] Connected via Stream Mode", self.name) + return True + except Exception as e: + logger.error("[%s] Failed to connect: %s", self.name, e) + return False + + async def _run_stream(self) -> None: + """Run the blocking stream client with auto-reconnection.""" + backoff_idx = 0 + while self._running: + try: + logger.debug("[%s] Starting stream client...", self.name) + await asyncio.to_thread(self._stream_client.start) + except asyncio.CancelledError: + return + except Exception as e: + if not self._running: + return + logger.warning("[%s] Stream client error: %s", self.name, e) + + if not self._running: + return + + delay = RECONNECT_BACKOFF[min(backoff_idx, len(RECONNECT_BACKOFF) - 1)] + logger.info("[%s] Reconnecting in %ds...", self.name, delay) + await asyncio.sleep(delay) + backoff_idx += 1 + + async def disconnect(self) -> None: + """Disconnect from DingTalk.""" + self._running = False + self._mark_disconnected() + + if self._stream_task: + self._stream_task.cancel() + try: + await self._stream_task + except asyncio.CancelledError: + pass + self._stream_task = None + + if self._http_client: + await self._http_client.aclose() + self._http_client = None + + self._stream_client = None + self._session_webhooks.clear() + self._seen_messages.clear() + logger.info("[%s] Disconnected", self.name) + + # -- Inbound message processing ----------------------------------------- + + async def _on_message(self, message: "ChatbotMessage") -> None: + """Process an incoming DingTalk chatbot message.""" + msg_id = getattr(message, "message_id", None) or uuid.uuid4().hex + if self._is_duplicate(msg_id): + logger.debug("[%s] Duplicate message %s, skipping", self.name, msg_id) + return + + text = self._extract_text(message) + if not text: + logger.debug("[%s] Empty message, skipping", self.name) + return + + # Chat context + conversation_id = getattr(message, "conversation_id", "") or "" + conversation_type = getattr(message, "conversation_type", "1") + is_group = str(conversation_type) == "2" + sender_id = getattr(message, "sender_id", "") or "" + sender_nick = getattr(message, "sender_nick", "") or sender_id + sender_staff_id = getattr(message, "sender_staff_id", "") or "" + + chat_id = conversation_id or sender_id + chat_type = "group" if is_group else "dm" + + # Store session webhook for reply routing + session_webhook = getattr(message, "session_webhook", None) or "" + if session_webhook and chat_id: + self._session_webhooks[chat_id] = session_webhook + + source = self.build_source( + chat_id=chat_id, + chat_name=getattr(message, "conversation_title", None), + chat_type=chat_type, + user_id=sender_id, + user_name=sender_nick, + user_id_alt=sender_staff_id if sender_staff_id else None, + ) + + # Parse timestamp + create_at = getattr(message, "create_at", None) + try: + timestamp = datetime.fromtimestamp(int(create_at) / 1000, tz=timezone.utc) if create_at else datetime.now(tz=timezone.utc) + except (ValueError, OSError, TypeError): + timestamp = datetime.now(tz=timezone.utc) + + event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id=msg_id, + raw_message=message, + timestamp=timestamp, + ) + + logger.debug("[%s] Message from %s in %s: %s", + self.name, sender_nick, chat_id[:20] if chat_id else "?", text[:50]) + await self.handle_message(event) + + @staticmethod + def _extract_text(message: "ChatbotMessage") -> str: + """Extract plain text from a DingTalk chatbot message.""" + text = getattr(message, "text", None) or "" + if isinstance(text, dict): + content = text.get("content", "").strip() + else: + content = str(text).strip() + + # Fall back to rich text if present + if not content: + rich_text = getattr(message, "rich_text", None) + if rich_text and isinstance(rich_text, list): + parts = [item["text"] for item in rich_text + if isinstance(item, dict) and item.get("text")] + content = " ".join(parts).strip() + return content + + # -- Deduplication ------------------------------------------------------ + + def _is_duplicate(self, msg_id: str) -> bool: + """Check and record a message ID. Returns True if already seen.""" + now = time.time() + if len(self._seen_messages) > DEDUP_MAX_SIZE: + cutoff = now - DEDUP_WINDOW_SECONDS + self._seen_messages = {k: v for k, v in self._seen_messages.items() if v > cutoff} + + if msg_id in self._seen_messages: + return True + self._seen_messages[msg_id] = now + return False + + # -- Outbound messaging ------------------------------------------------- + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a markdown reply via DingTalk session webhook.""" + metadata = metadata or {} + + session_webhook = metadata.get("session_webhook") or self._session_webhooks.get(chat_id) + if not session_webhook: + return SendResult(success=False, + error="No session_webhook available. Reply must follow an incoming message.") + + if not self._http_client: + return SendResult(success=False, error="HTTP client not initialized") + + payload = { + "msgtype": "markdown", + "markdown": {"title": "Hermes", "text": content[:self.MAX_MESSAGE_LENGTH]}, + } + + try: + resp = await self._http_client.post(session_webhook, json=payload, timeout=15.0) + if resp.status_code < 300: + return SendResult(success=True, message_id=uuid.uuid4().hex[:12]) + body = resp.text + logger.warning("[%s] Send failed HTTP %d: %s", self.name, resp.status_code, body[:200]) + return SendResult(success=False, error=f"HTTP {resp.status_code}: {body[:200]}") + except httpx.TimeoutException: + return SendResult(success=False, error="Timeout sending message to DingTalk") + except Exception as e: + logger.error("[%s] Send error: %s", self.name, e) + return SendResult(success=False, error=str(e)) + + async def send_typing(self, chat_id: str, metadata=None) -> None: + """DingTalk does not support typing indicators.""" + pass + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return basic info about a DingTalk conversation.""" + return {"name": chat_id, "type": "group" if "group" in chat_id.lower() else "dm"} + + +# --------------------------------------------------------------------------- +# Internal stream handler +# --------------------------------------------------------------------------- + +class _IncomingHandler(ChatbotHandler if DINGTALK_STREAM_AVAILABLE else object): + """dingtalk-stream ChatbotHandler that forwards messages to the adapter.""" + + def __init__(self, adapter: DingTalkAdapter, loop: asyncio.AbstractEventLoop): + if DINGTALK_STREAM_AVAILABLE: + super().__init__() + self._adapter = adapter + self._loop = loop + + def process(self, message: "ChatbotMessage"): + """Called by dingtalk-stream in its thread when a message arrives. + + Schedules the async handler on the main event loop. + """ + loop = self._loop + if loop is None or loop.is_closed(): + logger.error("[DingTalk] Event loop unavailable, cannot dispatch message") + return dingtalk_stream.AckMessage.STATUS_OK, "OK" + + future = asyncio.run_coroutine_threadsafe(self._adapter._on_message(message), loop) + try: + future.result(timeout=60) + except Exception: + logger.exception("[DingTalk] Error processing incoming message") + + return dingtalk_stream.AckMessage.STATUS_OK, "OK" diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py new file mode 100644 index 000000000..8431e31d6 --- /dev/null +++ b/gateway/platforms/matrix.py @@ -0,0 +1,842 @@ +"""Matrix gateway adapter. + +Connects to any Matrix homeserver (self-hosted or matrix.org) via the +matrix-nio Python SDK. Supports optional end-to-end encryption (E2EE) +when installed with ``pip install "matrix-nio[e2e]"``. + +Environment variables: + MATRIX_HOMESERVER Homeserver URL (e.g. https://matrix.example.org) + MATRIX_ACCESS_TOKEN Access token (preferred auth method) + MATRIX_USER_ID Full user ID (@bot:server) — required for password login + MATRIX_PASSWORD Password (alternative to access token) + MATRIX_ENCRYPTION Set "true" to enable E2EE + MATRIX_ALLOWED_USERS Comma-separated Matrix user IDs (@user:server) + MATRIX_HOME_ROOM Room ID for cron/notification delivery +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import mimetypes +import os +import re +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + +# Matrix message size limit (4000 chars practical, spec has no hard limit +# but clients render poorly above this). +MAX_MESSAGE_LENGTH = 4000 + +# Store directory for E2EE keys and sync state. +_STORE_DIR = Path.home() / ".hermes" / "matrix" / "store" + +# Grace period: ignore messages older than this many seconds before startup. +_STARTUP_GRACE_SECONDS = 5 + + +def check_matrix_requirements() -> bool: + """Return True if the Matrix adapter can be used.""" + token = os.getenv("MATRIX_ACCESS_TOKEN", "") + password = os.getenv("MATRIX_PASSWORD", "") + homeserver = os.getenv("MATRIX_HOMESERVER", "") + + if not token and not password: + logger.debug("Matrix: neither MATRIX_ACCESS_TOKEN nor MATRIX_PASSWORD set") + return False + if not homeserver: + logger.warning("Matrix: MATRIX_HOMESERVER not set") + return False + try: + import nio # noqa: F401 + return True + except ImportError: + logger.warning( + "Matrix: matrix-nio not installed. " + "Run: pip install 'matrix-nio[e2e]'" + ) + return False + + +class MatrixAdapter(BasePlatformAdapter): + """Gateway adapter for Matrix (any homeserver).""" + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.MATRIX) + + self._homeserver: str = ( + config.extra.get("homeserver", "") + or os.getenv("MATRIX_HOMESERVER", "") + ).rstrip("/") + self._access_token: str = config.token or os.getenv("MATRIX_ACCESS_TOKEN", "") + self._user_id: str = ( + config.extra.get("user_id", "") + or os.getenv("MATRIX_USER_ID", "") + ) + self._password: str = ( + config.extra.get("password", "") + or os.getenv("MATRIX_PASSWORD", "") + ) + self._encryption: bool = config.extra.get( + "encryption", + os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes"), + ) + + self._client: Any = None # nio.AsyncClient + self._sync_task: Optional[asyncio.Task] = None + self._closing = False + self._startup_ts: float = 0.0 + + # Cache: room_id → bool (is DM) + self._dm_rooms: Dict[str, bool] = {} + # Set of room IDs we've joined + self._joined_rooms: Set[str] = set() + + # ------------------------------------------------------------------ + # Required overrides + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to the Matrix homeserver and start syncing.""" + import nio + + if not self._homeserver: + logger.error("Matrix: homeserver URL not configured") + return False + + # Determine store path and ensure it exists. + store_path = str(_STORE_DIR) + _STORE_DIR.mkdir(parents=True, exist_ok=True) + + # Create the client. + if self._encryption: + try: + client = nio.AsyncClient( + self._homeserver, + self._user_id or "", + store_path=store_path, + ) + logger.info("Matrix: E2EE enabled (store: %s)", store_path) + except Exception as exc: + logger.warning( + "Matrix: failed to create E2EE client (%s), " + "falling back to plain client. Install: " + "pip install 'matrix-nio[e2e]'", + exc, + ) + client = nio.AsyncClient(self._homeserver, self._user_id or "") + else: + client = nio.AsyncClient(self._homeserver, self._user_id or "") + + self._client = client + + # Authenticate. + if self._access_token: + client.access_token = self._access_token + # Resolve user_id if not set. + if not self._user_id: + resp = await client.whoami() + if isinstance(resp, nio.WhoamiResponse): + self._user_id = resp.user_id + client.user_id = resp.user_id + logger.info("Matrix: authenticated as %s", self._user_id) + else: + logger.error( + "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER" + ) + await client.close() + return False + else: + client.user_id = self._user_id + logger.info("Matrix: using access token for %s", self._user_id) + elif self._password and self._user_id: + resp = await client.login( + self._password, + device_name="Hermes Agent", + ) + if isinstance(resp, nio.LoginResponse): + logger.info("Matrix: logged in as %s", self._user_id) + else: + logger.error("Matrix: login failed — %s", getattr(resp, "message", resp)) + await client.close() + return False + else: + logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD") + await client.close() + return False + + # If E2EE is enabled, load the crypto store. + if self._encryption and hasattr(client, "olm"): + try: + if client.should_upload_keys: + await client.keys_upload() + logger.info("Matrix: E2EE crypto initialized") + except Exception as exc: + logger.warning("Matrix: crypto init issue: %s", exc) + + # Register event callbacks. + client.add_event_callback(self._on_room_message, nio.RoomMessageText) + client.add_event_callback(self._on_room_message_media, nio.RoomMessageMedia) + client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage) + client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio) + client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo) + client.add_event_callback(self._on_room_message_media, nio.RoomMessageFile) + client.add_event_callback(self._on_invite, nio.InviteMemberEvent) + + # If E2EE: handle encrypted events. + if self._encryption and hasattr(client, "olm"): + client.add_event_callback( + self._on_room_message, nio.MegolmEvent + ) + + # Initial sync to catch up, then start background sync. + self._startup_ts = time.time() + self._closing = False + + # Do an initial sync to populate room state. + resp = await client.sync(timeout=10000, full_state=True) + if isinstance(resp, nio.SyncResponse): + self._joined_rooms = set(resp.rooms.join.keys()) + logger.info( + "Matrix: initial sync complete, joined %d rooms", + len(self._joined_rooms), + ) + # Build DM room cache from m.direct account data. + await self._refresh_dm_cache() + else: + logger.warning("Matrix: initial sync returned %s", type(resp).__name__) + + # Start the sync loop. + self._sync_task = asyncio.create_task(self._sync_loop()) + self._mark_connected() + return True + + async def disconnect(self) -> None: + """Disconnect from Matrix.""" + self._closing = True + + if self._sync_task and not self._sync_task.done(): + self._sync_task.cancel() + try: + await self._sync_task + except (asyncio.CancelledError, Exception): + pass + + if self._client: + await self._client.close() + self._client = None + + logger.info("Matrix: disconnected") + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a message to a Matrix room.""" + import nio + + if not content: + return SendResult(success=True) + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted, MAX_MESSAGE_LENGTH) + + last_event_id = None + for chunk in chunks: + msg_content: Dict[str, Any] = { + "msgtype": "m.text", + "body": chunk, + } + + # Convert markdown to HTML for rich rendering. + html = self._markdown_to_html(chunk) + if html and html != chunk: + msg_content["format"] = "org.matrix.custom.html" + msg_content["formatted_body"] = html + + # Reply-to support. + if reply_to: + msg_content["m.relates_to"] = { + "m.in_reply_to": {"event_id": reply_to} + } + + # Thread support: if metadata has thread_id, send as threaded reply. + thread_id = (metadata or {}).get("thread_id") + if thread_id: + relates_to = msg_content.get("m.relates_to", {}) + relates_to["rel_type"] = "m.thread" + relates_to["event_id"] = thread_id + relates_to["is_falling_back"] = True + if reply_to and "m.in_reply_to" not in relates_to: + relates_to["m.in_reply_to"] = {"event_id": reply_to} + msg_content["m.relates_to"] = relates_to + + resp = await self._client.room_send( + chat_id, + "m.room.message", + msg_content, + ) + if isinstance(resp, nio.RoomSendResponse): + last_event_id = resp.event_id + else: + err = getattr(resp, "message", str(resp)) + logger.error("Matrix: failed to send to %s: %s", chat_id, err) + return SendResult(success=False, error=err) + + return SendResult(success=True, message_id=last_event_id) + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return room name and type (dm/group).""" + name = chat_id + chat_type = "group" + + if self._client: + room = self._client.rooms.get(chat_id) + if room: + name = room.display_name or room.canonical_alias or chat_id + # Use DM cache. + if self._dm_rooms.get(chat_id, False): + chat_type = "dm" + elif room.member_count == 2: + chat_type = "dm" + + return {"name": name, "type": chat_type} + + # ------------------------------------------------------------------ + # Optional overrides + # ------------------------------------------------------------------ + + async def send_typing( + self, chat_id: str, metadata: Optional[Dict[str, Any]] = None + ) -> None: + """Send a typing indicator.""" + if self._client: + try: + await self._client.room_typing(chat_id, typing_state=True, timeout=30000) + except Exception: + pass + + async def edit_message( + self, chat_id: str, message_id: str, content: str + ) -> SendResult: + """Edit an existing message (via m.replace).""" + import nio + + formatted = self.format_message(content) + msg_content: Dict[str, Any] = { + "msgtype": "m.text", + "body": f"* {formatted}", + "m.new_content": { + "msgtype": "m.text", + "body": formatted, + }, + "m.relates_to": { + "rel_type": "m.replace", + "event_id": message_id, + }, + } + + html = self._markdown_to_html(formatted) + if html and html != formatted: + msg_content["m.new_content"]["format"] = "org.matrix.custom.html" + msg_content["m.new_content"]["formatted_body"] = html + msg_content["format"] = "org.matrix.custom.html" + msg_content["formatted_body"] = f"* {html}" + + resp = await self._client.room_send(chat_id, "m.room.message", msg_content) + if isinstance(resp, nio.RoomSendResponse): + return SendResult(success=True, message_id=resp.event_id) + return SendResult(success=False, error=getattr(resp, "message", str(resp))) + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Download an image URL and upload it to Matrix.""" + try: + # Try aiohttp first (always available), fall back to httpx + try: + import aiohttp as _aiohttp + async with _aiohttp.ClientSession() as http: + async with http.get(image_url, timeout=_aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + data = await resp.read() + ct = resp.content_type or "image/png" + fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png" + except ImportError: + import httpx + async with httpx.AsyncClient() as http: + resp = await http.get(image_url, follow_redirects=True, timeout=30) + resp.raise_for_status() + data = resp.content + ct = resp.headers.get("content-type", "image/png") + fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png" + except Exception as exc: + logger.warning("Matrix: failed to download image %s: %s", image_url, exc) + return await self.send(chat_id, f"{caption or ''}\n{image_url}".strip(), reply_to) + + return await self._upload_and_send(chat_id, data, fname, ct, "m.image", caption, reply_to, metadata) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a local image file to Matrix.""" + return await self._send_local_file(chat_id, image_path, "m.image", caption, reply_to, metadata=metadata) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a local file as a document.""" + return await self._send_local_file(chat_id, file_path, "m.file", caption, reply_to, file_name, metadata) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload an audio file as a voice message.""" + return await self._send_local_file(chat_id, audio_path, "m.audio", caption, reply_to, metadata=metadata) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a video file.""" + return await self._send_local_file(chat_id, video_path, "m.video", caption, reply_to, metadata=metadata) + + def format_message(self, content: str) -> str: + """Pass-through — Matrix supports standard Markdown natively.""" + # Strip image markdown; media is uploaded separately. + content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content) + return content + + # ------------------------------------------------------------------ + # File helpers + # ------------------------------------------------------------------ + + async def _upload_and_send( + self, + room_id: str, + data: bytes, + filename: str, + content_type: str, + msgtype: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload bytes to Matrix and send as a media message.""" + import nio + + # Upload to homeserver. + resp = await self._client.upload( + data, + content_type=content_type, + filename=filename, + ) + if not isinstance(resp, nio.UploadResponse): + err = getattr(resp, "message", str(resp)) + logger.error("Matrix: upload failed: %s", err) + return SendResult(success=False, error=err) + + mxc_url = resp.content_uri + + # Build media message content. + msg_content: Dict[str, Any] = { + "msgtype": msgtype, + "body": caption or filename, + "url": mxc_url, + "info": { + "mimetype": content_type, + "size": len(data), + }, + } + + if reply_to: + msg_content["m.relates_to"] = { + "m.in_reply_to": {"event_id": reply_to} + } + + thread_id = (metadata or {}).get("thread_id") + if thread_id: + relates_to = msg_content.get("m.relates_to", {}) + relates_to["rel_type"] = "m.thread" + relates_to["event_id"] = thread_id + relates_to["is_falling_back"] = True + msg_content["m.relates_to"] = relates_to + + resp2 = await self._client.room_send(room_id, "m.room.message", msg_content) + if isinstance(resp2, nio.RoomSendResponse): + return SendResult(success=True, message_id=resp2.event_id) + return SendResult(success=False, error=getattr(resp2, "message", str(resp2))) + + async def _send_local_file( + self, + room_id: str, + file_path: str, + msgtype: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + file_name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Read a local file and upload it.""" + p = Path(file_path) + if not p.exists(): + return await self.send( + room_id, f"{caption or ''}\n(file not found: {file_path})", reply_to + ) + + fname = file_name or p.name + ct = mimetypes.guess_type(fname)[0] or "application/octet-stream" + data = p.read_bytes() + + return await self._upload_and_send(room_id, data, fname, ct, msgtype, caption, reply_to, metadata) + + # ------------------------------------------------------------------ + # Sync loop + # ------------------------------------------------------------------ + + async def _sync_loop(self) -> None: + """Continuously sync with the homeserver.""" + while not self._closing: + try: + await self._client.sync(timeout=30000) + except asyncio.CancelledError: + return + except Exception as exc: + if self._closing: + return + logger.warning("Matrix: sync error: %s — retrying in 5s", exc) + await asyncio.sleep(5) + + # ------------------------------------------------------------------ + # Event callbacks + # ------------------------------------------------------------------ + + async def _on_room_message(self, room: Any, event: Any) -> None: + """Handle incoming text messages (and decrypted megolm events).""" + import nio + + # Ignore own messages. + if event.sender == self._user_id: + return + + # Startup grace: ignore old messages from initial sync. + event_ts = getattr(event, "server_timestamp", 0) / 1000.0 + if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: + return + + # Handle decrypted MegolmEvents — extract the inner event. + if isinstance(event, nio.MegolmEvent): + # Failed to decrypt. + logger.warning( + "Matrix: could not decrypt event %s in %s", + event.event_id, room.room_id, + ) + return + + # Skip edits (m.replace relation). + source_content = getattr(event, "source", {}).get("content", {}) + relates_to = source_content.get("m.relates_to", {}) + if relates_to.get("rel_type") == "m.replace": + return + + body = getattr(event, "body", "") or "" + if not body: + return + + # Determine chat type. + is_dm = self._dm_rooms.get(room.room_id, False) + if not is_dm and room.member_count == 2: + is_dm = True + chat_type = "dm" if is_dm else "group" + + # Thread support. + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + + # Reply-to detection. + reply_to = None + in_reply_to = relates_to.get("m.in_reply_to", {}) + if in_reply_to: + reply_to = in_reply_to.get("event_id") + + # Strip reply fallback from body (Matrix prepends "> ..." lines). + if reply_to and body.startswith("> "): + lines = body.split("\n") + stripped = [] + past_fallback = False + for line in lines: + if not past_fallback: + if line.startswith("> ") or line == ">": + continue + if line == "": + past_fallback = True + continue + past_fallback = True + stripped.append(line) + body = "\n".join(stripped) if stripped else body + + # Message type. + msg_type = MessageType.TEXT + if body.startswith("!") or body.startswith("/"): + msg_type = MessageType.COMMAND + + source = self.build_source( + chat_id=room.room_id, + chat_type=chat_type, + user_id=event.sender, + user_name=self._get_display_name(room, event.sender), + thread_id=thread_id, + ) + + msg_event = MessageEvent( + text=body, + message_type=msg_type, + source=source, + raw_message=getattr(event, "source", {}), + message_id=event.event_id, + reply_to=reply_to, + ) + + await self.handle_message(msg_event) + + async def _on_room_message_media(self, room: Any, event: Any) -> None: + """Handle incoming media messages (images, audio, video, files).""" + import nio + + # Ignore own messages. + if event.sender == self._user_id: + return + + # Startup grace. + event_ts = getattr(event, "server_timestamp", 0) / 1000.0 + if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: + return + + body = getattr(event, "body", "") or "" + url = getattr(event, "url", "") + + # Convert mxc:// to HTTP URL for downstream processing. + http_url = "" + if url and url.startswith("mxc://"): + http_url = self._mxc_to_http(url) + + # Determine message type from event class. + media_type = "document" + msg_type = MessageType.DOCUMENT + if isinstance(event, nio.RoomMessageImage): + msg_type = MessageType.PHOTO + media_type = "image" + elif isinstance(event, nio.RoomMessageAudio): + msg_type = MessageType.AUDIO + media_type = "audio" + elif isinstance(event, nio.RoomMessageVideo): + msg_type = MessageType.VIDEO + media_type = "video" + + is_dm = self._dm_rooms.get(room.room_id, False) + if not is_dm and room.member_count == 2: + is_dm = True + chat_type = "dm" if is_dm else "group" + + # Thread/reply detection. + source_content = getattr(event, "source", {}).get("content", {}) + relates_to = source_content.get("m.relates_to", {}) + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + + source = self.build_source( + chat_id=room.room_id, + chat_type=chat_type, + user_id=event.sender, + user_name=self._get_display_name(room, event.sender), + thread_id=thread_id, + ) + + msg_event = MessageEvent( + text=body, + message_type=msg_type, + source=source, + raw_message=getattr(event, "source", {}), + message_id=event.event_id, + media_urls=[http_url] if http_url else None, + media_types=[media_type] if http_url else None, + ) + + await self.handle_message(msg_event) + + async def _on_invite(self, room: Any, event: Any) -> None: + """Auto-join rooms when invited.""" + import nio + + if not isinstance(event, nio.InviteMemberEvent): + return + + # Only process invites directed at us. + if event.state_key != self._user_id: + return + + if event.membership != "invite": + return + + logger.info( + "Matrix: invited to %s by %s — joining", + room.room_id, event.sender, + ) + try: + resp = await self._client.join(room.room_id) + if isinstance(resp, nio.JoinResponse): + self._joined_rooms.add(room.room_id) + logger.info("Matrix: joined %s", room.room_id) + # Refresh DM cache since new room may be a DM. + await self._refresh_dm_cache() + else: + logger.warning( + "Matrix: failed to join %s: %s", + room.room_id, getattr(resp, "message", resp), + ) + except Exception as exc: + logger.warning("Matrix: error joining %s: %s", room.room_id, exc) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + async def _refresh_dm_cache(self) -> None: + """Refresh the DM room cache from m.direct account data. + + Tries the account_data API first, then falls back to parsing + the sync response's account_data for robustness. + """ + if not self._client: + return + + dm_data: Optional[Dict] = None + + # Primary: try the dedicated account data endpoint. + try: + resp = await self._client.get_account_data("m.direct") + if hasattr(resp, "content"): + dm_data = resp.content + elif isinstance(resp, dict): + dm_data = resp + except Exception as exc: + logger.debug("Matrix: get_account_data('m.direct') failed: %s — trying sync fallback", exc) + + # Fallback: parse from the client's account_data store (populated by sync). + if dm_data is None: + try: + # matrix-nio stores account data events on the client object + ad = getattr(self._client, "account_data", None) + if ad and isinstance(ad, dict) and "m.direct" in ad: + event = ad["m.direct"] + if hasattr(event, "content"): + dm_data = event.content + elif isinstance(event, dict): + dm_data = event + except Exception: + pass + + if dm_data is None: + return + + dm_room_ids: Set[str] = set() + for user_id, rooms in dm_data.items(): + if isinstance(rooms, list): + dm_room_ids.update(rooms) + + self._dm_rooms = { + rid: (rid in dm_room_ids) + for rid in self._joined_rooms + } + + def _get_display_name(self, room: Any, user_id: str) -> str: + """Get a user's display name in a room, falling back to user_id.""" + if room and hasattr(room, "users"): + user = room.users.get(user_id) + if user and getattr(user, "display_name", None): + return user.display_name + # Strip the @...:server format to just the localpart. + if user_id.startswith("@") and ":" in user_id: + return user_id[1:].split(":")[0] + return user_id + + def _mxc_to_http(self, mxc_url: str) -> str: + """Convert mxc://server/media_id to an HTTP download URL.""" + # mxc://matrix.org/abc123 → https://matrix.org/_matrix/client/v1/media/download/matrix.org/abc123 + # Uses the authenticated client endpoint (spec v1.11+) instead of the + # deprecated /_matrix/media/v3/download/ path. + if not mxc_url.startswith("mxc://"): + return mxc_url + parts = mxc_url[6:] # strip mxc:// + # Use our homeserver for download (federation handles the rest). + return f"{self._homeserver}/_matrix/client/v1/media/download/{parts}" + + def _markdown_to_html(self, text: str) -> str: + """Convert Markdown to Matrix-compatible HTML. + + Uses a simple conversion for common patterns. For full fidelity + a markdown-it style library could be used, but this covers the + common cases without an extra dependency. + """ + try: + import markdown + html = markdown.markdown( + text, + extensions=["fenced_code", "tables", "nl2br"], + ) + # Strip wrapping

tags for single-paragraph messages. + if html.count("

") == 1: + html = html.replace("

", "").replace("

", "") + return html + except ImportError: + pass + + # Minimal fallback: just handle bold, italic, code. + html = text + html = re.sub(r"\*\*(.+?)\*\*", r"\1", html) + html = re.sub(r"\*(.+?)\*", r"\1", html) + html = re.sub(r"`([^`]+)`", r"\1", html) + html = re.sub(r"\n", r"
", html) + return html diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py new file mode 100644 index 000000000..ef1d5b838 --- /dev/null +++ b/gateway/platforms/mattermost.py @@ -0,0 +1,664 @@ +"""Mattermost gateway adapter. + +Connects to a self-hosted (or cloud) Mattermost instance via its REST API +(v4) and WebSocket for real-time events. No external Mattermost library +required — uses aiohttp which is already a Hermes dependency. + +Environment variables: + MATTERMOST_URL Server URL (e.g. https://mm.example.com) + MATTERMOST_TOKEN Bot token or personal-access token + MATTERMOST_ALLOWED_USERS Comma-separated user IDs + MATTERMOST_HOME_CHANNEL Channel ID for cron/notification delivery +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import re +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import ( + BasePlatformAdapter, + MessageEvent, + MessageType, + SendResult, +) + +logger = logging.getLogger(__name__) + +# Mattermost post size limit (server default is 16383, but 4000 is the +# practical limit for readable messages — matching OpenClaw's choice). +MAX_POST_LENGTH = 4000 + +# Channel type codes returned by the Mattermost API. +_CHANNEL_TYPE_MAP = { + "D": "dm", + "G": "group", + "P": "group", # private channel → treat as group + "O": "channel", +} + +# Reconnect parameters (exponential backoff). +_RECONNECT_BASE_DELAY = 2.0 +_RECONNECT_MAX_DELAY = 60.0 +_RECONNECT_JITTER = 0.2 + + +def check_mattermost_requirements() -> bool: + """Return True if the Mattermost adapter can be used.""" + token = os.getenv("MATTERMOST_TOKEN", "") + url = os.getenv("MATTERMOST_URL", "") + if not token: + logger.debug("Mattermost: MATTERMOST_TOKEN not set") + return False + if not url: + logger.warning("Mattermost: MATTERMOST_URL not set") + return False + try: + import aiohttp # noqa: F401 + return True + except ImportError: + logger.warning("Mattermost: aiohttp not installed") + return False + + +class MattermostAdapter(BasePlatformAdapter): + """Gateway adapter for Mattermost (self-hosted or cloud).""" + + def __init__(self, config: PlatformConfig): + super().__init__(config, Platform.MATTERMOST) + + self._base_url: str = ( + config.extra.get("url", "") + or os.getenv("MATTERMOST_URL", "") + ).rstrip("/") + self._token: str = config.token or os.getenv("MATTERMOST_TOKEN", "") + + self._bot_user_id: str = "" + self._bot_username: str = "" + + # aiohttp session + websocket handle + self._session: Any = None # aiohttp.ClientSession + self._ws: Any = None # aiohttp.ClientWebSocketResponse + self._ws_task: Optional[asyncio.Task] = None + self._reconnect_task: Optional[asyncio.Task] = None + self._closing = False + + # Reply mode: "thread" to nest replies, "off" for flat messages. + self._reply_mode: str = ( + config.extra.get("reply_mode", "") + or os.getenv("MATTERMOST_REPLY_MODE", "off") + ).lower() + + # Dedup cache: post_id → timestamp (prevent reprocessing) + self._seen_posts: Dict[str, float] = {} + self._SEEN_MAX = 2000 + self._SEEN_TTL = 300 # 5 minutes + + # ------------------------------------------------------------------ + # HTTP helpers + # ------------------------------------------------------------------ + + def _headers(self) -> Dict[str, str]: + return { + "Authorization": f"Bearer {self._token}", + "Content-Type": "application/json", + } + + async def _api_get(self, path: str) -> Dict[str, Any]: + """GET /api/v4/{path}.""" + import aiohttp + url = f"{self._base_url}/api/v4/{path.lstrip('/')}" + try: + async with self._session.get(url, headers=self._headers()) as resp: + if resp.status >= 400: + body = await resp.text() + logger.error("MM API GET %s → %s: %s", path, resp.status, body[:200]) + return {} + return await resp.json() + except aiohttp.ClientError as exc: + logger.error("MM API GET %s network error: %s", path, exc) + return {} + + async def _api_post( + self, path: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """POST /api/v4/{path} with JSON body.""" + import aiohttp + url = f"{self._base_url}/api/v4/{path.lstrip('/')}" + try: + async with self._session.post( + url, headers=self._headers(), json=payload + ) as resp: + if resp.status >= 400: + body = await resp.text() + logger.error("MM API POST %s → %s: %s", path, resp.status, body[:200]) + return {} + return await resp.json() + except aiohttp.ClientError as exc: + logger.error("MM API POST %s network error: %s", path, exc) + return {} + + async def _api_put( + self, path: str, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """PUT /api/v4/{path} with JSON body.""" + import aiohttp + url = f"{self._base_url}/api/v4/{path.lstrip('/')}" + try: + async with self._session.put( + url, headers=self._headers(), json=payload + ) as resp: + if resp.status >= 400: + body = await resp.text() + logger.error("MM API PUT %s → %s: %s", path, resp.status, body[:200]) + return {} + return await resp.json() + except aiohttp.ClientError as exc: + logger.error("MM API PUT %s network error: %s", path, exc) + return {} + + async def _upload_file( + self, channel_id: str, file_data: bytes, filename: str, content_type: str = "application/octet-stream" + ) -> Optional[str]: + """Upload a file and return its file ID, or None on failure.""" + import aiohttp + + url = f"{self._base_url}/api/v4/files" + form = aiohttp.FormData() + form.add_field("channel_id", channel_id) + form.add_field( + "files", + file_data, + filename=filename, + content_type=content_type, + ) + headers = {"Authorization": f"Bearer {self._token}"} + async with self._session.post(url, headers=headers, data=form) as resp: + if resp.status >= 400: + body = await resp.text() + logger.error("MM file upload → %s: %s", resp.status, body[:200]) + return None + data = await resp.json() + infos = data.get("file_infos", []) + return infos[0]["id"] if infos else None + + # ------------------------------------------------------------------ + # Required overrides + # ------------------------------------------------------------------ + + async def connect(self) -> bool: + """Connect to Mattermost and start the WebSocket listener.""" + import aiohttp + + if not self._base_url or not self._token: + logger.error("Mattermost: URL or token not configured") + return False + + self._session = aiohttp.ClientSession() + self._closing = False + + # Verify credentials and fetch bot identity. + me = await self._api_get("users/me") + if not me or "id" not in me: + logger.error("Mattermost: failed to authenticate — check MATTERMOST_TOKEN and MATTERMOST_URL") + await self._session.close() + return False + + self._bot_user_id = me["id"] + self._bot_username = me.get("username", "") + logger.info( + "Mattermost: authenticated as @%s (%s) on %s", + self._bot_username, + self._bot_user_id, + self._base_url, + ) + + # Start WebSocket in background. + self._ws_task = asyncio.create_task(self._ws_loop()) + self._mark_connected() + return True + + async def disconnect(self) -> None: + """Disconnect from Mattermost.""" + self._closing = True + + if self._ws_task and not self._ws_task.done(): + self._ws_task.cancel() + try: + await self._ws_task + except (asyncio.CancelledError, Exception): + pass + + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + + if self._ws: + await self._ws.close() + self._ws = None + + if self._session and not self._session.closed: + await self._session.close() + + logger.info("Mattermost: disconnected") + + async def send( + self, + chat_id: str, + content: str, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send a message (or multiple chunks) to a channel.""" + if not content: + return SendResult(success=True) + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted, MAX_POST_LENGTH) + + last_id = None + for chunk in chunks: + payload: Dict[str, Any] = { + "channel_id": chat_id, + "message": chunk, + } + # Thread support: reply_to is the root post ID. + if reply_to and self._reply_mode == "thread": + payload["root_id"] = reply_to + + data = await self._api_post("posts", payload) + if not data or "id" not in data: + return SendResult(success=False, error="Failed to create post") + last_id = data["id"] + + return SendResult(success=True, message_id=last_id) + + async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: + """Return channel name and type.""" + data = await self._api_get(f"channels/{chat_id}") + if not data: + return {"name": chat_id, "type": "channel"} + + ch_type = _CHANNEL_TYPE_MAP.get(data.get("type", "O"), "channel") + display_name = data.get("display_name") or data.get("name") or chat_id + return {"name": display_name, "type": ch_type} + + # ------------------------------------------------------------------ + # Optional overrides + # ------------------------------------------------------------------ + + async def send_typing( + self, chat_id: str, metadata: Optional[Dict[str, Any]] = None + ) -> None: + """Send a typing indicator.""" + await self._api_post( + f"users/{self._bot_user_id}/typing", + {"channel_id": chat_id}, + ) + + async def edit_message( + self, chat_id: str, message_id: str, content: str + ) -> SendResult: + """Edit an existing post.""" + formatted = self.format_message(content) + data = await self._api_put( + f"posts/{message_id}/patch", + {"message": formatted}, + ) + if not data or "id" not in data: + return SendResult(success=False, error="Failed to edit post") + return SendResult(success=True, message_id=data["id"]) + + async def send_image( + self, + chat_id: str, + image_url: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Download an image and upload it as a file attachment.""" + return await self._send_url_as_file( + chat_id, image_url, caption, reply_to, "image" + ) + + async def send_image_file( + self, + chat_id: str, + image_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a local image file.""" + return await self._send_local_file( + chat_id, image_path, caption, reply_to + ) + + async def send_document( + self, + chat_id: str, + file_path: str, + caption: Optional[str] = None, + file_name: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a local file as a document.""" + return await self._send_local_file( + chat_id, file_path, caption, reply_to, file_name + ) + + async def send_voice( + self, + chat_id: str, + audio_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload an audio file.""" + return await self._send_local_file( + chat_id, audio_path, caption, reply_to + ) + + async def send_video( + self, + chat_id: str, + video_path: str, + caption: Optional[str] = None, + reply_to: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Upload a video file.""" + return await self._send_local_file( + chat_id, video_path, caption, reply_to + ) + + def format_message(self, content: str) -> str: + """Mattermost uses standard Markdown — mostly pass through. + + Strip image markdown into plain links (files are uploaded separately). + """ + # Convert ![alt](url) to just the URL — Mattermost renders + # image URLs as inline previews automatically. + content = re.sub(r"!\[([^\]]*)\]\(([^)]+)\)", r"\2", content) + return content + + # ------------------------------------------------------------------ + # File helpers + # ------------------------------------------------------------------ + + async def _send_url_as_file( + self, + chat_id: str, + url: str, + caption: Optional[str], + reply_to: Optional[str], + kind: str = "file", + ) -> SendResult: + """Download a URL and upload it as a file attachment.""" + import aiohttp + try: + async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + if resp.status >= 400: + # Fall back to sending the URL as text. + return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) + file_data = await resp.read() + ct = resp.content_type or "application/octet-stream" + # Derive filename from URL. + fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png" + except Exception as exc: + logger.warning("Mattermost: failed to download %s: %s", url, exc) + return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) + + file_id = await self._upload_file(chat_id, file_data, fname, ct) + if not file_id: + return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to) + + payload: Dict[str, Any] = { + "channel_id": chat_id, + "message": caption or "", + "file_ids": [file_id], + } + if reply_to and self._reply_mode == "thread": + payload["root_id"] = reply_to + + data = await self._api_post("posts", payload) + if not data or "id" not in data: + return SendResult(success=False, error="Failed to post with file") + return SendResult(success=True, message_id=data["id"]) + + async def _send_local_file( + self, + chat_id: str, + file_path: str, + caption: Optional[str], + reply_to: Optional[str], + file_name: Optional[str] = None, + ) -> SendResult: + """Upload a local file and attach it to a post.""" + import mimetypes + + p = Path(file_path) + if not p.exists(): + return await self.send( + chat_id, f"{caption or ''}\n(file not found: {file_path})", reply_to + ) + + fname = file_name or p.name + ct = mimetypes.guess_type(fname)[0] or "application/octet-stream" + file_data = p.read_bytes() + + file_id = await self._upload_file(chat_id, file_data, fname, ct) + if not file_id: + return SendResult(success=False, error="File upload failed") + + payload: Dict[str, Any] = { + "channel_id": chat_id, + "message": caption or "", + "file_ids": [file_id], + } + if reply_to and self._reply_mode == "thread": + payload["root_id"] = reply_to + + data = await self._api_post("posts", payload) + if not data or "id" not in data: + return SendResult(success=False, error="Failed to post with file") + return SendResult(success=True, message_id=data["id"]) + + # ------------------------------------------------------------------ + # WebSocket + # ------------------------------------------------------------------ + + async def _ws_loop(self) -> None: + """Connect to the WebSocket and listen for events, reconnecting on failure.""" + delay = _RECONNECT_BASE_DELAY + while not self._closing: + try: + await self._ws_connect_and_listen() + # Clean disconnect — reset delay. + delay = _RECONNECT_BASE_DELAY + except asyncio.CancelledError: + return + except Exception as exc: + if self._closing: + return + logger.warning("Mattermost WS error: %s — reconnecting in %.0fs", exc, delay) + + if self._closing: + return + + # Exponential backoff with jitter. + import random + jitter = delay * _RECONNECT_JITTER * random.random() + await asyncio.sleep(delay + jitter) + delay = min(delay * 2, _RECONNECT_MAX_DELAY) + + async def _ws_connect_and_listen(self) -> None: + """Single WebSocket session: connect, authenticate, process events.""" + # Build WS URL: https:// → wss://, http:// → ws:// + ws_url = re.sub(r"^http", "ws", self._base_url) + "/api/v4/websocket" + logger.info("Mattermost: connecting to %s", ws_url) + + self._ws = await self._session.ws_connect(ws_url, heartbeat=30.0) + + # Authenticate via the WebSocket. + auth_msg = { + "seq": 1, + "action": "authentication_challenge", + "data": {"token": self._token}, + } + await self._ws.send_json(auth_msg) + logger.info("Mattermost: WebSocket connected and authenticated") + + async for raw_msg in self._ws: + if self._closing: + return + + if raw_msg.type in ( + raw_msg.type.TEXT, + raw_msg.type.BINARY, + ): + try: + event = json.loads(raw_msg.data) + except (json.JSONDecodeError, TypeError): + continue + await self._handle_ws_event(event) + elif raw_msg.type in ( + raw_msg.type.ERROR, + raw_msg.type.CLOSE, + raw_msg.type.CLOSING, + raw_msg.type.CLOSED, + ): + logger.info("Mattermost: WebSocket closed (%s)", raw_msg.type) + break + + async def _handle_ws_event(self, event: Dict[str, Any]) -> None: + """Process a single WebSocket event.""" + event_type = event.get("event") + if event_type != "posted": + return + + data = event.get("data", {}) + raw_post_str = data.get("post") + if not raw_post_str: + return + + try: + post = json.loads(raw_post_str) + except (json.JSONDecodeError, TypeError): + return + + # Ignore own messages. + if post.get("user_id") == self._bot_user_id: + return + + # Ignore system posts. + if post.get("type"): + return + + post_id = post.get("id", "") + + # Dedup. + self._prune_seen() + if post_id in self._seen_posts: + return + self._seen_posts[post_id] = time.time() + + # Build message event. + channel_id = post.get("channel_id", "") + channel_type_raw = data.get("channel_type", "O") + chat_type = _CHANNEL_TYPE_MAP.get(channel_type_raw, "channel") + + # For DMs, user_id is sufficient. For channels, check for @mention. + message_text = post.get("message", "") + + # Resolve sender info. + sender_id = post.get("user_id", "") + sender_name = data.get("sender_name", "").lstrip("@") or sender_id + + # Thread support: if the post is in a thread, use root_id. + thread_id = post.get("root_id") or None + + # Determine message type. + file_ids = post.get("file_ids") or [] + msg_type = MessageType.TEXT + if message_text.startswith("/"): + msg_type = MessageType.COMMAND + + # Download file attachments immediately (URLs require auth headers + # that downstream tools won't have). + media_urls: List[str] = [] + media_types: List[str] = [] + for fid in file_ids: + try: + file_info = await self._api_get(f"files/{fid}/info") + fname = file_info.get("name", f"file_{fid}") + ext = Path(fname).suffix or "" + mime = file_info.get("mime_type", "application/octet-stream") + + import aiohttp + dl_url = f"{self._base_url}/api/v4/files/{fid}" + async with self._session.get( + dl_url, + headers={"Authorization": f"Bearer {self._token}"}, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + if resp.status < 400: + file_data = await resp.read() + from gateway.platforms.base import cache_image_from_bytes, cache_document_from_bytes + if mime.startswith("image/"): + local_path = cache_image_from_bytes(file_data, ext or ".png") + media_urls.append(local_path) + media_types.append("image") + elif mime.startswith("audio/"): + from gateway.platforms.base import cache_audio_from_bytes + local_path = cache_audio_from_bytes(file_data, ext or ".ogg") + media_urls.append(local_path) + media_types.append("audio") + else: + local_path = cache_document_from_bytes(file_data, fname) + media_urls.append(local_path) + media_types.append("document") + else: + logger.warning("Mattermost: failed to download file %s: HTTP %s", fid, resp.status) + except Exception as exc: + logger.warning("Mattermost: error downloading file %s: %s", fid, exc) + + source = self.build_source( + chat_id=channel_id, + chat_type=chat_type, + user_id=sender_id, + user_name=sender_name, + thread_id=thread_id, + ) + + msg_event = MessageEvent( + text=message_text, + message_type=msg_type, + source=source, + raw_message=post, + message_id=post_id, + media_urls=media_urls if media_urls else None, + media_types=media_types if media_types else None, + ) + + await self.handle_message(msg_event) + + def _prune_seen(self) -> None: + """Remove expired entries from the dedup cache.""" + if len(self._seen_posts) < self._SEEN_MAX: + return + now = time.time() + self._seen_posts = { + pid: ts + for pid, ts in self._seen_posts.items() + if now - ts < self._SEEN_TTL + } diff --git a/gateway/platforms/sms.py b/gateway/platforms/sms.py index f83ecaf97..03e2475e7 100644 --- a/gateway/platforms/sms.py +++ b/gateway/platforms/sms.py @@ -1,19 +1,27 @@ -"""SMS (Telnyx) platform adapter. +"""SMS (Twilio) platform adapter. -Connects to the Telnyx REST API for outbound SMS and runs an aiohttp +Connects to the Twilio REST API for outbound SMS and runs an aiohttp webhook server to receive inbound messages. -Requires: - - aiohttp installed: pip install 'hermes-agent[sms]' - - TELNYX_API_KEY environment variable set - - TELNYX_FROM_NUMBERS: comma-separated E.164 numbers (e.g. +15551234567) +Shares credentials with the optional telephony skill — same env vars: + - TWILIO_ACCOUNT_SID + - TWILIO_AUTH_TOKEN + - TWILIO_PHONE_NUMBER (E.164 from-number, e.g. +15551234567) + +Gateway-specific env vars: + - SMS_WEBHOOK_PORT (default 8080) + - SMS_ALLOWED_USERS (comma-separated E.164 phone numbers) + - SMS_ALLOW_ALL_USERS (true/false) + - SMS_HOME_CHANNEL (phone number for cron delivery) """ import asyncio +import base64 import json import logging import os import re +import urllib.parse from typing import Any, Dict, List, Optional from gateway.config import Platform, PlatformConfig @@ -26,7 +34,7 @@ from gateway.platforms.base import ( logger = logging.getLogger(__name__) -TELNYX_BASE = "https://api.telnyx.com/v2" +TWILIO_API_BASE = "https://api.twilio.com/2010-04-01/Accounts" MAX_SMS_LENGTH = 1600 # ~10 SMS segments DEFAULT_WEBHOOK_PORT = 8080 @@ -35,17 +43,12 @@ _PHONE_RE = re.compile(r"\+[1-9]\d{6,14}") def _redact_phone(phone: str) -> str: - """Redact a phone number for logging: +15551234567 -> +155****4567.""" + """Redact a phone number for logging: +15551234567 -> +1555***4567.""" if not phone: return "" if len(phone) <= 8: - return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" - return phone[:4] + "****" + phone[-4:] - - -def _parse_comma_list(value: str) -> List[str]: - """Split a comma-separated string into a list, stripping whitespace.""" - return [v.strip() for v in value.split(",") if v.strip()] + return phone[:2] + "***" + phone[-2:] if len(phone) > 4 else "****" + return phone[:5] + "***" + phone[-4:] def check_sms_requirements() -> bool: @@ -54,32 +57,35 @@ def check_sms_requirements() -> bool: import aiohttp # noqa: F401 except ImportError: return False - return bool(os.getenv("TELNYX_API_KEY")) + return bool(os.getenv("TWILIO_ACCOUNT_SID") and os.getenv("TWILIO_AUTH_TOKEN")) class SmsAdapter(BasePlatformAdapter): """ - Telnyx SMS <-> Hermes gateway adapter. + Twilio SMS <-> Hermes gateway adapter. Each inbound phone number gets its own Hermes session (multi-tenant). - Tracks which owned number received each user's message to reply from - the same number. + Replies are always sent from the configured TWILIO_PHONE_NUMBER. """ + MAX_MESSAGE_LENGTH = MAX_SMS_LENGTH + def __init__(self, config: PlatformConfig): super().__init__(config, Platform.SMS) - self._api_key: str = os.environ["TELNYX_API_KEY"] + self._account_sid: str = os.environ["TWILIO_ACCOUNT_SID"] + self._auth_token: str = os.environ["TWILIO_AUTH_TOKEN"] + self._from_number: str = os.getenv("TWILIO_PHONE_NUMBER", "") self._webhook_port: int = int( os.getenv("SMS_WEBHOOK_PORT", str(DEFAULT_WEBHOOK_PORT)) ) - # Set of owned numbers - self._from_numbers: set = set( - _parse_comma_list(os.getenv("TELNYX_FROM_NUMBERS", "")) - ) - # Runtime map: user phone -> which owned number to reply from - self._reply_from: Dict[str, str] = {} self._runner = None + def _basic_auth_header(self) -> str: + """Build HTTP Basic auth header value for Twilio.""" + creds = f"{self._account_sid}:{self._auth_token}" + encoded = base64.b64encode(creds.encode("ascii")).decode("ascii") + return f"Basic {encoded}" + # ------------------------------------------------------------------ # Required abstract methods # ------------------------------------------------------------------ @@ -88,8 +94,12 @@ class SmsAdapter(BasePlatformAdapter): import aiohttp from aiohttp import web + if not self._from_number: + logger.error("[sms] TWILIO_PHONE_NUMBER not set — cannot send replies") + return False + app = web.Application() - app.router.add_post("/webhooks/telnyx", self._handle_webhook) + app.router.add_post("/webhooks/twilio", self._handle_webhook) app.router.add_get("/health", lambda _: web.Response(text="ok")) self._runner = web.AppRunner(app) @@ -98,11 +108,10 @@ class SmsAdapter(BasePlatformAdapter): await site.start() self._running = True - from_display = ", ".join(_redact_phone(n) for n in self._from_numbers) or "(none)" logger.info( - "[sms] Webhook server listening on port %d, from numbers: %s", + "[sms] Twilio webhook server listening on port %d, from: %s", self._webhook_port, - from_display, + _redact_phone(self._from_number), ) return True @@ -122,40 +131,41 @@ class SmsAdapter(BasePlatformAdapter): ) -> SendResult: import aiohttp - from_number = self._get_reply_from(chat_id, metadata) formatted = self.format_message(content) chunks = self.truncate_message(formatted) last_result = SendResult(success=True) + url = f"{TWILIO_API_BASE}/{self._account_sid}/Messages.json" + headers = { + "Authorization": self._basic_auth_header(), + } + async with aiohttp.ClientSession() as session: - for i, chunk in enumerate(chunks): - payload = {"from": from_number, "to": chat_id, "text": chunk} - headers = { - "Authorization": f"Bearer {self._api_key}", - "Content-Type": "application/json", - } + for chunk in chunks: + form_data = aiohttp.FormData() + form_data.add_field("From", self._from_number) + form_data.add_field("To", chat_id) + form_data.add_field("Body", chunk) + try: - async with session.post( - f"{TELNYX_BASE}/messages", - json=payload, - headers=headers, - ) as resp: + async with session.post(url, data=form_data, headers=headers) as resp: body = await resp.json() if resp.status >= 400: + error_msg = body.get("message", str(body)) logger.error( - "[sms] send failed %s: %s %s", + "[sms] send failed to %s: %s %s", _redact_phone(chat_id), resp.status, - body, + error_msg, ) return SendResult( success=False, - error=f"Telnyx {resp.status}: {body}", + error=f"Twilio {resp.status}: {error_msg}", ) - msg_id = body.get("data", {}).get("id", "") - last_result = SendResult(success=True, message_id=msg_id) + msg_sid = body.get("sid", "") + last_result = SendResult(success=True, message_id=msg_sid) except Exception as e: - logger.error("[sms] send error %s: %s", _redact_phone(chat_id), e) + logger.error("[sms] send error to %s: %s", _redact_phone(chat_id), e) return SendResult(success=False, error=str(e)) return last_result @@ -168,7 +178,7 @@ class SmsAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ def format_message(self, content: str) -> str: - """Strip markdown -- SMS renders it as literal characters.""" + """Strip markdown — SMS renders it as literal characters.""" content = re.sub(r"\*\*(.+?)\*\*", r"\1", content, flags=re.DOTALL) content = re.sub(r"\*(.+?)\*", r"\1", content, flags=re.DOTALL) content = re.sub(r"__(.+?)__", r"\1", content, flags=re.DOTALL) @@ -180,28 +190,8 @@ class SmsAdapter(BasePlatformAdapter): content = re.sub(r"\n{3,}", "\n\n", content) return content.strip() - def truncate_message( - self, content: str, max_length: int = MAX_SMS_LENGTH - ) -> List[str]: - """Split into <=1600-char chunks (10 SMS segments).""" - if len(content) <= max_length: - return [content] - chunks: List[str] = [] - while content: - if len(content) <= max_length: - chunks.append(content) - break - split_at = content.rfind("\n", 0, max_length) - if split_at < max_length // 2: - split_at = content.rfind(" ", 0, max_length) - if split_at < 1: - split_at = max_length - chunks.append(content[:split_at].strip()) - content = content[split_at:].strip() - return chunks - # ------------------------------------------------------------------ - # Telnyx webhook handler + # Twilio webhook handler # ------------------------------------------------------------------ async def _handle_webhook(self, request) -> "aiohttp.web.Response": @@ -209,32 +199,35 @@ class SmsAdapter(BasePlatformAdapter): try: raw = await request.read() - body = json.loads(raw.decode("utf-8")) + # Twilio sends form-encoded data, not JSON + form = urllib.parse.parse_qs(raw.decode("utf-8")) except Exception as e: logger.error("[sms] webhook parse error: %s", e) - return web.json_response({"error": "invalid json"}, status=400) + return web.Response( + text='', + content_type="application/xml", + status=400, + ) - # Only handle inbound messages - if body.get("data", {}).get("event_type") != "message.received": - return web.json_response({"received": True}) - - payload = body["data"]["payload"] - from_number: str = payload.get("from", {}).get("phone_number", "") - to_list = payload.get("to", []) - to_number: str = to_list[0].get("phone_number", "") if to_list else "" - text: str = payload.get("text", "").strip() + # Extract fields (parse_qs returns lists) + from_number = (form.get("From", [""]))[0].strip() + to_number = (form.get("To", [""]))[0].strip() + text = (form.get("Body", [""]))[0].strip() + message_sid = (form.get("MessageSid", [""]))[0].strip() if not from_number or not text: - return web.json_response({"received": True}) + return web.Response( + text='', + content_type="application/xml", + ) - # Ignore messages sent FROM one of our own numbers (echo loop prevention) - if from_number in self._from_numbers: + # Ignore messages from our own number (echo prevention) + if from_number == self._from_number: logger.debug("[sms] ignoring echo from own number %s", _redact_phone(from_number)) - return web.json_response({"received": True}) - - # Remember which owned number received this user's message - if to_number and to_number in self._from_numbers: - self._reply_from[from_number] = to_number + return web.Response( + text='', + content_type="application/xml", + ) logger.info( "[sms] inbound from %s -> %s: %s", @@ -254,29 +247,15 @@ class SmsAdapter(BasePlatformAdapter): text=text, message_type=MessageType.TEXT, source=source, - raw_message=body, - message_id=payload.get("id"), + raw_message=form, + message_id=message_sid, ) - # Non-blocking: Telnyx expects a fast 200 + # Non-blocking: Twilio expects a fast response asyncio.create_task(self.handle_message(event)) - return web.json_response({"received": True}) - # ------------------------------------------------------------------ - # Internal helpers - # ------------------------------------------------------------------ - - def _get_reply_from( - self, user_phone: str, metadata: Optional[Dict] = None - ) -> str: - """Determine which owned number to send from.""" - if metadata and "from_number" in metadata: - return metadata["from_number"] - if user_phone in self._reply_from: - return self._reply_from[user_phone] - if self._from_numbers: - return next(iter(self._from_numbers)) - raise RuntimeError( - "No FROM number configured (TELNYX_FROM_NUMBERS) and no prior " - "reply_from mapping for this user" + # Return empty TwiML — we send replies via the REST API, not inline TwiML + return web.Response( + text='', + content_type="application/xml", ) diff --git a/gateway/run.py b/gateway/run.py index ceb7d92fd..c820f2b06 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -848,7 +848,8 @@ class GatewayRunner: os.getenv(v) for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS", "WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS", - "SMS_ALLOWED_USERS", "GATEWAY_ALLOWED_USERS") + "SMS_ALLOWED_USERS", + "GATEWAY_ALLOWED_USERS") ) _allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") if not _any_allowlist and not _allow_all: @@ -983,6 +984,16 @@ class GatewayRunner: ): self._schedule_update_notification_watch() + # Drain any recovered process watchers (from crash recovery checkpoint) + try: + from tools.process_registry import process_registry + while process_registry.pending_watchers: + watcher = process_registry.pending_watchers.pop(0) + asyncio.create_task(self._run_process_watcher(watcher)) + logger.info("Resumed watcher for recovered process %s", watcher.get("session_id")) + except Exception as e: + logger.error("Recovered watcher setup error: %s", e) + # Start background session expiry watcher for proactive memory flushing asyncio.create_task(self._session_expiry_watcher()) @@ -1135,10 +1146,31 @@ class GatewayRunner: elif platform == Platform.SMS: from gateway.platforms.sms import SmsAdapter, check_sms_requirements if not check_sms_requirements(): - logger.warning("SMS: aiohttp not installed or TELNYX_API_KEY not set. Run: pip install 'hermes-agent[sms]'") + logger.warning("SMS: aiohttp not installed or TWILIO_ACCOUNT_SID/TWILIO_AUTH_TOKEN not set") return None return SmsAdapter(config) + elif platform == Platform.DINGTALK: + from gateway.platforms.dingtalk import DingTalkAdapter, check_dingtalk_requirements + if not check_dingtalk_requirements(): + logger.warning("DingTalk: dingtalk-stream not installed or DINGTALK_CLIENT_ID/SECRET not set") + return None + return DingTalkAdapter(config) + + elif platform == Platform.MATTERMOST: + from gateway.platforms.mattermost import MattermostAdapter, check_mattermost_requirements + if not check_mattermost_requirements(): + logger.warning("Mattermost: MATTERMOST_TOKEN or MATTERMOST_URL not set, or aiohttp missing") + return None + return MattermostAdapter(config) + + elif platform == Platform.MATRIX: + from gateway.platforms.matrix import MatrixAdapter, check_matrix_requirements + if not check_matrix_requirements(): + logger.warning("Matrix: matrix-nio not installed or credentials not set. Run: pip install 'matrix-nio[e2e]'") + return None + return MatrixAdapter(config) + return None def _is_user_authorized(self, source: SessionSource) -> bool: @@ -1170,6 +1202,9 @@ class GatewayRunner: Platform.SIGNAL: "SIGNAL_ALLOWED_USERS", Platform.EMAIL: "EMAIL_ALLOWED_USERS", Platform.SMS: "SMS_ALLOWED_USERS", + Platform.MATTERMOST: "MATTERMOST_ALLOWED_USERS", + Platform.MATRIX: "MATRIX_ALLOWED_USERS", + Platform.DINGTALK: "DINGTALK_ALLOWED_USERS", } platform_allow_all_map = { Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS", @@ -1179,6 +1214,9 @@ class GatewayRunner: Platform.SIGNAL: "SIGNAL_ALLOW_ALL_USERS", Platform.EMAIL: "EMAIL_ALLOW_ALL_USERS", Platform.SMS: "SMS_ALLOW_ALL_USERS", + Platform.MATTERMOST: "MATTERMOST_ALLOW_ALL_USERS", + Platform.MATRIX: "MATRIX_ALLOW_ALL_USERS", + Platform.DINGTALK: "DINGTALK_ALLOW_ALL_USERS", } # Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true) @@ -1430,8 +1468,19 @@ class GatewayRunner: return f"Quick command error: {e}" else: return f"Quick command '/{command}' has no command defined." + elif qcmd.get("type") == "alias": + target = qcmd.get("target", "").strip() + if target: + target = target if target.startswith("/") else f"/{target}" + target_command = target.lstrip("/") + user_args = event.get_command_args().strip() + event.text = f"{target} {user_args}".strip() + command = target_command + # Fall through to normal command dispatch below + else: + return f"Quick command '/{command}' has no target defined." else: - return f"Quick command '/{command}' has unsupported type (only 'exec' is supported)." + return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')." # Skill slash commands: /skill-name loads the skill and sends to agent if command: @@ -1442,7 +1491,7 @@ class GatewayRunner: if cmd_key in skill_cmds: user_instruction = event.get_command_args().strip() msg = build_skill_invocation_message( - cmd_key, user_instruction, task_id=session_key + cmd_key, user_instruction, task_id=_quick_key ) if msg: event.text = msg @@ -1503,8 +1552,9 @@ class GatewayRunner: # Read privacy.redact_pii from config (re-read per message) _redact_pii = False try: + import yaml as _pii_yaml with open(_config_path, encoding="utf-8") as _pf: - _pcfg = yaml.safe_load(_pf) or {} + _pcfg = _pii_yaml.safe_load(_pf) or {} _redact_pii = bool((_pcfg.get("privacy") or {}).get("redact_pii", False)) except Exception: pass @@ -2050,8 +2100,15 @@ class GatewayRunner: session_entry.session_key, input_tokens=agent_result.get("input_tokens", 0), output_tokens=agent_result.get("output_tokens", 0), + cache_read_tokens=agent_result.get("cache_read_tokens", 0), + cache_write_tokens=agent_result.get("cache_write_tokens", 0), last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), model=agent_result.get("model"), + estimated_cost_usd=agent_result.get("estimated_cost_usd"), + cost_status=agent_result.get("cost_status"), + cost_source=agent_result.get("cost_source"), + provider=agent_result.get("provider"), + base_url=agent_result.get("base_url"), ) # Auto voice reply: send TTS audio before the text response @@ -2121,7 +2178,14 @@ class GatewayRunner: # Reset the session new_entry = self.session_store.reset_session(session_key) - + + # Emit session:end hook (session is ending) + await self.hooks.emit("session:end", { + "platform": source.platform.value if source.platform else "", + "user_id": source.user_id, + "session_key": session_key, + }) + # Emit session:reset hook await self.hooks.emit("session:reset", { "platform": source.platform.value if source.platform else "", @@ -3027,6 +3091,7 @@ class GatewayRunner: Platform.SIGNAL: "hermes-signal", Platform.HOMEASSISTANT: "hermes-homeassistant", Platform.EMAIL: "hermes-email", + Platform.DINGTALK: "hermes-dingtalk", } platform_toolsets_config = {} try: @@ -3048,6 +3113,7 @@ class GatewayRunner: Platform.SIGNAL: "signal", Platform.HOMEASSISTANT: "homeassistant", Platform.EMAIL: "email", + Platform.DINGTALK: "dingtalk", }.get(source.platform, "telegram") config_toolsets = platform_toolsets_config.get(platform_config_key) @@ -4045,6 +4111,7 @@ class GatewayRunner: Platform.SIGNAL: "hermes-signal", Platform.HOMEASSISTANT: "hermes-homeassistant", Platform.EMAIL: "hermes-email", + Platform.DINGTALK: "hermes-dingtalk", } # Try to load platform_toolsets from config @@ -4069,6 +4136,7 @@ class GatewayRunner: Platform.SIGNAL: "signal", Platform.HOMEASSISTANT: "homeassistant", Platform.EMAIL: "email", + Platform.DINGTALK: "dingtalk", }.get(source.platform, "telegram") # Use config override if present (list of toolsets), otherwise hardcoded default diff --git a/gateway/session.py b/gateway/session.py index d0bf0cfe4..e58b6d689 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -343,7 +343,11 @@ class SessionEntry: # Token tracking input_tokens: int = 0 output_tokens: int = 0 + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 total_tokens: int = 0 + estimated_cost_usd: float = 0.0 + cost_status: str = "unknown" # Last API-reported prompt tokens (for accurate compression pre-check) last_prompt_tokens: int = 0 @@ -363,8 +367,12 @@ class SessionEntry: "chat_type": self.chat_type, "input_tokens": self.input_tokens, "output_tokens": self.output_tokens, + "cache_read_tokens": self.cache_read_tokens, + "cache_write_tokens": self.cache_write_tokens, "total_tokens": self.total_tokens, "last_prompt_tokens": self.last_prompt_tokens, + "estimated_cost_usd": self.estimated_cost_usd, + "cost_status": self.cost_status, } if self.origin: result["origin"] = self.origin.to_dict() @@ -394,8 +402,12 @@ class SessionEntry: chat_type=data.get("chat_type", "dm"), input_tokens=data.get("input_tokens", 0), output_tokens=data.get("output_tokens", 0), + cache_read_tokens=data.get("cache_read_tokens", 0), + cache_write_tokens=data.get("cache_write_tokens", 0), total_tokens=data.get("total_tokens", 0), last_prompt_tokens=data.get("last_prompt_tokens", 0), + estimated_cost_usd=data.get("estimated_cost_usd", 0.0), + cost_status=data.get("cost_status", "unknown"), ) @@ -696,8 +708,15 @@ class SessionStore: session_key: str, input_tokens: int = 0, output_tokens: int = 0, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, last_prompt_tokens: int = None, model: str = None, + estimated_cost_usd: Optional[float] = None, + cost_status: Optional[str] = None, + cost_source: Optional[str] = None, + provider: Optional[str] = None, + base_url: Optional[str] = None, ) -> None: """Update a session's metadata after an interaction.""" self._ensure_loaded() @@ -707,15 +726,35 @@ class SessionStore: entry.updated_at = datetime.now() entry.input_tokens += input_tokens entry.output_tokens += output_tokens + entry.cache_read_tokens += cache_read_tokens + entry.cache_write_tokens += cache_write_tokens if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens - entry.total_tokens = entry.input_tokens + entry.output_tokens + if estimated_cost_usd is not None: + entry.estimated_cost_usd += estimated_cost_usd + if cost_status: + entry.cost_status = cost_status + entry.total_tokens = ( + entry.input_tokens + + entry.output_tokens + + entry.cache_read_tokens + + entry.cache_write_tokens + ) self._save() if self._db: try: self._db.update_token_counts( - entry.session_id, input_tokens, output_tokens, + entry.session_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + estimated_cost_usd=estimated_cost_usd, + cost_status=cost_status, + cost_source=cost_source, + billing_provider=provider, + billing_base_url=base_url, model=model, ) except Exception as e: diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 62d8a19a7..8c914034c 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -34,8 +34,11 @@ _EXTRA_ENV_KEYS = frozenset({ "DISCORD_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL", "SIGNAL_ACCOUNT", "SIGNAL_HTTP_URL", "SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS", + "DINGTALK_CLIENT_ID", "DINGTALK_CLIENT_SECRET", "TERMINAL_ENV", "TERMINAL_SSH_KEY", "TERMINAL_SSH_PORT", "WHATSAPP_MODE", "WHATSAPP_ENABLED", + "MATTERMOST_HOME_CHANNEL", "MATTERMOST_REPLY_MODE", + "MATRIX_PASSWORD", "MATRIX_ENCRYPTION", "MATRIX_HOME_ROOM", }) import yaml @@ -354,6 +357,11 @@ DEFAULT_CONFIG = { "tirith_path": "tirith", "tirith_timeout": 5, "tirith_fail_open": True, + "website_blocklist": { + "enabled": False, + "domains": [], + "shared_files": [], + }, }, # Config schema version - bump this when adding new required fields @@ -371,6 +379,7 @@ ENV_VARS_BY_VERSION: Dict[int, List[str]] = { 4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"], 5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS", "SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"], + 10: ["TAVILY_API_KEY"], } # Required environment variables with metadata for migration prompts. @@ -542,6 +551,14 @@ OPTIONAL_ENV_VARS = { }, # ── Tool API keys ── + "PARALLEL_API_KEY": { + "description": "Parallel API key for AI-native web search and extract", + "prompt": "Parallel API key", + "url": "https://parallel.ai/", + "tools": ["web_search", "web_extract"], + "password": True, + "category": "tool", + }, "FIRECRAWL_API_KEY": { "description": "Firecrawl API key for web search and scraping", "prompt": "Firecrawl API key", @@ -558,6 +575,14 @@ OPTIONAL_ENV_VARS = { "category": "tool", "advanced": True, }, + "TAVILY_API_KEY": { + "description": "Tavily API key for AI-native web search, extract, and crawl", + "prompt": "Tavily API key", + "url": "https://app.tavily.com/home", + "tools": ["web_search", "web_extract", "web_crawl"], + "password": True, + "category": "tool", + }, "BROWSERBASE_API_KEY": { "description": "Browserbase API key for cloud browser (optional — local browser works without this)", "prompt": "Browserbase API key", @@ -686,6 +711,55 @@ OPTIONAL_ENV_VARS = { "password": True, "category": "messaging", }, + "MATTERMOST_URL": { + "description": "Mattermost server URL (e.g. https://mm.example.com)", + "prompt": "Mattermost server URL", + "url": "https://mattermost.com/deploy/", + "password": False, + "category": "messaging", + }, + "MATTERMOST_TOKEN": { + "description": "Mattermost bot token or personal access token", + "prompt": "Mattermost bot token", + "url": None, + "password": True, + "category": "messaging", + }, + "MATTERMOST_ALLOWED_USERS": { + "description": "Comma-separated Mattermost user IDs allowed to use the bot", + "prompt": "Allowed Mattermost user IDs (comma-separated)", + "url": None, + "password": False, + "category": "messaging", + }, + "MATRIX_HOMESERVER": { + "description": "Matrix homeserver URL (e.g. https://matrix.example.org)", + "prompt": "Matrix homeserver URL", + "url": "https://matrix.org/ecosystem/servers/", + "password": False, + "category": "messaging", + }, + "MATRIX_ACCESS_TOKEN": { + "description": "Matrix access token (preferred over password login)", + "prompt": "Matrix access token", + "url": None, + "password": True, + "category": "messaging", + }, + "MATRIX_USER_ID": { + "description": "Matrix user ID (e.g. @hermes:example.org)", + "prompt": "Matrix user ID (@user:server)", + "url": None, + "password": False, + "category": "messaging", + }, + "MATRIX_ALLOWED_USERS": { + "description": "Comma-separated Matrix user IDs allowed to use the bot (@user:server format)", + "prompt": "Allowed Matrix user IDs (comma-separated)", + "url": None, + "password": False, + "category": "messaging", + }, "GATEWAY_ALLOW_ALL_USERS": { "description": "Allow all users to interact with messaging bots (true/false). Default: false.", "prompt": "Allow all users (true/false)", @@ -1449,7 +1523,9 @@ def show_config(): keys = [ ("OPENROUTER_API_KEY", "OpenRouter"), ("VOICE_TOOLS_OPENAI_KEY", "OpenAI (STT/TTS)"), + ("PARALLEL_API_KEY", "Parallel"), ("FIRECRAWL_API_KEY", "Firecrawl"), + ("TAVILY_API_KEY", "Tavily"), ("BROWSERBASE_API_KEY", "Browserbase"), ("BROWSER_USE_API_KEY", "Browser Use"), ("FAL_KEY", "FAL"), @@ -1598,7 +1674,8 @@ def set_config_value(key: str, value: str): # Check if it's an API key (goes to .env) api_keys = [ 'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY', - 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY', + 'PARALLEL_API_KEY', 'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'TAVILY_API_KEY', + 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID', 'BROWSER_USE_API_KEY', 'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN', 'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY', 'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN', diff --git a/hermes_cli/gateway.py b/hermes_cli/gateway.py index 3f63a1d18..138eeb25d 100644 --- a/hermes_cli/gateway.py +++ b/hermes_cli/gateway.py @@ -1001,6 +1001,64 @@ _PLATFORMS = [ "help": "Paste your member ID from step 7 above."}, ], }, + { + "key": "matrix", + "label": "Matrix", + "emoji": "🔐", + "token_var": "MATRIX_ACCESS_TOKEN", + "setup_instructions": [ + "1. Works with any Matrix homeserver (self-hosted Synapse/Conduit/Dendrite or matrix.org)", + "2. Create a bot user on your homeserver, or use your own account", + "3. Get an access token: Element → Settings → Help & About → Access Token", + " Or via API: curl -X POST https://your-server/_matrix/client/v3/login \\", + " -d '{\"type\":\"m.login.password\",\"user\":\"@bot:server\",\"password\":\"...\"}'", + "4. Alternatively, provide user ID + password and Hermes will log in directly", + "5. For E2EE: set MATRIX_ENCRYPTION=true (requires pip install 'matrix-nio[e2e]')", + "6. To find your user ID: it's @username:your-server (shown in Element profile)", + ], + "vars": [ + {"name": "MATRIX_HOMESERVER", "prompt": "Homeserver URL (e.g. https://matrix.example.org)", "password": False, + "help": "Your Matrix homeserver URL. Works with any self-hosted instance."}, + {"name": "MATRIX_ACCESS_TOKEN", "prompt": "Access token (leave empty to use password login instead)", "password": True, + "help": "Paste your access token, or leave empty and provide user ID + password below."}, + {"name": "MATRIX_USER_ID", "prompt": "User ID (@bot:server — required for password login)", "password": False, + "help": "Full Matrix user ID, e.g. @hermes:matrix.example.org"}, + {"name": "MATRIX_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated, e.g. @you:server)", "password": False, + "is_allowlist": True, + "help": "Matrix user IDs who can interact with the bot."}, + {"name": "MATRIX_HOME_ROOM", "prompt": "Home room ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False, + "help": "Room ID (e.g. !abc123:server) for delivering cron results and notifications."}, + ], + }, + { + "key": "mattermost", + "label": "Mattermost", + "emoji": "💬", + "token_var": "MATTERMOST_TOKEN", + "setup_instructions": [ + "1. In Mattermost: Integrations → Bot Accounts → Add Bot Account", + " (System Console → Integrations → Bot Accounts must be enabled)", + "2. Give it a username (e.g. hermes) and copy the bot token", + "3. Works with any self-hosted Mattermost instance — enter your server URL", + "4. To find your user ID: click your avatar (top-left) → Profile", + " Your user ID is displayed there — click it to copy.", + " ⚠ This is NOT your username — it's a 26-character alphanumeric ID.", + "5. To get a channel ID: click the channel name → View Info → copy the ID", + ], + "vars": [ + {"name": "MATTERMOST_URL", "prompt": "Server URL (e.g. https://mm.example.com)", "password": False, + "help": "Your Mattermost server URL. Works with any self-hosted instance."}, + {"name": "MATTERMOST_TOKEN", "prompt": "Bot token", "password": True, + "help": "Paste the bot token from step 2 above."}, + {"name": "MATTERMOST_ALLOWED_USERS", "prompt": "Allowed user IDs (comma-separated)", "password": False, + "is_allowlist": True, + "help": "Your Mattermost user ID from step 4 above."}, + {"name": "MATTERMOST_HOME_CHANNEL", "prompt": "Home channel ID (for cron/notification delivery, or empty to set later with /set-home)", "password": False, + "help": "Channel ID where Hermes delivers cron results and notifications."}, + {"name": "MATTERMOST_REPLY_MODE", "prompt": "Reply mode — 'off' for flat messages, 'thread' for threaded replies (default: off)", "password": False, + "help": "off = flat channel messages, thread = replies nest under your message."}, + ], + }, { "key": "whatsapp", "label": "WhatsApp", @@ -1013,30 +1071,6 @@ _PLATFORMS = [ "emoji": "📡", "token_var": "SIGNAL_HTTP_URL", }, - { - "key": "sms", - "label": "SMS (Telnyx)", - "emoji": "📱", - "token_var": "TELNYX_API_KEY", - "setup_instructions": [ - "1. Create a Telnyx account at https://portal.telnyx.com/", - "2. Buy a phone number with SMS capability", - "3. Create an API key: API Keys → Create API Key", - "4. Set up a Messaging Profile and assign your number to it", - "5. Configure the webhook URL: https://your-server/webhooks/telnyx", - ], - "vars": [ - {"name": "TELNYX_API_KEY", "prompt": "Telnyx API key", "password": True, - "help": "Paste the API key from step 3 above."}, - {"name": "TELNYX_FROM_NUMBERS", "prompt": "From numbers (comma-separated E.164, e.g. +15551234567)", "password": False, - "help": "The Telnyx phone number(s) Hermes will send SMS from."}, - {"name": "SMS_ALLOWED_USERS", "prompt": "Allowed phone numbers (comma-separated E.164)", "password": False, - "is_allowlist": True, - "help": "Only messages from these phone numbers will be processed."}, - {"name": "SMS_HOME_CHANNEL", "prompt": "Home channel phone (for cron/notification delivery, or empty)", "password": False, - "help": "A phone number where cron job outputs are delivered."}, - ], - }, { "key": "email", "label": "Email", @@ -1063,6 +1097,51 @@ _PLATFORMS = [ "help": "Only emails from these addresses will be processed."}, ], }, + { + "key": "sms", + "label": "SMS (Twilio)", + "emoji": "📱", + "token_var": "TWILIO_ACCOUNT_SID", + "setup_instructions": [ + "1. Create a Twilio account at https://www.twilio.com/", + "2. Get your Account SID and Auth Token from the Twilio Console dashboard", + "3. Buy or configure a phone number capable of sending SMS", + "4. Set up your webhook URL for inbound SMS:", + " Twilio Console → Phone Numbers → Active Numbers → your number", + " → Messaging → A MESSAGE COMES IN → Webhook → https://your-server:8080/webhooks/twilio", + ], + "vars": [ + {"name": "TWILIO_ACCOUNT_SID", "prompt": "Twilio Account SID", "password": False, + "help": "Found on the Twilio Console dashboard."}, + {"name": "TWILIO_AUTH_TOKEN", "prompt": "Twilio Auth Token", "password": True, + "help": "Found on the Twilio Console dashboard (click to reveal)."}, + {"name": "TWILIO_PHONE_NUMBER", "prompt": "Twilio phone number (E.164 format, e.g. +15551234567)", "password": False, + "help": "The Twilio phone number to send SMS from."}, + {"name": "SMS_ALLOWED_USERS", "prompt": "Allowed phone numbers (comma-separated, E.164 format)", "password": False, + "is_allowlist": True, + "help": "Only messages from these phone numbers will be processed."}, + {"name": "SMS_HOME_CHANNEL", "prompt": "Home channel phone number (for cron/notification delivery, or empty)", "password": False, + "help": "Phone number to deliver cron job results and notifications to."}, + ], + }, + { + "key": "dingtalk", + "label": "DingTalk", + "emoji": "💬", + "token_var": "DINGTALK_CLIENT_ID", + "setup_instructions": [ + "1. Go to https://open-dev.dingtalk.com → Create Application", + "2. Under 'Credentials', copy the AppKey (Client ID) and AppSecret (Client Secret)", + "3. Enable 'Stream Mode' under the bot settings", + "4. Add the bot to a group chat or message it directly", + ], + "vars": [ + {"name": "DINGTALK_CLIENT_ID", "prompt": "AppKey (Client ID)", "password": False, + "help": "The AppKey from your DingTalk application credentials."}, + {"name": "DINGTALK_CLIENT_SECRET", "prompt": "AppSecret (Client Secret)", "password": True, + "help": "The AppSecret from your DingTalk application credentials."}, + ], + }, ] @@ -1097,6 +1176,16 @@ def _platform_status(platform: dict) -> str: if any([val, pwd, imap, smtp]): return "partially configured" return "not configured" + if platform.get("key") == "matrix": + homeserver = get_env_value("MATRIX_HOMESERVER") + password = get_env_value("MATRIX_PASSWORD") + if (val or password) and homeserver: + e2ee = get_env_value("MATRIX_ENCRYPTION") + suffix = " + E2EE" if e2ee and e2ee.lower() in ("true", "1", "yes") else "" + return f"configured{suffix}" + if val or password or homeserver: + return "partially configured" + return "not configured" if val: return "configured" return "not configured" diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 32d90ac6a..d5d4885a7 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -784,6 +784,7 @@ def cmd_model(args): "opencode-go": "OpenCode Go", "ai-gateway": "AI Gateway", "kilocode": "Kilo Code", + "alibaba": "Alibaba Cloud (DashScope)", "custom": "Custom endpoint", } active_label = provider_labels.get(active, active) @@ -807,6 +808,7 @@ def cmd_model(args): ("opencode-zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"), ("opencode-go", "OpenCode Go (open models, $10/month subscription)"), ("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"), + ("alibaba", "Alibaba Cloud / DashScope (Qwen models, Anthropic-compatible)"), ] # Add user-defined custom providers from config.yaml @@ -875,7 +877,7 @@ def cmd_model(args): _model_flow_anthropic(config, current_model) elif selected_provider == "kimi-coding": _model_flow_kimi(config, current_model) - elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway"): + elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba"): _model_flow_api_key_provider(config, selected_provider, current_model) @@ -1994,20 +1996,32 @@ def _update_via_zip(args): print(f"✗ ZIP update failed: {e}") sys.exit(1) - # Reinstall Python dependencies + # Reinstall Python dependencies (try .[all] first for optional extras, + # fall back to . if extras fail — mirrors the install script behavior) print("→ Updating Python dependencies...") import subprocess uv_bin = shutil.which("uv") if uv_bin: - subprocess.run( - [uv_bin, "pip", "install", "-e", ".", "--quiet"], - cwd=PROJECT_ROOT, check=True, - env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} - ) + uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} + try: + subprocess.run( + [uv_bin, "pip", "install", "-e", ".[all]", "--quiet"], + cwd=PROJECT_ROOT, check=True, env=uv_env, + ) + except subprocess.CalledProcessError: + print(" ⚠ Optional extras failed, installing base dependencies...") + subprocess.run( + [uv_bin, "pip", "install", "-e", ".", "--quiet"], + cwd=PROJECT_ROOT, check=True, env=uv_env, + ) else: venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip" - if venv_pip.exists(): - subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) + pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"] + try: + subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True) + except subprocess.CalledProcessError: + print(" ⚠ Optional extras failed, installing base dependencies...") + subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) # Sync skills try: @@ -2255,21 +2269,31 @@ def cmd_update(args): _invalidate_update_cache() - # Reinstall Python dependencies (prefer uv for speed, fall back to pip) + # Reinstall Python dependencies (try .[all] first for optional extras, + # fall back to . if extras fail — mirrors the install script behavior) print("→ Updating Python dependencies...") uv_bin = shutil.which("uv") if uv_bin: - subprocess.run( - [uv_bin, "pip", "install", "-e", ".", "--quiet"], - cwd=PROJECT_ROOT, check=True, - env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} - ) + uv_env = {**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")} + try: + subprocess.run( + [uv_bin, "pip", "install", "-e", ".[all]", "--quiet"], + cwd=PROJECT_ROOT, check=True, env=uv_env, + ) + except subprocess.CalledProcessError: + print(" ⚠ Optional extras failed, installing base dependencies...") + subprocess.run( + [uv_bin, "pip", "install", "-e", ".", "--quiet"], + cwd=PROJECT_ROOT, check=True, env=uv_env, + ) else: venv_pip = PROJECT_ROOT / "venv" / ("Scripts" if sys.platform == "win32" else "bin") / "pip" - if venv_pip.exists(): - subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) - else: - subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) + pip_cmd = [str(venv_pip)] if venv_pip.exists() else ["pip"] + try: + subprocess.run(pip_cmd + ["install", "-e", ".[all]", "--quiet"], cwd=PROJECT_ROOT, check=True) + except subprocess.CalledProcessError: + print(" ⚠ Optional extras failed, installing base dependencies...") + subprocess.run(pip_cmd + ["install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True) # Check for Node.js deps if (PROJECT_ROOT / "package.json").exists(): diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 25c9eea54..174aa9475 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -473,7 +473,7 @@ def provider_model_ids(provider: Optional[str]) -> list[str]: from hermes_cli.auth import fetch_nous_models, resolve_nous_runtime_credentials creds = resolve_nous_runtime_credentials() if creds: - live = fetch_nous_models(creds.get("api_key", ""), creds.get("base_url", "")) + live = fetch_nous_models(api_key=creds.get("api_key", ""), inference_base_url=creds.get("base_url", "")) if live: return live except Exception: diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 70bad2ef4..c9a117c5d 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -444,11 +444,11 @@ def _print_setup_summary(config: dict, hermes_home): else: tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY")) - # Firecrawl (web tools) - if get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"): + # Web tools (Parallel, Firecrawl, or Tavily) + if get_env_value("PARALLEL_API_KEY") or get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL") or get_env_value("TAVILY_API_KEY"): tool_status.append(("Web Search & Extract", True, None)) else: - tool_status.append(("Web Search & Extract", False, "FIRECRAWL_API_KEY")) + tool_status.append(("Web Search & Extract", False, "PARALLEL_API_KEY, FIRECRAWL_API_KEY, or TAVILY_API_KEY")) # Browser tools (local Chromium or Browserbase cloud) import shutil @@ -738,6 +738,7 @@ def setup_model_provider(config: dict): "Kilo Code (Kilo Gateway API)", "Anthropic (Claude models — API key or Claude Code subscription)", "AI Gateway (Vercel — 200+ models, pay-per-use)", + "Alibaba Cloud / DashScope (Qwen models via Anthropic-compatible API)", "OpenCode Zen (35+ curated models, pay-as-you-go)", "OpenCode Go (open models, $10/month subscription)", ] @@ -1313,7 +1314,39 @@ def setup_model_provider(config: dict): _update_config_for_provider("ai-gateway", pconfig.inference_base_url, default_model="anthropic/claude-opus-4.6") _set_model_provider(config, "ai-gateway", pconfig.inference_base_url) - elif provider_idx == 11: # OpenCode Zen + elif provider_idx == 11: # Alibaba Cloud / DashScope + selected_provider = "alibaba" + print() + print_header("Alibaba Cloud / DashScope API Key") + pconfig = PROVIDER_REGISTRY["alibaba"] + print_info(f"Provider: {pconfig.name}") + print_info("Get your API key at: https://modelstudio.console.alibabacloud.com/") + print() + + existing_key = get_env_value("DASHSCOPE_API_KEY") + if existing_key: + print_info(f"Current: {existing_key[:8]}... (configured)") + if prompt_yes_no("Update API key?", False): + new_key = prompt(" DashScope API key", password=True) + if new_key: + save_env_value("DASHSCOPE_API_KEY", new_key) + print_success("DashScope API key updated") + else: + new_key = prompt(" DashScope API key", password=True) + if new_key: + save_env_value("DASHSCOPE_API_KEY", new_key) + print_success("DashScope API key saved") + else: + print_warning("Skipped - agent won't work without an API key") + + # Clear custom endpoint vars if switching + if existing_custom: + save_env_value("OPENAI_BASE_URL", "") + save_env_value("OPENAI_API_KEY", "") + _update_config_for_provider("alibaba", pconfig.inference_base_url, default_model="qwen3.5-plus") + _set_model_provider(config, "alibaba", pconfig.inference_base_url) + + elif provider_idx == 12: # OpenCode Zen selected_provider = "opencode-zen" print() print_header("OpenCode Zen API Key") @@ -1346,7 +1379,7 @@ def setup_model_provider(config: dict): _set_model_provider(config, "opencode-zen", pconfig.inference_base_url) selected_base_url = pconfig.inference_base_url - elif provider_idx == 12: # OpenCode Go + elif provider_idx == 13: # OpenCode Go selected_provider = "opencode-go" print() print_header("OpenCode Go API Key") @@ -1379,7 +1412,7 @@ def setup_model_provider(config: dict): _set_model_provider(config, "opencode-go", pconfig.inference_base_url) selected_base_url = pconfig.inference_base_url - # else: provider_idx == 13 (Keep current) — only shown when a provider already exists + # else: provider_idx == 14 (Keep current) — only shown when a provider already exists # Normalize "keep current" to an explicit provider so downstream logic # doesn't fall back to the generic OpenRouter/static-model path. if selected_provider is None: @@ -2486,6 +2519,119 @@ def setup_gateway(config: dict): " Set SLACK_ALLOW_ALL_USERS=true or GATEWAY_ALLOW_ALL_USERS=true only if you intentionally want open workspace access." ) + # ── Matrix ── + existing_matrix = get_env_value("MATRIX_ACCESS_TOKEN") or get_env_value("MATRIX_PASSWORD") + if existing_matrix: + print_info("Matrix: already configured") + if prompt_yes_no("Reconfigure Matrix?", False): + existing_matrix = None + + if not existing_matrix and prompt_yes_no("Set up Matrix?", False): + print_info("Works with any Matrix homeserver (Synapse, Conduit, Dendrite, or matrix.org).") + print_info(" 1. Create a bot user on your homeserver, or use your own account") + print_info(" 2. Get an access token from Element, or provide user ID + password") + print() + homeserver = prompt("Homeserver URL (e.g. https://matrix.example.org)") + if homeserver: + save_env_value("MATRIX_HOMESERVER", homeserver.rstrip("/")) + + print() + print_info("Auth: provide an access token (recommended), or user ID + password.") + token = prompt("Access token (leave empty for password login)", password=True) + if token: + save_env_value("MATRIX_ACCESS_TOKEN", token) + user_id = prompt("User ID (@bot:server — optional, will be auto-detected)") + if user_id: + save_env_value("MATRIX_USER_ID", user_id) + print_success("Matrix access token saved") + else: + user_id = prompt("User ID (@bot:server)") + if user_id: + save_env_value("MATRIX_USER_ID", user_id) + password = prompt("Password", password=True) + if password: + save_env_value("MATRIX_PASSWORD", password) + print_success("Matrix credentials saved") + + if token or get_env_value("MATRIX_PASSWORD"): + # E2EE + print() + if prompt_yes_no("Enable end-to-end encryption (E2EE)?", False): + save_env_value("MATRIX_ENCRYPTION", "true") + print_success("E2EE enabled") + print_info(" Requires: pip install 'matrix-nio[e2e]'") + + # Allowed users + print() + print_info("🔒 Security: Restrict who can use your bot") + print_info(" Matrix user IDs look like @username:server") + print() + allowed_users = prompt( + "Allowed user IDs (comma-separated, leave empty for open access)" + ) + if allowed_users: + save_env_value("MATRIX_ALLOWED_USERS", allowed_users.replace(" ", "")) + print_success("Matrix allowlist configured") + else: + print_info( + "⚠️ No allowlist set - anyone who can message the bot can use it!" + ) + + # Home room + print() + print_info("📬 Home Room: where Hermes delivers cron job results and notifications.") + print_info(" Room IDs look like !abc123:server (shown in Element room settings)") + print_info(" You can also set this later by typing /set-home in a Matrix room.") + home_room = prompt("Home room ID (leave empty to set later with /set-home)") + if home_room: + save_env_value("MATRIX_HOME_ROOM", home_room) + + # ── Mattermost ── + existing_mattermost = get_env_value("MATTERMOST_TOKEN") + if existing_mattermost: + print_info("Mattermost: already configured") + if prompt_yes_no("Reconfigure Mattermost?", False): + existing_mattermost = None + + if not existing_mattermost and prompt_yes_no("Set up Mattermost?", False): + print_info("Works with any self-hosted Mattermost instance.") + print_info(" 1. In Mattermost: Integrations → Bot Accounts → Add Bot Account") + print_info(" 2. Copy the bot token") + print() + mm_url = prompt("Mattermost server URL (e.g. https://mm.example.com)") + if mm_url: + save_env_value("MATTERMOST_URL", mm_url.rstrip("/")) + token = prompt("Bot token", password=True) + if token: + save_env_value("MATTERMOST_TOKEN", token) + print_success("Mattermost token saved") + + # Allowed users + print() + print_info("🔒 Security: Restrict who can use your bot") + print_info(" To find your user ID: click your avatar → Profile") + print_info(" or use the API: GET /api/v4/users/me") + print() + allowed_users = prompt( + "Allowed user IDs (comma-separated, leave empty for open access)" + ) + if allowed_users: + save_env_value("MATTERMOST_ALLOWED_USERS", allowed_users.replace(" ", "")) + print_success("Mattermost allowlist configured") + else: + print_info( + "⚠️ No allowlist set - anyone who can message the bot can use it!" + ) + + # Home channel + print() + print_info("📬 Home Channel: where Hermes delivers cron job results and notifications.") + print_info(" To get a channel ID: click channel name → View Info → copy the ID") + print_info(" You can also set this later by typing /set-home in a Mattermost channel.") + home_channel = prompt("Home channel ID (leave empty to set later with /set-home)") + if home_channel: + save_env_value("MATTERMOST_HOME_CHANNEL", home_channel) + # ── WhatsApp ── existing_whatsapp = get_env_value("WHATSAPP_ENABLED") if not existing_whatsapp and prompt_yes_no("Set up WhatsApp?", False): @@ -2503,6 +2649,9 @@ def setup_gateway(config: dict): get_env_value("TELEGRAM_BOT_TOKEN") or get_env_value("DISCORD_BOT_TOKEN") or get_env_value("SLACK_BOT_TOKEN") + or get_env_value("MATTERMOST_TOKEN") + or get_env_value("MATRIX_ACCESS_TOKEN") + or get_env_value("MATRIX_PASSWORD") or get_env_value("WHATSAPP_ENABLED") ) if any_messaging: diff --git a/hermes_cli/status.py b/hermes_cli/status.py index ccdeca4d0..e8db90cf2 100644 --- a/hermes_cli/status.py +++ b/hermes_cli/status.py @@ -120,6 +120,7 @@ def show_status(args): "MiniMax": "MINIMAX_API_KEY", "MiniMax-CN": "MINIMAX_CN_API_KEY", "Firecrawl": "FIRECRAWL_API_KEY", + "Tavily": "TAVILY_API_KEY", "Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this "FAL": "FAL_KEY", "Tinker": "TINKER_API_KEY", @@ -252,7 +253,7 @@ def show_status(args): "Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"), "Slack": ("SLACK_BOT_TOKEN", None), "Email": ("EMAIL_ADDRESS", "EMAIL_HOME_ADDRESS"), - "SMS": ("TELNYX_API_KEY", "SMS_HOME_CHANNEL"), + "SMS": ("TWILIO_ACCOUNT_SID", "SMS_HOME_CHANNEL"), } for name, (token_var, home_var) in platforms.items(): diff --git a/hermes_cli/tools_config.py b/hermes_cli/tools_config.py index d106d0c47..1d6783a2d 100644 --- a/hermes_cli/tools_config.py +++ b/hermes_cli/tools_config.py @@ -110,6 +110,7 @@ PLATFORMS = { "whatsapp": {"label": "📱 WhatsApp", "default_toolset": "hermes-whatsapp"}, "signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"}, "email": {"label": "📧 Email", "default_toolset": "hermes-email"}, + "dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"}, } @@ -150,19 +151,37 @@ TOOL_CATEGORIES = { "web": { "name": "Web Search & Extract", "setup_title": "Select Search Provider", - "setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need Firecrawl.", + "setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need a premium provider.", "icon": "🔍", "providers": [ { "name": "Firecrawl Cloud", - "tag": "Recommended - hosted service", + "tag": "Hosted service - search, extract, and crawl", + "web_backend": "firecrawl", "env_vars": [ {"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"}, ], }, + { + "name": "Parallel", + "tag": "AI-native search and extract", + "web_backend": "parallel", + "env_vars": [ + {"key": "PARALLEL_API_KEY", "prompt": "Parallel API key", "url": "https://parallel.ai"}, + ], + }, + { + "name": "Tavily", + "tag": "AI-native search, extract, and crawl", + "web_backend": "tavily", + "env_vars": [ + {"key": "TAVILY_API_KEY", "prompt": "Tavily API key", "url": "https://app.tavily.com/home"}, + ], + }, { "name": "Firecrawl Self-Hosted", "tag": "Free - run your own instance", + "web_backend": "firecrawl", "env_vars": [ {"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"}, ], @@ -617,6 +636,9 @@ def _is_provider_active(provider: dict, config: dict) -> bool: if "browser_provider" in provider: current = config.get("browser", {}).get("cloud_provider") return provider["browser_provider"] == current + if provider.get("web_backend"): + current = config.get("web", {}).get("backend") + return current == provider["web_backend"] return False @@ -649,6 +671,11 @@ def _configure_provider(provider: dict, config: dict): else: config.get("browser", {}).pop("cloud_provider", None) + # Set web search backend in config if applicable + if provider.get("web_backend"): + config.setdefault("web", {})["backend"] = provider["web_backend"] + _print_success(f" Web backend set to: {provider['web_backend']}") + if not env_vars: _print_success(f" {provider['name']} - no configuration needed!") return @@ -832,6 +859,11 @@ def _reconfigure_provider(provider: dict, config: dict): config.get("browser", {}).pop("cloud_provider", None) _print_success(f" Browser set to local mode") + # Set web search backend in config if applicable + if provider.get("web_backend"): + config.setdefault("web", {})["backend"] = provider["web_backend"] + _print_success(f" Web backend set to: {provider['web_backend']}") + if not env_vars: _print_success(f" {provider['name']} - no configuration needed!") return @@ -984,12 +1016,19 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): if len(platform_keys) > 1: platform_choices.append("Configure all platforms (global)") platform_choices.append("Reconfigure an existing tool's provider or API key") + + # Show MCP option if any MCP servers are configured + _has_mcp = bool(config.get("mcp_servers")) + if _has_mcp: + platform_choices.append("Configure MCP server tools") + platform_choices.append("Done") # Index offsets for the extra options after per-platform entries _global_idx = len(platform_keys) if len(platform_keys) > 1 else -1 _reconfig_idx = len(platform_keys) + (1 if len(platform_keys) > 1 else 0) - _done_idx = _reconfig_idx + 1 + _mcp_idx = (_reconfig_idx + 1) if _has_mcp else -1 + _done_idx = _reconfig_idx + (2 if _has_mcp else 1) while True: idx = _prompt_choice("Select an option:", platform_choices, default=0) @@ -1004,6 +1043,12 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): print() continue + # "Configure MCP tools" selected + if idx == _mcp_idx: + _configure_mcp_tools_interactive(config) + print() + continue + # "Configure all platforms (global)" selected if idx == _global_idx: # Use the union of all platforms' current tools as the starting state @@ -1090,6 +1135,137 @@ def tools_command(args=None, first_install: bool = False, config: dict = None): print() +# ─── MCP Tools Interactive Configuration ───────────────────────────────────── + + +def _configure_mcp_tools_interactive(config: dict): + """Probe MCP servers for available tools and let user toggle them on/off. + + Connects to each configured MCP server, discovers tools, then shows + a per-server curses checklist. Writes changes back as ``tools.exclude`` + entries in config.yaml. + """ + from hermes_cli.curses_ui import curses_checklist + + mcp_servers = config.get("mcp_servers") or {} + if not mcp_servers: + _print_info("No MCP servers configured.") + return + + # Count enabled servers + enabled_names = [ + k for k, v in mcp_servers.items() + if v.get("enabled", True) not in (False, "false", "0", "no", "off") + ] + if not enabled_names: + _print_info("All MCP servers are disabled.") + return + + print() + print(color(" Discovering tools from MCP servers...", Colors.YELLOW)) + print(color(f" Connecting to {len(enabled_names)} server(s): {', '.join(enabled_names)}", Colors.DIM)) + + try: + from tools.mcp_tool import probe_mcp_server_tools + server_tools = probe_mcp_server_tools() + except Exception as exc: + _print_error(f"Failed to probe MCP servers: {exc}") + return + + if not server_tools: + _print_warning("Could not discover tools from any MCP server.") + _print_info("Check that server commands/URLs are correct and dependencies are installed.") + return + + # Report discovery results + failed = [n for n in enabled_names if n not in server_tools] + if failed: + for name in failed: + _print_warning(f" Could not connect to '{name}'") + + total_tools = sum(len(tools) for tools in server_tools.values()) + print(color(f" Found {total_tools} tool(s) across {len(server_tools)} server(s)", Colors.GREEN)) + print() + + any_changes = False + + for server_name, tools in server_tools.items(): + if not tools: + _print_info(f" {server_name}: no tools found") + continue + + srv_cfg = mcp_servers.get(server_name, {}) + tools_cfg = srv_cfg.get("tools") or {} + include_list = tools_cfg.get("include") or [] + exclude_list = tools_cfg.get("exclude") or [] + + # Build checklist labels + labels = [] + for tool_name, description in tools: + desc_short = description[:70] + "..." if len(description) > 70 else description + if desc_short: + labels.append(f"{tool_name} ({desc_short})") + else: + labels.append(tool_name) + + # Determine which tools are currently enabled + pre_selected: Set[int] = set() + tool_names = [t[0] for t in tools] + for i, tool_name in enumerate(tool_names): + if include_list: + # Include mode: only included tools are selected + if tool_name in include_list: + pre_selected.add(i) + elif exclude_list: + # Exclude mode: everything except excluded + if tool_name not in exclude_list: + pre_selected.add(i) + else: + # No filter: all enabled + pre_selected.add(i) + + chosen = curses_checklist( + f"MCP Server: {server_name} ({len(tools)} tools)", + labels, + pre_selected, + cancel_returns=pre_selected, + ) + + if chosen == pre_selected: + _print_info(f" {server_name}: no changes") + continue + + # Compute new exclude list based on unchecked tools + new_exclude = [tool_names[i] for i in range(len(tool_names)) if i not in chosen] + + # Update config + srv_cfg = mcp_servers.setdefault(server_name, {}) + tools_cfg = srv_cfg.setdefault("tools", {}) + + if new_exclude: + tools_cfg["exclude"] = new_exclude + # Remove include if present — we're switching to exclude mode + tools_cfg.pop("include", None) + else: + # All tools enabled — clear filters + tools_cfg.pop("exclude", None) + tools_cfg.pop("include", None) + + enabled_count = len(chosen) + disabled_count = len(tools) - enabled_count + _print_success( + f" {server_name}: {enabled_count} enabled, {disabled_count} disabled" + ) + any_changes = True + + if any_changes: + save_config(config) + print() + print(color(" ✓ MCP tool configuration saved", Colors.GREEN)) + else: + print(color(" No changes to MCP tools", Colors.DIM)) + + # ─── Non-interactive disable/enable ────────────────────────────────────────── diff --git a/hermes_state.py b/hermes_state.py index 3f4715067..396c4dbf9 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -18,6 +18,7 @@ import json import os import re import sqlite3 +import threading import time from pathlib import Path from typing import Dict, Any, List, Optional @@ -25,7 +26,7 @@ from typing import Dict, Any, List, Optional DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db" -SCHEMA_VERSION = 4 +SCHEMA_VERSION = 5 SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS schema_version ( @@ -47,6 +48,17 @@ CREATE TABLE IF NOT EXISTS sessions ( tool_call_count INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0, + cache_read_tokens INTEGER DEFAULT 0, + cache_write_tokens INTEGER DEFAULT 0, + reasoning_tokens INTEGER DEFAULT 0, + billing_provider TEXT, + billing_base_url TEXT, + billing_mode TEXT, + estimated_cost_usd REAL, + actual_cost_usd REAL, + cost_status TEXT, + cost_source TEXT, + pricing_version TEXT, title TEXT, FOREIGN KEY (parent_session_id) REFERENCES sessions(id) ); @@ -104,6 +116,7 @@ class SessionDB: self.db_path = db_path or DEFAULT_DB_PATH self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() self._conn = sqlite3.connect( str(self.db_path), check_same_thread=False, @@ -152,6 +165,26 @@ class SessionDB: except sqlite3.OperationalError: pass # Index already exists cursor.execute("UPDATE schema_version SET version = 4") + if current_version < 5: + new_columns = [ + ("cache_read_tokens", "INTEGER DEFAULT 0"), + ("cache_write_tokens", "INTEGER DEFAULT 0"), + ("reasoning_tokens", "INTEGER DEFAULT 0"), + ("billing_provider", "TEXT"), + ("billing_base_url", "TEXT"), + ("billing_mode", "TEXT"), + ("estimated_cost_usd", "REAL"), + ("actual_cost_usd", "REAL"), + ("cost_status", "TEXT"), + ("cost_source", "TEXT"), + ("pricing_version", "TEXT"), + ] + for name, column_type in new_columns: + try: + cursor.execute(f"ALTER TABLE sessions ADD COLUMN {name} {column_type}") + except sqlite3.OperationalError: + pass + cursor.execute("UPDATE schema_version SET version = 5") # Unique title index — always ensure it exists (safe to run after migrations # since the title column is guaranteed to exist at this point) @@ -173,9 +206,10 @@ class SessionDB: def close(self): """Close the database connection.""" - if self._conn: - self._conn.close() - self._conn = None + with self._lock: + if self._conn: + self._conn.close() + self._conn = None # ========================================================================= # Session lifecycle @@ -192,61 +226,111 @@ class SessionDB: parent_session_id: str = None, ) -> str: """Create a new session record. Returns the session_id.""" - self._conn.execute( - """INSERT INTO sessions (id, source, user_id, model, model_config, - system_prompt, parent_session_id, started_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - source, - user_id, - model, - json.dumps(model_config) if model_config else None, - system_prompt, - parent_session_id, - time.time(), - ), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + """INSERT INTO sessions (id, source, user_id, model, model_config, + system_prompt, parent_session_id, started_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + session_id, + source, + user_id, + model, + json.dumps(model_config) if model_config else None, + system_prompt, + parent_session_id, + time.time(), + ), + ) + self._conn.commit() return session_id def end_session(self, session_id: str, end_reason: str) -> None: """Mark a session as ended.""" - self._conn.execute( - "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", - (time.time(), end_reason, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + "UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?", + (time.time(), end_reason, session_id), + ) + self._conn.commit() def update_system_prompt(self, session_id: str, system_prompt: str) -> None: """Store the full assembled system prompt snapshot.""" - self._conn.execute( - "UPDATE sessions SET system_prompt = ? WHERE id = ?", - (system_prompt, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + "UPDATE sessions SET system_prompt = ? WHERE id = ?", + (system_prompt, session_id), + ) + self._conn.commit() def update_token_counts( - self, session_id: str, input_tokens: int = 0, output_tokens: int = 0, + self, + session_id: str, + input_tokens: int = 0, + output_tokens: int = 0, model: str = None, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, + reasoning_tokens: int = 0, + estimated_cost_usd: Optional[float] = None, + actual_cost_usd: Optional[float] = None, + cost_status: Optional[str] = None, + cost_source: Optional[str] = None, + pricing_version: Optional[str] = None, + billing_provider: Optional[str] = None, + billing_base_url: Optional[str] = None, + billing_mode: Optional[str] = None, ) -> None: """Increment token counters and backfill model if not already set.""" - self._conn.execute( - """UPDATE sessions SET - input_tokens = input_tokens + ?, - output_tokens = output_tokens + ?, - model = COALESCE(model, ?) - WHERE id = ?""", - (input_tokens, output_tokens, model, session_id), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + """UPDATE sessions SET + input_tokens = input_tokens + ?, + output_tokens = output_tokens + ?, + cache_read_tokens = cache_read_tokens + ?, + cache_write_tokens = cache_write_tokens + ?, + reasoning_tokens = reasoning_tokens + ?, + estimated_cost_usd = COALESCE(estimated_cost_usd, 0) + COALESCE(?, 0), + actual_cost_usd = CASE + WHEN ? IS NULL THEN actual_cost_usd + ELSE COALESCE(actual_cost_usd, 0) + ? + END, + cost_status = COALESCE(?, cost_status), + cost_source = COALESCE(?, cost_source), + pricing_version = COALESCE(?, pricing_version), + billing_provider = COALESCE(billing_provider, ?), + billing_base_url = COALESCE(billing_base_url, ?), + billing_mode = COALESCE(billing_mode, ?), + model = COALESCE(model, ?) + WHERE id = ?""", + ( + input_tokens, + output_tokens, + cache_read_tokens, + cache_write_tokens, + reasoning_tokens, + estimated_cost_usd, + actual_cost_usd, + actual_cost_usd, + cost_status, + cost_source, + pricing_version, + billing_provider, + billing_base_url, + billing_mode, + model, + session_id, + ), + ) + self._conn.commit() def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """Get a session by ID.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM sessions WHERE id = ?", (session_id,) + ) + row = cursor.fetchone() return dict(row) if row else None def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]: @@ -331,38 +415,42 @@ class SessionDB: Empty/whitespace-only strings are normalized to None (clearing the title). """ title = self.sanitize_title(title) - if title: - # Check uniqueness (allow the same session to keep its own title) + with self._lock: + if title: + # Check uniqueness (allow the same session to keep its own title) + cursor = self._conn.execute( + "SELECT id FROM sessions WHERE title = ? AND id != ?", + (title, session_id), + ) + conflict = cursor.fetchone() + if conflict: + raise ValueError( + f"Title '{title}' is already in use by session {conflict['id']}" + ) cursor = self._conn.execute( - "SELECT id FROM sessions WHERE title = ? AND id != ?", + "UPDATE sessions SET title = ? WHERE id = ?", (title, session_id), ) - conflict = cursor.fetchone() - if conflict: - raise ValueError( - f"Title '{title}' is already in use by session {conflict['id']}" - ) - cursor = self._conn.execute( - "UPDATE sessions SET title = ? WHERE id = ?", - (title, session_id), - ) - self._conn.commit() - return cursor.rowcount > 0 + self._conn.commit() + rowcount = cursor.rowcount + return rowcount > 0 def get_session_title(self, session_id: str) -> Optional[str]: """Get the title for a session, or None.""" - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE id = ?", (session_id,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT title FROM sessions WHERE id = ?", (session_id,) + ) + row = cursor.fetchone() return row["title"] if row else None def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]: """Look up a session by exact title. Returns session dict or None.""" - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE title = ?", (title,) - ) - row = cursor.fetchone() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM sessions WHERE title = ?", (title,) + ) + row = cursor.fetchone() return dict(row) if row else None def resolve_session_by_title(self, title: str) -> Optional[str]: @@ -379,12 +467,13 @@ class SessionDB: # Also search for numbered variants: "title #2", "title #3", etc. # Escape SQL LIKE wildcards (%, _) in the title to prevent false matches escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - cursor = self._conn.execute( - "SELECT id, title, started_at FROM sessions " - "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", - (f"{escaped} #%",), - ) - numbered = cursor.fetchall() + with self._lock: + cursor = self._conn.execute( + "SELECT id, title, started_at FROM sessions " + "WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC", + (f"{escaped} #%",), + ) + numbered = cursor.fetchall() if numbered: # Return the most recent numbered variant @@ -409,11 +498,12 @@ class SessionDB: # Find all existing numbered variants # Escape SQL LIKE wildcards (%, _) in the base to prevent false matches escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") - cursor = self._conn.execute( - "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", - (base, f"{escaped} #%"), - ) - existing = [row["title"] for row in cursor.fetchall()] + with self._lock: + cursor = self._conn.execute( + "SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'", + (base, f"{escaped} #%"), + ) + existing = [row["title"] for row in cursor.fetchall()] if not existing: return base # No conflict, use the base name as-is @@ -461,9 +551,11 @@ class SessionDB: LIMIT ? OFFSET ? """ params = (source, limit, offset) if source else (limit, offset) - cursor = self._conn.execute(query, params) + with self._lock: + cursor = self._conn.execute(query, params) + rows = cursor.fetchall() sessions = [] - for row in cursor.fetchall(): + for row in rows: s = dict(row) # Build the preview from the raw substring raw = s.pop("_preview_raw", "").strip() @@ -497,52 +589,54 @@ class SessionDB: Also increments the session's message_count (and tool_call_count if role is 'tool' or tool_calls is present). """ - cursor = self._conn.execute( - """INSERT INTO messages (session_id, role, content, tool_call_id, - tool_calls, tool_name, timestamp, token_count, finish_reason) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - session_id, - role, - content, - tool_call_id, - json.dumps(tool_calls) if tool_calls else None, - tool_name, - time.time(), - token_count, - finish_reason, - ), - ) - msg_id = cursor.lastrowid - - # Update counters - # Count actual tool calls from the tool_calls list (not from tool responses). - # A single assistant message can contain multiple parallel tool calls. - num_tool_calls = 0 - if tool_calls is not None: - num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 - if num_tool_calls > 0: - self._conn.execute( - """UPDATE sessions SET message_count = message_count + 1, - tool_call_count = tool_call_count + ? WHERE id = ?""", - (num_tool_calls, session_id), - ) - else: - self._conn.execute( - "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", - (session_id,), + with self._lock: + cursor = self._conn.execute( + """INSERT INTO messages (session_id, role, content, tool_call_id, + tool_calls, tool_name, timestamp, token_count, finish_reason) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + session_id, + role, + content, + tool_call_id, + json.dumps(tool_calls) if tool_calls else None, + tool_name, + time.time(), + token_count, + finish_reason, + ), ) + msg_id = cursor.lastrowid - self._conn.commit() + # Update counters + # Count actual tool calls from the tool_calls list (not from tool responses). + # A single assistant message can contain multiple parallel tool calls. + num_tool_calls = 0 + if tool_calls is not None: + num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 + if num_tool_calls > 0: + self._conn.execute( + """UPDATE sessions SET message_count = message_count + 1, + tool_call_count = tool_call_count + ? WHERE id = ?""", + (num_tool_calls, session_id), + ) + else: + self._conn.execute( + "UPDATE sessions SET message_count = message_count + 1 WHERE id = ?", + (session_id,), + ) + + self._conn.commit() return msg_id def get_messages(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" - cursor = self._conn.execute( - "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", - (session_id,), - ) - rows = cursor.fetchall() + with self._lock: + cursor = self._conn.execute( + "SELECT * FROM messages WHERE session_id = ? ORDER BY timestamp, id", + (session_id,), + ) + rows = cursor.fetchall() result = [] for row in rows: msg = dict(row) @@ -559,13 +653,15 @@ class SessionDB: Load messages in the OpenAI conversation format (role + content dicts). Used by the gateway to restore conversation history. """ - cursor = self._conn.execute( - "SELECT role, content, tool_call_id, tool_calls, tool_name " - "FROM messages WHERE session_id = ? ORDER BY timestamp, id", - (session_id,), - ) + with self._lock: + cursor = self._conn.execute( + "SELECT role, content, tool_call_id, tool_calls, tool_name " + "FROM messages WHERE session_id = ? ORDER BY timestamp, id", + (session_id,), + ) + rows = cursor.fetchall() messages = [] - for row in cursor.fetchall(): + for row in rows: msg = {"role": row["role"], "content": row["content"]} if row["tool_call_id"]: msg["tool_call_id"] = row["tool_call_id"] @@ -675,31 +771,33 @@ class SessionDB: LIMIT ? OFFSET ? """ - try: - cursor = self._conn.execute(sql, params) - except sqlite3.OperationalError: - # FTS5 query syntax error despite sanitization — return empty - return [] - matches = [dict(row) for row in cursor.fetchall()] - - # Add surrounding context (1 message before + after each match) - for match in matches: + with self._lock: try: - ctx_cursor = self._conn.execute( - """SELECT role, content FROM messages - WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 - ORDER BY id""", - (match["session_id"], match["id"], match["id"]), - ) - context_msgs = [ - {"role": r["role"], "content": (r["content"] or "")[:200]} - for r in ctx_cursor.fetchall() - ] - match["context"] = context_msgs - except Exception: - match["context"] = [] + cursor = self._conn.execute(sql, params) + except sqlite3.OperationalError: + # FTS5 query syntax error despite sanitization — return empty + return [] + matches = [dict(row) for row in cursor.fetchall()] - # Remove full content from result (snippet is enough, saves tokens) + # Add surrounding context (1 message before + after each match) + for match in matches: + try: + ctx_cursor = self._conn.execute( + """SELECT role, content FROM messages + WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 + ORDER BY id""", + (match["session_id"], match["id"], match["id"]), + ) + context_msgs = [ + {"role": r["role"], "content": (r["content"] or "")[:200]} + for r in ctx_cursor.fetchall() + ] + match["context"] = context_msgs + except Exception: + match["context"] = [] + + # Remove full content from result (snippet is enough, saves tokens) + for match in matches: match.pop("content", None) return matches @@ -711,17 +809,18 @@ class SessionDB: offset: int = 0, ) -> List[Dict[str, Any]]: """List sessions, optionally filtered by source.""" - if source: - cursor = self._conn.execute( - "SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?", - (source, limit, offset), - ) - else: - cursor = self._conn.execute( - "SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?", - (limit, offset), - ) - return [dict(row) for row in cursor.fetchall()] + with self._lock: + if source: + cursor = self._conn.execute( + "SELECT * FROM sessions WHERE source = ? ORDER BY started_at DESC LIMIT ? OFFSET ?", + (source, limit, offset), + ) + else: + cursor = self._conn.execute( + "SELECT * FROM sessions ORDER BY started_at DESC LIMIT ? OFFSET ?", + (limit, offset), + ) + return [dict(row) for row in cursor.fetchall()] # ========================================================================= # Utility @@ -773,26 +872,28 @@ class SessionDB: def clear_messages(self, session_id: str) -> None: """Delete all messages for a session and reset its counters.""" - self._conn.execute( - "DELETE FROM messages WHERE session_id = ?", (session_id,) - ) - self._conn.execute( - "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", - (session_id,), - ) - self._conn.commit() + with self._lock: + self._conn.execute( + "DELETE FROM messages WHERE session_id = ?", (session_id,) + ) + self._conn.execute( + "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", + (session_id,), + ) + self._conn.commit() def delete_session(self, session_id: str) -> bool: """Delete a session and all its messages. Returns True if found.""" - cursor = self._conn.execute( - "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) - ) - if cursor.fetchone()[0] == 0: - return False - self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) - self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) - self._conn.commit() - return True + with self._lock: + cursor = self._conn.execute( + "SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,) + ) + if cursor.fetchone()[0] == 0: + return False + self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,)) + self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,)) + self._conn.commit() + return True def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int: """ @@ -802,22 +903,23 @@ class SessionDB: import time as _time cutoff = _time.time() - (older_than_days * 86400) - if source: - cursor = self._conn.execute( - """SELECT id FROM sessions - WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", - (cutoff, source), - ) - else: - cursor = self._conn.execute( - "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", - (cutoff,), - ) - session_ids = [row["id"] for row in cursor.fetchall()] + with self._lock: + if source: + cursor = self._conn.execute( + """SELECT id FROM sessions + WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""", + (cutoff, source), + ) + else: + cursor = self._conn.execute( + "SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL", + (cutoff,), + ) + session_ids = [row["id"] for row in cursor.fetchall()] - for sid in session_ids: - self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) - self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) + for sid in session_ids: + self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,)) + self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,)) - self._conn.commit() + self._conn.commit() return len(session_ids) diff --git a/honcho_integration/client.py b/honcho_integration/client.py index ccc2f6f25..759576ada 100644 --- a/honcho_integration/client.py +++ b/honcho_integration/client.py @@ -69,6 +69,8 @@ class HonchoClientConfig: workspace_id: str = "hermes" api_key: str | None = None environment: str = "production" + # Optional base URL for self-hosted Honcho (overrides environment mapping) + base_url: str | None = None # Identity peer_name: str | None = None ai_peer: str = "hermes" @@ -361,13 +363,34 @@ def get_honcho_client(config: HonchoClientConfig | None = None) -> Honcho: "Install it with: pip install honcho-ai" ) - logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) + # Allow config.yaml honcho.base_url to override the SDK's environment + # mapping, enabling remote self-hosted Honcho deployments without + # requiring the server to live on localhost. + resolved_base_url = config.base_url + if not resolved_base_url: + try: + from hermes_cli.config import load_config + hermes_cfg = load_config() + honcho_cfg = hermes_cfg.get("honcho", {}) + if isinstance(honcho_cfg, dict): + resolved_base_url = honcho_cfg.get("base_url", "").strip() or None + except Exception: + pass - _honcho_client = Honcho( - workspace_id=config.workspace_id, - api_key=config.api_key, - environment=config.environment, - ) + if resolved_base_url: + logger.info("Initializing Honcho client (base_url: %s, workspace: %s)", resolved_base_url, config.workspace_id) + else: + logger.info("Initializing Honcho client (host: %s, workspace: %s)", config.host, config.workspace_id) + + kwargs: dict = { + "workspace_id": config.workspace_id, + "api_key": config.api_key, + "environment": config.environment, + } + if resolved_base_url: + kwargs["base_url"] = resolved_base_url + + _honcho_client = Honcho(**kwargs) return _honcho_client diff --git a/pyproject.toml b/pyproject.toml index b7b1f167d..7e92f9078 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "prompt_toolkit", # Tools "firecrawl-py", + "parallel-web>=0.4.2", "fal-client", # Text-to-speech (Edge TTS is free, no API key needed) "edge-tts", @@ -46,6 +47,7 @@ dev = ["pytest", "pytest-asyncio", "pytest-xdist", "mcp>=1.2.0"] messaging = ["python-telegram-bot>=20.0", "discord.py[voice]>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] cron = ["croniter"] slack = ["slack-bolt>=1.18.0", "slack-sdk>=3.27.0"] +matrix = ["matrix-nio[e2e]>=0.24.0"] cli = ["simple-term-menu"] tts-premium = ["elevenlabs"] voice = ["sounddevice>=0.4.6", "numpy>=1.24.0"] @@ -79,9 +81,9 @@ all = [ "hermes-agent[honcho]", "hermes-agent[mcp]", "hermes-agent[homeassistant]", + "hermes-agent[sms]", "hermes-agent[acp]", "hermes-agent[voice]", - "hermes-agent[sms]", ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index 030c84656..67b05659a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ PyJWT[crypto] # Web tools firecrawl-py +parallel-web>=0.4.2 # Image generation fal-client diff --git a/run_agent.py b/run_agent.py index 2c8fad0b8..bfe62e04c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -86,6 +86,7 @@ from agent.model_metadata import ( from agent.context_compressor import ContextCompressor from agent.prompt_caching import apply_anthropic_cache_control from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt +from agent.usage_pricing import estimate_usage_cost, normalize_usage from agent.display import ( KawaiiSpinner, build_tool_preview as _build_tool_preview, get_cute_tool_message as _get_cute_tool_message_impl, @@ -391,6 +392,15 @@ class AIAgent: else: self.api_mode = "chat_completions" + # Pre-warm OpenRouter model metadata cache in a background thread. + # fetch_model_metadata() is cached for 1 hour; this avoids a blocking + # HTTP request on the first API response when pricing is estimated. + if self.provider == "openrouter" or "openrouter" in self.base_url.lower(): + threading.Thread( + target=lambda: fetch_model_metadata(), + daemon=True, + ).start() + self.tool_progress_callback = tool_progress_callback self.thinking_callback = thinking_callback self.reasoning_callback = reasoning_callback @@ -407,6 +417,7 @@ class AIAgent: # Subagent delegation state self._delegate_depth = 0 # 0 = top-level agent, incremented for children self._active_children = [] # Running child AIAgents (for interrupt propagation) + self._active_children_lock = threading.Lock() # Store OpenRouter provider preferences self.providers_allowed = providers_allowed @@ -456,8 +467,8 @@ class AIAgent: and Path(getattr(handler, "baseFilename", "")).resolve() == resolved_error_log_path for handler in root_logger.handlers ) + from agent.redact import RedactingFormatter if not has_errors_log_handler: - from agent.redact import RedactingFormatter error_log_dir.mkdir(parents=True, exist_ok=True) error_file_handler = RotatingFileHandler( error_log_path, maxBytes=2 * 1024 * 1024, backupCount=2, @@ -849,6 +860,14 @@ class AIAgent: self.session_completion_tokens = 0 self.session_total_tokens = 0 self.session_api_calls = 0 + self.session_input_tokens = 0 + self.session_output_tokens = 0 + self.session_cache_read_tokens = 0 + self.session_cache_write_tokens = 0 + self.session_reasoning_tokens = 0 + self.session_estimated_cost_usd = 0.0 + self.session_cost_status = "unknown" + self.session_cost_source = "none" if not self.quiet_mode: if compression_enabled: @@ -1526,7 +1545,9 @@ class AIAgent: # Signal all tools to abort any in-flight operations immediately _set_interrupt(True) # Propagate interrupt to any running child agents (subagent delegation) - for child in self._active_children: + with self._active_children_lock: + children_copy = list(self._active_children) + for child in children_copy: try: child.interrupt(message) except Exception as e: @@ -1936,7 +1957,124 @@ class AIAgent: prompt_parts.append(PLATFORM_HINTS[platform_key]) return "\n\n".join(prompt_parts) - + + # ========================================================================= + # Pre/post-call guardrails (inspired by PR #1321 — @alireza78a) + # ========================================================================= + + @staticmethod + def _get_tool_call_id_static(tc) -> str: + """Extract call ID from a tool_call entry (dict or object).""" + if isinstance(tc, dict): + return tc.get("id", "") or "" + return getattr(tc, "id", "") or "" + + @staticmethod + def _sanitize_api_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Fix orphaned tool_call / tool_result pairs before every LLM call. + + Runs unconditionally — not gated on whether the context compressor + is present — so orphans from session loading or manual message + manipulation are always caught. + """ + surviving_call_ids: set = set() + for msg in messages: + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = AIAgent._get_tool_call_id_static(tc) + if cid: + surviving_call_ids.add(cid) + + result_call_ids: set = set() + for msg in messages: + if msg.get("role") == "tool": + cid = msg.get("tool_call_id") + if cid: + result_call_ids.add(cid) + + # 1. Drop tool results with no matching assistant call + orphaned_results = result_call_ids - surviving_call_ids + if orphaned_results: + messages = [ + m for m in messages + if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results) + ] + logger.debug( + "Pre-call sanitizer: removed %d orphaned tool result(s)", + len(orphaned_results), + ) + + # 2. Inject stub results for calls whose result was dropped + missing_results = surviving_call_ids - result_call_ids + if missing_results: + patched: List[Dict[str, Any]] = [] + for msg in messages: + patched.append(msg) + if msg.get("role") == "assistant": + for tc in msg.get("tool_calls") or []: + cid = AIAgent._get_tool_call_id_static(tc) + if cid in missing_results: + patched.append({ + "role": "tool", + "content": "[Result unavailable — see context summary above]", + "tool_call_id": cid, + }) + messages = patched + logger.debug( + "Pre-call sanitizer: added %d stub tool result(s)", + len(missing_results), + ) + + return messages + + @staticmethod + def _cap_delegate_task_calls(tool_calls: list) -> list: + """Truncate excess delegate_task calls to MAX_CONCURRENT_CHILDREN. + + The delegate_tool caps the task list inside a single call, but the + model can emit multiple separate delegate_task tool_calls in one + turn. This truncates the excess, preserving all non-delegate calls. + + Returns the original list if no truncation was needed. + """ + from tools.delegate_tool import MAX_CONCURRENT_CHILDREN + delegate_count = sum(1 for tc in tool_calls if tc.function.name == "delegate_task") + if delegate_count <= MAX_CONCURRENT_CHILDREN: + return tool_calls + kept_delegates = 0 + truncated = [] + for tc in tool_calls: + if tc.function.name == "delegate_task": + if kept_delegates < MAX_CONCURRENT_CHILDREN: + truncated.append(tc) + kept_delegates += 1 + else: + truncated.append(tc) + logger.warning( + "Truncated %d excess delegate_task call(s) to enforce " + "MAX_CONCURRENT_CHILDREN=%d limit", + delegate_count - MAX_CONCURRENT_CHILDREN, MAX_CONCURRENT_CHILDREN, + ) + return truncated + + @staticmethod + def _deduplicate_tool_calls(tool_calls: list) -> list: + """Remove duplicate (tool_name, arguments) pairs within a single turn. + + Only the first occurrence of each unique pair is kept. + Returns the original list if no duplicates were found. + """ + seen: set = set() + unique: list = [] + for tc in tool_calls: + key = (tc.function.name, tc.function.arguments) + if key not in seen: + seen.add(key) + unique.append(tc) + else: + logger.warning("Removed duplicate tool call: %s", tc.function.name) + return unique if len(unique) < len(tool_calls) else tool_calls + def _repair_tool_call(self, tool_name: str) -> str | None: """Attempt to repair a mismatched tool name before aborting. @@ -4863,6 +5001,7 @@ class AIAgent: codex_ack_continuations = 0 length_continue_retries = 0 truncated_response_prefix = "" + compression_attempts = 0 # Clear any stale interrupt state at start self.clear_interrupt() @@ -4970,11 +5109,10 @@ class AIAgent: api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl) # Safety net: strip orphaned tool results / add stubs for missing - # results before sending to the API. The compressor handles this - # during compression, but orphans can also sneak in from session - # loading or manual message manipulation. - if hasattr(self, 'context_compressor') and self.context_compressor: - api_messages = self.context_compressor._sanitize_tool_pairs(api_messages) + # results before sending to the API. Runs unconditionally — not + # gated on context_compressor — so orphans from session loading or + # manual message manipulation are always caught. + api_messages = self._sanitize_api_messages(api_messages) # Calculate approximate request size for logging total_chars = sum(len(str(msg)) for msg in api_messages) @@ -5008,7 +5146,6 @@ class AIAgent: api_start_time = time.time() retry_count = 0 max_retries = 3 - compression_attempts = 0 max_compression_attempts = 3 codex_auth_retry_attempted = False anthropic_auth_retry_attempted = False @@ -5111,6 +5248,13 @@ class AIAgent: # This is often rate limiting or provider returning malformed response retry_count += 1 + # Eager fallback: empty/malformed responses are a common + # rate-limit symptom. Switch to fallback immediately + # rather than retrying with extended backoff. + if not self._fallback_activated and self._try_activate_fallback(): + retry_count = 0 + continue + # Check for error field in response (some providers include this) error_msg = "Unknown" provider_name = "Unknown" @@ -5269,26 +5413,14 @@ class AIAgent: # Track actual token usage from response for context management if hasattr(response, 'usage') and response.usage: - if self.api_mode in ("codex_responses", "anthropic_messages"): - prompt_tokens = getattr(response.usage, 'input_tokens', 0) or 0 - if self.api_mode == "anthropic_messages": - # Anthropic splits input into cache_read + cache_creation - # + non-cached input_tokens. Without adding the cached - # portions, the context bar shows only the tiny non-cached - # portion (e.g. 3 tokens) instead of the real total (~18K). - # Other providers (OpenAI/Codex) already include cached - # tokens in their input_tokens/prompt_tokens field. - prompt_tokens += getattr(response.usage, 'cache_read_input_tokens', 0) or 0 - prompt_tokens += getattr(response.usage, 'cache_creation_input_tokens', 0) or 0 - completion_tokens = getattr(response.usage, 'output_tokens', 0) or 0 - total_tokens = ( - getattr(response.usage, 'total_tokens', None) - or (prompt_tokens + completion_tokens) - ) - else: - prompt_tokens = getattr(response.usage, 'prompt_tokens', 0) or 0 - completion_tokens = getattr(response.usage, 'completion_tokens', 0) or 0 - total_tokens = getattr(response.usage, 'total_tokens', 0) or 0 + canonical_usage = normalize_usage( + response.usage, + provider=self.provider, + api_mode=self.api_mode, + ) + prompt_tokens = canonical_usage.prompt_tokens + completion_tokens = canonical_usage.output_tokens + total_tokens = canonical_usage.total_tokens usage_dict = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, @@ -5307,6 +5439,22 @@ class AIAgent: self.session_completion_tokens += completion_tokens self.session_total_tokens += total_tokens self.session_api_calls += 1 + self.session_input_tokens += canonical_usage.input_tokens + self.session_output_tokens += canonical_usage.output_tokens + self.session_cache_read_tokens += canonical_usage.cache_read_tokens + self.session_cache_write_tokens += canonical_usage.cache_write_tokens + self.session_reasoning_tokens += canonical_usage.reasoning_tokens + + cost_result = estimate_usage_cost( + self.model, + canonical_usage, + provider=self.provider, + base_url=self.base_url, + ) + if cost_result.amount_usd is not None: + self.session_estimated_cost_usd += float(cost_result.amount_usd) + self.session_cost_status = cost_result.status + self.session_cost_source = cost_result.source # Persist token counts to session DB for /insights. # Gateway sessions persist via session_store.update_session() @@ -5317,8 +5465,19 @@ class AIAgent: try: self._session_db.update_token_counts( self.session_id, - input_tokens=prompt_tokens, - output_tokens=completion_tokens, + input_tokens=canonical_usage.input_tokens, + output_tokens=canonical_usage.output_tokens, + cache_read_tokens=canonical_usage.cache_read_tokens, + cache_write_tokens=canonical_usage.cache_write_tokens, + reasoning_tokens=canonical_usage.reasoning_tokens, + estimated_cost_usd=float(cost_result.amount_usd) + if cost_result.amount_usd is not None else None, + cost_status=cost_result.status, + cost_source=cost_result.source, + billing_provider=self.provider, + billing_base_url=self.base_url, + billing_mode="subscription_included" + if cost_result.status == "included" else None, model=self.model, ) except Exception: @@ -5449,6 +5608,24 @@ class AIAgent: # A 413 is a payload-size error — the correct response is to # compress history and retry, not abort immediately. status_code = getattr(api_error, "status_code", None) + + # Eager fallback for rate-limit errors (429 or quota exhaustion). + # When a fallback model is configured, switch immediately instead + # of burning through retries with exponential backoff -- the + # primary provider won't recover within the retry window. + is_rate_limited = ( + status_code == 429 + or "rate limit" in error_msg + or "too many requests" in error_msg + or "rate_limit" in error_msg + or "usage limit" in error_msg + or "quota" in error_msg + ) + if is_rate_limited and not self._fallback_activated: + if self._try_activate_fallback(): + retry_count = 0 + continue + is_payload_too_large = ( status_code == 413 or 'request entity too large' in error_msg @@ -5935,24 +6112,45 @@ class AIAgent: # Don't add anything to messages, just retry the API call continue else: - # Instead of returning partial, inject a helpful message and let model recover - self._vprint(f"{self.log_prefix}⚠️ Injecting recovery message for invalid JSON...") + # Instead of returning partial, inject tool error results so the model can recover. + # Using tool results (not user messages) preserves role alternation. + self._vprint(f"{self.log_prefix}⚠️ Injecting recovery tool results for invalid JSON...") self._invalid_json_retries = 0 # Reset for next attempt - # Add a user message explaining the issue - recovery_msg = ( - f"Your tool call to '{tool_name}' had invalid JSON arguments. " - f"Error: {error_msg}. " - f"For tools with no required parameters, use an empty object: {{}}. " - f"Please either retry the tool call with valid JSON, or respond without using that tool." - ) - recovery_dict = {"role": "user", "content": recovery_msg} - messages.append(recovery_dict) + # Append the assistant message with its (broken) tool_calls + recovery_assistant = self._build_assistant_message(assistant_message, finish_reason) + messages.append(recovery_assistant) + + # Respond with tool error results for each tool call + invalid_names = {name for name, _ in invalid_json_args} + for tc in assistant_message.tool_calls: + if tc.function.name in invalid_names: + err = next(e for n, e in invalid_json_args if n == tc.function.name) + tool_result = ( + f"Error: Invalid JSON arguments. {err}. " + f"For tools with no required parameters, use an empty object: {{}}. " + f"Please retry with valid JSON." + ) + else: + tool_result = "Skipped: other tool call in this response had invalid JSON." + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": tool_result, + }) continue # Reset retry counter on successful JSON validation self._invalid_json_retries = 0 - + + # ── Post-call guardrails ────────────────────────── + assistant_message.tool_calls = self._cap_delegate_task_calls( + assistant_message.tool_calls + ) + assistant_message.tool_calls = self._deduplicate_tool_calls( + assistant_message.tool_calls + ) + assistant_msg = self._build_assistant_message(assistant_message, finish_reason) # If this turn has both content AND tool_calls, capture the content @@ -6133,6 +6331,8 @@ class AIAgent: if truncated_response_prefix: final_response = truncated_response_prefix + final_response + truncated_response_prefix = "" + length_continue_retries = 0 # Strip blocks from user-facing response (keep raw in messages for trajectory) final_response = self._strip_think_blocks(final_response).strip() @@ -6184,10 +6384,11 @@ class AIAgent: if not pending_handled: # Error happened before tool processing (e.g. response parsing). - # Use a user-role message so the model can see what went wrong - # without confusing the API with a fabricated assistant turn. + # Choose role to avoid consecutive same-role messages. + last_role = messages[-1].get("role") if messages else None + err_role = "assistant" if last_role == "user" else "user" sys_err_msg = { - "role": "user", + "role": err_role, "content": f"[System error during processing: {error_msg}]", } messages.append(sys_err_msg) @@ -6239,6 +6440,21 @@ class AIAgent: "partial": False, # True only when stopped due to invalid tool calls "interrupted": interrupted, "response_previewed": getattr(self, "_response_was_previewed", False), + "model": self.model, + "provider": self.provider, + "base_url": self.base_url, + "input_tokens": self.session_input_tokens, + "output_tokens": self.session_output_tokens, + "cache_read_tokens": self.session_cache_read_tokens, + "cache_write_tokens": self.session_cache_write_tokens, + "reasoning_tokens": self.session_reasoning_tokens, + "prompt_tokens": self.session_prompt_tokens, + "completion_tokens": self.session_completion_tokens, + "total_tokens": self.session_total_tokens, + "last_prompt_tokens": getattr(self.context_compressor, "last_prompt_tokens", 0) or 0, + "estimated_cost_usd": self.session_estimated_cost_usd, + "cost_status": self.session_cost_status, + "cost_source": self.session_cost_source, } self._response_was_previewed = False diff --git a/skills/inference-sh/DESCRIPTION.md b/skills/inference-sh/DESCRIPTION.md new file mode 100644 index 000000000..011ede4c1 --- /dev/null +++ b/skills/inference-sh/DESCRIPTION.md @@ -0,0 +1,19 @@ +# inference.sh + +Run 150+ AI applications in the cloud via the [inference.sh](https://inference.sh) platform. + +**One API key for everything** — access image generation, video creation, LLMs, search, 3D, and more through a single account. No need to manage separate API keys for each provider. + +## Available Skills + +- **cli**: Use the inference.sh CLI (`infsh`) via the terminal tool + +## What's Included + +- **Image Generation**: FLUX, Reve, Seedream, Grok Imagine, Gemini +- **Video Generation**: Veo, Wan, Seedance, OmniHuman, HunyuanVideo +- **LLMs**: Claude, Gemini, Kimi, GLM-4 (via OpenRouter) +- **Search**: Tavily, Exa +- **3D**: Rodin +- **Social**: Twitter/X automation +- **Audio**: TTS, voice cloning diff --git a/skills/inference-sh/cli/SKILL.md b/skills/inference-sh/cli/SKILL.md new file mode 100644 index 000000000..79183f61c --- /dev/null +++ b/skills/inference-sh/cli/SKILL.md @@ -0,0 +1,155 @@ +--- +name: inference-sh-cli +description: "Run 150+ AI apps via inference.sh CLI (infsh) — image generation, video creation, LLMs, search, 3D, social automation. Uses the terminal tool. Triggers: inference.sh, infsh, ai apps, flux, veo, image generation, video generation, seedream, seedance, tavily" +version: 1.0.0 +author: okaris +license: MIT +metadata: + hermes: + tags: [AI, image-generation, video, LLM, search, inference, FLUX, Veo, Claude] + related_skills: [] +--- + +# inference.sh CLI + +Run 150+ AI apps in the cloud with a simple CLI. No GPU required. + +All commands use the **terminal tool** to run `infsh` commands. + +## When to Use + +- User asks to generate images (FLUX, Reve, Seedream, Grok, Gemini image) +- User asks to generate video (Veo, Wan, Seedance, OmniHuman) +- User asks about inference.sh or infsh +- User wants to run AI apps without managing individual provider APIs +- User asks for AI-powered search (Tavily, Exa) +- User needs avatar/lipsync generation + +## Prerequisites + +The `infsh` CLI must be installed and authenticated. Check with: + +```bash +infsh me +``` + +If not installed: + +```bash +curl -fsSL https://cli.inference.sh | sh +infsh login +``` + +See `references/authentication.md` for full setup details. + +## Workflow + +### 1. Always Search First + +Never guess app names — always search to find the correct app ID: + +```bash +infsh app list --search flux +infsh app list --search video +infsh app list --search image +``` + +### 2. Run an App + +Use the exact app ID from the search results. Always use `--json` for machine-readable output: + +```bash +infsh app run --input '{"prompt": "your prompt here"}' --json +``` + +### 3. Parse the Output + +The JSON output contains URLs to generated media. Present these to the user with `MEDIA:` for inline display. + +## Common Commands + +### Image Generation + +```bash +# Search for image apps +infsh app list --search image + +# FLUX Dev with LoRA +infsh app run falai/flux-dev-lora --input '{"prompt": "sunset over mountains", "num_images": 1}' --json + +# Gemini image generation +infsh app run google/gemini-2-5-flash-image --input '{"prompt": "futuristic city", "num_images": 1}' --json + +# Seedream (ByteDance) +infsh app run bytedance/seedream-5-lite --input '{"prompt": "nature scene"}' --json + +# Grok Imagine (xAI) +infsh app run xai/grok-imagine-image --input '{"prompt": "abstract art"}' --json +``` + +### Video Generation + +```bash +# Search for video apps +infsh app list --search video + +# Veo 3.1 (Google) +infsh app run google/veo-3-1-fast --input '{"prompt": "drone shot of coastline"}' --json + +# Seedance (ByteDance) +infsh app run bytedance/seedance-1-5-pro --input '{"prompt": "dancing figure", "resolution": "1080p"}' --json + +# Wan 2.5 +infsh app run falai/wan-2-5 --input '{"prompt": "person walking through city"}' --json +``` + +### Local File Uploads + +The CLI automatically uploads local files when you provide a path: + +```bash +# Upscale a local image +infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}' --json + +# Image-to-video from local file +infsh app run falai/wan-2-5-i2v --input '{"image": "/path/to/image.png", "prompt": "make it move"}' --json + +# Avatar with audio +infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/audio.mp3", "image": "/path/to/face.jpg"}' --json +``` + +### Search & Research + +```bash +infsh app list --search search +infsh app run tavily/tavily-search --input '{"query": "latest AI news"}' --json +infsh app run exa/exa-search --input '{"query": "machine learning papers"}' --json +``` + +### Other Categories + +```bash +# 3D generation +infsh app list --search 3d + +# Audio / TTS +infsh app list --search tts + +# Twitter/X automation +infsh app list --search twitter +``` + +## Pitfalls + +1. **Never guess app IDs** — always run `infsh app list --search ` first. App IDs change and new apps are added frequently. +2. **Always use `--json`** — raw output is hard to parse. The `--json` flag gives structured output with URLs. +3. **Check authentication** — if commands fail with auth errors, run `infsh login` or verify `INFSH_API_KEY` is set. +4. **Long-running apps** — video generation can take 30-120 seconds. The terminal tool timeout should be sufficient, but warn the user it may take a moment. +5. **Input format** — the `--input` flag takes a JSON string. Make sure to properly escape quotes. + +## Reference Docs + +- `references/authentication.md` — Setup, login, API keys +- `references/app-discovery.md` — Searching and browsing the app catalog +- `references/running-apps.md` — Running apps, input formats, output handling +- `references/cli-reference.md` — Complete CLI command reference diff --git a/skills/inference-sh/cli/references/app-discovery.md b/skills/inference-sh/cli/references/app-discovery.md new file mode 100644 index 000000000..adcac8c5d --- /dev/null +++ b/skills/inference-sh/cli/references/app-discovery.md @@ -0,0 +1,112 @@ +# Discovering Apps + +## List All Apps + +```bash +infsh app list +``` + +## Pagination + +```bash +infsh app list --page 2 +``` + +## Filter by Category + +```bash +infsh app list --category image +infsh app list --category video +infsh app list --category audio +infsh app list --category text +infsh app list --category other +``` + +## Search + +```bash +infsh app search "flux" +infsh app search "video generation" +infsh app search "tts" -l +infsh app search "image" --category image +``` + +Or use the flag form: + +```bash +infsh app list --search "flux" +infsh app list --search "video generation" +infsh app list --search "tts" +``` + +## Featured Apps + +```bash +infsh app list --featured +``` + +## Newest First + +```bash +infsh app list --new +``` + +## Detailed View + +```bash +infsh app list -l +``` + +Shows table with app name, category, description, and featured status. + +## Save to File + +```bash +infsh app list --save apps.json +``` + +## Your Apps + +List apps you've deployed: + +```bash +infsh app my +infsh app my -l # detailed +``` + +## Get App Details + +```bash +infsh app get falai/flux-dev-lora +infsh app get falai/flux-dev-lora --json +``` + +Shows full app info including input/output schema. + +## Popular Apps by Category + +### Image Generation +- `falai/flux-dev-lora` - FLUX.2 Dev (high quality) +- `falai/flux-2-klein-lora` - FLUX.2 Klein (fastest) +- `infsh/sdxl` - Stable Diffusion XL +- `google/gemini-3-pro-image-preview` - Gemini 3 Pro +- `xai/grok-imagine-image` - Grok image generation + +### Video Generation +- `google/veo-3-1-fast` - Veo 3.1 Fast +- `google/veo-3` - Veo 3 +- `bytedance/seedance-1-5-pro` - Seedance 1.5 Pro +- `infsh/ltx-video-2` - LTX Video 2 (with audio) +- `bytedance/omnihuman-1-5` - OmniHuman avatar + +### Audio +- `infsh/dia-tts` - Conversational TTS +- `infsh/kokoro-tts` - Kokoro TTS +- `infsh/fast-whisper-large-v3` - Fast transcription +- `infsh/diffrythm` - Music generation + +## Documentation + +- [Browsing the Grid](https://inference.sh/docs/apps/browsing-grid) - Visual app browsing +- [Apps Overview](https://inference.sh/docs/apps/overview) - Understanding apps +- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps diff --git a/skills/inference-sh/cli/references/authentication.md b/skills/inference-sh/cli/references/authentication.md new file mode 100644 index 000000000..3b6519d3d --- /dev/null +++ b/skills/inference-sh/cli/references/authentication.md @@ -0,0 +1,59 @@ +# Authentication & Setup + +## Install the CLI + +```bash +curl -fsSL https://cli.inference.sh | sh +``` + +## Login + +```bash +infsh login +``` + +This opens a browser for authentication. After login, credentials are stored locally. + +## Check Authentication + +```bash +infsh me +``` + +Shows your user info if authenticated. + +## Environment Variable + +For CI/CD or scripts, set your API key: + +```bash +export INFSH_API_KEY=your-api-key +``` + +The environment variable overrides the config file. + +## Update CLI + +```bash +infsh update +``` + +Or reinstall: + +```bash +curl -fsSL https://cli.inference.sh | sh +``` + +## Troubleshooting + +| Error | Solution | +|-------|----------| +| "not authenticated" | Run `infsh login` | +| "command not found" | Reinstall CLI or add to PATH | +| "API key invalid" | Check `INFSH_API_KEY` or re-login | + +## Documentation + +- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide +- [API Authentication](https://inference.sh/docs/api/authentication) - API key management +- [Secrets](https://inference.sh/docs/secrets/overview) - Managing credentials diff --git a/skills/inference-sh/cli/references/cli-reference.md b/skills/inference-sh/cli/references/cli-reference.md new file mode 100644 index 000000000..50825825f --- /dev/null +++ b/skills/inference-sh/cli/references/cli-reference.md @@ -0,0 +1,104 @@ +# CLI Reference + +## Installation + +```bash +curl -fsSL https://cli.inference.sh | sh +``` + +## Global Commands + +| Command | Description | +|---------|-------------| +| `infsh help` | Show help | +| `infsh version` | Show CLI version | +| `infsh update` | Update CLI to latest | +| `infsh login` | Authenticate | +| `infsh me` | Show current user | + +## App Commands + +### Discovery + +| Command | Description | +|---------|-------------| +| `infsh app list` | List available apps | +| `infsh app list --category ` | Filter by category (image, video, audio, text, other) | +| `infsh app search ` | Search apps | +| `infsh app list --search ` | Search apps (flag form) | +| `infsh app list --featured` | Show featured apps | +| `infsh app list --new` | Sort by newest | +| `infsh app list --page ` | Pagination | +| `infsh app list -l` | Detailed table view | +| `infsh app list --save ` | Save to JSON file | +| `infsh app my` | List your deployed apps | +| `infsh app get ` | Get app details | +| `infsh app get --json` | Get app details as JSON | + +### Execution + +| Command | Description | +|---------|-------------| +| `infsh app run --input ` | Run app with input file | +| `infsh app run --input ''` | Run with inline JSON | +| `infsh app run --input --no-wait` | Run without waiting for completion | +| `infsh app sample ` | Show sample input | +| `infsh app sample --save ` | Save sample to file | + +## Task Commands + +| Command | Description | +|---------|-------------| +| `infsh task get ` | Get task status and result | +| `infsh task get --json` | Get task as JSON | +| `infsh task get --save ` | Save task result to file | + +### Development + +| Command | Description | +|---------|-------------| +| `infsh app init` | Create new app (interactive) | +| `infsh app init ` | Create new app with name | +| `infsh app test --input ` | Test app locally | +| `infsh app deploy` | Deploy app | +| `infsh app deploy --dry-run` | Validate without deploying | +| `infsh app pull ` | Pull app source | +| `infsh app pull --all` | Pull all your apps | + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `INFSH_API_KEY` | API key (overrides config) | + +## Shell Completions + +```bash +# Bash +infsh completion bash > /etc/bash_completion.d/infsh + +# Zsh +infsh completion zsh > "${fpath[1]}/_infsh" + +# Fish +infsh completion fish > ~/.config/fish/completions/infsh.fish +``` + +## App Name Format + +Apps use the format `namespace/app-name`: + +- `falai/flux-dev-lora` - fal.ai's FLUX 2 Dev +- `google/veo-3` - Google's Veo 3 +- `infsh/sdxl` - inference.sh's SDXL +- `bytedance/seedance-1-5-pro` - ByteDance's Seedance +- `xai/grok-imagine-image` - xAI's Grok + +Version pinning: `namespace/app-name@version` + +## Documentation + +- [CLI Setup](https://inference.sh/docs/extend/cli-setup) - Complete CLI installation guide +- [Running Apps](https://inference.sh/docs/apps/running) - How to run apps via CLI +- [Creating an App](https://inference.sh/docs/extend/creating-app) - Build your own apps +- [Deploying](https://inference.sh/docs/extend/deploying) - Deploy apps to the cloud diff --git a/skills/inference-sh/cli/references/running-apps.md b/skills/inference-sh/cli/references/running-apps.md new file mode 100644 index 000000000..e930d5cfb --- /dev/null +++ b/skills/inference-sh/cli/references/running-apps.md @@ -0,0 +1,171 @@ +# Running Apps + +## Basic Run + +```bash +infsh app run user/app-name --input input.json +``` + +## Inline JSON + +```bash +infsh app run falai/flux-dev-lora --input '{"prompt": "a sunset over mountains"}' +``` + +## Version Pinning + +```bash +infsh app run user/app-name@1.0.0 --input input.json +``` + +## Local File Uploads + +The CLI automatically uploads local files when you provide a file path instead of a URL. Any field that accepts a URL also accepts a local path: + +```bash +# Upscale a local image +infsh app run falai/topaz-image-upscaler --input '{"image": "/path/to/photo.jpg", "upscale_factor": 2}' + +# Image-to-video from local file +infsh app run falai/wan-2-5-i2v --input '{"image": "./my-image.png", "prompt": "make it move"}' + +# Avatar with local audio and image +infsh app run bytedance/omnihuman-1-5 --input '{"audio": "/path/to/speech.mp3", "image": "/path/to/face.jpg"}' + +# Post tweet with local media +infsh app run x/post-create --input '{"text": "Check this out!", "media": "./screenshot.png"}' +``` + +Supported paths: +- Absolute paths: `/home/user/images/photo.jpg` +- Relative paths: `./image.png`, `../data/video.mp4` +- Home directory: `~/Pictures/photo.jpg` + +## Generate Sample Input + +Before running, generate a sample input file: + +```bash +infsh app sample falai/flux-dev-lora +``` + +Save to file: + +```bash +infsh app sample falai/flux-dev-lora --save input.json +``` + +Then edit `input.json` and run: + +```bash +infsh app run falai/flux-dev-lora --input input.json +``` + +## Workflow Example + +### Image Generation with FLUX + +```bash +# 1. Get app details +infsh app get falai/flux-dev-lora + +# 2. Generate sample input +infsh app sample falai/flux-dev-lora --save input.json + +# 3. Edit input.json +# { +# "prompt": "a cat astronaut floating in space", +# "num_images": 1, +# "image_size": "landscape_16_9" +# } + +# 4. Run +infsh app run falai/flux-dev-lora --input input.json +``` + +### Video Generation with Veo + +```bash +# 1. Generate sample +infsh app sample google/veo-3-1-fast --save input.json + +# 2. Edit prompt +# { +# "prompt": "A drone shot flying over a forest at sunset" +# } + +# 3. Run +infsh app run google/veo-3-1-fast --input input.json +``` + +### Text-to-Speech + +```bash +# Quick inline run +infsh app run falai/kokoro-tts --input '{"text": "Hello, this is a test."}' +``` + +## Task Tracking + +When you run an app, the CLI shows the task ID: + +``` +Running falai/flux-dev-lora +Task ID: abc123def456 +``` + +For long-running tasks, you can check status anytime: + +```bash +# Check task status +infsh task get abc123def456 + +# Get result as JSON +infsh task get abc123def456 --json + +# Save result to file +infsh task get abc123def456 --save result.json +``` + +### Run Without Waiting + +For very long tasks, run in background: + +```bash +# Submit and return immediately +infsh app run google/veo-3 --input input.json --no-wait + +# Check later +infsh task get +``` + +## Output + +The CLI returns the app output directly. For file outputs (images, videos, audio), you'll receive URLs to download. + +Example output: + +```json +{ + "images": [ + { + "url": "https://cloud.inference.sh/...", + "content_type": "image/png" + } + ] +} +``` + +## Error Handling + +| Error | Cause | Solution | +|-------|-------|----------| +| "invalid input" | Schema mismatch | Check `infsh app get` for required fields | +| "app not found" | Wrong app name | Check `infsh app list --search` | +| "quota exceeded" | Out of credits | Check account balance | + +## Documentation + +- [Running Apps](https://inference.sh/docs/apps/running) - Complete running apps guide +- [Streaming Results](https://inference.sh/docs/api/sdk/streaming) - Real-time progress updates +- [Setup Parameters](https://inference.sh/docs/apps/setup-parameters) - Configuring app inputs diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index 8c899d2bb..75570e343 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -113,11 +113,13 @@ class TestDefaultContextLengths: def test_gpt4_models_128k_or_1m(self): # gpt-4.1 and gpt-4.1-mini have 1M context; other gpt-4* have 128k for key, value in DEFAULT_CONTEXT_LENGTHS.items(): - if "gpt-4" in key: - if "gpt-4.1" in key: - assert value == 1047576, f"{key} should be 1047576 (1M)" - else: - assert value == 128000, f"{key} should be 128000" + if "gpt-4" in key and "gpt-4.1" not in key: + assert value == 128000, f"{key} should be 128000" + + def test_gpt41_models_1m(self): + for key, value in DEFAULT_CONTEXT_LENGTHS.items(): + if "gpt-4.1" in key: + assert value == 1047576, f"{key} should be 1047576" def test_gemini_models_1m(self): for key, value in DEFAULT_CONTEXT_LENGTHS.items(): diff --git a/tests/agent/test_usage_pricing.py b/tests/agent/test_usage_pricing.py new file mode 100644 index 000000000..6d972dfa7 --- /dev/null +++ b/tests/agent/test_usage_pricing.py @@ -0,0 +1,101 @@ +from types import SimpleNamespace + +from agent.usage_pricing import ( + CanonicalUsage, + estimate_usage_cost, + get_pricing_entry, + normalize_usage, +) + + +def test_normalize_usage_anthropic_keeps_cache_buckets_separate(): + usage = SimpleNamespace( + input_tokens=1000, + output_tokens=500, + cache_read_input_tokens=2000, + cache_creation_input_tokens=400, + ) + + normalized = normalize_usage(usage, provider="anthropic", api_mode="anthropic_messages") + + assert normalized.input_tokens == 1000 + assert normalized.output_tokens == 500 + assert normalized.cache_read_tokens == 2000 + assert normalized.cache_write_tokens == 400 + assert normalized.prompt_tokens == 3400 + + +def test_normalize_usage_openai_subtracts_cached_prompt_tokens(): + usage = SimpleNamespace( + prompt_tokens=3000, + completion_tokens=700, + prompt_tokens_details=SimpleNamespace(cached_tokens=1800), + ) + + normalized = normalize_usage(usage, provider="openai", api_mode="chat_completions") + + assert normalized.input_tokens == 1200 + assert normalized.cache_read_tokens == 1800 + assert normalized.output_tokens == 700 + + +def test_openrouter_models_api_pricing_is_converted_from_per_token_to_per_million(monkeypatch): + monkeypatch.setattr( + "agent.usage_pricing.fetch_model_metadata", + lambda: { + "anthropic/claude-opus-4.6": { + "pricing": { + "prompt": "0.000005", + "completion": "0.000025", + "input_cache_read": "0.0000005", + "input_cache_write": "0.00000625", + } + } + }, + ) + + entry = get_pricing_entry( + "anthropic/claude-opus-4.6", + provider="openrouter", + base_url="https://openrouter.ai/api/v1", + ) + + assert float(entry.input_cost_per_million) == 5.0 + assert float(entry.output_cost_per_million) == 25.0 + assert float(entry.cache_read_cost_per_million) == 0.5 + assert float(entry.cache_write_cost_per_million) == 6.25 + + +def test_estimate_usage_cost_marks_subscription_routes_included(): + result = estimate_usage_cost( + "gpt-5.3-codex", + CanonicalUsage(input_tokens=1000, output_tokens=500), + provider="openai-codex", + base_url="https://chatgpt.com/backend-api/codex", + ) + + assert result.status == "included" + assert float(result.amount_usd) == 0.0 + + +def test_estimate_usage_cost_refuses_cache_pricing_without_official_cache_rate(monkeypatch): + monkeypatch.setattr( + "agent.usage_pricing.fetch_model_metadata", + lambda: { + "google/gemini-2.5-pro": { + "pricing": { + "prompt": "0.00000125", + "completion": "0.00001", + } + } + }, + ) + + result = estimate_usage_cost( + "google/gemini-2.5-pro", + CanonicalUsage(input_tokens=1000, output_tokens=500, cache_read_tokens=100), + provider="openrouter", + base_url="https://openrouter.ai/api/v1", + ) + + assert result.status == "unknown" diff --git a/tests/gateway/test_background_process_notifications.py b/tests/gateway/test_background_process_notifications.py index 10069fe9c..9c1404f89 100644 --- a/tests/gateway/test_background_process_notifications.py +++ b/tests/gateway/test_background_process_notifications.py @@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner: return runner -def _watcher_dict(session_id="proc_test"): - return { +def _watcher_dict(session_id="proc_test", thread_id=""): + d = { "session_id": session_id, "check_interval": 0, "platform": "telegram", "chat_id": "123", } + if thread_id: + d["thread_id"] = thread_id + return d # --------------------------------------------------------------------------- @@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode( if expected_fragment is not None: sent_message = adapter.send.await_args.args[1] assert expected_fragment in sent_message + + +@pytest.mark.asyncio +async def test_thread_id_passed_to_send(monkeypatch, tmp_path): + """thread_id from watcher dict is forwarded as metadata to adapter.send().""" + import tools.process_registry as pr_module + + sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path, "all") + adapter = runner.adapters[Platform.TELEGRAM] + + await runner._run_process_watcher(_watcher_dict(thread_id="42")) + + assert adapter.send.await_count == 1 + _, kwargs = adapter.send.call_args + assert kwargs["metadata"] == {"thread_id": "42"} + + +@pytest.mark.asyncio +async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path): + """When thread_id is empty, metadata should be None (general topic).""" + import tools.process_registry as pr_module + + sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)] + monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions)) + + async def _instant_sleep(*_a, **_kw): + pass + monkeypatch.setattr(asyncio, "sleep", _instant_sleep) + + runner = _build_runner(monkeypatch, tmp_path, "all") + adapter = runner.adapters[Platform.TELEGRAM] + + await runner._run_process_watcher(_watcher_dict()) + + assert adapter.send.await_count == 1 + _, kwargs = adapter.send.call_args + assert kwargs["metadata"] is None diff --git a/tests/gateway/test_dingtalk.py b/tests/gateway/test_dingtalk.py new file mode 100644 index 000000000..5c73253fb --- /dev/null +++ b/tests/gateway/test_dingtalk.py @@ -0,0 +1,274 @@ +"""Tests for DingTalk platform adapter.""" +import asyncio +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock + +import pytest + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Requirements check +# --------------------------------------------------------------------------- + + +class TestDingTalkRequirements: + + def test_returns_false_when_sdk_missing(self, monkeypatch): + with patch.dict("sys.modules", {"dingtalk_stream": None}): + monkeypatch.setattr( + "gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False + ) + from gateway.platforms.dingtalk import check_dingtalk_requirements + assert check_dingtalk_requirements() is False + + def test_returns_false_when_env_vars_missing(self, monkeypatch): + monkeypatch.setattr( + "gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True + ) + monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True) + monkeypatch.delenv("DINGTALK_CLIENT_ID", raising=False) + monkeypatch.delenv("DINGTALK_CLIENT_SECRET", raising=False) + from gateway.platforms.dingtalk import check_dingtalk_requirements + assert check_dingtalk_requirements() is False + + def test_returns_true_when_all_available(self, monkeypatch): + monkeypatch.setattr( + "gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", True + ) + monkeypatch.setattr("gateway.platforms.dingtalk.HTTPX_AVAILABLE", True) + monkeypatch.setenv("DINGTALK_CLIENT_ID", "test-id") + monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "test-secret") + from gateway.platforms.dingtalk import check_dingtalk_requirements + assert check_dingtalk_requirements() is True + + +# --------------------------------------------------------------------------- +# Adapter construction +# --------------------------------------------------------------------------- + + +class TestDingTalkAdapterInit: + + def test_reads_config_from_extra(self): + from gateway.platforms.dingtalk import DingTalkAdapter + config = PlatformConfig( + enabled=True, + extra={"client_id": "cfg-id", "client_secret": "cfg-secret"}, + ) + adapter = DingTalkAdapter(config) + assert adapter._client_id == "cfg-id" + assert adapter._client_secret == "cfg-secret" + assert adapter.name == "Dingtalk" # base class uses .title() + + def test_falls_back_to_env_vars(self, monkeypatch): + monkeypatch.setenv("DINGTALK_CLIENT_ID", "env-id") + monkeypatch.setenv("DINGTALK_CLIENT_SECRET", "env-secret") + from gateway.platforms.dingtalk import DingTalkAdapter + config = PlatformConfig(enabled=True) + adapter = DingTalkAdapter(config) + assert adapter._client_id == "env-id" + assert adapter._client_secret == "env-secret" + + +# --------------------------------------------------------------------------- +# Message text extraction +# --------------------------------------------------------------------------- + + +class TestExtractText: + + def test_extracts_dict_text(self): + from gateway.platforms.dingtalk import DingTalkAdapter + msg = MagicMock() + msg.text = {"content": " hello world "} + msg.rich_text = None + assert DingTalkAdapter._extract_text(msg) == "hello world" + + def test_extracts_string_text(self): + from gateway.platforms.dingtalk import DingTalkAdapter + msg = MagicMock() + msg.text = "plain text" + msg.rich_text = None + assert DingTalkAdapter._extract_text(msg) == "plain text" + + def test_falls_back_to_rich_text(self): + from gateway.platforms.dingtalk import DingTalkAdapter + msg = MagicMock() + msg.text = "" + msg.rich_text = [{"text": "part1"}, {"text": "part2"}, {"image": "url"}] + assert DingTalkAdapter._extract_text(msg) == "part1 part2" + + def test_returns_empty_for_no_content(self): + from gateway.platforms.dingtalk import DingTalkAdapter + msg = MagicMock() + msg.text = "" + msg.rich_text = None + assert DingTalkAdapter._extract_text(msg) == "" + + +# --------------------------------------------------------------------------- +# Deduplication +# --------------------------------------------------------------------------- + + +class TestDeduplication: + + def test_first_message_not_duplicate(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + assert adapter._is_duplicate("msg-1") is False + + def test_second_same_message_is_duplicate(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._is_duplicate("msg-1") + assert adapter._is_duplicate("msg-1") is True + + def test_different_messages_not_duplicate(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._is_duplicate("msg-1") + assert adapter._is_duplicate("msg-2") is False + + def test_cache_cleanup_on_overflow(self): + from gateway.platforms.dingtalk import DingTalkAdapter, DEDUP_MAX_SIZE + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + # Fill beyond max + for i in range(DEDUP_MAX_SIZE + 10): + adapter._is_duplicate(f"msg-{i}") + # Cache should have been pruned + assert len(adapter._seen_messages) <= DEDUP_MAX_SIZE + 10 + + +# --------------------------------------------------------------------------- +# Send +# --------------------------------------------------------------------------- + + +class TestSend: + + @pytest.mark.asyncio + async def test_send_posts_to_webhook(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = "OK" + + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + adapter._http_client = mock_client + + result = await adapter.send( + "chat-123", "Hello!", + metadata={"session_webhook": "https://dingtalk.example/webhook"} + ) + assert result.success is True + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert call_args[0][0] == "https://dingtalk.example/webhook" + payload = call_args[1]["json"] + assert payload["msgtype"] == "markdown" + assert payload["markdown"]["title"] == "Hermes" + assert payload["markdown"]["text"] == "Hello!" + + @pytest.mark.asyncio + async def test_send_fails_without_webhook(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._http_client = AsyncMock() + + result = await adapter.send("chat-123", "Hello!") + assert result.success is False + assert "session_webhook" in result.error + + @pytest.mark.asyncio + async def test_send_uses_cached_webhook(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + adapter._http_client = mock_client + adapter._session_webhooks["chat-123"] = "https://cached.example/webhook" + + result = await adapter.send("chat-123", "Hello!") + assert result.success is True + assert mock_client.post.call_args[0][0] == "https://cached.example/webhook" + + @pytest.mark.asyncio + async def test_send_handles_http_error(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + adapter._http_client = mock_client + + result = await adapter.send( + "chat-123", "Hello!", + metadata={"session_webhook": "https://example/webhook"} + ) + assert result.success is False + assert "400" in result.error + + +# --------------------------------------------------------------------------- +# Connect / disconnect +# --------------------------------------------------------------------------- + + +class TestConnect: + + @pytest.mark.asyncio + async def test_connect_fails_without_sdk(self, monkeypatch): + monkeypatch.setattr( + "gateway.platforms.dingtalk.DINGTALK_STREAM_AVAILABLE", False + ) + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_connect_fails_without_credentials(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._client_id = "" + adapter._client_secret = "" + result = await adapter.connect() + assert result is False + + @pytest.mark.asyncio + async def test_disconnect_cleans_up(self): + from gateway.platforms.dingtalk import DingTalkAdapter + adapter = DingTalkAdapter(PlatformConfig(enabled=True)) + adapter._session_webhooks["a"] = "http://x" + adapter._seen_messages["b"] = 1.0 + adapter._http_client = AsyncMock() + adapter._stream_task = None + + await adapter.disconnect() + assert len(adapter._session_webhooks) == 0 + assert len(adapter._seen_messages) == 0 + assert adapter._http_client is None + + +# --------------------------------------------------------------------------- +# Platform enum +# --------------------------------------------------------------------------- + + +class TestPlatformEnum: + + def test_dingtalk_in_platform_enum(self): + assert Platform.DINGTALK.value == "dingtalk" diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py new file mode 100644 index 000000000..31e59caeb --- /dev/null +++ b/tests/gateway/test_matrix.py @@ -0,0 +1,448 @@ +"""Tests for Matrix platform adapter.""" +import json +import re +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Platform & Config +# --------------------------------------------------------------------------- + +class TestMatrixPlatformEnum: + def test_matrix_enum_exists(self): + assert Platform.MATRIX.value == "matrix" + + def test_matrix_in_platform_list(self): + platforms = [p.value for p in Platform] + assert "matrix" in platforms + + +class TestMatrixConfigLoading: + def test_apply_env_overrides_with_access_token(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX in config.platforms + mc = config.platforms[Platform.MATRIX] + assert mc.enabled is True + assert mc.token == "syt_abc123" + assert mc.extra.get("homeserver") == "https://matrix.example.org" + + def test_apply_env_overrides_with_password(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.setenv("MATRIX_PASSWORD", "secret123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX in config.platforms + mc = config.platforms[Platform.MATRIX] + assert mc.enabled is True + assert mc.extra.get("password") == "secret123" + assert mc.extra.get("user_id") == "@bot:example.org" + + def test_matrix_not_loaded_without_creds(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.delenv("MATRIX_PASSWORD", raising=False) + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX not in config.platforms + + def test_matrix_encryption_flag(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_ENCRYPTION", "true") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("encryption") is True + + def test_matrix_encryption_default_off(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("encryption") is False + + def test_matrix_home_room(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org") + monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + home = config.get_home_channel(Platform.MATRIX) + assert home is not None + assert home.chat_id == "!room123:example.org" + assert home.name == "Bot Room" + + def test_matrix_user_id_stored_in_extra(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("user_id") == "@hermes:example.org" + + +# --------------------------------------------------------------------------- +# Adapter helpers +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Create a MatrixAdapter with mocked config.""" + from gateway.platforms.matrix import MatrixAdapter + config = PlatformConfig( + enabled=True, + token="syt_test_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + }, + ) + adapter = MatrixAdapter(config) + return adapter + + +# --------------------------------------------------------------------------- +# mxc:// URL conversion +# --------------------------------------------------------------------------- + +class TestMatrixMxcToHttp: + def setup_method(self): + self.adapter = _make_adapter() + + def test_basic_mxc_conversion(self): + """mxc://server/media_id should become an authenticated HTTP URL.""" + mxc = "mxc://matrix.org/abc123" + result = self.adapter._mxc_to_http(mxc) + assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123" + + def test_mxc_with_different_server(self): + """mxc:// from a different server should still use our homeserver.""" + mxc = "mxc://other.server/media456" + result = self.adapter._mxc_to_http(mxc) + assert result.startswith("https://matrix.example.org/") + assert "other.server/media456" in result + + def test_non_mxc_url_passthrough(self): + """Non-mxc URLs should be returned unchanged.""" + url = "https://example.com/image.png" + assert self.adapter._mxc_to_http(url) == url + + def test_mxc_uses_client_v1_endpoint(self): + """Should use /_matrix/client/v1/media/download/ not the deprecated path.""" + mxc = "mxc://example.com/test123" + result = self.adapter._mxc_to_http(mxc) + assert "/_matrix/client/v1/media/download/" in result + assert "/_matrix/media/v3/download/" not in result + + +# --------------------------------------------------------------------------- +# DM detection +# --------------------------------------------------------------------------- + +class TestMatrixDmDetection: + def setup_method(self): + self.adapter = _make_adapter() + + def test_room_in_m_direct_is_dm(self): + """A room listed in m.direct should be detected as DM.""" + self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"} + self.adapter._dm_rooms = { + "!dm_room:ex.org": True, + "!group_room:ex.org": False, + } + + assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True + assert self.adapter._dm_rooms.get("!group_room:ex.org") is False + + def test_unknown_room_not_in_cache(self): + """Unknown rooms should not be in the DM cache.""" + self.adapter._dm_rooms = {} + assert self.adapter._dm_rooms.get("!unknown:ex.org") is None + + @pytest.mark.asyncio + async def test_refresh_dm_cache_with_m_direct(self): + """_refresh_dm_cache should populate _dm_rooms from m.direct data.""" + self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"} + + mock_client = MagicMock() + mock_resp = MagicMock() + mock_resp.content = { + "@alice:ex.org": ["!room_a:ex.org"], + "@bob:ex.org": ["!room_b:ex.org"], + } + mock_client.get_account_data = AsyncMock(return_value=mock_resp) + self.adapter._client = mock_client + + await self.adapter._refresh_dm_cache() + + assert self.adapter._dm_rooms["!room_a:ex.org"] is True + assert self.adapter._dm_rooms["!room_b:ex.org"] is True + assert self.adapter._dm_rooms["!room_c:ex.org"] is False + + +# --------------------------------------------------------------------------- +# Reply fallback stripping +# --------------------------------------------------------------------------- + +class TestMatrixReplyFallbackStripping: + """Test that Matrix reply fallback lines ('> ' prefix) are stripped.""" + + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._user_id = "@bot:example.org" + self.adapter._startup_ts = 0.0 + self.adapter._dm_rooms = {} + self.adapter._message_handler = AsyncMock() + + def _strip_fallback(self, body: str, has_reply: bool = True) -> str: + """Simulate the reply fallback stripping logic from _on_room_message.""" + reply_to = "some_event_id" if has_reply else None + if reply_to and body.startswith("> "): + lines = body.split("\n") + stripped = [] + past_fallback = False + for line in lines: + if not past_fallback: + if line.startswith("> ") or line == ">": + continue + if line == "": + past_fallback = True + continue + past_fallback = True + stripped.append(line) + body = "\n".join(stripped) if stripped else body + return body + + def test_simple_reply_fallback(self): + body = "> <@alice:ex.org> Original message\n\nActual reply" + result = self._strip_fallback(body) + assert result == "Actual reply" + + def test_multiline_reply_fallback(self): + body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response" + result = self._strip_fallback(body) + assert result == "My response" + + def test_no_reply_fallback_preserved(self): + body = "Just a normal message" + result = self._strip_fallback(body, has_reply=False) + assert result == "Just a normal message" + + def test_quote_without_reply_preserved(self): + """'> ' lines without a reply_to context should be preserved.""" + body = "> This is a blockquote" + result = self._strip_fallback(body, has_reply=False) + assert result == "> This is a blockquote" + + def test_empty_fallback_separator(self): + """The blank line between fallback and actual content should be stripped.""" + body = "> <@alice:ex.org> hi\n>\n\nResponse" + result = self._strip_fallback(body) + assert result == "Response" + + def test_multiline_response_after_fallback(self): + body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3" + result = self._strip_fallback(body) + assert result == "Line 1\nLine 2\nLine 3" + + +# --------------------------------------------------------------------------- +# Thread detection +# --------------------------------------------------------------------------- + +class TestMatrixThreadDetection: + def test_thread_id_from_m_relates_to(self): + """m.relates_to with rel_type=m.thread should extract the event_id.""" + relates_to = { + "rel_type": "m.thread", + "event_id": "$thread_root_event", + "is_falling_back": True, + "m.in_reply_to": {"event_id": "$some_event"}, + } + # Simulate the extraction logic from _on_room_message + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id == "$thread_root_event" + + def test_no_thread_for_reply(self): + """m.in_reply_to without m.thread should not set thread_id.""" + relates_to = { + "m.in_reply_to": {"event_id": "$reply_event"}, + } + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + def test_no_thread_for_edit(self): + """m.replace relation should not set thread_id.""" + relates_to = { + "rel_type": "m.replace", + "event_id": "$edited_event", + } + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + def test_empty_relates_to(self): + """Empty m.relates_to should not set thread_id.""" + relates_to = {} + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + +# --------------------------------------------------------------------------- +# Format message +# --------------------------------------------------------------------------- + +class TestMatrixFormatMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_image_markdown_stripped(self): + """![alt](url) should be converted to just the URL.""" + result = self.adapter.format_message("![cat](https://img.example.com/cat.png)") + assert result == "https://img.example.com/cat.png" + + def test_regular_markdown_preserved(self): + """Standard markdown should be preserved (Matrix supports it).""" + content = "**bold** and *italic* and `code`" + assert self.adapter.format_message(content) == content + + def test_plain_text_unchanged(self): + content = "Hello, world!" + assert self.adapter.format_message(content) == content + + def test_multiple_images_stripped(self): + content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)" + result = self.adapter.format_message(content) + assert "![" not in result + assert "http://a.com/1.png" in result + assert "http://b.com/2.png" in result + + +# --------------------------------------------------------------------------- +# Markdown to HTML conversion +# --------------------------------------------------------------------------- + +class TestMatrixMarkdownToHtml: + def setup_method(self): + self.adapter = _make_adapter() + + def test_bold_conversion(self): + """**bold** should produce tags.""" + result = self.adapter._markdown_to_html("**bold**") + assert "" in result or "" in result + assert "bold" in result + + def test_italic_conversion(self): + """*italic* should produce tags.""" + result = self.adapter._markdown_to_html("*italic*") + assert "" in result or "" in result + + def test_inline_code(self): + """`code` should produce tags.""" + result = self.adapter._markdown_to_html("`code`") + assert "" in result + + def test_plain_text_returns_html(self): + """Plain text should still be returned (possibly with
or

).""" + result = self.adapter._markdown_to_html("Hello world") + assert "Hello world" in result + + +# --------------------------------------------------------------------------- +# Helper: display name extraction +# --------------------------------------------------------------------------- + +class TestMatrixDisplayName: + def setup_method(self): + self.adapter = _make_adapter() + + def test_get_display_name_from_room_users(self): + """Should get display name from room's users dict.""" + mock_room = MagicMock() + mock_user = MagicMock() + mock_user.display_name = "Alice" + mock_room.users = {"@alice:ex.org": mock_user} + + name = self.adapter._get_display_name(mock_room, "@alice:ex.org") + assert name == "Alice" + + def test_get_display_name_fallback_to_localpart(self): + """Should extract localpart from @user:server format.""" + mock_room = MagicMock() + mock_room.users = {} + + name = self.adapter._get_display_name(mock_room, "@bob:example.org") + assert name == "bob" + + def test_get_display_name_no_room(self): + """Should handle None room gracefully.""" + name = self.adapter._get_display_name(None, "@charlie:ex.org") + assert name == "charlie" + + +# --------------------------------------------------------------------------- +# Requirements check +# --------------------------------------------------------------------------- + +class TestMatrixRequirements: + def test_check_requirements_with_token(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + from gateway.platforms.matrix import check_matrix_requirements + try: + import nio # noqa: F401 + assert check_matrix_requirements() is True + except ImportError: + assert check_matrix_requirements() is False + + def test_check_requirements_without_creds(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.delenv("MATRIX_PASSWORD", raising=False) + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + from gateway.platforms.matrix import check_matrix_requirements + assert check_matrix_requirements() is False + + def test_check_requirements_without_homeserver(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test") + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + from gateway.platforms.matrix import check_matrix_requirements + assert check_matrix_requirements() is False diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py new file mode 100644 index 000000000..6b0fbd899 --- /dev/null +++ b/tests/gateway/test_mattermost.py @@ -0,0 +1,574 @@ +"""Tests for Mattermost platform adapter.""" +import json +import time +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Platform & Config +# --------------------------------------------------------------------------- + +class TestMattermostPlatformEnum: + def test_mattermost_enum_exists(self): + assert Platform.MATTERMOST.value == "mattermost" + + def test_mattermost_in_platform_list(self): + platforms = [p.value for p in Platform] + assert "mattermost" in platforms + + +class TestMattermostConfigLoading: + def test_apply_env_overrides_mattermost(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST in config.platforms + mc = config.platforms[Platform.MATTERMOST] + assert mc.enabled is True + assert mc.token == "mm-tok-abc123" + assert mc.extra.get("url") == "https://mm.example.com" + + def test_mattermost_not_loaded_without_token(self, monkeypatch): + monkeypatch.delenv("MATTERMOST_TOKEN", raising=False) + monkeypatch.delenv("MATTERMOST_URL", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST not in config.platforms + + def test_connected_platforms_includes_mattermost(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + connected = config.get_connected_platforms() + assert Platform.MATTERMOST in connected + + def test_mattermost_home_channel(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123") + monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + home = config.get_home_channel(Platform.MATTERMOST) + assert home is not None + assert home.chat_id == "ch_abc123" + assert home.name == "General" + + def test_mattermost_url_warning_without_url(self, monkeypatch): + """MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load.""" + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.delenv("MATTERMOST_URL", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST in config.platforms + assert config.platforms[Platform.MATTERMOST].extra.get("url") == "" + + +# --------------------------------------------------------------------------- +# Adapter format / truncate +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Create a MattermostAdapter with mocked config.""" + from gateway.platforms.mattermost import MattermostAdapter + config = PlatformConfig( + enabled=True, + token="test-token", + extra={"url": "https://mm.example.com"}, + ) + adapter = MattermostAdapter(config) + return adapter + + +class TestMattermostFormatMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_image_markdown_to_url(self): + """![alt](url) should be converted to just the URL.""" + result = self.adapter.format_message("![cat](https://img.example.com/cat.png)") + assert result == "https://img.example.com/cat.png" + + def test_image_markdown_strips_alt_text(self): + result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done") + assert "![" not in result + assert "https://x.com/a.jpg" in result + + def test_regular_markdown_preserved(self): + """Regular markdown (bold, italic, code) should be kept as-is.""" + content = "**bold** and *italic* and `code`" + assert self.adapter.format_message(content) == content + + def test_regular_links_preserved(self): + """Non-image links should be preserved.""" + content = "[click](https://example.com)" + assert self.adapter.format_message(content) == content + + def test_plain_text_unchanged(self): + content = "Hello, world!" + assert self.adapter.format_message(content) == content + + def test_multiple_images(self): + content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)" + result = self.adapter.format_message(content) + assert "![" not in result + assert "http://a.com/1.png" in result + assert "http://b.com/2.png" in result + + +class TestMattermostTruncateMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_short_message_single_chunk(self): + msg = "Hello, world!" + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) == 1 + assert chunks[0] == msg + + def test_long_message_splits(self): + msg = "a " * 2500 # 5000 chars + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) >= 2 + for chunk in chunks: + assert len(chunk) <= 4000 + + def test_custom_max_length(self): + msg = "Hello " * 20 + chunks = self.adapter.truncate_message(msg, max_length=50) + assert all(len(c) <= 50 for c in chunks) + + def test_exactly_at_limit(self): + msg = "x" * 4000 + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) == 1 + + +# --------------------------------------------------------------------------- +# Send +# --------------------------------------------------------------------------- + +class TestMattermostSend: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._session = MagicMock() + + @pytest.mark.asyncio + async def test_send_calls_api_post(self): + """send() should POST to /api/v4/posts with channel_id and message.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post123"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Hello!") + + assert result.success is True + assert result.message_id == "post123" + + # Verify post was called with correct URL + call_args = self.adapter._session.post.call_args + assert "/api/v4/posts" in call_args[0][0] + # Verify payload + payload = call_args[1]["json"] + assert payload["channel_id"] == "channel_1" + assert payload["message"] == "Hello!" + + @pytest.mark.asyncio + async def test_send_empty_content_succeeds(self): + """Empty content should return success without calling the API.""" + result = await self.adapter.send("channel_1", "") + assert result.success is True + + @pytest.mark.asyncio + async def test_send_with_thread_reply(self): + """When reply_mode is 'thread', reply_to should become root_id.""" + self.adapter._reply_mode = "thread" + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post456"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post") + + assert result.success is True + payload = self.adapter._session.post.call_args[1]["json"] + assert payload["root_id"] == "root_post" + + @pytest.mark.asyncio + async def test_send_without_thread_no_root_id(self): + """When reply_mode is 'off', reply_to should NOT set root_id.""" + self.adapter._reply_mode = "off" + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post789"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post") + + assert result.success is True + payload = self.adapter._session.post.call_args[1]["json"] + assert "root_id" not in payload + + @pytest.mark.asyncio + async def test_send_api_failure(self): + """When API returns error, send should return failure.""" + mock_resp = AsyncMock() + mock_resp.status = 500 + mock_resp.json = AsyncMock(return_value={}) + mock_resp.text = AsyncMock(return_value="Internal Server Error") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Hello!") + + assert result.success is False + + +# --------------------------------------------------------------------------- +# WebSocket event parsing +# --------------------------------------------------------------------------- + +class TestMattermostWebSocketParsing: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._bot_user_id = "bot_user_id" + # Mock handle_message to capture the MessageEvent without processing + self.adapter.handle_message = AsyncMock() + + @pytest.mark.asyncio + async def test_parse_posted_event(self): + """'posted' events should extract message from double-encoded post JSON.""" + post_data = { + "id": "post_abc", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Hello from Matrix!", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), # double-encoded JSON string + "channel_type": "O", + "sender_name": "@alice", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.text == "Hello from Matrix!" + assert msg_event.message_id == "post_abc" + + @pytest.mark.asyncio + async def test_ignore_own_messages(self): + """Messages from the bot's own user_id should be ignored.""" + post_data = { + "id": "post_self", + "user_id": "bot_user_id", # same as bot + "channel_id": "chan_456", + "message": "Bot echo", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_ignore_non_posted_events(self): + """Non-'posted' events should be ignored.""" + event = { + "event": "typing", + "data": {"user_id": "user_123"}, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_ignore_system_posts(self): + """Posts with a 'type' field (system messages) should be ignored.""" + post_data = { + "id": "sys_post", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "user joined", + "type": "system_join_channel", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_channel_type_mapping(self): + """channel_type 'D' should map to 'dm'.""" + post_data = { + "id": "post_dm", + "user_id": "user_123", + "channel_id": "chan_dm", + "message": "DM message", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "D", + "sender_name": "@bob", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.source.chat_type == "dm" + + @pytest.mark.asyncio + async def test_thread_id_from_root_id(self): + """Post with root_id should have thread_id set.""" + post_data = { + "id": "post_reply", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Thread reply", + "root_id": "root_post_123", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.source.thread_id == "root_post_123" + + @pytest.mark.asyncio + async def test_invalid_post_json_ignored(self): + """Invalid JSON in data.post should be silently ignored.""" + event = { + "event": "posted", + "data": { + "post": "not-valid-json{{{", + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + +# --------------------------------------------------------------------------- +# File upload (send_image) +# --------------------------------------------------------------------------- + +class TestMattermostFileUpload: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._session = MagicMock() + + @pytest.mark.asyncio + async def test_send_image_downloads_and_uploads(self): + """send_image should download the URL, upload via /api/v4/files, then post.""" + # Mock the download (GET) + mock_dl_resp = AsyncMock() + mock_dl_resp.status = 200 + mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data") + mock_dl_resp.content_type = "image/png" + mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp) + mock_dl_resp.__aexit__ = AsyncMock(return_value=False) + + # Mock the upload (POST to /files) + mock_upload_resp = AsyncMock() + mock_upload_resp.status = 200 + mock_upload_resp.json = AsyncMock(return_value={ + "file_infos": [{"id": "file_abc123"}] + }) + mock_upload_resp.text = AsyncMock(return_value="") + mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp) + mock_upload_resp.__aexit__ = AsyncMock(return_value=False) + + # Mock the post (POST to /posts) + mock_post_resp = AsyncMock() + mock_post_resp.status = 200 + mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"}) + mock_post_resp.text = AsyncMock(return_value="") + mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp) + mock_post_resp.__aexit__ = AsyncMock(return_value=False) + + # Route calls: first GET (download), then POST (upload), then POST (create post) + self.adapter._session.get = MagicMock(return_value=mock_dl_resp) + post_call_count = 0 + original_post_returns = [mock_upload_resp, mock_post_resp] + + def post_side_effect(*args, **kwargs): + nonlocal post_call_count + resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)] + post_call_count += 1 + return resp + + self.adapter._session.post = MagicMock(side_effect=post_side_effect) + + result = await self.adapter.send_image( + "channel_1", "https://img.example.com/cat.png", caption="A cat" + ) + + assert result.success is True + assert result.message_id == "post_with_file" + + +# --------------------------------------------------------------------------- +# Dedup cache +# --------------------------------------------------------------------------- + +class TestMattermostDedup: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._bot_user_id = "bot_user_id" + # Mock handle_message to capture calls without processing + self.adapter.handle_message = AsyncMock() + + @pytest.mark.asyncio + async def test_duplicate_post_ignored(self): + """The same post_id within the TTL window should be ignored.""" + post_data = { + "id": "post_dup", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Hello!", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + + # First time: should process + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.call_count == 1 + + # Second time (same post_id): should be deduped + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.call_count == 1 # still 1 + + @pytest.mark.asyncio + async def test_different_post_ids_both_processed(self): + """Different post IDs should both be processed.""" + for i, pid in enumerate(["post_a", "post_b"]): + post_data = { + "id": pid, + "user_id": "user_123", + "channel_id": "chan_456", + "message": f"Message {i}", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + await self.adapter._handle_ws_event(event) + + assert self.adapter.handle_message.call_count == 2 + + def test_prune_seen_clears_expired(self): + """_prune_seen should remove entries older than _SEEN_TTL.""" + now = time.time() + # Fill with enough expired entries to trigger pruning + for i in range(self.adapter._SEEN_MAX + 10): + self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago + + # Add a fresh one + self.adapter._seen_posts["fresh"] = now + + self.adapter._prune_seen() + + # Old entries should be pruned, fresh one kept + assert "fresh" in self.adapter._seen_posts + assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX + + def test_seen_cache_tracks_post_ids(self): + """Posts are tracked in _seen_posts dict.""" + self.adapter._seen_posts["test_post"] = time.time() + assert "test_post" in self.adapter._seen_posts + + +# --------------------------------------------------------------------------- +# Requirements check +# --------------------------------------------------------------------------- + +class TestMattermostRequirements: + def test_check_requirements_with_token_and_url(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "test-token") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is True + + def test_check_requirements_without_token(self, monkeypatch): + monkeypatch.delenv("MATTERMOST_TOKEN", raising=False) + monkeypatch.delenv("MATTERMOST_URL", raising=False) + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is False + + def test_check_requirements_without_url(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "test-token") + monkeypatch.delenv("MATTERMOST_URL", raising=False) + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is False diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index e29a9583d..afe436870 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -703,5 +703,15 @@ class TestLastPromptTokens: store.update_session("k1", model="openai/gpt-5.4") store._db.update_token_counts.assert_called_once_with( - "s1", 0, 0, model="openai/gpt-5.4" + "s1", + input_tokens=0, + output_tokens=0, + cache_read_tokens=0, + cache_write_tokens=0, + estimated_cost_usd=None, + cost_status=None, + cost_source=None, + billing_provider=None, + billing_base_url=None, + model="openai/gpt-5.4", ) diff --git a/tests/gateway/test_sms.py b/tests/gateway/test_sms.py index e3d927bb3..54c1edf23 100644 --- a/tests/gateway/test_sms.py +++ b/tests/gateway/test_sms.py @@ -1,240 +1,215 @@ -"""Tests for SMS (Telnyx) platform adapter.""" -import json +"""Tests for SMS (Twilio) platform integration. + +Covers config loading, format/truncate, echo prevention, +requirements check, and toolset verification. +""" + +import os +from unittest.mock import patch + import pytest -from unittest.mock import MagicMock, patch, AsyncMock -from gateway.config import Platform, PlatformConfig +from gateway.config import Platform, PlatformConfig, HomeChannel -# --------------------------------------------------------------------------- -# Platform & Config -# --------------------------------------------------------------------------- - -class TestSmsPlatformEnum: - def test_sms_enum_exists(self): - assert Platform.SMS.value == "sms" - - def test_sms_in_platform_list(self): - platforms = [p.value for p in Platform] - assert "sms" in platforms - +# ── Config loading ────────────────────────────────────────────────── class TestSmsConfigLoading: - def test_apply_env_overrides_sms(self, monkeypatch): - monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123") + """Verify _apply_env_overrides wires SMS correctly.""" - from gateway.config import GatewayConfig, _apply_env_overrides - config = GatewayConfig() - _apply_env_overrides(config) + def test_sms_platform_enum_exists(self): + assert Platform.SMS.value == "sms" - assert Platform.SMS in config.platforms - sc = config.platforms[Platform.SMS] - assert sc.enabled is True - assert sc.api_key == "KEY_test123" + def test_env_overrides_create_sms_config(self): + from gateway.config import load_gateway_config - def test_sms_not_loaded_without_key(self, monkeypatch): - monkeypatch.delenv("TELNYX_API_KEY", raising=False) + env = { + "TWILIO_ACCOUNT_SID": "ACtest123", + "TWILIO_AUTH_TOKEN": "token_abc", + "TWILIO_PHONE_NUMBER": "+15551234567", + } + with patch.dict(os.environ, env, clear=False): + config = load_gateway_config() + assert Platform.SMS in config.platforms + pc = config.platforms[Platform.SMS] + assert pc.enabled is True + assert pc.api_key == "token_abc" - from gateway.config import GatewayConfig, _apply_env_overrides - config = GatewayConfig() - _apply_env_overrides(config) + def test_env_overrides_set_home_channel(self): + from gateway.config import load_gateway_config - assert Platform.SMS not in config.platforms + env = { + "TWILIO_ACCOUNT_SID": "ACtest123", + "TWILIO_AUTH_TOKEN": "token_abc", + "TWILIO_PHONE_NUMBER": "+15551234567", + "SMS_HOME_CHANNEL": "+15559876543", + "SMS_HOME_CHANNEL_NAME": "My Phone", + } + with patch.dict(os.environ, env, clear=False): + config = load_gateway_config() + hc = config.platforms[Platform.SMS].home_channel + assert hc is not None + assert hc.chat_id == "+15559876543" + assert hc.name == "My Phone" + assert hc.platform == Platform.SMS - def test_connected_platforms_includes_sms(self, monkeypatch): - monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123") + def test_sms_in_connected_platforms(self): + from gateway.config import load_gateway_config - from gateway.config import GatewayConfig, _apply_env_overrides - config = GatewayConfig() - _apply_env_overrides(config) - - connected = config.get_connected_platforms() - assert Platform.SMS in connected - - def test_sms_home_channel(self, monkeypatch): - monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123") - monkeypatch.setenv("SMS_HOME_CHANNEL", "+15559876543") - monkeypatch.setenv("SMS_HOME_CHANNEL_NAME", "Owner") - - from gateway.config import GatewayConfig, _apply_env_overrides - config = GatewayConfig() - _apply_env_overrides(config) - - home = config.get_home_channel(Platform.SMS) - assert home is not None - assert home.chat_id == "+15559876543" - assert home.name == "Owner" + env = { + "TWILIO_ACCOUNT_SID": "ACtest123", + "TWILIO_AUTH_TOKEN": "token_abc", + } + with patch.dict(os.environ, env, clear=False): + config = load_gateway_config() + connected = config.get_connected_platforms() + assert Platform.SMS in connected -# --------------------------------------------------------------------------- -# Adapter format / truncate -# --------------------------------------------------------------------------- +# ── Format / truncate ─────────────────────────────────────────────── -class TestSmsFormatMessage: - def setup_method(self): +class TestSmsFormatAndTruncate: + """Test SmsAdapter.format_message strips markdown.""" + + def _make_adapter(self): from gateway.platforms.sms import SmsAdapter - config = PlatformConfig(enabled=True, api_key="test_key") - with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}): - self.adapter = SmsAdapter(config) - def test_strip_bold(self): - assert self.adapter.format_message("**bold**") == "bold" + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = object.__new__(SmsAdapter) + adapter.config = pc + adapter._platform = Platform.SMS + adapter._account_sid = "ACtest" + adapter._auth_token = "tok" + adapter._from_number = "+15550001111" + return adapter - def test_strip_italic(self): - assert self.adapter.format_message("*italic*") == "italic" + def test_strips_bold(self): + adapter = self._make_adapter() + assert adapter.format_message("**hello**") == "hello" - def test_strip_code_block(self): - result = self.adapter.format_message("```python\ncode\n```") + def test_strips_italic(self): + adapter = self._make_adapter() + assert adapter.format_message("*world*") == "world" + + def test_strips_code_blocks(self): + adapter = self._make_adapter() + result = adapter.format_message("```python\nprint('hi')\n```") assert "```" not in result - assert "code" in result + assert "print('hi')" in result - def test_strip_inline_code(self): - assert self.adapter.format_message("`code`") == "code" + def test_strips_inline_code(self): + adapter = self._make_adapter() + assert adapter.format_message("`code`") == "code" - def test_strip_headers(self): - assert self.adapter.format_message("## Header") == "Header" + def test_strips_headers(self): + adapter = self._make_adapter() + assert adapter.format_message("## Title") == "Title" - def test_strip_links(self): - assert self.adapter.format_message("[click](http://example.com)") == "click" + def test_strips_links(self): + adapter = self._make_adapter() + assert adapter.format_message("[click](https://example.com)") == "click" - def test_collapse_newlines(self): - result = self.adapter.format_message("a\n\n\n\nb") + def test_collapses_newlines(self): + adapter = self._make_adapter() + result = adapter.format_message("a\n\n\n\nb") assert result == "a\n\nb" -class TestSmsTruncateMessage: - def setup_method(self): +# ── Echo prevention ──────────────────────────────────────────────── + +class TestSmsEchoPrevention: + """Adapter should ignore messages from its own number.""" + + def test_own_number_detection(self): + """The adapter stores _from_number for echo prevention.""" from gateway.platforms.sms import SmsAdapter - config = PlatformConfig(enabled=True, api_key="test_key") - with patch.dict("os.environ", {"TELNYX_API_KEY": "test_key"}): - self.adapter = SmsAdapter(config) - def test_short_message_single_chunk(self): - msg = "Hello, world!" - chunks = self.adapter.truncate_message(msg) - assert len(chunks) == 1 - assert chunks[0] == msg - - def test_long_message_splits(self): - msg = "a " * 1000 # 2000 chars - chunks = self.adapter.truncate_message(msg) - assert len(chunks) >= 2 - for chunk in chunks: - assert len(chunk) <= 1600 - - def test_custom_max_length(self): - msg = "Hello " * 20 - chunks = self.adapter.truncate_message(msg, max_length=50) - assert all(len(c) <= 50 for c in chunks) + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + "TWILIO_PHONE_NUMBER": "+15550001111", + } + with patch.dict(os.environ, env): + pc = PlatformConfig(enabled=True, api_key="tok") + adapter = SmsAdapter(pc) + assert adapter._from_number == "+15550001111" -# --------------------------------------------------------------------------- -# Echo loop prevention -# --------------------------------------------------------------------------- - -class TestSmsEchoLoop: - def test_own_number_ignored(self): - from gateway.platforms.sms import SmsAdapter - config = PlatformConfig(enabled=True, api_key="test_key") - with patch.dict("os.environ", { - "TELNYX_API_KEY": "test_key", - "TELNYX_FROM_NUMBERS": "+15551234567,+15559876543", - }): - adapter = SmsAdapter(config) - assert "+15551234567" in adapter._from_numbers - assert "+15559876543" in adapter._from_numbers - - -# --------------------------------------------------------------------------- -# Auth maps -# --------------------------------------------------------------------------- - -class TestSmsAuthMaps: - def test_sms_in_allowed_users_map(self): - """SMS should be in the platform auth maps in run.py.""" - # Verify the env var names are consistent - import os - os.environ.setdefault("SMS_ALLOWED_USERS", "+15551234567") - assert os.getenv("SMS_ALLOWED_USERS") == "+15551234567" - - def test_sms_allow_all_env_var(self): - """SMS_ALLOW_ALL_USERS should be recognized.""" - import os - os.environ.setdefault("SMS_ALLOW_ALL_USERS", "true") - assert os.getenv("SMS_ALLOW_ALL_USERS") == "true" - - -# --------------------------------------------------------------------------- -# Requirements check -# --------------------------------------------------------------------------- +# ── Requirements check ───────────────────────────────────────────── class TestSmsRequirements: - def test_check_sms_requirements_with_key(self, monkeypatch): - monkeypatch.setenv("TELNYX_API_KEY", "KEY_test123") + def test_check_sms_requirements_missing_sid(self): from gateway.platforms.sms import check_sms_requirements - # aiohttp is available in test environment - assert check_sms_requirements() is True - def test_check_sms_requirements_without_key(self, monkeypatch): - monkeypatch.delenv("TELNYX_API_KEY", raising=False) + env = {"TWILIO_AUTH_TOKEN": "tok"} + with patch.dict(os.environ, env, clear=True): + assert check_sms_requirements() is False + + def test_check_sms_requirements_missing_token(self): from gateway.platforms.sms import check_sms_requirements - assert check_sms_requirements() is False + + env = {"TWILIO_ACCOUNT_SID": "ACtest"} + with patch.dict(os.environ, env, clear=True): + assert check_sms_requirements() is False + + def test_check_sms_requirements_both_set(self): + from gateway.platforms.sms import check_sms_requirements + + env = { + "TWILIO_ACCOUNT_SID": "ACtest", + "TWILIO_AUTH_TOKEN": "tok", + } + with patch.dict(os.environ, env, clear=False): + # Only returns True if aiohttp is also importable + result = check_sms_requirements() + try: + import aiohttp # noqa: F401 + assert result is True + except ImportError: + assert result is False -# --------------------------------------------------------------------------- -# Toolset & integration points -# --------------------------------------------------------------------------- +# ── Toolset verification ─────────────────────────────────────────── class TestSmsToolset: def test_hermes_sms_toolset_exists(self): from toolsets import get_toolset + ts = get_toolset("hermes-sms") assert ts is not None - assert "hermes-sms" in ts.get("description", "").lower() or "sms" in ts.get("description", "").lower() + assert "tools" in ts - def test_hermes_gateway_includes_sms(self): + def test_hermes_sms_in_gateway_includes(self): from toolsets import get_toolset + gw = get_toolset("hermes-gateway") + assert gw is not None assert "hermes-sms" in gw["includes"] - -class TestSmsPlatformHints: - def test_sms_in_platform_hints(self): + def test_sms_platform_hint_exists(self): from agent.prompt_builder import PLATFORM_HINTS + assert "sms" in PLATFORM_HINTS - assert "SMS" in PLATFORM_HINTS["sms"] or "sms" in PLATFORM_HINTS["sms"].lower() + assert "concise" in PLATFORM_HINTS["sms"].lower() - -class TestSmsCronDelivery: - def test_sms_in_cron_platform_map(self): - """Verify the cron scheduler can resolve 'sms' platform.""" - # The platform_map in _deliver_result should include sms - from gateway.config import Platform + def test_sms_in_scheduler_platform_map(self): + """Verify cron scheduler recognizes 'sms' as a valid platform.""" + # Just check the Platform enum has SMS — the scheduler imports it dynamically assert Platform.SMS.value == "sms" - -class TestSmsSendMessageTool: def test_sms_in_send_message_platform_map(self): - """The send_message tool should recognize 'sms' as a valid platform.""" - # We verify by checking that SMS is in the Platform enum - # and the code path exists - from gateway.config import Platform + """Verify send_message_tool recognizes 'sms'.""" + # The platform_map is built inside _handle_send; verify SMS enum exists assert hasattr(Platform, "SMS") - -class TestSmsChannelDirectory: - def test_sms_in_session_discovery(self): - """Verify SMS is included in session-based channel discovery.""" - import inspect - from gateway.channel_directory import build_channel_directory - source = inspect.getsource(build_channel_directory) - assert '"sms"' in source - - -class TestSmsStatus: - def test_sms_in_status_platforms(self): - """Verify SMS appears in the status command platforms dict.""" - import inspect - from hermes_cli.status import show_status - source = inspect.getsource(show_status) - assert '"SMS"' in source or "'SMS'" in source + def test_sms_in_cronjob_deliver_description(self): + """Verify cronjob_tools mentions sms in deliver description.""" + from tools.cronjob_tools import CRONJOB_SCHEMA + deliver_desc = CRONJOB_SCHEMA["parameters"]["properties"]["deliver"]["description"] + assert "sms" in deliver_desc.lower() diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index 1c22543f7..1378ff1cb 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -128,6 +128,13 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch): session_entry.session_key, input_tokens=120, output_tokens=45, + cache_read_tokens=0, + cache_write_tokens=0, last_prompt_tokens=80, model="openai/test-model", + estimated_cost_usd=None, + cost_status=None, + cost_source=None, + provider=None, + base_url=None, ) diff --git a/tests/hermes_cli/test_config.py b/tests/hermes_cli/test_config.py index ba4f5c844..82cb99c64 100644 --- a/tests/hermes_cli/test_config.py +++ b/tests/hermes_cli/test_config.py @@ -316,6 +316,38 @@ class TestSanitizeEnvLines: assert fixes == 0 +class TestOptionalEnvVarsRegistry: + """Verify that key env vars are registered in OPTIONAL_ENV_VARS.""" + + def test_tavily_api_key_registered(self): + """TAVILY_API_KEY is listed in OPTIONAL_ENV_VARS.""" + from hermes_cli.config import OPTIONAL_ENV_VARS + assert "TAVILY_API_KEY" in OPTIONAL_ENV_VARS + + def test_tavily_api_key_is_tool_category(self): + """TAVILY_API_KEY is in the 'tool' category.""" + from hermes_cli.config import OPTIONAL_ENV_VARS + assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["category"] == "tool" + + def test_tavily_api_key_is_password(self): + """TAVILY_API_KEY is marked as password.""" + from hermes_cli.config import OPTIONAL_ENV_VARS + assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["password"] is True + + def test_tavily_api_key_has_url(self): + """TAVILY_API_KEY has a URL.""" + from hermes_cli.config import OPTIONAL_ENV_VARS + assert OPTIONAL_ENV_VARS["TAVILY_API_KEY"]["url"] == "https://app.tavily.com/home" + + def test_tavily_in_env_vars_by_version(self): + """TAVILY_API_KEY is listed in ENV_VARS_BY_VERSION.""" + from hermes_cli.config import ENV_VARS_BY_VERSION + all_vars = [] + for vars_list in ENV_VARS_BY_VERSION.values(): + all_vars.extend(vars_list) + assert "TAVILY_API_KEY" in all_vars + + class TestAnthropicTokenMigration: """Test that config version 8→9 clears ANTHROPIC_TOKEN.""" diff --git a/tests/hermes_cli/test_mcp_tools_config.py b/tests/hermes_cli/test_mcp_tools_config.py new file mode 100644 index 000000000..d7be938ad --- /dev/null +++ b/tests/hermes_cli/test_mcp_tools_config.py @@ -0,0 +1,291 @@ +"""Tests for MCP tools interactive configuration in hermes_cli.tools_config.""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from hermes_cli.tools_config import _configure_mcp_tools_interactive + +# Patch targets: imports happen inside the function body, so patch at source +_PROBE = "tools.mcp_tool.probe_mcp_server_tools" +_CHECKLIST = "hermes_cli.curses_ui.curses_checklist" +_SAVE = "hermes_cli.tools_config.save_config" + + +def test_no_mcp_servers_prints_info(capsys): + """Returns immediately when no MCP servers are configured.""" + config = {} + _configure_mcp_tools_interactive(config) + captured = capsys.readouterr() + assert "No MCP servers configured" in captured.out + + +def test_all_servers_disabled_prints_info(capsys): + """Returns immediately when all configured servers have enabled=false.""" + config = { + "mcp_servers": { + "github": {"command": "npx", "enabled": False}, + "slack": {"command": "npx", "enabled": "false"}, + } + } + _configure_mcp_tools_interactive(config) + captured = capsys.readouterr() + assert "disabled" in captured.out + + +def test_probe_failure_shows_warning(capsys): + """Shows warning when probe returns no tools.""" + config = {"mcp_servers": {"github": {"command": "npx"}}} + with patch(_PROBE, return_value={}): + _configure_mcp_tools_interactive(config) + captured = capsys.readouterr() + assert "Could not discover" in captured.out + + +def test_probe_exception_shows_error(capsys): + """Shows error when probe raises an exception.""" + config = {"mcp_servers": {"github": {"command": "npx"}}} + with patch(_PROBE, side_effect=RuntimeError("MCP not installed")): + _configure_mcp_tools_interactive(config) + captured = capsys.readouterr() + assert "Failed to probe" in captured.out + + +def test_no_changes_when_checklist_cancelled(capsys): + """No config changes when user cancels (ESC) the checklist.""" + config = { + "mcp_servers": { + "github": {"command": "npx", "args": ["-y", "server-github"]}, + } + } + tools = [("create_issue", "Create an issue"), ("search_repos", "Search repos")] + + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, return_value={0, 1}), \ + patch(_SAVE) as mock_save: + _configure_mcp_tools_interactive(config) + mock_save.assert_not_called() + captured = capsys.readouterr() + assert "no changes" in captured.out.lower() + + +def test_disabling_tool_writes_exclude_list(capsys): + """Unchecking a tool adds it to the exclude list.""" + config = { + "mcp_servers": { + "github": {"command": "npx"}, + } + } + tools = [ + ("create_issue", "Create an issue"), + ("delete_repo", "Delete a repo"), + ("search_repos", "Search repos"), + ] + + # User unchecks delete_repo (index 1) + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, return_value={0, 2}), \ + patch(_SAVE) as mock_save: + _configure_mcp_tools_interactive(config) + + mock_save.assert_called_once() + tools_cfg = config["mcp_servers"]["github"]["tools"] + assert tools_cfg["exclude"] == ["delete_repo"] + assert "include" not in tools_cfg + + +def test_enabling_all_clears_filters(capsys): + """Checking all tools clears both include and exclude lists.""" + config = { + "mcp_servers": { + "github": { + "command": "npx", + "tools": {"exclude": ["delete_repo"], "include": ["create_issue"]}, + }, + } + } + tools = [("create_issue", "Create"), ("delete_repo", "Delete")] + + # User checks all tools — pre_selected would be {0} (include mode), + # so returning {0, 1} is a change + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, return_value={0, 1}), \ + patch(_SAVE) as mock_save: + _configure_mcp_tools_interactive(config) + + mock_save.assert_called_once() + tools_cfg = config["mcp_servers"]["github"]["tools"] + assert "exclude" not in tools_cfg + assert "include" not in tools_cfg + + +def test_pre_selection_respects_existing_exclude(capsys): + """Tools in exclude list start unchecked.""" + config = { + "mcp_servers": { + "github": { + "command": "npx", + "tools": {"exclude": ["delete_repo"]}, + }, + } + } + tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")] + captured_pre_selected = {} + + def fake_checklist(title, labels, pre_selected, **kwargs): + captured_pre_selected["value"] = set(pre_selected) + return pre_selected # No changes + + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, side_effect=fake_checklist), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + # create_issue (0) and search (2) should be pre-selected, delete_repo (1) should not + assert captured_pre_selected["value"] == {0, 2} + + +def test_pre_selection_respects_existing_include(capsys): + """Only tools in include list start checked.""" + config = { + "mcp_servers": { + "github": { + "command": "npx", + "tools": {"include": ["search"]}, + }, + } + } + tools = [("create_issue", "Create"), ("delete_repo", "Delete"), ("search", "Search")] + captured_pre_selected = {} + + def fake_checklist(title, labels, pre_selected, **kwargs): + captured_pre_selected["value"] = set(pre_selected) + return pre_selected # No changes + + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, side_effect=fake_checklist), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + # Only search (2) should be pre-selected + assert captured_pre_selected["value"] == {2} + + +def test_multiple_servers_each_get_checklist(capsys): + """Each server gets its own checklist.""" + config = { + "mcp_servers": { + "github": {"command": "npx"}, + "slack": {"url": "https://mcp.example.com"}, + } + } + checklist_calls = [] + + def fake_checklist(title, labels, pre_selected, **kwargs): + checklist_calls.append(title) + return pre_selected # No changes + + with patch( + _PROBE, + return_value={ + "github": [("create_issue", "Create")], + "slack": [("send_message", "Send")], + }, + ), patch(_CHECKLIST, side_effect=fake_checklist), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + assert len(checklist_calls) == 2 + assert any("github" in t for t in checklist_calls) + assert any("slack" in t for t in checklist_calls) + + +def test_failed_server_shows_warning(capsys): + """Servers that fail to connect show warnings.""" + config = { + "mcp_servers": { + "github": {"command": "npx"}, + "broken": {"command": "nonexistent"}, + } + } + + # Only github succeeds + with patch( + _PROBE, return_value={"github": [("create_issue", "Create")]}, + ), patch(_CHECKLIST, return_value={0}), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + captured = capsys.readouterr() + assert "broken" in captured.out + + +def test_description_truncation_in_labels(): + """Long descriptions are truncated in checklist labels.""" + config = { + "mcp_servers": { + "github": {"command": "npx"}, + } + } + long_desc = "A" * 100 + captured_labels = {} + + def fake_checklist(title, labels, pre_selected, **kwargs): + captured_labels["value"] = labels + return pre_selected + + with patch( + _PROBE, return_value={"github": [("my_tool", long_desc)]}, + ), patch(_CHECKLIST, side_effect=fake_checklist), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + label = captured_labels["value"][0] + assert "..." in label + assert len(label) < len(long_desc) + 30 # truncated + tool name + parens + + +def test_switching_from_include_to_exclude(capsys): + """When user modifies selection, include list is replaced by exclude list.""" + config = { + "mcp_servers": { + "github": { + "command": "npx", + "tools": {"include": ["create_issue"]}, + }, + } + } + tools = [("create_issue", "Create"), ("search", "Search"), ("delete", "Delete")] + + # User selects create_issue and search (deselects delete) + # pre_selected would be {0} (only create_issue from include), so {0, 1} is a change + with patch(_PROBE, return_value={"github": tools}), \ + patch(_CHECKLIST, return_value={0, 1}), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + tools_cfg = config["mcp_servers"]["github"]["tools"] + assert tools_cfg["exclude"] == ["delete"] + assert "include" not in tools_cfg + + +def test_empty_tools_server_skipped(capsys): + """Server with no tools shows info message and skips checklist.""" + config = { + "mcp_servers": { + "empty": {"command": "npx"}, + } + } + checklist_calls = [] + + def fake_checklist(title, labels, pre_selected, **kwargs): + checklist_calls.append(title) + return pre_selected + + with patch(_PROBE, return_value={"empty": []}), \ + patch(_CHECKLIST, side_effect=fake_checklist), \ + patch(_SAVE): + _configure_mcp_tools_interactive(config) + + assert len(checklist_calls) == 0 + captured = capsys.readouterr() + assert "no tools found" in captured.out diff --git a/tests/hermes_cli/test_setup.py b/tests/hermes_cli/test_setup.py index 11e633306..bc19e7bbf 100644 --- a/tests/hermes_cli/test_setup.py +++ b/tests/hermes_cli/test_setup.py @@ -5,6 +5,13 @@ from hermes_cli.config import load_config, save_config from hermes_cli.setup import setup_model_provider +def _maybe_keep_current_tts(question, choices): + if question != "Select TTS provider:": + return None + assert choices[-1].startswith("Keep current (") + return len(choices) - 1 + + def _clear_provider_env(monkeypatch): for key in ( "NOUS_API_KEY", @@ -25,16 +32,22 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider( config = load_config() - # Provider selection always comes first. Depending on available vision - # backends, setup may either skip the optional vision step or prompt for - # it before the default-model choice. Provide enough selections for both - # paths while still ending on "keep current model". - prompt_choices = iter([0, 2, 2]) - monkeypatch.setattr( - "hermes_cli.setup.prompt_choice", - lambda *args, **kwargs: next(prompt_choices), - ) + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + return 0 + if question == "Configure vision:": + return len(choices) - 1 + if question == "Select default model:": + assert choices[-1] == "Keep current (anthropic/claude-opus-4.6)" + return len(choices) - 1 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "") + monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) def _fake_login_nous(*args, **kwargs): auth_path = tmp_path / "auth.json" @@ -53,7 +66,6 @@ def test_nous_oauth_setup_keeps_current_model_when_syncing_disk_provider( "hermes_cli.auth.fetch_nous_models", lambda *args, **kwargs: ["gemini-3-flash"], ) - monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None) setup_model_provider(config) save_config(config) @@ -75,21 +87,29 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch): config = load_config() - monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: 3) + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + return 3 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) prompt_values = iter( [ "https://custom.example/v1", "custom-api-key", "custom/model", - "", ] ) monkeypatch.setattr( "hermes_cli.setup.prompt", lambda *args, **kwargs: next(prompt_values), ) - monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None) + monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) + monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) setup_model_provider(config) save_config(config) @@ -111,11 +131,17 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon config = load_config() - prompt_choices = iter([1, 0]) - monkeypatch.setattr( - "hermes_cli.setup.prompt_choice", - lambda *args, **kwargs: next(prompt_choices), - ) + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + return 1 + if question == "Select default model:": + return 0 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "") monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) monkeypatch.setattr("hermes_cli.auth._login_openai_codex", lambda *args, **kwargs: None) @@ -137,7 +163,6 @@ def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, mon "hermes_cli.codex_models.get_codex_model_ids", _fake_get_codex_model_ids, ) - monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None) setup_model_provider(config) save_config(config) diff --git a/tests/hermes_cli/test_setup_model_provider.py b/tests/hermes_cli/test_setup_model_provider.py index 9b44f6bcd..671bb9ba3 100644 --- a/tests/hermes_cli/test_setup_model_provider.py +++ b/tests/hermes_cli/test_setup_model_provider.py @@ -6,6 +6,13 @@ from hermes_cli.config import load_config, save_config, save_env_value from hermes_cli.setup import _print_setup_summary, setup_model_provider +def _maybe_keep_current_tts(question, choices): + if question != "Select TTS provider:": + return None + assert choices[-1].startswith("Keep current (") + return len(choices) - 1 + + def _read_env(home): env_path = home / ".env" data = {} @@ -50,19 +57,18 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m } save_config(config) - calls = {"count": 0} - def fake_prompt_choice(question, choices, default=0): - calls["count"] += 1 - if calls["count"] == 1: + if question == "Select your inference provider:": assert choices[-1] == "Keep current (Custom: https://example.invalid/v1)" return len(choices) - 1 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx raise AssertionError("Model menu should not appear for keep-current custom") monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "") monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) - monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None) monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None) monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) @@ -73,7 +79,6 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m assert reloaded["model"]["provider"] == "custom" assert reloaded["model"]["default"] == "custom/model" assert reloaded["model"]["base_url"] == "https://example.invalid/v1" - assert calls["count"] == 1 def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): @@ -87,8 +92,9 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): return 3 # Custom endpoint if question == "Configure vision:": return len(choices) - 1 # Skip - if question == "Select TTS provider:": - return len(choices) - 1 # Keep current + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx raise AssertionError(f"Unexpected prompt_choice call: {question}") def fake_prompt(message, current=None, **kwargs): @@ -103,7 +109,6 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt) monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) - monkeypatch.setattr("hermes_cli.setup._setup_tts_provider", lambda config: None) monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None) monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: []) @@ -144,25 +149,23 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm save_config(config) captured = {"provider_choices": None, "model_choices": None} - calls = {"count": 0} def fake_prompt_choice(question, choices, default=0): - calls["count"] += 1 - if calls["count"] == 1: + if question == "Select your inference provider:": captured["provider_choices"] = list(choices) assert choices[-1] == "Keep current (Anthropic)" return len(choices) - 1 - if calls["count"] == 2: + if question == "Configure vision:": assert question == "Configure vision:" assert choices[-1] == "Skip for now" return len(choices) - 1 - if calls["count"] == 3: + if question == "Select default model:": captured["model_choices"] = list(choices) return len(choices) - 1 # keep current model - if calls["count"] == 4: - assert question == "Select TTS provider:" - return len(choices) - 1 # Keep current - raise AssertionError("Unexpected extra prompt_choice call") + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "") @@ -179,7 +182,6 @@ def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tm assert captured["model_choices"] is not None assert captured["model_choices"][0] == "claude-opus-4-6" assert "anthropic/claude-opus-4.6 (recommended)" not in captured["model_choices"] - assert calls["count"] == 4 # provider, vision, model, TTS def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_path, monkeypatch): @@ -193,15 +195,24 @@ def test_setup_keep_current_anthropic_can_configure_openai_vision_default(tmp_pa } save_config(config) - picks = iter([ - 10, # keep current provider (shifted +1 by kilocode insertion) - 1, # configure vision with OpenAI - 5, # use default gpt-4o-mini vision model - 4, # keep current Anthropic model - 4, # TTS: Keep current - ]) + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + assert choices[-1] == "Keep current (Anthropic)" + return len(choices) - 1 + if question == "Configure vision:": + return 1 + if question == "Select vision model:": + assert choices[-1] == "Use default (gpt-4o-mini)" + return len(choices) - 1 + if question == "Select default model:": + assert choices[-1] == "Keep current (claude-opus-4-6)" + return len(choices) - 1 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") - monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks)) + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr( "hermes_cli.setup.prompt", lambda message, *args, **kwargs: "sk-openai" if "OpenAI API key" in message else "", @@ -237,8 +248,17 @@ def test_setup_switch_custom_to_codex_clears_custom_endpoint_and_updates_config( } save_config(config) - picks = iter([1, 0, 4]) # provider, model; 4 = TTS Keep current - monkeypatch.setattr("hermes_cli.setup.prompt_choice", lambda *args, **kwargs: next(picks)) + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + return 1 + if question == "Select default model:": + return 0 + tts_idx = _maybe_keep_current_tts(question, choices) + if tts_idx is not None: + return tts_idx + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) monkeypatch.setattr("hermes_cli.setup.prompt", lambda *args, **kwargs: "") monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None) diff --git a/tests/hermes_cli/test_status.py b/tests/hermes_cli/test_status.py new file mode 100644 index 000000000..374e57b29 --- /dev/null +++ b/tests/hermes_cli/test_status.py @@ -0,0 +1,14 @@ +from types import SimpleNamespace + +from hermes_cli.status import show_status + + +def test_show_status_includes_tavily_key(monkeypatch, capsys, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.setenv("TAVILY_API_KEY", "tvly-1234567890abcdef") + + show_status(SimpleNamespace(all=False, deep=False)) + + output = capsys.readouterr().out + assert "Tavily" in output + assert "tvly...cdef" in output diff --git a/tests/hermes_cli/test_update_autostash.py b/tests/hermes_cli/test_update_autostash.py index 85523e8df..c03b6bf37 100644 --- a/tests/hermes_cli/test_update_autostash.py +++ b/tests/hermes_cli/test_update_autostash.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import pytest +from hermes_cli import config as hermes_config from hermes_cli import main as hermes_main @@ -235,3 +236,82 @@ def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch with pytest.raises(CalledProcessError): hermes_main._stash_local_changes_if_needed(["git"], Path(tmp_path)) + + +# --------------------------------------------------------------------------- +# Update uses .[all] with fallback to . +# --------------------------------------------------------------------------- + +def _setup_update_mocks(monkeypatch, tmp_path): + """Common setup for cmd_update tests.""" + (tmp_path / ".git").mkdir() + monkeypatch.setattr(hermes_main, "PROJECT_ROOT", tmp_path) + monkeypatch.setattr(hermes_main, "_stash_local_changes_if_needed", lambda *a, **kw: None) + monkeypatch.setattr(hermes_main, "_restore_stashed_changes", lambda *a, **kw: True) + monkeypatch.setattr(hermes_config, "get_missing_env_vars", lambda required_only=True: []) + monkeypatch.setattr(hermes_config, "get_missing_config_fields", lambda: []) + monkeypatch.setattr(hermes_config, "check_config_version", lambda: (5, 5)) + monkeypatch.setattr(hermes_config, "migrate_config", lambda **kw: {"env_added": [], "config_added": []}) + + +def test_cmd_update_tries_extras_first_then_falls_back(monkeypatch, tmp_path): + """When .[all] fails, update should fall back to . instead of aborting.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + recorded = [] + + def fake_run(cmd, **kwargs): + recorded.append(cmd) + if cmd == ["git", "fetch", "origin"]: + return SimpleNamespace(stdout="", stderr="", returncode=0) + if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]: + return SimpleNamespace(stdout="main\n", stderr="", returncode=0) + if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]: + return SimpleNamespace(stdout="1\n", stderr="", returncode=0) + if cmd == ["git", "pull", "origin", "main"]: + return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0) + # .[all] fails + if ".[all]" in cmd: + raise CalledProcessError(returncode=1, cmd=cmd) + # bare . succeeds + if cmd == ["/usr/bin/uv", "pip", "install", "-e", ".", "--quiet"]: + return SimpleNamespace(returncode=0) + return SimpleNamespace(returncode=0) + + monkeypatch.setattr(hermes_main.subprocess, "run", fake_run) + + hermes_main.cmd_update(SimpleNamespace()) + + install_cmds = [c for c in recorded if "pip" in c and "install" in c] + assert len(install_cmds) == 2 + assert ".[all]" in install_cmds[0] + assert "." in install_cmds[1] and ".[all]" not in install_cmds[1] + + +def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path): + """When .[all] succeeds, no fallback should be attempted.""" + _setup_update_mocks(monkeypatch, tmp_path) + monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None) + + recorded = [] + + def fake_run(cmd, **kwargs): + recorded.append(cmd) + if cmd == ["git", "fetch", "origin"]: + return SimpleNamespace(stdout="", stderr="", returncode=0) + if cmd == ["git", "rev-parse", "--abbrev-ref", "HEAD"]: + return SimpleNamespace(stdout="main\n", stderr="", returncode=0) + if cmd == ["git", "rev-list", "HEAD..origin/main", "--count"]: + return SimpleNamespace(stdout="1\n", stderr="", returncode=0) + if cmd == ["git", "pull", "origin", "main"]: + return SimpleNamespace(stdout="Updating\n", stderr="", returncode=0) + return SimpleNamespace(returncode=0) + + monkeypatch.setattr(hermes_main.subprocess, "run", fake_run) + + hermes_main.cmd_update(SimpleNamespace()) + + install_cmds = [c for c in recorded if "pip" in c and "install" in c] + assert len(install_cmds) == 1 + assert ".[all]" in install_cmds[0] diff --git a/tests/honcho_integration/test_client.py b/tests/honcho_integration/test_client.py index d779d9a63..b1ae29c54 100644 --- a/tests/honcho_integration/test_client.py +++ b/tests/honcho_integration/test_client.py @@ -63,11 +63,13 @@ class TestFromEnv: class TestFromGlobalConfig: def test_missing_config_falls_back_to_env(self, tmp_path): - config = HonchoClientConfig.from_global_config( - config_path=tmp_path / "nonexistent.json" - ) + with patch.dict(os.environ, {}, clear=True): + config = HonchoClientConfig.from_global_config( + config_path=tmp_path / "nonexistent.json" + ) # Should fall back to from_env - assert config.enabled is True or config.api_key is None # depends on env + assert config.enabled is False + assert config.api_key is None def test_reads_full_config(self, tmp_path): config_file = tmp_path / "config.json" diff --git a/tests/integration/test_web_tools.py b/tests/integration/test_web_tools.py index fb2ea9da0..fe96b3adb 100644 --- a/tests/integration/test_web_tools.py +++ b/tests/integration/test_web_tools.py @@ -3,7 +3,7 @@ Comprehensive Test Suite for Web Tools Module This script tests all web tools functionality to ensure they work correctly. -Run this after any updates to the web_tools.py module or Firecrawl library. +Run this after any updates to the web_tools.py module or backend libraries. Usage: python test_web_tools.py # Run all tests @@ -11,7 +11,7 @@ Usage: python test_web_tools.py --verbose # Show detailed output Requirements: - - FIRECRAWL_API_KEY environment variable must be set + - PARALLEL_API_KEY or FIRECRAWL_API_KEY environment variable must be set - An auxiliary LLM provider (OPENROUTER_API_KEY or Nous Portal auth) (optional, for LLM tests) """ @@ -28,12 +28,14 @@ from typing import List # Import the web tools to test (updated path after moving tools/) from tools.web_tools import ( - web_search_tool, - web_extract_tool, + web_search_tool, + web_extract_tool, web_crawl_tool, check_firecrawl_api_key, + check_web_api_key, check_auxiliary_model, - get_debug_session_info + get_debug_session_info, + _get_backend, ) @@ -121,12 +123,13 @@ class WebToolsTester: """Test environment setup and API keys""" print_section("Environment Check") - # Check Firecrawl API key - if not check_firecrawl_api_key(): - self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set") + # Check web backend API key (Parallel or Firecrawl) + if not check_web_api_key(): + self.log_result("Web Backend API Key", "failed", "PARALLEL_API_KEY or FIRECRAWL_API_KEY not set") return False else: - self.log_result("Firecrawl API Key", "passed", "Found") + backend = _get_backend() + self.log_result("Web Backend API Key", "passed", f"Using {backend} backend") # Check auxiliary LLM provider (optional) if not check_auxiliary_model(): @@ -578,7 +581,9 @@ class WebToolsTester: }, "results": self.test_results, "environment": { + "web_backend": _get_backend() if check_web_api_key() else None, "firecrawl_api_key": check_firecrawl_api_key(), + "parallel_api_key": bool(os.getenv("PARALLEL_API_KEY")), "auxiliary_model": check_auxiliary_model(), "debug_mode": get_debug_session_info()["enabled"] } diff --git a/tests/run_interrupt_test.py b/tests/run_interrupt_test.py index 845060ffa..a539c6ca9 100644 --- a/tests/run_interrupt_test.py +++ b/tests/run_interrupt_test.py @@ -24,6 +24,7 @@ def main() -> int: parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" diff --git a/tests/test_agent_guardrails.py b/tests/test_agent_guardrails.py new file mode 100644 index 000000000..706b1daf8 --- /dev/null +++ b/tests/test_agent_guardrails.py @@ -0,0 +1,263 @@ +"""Unit tests for AIAgent pre/post-LLM-call guardrails. + +Covers three static methods on AIAgent (inspired by PR #1321 — @alireza78a): + - _sanitize_api_messages() — Phase 1: orphaned tool pair repair + - _cap_delegate_task_calls() — Phase 2a: subagent concurrency limit + - _deduplicate_tool_calls() — Phase 2b: identical call deduplication +""" + +import types + +from run_agent import AIAgent +from tools.delegate_tool import MAX_CONCURRENT_CHILDREN + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_tc(name: str, arguments: str = "{}") -> types.SimpleNamespace: + """Create a minimal tool_call SimpleNamespace mirroring the OpenAI SDK object.""" + tc = types.SimpleNamespace() + tc.function = types.SimpleNamespace(name=name, arguments=arguments) + return tc + + +def tool_result(call_id: str, content: str = "ok") -> dict: + return {"role": "tool", "tool_call_id": call_id, "content": content} + + +def assistant_dict_call(call_id: str, name: str = "terminal") -> dict: + """Dict-style tool_call (as stored in message history).""" + return {"id": call_id, "function": {"name": name, "arguments": "{}"}} + + +# --------------------------------------------------------------------------- +# Phase 1 — _sanitize_api_messages +# --------------------------------------------------------------------------- + +class TestSanitizeApiMessages: + + def test_orphaned_result_removed(self): + msgs = [ + {"role": "assistant", "tool_calls": [assistant_dict_call("c1")]}, + tool_result("c1"), + tool_result("c_ORPHAN"), + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + assert all(m.get("tool_call_id") != "c_ORPHAN" for m in out) + + def test_orphaned_call_gets_stub_result(self): + msgs = [ + {"role": "assistant", "tool_calls": [assistant_dict_call("c2")]}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + stub = out[1] + assert stub["role"] == "tool" + assert stub["tool_call_id"] == "c2" + assert stub["content"] + + def test_clean_messages_pass_through(self): + msgs = [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "tool_calls": [assistant_dict_call("c3")]}, + tool_result("c3"), + {"role": "assistant", "content": "done"}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert out == msgs + + def test_mixed_orphaned_result_and_orphaned_call(self): + msgs = [ + {"role": "assistant", "tool_calls": [ + assistant_dict_call("c4"), + assistant_dict_call("c5"), + ]}, + tool_result("c4"), + tool_result("c_DANGLING"), + ] + out = AIAgent._sanitize_api_messages(msgs) + ids = [m.get("tool_call_id") for m in out if m.get("role") == "tool"] + assert "c_DANGLING" not in ids + assert "c4" in ids + assert "c5" in ids + + def test_empty_list_is_safe(self): + assert AIAgent._sanitize_api_messages([]) == [] + + def test_no_tool_messages(self): + msgs = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert out == msgs + + def test_sdk_object_tool_calls(self): + tc_obj = types.SimpleNamespace(id="c6", function=types.SimpleNamespace( + name="terminal", arguments="{}" + )) + msgs = [ + {"role": "assistant", "tool_calls": [tc_obj]}, + ] + out = AIAgent._sanitize_api_messages(msgs) + assert len(out) == 2 + assert out[1]["tool_call_id"] == "c6" + + +# --------------------------------------------------------------------------- +# Phase 2a — _cap_delegate_task_calls +# --------------------------------------------------------------------------- + +class TestCapDelegateTaskCalls: + + def test_excess_delegates_truncated(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)] + out = AIAgent._cap_delegate_task_calls(tcs) + delegate_count = sum(1 for tc in out if tc.function.name == "delegate_task") + assert delegate_count == MAX_CONCURRENT_CHILDREN + + def test_non_delegate_calls_preserved(self): + tcs = ( + [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 1)] + + [make_tc("terminal"), make_tc("web_search")] + ) + out = AIAgent._cap_delegate_task_calls(tcs) + names = [tc.function.name for tc in out] + assert "terminal" in names + assert "web_search" in names + + def test_at_limit_passes_through(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN)] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_below_limit_passes_through(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN - 1)] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_no_delegate_calls_unchanged(self): + tcs = [make_tc("terminal"), make_tc("web_search")] + out = AIAgent._cap_delegate_task_calls(tcs) + assert out is tcs + + def test_empty_list_safe(self): + assert AIAgent._cap_delegate_task_calls([]) == [] + + def test_original_list_not_mutated(self): + tcs = [make_tc("delegate_task") for _ in range(MAX_CONCURRENT_CHILDREN + 2)] + original_len = len(tcs) + AIAgent._cap_delegate_task_calls(tcs) + assert len(tcs) == original_len + + def test_interleaved_order_preserved(self): + delegates = [make_tc("delegate_task", f'{{"task":"{i}"}}') + for i in range(MAX_CONCURRENT_CHILDREN + 1)] + t1 = make_tc("terminal", '{"cmd":"ls"}') + w1 = make_tc("web_search", '{"q":"x"}') + tcs = [delegates[0], t1, delegates[1], w1] + delegates[2:] + out = AIAgent._cap_delegate_task_calls(tcs) + expected = [delegates[0], t1, delegates[1], w1] + delegates[2:MAX_CONCURRENT_CHILDREN] + assert len(out) == len(expected) + for i, (actual, exp) in enumerate(zip(out, expected)): + assert actual is exp, f"mismatch at index {i}" + + +# --------------------------------------------------------------------------- +# Phase 2b — _deduplicate_tool_calls +# --------------------------------------------------------------------------- + +class TestDeduplicateToolCalls: + + def test_duplicate_pair_deduplicated(self): + tcs = [ + make_tc("web_search", '{"query":"foo"}'), + make_tc("web_search", '{"query":"foo"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert len(out) == 1 + + def test_multiple_duplicates(self): + tcs = [ + make_tc("web_search", '{"q":"a"}'), + make_tc("web_search", '{"q":"a"}'), + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"pwd"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert len(out) == 3 + + def test_same_tool_different_args_kept(self): + tcs = [ + make_tc("terminal", '{"cmd":"ls"}'), + make_tc("terminal", '{"cmd":"pwd"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_different_tools_same_args_kept(self): + tcs = [ + make_tc("tool_a", '{"x":1}'), + make_tc("tool_b", '{"x":1}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_clean_list_unchanged(self): + tcs = [ + make_tc("web_search", '{"q":"x"}'), + make_tc("terminal", '{"cmd":"ls"}'), + ] + out = AIAgent._deduplicate_tool_calls(tcs) + assert out is tcs + + def test_empty_list_safe(self): + assert AIAgent._deduplicate_tool_calls([]) == [] + + def test_first_occurrence_kept(self): + tc1 = make_tc("terminal", '{"cmd":"ls"}') + tc2 = make_tc("terminal", '{"cmd":"ls"}') + out = AIAgent._deduplicate_tool_calls([tc1, tc2]) + assert len(out) == 1 + assert out[0] is tc1 + + def test_original_list_not_mutated(self): + tcs = [ + make_tc("web_search", '{"q":"dup"}'), + make_tc("web_search", '{"q":"dup"}'), + ] + original_len = len(tcs) + AIAgent._deduplicate_tool_calls(tcs) + assert len(tcs) == original_len + + +# --------------------------------------------------------------------------- +# _get_tool_call_id_static +# --------------------------------------------------------------------------- + +class TestGetToolCallIdStatic: + + def test_dict_with_valid_id(self): + assert AIAgent._get_tool_call_id_static({"id": "call_123"}) == "call_123" + + def test_dict_with_none_id(self): + assert AIAgent._get_tool_call_id_static({"id": None}) == "" + + def test_dict_without_id_key(self): + assert AIAgent._get_tool_call_id_static({"function": {}}) == "" + + def test_object_with_valid_id(self): + tc = types.SimpleNamespace(id="call_456") + assert AIAgent._get_tool_call_id_static(tc) == "call_456" + + def test_object_with_none_id(self): + tc = types.SimpleNamespace(id=None) + assert AIAgent._get_tool_call_id_static(tc) == "" + + def test_object_without_id_attr(self): + tc = types.SimpleNamespace() + assert AIAgent._get_tool_call_id_static(tc) == "" diff --git a/tests/test_api_key_providers.py b/tests/test_api_key_providers.py index deb55734d..98f27d103 100644 --- a/tests/test_api_key_providers.py +++ b/tests/test_api_key_providers.py @@ -98,11 +98,14 @@ class TestProviderRegistry: # ============================================================================= PROVIDER_ENV_VARS = ( - "OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", + "OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "ANTHROPIC_TOKEN", + "CLAUDE_CODE_OAUTH_TOKEN", "GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY", "KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY", "AI_GATEWAY_API_KEY", "AI_GATEWAY_BASE_URL", "KILOCODE_API_KEY", "KILOCODE_BASE_URL", + "DASHSCOPE_API_KEY", "OPENCODE_ZEN_API_KEY", "OPENCODE_GO_API_KEY", + "NOUS_API_KEY", "OPENAI_BASE_URL", ) @@ -111,6 +114,7 @@ PROVIDER_ENV_VARS = ( def _clear_provider_env(monkeypatch): for key in PROVIDER_ENV_VARS: monkeypatch.delenv(key, raising=False) + monkeypatch.setattr("hermes_cli.auth._load_auth_store", lambda: {}) class TestResolveProvider: diff --git a/tests/test_cli_interrupt_subagent.py b/tests/test_cli_interrupt_subagent.py index b91a7b654..f4322ea6b 100644 --- a/tests/test_cli_interrupt_subagent.py +++ b/tests/test_cli_interrupt_subagent.py @@ -43,6 +43,7 @@ class TestCLISubagentInterrupt(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" @@ -112,21 +113,21 @@ class TestCLISubagentInterrupt(unittest.TestCase): mock_instance._interrupt_requested = False mock_instance._interrupt_message = None mock_instance._active_children = [] + mock_instance._active_children_lock = threading.Lock() mock_instance.quiet_mode = True mock_instance.run_conversation = mock_child_run_conversation mock_instance.interrupt = lambda msg=None: setattr(mock_instance, '_interrupt_requested', True) or setattr(mock_instance, '_interrupt_message', msg) mock_instance.tools = [] MockAgent.return_value = mock_instance - + + # Register child manually (normally done by _build_child_agent) + parent._active_children.append(mock_instance) + result = _run_single_child( task_index=0, goal="Do something slow", - context=None, - toolsets=["terminal"], - model=None, - max_iterations=50, + child=mock_instance, parent_agent=parent, - task_count=1, ) delegate_result[0] = result except Exception as e: diff --git a/tests/test_cli_status_bar.py b/tests/test_cli_status_bar.py index 4e281ffa8..c1dd4b35b 100644 --- a/tests/test_cli_status_bar.py +++ b/tests/test_cli_status_bar.py @@ -16,6 +16,10 @@ def _make_cli(model: str = "anthropic/claude-sonnet-4-20250514"): def _attach_agent( cli_obj, *, + input_tokens: int | None = None, + output_tokens: int | None = None, + cache_read_tokens: int = 0, + cache_write_tokens: int = 0, prompt_tokens: int, completion_tokens: int, total_tokens: int, @@ -26,6 +30,12 @@ def _attach_agent( ): cli_obj.agent = SimpleNamespace( model=cli_obj.model, + provider="anthropic" if cli_obj.model.startswith("anthropic/") else None, + base_url="", + session_input_tokens=input_tokens if input_tokens is not None else prompt_tokens, + session_output_tokens=output_tokens if output_tokens is not None else completion_tokens, + session_cache_read_tokens=cache_read_tokens, + session_cache_write_tokens=cache_write_tokens, session_prompt_tokens=prompt_tokens, session_completion_tokens=completion_tokens, session_total_tokens=total_tokens, @@ -68,20 +78,19 @@ class TestCLIStatusBar: assert "$0.06" not in text # cost hidden by default assert "15m" in text - def test_build_status_bar_text_shows_cost_when_enabled(self): + def test_build_status_bar_text_no_cost_in_status_bar(self): cli_obj = _attach_agent( _make_cli(), prompt_tokens=10000, - completion_tokens=2400, - total_tokens=12400, + completion_tokens=5000, + total_tokens=15000, api_calls=7, - context_tokens=12400, + context_tokens=50000, context_length=200_000, ) - cli_obj.show_cost = True text = cli_obj._build_status_bar_text(width=120) - assert "$" in text # cost is shown when enabled + assert "$" not in text # cost is never shown in status bar def test_build_status_bar_text_collapses_for_narrow_terminal(self): cli_obj = _attach_agent( @@ -128,8 +137,8 @@ class TestCLIUsageReport: output = capsys.readouterr().out assert "Model:" in output - assert "Input cost:" in output - assert "Output cost:" in output + assert "Cost status:" in output + assert "Cost source:" in output assert "Total cost:" in output assert "$" in output assert "0.064" in output diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index d77247936..01d9c37ca 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -657,7 +657,7 @@ class TestSchemaInit: def test_schema_version(self, db): cursor = db._conn.execute("SELECT version FROM schema_version") version = cursor.fetchone()[0] - assert version == 4 + assert version == 5 def test_title_column_exists(self, db): """Verify the title column was created in the sessions table.""" @@ -713,12 +713,12 @@ class TestSchemaInit: conn.commit() conn.close() - # Open with SessionDB — should migrate to v4 + # Open with SessionDB — should migrate to v5 migrated_db = SessionDB(db_path=db_path) # Verify migration cursor = migrated_db._conn.execute("SELECT version FROM schema_version") - assert cursor.fetchone()[0] == 4 + assert cursor.fetchone()[0] == 5 # Verify title column exists and is NULL for existing sessions session = migrated_db.get_session("existing") diff --git a/tests/test_insights.py b/tests/test_insights.py index 6f6280a1d..af4f59829 100644 --- a/tests/test_insights.py +++ b/tests/test_insights.py @@ -123,28 +123,16 @@ def populated_db(db): # ========================================================================= class TestPricing: - def test_exact_match(self): - pricing = _get_pricing("gpt-4o") - assert pricing["input"] == 2.50 - assert pricing["output"] == 10.00 - def test_provider_prefix_stripped(self): pricing = _get_pricing("anthropic/claude-sonnet-4-20250514") assert pricing["input"] == 3.00 assert pricing["output"] == 15.00 - def test_prefix_match(self): - pricing = _get_pricing("claude-3-5-sonnet-20241022") - assert pricing["input"] == 3.00 - - def test_keyword_heuristic_opus(self): + def test_unknown_models_do_not_use_heuristics(self): pricing = _get_pricing("some-new-opus-model") - assert pricing["input"] == 15.00 - assert pricing["output"] == 75.00 - - def test_keyword_heuristic_haiku(self): + assert pricing == _DEFAULT_PRICING pricing = _get_pricing("anthropic/claude-haiku-future") - assert pricing["input"] == 0.80 + assert pricing == _DEFAULT_PRICING def test_unknown_model_returns_zero_cost(self): """Unknown/custom models should NOT have fabricated costs.""" @@ -168,40 +156,12 @@ class TestPricing: pricing = _get_pricing("") assert pricing == _DEFAULT_PRICING - def test_deepseek_heuristic(self): - pricing = _get_pricing("deepseek-v3") - assert pricing["input"] == 0.14 - - def test_gemini_heuristic(self): - pricing = _get_pricing("gemini-3.0-ultra") - assert pricing["input"] == 0.15 - - def test_dated_model_gpt4o_mini(self): - """gpt-4o-mini-2024-07-18 should match gpt-4o-mini, NOT gpt-4o.""" - pricing = _get_pricing("gpt-4o-mini-2024-07-18") - assert pricing["input"] == 0.15 # gpt-4o-mini price, not gpt-4o's 2.50 - - def test_dated_model_o3_mini(self): - """o3-mini-2025-01-31 should match o3-mini, NOT o3.""" - pricing = _get_pricing("o3-mini-2025-01-31") - assert pricing["input"] == 1.10 # o3-mini price, not o3's 10.00 - - def test_dated_model_gpt41_mini(self): - """gpt-4.1-mini-2025-04-14 should match gpt-4.1-mini, NOT gpt-4.1.""" - pricing = _get_pricing("gpt-4.1-mini-2025-04-14") - assert pricing["input"] == 0.40 # gpt-4.1-mini, not gpt-4.1's 2.00 - - def test_dated_model_gpt41_nano(self): - """gpt-4.1-nano-2025-04-14 should match gpt-4.1-nano, NOT gpt-4.1.""" - pricing = _get_pricing("gpt-4.1-nano-2025-04-14") - assert pricing["input"] == 0.10 # gpt-4.1-nano, not gpt-4.1's 2.00 - class TestHasKnownPricing: def test_known_commercial_model(self): - assert _has_known_pricing("gpt-4o") is True + assert _has_known_pricing("gpt-4o", provider="openai") is True assert _has_known_pricing("anthropic/claude-sonnet-4-20250514") is True - assert _has_known_pricing("deepseek-chat") is True + assert _has_known_pricing("gpt-4.1", provider="openai") is True def test_unknown_custom_model(self): assert _has_known_pricing("FP16_Hermes_4.5") is False @@ -210,26 +170,39 @@ class TestHasKnownPricing: assert _has_known_pricing("") is False assert _has_known_pricing(None) is False - def test_heuristic_matched_models(self): - """Models matched by keyword heuristics should be considered known.""" - assert _has_known_pricing("some-opus-model") is True - assert _has_known_pricing("future-sonnet-v2") is True + def test_heuristic_matched_models_are_not_considered_known(self): + assert _has_known_pricing("some-opus-model") is False + assert _has_known_pricing("future-sonnet-v2") is False class TestEstimateCost: def test_basic_cost(self): - # gpt-4o: 2.50/M input, 10.00/M output - cost = _estimate_cost("gpt-4o", 1_000_000, 1_000_000) - assert cost == pytest.approx(12.50, abs=0.01) + cost, status = _estimate_cost( + "anthropic/claude-sonnet-4-20250514", + 1_000_000, + 1_000_000, + provider="anthropic", + ) + assert status == "estimated" + assert cost == pytest.approx(18.0, abs=0.01) def test_zero_tokens(self): - cost = _estimate_cost("gpt-4o", 0, 0) + cost, status = _estimate_cost("gpt-4o", 0, 0, provider="openai") + assert status == "estimated" assert cost == 0.0 - def test_small_usage(self): - cost = _estimate_cost("gpt-4o", 1000, 500) - # 1000 * 2.50/1M + 500 * 10.00/1M = 0.0025 + 0.005 = 0.0075 - assert cost == pytest.approx(0.0075, abs=0.0001) + def test_cache_aware_usage(self): + cost, status = _estimate_cost( + "anthropic/claude-sonnet-4-20250514", + 1000, + 500, + cache_read_tokens=2000, + cache_write_tokens=400, + provider="anthropic", + ) + assert status == "estimated" + expected = (1000 * 3.0 + 500 * 15.0 + 2000 * 0.30 + 400 * 3.75) / 1_000_000 + assert cost == pytest.approx(expected, abs=0.0001) # ========================================================================= @@ -660,8 +633,13 @@ class TestEdgeCases: def test_mixed_commercial_and_custom_models(self, db): """Mix of commercial and custom models: only commercial ones get costs.""" - db.create_session(session_id="s1", source="cli", model="gpt-4o") - db.update_token_counts("s1", input_tokens=10000, output_tokens=5000) + db.create_session(session_id="s1", source="cli", model="anthropic/claude-sonnet-4-20250514") + db.update_token_counts( + "s1", + input_tokens=10000, + output_tokens=5000, + billing_provider="anthropic", + ) db.create_session(session_id="s2", source="cli", model="my-local-llama") db.update_token_counts("s2", input_tokens=10000, output_tokens=5000) db._conn.commit() @@ -672,13 +650,13 @@ class TestEdgeCases: # Cost should only come from gpt-4o, not from the custom model overview = report["overview"] assert overview["estimated_cost"] > 0 - assert "gpt-4o" in overview["models_with_pricing"] # list now, not set + assert "claude-sonnet-4-20250514" in overview["models_with_pricing"] # list now, not set assert "my-local-llama" in overview["models_without_pricing"] # Verify individual model entries - gpt = next(m for m in report["models"] if m["model"] == "gpt-4o") - assert gpt["has_pricing"] is True - assert gpt["cost"] > 0 + claude = next(m for m in report["models"] if m["model"] == "claude-sonnet-4-20250514") + assert claude["has_pricing"] is True + assert claude["cost"] > 0 llama = next(m for m in report["models"] if m["model"] == "my-local-llama") assert llama["has_pricing"] is False diff --git a/tests/test_interactive_interrupt.py b/tests/test_interactive_interrupt.py index c01404e1c..8c0d328c2 100644 --- a/tests/test_interactive_interrupt.py +++ b/tests/test_interactive_interrupt.py @@ -57,6 +57,7 @@ def main() -> int: parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" diff --git a/tests/test_interrupt_propagation.py b/tests/test_interrupt_propagation.py index ff1cafdc8..7f8cb01c3 100644 --- a/tests/test_interrupt_propagation.py +++ b/tests/test_interrupt_propagation.py @@ -30,12 +30,14 @@ class TestInterruptPropagationToChild(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True child = AIAgent.__new__(AIAgent) child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True parent._active_children.append(child) @@ -60,6 +62,7 @@ class TestInterruptPropagationToChild(unittest.TestCase): child._interrupt_message = "msg" child.quiet_mode = True child._active_children = [] + child._active_children_lock = threading.Lock() # Global is set set_interrupt(True) @@ -78,6 +81,7 @@ class TestInterruptPropagationToChild(unittest.TestCase): child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True child.api_mode = "chat_completions" child.log_prefix = "" @@ -119,12 +123,14 @@ class TestInterruptPropagationToChild(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True child = AIAgent.__new__(AIAgent) child._interrupt_requested = False child._interrupt_message = None child._active_children = [] + child._active_children_lock = threading.Lock() child.quiet_mode = True # Register child (simulating what _run_single_child does) diff --git a/tests/test_quick_commands.py b/tests/test_quick_commands.py index 9708b1fb3..7a89d4ca2 100644 --- a/tests/test_quick_commands.py +++ b/tests/test_quick_commands.py @@ -47,6 +47,28 @@ class TestCLIQuickCommands: args = cli.console.print.call_args[0][0] assert "no output" in args.lower() + def test_alias_command_routes_to_target(self): + """Alias quick commands rewrite to the target command.""" + cli = self._make_cli({"shortcut": {"type": "alias", "target": "/help"}}) + with patch.object(cli, "process_command", wraps=cli.process_command) as spy: + cli.process_command("/shortcut") + # Should recursively call process_command with /help + spy.assert_any_call("/help") + + def test_alias_command_passes_args(self): + """Alias quick commands forward user arguments to the target.""" + cli = self._make_cli({"sc": {"type": "alias", "target": "/context"}}) + with patch.object(cli, "process_command", wraps=cli.process_command) as spy: + cli.process_command("/sc some args") + spy.assert_any_call("/context some args") + + def test_alias_no_target_shows_error(self): + cli = self._make_cli({"broken": {"type": "alias", "target": ""}}) + cli.process_command("/broken") + cli.console.print.assert_called_once() + args = cli.console.print.call_args[0][0] + assert "no target defined" in args.lower() + def test_unsupported_type_shows_error(self): cli = self._make_cli({"bad": {"type": "prompt", "command": "echo hi"}}) cli.process_command("/bad") diff --git a/tests/test_real_interrupt_subagent.py b/tests/test_real_interrupt_subagent.py index f1a16753a..e0e681cdf 100644 --- a/tests/test_real_interrupt_subagent.py +++ b/tests/test_real_interrupt_subagent.py @@ -55,6 +55,7 @@ class TestRealSubagentInterrupt(unittest.TestCase): parent._interrupt_requested = False parent._interrupt_message = None parent._active_children = [] + parent._active_children_lock = threading.Lock() parent.quiet_mode = True parent.model = "test/model" parent.base_url = "http://localhost:1" @@ -103,19 +104,28 @@ class TestRealSubagentInterrupt(unittest.TestCase): return original_run(self_agent, *args, **kwargs) with patch.object(AIAgent, 'run_conversation', patched_run): + # Build a real child agent (AIAgent is NOT patched here, + # only run_conversation and _build_system_prompt are) + child = AIAgent( + base_url="http://localhost:1", + api_key="test-key", + model="test/model", + provider="test", + api_mode="chat_completions", + max_iterations=5, + enabled_toolsets=["terminal"], + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + platform="cli", + ) + child._delegate_depth = 1 + parent._active_children.append(child) result = _run_single_child( task_index=0, goal="Test task", - context=None, - toolsets=["terminal"], - model="test/model", - max_iterations=5, + child=child, parent_agent=parent, - task_count=1, - override_provider="test", - override_base_url="http://localhost:1", - override_api_key="test", - override_api_mode="chat_completions", ) result_holder[0] = result except Exception as e: diff --git a/tests/tools/test_delegate.py b/tests/tools/test_delegate.py index a29560b2c..476a2401b 100644 --- a/tests/tools/test_delegate.py +++ b/tests/tools/test_delegate.py @@ -12,6 +12,7 @@ Run with: python -m pytest tests/test_delegate.py -v import json import os import sys +import threading import unittest from unittest.mock import MagicMock, patch @@ -44,6 +45,7 @@ def _make_mock_parent(depth=0): parent._session_db = None parent._delegate_depth = depth parent._active_children = [] + parent._active_children_lock = threading.Lock() return parent @@ -722,7 +724,12 @@ class TestDelegationProviderIntegration(unittest.TestCase): } parent = _make_mock_parent(depth=0) - with patch("tools.delegate_tool._run_single_child") as mock_run: + # Patch _build_child_agent since credentials are now passed there + # (agents are built in the main thread before being handed to workers) + with patch("tools.delegate_tool._build_child_agent") as mock_build, \ + patch("tools.delegate_tool._run_single_child") as mock_run: + mock_child = MagicMock() + mock_build.return_value = mock_child mock_run.return_value = { "task_index": 0, "status": "completed", "summary": "Done", "api_calls": 1, "duration_seconds": 1.0 @@ -731,7 +738,8 @@ class TestDelegationProviderIntegration(unittest.TestCase): tasks = [{"goal": "Task A"}, {"goal": "Task B"}] delegate_task(tasks=tasks, parent_agent=parent) - for call in mock_run.call_args_list: + self.assertEqual(mock_build.call_count, 2) + for call in mock_build.call_args_list: self.assertEqual(call.kwargs.get("model"), "meta-llama/llama-4-scout") self.assertEqual(call.kwargs.get("override_provider"), "openrouter") self.assertEqual(call.kwargs.get("override_base_url"), "https://openrouter.ai/api/v1") diff --git a/tests/tools/test_mcp_probe.py b/tests/tools/test_mcp_probe.py new file mode 100644 index 000000000..a592c5dca --- /dev/null +++ b/tests/tools/test_mcp_probe.py @@ -0,0 +1,210 @@ +"""Tests for probe_mcp_server_tools() in tools.mcp_tool.""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def _reset_mcp_state(): + """Ensure clean MCP module state before/after each test.""" + import tools.mcp_tool as mcp + old_loop = mcp._mcp_loop + old_thread = mcp._mcp_thread + old_servers = dict(mcp._servers) + yield + mcp._servers.clear() + mcp._servers.update(old_servers) + mcp._mcp_loop = old_loop + mcp._mcp_thread = old_thread + + +class TestProbeMcpServerTools: + """Tests for the lightweight probe_mcp_server_tools function.""" + + def test_returns_empty_when_mcp_not_available(self): + with patch("tools.mcp_tool._MCP_AVAILABLE", False): + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + assert result == {} + + def test_returns_empty_when_no_config(self): + with patch("tools.mcp_tool._load_mcp_config", return_value={}): + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + assert result == {} + + def test_returns_empty_when_all_servers_disabled(self): + config = { + "github": {"command": "npx", "enabled": False}, + "slack": {"command": "npx", "enabled": "off"}, + } + with patch("tools.mcp_tool._load_mcp_config", return_value=config): + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + assert result == {} + + def test_returns_tools_from_successful_server(self): + """Successfully probed server returns its tools list.""" + config = { + "github": {"command": "npx", "connect_timeout": 5}, + } + mock_tool_1 = SimpleNamespace(name="create_issue", description="Create a new issue") + mock_tool_2 = SimpleNamespace(name="search_repos", description="Search repositories") + + mock_server = MagicMock() + mock_server._tools = [mock_tool_1, mock_tool_2] + mock_server.shutdown = AsyncMock() + + async def fake_connect(name, cfg): + return mock_server + + with patch("tools.mcp_tool._load_mcp_config", return_value=config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.mcp_tool._ensure_mcp_loop"), \ + patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \ + patch("tools.mcp_tool._stop_mcp_loop"): + + # Simulate running the async probe + def run_coro(coro, timeout=120): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + mock_run.side_effect = run_coro + + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + + assert "github" in result + assert len(result["github"]) == 2 + assert result["github"][0] == ("create_issue", "Create a new issue") + assert result["github"][1] == ("search_repos", "Search repositories") + mock_server.shutdown.assert_awaited_once() + + def test_failed_server_omitted_from_results(self): + """Servers that fail to connect are silently skipped.""" + config = { + "github": {"command": "npx", "connect_timeout": 5}, + "broken": {"command": "nonexistent", "connect_timeout": 5}, + } + mock_tool = SimpleNamespace(name="create_issue", description="Create") + mock_server = MagicMock() + mock_server._tools = [mock_tool] + mock_server.shutdown = AsyncMock() + + async def fake_connect(name, cfg): + if name == "broken": + raise ConnectionError("Server not found") + return mock_server + + with patch("tools.mcp_tool._load_mcp_config", return_value=config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.mcp_tool._ensure_mcp_loop"), \ + patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \ + patch("tools.mcp_tool._stop_mcp_loop"): + + def run_coro(coro, timeout=120): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + mock_run.side_effect = run_coro + + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + + assert "github" in result + assert "broken" not in result + + def test_handles_tool_without_description(self): + """Tools without descriptions get empty string.""" + config = {"github": {"command": "npx", "connect_timeout": 5}} + mock_tool = SimpleNamespace(name="my_tool") # no description attribute + + mock_server = MagicMock() + mock_server._tools = [mock_tool] + mock_server.shutdown = AsyncMock() + + async def fake_connect(name, cfg): + return mock_server + + with patch("tools.mcp_tool._load_mcp_config", return_value=config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.mcp_tool._ensure_mcp_loop"), \ + patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \ + patch("tools.mcp_tool._stop_mcp_loop"): + + def run_coro(coro, timeout=120): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + mock_run.side_effect = run_coro + + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + + assert result["github"][0] == ("my_tool", "") + + def test_cleanup_called_even_on_failure(self): + """_stop_mcp_loop is called even when probe fails.""" + config = {"github": {"command": "npx", "connect_timeout": 5}} + + with patch("tools.mcp_tool._load_mcp_config", return_value=config), \ + patch("tools.mcp_tool._ensure_mcp_loop"), \ + patch("tools.mcp_tool._run_on_mcp_loop", side_effect=RuntimeError("boom")), \ + patch("tools.mcp_tool._stop_mcp_loop") as mock_stop: + + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + + assert result == {} + mock_stop.assert_called_once() + + def test_skips_disabled_servers(self): + """Disabled servers are not probed.""" + config = { + "github": {"command": "npx", "connect_timeout": 5}, + "disabled_one": {"command": "npx", "enabled": False}, + } + mock_tool = SimpleNamespace(name="create_issue", description="Create") + mock_server = MagicMock() + mock_server._tools = [mock_tool] + mock_server.shutdown = AsyncMock() + + connect_calls = [] + + async def fake_connect(name, cfg): + connect_calls.append(name) + return mock_server + + with patch("tools.mcp_tool._load_mcp_config", return_value=config), \ + patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch("tools.mcp_tool._ensure_mcp_loop"), \ + patch("tools.mcp_tool._run_on_mcp_loop") as mock_run, \ + patch("tools.mcp_tool._stop_mcp_loop"): + + def run_coro(coro, timeout=120): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + mock_run.side_effect = run_coro + + from tools.mcp_tool import probe_mcp_server_tools + result = probe_mcp_server_tools() + + assert "github" in result + assert "disabled_one" not in result + assert "disabled_one" not in connect_calls diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 3796d8ced..9c49bd2c2 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -2596,17 +2596,19 @@ class TestMCPSelectiveToolLoading: async def run(): with patch("tools.mcp_tool._connect_server", side_effect=fake_connect), \ + patch.dict("tools.mcp_tool._servers", {}, clear=True), \ patch("tools.registry.registry", mock_registry), \ patch("toolsets.create_custom_toolset"): - return await _discover_and_register_server( + registered = await _discover_and_register_server( "ink_existing", {"url": "https://mcp.example.com", "tools": {"include": ["create_service"]}}, ) + return registered, _existing_tool_names() try: - registered = asyncio.run(run()) + registered, existing = asyncio.run(run()) assert registered == ["mcp_ink_existing_create_service"] - assert _existing_tool_names() == ["mcp_ink_existing_create_service"] + assert existing == ["mcp_ink_existing_create_service"] finally: _servers.pop("ink_existing", None) diff --git a/tests/tools/test_process_registry.py b/tests/tools/test_process_registry.py index 7ebe94c04..e6cfa40e7 100644 --- a/tests/tools/test_process_registry.py +++ b/tests/tools/test_process_registry.py @@ -294,6 +294,61 @@ class TestCheckpoint: recovered = registry.recover_from_checkpoint() assert recovered == 0 + def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path): + with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"): + s = _make_session() + s.watcher_platform = "telegram" + s.watcher_chat_id = "999" + s.watcher_thread_id = "42" + s.watcher_interval = 60 + registry._running[s.id] = s + registry._write_checkpoint() + + data = json.loads((tmp_path / "procs.json").read_text()) + assert len(data) == 1 + assert data[0]["watcher_platform"] == "telegram" + assert data[0]["watcher_chat_id"] == "999" + assert data[0]["watcher_thread_id"] == "42" + assert data[0]["watcher_interval"] == 60 + + def test_recover_enqueues_watchers(self, registry, tmp_path): + checkpoint = tmp_path / "procs.json" + checkpoint.write_text(json.dumps([{ + "session_id": "proc_live", + "command": "sleep 999", + "pid": os.getpid(), # current process — guaranteed alive + "task_id": "t1", + "session_key": "sk1", + "watcher_platform": "telegram", + "watcher_chat_id": "123", + "watcher_thread_id": "42", + "watcher_interval": 60, + }])) + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 1 + assert len(registry.pending_watchers) == 1 + w = registry.pending_watchers[0] + assert w["session_id"] == "proc_live" + assert w["platform"] == "telegram" + assert w["chat_id"] == "123" + assert w["thread_id"] == "42" + assert w["check_interval"] == 60 + + def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path): + checkpoint = tmp_path / "procs.json" + checkpoint.write_text(json.dumps([{ + "session_id": "proc_live", + "command": "sleep 999", + "pid": os.getpid(), + "task_id": "t1", + "watcher_interval": 0, + }])) + with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint): + recovered = registry.recover_from_checkpoint() + assert recovered == 1 + assert len(registry.pending_watchers) == 0 + # ========================================================================= # Kill process diff --git a/tests/tools/test_send_message_tool.py b/tests/tools/test_send_message_tool.py index 7ef9b149d..2b03847e5 100644 --- a/tests/tools/test_send_message_tool.py +++ b/tests/tools/test_send_message_tool.py @@ -25,7 +25,7 @@ def _make_config(): def _install_telegram_mock(monkeypatch, bot): - parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2") + parse_mode = SimpleNamespace(MARKDOWN_V2="MarkdownV2", HTML="HTML") constants_mod = SimpleNamespace(ParseMode=parse_mode) telegram_mod = SimpleNamespace(Bot=lambda token: bot, constants=constants_mod) monkeypatch.setitem(sys.modules, "telegram", telegram_mod) @@ -391,3 +391,97 @@ class TestSendToPlatformChunking: assert len(sent_calls) >= 3 assert all(call == [] for call in sent_calls[:-1]) assert sent_calls[-1] == media + + +# --------------------------------------------------------------------------- +# HTML auto-detection in Telegram send +# --------------------------------------------------------------------------- + + +class TestSendTelegramHtmlDetection: + """Verify that messages containing HTML tags are sent with parse_mode=HTML + and that plain / markdown messages use MarkdownV2.""" + + def _make_bot(self): + bot = MagicMock() + bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=1)) + bot.send_photo = AsyncMock() + bot.send_video = AsyncMock() + bot.send_voice = AsyncMock() + bot.send_audio = AsyncMock() + bot.send_document = AsyncMock() + return bot + + def test_html_message_uses_html_parse_mode(self, monkeypatch): + bot = self._make_bot() + _install_telegram_mock(monkeypatch, bot) + + asyncio.run( + _send_telegram("tok", "123", "Hello world") + ) + + bot.send_message.assert_awaited_once() + kwargs = bot.send_message.await_args.kwargs + assert kwargs["parse_mode"] == "HTML" + assert kwargs["text"] == "Hello world" + + def test_plain_text_uses_markdown_v2(self, monkeypatch): + bot = self._make_bot() + _install_telegram_mock(monkeypatch, bot) + + asyncio.run( + _send_telegram("tok", "123", "Just plain text, no tags") + ) + + bot.send_message.assert_awaited_once() + kwargs = bot.send_message.await_args.kwargs + assert kwargs["parse_mode"] == "MarkdownV2" + + def test_html_with_code_and_pre_tags(self, monkeypatch): + bot = self._make_bot() + _install_telegram_mock(monkeypatch, bot) + + html = "

code block
and inline" + asyncio.run(_send_telegram("tok", "123", html)) + + kwargs = bot.send_message.await_args.kwargs + assert kwargs["parse_mode"] == "HTML" + + def test_closing_tag_detected(self, monkeypatch): + bot = self._make_bot() + _install_telegram_mock(monkeypatch, bot) + + asyncio.run(_send_telegram("tok", "123", "text more")) + + kwargs = bot.send_message.await_args.kwargs + assert kwargs["parse_mode"] == "HTML" + + def test_angle_brackets_in_math_not_detected(self, monkeypatch): + """Expressions like 'x < 5' or '3 > 2' should not trigger HTML mode.""" + bot = self._make_bot() + _install_telegram_mock(monkeypatch, bot) + + asyncio.run(_send_telegram("tok", "123", "if x < 5 then y > 2")) + + kwargs = bot.send_message.await_args.kwargs + assert kwargs["parse_mode"] == "MarkdownV2" + + def test_html_parse_failure_falls_back_to_plain(self, monkeypatch): + """If Telegram rejects the HTML, fall back to plain text.""" + bot = self._make_bot() + bot.send_message = AsyncMock( + side_effect=[ + Exception("Bad Request: can't parse entities: unsupported html tag"), + SimpleNamespace(message_id=2), # plain fallback succeeds + ] + ) + _install_telegram_mock(monkeypatch, bot) + + result = asyncio.run( + _send_telegram("tok", "123", "broken html") + ) + + assert result["success"] is True + assert bot.send_message.await_count == 2 + second_call = bot.send_message.await_args_list[1].kwargs + assert second_call["parse_mode"] is None diff --git a/tests/tools/test_web_tools_config.py b/tests/tools/test_web_tools_config.py index 4bc49166f..d291a005b 100644 --- a/tests/tools/test_web_tools_config.py +++ b/tests/tools/test_web_tools_config.py @@ -1,8 +1,11 @@ -"""Tests for Firecrawl client configuration and singleton behavior. +"""Tests for web backend client configuration and singleton behavior. Coverage: _get_firecrawl_client() — configuration matrix, singleton caching, constructor failure recovery, return value verification, edge cases. + _get_backend() — backend selection logic with env var combinations. + _get_parallel_client() — Parallel client configuration, singleton caching. + check_web_api_key() — unified availability check. """ import os @@ -117,3 +120,212 @@ class TestFirecrawlClientConfig: from tools.web_tools import _get_firecrawl_client with pytest.raises(ValueError): _get_firecrawl_client() + + +class TestBackendSelection: + """Test suite for _get_backend() backend selection logic. + + The backend is configured via config.yaml (web.backend), set by + ``hermes tools``. Falls back to key-based detection for legacy/manual + setups. + """ + + _ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY") + + def setup_method(self): + for key in self._ENV_KEYS: + os.environ.pop(key, None) + + def teardown_method(self): + for key in self._ENV_KEYS: + os.environ.pop(key, None) + + # ── Config-based selection (web.backend in config.yaml) ─────────── + + def test_config_parallel(self): + """web.backend=parallel in config → 'parallel' regardless of keys.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "parallel"}): + assert _get_backend() == "parallel" + + def test_config_firecrawl(self): + """web.backend=firecrawl in config → 'firecrawl' even if Parallel key set.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "firecrawl"}), \ + patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + assert _get_backend() == "firecrawl" + + def test_config_tavily(self): + """web.backend=tavily in config → 'tavily' regardless of other keys.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}): + assert _get_backend() == "tavily" + + def test_config_tavily_overrides_env_keys(self): + """web.backend=tavily in config → 'tavily' even if Firecrawl key set.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "tavily"}), \ + patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + assert _get_backend() == "tavily" + + def test_config_case_insensitive(self): + """web.backend=Parallel (mixed case) → 'parallel'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "Parallel"}): + assert _get_backend() == "parallel" + + def test_config_tavily_case_insensitive(self): + """web.backend=Tavily (mixed case) → 'tavily'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "Tavily"}): + assert _get_backend() == "tavily" + + # ── Fallback (no web.backend in config) ─────────────────────────── + + def test_fallback_parallel_only_key(self): + """Only PARALLEL_API_KEY set → 'parallel'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + assert _get_backend() == "parallel" + + def test_fallback_tavily_only_key(self): + """Only TAVILY_API_KEY set → 'tavily'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}): + assert _get_backend() == "tavily" + + def test_fallback_tavily_with_firecrawl_prefers_firecrawl(self): + """Tavily + Firecrawl keys, no config → 'firecrawl' (backward compat).""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "FIRECRAWL_API_KEY": "fc-test"}): + assert _get_backend() == "firecrawl" + + def test_fallback_tavily_with_parallel_prefers_parallel(self): + """Tavily + Parallel keys, no config → 'parallel' (Parallel takes priority over Tavily).""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test", "PARALLEL_API_KEY": "par-test"}): + # Parallel + no Firecrawl → parallel + assert _get_backend() == "parallel" + + def test_fallback_both_keys_defaults_to_firecrawl(self): + """Both keys set, no config → 'firecrawl' (backward compat).""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key", "FIRECRAWL_API_KEY": "fc-test"}): + assert _get_backend() == "firecrawl" + + def test_fallback_firecrawl_only_key(self): + """Only FIRECRAWL_API_KEY set → 'firecrawl'.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}), \ + patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + assert _get_backend() == "firecrawl" + + def test_fallback_no_keys_defaults_to_firecrawl(self): + """No keys, no config → 'firecrawl' (will fail at client init).""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={}): + assert _get_backend() == "firecrawl" + + def test_invalid_config_falls_through_to_fallback(self): + """web.backend=invalid → ignored, uses key-based fallback.""" + from tools.web_tools import _get_backend + with patch("tools.web_tools._load_web_config", return_value={"backend": "nonexistent"}), \ + patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + assert _get_backend() == "parallel" + + +class TestParallelClientConfig: + """Test suite for Parallel client initialization.""" + + def setup_method(self): + import tools.web_tools + tools.web_tools._parallel_client = None + os.environ.pop("PARALLEL_API_KEY", None) + + def teardown_method(self): + import tools.web_tools + tools.web_tools._parallel_client = None + os.environ.pop("PARALLEL_API_KEY", None) + + def test_creates_client_with_key(self): + """PARALLEL_API_KEY set → creates Parallel client.""" + with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + from tools.web_tools import _get_parallel_client + from parallel import Parallel + client = _get_parallel_client() + assert client is not None + assert isinstance(client, Parallel) + + def test_no_key_raises_with_helpful_message(self): + """No PARALLEL_API_KEY → ValueError with guidance.""" + from tools.web_tools import _get_parallel_client + with pytest.raises(ValueError, match="PARALLEL_API_KEY"): + _get_parallel_client() + + def test_singleton_returns_same_instance(self): + """Second call returns cached client.""" + with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + from tools.web_tools import _get_parallel_client + client1 = _get_parallel_client() + client2 = _get_parallel_client() + assert client1 is client2 + + +class TestCheckWebApiKey: + """Test suite for check_web_api_key() unified availability check.""" + + _ENV_KEYS = ("PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "FIRECRAWL_API_URL", "TAVILY_API_KEY") + + def setup_method(self): + for key in self._ENV_KEYS: + os.environ.pop(key, None) + + def teardown_method(self): + for key in self._ENV_KEYS: + os.environ.pop(key, None) + + def test_parallel_key_only(self): + with patch.dict(os.environ, {"PARALLEL_API_KEY": "test-key"}): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + + def test_firecrawl_key_only(self): + with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + + def test_firecrawl_url_only(self): + with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + + def test_tavily_key_only(self): + with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + + def test_no_keys_returns_false(self): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is False + + def test_both_keys_returns_true(self): + with patch.dict(os.environ, { + "PARALLEL_API_KEY": "test-key", + "FIRECRAWL_API_KEY": "fc-test", + }): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True + + def test_all_three_keys_returns_true(self): + with patch.dict(os.environ, { + "PARALLEL_API_KEY": "test-key", + "FIRECRAWL_API_KEY": "fc-test", + "TAVILY_API_KEY": "tvly-test", + }): + from tools.web_tools import check_web_api_key + assert check_web_api_key() is True diff --git a/tests/tools/test_web_tools_tavily.py b/tests/tools/test_web_tools_tavily.py new file mode 100644 index 000000000..2e49b72f1 --- /dev/null +++ b/tests/tools/test_web_tools_tavily.py @@ -0,0 +1,255 @@ +"""Tests for Tavily web backend integration. + +Coverage: + _tavily_request() — API key handling, endpoint construction, error propagation. + _normalize_tavily_search_results() — search response normalization. + _normalize_tavily_documents() — extract/crawl response normalization, failed_results. + web_search_tool / web_extract_tool / web_crawl_tool — Tavily dispatch paths. +""" + +import json +import os +import asyncio +import pytest +from unittest.mock import patch, MagicMock + + +# ─── _tavily_request ───────────────────────────────────────────────────────── + +class TestTavilyRequest: + """Test suite for the _tavily_request helper.""" + + def test_raises_without_api_key(self): + """No TAVILY_API_KEY → ValueError with guidance.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("TAVILY_API_KEY", None) + from tools.web_tools import _tavily_request + with pytest.raises(ValueError, match="TAVILY_API_KEY"): + _tavily_request("search", {"query": "test"}) + + def test_posts_with_api_key_in_body(self): + """api_key is injected into the JSON payload.""" + mock_response = MagicMock() + mock_response.json.return_value = {"results": []} + mock_response.raise_for_status = MagicMock() + + with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test-key"}): + with patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post: + from tools.web_tools import _tavily_request + result = _tavily_request("search", {"query": "hello"}) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args + payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert payload["api_key"] == "tvly-test-key" + assert payload["query"] == "hello" + assert "api.tavily.com/search" in call_kwargs.args[0] + + def test_raises_on_http_error(self): + """Non-2xx responses propagate as httpx.HTTPStatusError.""" + import httpx as _httpx + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = _httpx.HTTPStatusError( + "401 Unauthorized", request=MagicMock(), response=mock_response + ) + + with patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-bad-key"}): + with patch("tools.web_tools.httpx.post", return_value=mock_response): + from tools.web_tools import _tavily_request + with pytest.raises(_httpx.HTTPStatusError): + _tavily_request("search", {"query": "test"}) + + +# ─── _normalize_tavily_search_results ───────────────────────────────────────── + +class TestNormalizeTavilySearchResults: + """Test search result normalization.""" + + def test_basic_normalization(self): + from tools.web_tools import _normalize_tavily_search_results + raw = { + "results": [ + {"title": "Python Docs", "url": "https://docs.python.org", "content": "Official docs", "score": 0.9}, + {"title": "Tutorial", "url": "https://example.com", "content": "A tutorial", "score": 0.8}, + ] + } + result = _normalize_tavily_search_results(raw) + assert result["success"] is True + web = result["data"]["web"] + assert len(web) == 2 + assert web[0]["title"] == "Python Docs" + assert web[0]["url"] == "https://docs.python.org" + assert web[0]["description"] == "Official docs" + assert web[0]["position"] == 1 + assert web[1]["position"] == 2 + + def test_empty_results(self): + from tools.web_tools import _normalize_tavily_search_results + result = _normalize_tavily_search_results({"results": []}) + assert result["success"] is True + assert result["data"]["web"] == [] + + def test_missing_fields(self): + from tools.web_tools import _normalize_tavily_search_results + result = _normalize_tavily_search_results({"results": [{}]}) + web = result["data"]["web"] + assert web[0]["title"] == "" + assert web[0]["url"] == "" + assert web[0]["description"] == "" + + +# ─── _normalize_tavily_documents ────────────────────────────────────────────── + +class TestNormalizeTavilyDocuments: + """Test extract/crawl document normalization.""" + + def test_basic_document(self): + from tools.web_tools import _normalize_tavily_documents + raw = { + "results": [{ + "url": "https://example.com", + "title": "Example", + "raw_content": "Full page content here", + }] + } + docs = _normalize_tavily_documents(raw) + assert len(docs) == 1 + assert docs[0]["url"] == "https://example.com" + assert docs[0]["title"] == "Example" + assert docs[0]["content"] == "Full page content here" + assert docs[0]["raw_content"] == "Full page content here" + assert docs[0]["metadata"]["sourceURL"] == "https://example.com" + + def test_falls_back_to_content_when_no_raw_content(self): + from tools.web_tools import _normalize_tavily_documents + raw = {"results": [{"url": "https://example.com", "content": "Snippet"}]} + docs = _normalize_tavily_documents(raw) + assert docs[0]["content"] == "Snippet" + + def test_failed_results_included(self): + from tools.web_tools import _normalize_tavily_documents + raw = { + "results": [], + "failed_results": [ + {"url": "https://fail.com", "error": "timeout"}, + ], + } + docs = _normalize_tavily_documents(raw) + assert len(docs) == 1 + assert docs[0]["url"] == "https://fail.com" + assert docs[0]["error"] == "timeout" + assert docs[0]["content"] == "" + + def test_failed_urls_included(self): + from tools.web_tools import _normalize_tavily_documents + raw = { + "results": [], + "failed_urls": ["https://bad.com"], + } + docs = _normalize_tavily_documents(raw) + assert len(docs) == 1 + assert docs[0]["url"] == "https://bad.com" + assert docs[0]["error"] == "extraction failed" + + def test_fallback_url(self): + from tools.web_tools import _normalize_tavily_documents + raw = {"results": [{"content": "data"}]} + docs = _normalize_tavily_documents(raw, fallback_url="https://fallback.com") + assert docs[0]["url"] == "https://fallback.com" + + +# ─── web_search_tool (Tavily dispatch) ──────────────────────────────────────── + +class TestWebSearchTavily: + """Test web_search_tool dispatch to Tavily.""" + + def test_search_dispatches_to_tavily(self): + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [{"title": "Result", "url": "https://r.com", "content": "desc", "score": 0.9}] + } + mock_response.raise_for_status = MagicMock() + + with patch("tools.web_tools._get_backend", return_value="tavily"), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \ + patch("tools.web_tools.httpx.post", return_value=mock_response), \ + patch("tools.interrupt.is_interrupted", return_value=False): + from tools.web_tools import web_search_tool + result = json.loads(web_search_tool("test query", limit=3)) + assert result["success"] is True + assert len(result["data"]["web"]) == 1 + assert result["data"]["web"][0]["title"] == "Result" + + +# ─── web_extract_tool (Tavily dispatch) ─────────────────────────────────────── + +class TestWebExtractTavily: + """Test web_extract_tool dispatch to Tavily.""" + + def test_extract_dispatches_to_tavily(self): + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [{"url": "https://example.com", "raw_content": "Extracted content", "title": "Page"}] + } + mock_response.raise_for_status = MagicMock() + + with patch("tools.web_tools._get_backend", return_value="tavily"), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \ + patch("tools.web_tools.httpx.post", return_value=mock_response), \ + patch("tools.web_tools.process_content_with_llm", return_value=None): + from tools.web_tools import web_extract_tool + result = json.loads(asyncio.get_event_loop().run_until_complete( + web_extract_tool(["https://example.com"], use_llm_processing=False) + )) + assert "results" in result + assert len(result["results"]) == 1 + assert result["results"][0]["url"] == "https://example.com" + + +# ─── web_crawl_tool (Tavily dispatch) ───────────────────────────────────────── + +class TestWebCrawlTavily: + """Test web_crawl_tool dispatch to Tavily.""" + + def test_crawl_dispatches_to_tavily(self): + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"url": "https://example.com/page1", "raw_content": "Page 1 content", "title": "Page 1"}, + {"url": "https://example.com/page2", "raw_content": "Page 2 content", "title": "Page 2"}, + ] + } + mock_response.raise_for_status = MagicMock() + + with patch("tools.web_tools._get_backend", return_value="tavily"), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \ + patch("tools.web_tools.httpx.post", return_value=mock_response), \ + patch("tools.web_tools.check_website_access", return_value=None), \ + patch("tools.interrupt.is_interrupted", return_value=False): + from tools.web_tools import web_crawl_tool + result = json.loads(asyncio.get_event_loop().run_until_complete( + web_crawl_tool("https://example.com", use_llm_processing=False) + )) + assert "results" in result + assert len(result["results"]) == 2 + assert result["results"][0]["title"] == "Page 1" + + def test_crawl_sends_instructions(self): + """Instructions are included in the Tavily crawl payload.""" + mock_response = MagicMock() + mock_response.json.return_value = {"results": []} + mock_response.raise_for_status = MagicMock() + + with patch("tools.web_tools._get_backend", return_value="tavily"), \ + patch.dict(os.environ, {"TAVILY_API_KEY": "tvly-test"}), \ + patch("tools.web_tools.httpx.post", return_value=mock_response) as mock_post, \ + patch("tools.web_tools.check_website_access", return_value=None), \ + patch("tools.interrupt.is_interrupted", return_value=False): + from tools.web_tools import web_crawl_tool + asyncio.get_event_loop().run_until_complete( + web_crawl_tool("https://example.com", instructions="Find docs", use_llm_processing=False) + ) + call_kwargs = mock_post.call_args + payload = call_kwargs.kwargs.get("json") or call_kwargs[1].get("json") + assert payload["instructions"] == "Find docs" + assert payload["url"] == "https://example.com" diff --git a/tests/tools/test_website_policy.py b/tests/tools/test_website_policy.py new file mode 100644 index 000000000..9d620b59a --- /dev/null +++ b/tests/tools/test_website_policy.py @@ -0,0 +1,495 @@ +import json +from pathlib import Path + +import pytest +import yaml + +from tools.website_policy import WebsitePolicyError, check_website_access, load_website_blocklist + + +def test_load_website_blocklist_merges_config_and_shared_file(tmp_path): + shared = tmp_path / "community-blocklist.txt" + shared.write_text("# comment\nexample.org\nsub.bad.net\n", encoding="utf-8") + + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": ["example.com", "https://www.evil.test/path"], + "shared_files": [str(shared)], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + policy = load_website_blocklist(config_path) + + assert policy["enabled"] is True + assert {rule["pattern"] for rule in policy["rules"]} == { + "example.com", + "evil.test", + "example.org", + "sub.bad.net", + } + + +def test_check_website_access_matches_parent_domain_subdomains(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": ["example.com"], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + blocked = check_website_access("https://docs.example.com/page", config_path=config_path) + + assert blocked is not None + assert blocked["host"] == "docs.example.com" + assert blocked["rule"] == "example.com" + + +def test_check_website_access_supports_wildcard_subdomains_only(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": ["*.tracking.example"], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + assert check_website_access("https://a.tracking.example", config_path=config_path) is not None + assert check_website_access("https://www.tracking.example", config_path=config_path) is not None + assert check_website_access("https://tracking.example", config_path=config_path) is None + + +def test_default_config_exposes_website_blocklist_shape(): + from hermes_cli.config import DEFAULT_CONFIG + + website_blocklist = DEFAULT_CONFIG["security"]["website_blocklist"] + assert website_blocklist["enabled"] is False + assert website_blocklist["domains"] == [] + assert website_blocklist["shared_files"] == [] + + +def test_load_website_blocklist_uses_enabled_default_when_section_missing(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump({"display": {"tool_progress": "all"}}, sort_keys=False), encoding="utf-8") + + policy = load_website_blocklist(config_path) + + assert policy == {"enabled": False, "rules": []} + + +def test_load_website_blocklist_raises_clean_error_for_invalid_domains_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": "example.com", + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + with pytest.raises(WebsitePolicyError, match="security.website_blocklist.domains must be a list"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_invalid_shared_files_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "shared_files": "community-blocklist.txt", + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + with pytest.raises(WebsitePolicyError, match="security.website_blocklist.shared_files must be a list"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_invalid_top_level_config_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump(["not", "a", "mapping"], sort_keys=False), encoding="utf-8") + + with pytest.raises(WebsitePolicyError, match="config root must be a mapping"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_invalid_security_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump({"security": []}, sort_keys=False), encoding="utf-8") + + with pytest.raises(WebsitePolicyError, match="security must be a mapping"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_invalid_website_blocklist_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": "block everything", + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + with pytest.raises(WebsitePolicyError, match="security.website_blocklist must be a mapping"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_invalid_enabled_type(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": "false", + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + with pytest.raises(WebsitePolicyError, match="security.website_blocklist.enabled must be a boolean"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_raises_clean_error_for_malformed_yaml(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text("security: [oops\n", encoding="utf-8") + + with pytest.raises(WebsitePolicyError, match="Invalid config YAML"): + load_website_blocklist(config_path) + + +def test_load_website_blocklist_wraps_shared_file_read_errors(tmp_path, monkeypatch): + shared = tmp_path / "community-blocklist.txt" + shared.write_text("example.org\n", encoding="utf-8") + + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "shared_files": [str(shared)], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + def failing_read_text(self, *args, **kwargs): + raise PermissionError("no permission") + + monkeypatch.setattr(Path, "read_text", failing_read_text) + + # Unreadable shared files are now warned and skipped (not raised), + # so the blocklist loads successfully but without those rules. + result = load_website_blocklist(config_path) + assert result["enabled"] is True + assert result["rules"] == [] # shared file rules skipped + + +def test_check_website_access_uses_dynamic_hermes_home(monkeypatch, tmp_path): + hermes_home = tmp_path / "hermes-home" + hermes_home.mkdir() + (hermes_home / "config.yaml").write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": ["dynamic.example"], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + blocked = check_website_access("https://dynamic.example/path") + + assert blocked is not None + assert blocked["rule"] == "dynamic.example" + + +def test_check_website_access_blocks_scheme_less_urls(tmp_path): + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "domains": ["blocked.test"], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + blocked = check_website_access("www.blocked.test/path", config_path=config_path) + + assert blocked is not None + assert blocked["host"] == "www.blocked.test" + assert blocked["rule"] == "blocked.test" + + +def test_browser_navigate_returns_policy_block(monkeypatch): + from tools import browser_tool + + monkeypatch.setattr( + browser_tool, + "check_website_access", + lambda url: { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + }, + ) + monkeypatch.setattr( + browser_tool, + "_run_browser_command", + lambda *args, **kwargs: pytest.fail("browser command should not run for blocked URL"), + ) + + result = json.loads(browser_tool.browser_navigate("https://blocked.test")) + + assert result["success"] is False + assert result["blocked_by_policy"]["rule"] == "blocked.test" + + +def test_browser_navigate_allows_when_shared_file_missing(monkeypatch, tmp_path): + """Missing shared blocklist files are warned and skipped, not fatal.""" + from tools import browser_tool + + config_path = tmp_path / "config.yaml" + config_path.write_text( + yaml.safe_dump( + { + "security": { + "website_blocklist": { + "enabled": True, + "shared_files": ["missing-blocklist.txt"], + } + } + }, + sort_keys=False, + ), + encoding="utf-8", + ) + + # check_website_access should return None (allow) — missing file is skipped + result = check_website_access("https://allowed.test", config_path=config_path) + assert result is None + + +@pytest.mark.asyncio +async def test_web_extract_short_circuits_blocked_url(monkeypatch): + from tools import web_tools + + monkeypatch.setattr( + web_tools, + "check_website_access", + lambda url: { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + }, + ) + monkeypatch.setattr( + web_tools, + "_get_firecrawl_client", + lambda: pytest.fail("firecrawl should not run for blocked URL"), + ) + monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + + result = json.loads(await web_tools.web_extract_tool(["https://blocked.test"], use_llm_processing=False)) + + assert result["results"][0]["url"] == "https://blocked.test" + assert "Blocked by website policy" in result["results"][0]["error"] + + +def test_check_website_access_fails_open_on_malformed_config(tmp_path, monkeypatch): + """Malformed config with default path should fail open (return None), not crash.""" + config_path = tmp_path / "config.yaml" + config_path.write_text("security: [oops\n", encoding="utf-8") + + # With explicit config_path (test mode), errors propagate + with pytest.raises(WebsitePolicyError): + check_website_access("https://example.com", config_path=config_path) + + # Simulate default path by pointing HERMES_HOME to tmp_path + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools import website_policy + website_policy.invalidate_cache() + + # With default path, errors are caught and fail open + result = check_website_access("https://example.com") + assert result is None # allowed, not crashed + + +@pytest.mark.asyncio +async def test_web_extract_blocks_redirected_final_url(monkeypatch): + from tools import web_tools + + def fake_check(url): + if url == "https://allowed.test": + return None + if url == "https://blocked.test/final": + return { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + pytest.fail(f"unexpected URL checked: {url}") + + class FakeFirecrawlClient: + def scrape(self, url, formats): + return { + "markdown": "secret content", + "metadata": { + "title": "Redirected", + "sourceURL": "https://blocked.test/final", + }, + } + + monkeypatch.setattr(web_tools, "check_website_access", fake_check) + monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeFirecrawlClient()) + monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + + result = json.loads(await web_tools.web_extract_tool(["https://allowed.test"], use_llm_processing=False)) + + assert result["results"][0]["url"] == "https://blocked.test/final" + assert result["results"][0]["content"] == "" + assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test" + + +@pytest.mark.asyncio +async def test_web_crawl_short_circuits_blocked_url(monkeypatch): + from tools import web_tools + + # web_crawl_tool checks for Firecrawl env before website policy + monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key") + monkeypatch.setattr( + web_tools, + "check_website_access", + lambda url: { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + }, + ) + monkeypatch.setattr( + web_tools, + "_get_firecrawl_client", + lambda: pytest.fail("firecrawl should not run for blocked crawl URL"), + ) + monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + + result = json.loads(await web_tools.web_crawl_tool("https://blocked.test", use_llm_processing=False)) + + assert result["results"][0]["url"] == "https://blocked.test" + assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test" + + +@pytest.mark.asyncio +async def test_web_crawl_blocks_redirected_final_url(monkeypatch): + from tools import web_tools + + # web_crawl_tool checks for Firecrawl env before website policy + monkeypatch.setenv("FIRECRAWL_API_KEY", "fake-key") + + def fake_check(url): + if url == "https://allowed.test": + return None + if url == "https://blocked.test/final": + return { + "host": "blocked.test", + "rule": "blocked.test", + "source": "config", + "message": "Blocked by website policy", + } + pytest.fail(f"unexpected URL checked: {url}") + + class FakeCrawlClient: + def crawl(self, url, **kwargs): + return { + "data": [ + { + "markdown": "secret crawl content", + "metadata": { + "title": "Redirected crawl page", + "sourceURL": "https://blocked.test/final", + }, + } + ] + } + + monkeypatch.setattr(web_tools, "check_website_access", fake_check) + monkeypatch.setattr(web_tools, "_get_firecrawl_client", lambda: FakeCrawlClient()) + monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False) + + result = json.loads(await web_tools.web_crawl_tool("https://allowed.test", use_llm_processing=False)) + + assert result["results"][0]["content"] == "" + assert result["results"][0]["error"] == "Blocked by website policy" + assert result["results"][0]["blocked_by_policy"]["rule"] == "blocked.test" diff --git a/tools/browser_tool.py b/tools/browser_tool.py index d57eedee8..9760cf302 100644 --- a/tools/browser_tool.py +++ b/tools/browser_tool.py @@ -65,6 +65,11 @@ import requests from typing import Dict, Any, Optional, List from pathlib import Path from agent.auxiliary_client import call_llm + +try: + from tools.website_policy import check_website_access +except Exception: + check_website_access = lambda url: None # noqa: E731 — fail-open if policy module unavailable from tools.browser_providers.base import CloudBrowserProvider from tools.browser_providers.browserbase import BrowserbaseProvider from tools.browser_providers.browser_use import BrowserUseProvider @@ -550,6 +555,11 @@ def _get_session_info(task_id: Optional[str] = None) -> Dict[str, str]: session_info = provider.create_session(task_id) with _cleanup_lock: + # Double-check: another thread may have created a session while we + # were doing the network call. Use the existing one to avoid leaking + # orphan cloud sessions. + if task_id in _active_sessions: + return _active_sessions[task_id] _active_sessions[task_id] = session_info return session_info @@ -901,6 +911,15 @@ def browser_navigate(url: str, task_id: Optional[str] = None) -> str: Returns: JSON string with navigation result (includes stealth features info on first nav) """ + # Website policy check — block before navigating + blocked = check_website_access(url) + if blocked: + return json.dumps({ + "success": False, + "error": blocked["message"], + "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}, + }) + effective_task_id = task_id or "default" # Get session info to check if this is a new session diff --git a/tools/delegate_tool.py b/tools/delegate_tool.py index 1ac75ea88..2ef505dab 100644 --- a/tools/delegate_tool.py +++ b/tools/delegate_tool.py @@ -16,13 +16,10 @@ The parent's context only sees the delegation call and the summary result, never the child's intermediate tool calls or reasoning. """ -import contextlib -import io import json import logging logger = logging.getLogger(__name__) import os -import sys import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional @@ -150,7 +147,7 @@ def _build_child_progress_callback(task_index: int, parent_agent, task_count: in return _callback -def _run_single_child( +def _build_child_agent( task_index: int, goal: str, context: Optional[str], @@ -158,16 +155,15 @@ def _run_single_child( model: Optional[str], max_iterations: int, parent_agent, - task_count: int = 1, # Credential overrides from delegation config (provider:model resolution) override_provider: Optional[str] = None, override_base_url: Optional[str] = None, override_api_key: Optional[str] = None, override_api_mode: Optional[str] = None, -) -> Dict[str, Any]: +): """ - Spawn and run a single child agent. Called from within a thread. - Returns a structured result dict. + Build a child AIAgent on the main thread (thread-safe construction). + Returns the constructed child agent without running it. When override_* params are set (from delegation config), the child uses those credentials instead of inheriting from the parent. This enables @@ -176,8 +172,6 @@ def _run_single_child( """ from run_agent import AIAgent - child_start = time.monotonic() - # When no explicit toolsets given, inherit from parent's enabled toolsets # so disabled tools (e.g. web) don't leak to subagents. if toolsets: @@ -188,65 +182,84 @@ def _run_single_child( child_toolsets = _strip_blocked_tools(DEFAULT_TOOLSETS) child_prompt = _build_child_system_prompt(goal, context) + # Extract parent's API key so subagents inherit auth (e.g. Nous Portal). + parent_api_key = getattr(parent_agent, "api_key", None) + if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"): + parent_api_key = parent_agent._client_kwargs.get("api_key") - try: - # Extract parent's API key so subagents inherit auth (e.g. Nous Portal). - parent_api_key = getattr(parent_agent, "api_key", None) - if (not parent_api_key) and hasattr(parent_agent, "_client_kwargs"): - parent_api_key = parent_agent._client_kwargs.get("api_key") + # Build progress callback to relay tool calls to parent display + child_progress_cb = _build_child_progress_callback(task_index, parent_agent) - # Build progress callback to relay tool calls to parent display - child_progress_cb = _build_child_progress_callback(task_index, parent_agent, task_count) + # Share the parent's iteration budget so subagent tool calls + # count toward the session-wide limit. + shared_budget = getattr(parent_agent, "iteration_budget", None) - # Share the parent's iteration budget so subagent tool calls - # count toward the session-wide limit. - shared_budget = getattr(parent_agent, "iteration_budget", None) + # Resolve effective credentials: config override > parent inherit + effective_model = model or parent_agent.model + effective_provider = override_provider or getattr(parent_agent, "provider", None) + effective_base_url = override_base_url or parent_agent.base_url + effective_api_key = override_api_key or parent_api_key + effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None) - # Resolve effective credentials: config override > parent inherit - effective_model = model or parent_agent.model - effective_provider = override_provider or getattr(parent_agent, "provider", None) - effective_base_url = override_base_url or parent_agent.base_url - effective_api_key = override_api_key or parent_api_key - effective_api_mode = override_api_mode or getattr(parent_agent, "api_mode", None) + child = AIAgent( + base_url=effective_base_url, + api_key=effective_api_key, + model=effective_model, + provider=effective_provider, + api_mode=effective_api_mode, + max_iterations=max_iterations, + max_tokens=getattr(parent_agent, "max_tokens", None), + reasoning_config=getattr(parent_agent, "reasoning_config", None), + prefill_messages=getattr(parent_agent, "prefill_messages", None), + enabled_toolsets=child_toolsets, + quiet_mode=True, + ephemeral_system_prompt=child_prompt, + log_prefix=f"[subagent-{task_index}]", + platform=parent_agent.platform, + skip_context_files=True, + skip_memory=True, + clarify_callback=None, + session_db=getattr(parent_agent, '_session_db', None), + providers_allowed=parent_agent.providers_allowed, + providers_ignored=parent_agent.providers_ignored, + providers_order=parent_agent.providers_order, + provider_sort=parent_agent.provider_sort, + tool_progress_callback=child_progress_cb, + iteration_budget=shared_budget, + ) - child = AIAgent( - base_url=effective_base_url, - api_key=effective_api_key, - model=effective_model, - provider=effective_provider, - api_mode=effective_api_mode, - max_iterations=max_iterations, - max_tokens=getattr(parent_agent, "max_tokens", None), - reasoning_config=getattr(parent_agent, "reasoning_config", None), - prefill_messages=getattr(parent_agent, "prefill_messages", None), - enabled_toolsets=child_toolsets, - quiet_mode=True, - ephemeral_system_prompt=child_prompt, - log_prefix=f"[subagent-{task_index}]", - platform=parent_agent.platform, - skip_context_files=True, - skip_memory=True, - clarify_callback=None, - session_db=getattr(parent_agent, '_session_db', None), - providers_allowed=parent_agent.providers_allowed, - providers_ignored=parent_agent.providers_ignored, - providers_order=parent_agent.providers_order, - provider_sort=parent_agent.provider_sort, - tool_progress_callback=child_progress_cb, - iteration_budget=shared_budget, - ) + # Set delegation depth so children can't spawn grandchildren + child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 - # Set delegation depth so children can't spawn grandchildren - child._delegate_depth = getattr(parent_agent, '_delegate_depth', 0) + 1 - - # Register child for interrupt propagation - if hasattr(parent_agent, '_active_children'): + # Register child for interrupt propagation + if hasattr(parent_agent, '_active_children'): + lock = getattr(parent_agent, '_active_children_lock', None) + if lock: + with lock: + parent_agent._active_children.append(child) + else: parent_agent._active_children.append(child) - # Run with stdout/stderr suppressed to prevent interleaved output - devnull = io.StringIO() - with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr(devnull): - result = child.run_conversation(user_message=goal) + return child + +def _run_single_child( + task_index: int, + goal: str, + child=None, + parent_agent=None, + **_kwargs, +) -> Dict[str, Any]: + """ + Run a pre-built child agent. Called from within a thread. + Returns a structured result dict. + """ + child_start = time.monotonic() + + # Get the progress callback from the child agent + child_progress_cb = getattr(child, 'tool_progress_callback', None) + + try: + result = child.run_conversation(user_message=goal) # Flush any remaining batched progress to gateway if child_progress_cb and hasattr(child_progress_cb, '_flush'): @@ -355,11 +368,15 @@ def _run_single_child( # Unregister child from interrupt propagation if hasattr(parent_agent, '_active_children'): try: - parent_agent._active_children.remove(child) + lock = getattr(parent_agent, '_active_children_lock', None) + if lock: + with lock: + parent_agent._active_children.remove(child) + else: + parent_agent._active_children.remove(child) except (ValueError, UnboundLocalError) as e: logger.debug("Could not remove child from active_children: %s", e) - def delegate_task( goal: Optional[str] = None, context: Optional[str] = None, @@ -428,51 +445,38 @@ def delegate_task( # Track goal labels for progress display (truncated for readability) task_labels = [t["goal"][:40] for t in task_list] - if n_tasks == 1: - # Single task -- run directly (no thread pool overhead) - t = task_list[0] - result = _run_single_child( - task_index=0, - goal=t["goal"], - context=t.get("context"), - toolsets=t.get("toolsets") or toolsets, - model=creds["model"], - max_iterations=effective_max_iter, - parent_agent=parent_agent, - task_count=1, - override_provider=creds["provider"], - override_base_url=creds["base_url"], + # Build all child agents on the main thread (thread-safe construction) + children = [] + for i, t in enumerate(task_list): + child = _build_child_agent( + task_index=i, goal=t["goal"], context=t.get("context"), + toolsets=t.get("toolsets") or toolsets, model=creds["model"], + max_iterations=effective_max_iter, parent_agent=parent_agent, + override_provider=creds["provider"], override_base_url=creds["base_url"], override_api_key=creds["api_key"], override_api_mode=creds["api_mode"], ) + children.append((i, t, child)) + + if n_tasks == 1: + # Single task -- run directly (no thread pool overhead) + _i, _t, child = children[0] + result = _run_single_child(0, _t["goal"], child, parent_agent) results.append(result) else: # Batch -- run in parallel with per-task progress lines completed_count = 0 spinner_ref = getattr(parent_agent, '_delegate_spinner', None) - # Save stdout/stderr before the executor — redirect_stdout in child - # threads races on sys.stdout and can leave it as devnull permanently. - _saved_stdout = sys.stdout - _saved_stderr = sys.stderr - with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_CHILDREN) as executor: futures = {} - for i, t in enumerate(task_list): + for i, t, child in children: future = executor.submit( _run_single_child, task_index=i, goal=t["goal"], - context=t.get("context"), - toolsets=t.get("toolsets") or toolsets, - model=creds["model"], - max_iterations=effective_max_iter, + child=child, parent_agent=parent_agent, - task_count=n_tasks, - override_provider=creds["provider"], - override_base_url=creds["base_url"], - override_api_key=creds["api_key"], - override_api_mode=creds["api_mode"], ) futures[future] = i @@ -515,10 +519,6 @@ def delegate_task( except Exception as e: logger.debug("Spinner update_text failed: %s", e) - # Restore stdout/stderr in case redirect_stdout race left them as devnull - sys.stdout = _saved_stdout - sys.stderr = _saved_stderr - # Sort by task_index so results match input order results.sort(key=lambda r: r["task_index"]) diff --git a/tools/environments/local.py b/tools/environments/local.py index dc753b410..914192f2d 100644 --- a/tools/environments/local.py +++ b/tools/environments/local.py @@ -82,6 +82,9 @@ def _build_provider_env_blocklist() -> frozenset: "FIREWORKS_API_KEY", # Fireworks AI "XAI_API_KEY", # xAI (Grok) "HELICONE_API_KEY", # LLM Observability proxy + "PARALLEL_API_KEY", + "FIRECRAWL_API_KEY", + "FIRECRAWL_API_URL", # Gateway/runtime config not represented in OPTIONAL_ENV_VARS. "TELEGRAM_HOME_CHANNEL", "TELEGRAM_HOME_CHANNEL_NAME", diff --git a/tools/file_operations.py b/tools/file_operations.py index 7f39a0277..56ed1319f 100644 --- a/tools/file_operations.py +++ b/tools/file_operations.py @@ -94,7 +94,7 @@ def _get_safe_write_root() -> Optional[str]: def _is_write_denied(path: str) -> bool: """Return True if path is on the write deny list.""" - resolved = os.path.realpath(os.path.expanduser(path)) + resolved = os.path.realpath(os.path.expanduser(str(path))) # 1) Static deny list if resolved in WRITE_DENIED_PATHS: diff --git a/tools/fuzzy_match.py b/tools/fuzzy_match.py index f53451c63..ddcdf4274 100644 --- a/tools/fuzzy_match.py +++ b/tools/fuzzy_match.py @@ -254,10 +254,9 @@ def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, in if '\n'.join(check_lines) == modified_pattern: # Found match - calculate original positions - start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 - if end_pos >= len(content): - end_pos = len(content) + start_pos, end_pos = _calculate_line_positions( + content_lines, i, i + pattern_line_count, len(content) + ) matches.append((start_pos, end_pos)) return matches @@ -309,9 +308,10 @@ def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]: if similarity >= threshold: # Calculate positions using ORIGINAL lines to ensure correct character offsets in the file - start_pos = sum(len(line) + 1 for line in orig_content_lines[:i]) - end_pos = sum(len(line) + 1 for line in orig_content_lines[:i + pattern_line_count]) - 1 - matches.append((start_pos, min(end_pos, len(content)))) + start_pos, end_pos = _calculate_line_positions( + orig_content_lines, i, i + pattern_line_count, len(content) + ) + matches.append((start_pos, end_pos)) return matches @@ -343,10 +343,9 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]] # Need at least 50% of lines to have high similarity if high_similarity_count >= len(pattern_lines) * 0.5: - start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1 - if end_pos >= len(content): - end_pos = len(content) + start_pos, end_pos = _calculate_line_positions( + content_lines, i, i + pattern_line_count, len(content) + ) matches.append((start_pos, end_pos)) return matches @@ -356,6 +355,26 @@ def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]] # Helper Functions # ============================================================================= +def _calculate_line_positions(content_lines: List[str], start_line: int, + end_line: int, content_length: int) -> Tuple[int, int]: + """Calculate start and end character positions from line indices. + + Args: + content_lines: List of lines (without newlines) + start_line: Starting line index (0-based) + end_line: Ending line index (exclusive, 0-based) + content_length: Total length of the original content string + + Returns: + Tuple of (start_pos, end_pos) in the original content + """ + start_pos = sum(len(line) + 1 for line in content_lines[:start_line]) + end_pos = sum(len(line) + 1 for line in content_lines[:end_line]) - 1 + if end_pos >= content_length: + end_pos = content_length + return start_pos, end_pos + + def _find_normalized_matches(content: str, content_lines: List[str], content_normalized_lines: List[str], pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]: @@ -383,13 +402,9 @@ def _find_normalized_matches(content: str, content_lines: List[str], if block == pattern_normalized: # Found a match - calculate original positions - start_pos = sum(len(line) + 1 for line in content_lines[:i]) - end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1 - - # Handle case where end is past content - if end_pos >= len(content): - end_pos = len(content) - + start_pos, end_pos = _calculate_line_positions( + content_lines, i, i + num_pattern_lines, len(content) + ) matches.append((start_pos, end_pos)) return matches diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 7294e8be5..7ff8103b2 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -1624,6 +1624,72 @@ def get_mcp_status() -> List[dict]: return result +def probe_mcp_server_tools() -> Dict[str, List[tuple]]: + """Temporarily connect to configured MCP servers and list their tools. + + Designed for ``hermes tools`` interactive configuration — connects to each + enabled server, grabs tool names and descriptions, then disconnects. + Does NOT register tools in the Hermes registry. + + Returns: + Dict mapping server name to list of (tool_name, description) tuples. + Servers that fail to connect are omitted from the result. + """ + if not _MCP_AVAILABLE: + return {} + + servers_config = _load_mcp_config() + if not servers_config: + return {} + + enabled = { + k: v for k, v in servers_config.items() + if _parse_boolish(v.get("enabled", True), default=True) + } + if not enabled: + return {} + + _ensure_mcp_loop() + + result: Dict[str, List[tuple]] = {} + probed_servers: List[MCPServerTask] = [] + + async def _probe_all(): + names = list(enabled.keys()) + coros = [] + for name, cfg in enabled.items(): + ct = cfg.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) + coros.append(asyncio.wait_for(_connect_server(name, cfg), timeout=ct)) + + outcomes = await asyncio.gather(*coros, return_exceptions=True) + + for name, outcome in zip(names, outcomes): + if isinstance(outcome, Exception): + logger.debug("Probe: failed to connect to '%s': %s", name, outcome) + continue + probed_servers.append(outcome) + tools = [] + for t in outcome._tools: + desc = getattr(t, "description", "") or "" + tools.append((t.name, desc)) + result[name] = tools + + # Shut down all probed connections + await asyncio.gather( + *(s.shutdown() for s in probed_servers), + return_exceptions=True, + ) + + try: + _run_on_mcp_loop(_probe_all(), timeout=120) + except Exception as exc: + logger.debug("MCP probe failed: %s", exc) + finally: + _stop_mcp_loop() + + return result + + def shutdown_mcp_servers(): """Close all MCP server connections and stop the background loop. diff --git a/tools/memory_tool.py b/tools/memory_tool.py index d7950d38c..241c17f8f 100644 --- a/tools/memory_tool.py +++ b/tools/memory_tool.py @@ -23,11 +23,13 @@ Design: - Frozen snapshot pattern: system prompt is stable, tool responses show live state """ +import fcntl import json import logging import os import re import tempfile +from contextlib import contextmanager from pathlib import Path from typing import Dict, Any, List, Optional @@ -120,14 +122,43 @@ class MemoryStore: "user": self._render_block("user", self.user_entries), } + @staticmethod + @contextmanager + def _file_lock(path: Path): + """Acquire an exclusive file lock for read-modify-write safety. + + Uses a separate .lock file so the memory file itself can still be + atomically replaced via os.replace(). + """ + lock_path = path.with_suffix(path.suffix + ".lock") + lock_path.parent.mkdir(parents=True, exist_ok=True) + fd = open(lock_path, "w") + try: + fcntl.flock(fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(fd, fcntl.LOCK_UN) + fd.close() + + @staticmethod + def _path_for(target: str) -> Path: + if target == "user": + return MEMORY_DIR / "USER.md" + return MEMORY_DIR / "MEMORY.md" + + def _reload_target(self, target: str): + """Re-read entries from disk into in-memory state. + + Called under file lock to get the latest state before mutating. + """ + fresh = self._read_file(self._path_for(target)) + fresh = list(dict.fromkeys(fresh)) # deduplicate + self._set_entries(target, fresh) + def save_to_disk(self, target: str): """Persist entries to the appropriate file. Called after every mutation.""" MEMORY_DIR.mkdir(parents=True, exist_ok=True) - - if target == "memory": - self._write_file(MEMORY_DIR / "MEMORY.md", self.memory_entries) - elif target == "user": - self._write_file(MEMORY_DIR / "USER.md", self.user_entries) + self._write_file(self._path_for(target), self._entries_for(target)) def _entries_for(self, target: str) -> List[str]: if target == "user": @@ -162,33 +193,37 @@ class MemoryStore: if scan_error: return {"success": False, "error": scan_error} - entries = self._entries_for(target) - limit = self._char_limit(target) + with self._file_lock(self._path_for(target)): + # Re-read from disk under lock to pick up writes from other sessions + self._reload_target(target) - # Reject exact duplicates - if content in entries: - return self._success_response(target, "Entry already exists (no duplicate added).") + entries = self._entries_for(target) + limit = self._char_limit(target) - # Calculate what the new total would be - new_entries = entries + [content] - new_total = len(ENTRY_DELIMITER.join(new_entries)) + # Reject exact duplicates + if content in entries: + return self._success_response(target, "Entry already exists (no duplicate added).") - if new_total > limit: - current = self._char_count(target) - return { - "success": False, - "error": ( - f"Memory at {current:,}/{limit:,} chars. " - f"Adding this entry ({len(content)} chars) would exceed the limit. " - f"Replace or remove existing entries first." - ), - "current_entries": entries, - "usage": f"{current:,}/{limit:,}", - } + # Calculate what the new total would be + new_entries = entries + [content] + new_total = len(ENTRY_DELIMITER.join(new_entries)) - entries.append(content) - self._set_entries(target, entries) - self.save_to_disk(target) + if new_total > limit: + current = self._char_count(target) + return { + "success": False, + "error": ( + f"Memory at {current:,}/{limit:,} chars. " + f"Adding this entry ({len(content)} chars) would exceed the limit. " + f"Replace or remove existing entries first." + ), + "current_entries": entries, + "usage": f"{current:,}/{limit:,}", + } + + entries.append(content) + self._set_entries(target, entries) + self.save_to_disk(target) return self._success_response(target, "Entry added.") @@ -206,44 +241,47 @@ class MemoryStore: if scan_error: return {"success": False, "error": scan_error} - entries = self._entries_for(target) - matches = [(i, e) for i, e in enumerate(entries) if old_text in e] + with self._file_lock(self._path_for(target)): + self._reload_target(target) - if len(matches) == 0: - return {"success": False, "error": f"No entry matched '{old_text}'."} + entries = self._entries_for(target) + matches = [(i, e) for i, e in enumerate(entries) if old_text in e] - if len(matches) > 1: - # If all matches are identical (exact duplicates), operate on the first one - unique_texts = set(e for _, e in matches) - if len(unique_texts) > 1: - previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches] + if len(matches) == 0: + return {"success": False, "error": f"No entry matched '{old_text}'."} + + if len(matches) > 1: + # If all matches are identical (exact duplicates), operate on the first one + unique_texts = set(e for _, e in matches) + if len(unique_texts) > 1: + previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches] + return { + "success": False, + "error": f"Multiple entries matched '{old_text}'. Be more specific.", + "matches": previews, + } + # All identical -- safe to replace just the first + + idx = matches[0][0] + limit = self._char_limit(target) + + # Check that replacement doesn't blow the budget + test_entries = entries.copy() + test_entries[idx] = new_content + new_total = len(ENTRY_DELIMITER.join(test_entries)) + + if new_total > limit: return { "success": False, - "error": f"Multiple entries matched '{old_text}'. Be more specific.", - "matches": previews, + "error": ( + f"Replacement would put memory at {new_total:,}/{limit:,} chars. " + f"Shorten the new content or remove other entries first." + ), } - # All identical -- safe to replace just the first - idx = matches[0][0] - limit = self._char_limit(target) - - # Check that replacement doesn't blow the budget - test_entries = entries.copy() - test_entries[idx] = new_content - new_total = len(ENTRY_DELIMITER.join(test_entries)) - - if new_total > limit: - return { - "success": False, - "error": ( - f"Replacement would put memory at {new_total:,}/{limit:,} chars. " - f"Shorten the new content or remove other entries first." - ), - } - - entries[idx] = new_content - self._set_entries(target, entries) - self.save_to_disk(target) + entries[idx] = new_content + self._set_entries(target, entries) + self.save_to_disk(target) return self._success_response(target, "Entry replaced.") @@ -253,28 +291,31 @@ class MemoryStore: if not old_text: return {"success": False, "error": "old_text cannot be empty."} - entries = self._entries_for(target) - matches = [(i, e) for i, e in enumerate(entries) if old_text in e] + with self._file_lock(self._path_for(target)): + self._reload_target(target) - if len(matches) == 0: - return {"success": False, "error": f"No entry matched '{old_text}'."} + entries = self._entries_for(target) + matches = [(i, e) for i, e in enumerate(entries) if old_text in e] - if len(matches) > 1: - # If all matches are identical (exact duplicates), remove the first one - unique_texts = set(e for _, e in matches) - if len(unique_texts) > 1: - previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches] - return { - "success": False, - "error": f"Multiple entries matched '{old_text}'. Be more specific.", - "matches": previews, - } - # All identical -- safe to remove just the first + if len(matches) == 0: + return {"success": False, "error": f"No entry matched '{old_text}'."} - idx = matches[0][0] - entries.pop(idx) - self._set_entries(target, entries) - self.save_to_disk(target) + if len(matches) > 1: + # If all matches are identical (exact duplicates), remove the first one + unique_texts = set(e for _, e in matches) + if len(unique_texts) > 1: + previews = [e[:80] + ("..." if len(e) > 80 else "") for _, e in matches] + return { + "success": False, + "error": f"Multiple entries matched '{old_text}'. Be more specific.", + "matches": previews, + } + # All identical -- safe to remove just the first + + idx = matches[0][0] + entries.pop(idx) + self._set_entries(target, entries) + self.save_to_disk(target) return self._success_response(target, "Entry removed.") diff --git a/tools/process_registry.py b/tools/process_registry.py index ceb45ab27..c6ee9ceb6 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -78,6 +78,11 @@ class ProcessSession: output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS) max_output_chars: int = MAX_OUTPUT_CHARS detached: bool = False # True if recovered from crash (no pipe) + # Watcher/notification metadata (persisted for crash recovery) + watcher_platform: str = "" + watcher_chat_id: str = "" + watcher_thread_id: str = "" + watcher_interval: int = 0 # 0 = no watcher configured _lock: threading.Lock = field(default_factory=threading.Lock) _reader_thread: Optional[threading.Thread] = field(default=None, repr=False) _pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True) @@ -709,6 +714,10 @@ class ProcessRegistry: "started_at": s.started_at, "task_id": s.task_id, "session_key": s.session_key, + "watcher_platform": s.watcher_platform, + "watcher_chat_id": s.watcher_chat_id, + "watcher_thread_id": s.watcher_thread_id, + "watcher_interval": s.watcher_interval, }) # Atomic write to avoid corruption on crash @@ -755,12 +764,27 @@ class ProcessRegistry: cwd=entry.get("cwd"), started_at=entry.get("started_at", time.time()), detached=True, # Can't read output, but can report status + kill + watcher_platform=entry.get("watcher_platform", ""), + watcher_chat_id=entry.get("watcher_chat_id", ""), + watcher_thread_id=entry.get("watcher_thread_id", ""), + watcher_interval=entry.get("watcher_interval", 0), ) with self._lock: self._running[session.id] = session recovered += 1 logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid) + # Re-enqueue watcher so gateway can resume notifications + if session.watcher_interval > 0: + self.pending_watchers.append({ + "session_id": session.id, + "check_interval": session.watcher_interval, + "session_key": session.session_key, + "platform": session.watcher_platform, + "chat_id": session.watcher_chat_id, + "thread_id": session.watcher_thread_id, + }) + # Clear the checkpoint (will be rewritten as processes finish) try: from utils import atomic_json_write diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index 2f0f014ab..4b0c4815f 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -355,20 +355,31 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No """Send via Telegram Bot API (one-shot, no polling needed). Applies markdown→MarkdownV2 formatting (same as the gateway adapter) - so that bold, links, and headers render correctly. + so that bold, links, and headers render correctly. If the message + already contains HTML tags, it is sent with ``parse_mode='HTML'`` + instead, bypassing MarkdownV2 conversion. """ try: from telegram import Bot from telegram.constants import ParseMode - # Reuse the gateway adapter's format_message for markdown→MarkdownV2 - try: - from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 - _adapter = TelegramAdapter.__new__(TelegramAdapter) - formatted = _adapter.format_message(message) - except Exception: - # Fallback: send as-is if formatting unavailable + # Auto-detect HTML tags — if present, skip MarkdownV2 and send as HTML. + # Inspired by github.com/ashaney — PR #1568. + _has_html = bool(re.search(r'<[a-zA-Z/][^>]*>', message)) + + if _has_html: formatted = message + send_parse_mode = ParseMode.HTML + else: + # Reuse the gateway adapter's format_message for markdown→MarkdownV2 + try: + from gateway.platforms.telegram import TelegramAdapter, _escape_mdv2, _strip_mdv2 + _adapter = TelegramAdapter.__new__(TelegramAdapter) + formatted = _adapter.format_message(message) + except Exception: + # Fallback: send as-is if formatting unavailable + formatted = message + send_parse_mode = ParseMode.MARKDOWN_V2 bot = Bot(token=token) int_chat_id = int(chat_id) @@ -384,16 +395,19 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No try: last_msg = await bot.send_message( chat_id=int_chat_id, text=formatted, - parse_mode=ParseMode.MARKDOWN_V2, **thread_kwargs + parse_mode=send_parse_mode, **thread_kwargs ) except Exception as md_error: - # MarkdownV2 failed, fall back to plain text - if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower(): - logger.warning("MarkdownV2 parse failed in _send_telegram, falling back to plain text: %s", md_error) - try: - from gateway.platforms.telegram import _strip_mdv2 - plain = _strip_mdv2(formatted) - except Exception: + # Parse failed, fall back to plain text + if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower() or "html" in str(md_error).lower(): + logger.warning("Parse mode %s failed in _send_telegram, falling back to plain text: %s", send_parse_mode, md_error) + if not _has_html: + try: + from gateway.platforms.telegram import _strip_mdv2 + plain = _strip_mdv2(formatted) + except Exception: + plain = message + else: plain = message last_msg = await bot.send_message( chat_id=int_chat_id, text=plain, @@ -565,50 +579,55 @@ async def _send_email(extra, chat_id, message): return {"error": f"Email send failed: {e}"} -async def _send_sms(api_key, chat_id, message): - """Send via Telnyx SMS REST API (one-shot, no persistent connection needed).""" +async def _send_sms(auth_token, chat_id, message): + """Send a single SMS via Twilio REST API. + + Uses HTTP Basic auth (Account SID : Auth Token) and form-encoded POST. + Chunking is handled by _send_to_platform() before this is called. + """ try: import aiohttp except ImportError: return {"error": "aiohttp not installed. Run: pip install aiohttp"} + + import base64 + + account_sid = os.getenv("TWILIO_ACCOUNT_SID", "") + from_number = os.getenv("TWILIO_PHONE_NUMBER", "") + if not account_sid or not auth_token or not from_number: + return {"error": "SMS not configured (TWILIO_ACCOUNT_SID, TWILIO_AUTH_TOKEN, TWILIO_PHONE_NUMBER required)"} + + # Strip markdown — SMS renders it as literal characters + message = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL) + message = re.sub(r"\*(.+?)\*", r"\1", message, flags=re.DOTALL) + message = re.sub(r"__(.+?)__", r"\1", message, flags=re.DOTALL) + message = re.sub(r"_(.+?)_", r"\1", message, flags=re.DOTALL) + message = re.sub(r"```[a-z]*\n?", "", message) + message = re.sub(r"`(.+?)`", r"\1", message) + message = re.sub(r"^#{1,6}\s+", "", message, flags=re.MULTILINE) + message = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", message) + message = re.sub(r"\n{3,}", "\n\n", message) + message = message.strip() + try: - from_number = os.getenv("TELNYX_FROM_NUMBERS", "").split(",")[0].strip() - if not from_number: - return {"error": "TELNYX_FROM_NUMBERS not configured"} - if not api_key: - api_key = os.getenv("TELNYX_API_KEY", "") - if not api_key: - return {"error": "TELNYX_API_KEY not configured"} + creds = f"{account_sid}:{auth_token}" + encoded = base64.b64encode(creds.encode("ascii")).decode("ascii") + url = f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Messages.json" + headers = {"Authorization": f"Basic {encoded}"} - # Strip markdown for SMS - text = re.sub(r"\*\*(.+?)\*\*", r"\1", message, flags=re.DOTALL) - text = re.sub(r"\*(.+?)\*", r"\1", text, flags=re.DOTALL) - text = re.sub(r"```[a-z]*\n?", "", text) - text = re.sub(r"`(.+?)`", r"\1", text) - text = re.sub(r"^#{1,6}\s+", "", text, flags=re.MULTILINE) - text = text.strip() - - # Chunk to 1600 chars - chunks = [text[i:i+1600] for i in range(0, len(text), 1600)] if len(text) > 1600 else [text] - - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } - message_ids = [] async with aiohttp.ClientSession() as session: - for chunk in chunks: - payload = {"from": from_number, "to": chat_id, "text": chunk} - async with session.post( - "https://api.telnyx.com/v2/messages", - json=payload, - headers=headers, - ) as resp: - body = await resp.json() - if resp.status >= 400: - return {"error": f"Telnyx API error ({resp.status}): {body}"} - message_ids.append(body.get("data", {}).get("id", "")) - return {"success": True, "platform": "sms", "chat_id": chat_id, "message_ids": message_ids} + form_data = aiohttp.FormData() + form_data.add_field("From", from_number) + form_data.add_field("To", chat_id) + form_data.add_field("Body", message) + + async with session.post(url, data=form_data, headers=headers) as resp: + body = await resp.json() + if resp.status >= 400: + error_msg = body.get("message", str(body)) + return {"error": f"Twilio API error ({resp.status}): {error_msg}"} + msg_sid = body.get("sid", "") + return {"success": True, "platform": "sms", "chat_id": chat_id, "message_id": msg_sid} except Exception as e: return {"error": f"SMS send failed: {e}"} diff --git a/tools/terminal_tool.py b/tools/terminal_tool.py index 3cc541b58..424bf6514 100644 --- a/tools/terminal_tool.py +++ b/tools/terminal_tool.py @@ -1082,13 +1082,23 @@ def terminal_tool( result_data["check_interval_note"] = ( f"Requested {check_interval}s raised to minimum 30s" ) + watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "") + watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "") + watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "") + + # Store on session for checkpoint persistence + proc_session.watcher_platform = watcher_platform + proc_session.watcher_chat_id = watcher_chat_id + proc_session.watcher_thread_id = watcher_thread_id + proc_session.watcher_interval = effective_interval + process_registry.pending_watchers.append({ "session_id": proc_session.id, "check_interval": effective_interval, "session_key": session_key, - "platform": os.getenv("HERMES_SESSION_PLATFORM", ""), - "chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""), - "thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""), + "platform": watcher_platform, + "chat_id": watcher_chat_id, + "thread_id": watcher_thread_id, }) return json.dumps(result_data, ensure_ascii=False) diff --git a/tools/web_tools.py b/tools/web_tools.py index ede1adb03..79444d72b 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -3,16 +3,16 @@ Standalone Web Tools Module This module provides generic web tools that work with multiple backend providers. -Currently uses Firecrawl as the backend, and the interface makes it easy to swap -providers without changing the function signatures. +Backend is selected during ``hermes tools`` setup (web.backend in config.yaml). Available tools: - web_search_tool: Search the web for information - web_extract_tool: Extract content from specific web pages -- web_crawl_tool: Crawl websites with specific instructions +- web_crawl_tool: Crawl websites with specific instructions (Firecrawl only) Backend compatibility: -- Firecrawl: https://docs.firecrawl.dev/introduction +- Firecrawl: https://docs.firecrawl.dev/introduction (search, extract, crawl) +- Parallel: https://docs.parallel.ai (search, extract) LLM Processing: - Uses OpenRouter API with Gemini 3 Flash Preview for intelligent content extraction @@ -46,12 +46,50 @@ import os import re import asyncio from typing import List, Dict, Any, Optional +import httpx from firecrawl import Firecrawl from agent.auxiliary_client import async_call_llm from tools.debug_helpers import DebugSession +from tools.website_policy import check_website_access logger = logging.getLogger(__name__) + +# ─── Backend Selection ──────────────────────────────────────────────────────── + +def _load_web_config() -> dict: + """Load the ``web:`` section from ~/.hermes/config.yaml.""" + try: + from hermes_cli.config import load_config + return load_config().get("web", {}) + except (ImportError, Exception): + return {} + + +def _get_backend() -> str: + """Determine which web backend to use. + + Reads ``web.backend`` from config.yaml (set by ``hermes tools``). + Falls back to whichever API key is present for users who configured + keys manually without running setup. + """ + configured = _load_web_config().get("backend", "").lower().strip() + if configured in ("parallel", "firecrawl", "tavily"): + return configured + # Fallback for manual / legacy config — use whichever key is present. + has_firecrawl = bool(os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")) + has_parallel = bool(os.getenv("PARALLEL_API_KEY")) + has_tavily = bool(os.getenv("TAVILY_API_KEY")) + if has_tavily and not has_firecrawl and not has_parallel: + return "tavily" + if has_parallel and not has_firecrawl: + return "parallel" + # Default to firecrawl (backward compat, or when both are set) + return "firecrawl" + + +# ─── Firecrawl Client ──────────────────────────────────────────────────────── + _firecrawl_client = None def _get_firecrawl_client(): @@ -80,6 +118,129 @@ def _get_firecrawl_client(): _firecrawl_client = Firecrawl(**kwargs) return _firecrawl_client + +# ─── Parallel Client ───────────────────────────────────────────────────────── + +_parallel_client = None +_async_parallel_client = None + +def _get_parallel_client(): + """Get or create the Parallel sync client (lazy initialization). + + Requires PARALLEL_API_KEY environment variable. + """ + from parallel import Parallel + global _parallel_client + if _parallel_client is None: + api_key = os.getenv("PARALLEL_API_KEY") + if not api_key: + raise ValueError( + "PARALLEL_API_KEY environment variable not set. " + "Get your API key at https://parallel.ai" + ) + _parallel_client = Parallel(api_key=api_key) + return _parallel_client + + +def _get_async_parallel_client(): + """Get or create the Parallel async client (lazy initialization). + + Requires PARALLEL_API_KEY environment variable. + """ + from parallel import AsyncParallel + global _async_parallel_client + if _async_parallel_client is None: + api_key = os.getenv("PARALLEL_API_KEY") + if not api_key: + raise ValueError( + "PARALLEL_API_KEY environment variable not set. " + "Get your API key at https://parallel.ai" + ) + _async_parallel_client = AsyncParallel(api_key=api_key) + return _async_parallel_client + +# ─── Tavily Client ─────────────────────────────────────────────────────────── + +_TAVILY_BASE_URL = "https://api.tavily.com" + + +def _tavily_request(endpoint: str, payload: dict) -> dict: + """Send a POST request to the Tavily API. + + Auth is provided via ``api_key`` in the JSON body (no header-based auth). + Raises ``ValueError`` if ``TAVILY_API_KEY`` is not set. + """ + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + raise ValueError( + "TAVILY_API_KEY environment variable not set. " + "Get your API key at https://app.tavily.com/home" + ) + payload["api_key"] = api_key + url = f"{_TAVILY_BASE_URL}/{endpoint.lstrip('/')}" + logger.info("Tavily %s request to %s", endpoint, url) + response = httpx.post(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def _normalize_tavily_search_results(response: dict) -> dict: + """Normalize Tavily /search response to the standard web search format. + + Tavily returns ``{results: [{title, url, content, score, ...}]}``. + We map to ``{success, data: {web: [{title, url, description, position}]}}``. + """ + web_results = [] + for i, result in enumerate(response.get("results", [])): + web_results.append({ + "title": result.get("title", ""), + "url": result.get("url", ""), + "description": result.get("content", ""), + "position": i + 1, + }) + return {"success": True, "data": {"web": web_results}} + + +def _normalize_tavily_documents(response: dict, fallback_url: str = "") -> List[Dict[str, Any]]: + """Normalize Tavily /extract or /crawl response to the standard document format. + + Maps results to ``{url, title, content, raw_content, metadata}`` and + includes any ``failed_results`` / ``failed_urls`` as error entries. + """ + documents: List[Dict[str, Any]] = [] + for result in response.get("results", []): + url = result.get("url", fallback_url) + raw = result.get("raw_content", "") or result.get("content", "") + documents.append({ + "url": url, + "title": result.get("title", ""), + "content": raw, + "raw_content": raw, + "metadata": {"sourceURL": url, "title": result.get("title", "")}, + }) + # Handle failed results + for fail in response.get("failed_results", []): + documents.append({ + "url": fail.get("url", fallback_url), + "title": "", + "content": "", + "raw_content": "", + "error": fail.get("error", "extraction failed"), + "metadata": {"sourceURL": fail.get("url", fallback_url)}, + }) + for fail_url in response.get("failed_urls", []): + url_str = fail_url if isinstance(fail_url, str) else str(fail_url) + documents.append({ + "url": url_str, + "title": "", + "content": "", + "raw_content": "", + "error": "extraction failed", + "metadata": {"sourceURL": url_str}, + }) + return documents + + DEFAULT_MIN_LENGTH_FOR_SUMMARIZATION = 5000 # Allow per-task override via env var @@ -427,13 +588,89 @@ def clean_base64_images(text: str) -> str: return cleaned_text +# ─── Parallel Search & Extract Helpers ──────────────────────────────────────── + +def _parallel_search(query: str, limit: int = 5) -> dict: + """Search using the Parallel SDK and return results as a dict.""" + from tools.interrupt import is_interrupted + if is_interrupted(): + return {"error": "Interrupted", "success": False} + + mode = os.getenv("PARALLEL_SEARCH_MODE", "agentic").lower().strip() + if mode not in ("fast", "one-shot", "agentic"): + mode = "agentic" + + logger.info("Parallel search: '%s' (mode=%s, limit=%d)", query, mode, limit) + response = _get_parallel_client().beta.search( + search_queries=[query], + objective=query, + mode=mode, + max_results=min(limit, 20), + ) + + web_results = [] + for i, result in enumerate(response.results or []): + excerpts = result.excerpts or [] + web_results.append({ + "url": result.url or "", + "title": result.title or "", + "description": " ".join(excerpts) if excerpts else "", + "position": i + 1, + }) + + return {"success": True, "data": {"web": web_results}} + + +async def _parallel_extract(urls: List[str]) -> List[Dict[str, Any]]: + """Extract content from URLs using the Parallel async SDK. + + Returns a list of result dicts matching the structure expected by the + LLM post-processing pipeline (url, title, content, metadata). + """ + from tools.interrupt import is_interrupted + if is_interrupted(): + return [{"url": u, "error": "Interrupted", "title": ""} for u in urls] + + logger.info("Parallel extract: %d URL(s)", len(urls)) + response = await _get_async_parallel_client().beta.extract( + urls=urls, + full_content=True, + ) + + results = [] + for result in response.results or []: + content = result.full_content or "" + if not content: + content = "\n\n".join(result.excerpts or []) + url = result.url or "" + title = result.title or "" + results.append({ + "url": url, + "title": title, + "content": content, + "raw_content": content, + "metadata": {"sourceURL": url, "title": title}, + }) + + for error in response.errors or []: + results.append({ + "url": error.url or "", + "title": "", + "content": "", + "error": error.content or error.error_type or "extraction failed", + "metadata": {"sourceURL": error.url or ""}, + }) + + return results + + def web_search_tool(query: str, limit: int = 5) -> str: """ Search the web for information using available search API backend. - + This function provides a generic interface for web search that can work - with multiple backends. Currently uses Firecrawl. - + with multiple backends (Parallel or Firecrawl). + Note: This function returns search result metadata only (URLs, titles, descriptions). Use web_extract_tool to get full content from specific URLs. @@ -477,17 +714,44 @@ def web_search_tool(query: str, limit: int = 5) -> str: if is_interrupted(): return json.dumps({"error": "Interrupted", "success": False}) + # Dispatch to the configured backend + backend = _get_backend() + if backend == "parallel": + response_data = _parallel_search(query, limit) + debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", [])) + result_json = json.dumps(response_data, indent=2, ensure_ascii=False) + debug_call_data["final_response_size"] = len(result_json) + _debug.log_call("web_search_tool", debug_call_data) + _debug.save() + return result_json + + if backend == "tavily": + logger.info("Tavily search: '%s' (limit: %d)", query, limit) + raw = _tavily_request("search", { + "query": query, + "max_results": min(limit, 20), + "include_raw_content": False, + "include_images": False, + }) + response_data = _normalize_tavily_search_results(raw) + debug_call_data["results_count"] = len(response_data.get("data", {}).get("web", [])) + result_json = json.dumps(response_data, indent=2, ensure_ascii=False) + debug_call_data["final_response_size"] = len(result_json) + _debug.log_call("web_search_tool", debug_call_data) + _debug.save() + return result_json + logger.info("Searching the web for: '%s' (limit: %d)", query, limit) - + response = _get_firecrawl_client().search( query=query, limit=limit ) - + # The response is a SearchData object with web, news, and images attributes # When not scraping, the results are directly in these attributes web_results = [] - + # Check if response has web attribute (SearchData object) if hasattr(response, 'web'): # Response is a SearchData object with web attribute @@ -595,100 +859,137 @@ async def web_extract_tool( try: logger.info("Extracting content from %d URL(s)", len(urls)) - - # Determine requested formats for Firecrawl v2 - formats: List[str] = [] - if format == "markdown": - formats = ["markdown"] - elif format == "html": - formats = ["html"] - else: - # Default: request markdown for LLM-readiness and include html as backup - formats = ["markdown", "html"] - - # Always use individual scraping for simplicity and reliability - # Batch scraping adds complexity without much benefit for small numbers of URLs - results: List[Dict[str, Any]] = [] - - from tools.interrupt import is_interrupted as _is_interrupted - for url in urls: - if _is_interrupted(): - results.append({"url": url, "error": "Interrupted", "title": ""}) - continue - try: - logger.info("Scraping: %s", url) - scrape_result = _get_firecrawl_client().scrape( - url=url, - formats=formats - ) - - # Process the result - properly handle object serialization - metadata = {} - title = "" - content_markdown = None - content_html = None - - # Extract data from the scrape result - if hasattr(scrape_result, 'model_dump'): - # Pydantic model - use model_dump to get dict - result_dict = scrape_result.model_dump() - content_markdown = result_dict.get('markdown') - content_html = result_dict.get('html') - metadata = result_dict.get('metadata', {}) - elif hasattr(scrape_result, '__dict__'): - # Regular object with attributes - content_markdown = getattr(scrape_result, 'markdown', None) - content_html = getattr(scrape_result, 'html', None) - - # Handle metadata - convert to dict if it's an object - metadata_obj = getattr(scrape_result, 'metadata', {}) - if hasattr(metadata_obj, 'model_dump'): - metadata = metadata_obj.model_dump() - elif hasattr(metadata_obj, '__dict__'): - metadata = metadata_obj.__dict__ - elif isinstance(metadata_obj, dict): - metadata = metadata_obj - else: - metadata = {} - elif isinstance(scrape_result, dict): - # Already a dictionary - content_markdown = scrape_result.get('markdown') - content_html = scrape_result.get('html') - metadata = scrape_result.get('metadata', {}) - - # Ensure metadata is a dict (not an object) - if not isinstance(metadata, dict): - if hasattr(metadata, 'model_dump'): - metadata = metadata.model_dump() - elif hasattr(metadata, '__dict__'): - metadata = metadata.__dict__ - else: - metadata = {} - - # Get title from metadata - title = metadata.get("title", "") - - # Choose content based on requested format - chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" - - results.append({ - "url": metadata.get("sourceURL", url), - "title": title, - "content": chosen_content, - "raw_content": chosen_content, - "metadata": metadata # Now guaranteed to be a dict - }) - - except Exception as scrape_err: - logger.debug("Scrape failed for %s: %s", url, scrape_err) - results.append({ - "url": url, - "title": "", - "content": "", - "raw_content": "", - "error": str(scrape_err) - }) + # Dispatch to the configured backend + backend = _get_backend() + + if backend == "parallel": + results = await _parallel_extract(urls) + elif backend == "tavily": + logger.info("Tavily extract: %d URL(s)", len(urls)) + raw = _tavily_request("extract", { + "urls": urls, + "include_images": False, + }) + results = _normalize_tavily_documents(raw, fallback_url=urls[0] if urls else "") + else: + # ── Firecrawl extraction ── + # Determine requested formats for Firecrawl v2 + formats: List[str] = [] + if format == "markdown": + formats = ["markdown"] + elif format == "html": + formats = ["html"] + else: + # Default: request markdown for LLM-readiness and include html as backup + formats = ["markdown", "html"] + + # Always use individual scraping for simplicity and reliability + # Batch scraping adds complexity without much benefit for small numbers of URLs + results: List[Dict[str, Any]] = [] + + from tools.interrupt import is_interrupted as _is_interrupted + for url in urls: + if _is_interrupted(): + results.append({"url": url, "error": "Interrupted", "title": ""}) + continue + + # Website policy check — block before fetching + blocked = check_website_access(url) + if blocked: + logger.info("Blocked web_extract for %s by rule %s", blocked["host"], blocked["rule"]) + results.append({ + "url": url, "title": "", "content": "", + "error": blocked["message"], + "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}, + }) + continue + + try: + logger.info("Scraping: %s", url) + scrape_result = _get_firecrawl_client().scrape( + url=url, + formats=formats + ) + + # Process the result - properly handle object serialization + metadata = {} + title = "" + content_markdown = None + content_html = None + + # Extract data from the scrape result + if hasattr(scrape_result, 'model_dump'): + # Pydantic model - use model_dump to get dict + result_dict = scrape_result.model_dump() + content_markdown = result_dict.get('markdown') + content_html = result_dict.get('html') + metadata = result_dict.get('metadata', {}) + elif hasattr(scrape_result, '__dict__'): + # Regular object with attributes + content_markdown = getattr(scrape_result, 'markdown', None) + content_html = getattr(scrape_result, 'html', None) + + # Handle metadata - convert to dict if it's an object + metadata_obj = getattr(scrape_result, 'metadata', {}) + if hasattr(metadata_obj, 'model_dump'): + metadata = metadata_obj.model_dump() + elif hasattr(metadata_obj, '__dict__'): + metadata = metadata_obj.__dict__ + elif isinstance(metadata_obj, dict): + metadata = metadata_obj + else: + metadata = {} + elif isinstance(scrape_result, dict): + # Already a dictionary + content_markdown = scrape_result.get('markdown') + content_html = scrape_result.get('html') + metadata = scrape_result.get('metadata', {}) + + # Ensure metadata is a dict (not an object) + if not isinstance(metadata, dict): + if hasattr(metadata, 'model_dump'): + metadata = metadata.model_dump() + elif hasattr(metadata, '__dict__'): + metadata = metadata.__dict__ + else: + metadata = {} + + # Get title from metadata + title = metadata.get("title", "") + + # Re-check final URL after redirect + final_url = metadata.get("sourceURL", url) + final_blocked = check_website_access(final_url) + if final_blocked: + logger.info("Blocked redirected web_extract for %s by rule %s", final_blocked["host"], final_blocked["rule"]) + results.append({ + "url": final_url, "title": title, "content": "", "raw_content": "", + "error": final_blocked["message"], + "blocked_by_policy": {"host": final_blocked["host"], "rule": final_blocked["rule"], "source": final_blocked["source"]}, + }) + continue + + # Choose content based on requested format + chosen_content = content_markdown if (format == "markdown" or (format is None and content_markdown)) else content_html or content_markdown or "" + + results.append({ + "url": final_url, + "title": title, + "content": chosen_content, + "raw_content": chosen_content, + "metadata": metadata # Now guaranteed to be a dict + }) + + except Exception as scrape_err: + logger.debug("Scrape failed for %s: %s", url, scrape_err) + results.append({ + "url": url, + "title": "", + "content": "", + "raw_content": "", + "error": str(scrape_err) + }) response = {"results": results} @@ -778,6 +1079,7 @@ async def web_extract_tool( "title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error"), + **({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}), } for r in response.get("results", []) ] @@ -862,6 +1164,91 @@ async def web_crawl_tool( } try: + backend = _get_backend() + + # Tavily supports crawl via its /crawl endpoint + if backend == "tavily": + # Ensure URL has protocol + if not url.startswith(('http://', 'https://')): + url = f'https://{url}' + + # Website policy check + blocked = check_website_access(url) + if blocked: + logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"]) + return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"], + "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False) + + from tools.interrupt import is_interrupted as _is_int + if _is_int(): + return json.dumps({"error": "Interrupted", "success": False}) + + logger.info("Tavily crawl: %s", url) + payload: Dict[str, Any] = { + "url": url, + "limit": 20, + "extract_depth": depth, + } + if instructions: + payload["instructions"] = instructions + raw = _tavily_request("crawl", payload) + results = _normalize_tavily_documents(raw, fallback_url=url) + + response = {"results": results} + # Fall through to the shared LLM processing and trimming below + # (skip the Firecrawl-specific crawl logic) + pages_crawled = len(response.get('results', [])) + logger.info("Crawled %d pages", pages_crawled) + debug_call_data["pages_crawled"] = pages_crawled + debug_call_data["original_response_size"] = len(json.dumps(response)) + + # Process each result with LLM if enabled + if use_llm_processing: + logger.info("Processing crawled content with LLM (parallel)...") + debug_call_data["processing_applied"].append("llm_processing") + + async def _process_tavily_crawl(result): + page_url = result.get('url', 'Unknown URL') + title = result.get('title', '') + content = result.get('content', '') + if not content: + return result, None, "no_content" + original_size = len(content) + processed = await process_content_with_llm(content, page_url, title, model, min_length) + if processed: + result['raw_content'] = content + result['content'] = processed + metrics = {"url": page_url, "original_size": original_size, "processed_size": len(processed), + "compression_ratio": len(processed) / original_size if original_size else 1.0, "model_used": model} + return result, metrics, "processed" + metrics = {"url": page_url, "original_size": original_size, "processed_size": original_size, + "compression_ratio": 1.0, "model_used": None, "reason": "content_too_short"} + return result, metrics, "too_short" + + tasks = [_process_tavily_crawl(r) for r in response.get('results', [])] + processed_results = await asyncio.gather(*tasks) + for result, metrics, status in processed_results: + if status == "processed": + debug_call_data["compression_metrics"].append(metrics) + debug_call_data["pages_processed_with_llm"] += 1 + + trimmed_results = [{"url": r.get("url", ""), "title": r.get("title", ""), "content": r.get("content", ""), "error": r.get("error"), + **({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {})} for r in response.get("results", [])] + result_json = json.dumps({"results": trimmed_results}, indent=2, ensure_ascii=False) + cleaned_result = clean_base64_images(result_json) + debug_call_data["final_response_size"] = len(cleaned_result) + _debug.log_call("web_crawl_tool", debug_call_data) + _debug.save() + return cleaned_result + + # web_crawl requires Firecrawl — Parallel has no crawl API + if not (os.getenv("FIRECRAWL_API_KEY") or os.getenv("FIRECRAWL_API_URL")): + return json.dumps({ + "error": "web_crawl requires Firecrawl. Set FIRECRAWL_API_KEY, " + "or use web_search + web_extract instead.", + "success": False, + }, ensure_ascii=False) + # Ensure URL has protocol if not url.startswith(('http://', 'https://')): url = f'https://{url}' @@ -870,6 +1257,13 @@ async def web_crawl_tool( instructions_text = f" with instructions: '{instructions}'" if instructions else "" logger.info("Crawling %s%s", url, instructions_text) + # Website policy check — block before crawling + blocked = check_website_access(url) + if blocked: + logger.info("Blocked web_crawl for %s by rule %s", blocked["host"], blocked["rule"]) + return json.dumps({"results": [{"url": url, "title": "", "content": "", "error": blocked["message"], + "blocked_by_policy": {"host": blocked["host"], "rule": blocked["rule"], "source": blocked["source"]}}]}, ensure_ascii=False) + # Use Firecrawl's v2 crawl functionality # Docs: https://docs.firecrawl.dev/features/crawl # The crawl() method automatically waits for completion and returns all data @@ -975,6 +1369,17 @@ async def web_crawl_tool( page_url = metadata.get("sourceURL", metadata.get("url", "Unknown URL")) title = metadata.get("title", "") + # Re-check crawled page URL against policy + page_blocked = check_website_access(page_url) + if page_blocked: + logger.info("Blocked crawled page %s by rule %s", page_blocked["host"], page_blocked["rule"]) + pages.append({ + "url": page_url, "title": title, "content": "", "raw_content": "", + "error": page_blocked["message"], + "blocked_by_policy": {"host": page_blocked["host"], "rule": page_blocked["rule"], "source": page_blocked["source"]}, + }) + continue + # Choose content (prefer markdown) content = content_markdown or content_html or "" @@ -1070,9 +1475,11 @@ async def web_crawl_tool( # Trim output to minimal fields per entry: title, content, error trimmed_results = [ { + "url": r.get("url", ""), "title": r.get("title", ""), "content": r.get("content", ""), - "error": r.get("error") + "error": r.get("error"), + **({ "blocked_by_policy": r["blocked_by_policy"]} if "blocked_by_policy" in r else {}), } for r in response.get("results", []) ] @@ -1106,13 +1513,23 @@ async def web_crawl_tool( def check_firecrawl_api_key() -> bool: """ Check if the Firecrawl API key is available in environment variables. - + Returns: bool: True if API key is set, False otherwise """ return bool(os.getenv("FIRECRAWL_API_KEY")) +def check_web_api_key() -> bool: + """Check if any web backend API key is available (Parallel, Firecrawl, or Tavily).""" + return bool( + os.getenv("PARALLEL_API_KEY") + or os.getenv("FIRECRAWL_API_KEY") + or os.getenv("FIRECRAWL_API_URL") + or os.getenv("TAVILY_API_KEY") + ) + + def check_auxiliary_model() -> bool: """Check if an auxiliary text model is available for LLM content processing.""" try: @@ -1139,26 +1556,32 @@ if __name__ == "__main__": print("=" * 40) # Check if API keys are available - firecrawl_available = check_firecrawl_api_key() + web_available = check_web_api_key() nous_available = check_auxiliary_model() - - if not firecrawl_available: - print("❌ FIRECRAWL_API_KEY environment variable not set") - print("Please set your API key: export FIRECRAWL_API_KEY='your-key-here'") - print("Get API key at: https://firecrawl.dev/") + + if web_available: + backend = _get_backend() + print(f"✅ Web backend: {backend}") + if backend == "parallel": + print(" Using Parallel API (https://parallel.ai)") + elif backend == "tavily": + print(" Using Tavily API (https://tavily.com)") + else: + print(" Using Firecrawl API (https://firecrawl.dev)") else: - print("✅ Firecrawl API key found") - + print("❌ No web search backend configured") + print("Set PARALLEL_API_KEY, TAVILY_API_KEY, or FIRECRAWL_API_KEY") + if not nous_available: print("❌ No auxiliary model available for LLM content processing") print("Set OPENROUTER_API_KEY, configure Nous Portal, or set OPENAI_BASE_URL + OPENAI_API_KEY") print("⚠️ Without an auxiliary model, LLM content processing will be disabled") else: print(f"✅ Auxiliary model available: {DEFAULT_SUMMARIZER_MODEL}") - - if not firecrawl_available: + + if not web_available: exit(1) - + print("🛠️ Web tools ready for use!") if nous_available: @@ -1256,8 +1679,8 @@ registry.register( toolset="web", schema=WEB_SEARCH_SCHEMA, handler=lambda args, **kw: web_search_tool(args.get("query", ""), limit=5), - check_fn=check_firecrawl_api_key, - requires_env=["FIRECRAWL_API_KEY"], + check_fn=check_web_api_key, + requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"], emoji="🔍", ) registry.register( @@ -1266,8 +1689,8 @@ registry.register( schema=WEB_EXTRACT_SCHEMA, handler=lambda args, **kw: web_extract_tool( args.get("urls", [])[:5] if isinstance(args.get("urls"), list) else [], "markdown"), - check_fn=check_firecrawl_api_key, - requires_env=["FIRECRAWL_API_KEY"], + check_fn=check_web_api_key, + requires_env=["PARALLEL_API_KEY", "FIRECRAWL_API_KEY", "TAVILY_API_KEY"], is_async=True, emoji="📄", ) diff --git a/tools/website_policy.py b/tools/website_policy.py new file mode 100644 index 000000000..2a3d2470f --- /dev/null +++ b/tools/website_policy.py @@ -0,0 +1,285 @@ +"""Website access policy helpers for URL-capable tools. + +This module loads a user-managed website blocklist from ~/.hermes/config.yaml +and optional shared list files. It is intentionally lightweight so web/browser +tools can enforce URL policy without pulling in the heavier CLI config stack. + +Policy is cached in memory with a short TTL so config changes take effect +quickly without re-reading the file on every URL check. +""" + +from __future__ import annotations + +import fnmatch +import logging +import os +import threading +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + +_DEFAULT_WEBSITE_BLOCKLIST = { + "enabled": False, + "domains": [], + "shared_files": [], +} + +# Cache: parsed policy + timestamp. Avoids re-reading config.yaml on every +# URL check (a web_crawl with 50 pages would otherwise mean 51 YAML parses). +_CACHE_TTL_SECONDS = 30.0 +_cache_lock = threading.Lock() +_cached_policy: Optional[Dict[str, Any]] = None +_cached_policy_path: Optional[str] = None +_cached_policy_time: float = 0.0 + + +def _get_hermes_home() -> Path: + return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) + + +def _get_default_config_path() -> Path: + return _get_hermes_home() / "config.yaml" + + +class WebsitePolicyError(Exception): + """Raised when a website policy file is malformed.""" + + +def _normalize_host(host: str) -> str: + return (host or "").strip().lower().rstrip(".") + + +def _normalize_rule(rule: Any) -> Optional[str]: + if not isinstance(rule, str): + return None + value = rule.strip().lower() + if not value or value.startswith("#"): + return None + if "://" in value: + parsed = urlparse(value) + value = parsed.netloc or parsed.path + value = value.split("/", 1)[0].strip().rstrip(".") + if value.startswith("www."): + value = value[4:] + return value or None + + +def _iter_blocklist_file_rules(path: Path) -> List[str]: + """Load rules from a shared blocklist file. + + Missing or unreadable files log a warning and return an empty list + rather than raising — a bad file path should not disable all web tools. + """ + try: + raw = path.read_text(encoding="utf-8") + except FileNotFoundError: + logger.warning("Shared blocklist file not found (skipping): %s", path) + return [] + except (OSError, UnicodeDecodeError) as exc: + logger.warning("Failed to read shared blocklist file %s (skipping): %s", path, exc) + return [] + + rules: List[str] = [] + for line in raw.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + normalized = _normalize_rule(stripped) + if normalized: + rules.append(normalized) + return rules + + +def _load_policy_config(config_path: Optional[Path] = None) -> Dict[str, Any]: + config_path = config_path or _get_default_config_path() + if not config_path.exists(): + return dict(_DEFAULT_WEBSITE_BLOCKLIST) + + try: + import yaml + except ImportError: + logger.debug("PyYAML not installed — website blocklist disabled") + return dict(_DEFAULT_WEBSITE_BLOCKLIST) + + try: + with open(config_path, encoding="utf-8") as f: + config = yaml.safe_load(f) or {} + except yaml.YAMLError as exc: + raise WebsitePolicyError(f"Invalid config YAML at {config_path}: {exc}") from exc + except OSError as exc: + raise WebsitePolicyError(f"Failed to read config file {config_path}: {exc}") from exc + if not isinstance(config, dict): + raise WebsitePolicyError("config root must be a mapping") + + security = config.get("security", {}) + if security is None: + security = {} + if not isinstance(security, dict): + raise WebsitePolicyError("security must be a mapping") + + website_blocklist = security.get("website_blocklist", {}) + if website_blocklist is None: + website_blocklist = {} + if not isinstance(website_blocklist, dict): + raise WebsitePolicyError("security.website_blocklist must be a mapping") + + policy = dict(_DEFAULT_WEBSITE_BLOCKLIST) + policy.update(website_blocklist) + return policy + + +def load_website_blocklist(config_path: Optional[Path] = None) -> Dict[str, Any]: + """Load and return the parsed website blocklist policy. + + Results are cached for ``_CACHE_TTL_SECONDS`` to avoid re-reading + config.yaml on every URL check. Pass an explicit ``config_path`` + to bypass the cache (used by tests). + """ + global _cached_policy, _cached_policy_path, _cached_policy_time + + resolved_path = str(config_path) if config_path else "__default__" + now = time.monotonic() + + # Return cached policy if still fresh and same path + if config_path is None: + with _cache_lock: + if ( + _cached_policy is not None + and _cached_policy_path == resolved_path + and (now - _cached_policy_time) < _CACHE_TTL_SECONDS + ): + return _cached_policy + + config_path = config_path or _get_default_config_path() + policy = _load_policy_config(config_path) + + raw_domains = policy.get("domains", []) or [] + if not isinstance(raw_domains, list): + raise WebsitePolicyError("security.website_blocklist.domains must be a list") + + raw_shared_files = policy.get("shared_files", []) or [] + if not isinstance(raw_shared_files, list): + raise WebsitePolicyError("security.website_blocklist.shared_files must be a list") + + enabled = policy.get("enabled", True) + if not isinstance(enabled, bool): + raise WebsitePolicyError("security.website_blocklist.enabled must be a boolean") + + rules: List[Dict[str, str]] = [] + seen: set[Tuple[str, str]] = set() + + for raw_rule in raw_domains: + normalized = _normalize_rule(raw_rule) + if normalized and ("config", normalized) not in seen: + rules.append({"pattern": normalized, "source": "config"}) + seen.add(("config", normalized)) + + for shared_file in raw_shared_files: + if not isinstance(shared_file, str) or not shared_file.strip(): + continue + path = Path(shared_file).expanduser() + if not path.is_absolute(): + path = (_get_hermes_home() / path).resolve() + for normalized in _iter_blocklist_file_rules(path): + key = (str(path), normalized) + if key in seen: + continue + rules.append({"pattern": normalized, "source": str(path)}) + seen.add(key) + + result = {"enabled": enabled, "rules": rules} + + # Cache the result (only for the default path — explicit paths are tests) + if config_path == _get_default_config_path(): + with _cache_lock: + _cached_policy = result + _cached_policy_path = "__default__" + _cached_policy_time = now + + return result + + +def invalidate_cache() -> None: + """Force the next ``check_website_access`` call to re-read config.""" + global _cached_policy + with _cache_lock: + _cached_policy = None + + +def _match_host_against_rule(host: str, pattern: str) -> bool: + if not host or not pattern: + return False + if pattern.startswith("*."): + return fnmatch.fnmatch(host, pattern) + return host == pattern or host.endswith(f".{pattern}") + + +def _extract_host_from_urlish(url: str) -> str: + parsed = urlparse(url) + host = _normalize_host(parsed.hostname or parsed.netloc) + if host: + return host + + if "://" not in url: + schemeless = urlparse(f"//{url}") + host = _normalize_host(schemeless.hostname or schemeless.netloc) + if host: + return host + + return "" + + +def check_website_access(url: str, config_path: Optional[Path] = None) -> Optional[Dict[str, str]]: + """Check whether a URL is allowed by the website blocklist policy. + + Returns ``None`` if access is allowed, or a dict with block metadata + (``host``, ``rule``, ``source``, ``message``) if blocked. + + Never raises on policy errors — logs a warning and returns ``None`` + (fail-open) so a config typo doesn't break all web tools. Pass + ``config_path`` explicitly (tests) to get strict error propagation. + """ + # Fast path: if no explicit config_path and the cached policy is disabled + # or empty, skip all work (no YAML read, no host extraction). + if config_path is None: + with _cache_lock: + if _cached_policy is not None and not _cached_policy.get("enabled"): + return None + + host = _extract_host_from_urlish(url) + if not host: + return None + + try: + policy = load_website_blocklist(config_path) + except WebsitePolicyError as exc: + if config_path is not None: + raise # Tests pass explicit paths — let errors propagate + logger.warning("Website policy config error (failing open): %s", exc) + return None + except Exception as exc: + logger.warning("Unexpected error loading website policy (failing open): %s", exc) + return None + + if not policy.get("enabled"): + return None + + for rule in policy.get("rules", []): + pattern = rule.get("pattern", "") + if _match_host_against_rule(host, pattern): + logger.info("Blocked URL %s — matched rule '%s' from %s", + url, pattern, rule.get("source", "config")) + return { + "url": url, + "host": host, + "rule": pattern, + "source": rule.get("source", "config"), + "message": ( + f"Blocked by website policy: '{host}' matched rule '{pattern}'" + f" from {rule.get('source', 'config')}" + ), + } + return None diff --git a/toolsets.py b/toolsets.py index b7b2e48fb..212b6ea22 100644 --- a/toolsets.py +++ b/toolsets.py @@ -130,6 +130,12 @@ TOOLSETS = { "includes": [] }, + "messaging": { + "description": "Cross-platform messaging: send messages to Telegram, Discord, Slack, SMS, etc.", + "tools": ["send_message"], + "includes": [] + }, + "rl": { "description": "RL training tools for running reinforcement learning on Tinker-Atropos", "tools": [ @@ -293,7 +299,7 @@ TOOLSETS = { }, "hermes-sms": { - "description": "SMS bot toolset - interact with Hermes via SMS (Telnyx)", + "description": "SMS bot toolset - interact with Hermes via SMS (Twilio)", "tools": _HERMES_CORE_TOOLS, "includes": [] }, diff --git a/website/docs/getting-started/quickstart.md b/website/docs/getting-started/quickstart.md index 66be25fd6..0418d473c 100644 --- a/website/docs/getting-started/quickstart.md +++ b/website/docs/getting-started/quickstart.md @@ -49,6 +49,9 @@ hermes setup # Or configure everything at once | **Kimi / Moonshot** | Moonshot-hosted coding and chat models | Set `KIMI_API_KEY` | | **MiniMax** | International MiniMax endpoint | Set `MINIMAX_API_KEY` | | **MiniMax China** | China-region MiniMax endpoint | Set `MINIMAX_CN_API_KEY` | +| **Alibaba Cloud** | Qwen models via DashScope | Set `DASHSCOPE_API_KEY` | +| **Kilo Code** | KiloCode-hosted models | Set `KILOCODE_API_KEY` | +| **Vercel AI Gateway** | Vercel AI Gateway routing | Set `AI_GATEWAY_API_KEY` | | **Custom Endpoint** | VLLM, SGLang, or any OpenAI-compatible API | Set base URL + API key | :::tip diff --git a/website/docs/reference/environment-variables.md b/website/docs/reference/environment-variables.md index 0b5afa4b8..a594b7a60 100644 --- a/website/docs/reference/environment-variables.md +++ b/website/docs/reference/environment-variables.md @@ -32,6 +32,8 @@ All variables go in `~/.hermes/.env`. You can also set them with `hermes config | `KILOCODE_BASE_URL` | Override Kilo Code base URL (default: `https://api.kilo.ai/api/gateway`) | | `ANTHROPIC_API_KEY` | Anthropic Console API key ([console.anthropic.com](https://console.anthropic.com/)) | | `ANTHROPIC_TOKEN` | Manual or legacy Anthropic OAuth/setup-token override | +| `DASHSCOPE_API_KEY` | Alibaba Cloud DashScope API key for Qwen models ([modelstudio.console.alibabacloud.com](https://modelstudio.console.alibabacloud.com/)) | +| `DASHSCOPE_BASE_URL` | Custom DashScope base URL (default: international endpoint) | | `CLAUDE_CODE_OAUTH_TOKEN` | Explicit Claude Code token override if you export one manually | | `HERMES_MODEL` | Preferred model name (checked before `LLM_MODEL`, used by gateway) | | `LLM_MODEL` | Default model name (fallback when not set in config.yaml) | @@ -46,7 +48,7 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe | Variable | Description | |----------|-------------| -| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode` (default: `auto`) | +| `HERMES_INFERENCE_PROVIDER` | Override provider selection: `auto`, `openrouter`, `nous`, `openai-codex`, `anthropic`, `zai`, `kimi-coding`, `minimax`, `minimax-cn`, `kilocode`, `alibaba` (default: `auto`) | | `HERMES_PORTAL_BASE_URL` | Override Nous Portal URL (for development/testing) | | `NOUS_INFERENCE_BASE_URL` | Override Nous inference API URL | | `HERMES_NOUS_MIN_KEY_TTL_SECONDS` | Min agent key TTL before re-mint (default: 1800 = 30min) | @@ -59,10 +61,13 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe | Variable | Description | |----------|-------------| +| `PARALLEL_API_KEY` | AI-native web search ([parallel.ai](https://parallel.ai/)) | | `FIRECRAWL_API_KEY` | Web scraping ([firecrawl.dev](https://firecrawl.dev/)) | | `FIRECRAWL_API_URL` | Custom Firecrawl API endpoint for self-hosted instances (optional) | | `BROWSERBASE_API_KEY` | Browser automation ([browserbase.com](https://browserbase.com/)) | | `BROWSERBASE_PROJECT_ID` | Browserbase project ID | +| `BROWSER_USE_API_KEY` | Browser Use cloud browser API key ([browser-use.com](https://browser-use.com/)) | +| `BROWSER_CDP_URL` | Chrome DevTools Protocol URL for local browser (set via `/browser connect`, e.g. `ws://localhost:9222`) | | `BROWSER_INACTIVITY_TIMEOUT` | Browser session inactivity timeout in seconds | | `FAL_KEY` | Image generation ([fal.ai](https://fal.ai/)) | | `GROQ_API_KEY` | Groq Whisper STT API key ([groq.com](https://groq.com/)) | @@ -151,6 +156,14 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe | `SIGNAL_HOME_CHANNEL_NAME` | Display name for the Signal home channel | | `SIGNAL_IGNORE_STORIES` | Ignore Signal stories/status updates | | `SIGNAL_ALLOW_ALL_USERS` | Allow all Signal users without an allowlist | +| `TWILIO_ACCOUNT_SID` | Twilio Account SID (shared with telephony skill) | +| `TWILIO_AUTH_TOKEN` | Twilio Auth Token (shared with telephony skill) | +| `TWILIO_PHONE_NUMBER` | Twilio phone number in E.164 format (shared with telephony skill) | +| `SMS_WEBHOOK_PORT` | Webhook listener port for inbound SMS (default: `8080`) | +| `SMS_ALLOWED_USERS` | Comma-separated E.164 phone numbers allowed to chat | +| `SMS_ALLOW_ALL_USERS` | Allow all SMS senders without an allowlist | +| `SMS_HOME_CHANNEL` | Phone number for cron job / notification delivery | +| `SMS_HOME_CHANNEL_NAME` | Display name for the SMS home channel | | `EMAIL_ADDRESS` | Email address for the Email gateway adapter | | `EMAIL_PASSWORD` | Password or app password for the email account | | `EMAIL_IMAP_HOST` | IMAP hostname for the email adapter | @@ -162,6 +175,21 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe | `EMAIL_HOME_ADDRESS_NAME` | Display name for the email home target | | `EMAIL_POLL_INTERVAL` | Email polling interval in seconds | | `EMAIL_ALLOW_ALL_USERS` | Allow all inbound email senders | +| `DINGTALK_CLIENT_ID` | DingTalk bot AppKey from developer portal ([open.dingtalk.com](https://open.dingtalk.com)) | +| `DINGTALK_CLIENT_SECRET` | DingTalk bot AppSecret from developer portal | +| `DINGTALK_ALLOWED_USERS` | Comma-separated DingTalk user IDs allowed to message the bot | +| `MATTERMOST_URL` | Mattermost server URL (e.g. `https://mm.example.com`) | +| `MATTERMOST_TOKEN` | Bot token or personal access token for Mattermost | +| `MATTERMOST_ALLOWED_USERS` | Comma-separated Mattermost user IDs allowed to message the bot | +| `MATTERMOST_HOME_CHANNEL` | Channel ID for proactive message delivery (cron, notifications) | +| `MATTERMOST_REPLY_MODE` | Reply style: `thread` (threaded replies) or `off` (flat messages, default) | +| `MATRIX_HOMESERVER` | Matrix homeserver URL (e.g. `https://matrix.org`) | +| `MATRIX_ACCESS_TOKEN` | Matrix access token for bot authentication | +| `MATRIX_USER_ID` | Matrix user ID (e.g. `@hermes:matrix.org`) — required for password login, optional with access token | +| `MATRIX_PASSWORD` | Matrix password (alternative to access token) | +| `MATRIX_ALLOWED_USERS` | Comma-separated Matrix user IDs allowed to message the bot (e.g. `@alice:matrix.org`) | +| `MATRIX_HOME_ROOM` | Room ID for proactive message delivery (e.g. `!abc123:matrix.org`) | +| `MATRIX_ENCRYPTION` | Enable end-to-end encryption (`true`/`false`, default: `false`) | | `HASS_TOKEN` | Home Assistant Long-Lived Access Token (enables HA platform + tools) | | `HASS_URL` | Home Assistant URL (default: `http://homeassistant.local:8123`) | | `MESSAGING_CWD` | Working directory for terminal commands in messaging mode (default: `~`) | diff --git a/website/docs/reference/slash-commands.md b/website/docs/reference/slash-commands.md index c3de04697..3c8ee77d3 100644 --- a/website/docs/reference/slash-commands.md +++ b/website/docs/reference/slash-commands.md @@ -52,8 +52,9 @@ Type `/` in the CLI to open the autocomplete menu. Built-in commands are case-in | Command | Description | |---------|-------------| -| `/tools` | List available tools | +| `/tools [list\|disable\|enable] [name...]` | Manage tools: list available tools, or disable/enable specific tools for the current session. Disabling a tool removes it from the agent's toolset and triggers a session reset. | | `/toolsets` | List available toolsets | +| `/browser [connect\|disconnect\|status]` | Manage local Chrome CDP connection. `connect` attaches browser tools to a running Chrome instance (default: `ws://localhost:9222`). `disconnect` detaches. `status` shows current connection. Auto-launches Chrome if no debugger is detected. | | `/skills` | Search, install, inspect, or manage skills from online registries | | `/cron` | Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove) | | `/reload-mcp` | Reload MCP servers from config.yaml | @@ -118,7 +119,7 @@ The messaging gateway supports the following built-in commands inside Telegram, ## Notes -- `/skin`, `/tools`, `/toolsets`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands. +- `/skin`, `/tools`, `/toolsets`, `/browser`, `/config`, `/prompt`, `/cron`, `/skills`, `/platforms`, `/paste`, and `/verbose` are **CLI-only** commands. - `/status`, `/stop`, `/sethome`, `/resume`, and `/update` are **messaging-only** commands. - `/background`, `/voice`, `/reload-mcp`, and `/rollback` work in **both** the CLI and the messaging gateway. - `/voice join`, `/voice channel`, and `/voice leave` are only meaningful on Discord. diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index f18d803d4..032b46179 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -70,7 +70,9 @@ You need at least one way to connect to an LLM. Use `hermes model` to switch pro | **Kimi / Moonshot** | `KIMI_API_KEY` in `~/.hermes/.env` (provider: `kimi-coding`) | | **MiniMax** | `MINIMAX_API_KEY` in `~/.hermes/.env` (provider: `minimax`) | | **MiniMax China** | `MINIMAX_CN_API_KEY` in `~/.hermes/.env` (provider: `minimax-cn`) | +| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`, aliases: `dashscope`, `qwen`) | | **Kilo Code** | `KILOCODE_API_KEY` in `~/.hermes/.env` (provider: `kilocode`) | +| **Alibaba Cloud** | `DASHSCOPE_API_KEY` in `~/.hermes/.env` (provider: `alibaba`) | | **Custom Endpoint** | `hermes model` (saved in `config.yaml`) or `OPENAI_BASE_URL` + `OPENAI_API_KEY` in `~/.hermes/.env` | :::info Codex Note @@ -135,16 +137,20 @@ hermes chat --provider minimax --model MiniMax-Text-01 # MiniMax (China endpoint) hermes chat --provider minimax-cn --model MiniMax-Text-01 # Requires: MINIMAX_CN_API_KEY in ~/.hermes/.env + +# Alibaba Cloud / DashScope (Qwen models) +hermes chat --provider alibaba --model qwen-plus +# Requires: DASHSCOPE_API_KEY in ~/.hermes/.env ``` Or set the provider permanently in `config.yaml`: ```yaml model: - provider: "zai" # or: kimi-coding, minimax, minimax-cn + provider: "zai" # or: kimi-coding, minimax, minimax-cn, alibaba default: "glm-4-plus" ``` -Base URLs can be overridden with `GLM_BASE_URL`, `KIMI_BASE_URL`, `MINIMAX_BASE_URL`, or `MINIMAX_CN_BASE_URL` environment variables. +Base URLs can be overridden with `GLM_BASE_URL`, `KIMI_BASE_URL`, `MINIMAX_BASE_URL`, `MINIMAX_CN_BASE_URL`, or `DASHSCOPE_BASE_URL` environment variables. ## Custom & Self-Hosted LLM Providers @@ -872,6 +878,7 @@ This controls both the `text_to_speech` tool and spoken replies in voice mode (` display: tool_progress: all # off | new | all | verbose skin: default # Built-in or custom CLI skin (see user-guide/features/skins) + theme_mode: auto # auto | light | dark — color scheme for skin-aware rendering personality: "kawaii" # Legacy cosmetic field still surfaced in some summaries compact: false # Compact output mode (less whitespace) resume_display: full # full (show previous messages on resume) | minimal (one-liner only) @@ -881,6 +888,18 @@ display: background_process_notifications: all # all | result | error | off (gateway only) ``` +### Theme mode + +The `theme_mode` setting controls whether skins render in light or dark mode: + +| Mode | Behavior | +|------|----------| +| `auto` (default) | Detects your terminal's background color automatically. Falls back to `dark` if detection fails. | +| `light` | Forces light-mode skin colors. Skins that define a `colors_light` override use those colors instead of the default dark-mode palette. | +| `dark` | Forces dark-mode skin colors. | + +This works with any skin — built-in or custom. Skin authors can provide `colors_light` in their skin definition for optimal light-terminal appearance. + | Mode | What you see | |------|-------------| | `off` | Silent — just the final response | @@ -1055,6 +1074,54 @@ browser: record_sessions: false # Auto-record browser sessions as WebM videos to ~/.hermes/browser_recordings/ ``` +The browser toolset supports multiple providers. See the [Browser feature page](/docs/user-guide/features/browser) for details on Browserbase, Browser Use, and local Chrome CDP setup. + +## Website Blocklist + +Block specific domains from being accessed by the agent's web and browser tools: + +```yaml +website_blocklist: + enabled: false # Enable URL blocking (default: false) + domains: # List of blocked domain patterns + - "*.internal.company.com" + - "admin.example.com" + - "*.local" + shared_files: # Load additional rules from external files + - "/etc/hermes/blocked-sites.txt" +``` + +When enabled, any URL matching a blocked domain pattern is rejected before the web or browser tool executes. This applies to `web_search`, `web_extract`, `browser_navigate`, and any tool that accesses URLs. + +Domain rules support: +- Exact domains: `admin.example.com` +- Wildcard subdomains: `*.internal.company.com` (blocks all subdomains) +- TLD wildcards: `*.local` + +Shared files contain one domain rule per line (blank lines and `#` comments are ignored). Missing or unreadable files log a warning but don't disable other web tools. + +The policy is cached for 30 seconds, so config changes take effect quickly without restart. + +## Smart Approvals + +Control how Hermes handles potentially dangerous commands: + +```yaml +approval_mode: ask # ask | smart | off +``` + +| Mode | Behavior | +|------|----------| +| `ask` (default) | Prompt the user before executing any flagged command. In the CLI, shows an interactive approval dialog. In messaging, queues a pending approval request. | +| `smart` | Use an auxiliary LLM to assess whether a flagged command is actually dangerous. Low-risk commands are auto-approved with session-level persistence. Genuinely risky commands are escalated to the user. | +| `off` | Skip all approval checks. Equivalent to `HERMES_YOLO_MODE=true`. **Use with caution.** | + +Smart mode is particularly useful for reducing approval fatigue — it lets the agent work more autonomously on safe operations while still catching genuinely destructive commands. + +:::warning +Setting `approval_mode: off` disables all safety checks for terminal commands. Only use this in trusted, sandboxed environments. +::: + ## Checkpoints Automatic filesystem snapshots before destructive file operations. See the [Checkpoints feature page](/docs/user-guide/features/checkpoints) for details. diff --git a/website/docs/user-guide/features/browser.md b/website/docs/user-guide/features/browser.md index ad6e6df81..0f7b2570c 100644 --- a/website/docs/user-guide/features/browser.md +++ b/website/docs/user-guide/features/browser.md @@ -1,27 +1,30 @@ --- title: Browser Automation -description: Control cloud browsers with Browserbase integration for web interaction, form filling, scraping, and more. +description: Control browsers with multiple providers, local Chrome via CDP, or cloud browsers for web interaction, form filling, scraping, and more. sidebar_label: Browser sidebar_position: 5 --- # Browser Automation -Hermes Agent includes a full browser automation toolset that can run in two modes: +Hermes Agent includes a full browser automation toolset with multiple backend options: - **Browserbase cloud mode** via [Browserbase](https://browserbase.com) for managed cloud browsers and anti-bot tooling +- **Browser Use cloud mode** via [Browser Use](https://browser-use.com) as an alternative cloud browser provider +- **Local Chrome via CDP** — connect browser tools to your own Chrome instance using `/browser connect` - **Local browser mode** via the `agent-browser` CLI and a local Chromium installation -In both modes, the agent can navigate websites, interact with page elements, fill forms, and extract information. +In all modes, the agent can navigate websites, interact with page elements, fill forms, and extract information. ## Overview -The browser tools use the `agent-browser` CLI. In Browserbase mode, `agent-browser` connects to Browserbase cloud sessions. In local mode, it drives a local Chromium installation. Pages are represented as **accessibility trees** (text-based snapshots), making them ideal for LLM agents. Interactive elements get ref IDs (like `@e1`, `@e2`) that the agent uses for clicking and typing. +Pages are represented as **accessibility trees** (text-based snapshots), making them ideal for LLM agents. Interactive elements get ref IDs (like `@e1`, `@e2`) that the agent uses for clicking and typing. Key capabilities: -- **Cloud execution** — no local browser needed -- **Built-in stealth** — random fingerprints, CAPTCHA solving, residential proxies +- **Multi-provider cloud execution** — Browserbase or Browser Use, no local browser needed +- **Local Chrome integration** — attach to your running Chrome via CDP for hands-on browsing +- **Built-in stealth** — random fingerprints, CAPTCHA solving, residential proxies (Browserbase) - **Session isolation** — each task gets its own browser session - **Automatic cleanup** — inactive sessions are closed after a timeout - **Vision analysis** — screenshot + AI analysis for visual understanding @@ -40,9 +43,48 @@ BROWSERBASE_PROJECT_ID=your-project-id-here Get your credentials at [browserbase.com](https://browserbase.com). +### Browser Use cloud mode + +To use Browser Use as your cloud browser provider, add: + +```bash +# Add to ~/.hermes/.env +BROWSER_USE_API_KEY=*** +``` + +Get your API key at [browser-use.com](https://browser-use.com). Browser Use provides a cloud browser via its REST API. If both Browserbase and Browser Use credentials are set, Browserbase takes priority. + +### Local Chrome via CDP (`/browser connect`) + +Instead of a cloud provider, you can attach Hermes browser tools to your own running Chrome instance via the Chrome DevTools Protocol (CDP). This is useful when you want to see what the agent is doing in real-time, interact with pages that require your own cookies/sessions, or avoid cloud browser costs. + +In the CLI, use: + +``` +/browser connect # Connect to Chrome at ws://localhost:9222 +/browser connect ws://host:port # Connect to a specific CDP endpoint +/browser status # Check current connection +/browser disconnect # Detach and return to cloud/local mode +``` + +If Chrome isn't already running with remote debugging, Hermes will attempt to auto-launch it with `--remote-debugging-port=9222`. + +:::tip +To start Chrome manually with CDP enabled: +```bash +# Linux +google-chrome --remote-debugging-port=9222 + +# macOS +"/Applications/Google Chrome.app/Contents/MacOS/Google Chrome" --remote-debugging-port=9222 +``` +::: + +When connected via CDP, all browser tools (`browser_navigate`, `browser_click`, etc.) operate on your live Chrome instance instead of spinning up a cloud session. + ### Local browser mode -If you do **not** set Browserbase credentials, Hermes can still use the browser tools through a local Chromium install driven by `agent-browser`. +If you do **not** set any cloud credentials and don't use `/browser connect`, Hermes can still use the browser tools through a local Chromium install driven by `agent-browser`. ### Optional Environment Variables @@ -232,10 +274,8 @@ If paid features aren't available on your plan, Hermes automatically falls back ## Limitations -- **Requires Browserbase account** — no local browser fallback -- **Requires `agent-browser` CLI** — must be installed via npm - **Text-based interaction** — relies on accessibility tree, not pixel coordinates - **Snapshot size** — large pages may be truncated or LLM-summarized at 8000 characters -- **Session timeout** — sessions expire based on your Browserbase plan settings -- **Cost** — each session consumes Browserbase credits; use `browser_close` when done +- **Session timeout** — cloud sessions expire based on your provider's plan settings +- **Cost** — cloud sessions consume provider credits; use `browser_close` when done. Use `/browser connect` for free local browsing. - **No file downloads** — cannot download files from the browser diff --git a/website/docs/user-guide/messaging/dingtalk.md b/website/docs/user-guide/messaging/dingtalk.md new file mode 100644 index 000000000..f7f5a00d2 --- /dev/null +++ b/website/docs/user-guide/messaging/dingtalk.md @@ -0,0 +1,192 @@ +--- +sidebar_position: 10 +title: "DingTalk" +description: "Set up Hermes Agent as a DingTalk chatbot" +--- + +# DingTalk Setup + +Hermes Agent integrates with DingTalk (钉钉) as a chatbot, letting you chat with your AI assistant through direct messages or group chats. The bot connects via DingTalk's Stream Mode — a long-lived WebSocket connection that requires no public URL or webhook server — and replies using markdown-formatted messages through DingTalk's session webhook API. + +Before setup, here's the part most people want to know: how Hermes behaves once it's in your DingTalk workspace. + +## How Hermes Behaves + +| Context | Behavior | +|---------|----------| +| **DMs (1:1 chat)** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. | +| **Group chats** | Hermes responds when you `@mention` it. Without a mention, Hermes ignores the message. | +| **Shared groups with multiple users** | By default, Hermes isolates session history per user inside the group. Two people talking in the same group do not share one transcript unless you explicitly disable that. | + +### Session Model in DingTalk + +By default: + +- each DM gets its own session +- each user in a shared group chat gets their own session inside that group + +This is controlled by `config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +Set it to `false` only if you explicitly want one shared conversation for the entire group: + +```yaml +group_sessions_per_user: false +``` + +This guide walks you through the full setup process — from creating your DingTalk bot to sending your first message. + +## Prerequisites + +Install the required Python packages: + +```bash +pip install dingtalk-stream httpx +``` + +- `dingtalk-stream` — DingTalk's official SDK for Stream Mode (WebSocket-based real-time messaging) +- `httpx` — async HTTP client used for sending replies via session webhooks + +## Step 1: Create a DingTalk App + +1. Go to the [DingTalk Developer Console](https://open-dev.dingtalk.com/). +2. Log in with your DingTalk admin account. +3. Click **Application Development** → **Custom Apps** → **Create App via H5 Micro-App** (or **Robot** depending on your console version). +4. Fill in: + - **App Name**: e.g., `Hermes Agent` + - **Description**: optional +5. After creating, navigate to **Credentials & Basic Info** to find your **Client ID** (AppKey) and **Client Secret** (AppSecret). Copy both. + +:::warning[Credentials shown only once] +The Client Secret is only displayed once when you create the app. If you lose it, you'll need to regenerate it. Never share these credentials publicly or commit them to Git. +::: + +## Step 2: Enable the Robot Capability + +1. In your app's settings page, go to **Add Capability** → **Robot**. +2. Enable the robot capability. +3. Under **Message Reception Mode**, select **Stream Mode** (recommended — no public URL needed). + +:::tip +Stream Mode is the recommended setup. It uses a long-lived WebSocket connection initiated from your machine, so you don't need a public IP, domain name, or webhook endpoint. This works behind NAT, firewalls, and on local machines. +::: + +## Step 3: Find Your DingTalk User ID + +Hermes Agent uses your DingTalk User ID to control who can interact with the bot. DingTalk User IDs are alphanumeric strings set by your organization's admin. + +To find yours: + +1. Ask your DingTalk organization admin — User IDs are configured in the DingTalk admin console under **Contacts** → **Members**. +2. Alternatively, the bot logs the `sender_id` for each incoming message. Start the gateway, send the bot a message, then check the logs for your ID. + +## Step 4: Configure Hermes Agent + +### Option A: Interactive Setup (Recommended) + +Run the guided setup command: + +```bash +hermes gateway setup +``` + +Select **DingTalk** when prompted, then paste your Client ID, Client Secret, and allowed user IDs when asked. + +### Option B: Manual Configuration + +Add the following to your `~/.hermes/.env` file: + +```bash +# Required +DINGTALK_CLIENT_ID=your-app-key +DINGTALK_CLIENT_SECRET=your-app-secret + +# Security: restrict who can interact with the bot +DINGTALK_ALLOWED_USERS=user-id-1 + +# Multiple allowed users (comma-separated) +# DINGTALK_ALLOWED_USERS=user-id-1,user-id-2 +``` + +Optional behavior settings in `~/.hermes/config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +- `group_sessions_per_user: true` keeps each participant's context isolated inside shared group chats + +### Start the Gateway + +Once configured, start the DingTalk gateway: + +```bash +hermes gateway +``` + +The bot should connect to DingTalk's Stream Mode within a few seconds. Send it a message — either a DM or in a group where it's been added — to test. + +:::tip +You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details. +::: + +## Troubleshooting + +### Bot is not responding to messages + +**Cause**: The robot capability isn't enabled, or `DINGTALK_ALLOWED_USERS` doesn't include your User ID. + +**Fix**: Verify the robot capability is enabled in your app settings and that Stream Mode is selected. Check that your User ID is in `DINGTALK_ALLOWED_USERS`. Restart the gateway. + +### "dingtalk-stream not installed" error + +**Cause**: The `dingtalk-stream` Python package is not installed. + +**Fix**: Install it: + +```bash +pip install dingtalk-stream httpx +``` + +### "DINGTALK_CLIENT_ID and DINGTALK_CLIENT_SECRET required" + +**Cause**: The credentials aren't set in your environment or `.env` file. + +**Fix**: Verify `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` are set correctly in `~/.hermes/.env`. The Client ID is your AppKey, and the Client Secret is your AppSecret from the DingTalk Developer Console. + +### Stream disconnects / reconnection loops + +**Cause**: Network instability, DingTalk platform maintenance, or credential issues. + +**Fix**: The adapter automatically reconnects with exponential backoff (2s → 5s → 10s → 30s → 60s). Check that your credentials are valid and your app hasn't been deactivated. Verify your network allows outbound WebSocket connections. + +### Bot is offline + +**Cause**: The Hermes gateway isn't running, or it failed to connect. + +**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong credentials, app deactivated, `dingtalk-stream` or `httpx` not installed. + +### "No session_webhook available" + +**Cause**: The bot tried to reply but doesn't have a session webhook URL. This typically happens if the webhook expired or the bot was restarted between receiving the message and sending the reply. + +**Fix**: Send a new message to the bot — each incoming message provides a fresh session webhook for replies. This is a normal DingTalk limitation; the bot can only reply to messages it has received recently. + +## Security + +:::warning +Always set `DINGTALK_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access. +::: + +For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md). + +## Notes + +- **Stream Mode**: No public URL, domain name, or webhook server needed. The connection is initiated from your machine via WebSocket, so it works behind NAT and firewalls. +- **Markdown responses**: Replies are formatted in DingTalk's markdown format for rich text display. +- **Message deduplication**: The adapter deduplicates messages with a 5-minute window to prevent processing the same message twice. +- **Auto-reconnection**: If the stream connection drops, the adapter automatically reconnects with exponential backoff. +- **Message length limit**: Responses are capped at 20,000 characters per message. Longer responses are truncated. diff --git a/website/docs/user-guide/messaging/index.md b/website/docs/user-guide/messaging/index.md index 0c17e65e6..c969b451d 100644 --- a/website/docs/user-guide/messaging/index.md +++ b/website/docs/user-guide/messaging/index.md @@ -1,12 +1,12 @@ --- sidebar_position: 1 title: "Messaging Gateway" -description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser — architecture and setup overview" +description: "Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or your browser — architecture and setup overview" --- # Messaging Gateway -Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, Email, Home Assistant, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. +Chat with Hermes from Telegram, Discord, Slack, WhatsApp, Signal, SMS, Email, Home Assistant, Mattermost, Matrix, DingTalk, or your browser. The gateway is a single background process that connects to all your configured platforms, handles sessions, runs cron jobs, and delivers voice messages. For the full voice feature set — including CLI microphone mode, spoken replies in messaging, and Discord voice-channel conversations — see [Voice Mode](/docs/user-guide/features/voice-mode) and [Use Voice Mode with Hermes](/docs/guides/use-voice-mode-with-hermes). @@ -21,8 +21,12 @@ flowchart TB wa[WhatsApp] sl[Slack] sig[Signal] + sms[SMS] em[Email] ha[Home Assistant] + mm[Mattermost] + mx[Matrix] + dt[DingTalk] end store["Session store
per chat"] @@ -35,8 +39,12 @@ flowchart TB wa --> store sl --> store sig --> store + sms --> store em --> store ha --> store + mm --> store + mx --> store + dt --> store store --> agent cron --> store ``` @@ -129,7 +137,11 @@ Configure per-platform overrides in `~/.hermes/gateway.json`: TELEGRAM_ALLOWED_USERS=123456789,987654321 DISCORD_ALLOWED_USERS=123456789012345678 SIGNAL_ALLOWED_USERS=+155****4567,+155****6543 +SMS_ALLOWED_USERS=+155****4567,+155****6543 EMAIL_ALLOWED_USERS=trusted@example.com,colleague@work.com +MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c +MATRIX_ALLOWED_USERS=@alice:matrix.org +DINGTALK_ALLOWED_USERS=user-id-1 # Or allow GATEWAY_ALLOWED_USERS=123456789,987654321 @@ -288,8 +300,12 @@ Each platform has its own toolset: | WhatsApp | `hermes-whatsapp` | Full tools including terminal | | Slack | `hermes-slack` | Full tools including terminal | | Signal | `hermes-signal` | Full tools including terminal | +| SMS | `hermes-sms` | Full tools including terminal | | Email | `hermes-email` | Full tools including terminal | | Home Assistant | `hermes-homeassistant` | Full tools + HA device control (ha_list_entities, ha_get_state, ha_call_service, ha_list_services) | +| Mattermost | `hermes-mattermost` | Full tools including terminal | +| Matrix | `hermes-matrix` | Full tools including terminal | +| DingTalk | `hermes-dingtalk` | Full tools including terminal | ## Next Steps @@ -298,5 +314,9 @@ Each platform has its own toolset: - [Slack Setup](slack.md) - [WhatsApp Setup](whatsapp.md) - [Signal Setup](signal.md) +- [SMS Setup (Twilio)](sms.md) - [Email Setup](email.md) - [Home Assistant Integration](homeassistant.md) +- [Mattermost Setup](mattermost.md) +- [Matrix Setup](matrix.md) +- [DingTalk Setup](dingtalk.md) diff --git a/website/docs/user-guide/messaging/matrix.md b/website/docs/user-guide/messaging/matrix.md new file mode 100644 index 000000000..020e15bd6 --- /dev/null +++ b/website/docs/user-guide/messaging/matrix.md @@ -0,0 +1,354 @@ +--- +sidebar_position: 9 +title: "Matrix" +description: "Set up Hermes Agent as a Matrix bot" +--- + +# Matrix Setup + +Hermes Agent integrates with Matrix, the open, federated messaging protocol. Matrix lets you run your own homeserver or use a public one like matrix.org — either way, you keep control of your communications. The bot connects via the `matrix-nio` Python SDK, processes messages through the Hermes Agent pipeline (including tool use, memory, and reasoning), and responds in real time. It supports text, file attachments, images, audio, video, and optional end-to-end encryption (E2EE). + +Hermes works with any Matrix homeserver — Synapse, Conduit, Dendrite, or matrix.org. + +Before setup, here's the part most people want to know: how Hermes behaves once it's connected. + +## How Hermes Behaves + +| Context | Behavior | +|---------|----------| +| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. | +| **Rooms** | Hermes responds to all messages in rooms it has joined. Room invites are auto-accepted. | +| **Threads** | Hermes supports Matrix threads (MSC3440). If you reply in a thread, Hermes keeps the thread context isolated from the main room timeline. | +| **Shared rooms with multiple users** | By default, Hermes isolates session history per user inside the room. Two people talking in the same room do not share one transcript unless you explicitly disable that. | + +:::tip +The bot automatically joins rooms when invited. Just invite the bot's Matrix user to any room and it will join and start responding. +::: + +### Session Model in Matrix + +By default: + +- each DM gets its own session +- each thread gets its own session namespace +- each user in a shared room gets their own session inside that room + +This is controlled by `config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +Set it to `false` only if you explicitly want one shared conversation for the entire room: + +```yaml +group_sessions_per_user: false +``` + +Shared sessions can be useful for a collaborative room, but they also mean: + +- users share context growth and token costs +- one person's long tool-heavy task can bloat everyone else's context +- one person's in-flight run can interrupt another person's follow-up in the same room + +This guide walks you through the full setup process — from creating your bot account to sending your first message. + +## Step 1: Create a Bot Account + +You need a Matrix user account for the bot. There are several ways to do this: + +### Option A: Register on Your Homeserver (Recommended) + +If you run your own homeserver (Synapse, Conduit, Dendrite): + +1. Use the admin API or registration tool to create a new user: + +```bash +# Synapse example +register_new_matrix_user -c /etc/synapse/homeserver.yaml http://localhost:8008 +``` + +2. Choose a username like `hermes` — the full user ID will be `@hermes:your-server.org`. + +### Option B: Use matrix.org or Another Public Homeserver + +1. Go to [Element Web](https://app.element.io) and create a new account. +2. Pick a username for your bot (e.g., `hermes-bot`). + +### Option C: Use Your Own Account + +You can also run Hermes as your own user. This means the bot posts as you — useful for personal assistants. + +## Step 2: Get an Access Token + +Hermes needs an access token to authenticate with the homeserver. You have two options: + +### Option A: Access Token (Recommended) + +The most reliable way to get a token: + +**Via Element:** +1. Log in to [Element](https://app.element.io) with the bot account. +2. Go to **Settings** → **Help & About**. +3. Scroll down and expand **Advanced** — the access token is displayed there. +4. **Copy it immediately.** + +**Via the API:** + +```bash +curl -X POST https://your-server/_matrix/client/v3/login \ + -H "Content-Type: application/json" \ + -d '{ + "type": "m.login.password", + "user": "@hermes:your-server.org", + "password": "your-password" + }' +``` + +The response includes an `access_token` field — copy it. + +:::warning[Keep your access token safe] +The access token gives full access to the bot's Matrix account. Never share it publicly or commit it to Git. If compromised, revoke it by logging out all sessions for that user. +::: + +### Option B: Password Login + +Instead of providing an access token, you can give Hermes the bot's user ID and password. Hermes will log in automatically on startup. This is simpler but means the password is stored in your `.env` file. + +```bash +MATRIX_USER_ID=@hermes:your-server.org +MATRIX_PASSWORD=your-password +``` + +## Step 3: Find Your Matrix User ID + +Hermes Agent uses your Matrix User ID to control who can interact with the bot. Matrix User IDs follow the format `@username:server`. + +To find yours: + +1. Open [Element](https://app.element.io) (or your preferred Matrix client). +2. Click your avatar → **Settings**. +3. Your User ID is displayed at the top of the profile (e.g., `@alice:matrix.org`). + +:::tip +Matrix User IDs always start with `@` and contain a `:` followed by the server name. For example: `@alice:matrix.org`, `@bob:your-server.com`. +::: + +## Step 4: Configure Hermes Agent + +### Option A: Interactive Setup (Recommended) + +Run the guided setup command: + +```bash +hermes gateway setup +``` + +Select **Matrix** when prompted, then provide your homeserver URL, access token (or user ID + password), and allowed user IDs when asked. + +### Option B: Manual Configuration + +Add the following to your `~/.hermes/.env` file: + +**Using an access token:** + +```bash +# Required +MATRIX_HOMESERVER=https://matrix.example.org +MATRIX_ACCESS_TOKEN=*** + +# Optional: user ID (auto-detected from token if omitted) +# MATRIX_USER_ID=@hermes:matrix.example.org + +# Security: restrict who can interact with the bot +MATRIX_ALLOWED_USERS=@alice:matrix.example.org + +# Multiple allowed users (comma-separated) +# MATRIX_ALLOWED_USERS=@alice:matrix.example.org,@bob:matrix.example.org +``` + +**Using password login:** + +```bash +# Required +MATRIX_HOMESERVER=https://matrix.example.org +MATRIX_USER_ID=@hermes:matrix.example.org +MATRIX_PASSWORD=*** + +# Security +MATRIX_ALLOWED_USERS=@alice:matrix.example.org +``` + +Optional behavior settings in `~/.hermes/config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +- `group_sessions_per_user: true` keeps each participant's context isolated inside shared rooms + +### Start the Gateway + +Once configured, start the Matrix gateway: + +```bash +hermes gateway +``` + +The bot should connect to your homeserver and start syncing within a few seconds. Send it a message — either a DM or in a room it has joined — to test. + +:::tip +You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details. +::: + +## End-to-End Encryption (E2EE) + +Hermes supports Matrix end-to-end encryption, so you can chat with your bot in encrypted rooms. + +### Requirements + +E2EE requires the `matrix-nio` library with encryption extras and the `libolm` C library: + +```bash +# Install matrix-nio with E2EE support +pip install 'matrix-nio[e2e]' + +# Or install with hermes extras +pip install 'hermes-agent[matrix]' +``` + +You also need `libolm` installed on your system: + +```bash +# Debian/Ubuntu +sudo apt install libolm-dev + +# macOS +brew install libolm + +# Fedora +sudo dnf install libolm-devel +``` + +### Enable E2EE + +Add to your `~/.hermes/.env`: + +```bash +MATRIX_ENCRYPTION=true +``` + +When E2EE is enabled, Hermes: + +- Stores encryption keys in `~/.hermes/matrix/store/` +- Uploads device keys on first connection +- Decrypts incoming messages and encrypts outgoing messages automatically +- Auto-joins encrypted rooms when invited + +:::warning +If you delete the `~/.hermes/matrix/store/` directory, the bot loses its encryption keys. You'll need to verify the device again in your Matrix client. Back up this directory if you want to preserve encrypted sessions. +::: + +:::info +If `matrix-nio[e2e]` is not installed or `libolm` is missing, the bot falls back to a plain (unencrypted) client automatically. You'll see a warning in the logs. +::: + +## Home Room + +You can designate a "home room" where the bot sends proactive messages (such as cron job output, reminders, and notifications). There are two ways to set it: + +### Using the Slash Command + +Type `/sethome` in any Matrix room where the bot is present. That room becomes the home room. + +### Manual Configuration + +Add this to your `~/.hermes/.env`: + +```bash +MATRIX_HOME_ROOM=!abc123def456:matrix.example.org +``` + +:::tip +To find a Room ID: in Element, go to the room → **Settings** → **Advanced** → the **Internal room ID** is shown there (starts with `!`). +::: + +## Troubleshooting + +### Bot is not responding to messages + +**Cause**: The bot hasn't joined the room, or `MATRIX_ALLOWED_USERS` doesn't include your User ID. + +**Fix**: Invite the bot to the room — it auto-joins on invite. Verify your User ID is in `MATRIX_ALLOWED_USERS` (use the full `@user:server` format). Restart the gateway. + +### "Failed to authenticate" / "whoami failed" on startup + +**Cause**: The access token or homeserver URL is incorrect. + +**Fix**: Verify `MATRIX_HOMESERVER` points to your homeserver (include `https://`, no trailing slash). Check that `MATRIX_ACCESS_TOKEN` is valid — try it with curl: + +```bash +curl -H "Authorization: Bearer YOUR_TOKEN" \ + https://your-server/_matrix/client/v3/account/whoami +``` + +If this returns your user info, the token is valid. If it returns an error, generate a new token. + +### "matrix-nio not installed" error + +**Cause**: The `matrix-nio` Python package is not installed. + +**Fix**: Install it: + +```bash +pip install 'matrix-nio[e2e]' +``` + +Or with Hermes extras: + +```bash +pip install 'hermes-agent[matrix]' +``` + +### Encryption errors / "could not decrypt event" + +**Cause**: Missing encryption keys, `libolm` not installed, or the bot's device isn't trusted. + +**Fix**: +1. Verify `libolm` is installed on your system (see the E2EE section above). +2. Make sure `MATRIX_ENCRYPTION=true` is set in your `.env`. +3. In your Matrix client (Element), go to the bot's profile → **Sessions** → verify/trust the bot's device. +4. If the bot just joined an encrypted room, it can only decrypt messages sent *after* it joined. Older messages are inaccessible. + +### Sync issues / bot falls behind + +**Cause**: Long-running tool executions can delay the sync loop, or the homeserver is slow. + +**Fix**: The sync loop automatically retries every 5 seconds on error. Check the Hermes logs for sync-related warnings. If the bot consistently falls behind, ensure your homeserver has adequate resources. + +### Bot is offline + +**Cause**: The Hermes gateway isn't running, or it failed to connect. + +**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong homeserver URL, expired access token, homeserver unreachable. + +### "User not allowed" / Bot ignores you + +**Cause**: Your User ID isn't in `MATRIX_ALLOWED_USERS`. + +**Fix**: Add your User ID to `MATRIX_ALLOWED_USERS` in `~/.hermes/.env` and restart the gateway. Use the full `@user:server` format. + +## Security + +:::warning +Always set `MATRIX_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access. +::: + +For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md). + +## Notes + +- **Any homeserver**: Works with Synapse, Conduit, Dendrite, matrix.org, or any spec-compliant Matrix homeserver. No specific homeserver software required. +- **Federation**: If you're on a federated homeserver, the bot can communicate with users from other servers — just add their full `@user:server` IDs to `MATRIX_ALLOWED_USERS`. +- **Auto-join**: The bot automatically accepts room invites and joins. It starts responding immediately after joining. +- **Media support**: Hermes can send and receive images, audio, video, and file attachments. Media is uploaded to your homeserver using the Matrix content repository API. diff --git a/website/docs/user-guide/messaging/mattermost.md b/website/docs/user-guide/messaging/mattermost.md new file mode 100644 index 000000000..f959bb872 --- /dev/null +++ b/website/docs/user-guide/messaging/mattermost.md @@ -0,0 +1,277 @@ +--- +sidebar_position: 8 +title: "Mattermost" +description: "Set up Hermes Agent as a Mattermost bot" +--- + +# Mattermost Setup + +Hermes Agent integrates with Mattermost as a bot, letting you chat with your AI assistant through direct messages or team channels. Mattermost is a self-hosted, open-source Slack alternative — you run it on your own infrastructure, keeping full control of your data. The bot connects via Mattermost's REST API (v4) and WebSocket for real-time events, processes messages through the Hermes Agent pipeline (including tool use, memory, and reasoning), and responds in real time. It supports text, file attachments, images, and slash commands. + +No external Mattermost library is required — the adapter uses `aiohttp`, which is already a Hermes dependency. + +Before setup, here's the part most people want to know: how Hermes behaves once it's in your Mattermost instance. + +## How Hermes Behaves + +| Context | Behavior | +|---------|----------| +| **DMs** | Hermes responds to every message. No `@mention` needed. Each DM has its own session. | +| **Public/private channels** | Hermes responds when you `@mention` it. Without a mention, Hermes ignores the message. | +| **Threads** | If `MATTERMOST_REPLY_MODE=thread`, Hermes replies in a thread under your message. Thread context stays isolated from the parent channel. | +| **Shared channels with multiple users** | By default, Hermes isolates session history per user inside the channel. Two people talking in the same channel do not share one transcript unless you explicitly disable that. | + +:::tip +If you want Hermes to reply as threaded conversations (nested under your original message), set `MATTERMOST_REPLY_MODE=thread`. The default is `off`, which sends flat messages in the channel. +::: + +### Session Model in Mattermost + +By default: + +- each DM gets its own session +- each thread gets its own session namespace +- each user in a shared channel gets their own session inside that channel + +This is controlled by `config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +Set it to `false` only if you explicitly want one shared conversation for the entire channel: + +```yaml +group_sessions_per_user: false +``` + +Shared sessions can be useful for a collaborative channel, but they also mean: + +- users share context growth and token costs +- one person's long tool-heavy task can bloat everyone else's context +- one person's in-flight run can interrupt another person's follow-up in the same channel + +This guide walks you through the full setup process — from creating your bot on Mattermost to sending your first message. + +## Step 1: Enable Bot Accounts + +Bot accounts must be enabled on your Mattermost server before you can create one. + +1. Log in to Mattermost as a **System Admin**. +2. Go to **System Console** → **Integrations** → **Bot Accounts**. +3. Set **Enable Bot Account Creation** to **true**. +4. Click **Save**. + +:::info +If you don't have System Admin access, ask your Mattermost administrator to enable bot accounts and create one for you. +::: + +## Step 2: Create a Bot Account + +1. In Mattermost, click the **☰** menu (top-left) → **Integrations** → **Bot Accounts**. +2. Click **Add Bot Account**. +3. Fill in the details: + - **Username**: e.g., `hermes` + - **Display Name**: e.g., `Hermes Agent` + - **Description**: optional + - **Role**: `Member` is sufficient +4. Click **Create Bot Account**. +5. Mattermost will display the **bot token**. **Copy it immediately.** + +:::warning[Token shown only once] +The bot token is only displayed once when you create the bot account. If you lose it, you'll need to regenerate it from the bot account settings. Never share your token publicly or commit it to Git — anyone with this token has full control of the bot. +::: + +Store the token somewhere safe (a password manager, for example). You'll need it in Step 5. + +:::tip +You can also use a **personal access token** instead of a bot account. Go to **Profile** → **Security** → **Personal Access Tokens** → **Create Token**. This is useful if you want Hermes to post as your own user rather than a separate bot user. +::: + +## Step 3: Add the Bot to Channels + +The bot needs to be a member of any channel where you want it to respond: + +1. Open the channel where you want the bot. +2. Click the channel name → **Add Members**. +3. Search for your bot username (e.g., `hermes`) and add it. + +For DMs, simply open a direct message with the bot — it will be able to respond immediately. + +## Step 4: Find Your Mattermost User ID + +Hermes Agent uses your Mattermost User ID to control who can interact with the bot. To find it: + +1. Click your **avatar** (top-left corner) → **Profile**. +2. Your User ID is displayed in the profile dialog — click it to copy. + +Your User ID is a 26-character alphanumeric string like `3uo8dkh1p7g1mfk49ear5fzs5c`. + +:::warning +Your User ID is **not** your username. The username is what appears after `@` (e.g., `@alice`). The User ID is a long alphanumeric identifier that Mattermost uses internally. +::: + +**Alternative**: You can also get your User ID via the API: + +```bash +curl -H "Authorization: Bearer YOUR_TOKEN" \ + https://your-mattermost-server/api/v4/users/me | jq .id +``` + +:::tip +To get a **Channel ID**: click the channel name → **View Info**. The Channel ID is shown in the info panel. You'll need this if you want to set a home channel manually. +::: + +## Step 5: Configure Hermes Agent + +### Option A: Interactive Setup (Recommended) + +Run the guided setup command: + +```bash +hermes gateway setup +``` + +Select **Mattermost** when prompted, then paste your server URL, bot token, and user ID when asked. + +### Option B: Manual Configuration + +Add the following to your `~/.hermes/.env` file: + +```bash +# Required +MATTERMOST_URL=https://mm.example.com +MATTERMOST_TOKEN=*** +MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c + +# Multiple allowed users (comma-separated) +# MATTERMOST_ALLOWED_USERS=3uo8dkh1p7g1mfk49ear5fzs5c,8fk2jd9s0a7bncm1xqw4tp6r3e + +# Optional: reply mode (thread or off, default: off) +# MATTERMOST_REPLY_MODE=thread +``` + +Optional behavior settings in `~/.hermes/config.yaml`: + +```yaml +group_sessions_per_user: true +``` + +- `group_sessions_per_user: true` keeps each participant's context isolated inside shared channels and threads + +### Start the Gateway + +Once configured, start the Mattermost gateway: + +```bash +hermes gateway +``` + +The bot should connect to your Mattermost server within a few seconds. Send it a message — either a DM or in a channel where it's been added — to test. + +:::tip +You can run `hermes gateway` in the background or as a systemd service for persistent operation. See the deployment docs for details. +::: + +## Home Channel + +You can designate a "home channel" where the bot sends proactive messages (such as cron job output, reminders, and notifications). There are two ways to set it: + +### Using the Slash Command + +Type `/sethome` in any Mattermost channel where the bot is present. That channel becomes the home channel. + +### Manual Configuration + +Add this to your `~/.hermes/.env`: + +```bash +MATTERMOST_HOME_CHANNEL=abc123def456ghi789jkl012mn +``` + +Replace the ID with the actual channel ID (click the channel name → View Info → copy the ID). + +## Reply Mode + +The `MATTERMOST_REPLY_MODE` setting controls how Hermes posts responses: + +| Mode | Behavior | +|------|----------| +| `off` (default) | Hermes posts flat messages in the channel, like a normal user. | +| `thread` | Hermes replies in a thread under your original message. Keeps channels clean when there's lots of back-and-forth. | + +Set it in your `~/.hermes/.env`: + +```bash +MATTERMOST_REPLY_MODE=thread +``` + +## Troubleshooting + +### Bot is not responding to messages + +**Cause**: The bot is not a member of the channel, or `MATTERMOST_ALLOWED_USERS` doesn't include your User ID. + +**Fix**: Add the bot to the channel (channel name → Add Members → search for the bot). Verify your User ID is in `MATTERMOST_ALLOWED_USERS`. Restart the gateway. + +### 403 Forbidden errors + +**Cause**: The bot token is invalid, or the bot doesn't have permission to post in the channel. + +**Fix**: Check that `MATTERMOST_TOKEN` in your `.env` file is correct. Make sure the bot account hasn't been deactivated. Verify the bot has been added to the channel. If using a personal access token, ensure your account has the required permissions. + +### WebSocket disconnects / reconnection loops + +**Cause**: Network instability, Mattermost server restarts, or firewall/proxy issues with WebSocket connections. + +**Fix**: The adapter automatically reconnects with exponential backoff (2s → 60s). Check your server's WebSocket configuration — reverse proxies (nginx, Apache) need WebSocket upgrade headers configured. Verify no firewall is blocking WebSocket connections on your Mattermost server. + +For nginx, ensure your config includes: + +```nginx +location /api/v4/websocket { + proxy_pass http://mattermost-backend; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_read_timeout 600s; +} +``` + +### "Failed to authenticate" on startup + +**Cause**: The token or server URL is incorrect. + +**Fix**: Verify `MATTERMOST_URL` points to your Mattermost server (include `https://`, no trailing slash). Check that `MATTERMOST_TOKEN` is valid — try it with curl: + +```bash +curl -H "Authorization: Bearer YOUR_TOKEN" \ + https://your-server/api/v4/users/me +``` + +If this returns your bot's user info, the token is valid. If it returns an error, regenerate the token. + +### Bot is offline + +**Cause**: The Hermes gateway isn't running, or it failed to connect. + +**Fix**: Check that `hermes gateway` is running. Look at the terminal output for error messages. Common issues: wrong URL, expired token, Mattermost server unreachable. + +### "User not allowed" / Bot ignores you + +**Cause**: Your User ID isn't in `MATTERMOST_ALLOWED_USERS`. + +**Fix**: Add your User ID to `MATTERMOST_ALLOWED_USERS` in `~/.hermes/.env` and restart the gateway. Remember: the User ID is a 26-character alphanumeric string, not your `@username`. + +## Security + +:::warning +Always set `MATTERMOST_ALLOWED_USERS` to restrict who can interact with the bot. Without it, the gateway denies all users by default as a safety measure. Only add User IDs of people you trust — authorized users have full access to the agent's capabilities, including tool use and system access. +::: + +For more information on securing your Hermes Agent deployment, see the [Security Guide](../security.md). + +## Notes + +- **Self-hosted friendly**: Works with any self-hosted Mattermost instance. No Mattermost Cloud account or subscription required. +- **No extra dependencies**: The adapter uses `aiohttp` for HTTP and WebSocket, which is already included with Hermes Agent. +- **Team Edition compatible**: Works with both Mattermost Team Edition (free) and Enterprise Edition. diff --git a/website/docs/user-guide/messaging/sms.md b/website/docs/user-guide/messaging/sms.md new file mode 100644 index 000000000..0aa835ffe --- /dev/null +++ b/website/docs/user-guide/messaging/sms.md @@ -0,0 +1,175 @@ +--- +sidebar_position: 8 +title: "SMS (Twilio)" +description: "Set up Hermes Agent as an SMS chatbot via Twilio" +--- + +# SMS Setup (Twilio) + +Hermes connects to SMS through the [Twilio](https://www.twilio.com/) API. People text your Twilio phone number and get AI responses back — same conversational experience as Telegram or Discord, but over standard text messages. + +:::info Shared Credentials +The SMS gateway shares credentials with the optional [telephony skill](/docs/reference/skills-catalog). If you've already set up Twilio for voice calls or one-off SMS, the gateway works with the same `TWILIO_ACCOUNT_SID`, `TWILIO_AUTH_TOKEN`, and `TWILIO_PHONE_NUMBER`. +::: + +--- + +## Prerequisites + +- **Twilio account** — [Sign up at twilio.com](https://www.twilio.com/try-twilio) (free trial available) +- **A Twilio phone number** with SMS capability +- **A publicly accessible server** — Twilio sends webhooks to your server when SMS arrives +- **aiohttp** — `pip install 'hermes-agent[sms]'` + +--- + +## Step 1: Get Your Twilio Credentials + +1. Go to the [Twilio Console](https://console.twilio.com/) +2. Copy your **Account SID** and **Auth Token** from the dashboard +3. Go to **Phone Numbers → Manage → Active Numbers** — note your phone number in E.164 format (e.g., `+15551234567`) + +--- + +## Step 2: Configure Hermes + +### Interactive setup (recommended) + +```bash +hermes gateway setup +``` + +Select **SMS (Twilio)** from the platform list. The wizard will prompt for your credentials. + +### Manual setup + +Add to `~/.hermes/.env`: + +```bash +TWILIO_ACCOUNT_SID=ACxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx +TWILIO_AUTH_TOKEN=your_auth_token_here +TWILIO_PHONE_NUMBER=+15551234567 + +# Security: restrict to specific phone numbers (recommended) +SMS_ALLOWED_USERS=+15559876543,+15551112222 + +# Optional: set a home channel for cron job delivery +SMS_HOME_CHANNEL=+15559876543 +``` + +--- + +## Step 3: Configure Twilio Webhook + +Twilio needs to know where to send incoming messages. In the [Twilio Console](https://console.twilio.com/): + +1. Go to **Phone Numbers → Manage → Active Numbers** +2. Click your phone number +3. Under **Messaging → A MESSAGE COMES IN**, set: + - **Webhook**: `https://your-server:8080/webhooks/twilio` + - **HTTP Method**: `POST` + +:::tip Exposing Your Webhook +If you're running Hermes locally, use a tunnel to expose the webhook: + +```bash +# Using cloudflared +cloudflared tunnel --url http://localhost:8080 + +# Using ngrok +ngrok http 8080 +``` + +Set the resulting public URL as your Twilio webhook. +::: + +The webhook port defaults to `8080`. Override with: + +```bash +SMS_WEBHOOK_PORT=3000 +``` + +--- + +## Step 4: Start the Gateway + +```bash +hermes gateway +``` + +You should see: + +``` +[sms] Twilio webhook server listening on port 8080, from: +1555***4567 +``` + +Text your Twilio number — Hermes will respond via SMS. + +--- + +## Environment Variables + +| Variable | Required | Description | +|----------|----------|-------------| +| `TWILIO_ACCOUNT_SID` | Yes | Twilio Account SID (starts with `AC`) | +| `TWILIO_AUTH_TOKEN` | Yes | Twilio Auth Token | +| `TWILIO_PHONE_NUMBER` | Yes | Your Twilio phone number (E.164 format) | +| `SMS_WEBHOOK_PORT` | No | Webhook listener port (default: `8080`) | +| `SMS_ALLOWED_USERS` | No | Comma-separated E.164 phone numbers allowed to chat | +| `SMS_ALLOW_ALL_USERS` | No | Set to `true` to allow anyone (not recommended) | +| `SMS_HOME_CHANNEL` | No | Phone number for cron job / notification delivery | +| `SMS_HOME_CHANNEL_NAME` | No | Display name for the home channel (default: `Home`) | + +--- + +## SMS-Specific Behavior + +- **Plain text only** — Markdown is automatically stripped since SMS renders it as literal characters +- **1600 character limit** — Longer responses are split across multiple messages at natural boundaries (newlines, then spaces) +- **Echo prevention** — Messages from your own Twilio number are ignored to prevent loops +- **Phone number redaction** — Phone numbers are redacted in logs for privacy + +--- + +## Security + +**The gateway denies all users by default.** Configure an allowlist: + +```bash +# Recommended: restrict to specific phone numbers +SMS_ALLOWED_USERS=+15559876543,+15551112222 + +# Or allow all (NOT recommended for bots with terminal access) +SMS_ALLOW_ALL_USERS=true +``` + +:::warning +SMS has no built-in encryption. Don't use SMS for sensitive operations unless you understand the security implications. For sensitive use cases, prefer Signal or Telegram. +::: + +--- + +## Troubleshooting + +### Messages not arriving + +1. Check your Twilio webhook URL is correct and publicly accessible +2. Verify `TWILIO_ACCOUNT_SID` and `TWILIO_AUTH_TOKEN` are correct +3. Check the Twilio Console → **Monitor → Logs → Messaging** for delivery errors +4. Ensure your phone number is in `SMS_ALLOWED_USERS` (or `SMS_ALLOW_ALL_USERS=true`) + +### Replies not sending + +1. Check `TWILIO_PHONE_NUMBER` is set correctly (E.164 format with `+`) +2. Verify your Twilio account has SMS-capable numbers +3. Check Hermes gateway logs for Twilio API errors + +### Webhook port conflicts + +If port 8080 is already in use, change it: + +```bash +SMS_WEBHOOK_PORT=3001 +``` + +Update the webhook URL in Twilio Console to match. diff --git a/website/docs/user-guide/security.md b/website/docs/user-guide/security.md index d31cc1757..d6d14db8d 100644 --- a/website/docs/user-guide/security.md +++ b/website/docs/user-guide/security.md @@ -277,6 +277,25 @@ Error messages from MCP tools are sanitized before being returned to the LLM. Th - Bearer tokens - `token=`, `key=`, `API_KEY=`, `password=`, `secret=` parameters +### Website Access Policy + +You can restrict which websites the agent can access through its web and browser tools. This is useful for preventing the agent from accessing internal services, admin panels, or other sensitive URLs. + +```yaml +# In ~/.hermes/config.yaml +website_blocklist: + enabled: true + domains: + - "*.internal.company.com" + - "admin.example.com" + shared_files: + - "/etc/hermes/blocked-sites.txt" +``` + +When a blocked URL is requested, the tool returns an error explaining the domain is blocked by policy. The blocklist is enforced across `web_search`, `web_extract`, `browser_navigate`, and all URL-capable tools. + +See [Website Blocklist](/docs/user-guide/configuration#website-blocklist) in the configuration guide for full details. + ### Context File Injection Protection Context files (AGENTS.md, .cursorrules, SOUL.md) are scanned for prompt injection before being included in the system prompt. The scanner checks for: diff --git a/website/sidebars.ts b/website/sidebars.ts index ac46028b4..935cdaffe 100644 --- a/website/sidebars.ts +++ b/website/sidebars.ts @@ -48,6 +48,9 @@ const sidebars: SidebarsConfig = { 'user-guide/messaging/signal', 'user-guide/messaging/email', 'user-guide/messaging/homeassistant', + 'user-guide/messaging/mattermost', + 'user-guide/messaging/matrix', + 'user-guide/messaging/dingtalk', ], }, {