From 5e5e0efc60884649f3d4e53fc73c9687176db36f Mon Sep 17 00:00:00 2001 From: Robin Fernandes Date: Mon, 2 Mar 2026 17:18:15 +1100 Subject: [PATCH] Fix nous refresh token rotation failure in case where api key mint/retrieval fails --- hermes_cli/auth.py | 142 ++++++++++++++++++++++++++-- tests/test_auth_nous_provider.py | 156 +++++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 tests/test_auth_nous_provider.py diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 34b07b71b..7a2fba0a9 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -21,8 +21,10 @@ import os import shutil import stat import base64 +import hashlib import subprocess import time +import uuid import webbrowser from contextlib import contextmanager from dataclasses import dataclass, field @@ -147,6 +149,31 @@ def format_auth_error(error: Exception) -> str: return str(error) +def _token_fingerprint(token: Any) -> Optional[str]: + """Return a short hash fingerprint for telemetry without leaking token bytes.""" + if not isinstance(token, str): + return None + cleaned = token.strip() + if not cleaned: + return None + return hashlib.sha256(cleaned.encode("utf-8")).hexdigest()[:12] + + +def _oauth_trace_enabled() -> bool: + raw = os.getenv("HERMES_OAUTH_TRACE", "").strip().lower() + return raw in {"1", "true", "yes", "on"} + + +def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None: + if not _oauth_trace_enabled(): + return + payload: Dict[str, Any] = {"event": event} + if sequence_id: + payload["sequence_id"] = sequence_id + payload.update(fields) + logger.info("oauth_trace %s", json.dumps(payload, sort_keys=True, ensure_ascii=False)) + + # ============================================================================= # Auth Store — persistence layer for ~/.hermes/auth.json # ============================================================================= @@ -216,7 +243,29 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path: auth_file.parent.mkdir(parents=True, exist_ok=True) auth_store["version"] = AUTH_STORE_VERSION auth_store["updated_at"] = datetime.now(timezone.utc).isoformat() - auth_file.write_text(json.dumps(auth_store, indent=2) + "\n") + payload = json.dumps(auth_store, indent=2) + "\n" + tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}") + try: + with tmp_path.open("w", encoding="utf-8") as handle: + handle.write(payload) + handle.flush() + os.fsync(handle.fileno()) + os.replace(tmp_path, auth_file) + try: + dir_fd = os.open(str(auth_file.parent), os.O_RDONLY) + except OSError: + dir_fd = None + if dir_fd is not None: + try: + os.fsync(dir_fd) + finally: + os.close(dir_fd) + finally: + try: + if tmp_path.exists(): + tmp_path.unlink() + except OSError: + pass # Restrict file permissions to owner only try: auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR) @@ -906,6 +955,7 @@ def resolve_nous_runtime_credentials( expires_in, source ("cache" or "portal"). """ min_key_ttl_seconds = max(60, int(min_key_ttl_seconds)) + sequence_id = uuid.uuid4().hex[:12] with _auth_store_lock(): auth_store = _load_auth_store() @@ -928,8 +978,35 @@ def resolve_nous_runtime_credentials( ).rstrip("/") client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) + def _persist_state(reason: str) -> None: + try: + _save_provider_state(auth_store, "nous", state) + _save_auth_store(auth_store) + except Exception as exc: + _oauth_trace( + "nous_state_persist_failed", + sequence_id=sequence_id, + reason=reason, + error_type=type(exc).__name__, + ) + raise + _oauth_trace( + "nous_state_persisted", + sequence_id=sequence_id, + reason=reason, + refresh_token_fp=_token_fingerprint(state.get("refresh_token")), + access_token_fp=_token_fingerprint(state.get("access_token")), + ) + verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state) timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0) + _oauth_trace( + "nous_runtime_credentials_start", + sequence_id=sequence_id, + force_mint=bool(force_mint), + min_key_ttl_seconds=min_key_ttl_seconds, + refresh_token_fp=_token_fingerprint(state.get("refresh_token")), + ) with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client: access_token = state.get("access_token") @@ -945,12 +1022,19 @@ def resolve_nous_runtime_credentials( raise AuthError("Session expired and no refresh token is available.", provider="nous", relogin_required=True) + _oauth_trace( + "refresh_start", + sequence_id=sequence_id, + reason="access_expiring", + refresh_token_fp=_token_fingerprint(refresh_token), + ) refreshed = _refresh_access_token( client=client, portal_base_url=portal_base_url, client_id=client_id, refresh_token=refresh_token, ) now = datetime.now(timezone.utc) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) + previous_refresh_token = refresh_token state["access_token"] = refreshed["access_token"] state["refresh_token"] = refreshed.get("refresh_token") or refresh_token state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" @@ -964,6 +1048,16 @@ def resolve_nous_runtime_credentials( now.timestamp() + access_ttl, tz=timezone.utc ).isoformat() access_token = state["access_token"] + refresh_token = state["refresh_token"] + _oauth_trace( + "refresh_success", + sequence_id=sequence_id, + reason="access_expiring", + previous_refresh_token_fp=_token_fingerprint(previous_refresh_token), + new_refresh_token_fp=_token_fingerprint(refresh_token), + ) + # Persist immediately so downstream mint failures cannot drop rotated refresh tokens. + _persist_state("post_refresh_access_expiring") # Step 2: mint agent key if missing/expiring used_cached_key = False @@ -971,23 +1065,45 @@ def resolve_nous_runtime_credentials( if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds): used_cached_key = True + _oauth_trace("agent_key_reuse", sequence_id=sequence_id) else: try: + _oauth_trace( + "mint_start", + sequence_id=sequence_id, + access_token_fp=_token_fingerprint(access_token), + ) mint_payload = _mint_agent_key( client=client, portal_base_url=portal_base_url, access_token=access_token, min_ttl_seconds=min_key_ttl_seconds, ) except AuthError as exc: + _oauth_trace( + "mint_error", + sequence_id=sequence_id, + code=exc.code, + ) # Retry path: access token may be stale server-side despite local checks - if exc.code in {"invalid_token", "invalid_grant"} and isinstance(refresh_token, str) and refresh_token: + latest_refresh_token = state.get("refresh_token") + if ( + exc.code in {"invalid_token", "invalid_grant"} + and isinstance(latest_refresh_token, str) + and latest_refresh_token + ): + _oauth_trace( + "refresh_start", + sequence_id=sequence_id, + reason="mint_retry_after_invalid_token", + refresh_token_fp=_token_fingerprint(latest_refresh_token), + ) refreshed = _refresh_access_token( client=client, portal_base_url=portal_base_url, - client_id=client_id, refresh_token=refresh_token, + client_id=client_id, refresh_token=latest_refresh_token, ) now = datetime.now(timezone.utc) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) state["access_token"] = refreshed["access_token"] - state["refresh_token"] = refreshed.get("refresh_token") or refresh_token + state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["scope"] = refreshed.get("scope") or state.get("scope") refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) @@ -999,6 +1115,16 @@ def resolve_nous_runtime_credentials( now.timestamp() + access_ttl, tz=timezone.utc ).isoformat() access_token = state["access_token"] + refresh_token = state["refresh_token"] + _oauth_trace( + "refresh_success", + sequence_id=sequence_id, + reason="mint_retry_after_invalid_token", + previous_refresh_token_fp=_token_fingerprint(latest_refresh_token), + new_refresh_token_fp=_token_fingerprint(refresh_token), + ) + # Persist retry refresh immediately for crash safety and cross-process visibility. + _persist_state("post_refresh_mint_retry") mint_payload = _mint_agent_key( client=client, portal_base_url=portal_base_url, @@ -1018,6 +1144,11 @@ def resolve_nous_runtime_credentials( minted_url = _optional_base_url(mint_payload.get("inference_base_url")) if minted_url: inference_base_url = minted_url + _oauth_trace( + "mint_success", + sequence_id=sequence_id, + reused=bool(mint_payload.get("reused", False)), + ) # Persist routing and TLS metadata for non-interactive refresh/mint state["portal_base_url"] = portal_base_url @@ -1028,8 +1159,7 @@ def resolve_nous_runtime_credentials( "ca_bundle": verify if isinstance(verify, str) else None, } - _save_provider_state(auth_store, "nous", state) - _save_auth_store(auth_store) + _persist_state("resolve_nous_runtime_credentials_final") api_key = state.get("agent_key") if not isinstance(api_key, str) or not api_key: diff --git a/tests/test_auth_nous_provider.py b/tests/test_auth_nous_provider.py new file mode 100644 index 000000000..c449fe3b4 --- /dev/null +++ b/tests/test_auth_nous_provider.py @@ -0,0 +1,156 @@ +"""Regression tests for Nous OAuth refresh + agent-key mint interactions.""" + +import json +from datetime import datetime, timezone +from pathlib import Path + +import httpx +import pytest + +from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials + + +def _setup_nous_auth( + hermes_home: Path, + *, + access_token: str = "access-old", + refresh_token: str = "refresh-old", +) -> None: + hermes_home.mkdir(parents=True, exist_ok=True) + auth_store = { + "version": 1, + "active_provider": "nous", + "providers": { + "nous": { + "portal_base_url": "https://portal.example.com", + "inference_base_url": "https://inference.example.com/v1", + "client_id": "hermes-cli", + "token_type": "Bearer", + "scope": "inference:mint_agent_key", + "access_token": access_token, + "refresh_token": refresh_token, + "obtained_at": "2026-02-01T00:00:00+00:00", + "expires_in": 0, + "expires_at": "2026-02-01T00:00:00+00:00", + "agent_key": None, + "agent_key_id": None, + "agent_key_expires_at": None, + "agent_key_expires_in": None, + "agent_key_reused": None, + "agent_key_obtained_at": None, + } + }, + } + (hermes_home / "auth.json").write_text(json.dumps(auth_store, indent=2)) + + +def _mint_payload(api_key: str = "agent-key") -> dict: + return { + "api_key": api_key, + "key_id": "key-id-1", + "expires_at": datetime.now(timezone.utc).isoformat(), + "expires_in": 1800, + "reused": False, + } + + +def test_refresh_token_persisted_when_mint_returns_insufficient_credits(tmp_path, monkeypatch): + hermes_home = tmp_path / "hermes" + _setup_nous_auth(hermes_home, refresh_token="refresh-old") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + refresh_calls = [] + mint_calls = {"count": 0} + + def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token): + refresh_calls.append(refresh_token) + idx = len(refresh_calls) + return { + "access_token": f"access-{idx}", + "refresh_token": f"refresh-{idx}", + "expires_in": 0, + "token_type": "Bearer", + } + + def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds): + mint_calls["count"] += 1 + if mint_calls["count"] == 1: + raise AuthError("credits exhausted", provider="nous", code="insufficient_credits") + return _mint_payload(api_key="agent-key-2") + + monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token) + monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key) + + with pytest.raises(AuthError) as exc: + resolve_nous_runtime_credentials(min_key_ttl_seconds=300) + assert exc.value.code == "insufficient_credits" + + state_after_failure = get_provider_auth_state("nous") + assert state_after_failure is not None + assert state_after_failure["refresh_token"] == "refresh-1" + assert state_after_failure["access_token"] == "access-1" + + creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300) + assert creds["api_key"] == "agent-key-2" + assert refresh_calls == ["refresh-old", "refresh-1"] + + +def test_refresh_token_persisted_when_mint_times_out(tmp_path, monkeypatch): + hermes_home = tmp_path / "hermes" + _setup_nous_auth(hermes_home, refresh_token="refresh-old") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token): + return { + "access_token": "access-1", + "refresh_token": "refresh-1", + "expires_in": 0, + "token_type": "Bearer", + } + + def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds): + raise httpx.ReadTimeout("mint timeout") + + monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token) + monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key) + + with pytest.raises(httpx.ReadTimeout): + resolve_nous_runtime_credentials(min_key_ttl_seconds=300) + + state_after_failure = get_provider_auth_state("nous") + assert state_after_failure is not None + assert state_after_failure["refresh_token"] == "refresh-1" + assert state_after_failure["access_token"] == "access-1" + + +def test_mint_retry_uses_latest_rotated_refresh_token(tmp_path, monkeypatch): + hermes_home = tmp_path / "hermes" + _setup_nous_auth(hermes_home, refresh_token="refresh-old") + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + refresh_calls = [] + mint_calls = {"count": 0} + + def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token): + refresh_calls.append(refresh_token) + idx = len(refresh_calls) + return { + "access_token": f"access-{idx}", + "refresh_token": f"refresh-{idx}", + "expires_in": 0, + "token_type": "Bearer", + } + + def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds): + mint_calls["count"] += 1 + if mint_calls["count"] == 1: + raise AuthError("stale access token", provider="nous", code="invalid_token") + return _mint_payload(api_key="agent-key") + + monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token) + monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key) + + creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300) + assert creds["api_key"] == "agent-key" + assert refresh_calls == ["refresh-old", "refresh-1"] +