- Add logger + debug log to read_nous_access_token() catch-all so token refresh failures are observable instead of silently swallowed - Tighten _is_nous_auxiliary_client() domain check to use proper URL hostname parsing instead of substring match, preventing false-positives on domains like not-nousresearch.com or nousresearch.com.evil.com
168 lines
5.3 KiB
Python
168 lines
5.3 KiB
Python
"""Generic managed-tool gateway helpers for Nous-hosted vendor passthroughs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timezone
|
|
from dataclasses import dataclass
|
|
from typing import Callable, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from hermes_constants import get_hermes_home
|
|
from tools.tool_backend_helpers import managed_nous_tools_enabled
|
|
|
|
_DEFAULT_TOOL_GATEWAY_DOMAIN = "nousresearch.com"
|
|
_DEFAULT_TOOL_GATEWAY_SCHEME = "https"
|
|
_NOUS_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ManagedToolGatewayConfig:
|
|
vendor: str
|
|
gateway_origin: str
|
|
nous_user_token: str
|
|
managed_mode: bool
|
|
|
|
|
|
def auth_json_path():
|
|
"""Return the Hermes auth store path, respecting HERMES_HOME overrides."""
|
|
return get_hermes_home() / "auth.json"
|
|
|
|
|
|
def _read_nous_provider_state() -> Optional[dict]:
|
|
try:
|
|
path = auth_json_path()
|
|
if not path.is_file():
|
|
return None
|
|
data = json.loads(path.read_text())
|
|
providers = data.get("providers", {})
|
|
if not isinstance(providers, dict):
|
|
return None
|
|
nous_provider = providers.get("nous", {})
|
|
if isinstance(nous_provider, dict):
|
|
return nous_provider
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _parse_timestamp(value: object) -> Optional[datetime]:
|
|
if not isinstance(value, str) or not value.strip():
|
|
return None
|
|
normalized = value.strip()
|
|
if normalized.endswith("Z"):
|
|
normalized = normalized[:-1] + "+00:00"
|
|
try:
|
|
parsed = datetime.fromisoformat(normalized)
|
|
except ValueError:
|
|
return None
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
return parsed.astimezone(timezone.utc)
|
|
|
|
|
|
def _access_token_is_expiring(expires_at: object, skew_seconds: int) -> bool:
|
|
expires = _parse_timestamp(expires_at)
|
|
if expires is None:
|
|
return True
|
|
remaining = (expires - datetime.now(timezone.utc)).total_seconds()
|
|
return remaining <= max(0, int(skew_seconds))
|
|
|
|
|
|
def read_nous_access_token() -> Optional[str]:
|
|
"""Read a Nous Subscriber OAuth access token from auth store or env override."""
|
|
explicit = os.getenv("TOOL_GATEWAY_USER_TOKEN")
|
|
if isinstance(explicit, str) and explicit.strip():
|
|
return explicit.strip()
|
|
|
|
nous_provider = _read_nous_provider_state() or {}
|
|
access_token = nous_provider.get("access_token")
|
|
cached_token = access_token.strip() if isinstance(access_token, str) and access_token.strip() else None
|
|
|
|
if cached_token and not _access_token_is_expiring(
|
|
nous_provider.get("expires_at"),
|
|
_NOUS_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
|
):
|
|
return cached_token
|
|
|
|
try:
|
|
from hermes_cli.auth import resolve_nous_access_token
|
|
|
|
refreshed_token = resolve_nous_access_token(
|
|
refresh_skew_seconds=_NOUS_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
|
)
|
|
if isinstance(refreshed_token, str) and refreshed_token.strip():
|
|
return refreshed_token.strip()
|
|
except Exception as exc:
|
|
logger.debug("Nous access token refresh failed: %s", exc)
|
|
|
|
return cached_token
|
|
|
|
|
|
def get_tool_gateway_scheme() -> str:
|
|
"""Return configured shared gateway URL scheme."""
|
|
scheme = os.getenv("TOOL_GATEWAY_SCHEME", "").strip().lower()
|
|
if not scheme:
|
|
return _DEFAULT_TOOL_GATEWAY_SCHEME
|
|
|
|
if scheme in {"http", "https"}:
|
|
return scheme
|
|
|
|
raise ValueError("TOOL_GATEWAY_SCHEME must be 'http' or 'https'")
|
|
|
|
|
|
def build_vendor_gateway_url(vendor: str) -> str:
|
|
"""Return the gateway origin for a specific vendor."""
|
|
vendor_key = f"{vendor.upper().replace('-', '_')}_GATEWAY_URL"
|
|
explicit_vendor_url = os.getenv(vendor_key, "").strip().rstrip("/")
|
|
if explicit_vendor_url:
|
|
return explicit_vendor_url
|
|
|
|
shared_scheme = get_tool_gateway_scheme()
|
|
shared_domain = os.getenv("TOOL_GATEWAY_DOMAIN", "").strip().strip("/")
|
|
if shared_domain:
|
|
return f"{shared_scheme}://{vendor}-gateway.{shared_domain}"
|
|
|
|
return f"{shared_scheme}://{vendor}-gateway.{_DEFAULT_TOOL_GATEWAY_DOMAIN}"
|
|
|
|
|
|
def resolve_managed_tool_gateway(
|
|
vendor: str,
|
|
gateway_builder: Optional[Callable[[str], str]] = None,
|
|
token_reader: Optional[Callable[[], Optional[str]]] = None,
|
|
) -> Optional[ManagedToolGatewayConfig]:
|
|
"""Resolve shared managed-tool gateway config for a vendor."""
|
|
if not managed_nous_tools_enabled():
|
|
return None
|
|
|
|
resolved_gateway_builder = gateway_builder or build_vendor_gateway_url
|
|
resolved_token_reader = token_reader or read_nous_access_token
|
|
|
|
gateway_origin = resolved_gateway_builder(vendor)
|
|
nous_user_token = resolved_token_reader()
|
|
if not gateway_origin or not nous_user_token:
|
|
return None
|
|
|
|
return ManagedToolGatewayConfig(
|
|
vendor=vendor,
|
|
gateway_origin=gateway_origin,
|
|
nous_user_token=nous_user_token,
|
|
managed_mode=True,
|
|
)
|
|
|
|
|
|
def is_managed_tool_gateway_ready(
|
|
vendor: str,
|
|
gateway_builder: Optional[Callable[[str], str]] = None,
|
|
token_reader: Optional[Callable[[], Optional[str]]] = None,
|
|
) -> bool:
|
|
"""Return True when gateway URL and Nous access token are available."""
|
|
return resolve_managed_tool_gateway(
|
|
vendor,
|
|
gateway_builder=gateway_builder,
|
|
token_reader=token_reader,
|
|
) is not None
|