Fix nous refresh token rotation failure in case where api key mint/retrieval fails

This commit is contained in:
Robin Fernandes
2026-03-02 17:18:15 +11:00
parent 7b38afc179
commit 5e5e0efc60
2 changed files with 292 additions and 6 deletions

View File

@@ -21,8 +21,10 @@ import os
import shutil
import stat
import base64
import hashlib
import subprocess
import time
import uuid
import webbrowser
from contextlib import contextmanager
from dataclasses import dataclass, field
@@ -147,6 +149,31 @@ def format_auth_error(error: Exception) -> str:
return str(error)
def _token_fingerprint(token: Any) -> Optional[str]:
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
if not isinstance(token, str):
return None
cleaned = token.strip()
if not cleaned:
return None
return hashlib.sha256(cleaned.encode("utf-8")).hexdigest()[:12]
def _oauth_trace_enabled() -> bool:
raw = os.getenv("HERMES_OAUTH_TRACE", "").strip().lower()
return raw in {"1", "true", "yes", "on"}
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
if not _oauth_trace_enabled():
return
payload: Dict[str, Any] = {"event": event}
if sequence_id:
payload["sequence_id"] = sequence_id
payload.update(fields)
logger.info("oauth_trace %s", json.dumps(payload, sort_keys=True, ensure_ascii=False))
# =============================================================================
# Auth Store — persistence layer for ~/.hermes/auth.json
# =============================================================================
@@ -216,7 +243,29 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
auth_file.parent.mkdir(parents=True, exist_ok=True)
auth_store["version"] = AUTH_STORE_VERSION
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
auth_file.write_text(json.dumps(auth_store, indent=2) + "\n")
payload = json.dumps(auth_store, indent=2) + "\n"
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
try:
with tmp_path.open("w", encoding="utf-8") as handle:
handle.write(payload)
handle.flush()
os.fsync(handle.fileno())
os.replace(tmp_path, auth_file)
try:
dir_fd = os.open(str(auth_file.parent), os.O_RDONLY)
except OSError:
dir_fd = None
if dir_fd is not None:
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
finally:
try:
if tmp_path.exists():
tmp_path.unlink()
except OSError:
pass
# Restrict file permissions to owner only
try:
auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
@@ -906,6 +955,7 @@ def resolve_nous_runtime_credentials(
expires_in, source ("cache" or "portal").
"""
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
with _auth_store_lock():
auth_store = _load_auth_store()
@@ -928,8 +978,35 @@ def resolve_nous_runtime_credentials(
).rstrip("/")
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
def _persist_state(reason: str) -> None:
try:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
except Exception as exc:
_oauth_trace(
"nous_state_persist_failed",
sequence_id=sequence_id,
reason=reason,
error_type=type(exc).__name__,
)
raise
_oauth_trace(
"nous_state_persisted",
sequence_id=sequence_id,
reason=reason,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
access_token_fp=_token_fingerprint(state.get("access_token")),
)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
_oauth_trace(
"nous_runtime_credentials_start",
sequence_id=sequence_id,
force_mint=bool(force_mint),
min_key_ttl_seconds=min_key_ttl_seconds,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
)
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
access_token = state.get("access_token")
@@ -945,12 +1022,19 @@ def resolve_nous_runtime_credentials(
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="access_expiring",
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
@@ -964,6 +1048,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring
used_cached_key = False
@@ -971,23 +1065,45 @@ def resolve_nous_runtime_credentials(
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
used_cached_key = True
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
else:
try:
_oauth_trace(
"mint_start",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
)
except AuthError as exc:
_oauth_trace(
"mint_error",
sequence_id=sequence_id,
code=exc.code,
)
# Retry path: access token may be stale server-side despite local checks
if exc.code in {"invalid_token", "invalid_grant"} and isinstance(refresh_token, str) and refresh_token:
latest_refresh_token = state.get("refresh_token")
if (
exc.code in {"invalid_token", "invalid_grant"}
and isinstance(latest_refresh_token, str)
and latest_refresh_token
):
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
client_id=client_id, refresh_token=latest_refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
@@ -999,6 +1115,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
@@ -1018,6 +1144,11 @@ def resolve_nous_runtime_credentials(
minted_url = _optional_base_url(mint_payload.get("inference_base_url"))
if minted_url:
inference_base_url = minted_url
_oauth_trace(
"mint_success",
sequence_id=sequence_id,
reused=bool(mint_payload.get("reused", False)),
)
# Persist routing and TLS metadata for non-interactive refresh/mint
state["portal_base_url"] = portal_base_url
@@ -1028,8 +1159,7 @@ def resolve_nous_runtime_credentials(
"ca_bundle": verify if isinstance(verify, str) else None,
}
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_persist_state("resolve_nous_runtime_credentials_final")
api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key:

View File

@@ -0,0 +1,156 @@
"""Regression tests for Nous OAuth refresh + agent-key mint interactions."""
import json
from datetime import datetime, timezone
from pathlib import Path
import httpx
import pytest
from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials
def _setup_nous_auth(
hermes_home: Path,
*,
access_token: str = "access-old",
refresh_token: str = "refresh-old",
) -> None:
hermes_home.mkdir(parents=True, exist_ok=True)
auth_store = {
"version": 1,
"active_provider": "nous",
"providers": {
"nous": {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"token_type": "Bearer",
"scope": "inference:mint_agent_key",
"access_token": access_token,
"refresh_token": refresh_token,
"obtained_at": "2026-02-01T00:00:00+00:00",
"expires_in": 0,
"expires_at": "2026-02-01T00:00:00+00:00",
"agent_key": None,
"agent_key_id": None,
"agent_key_expires_at": None,
"agent_key_expires_in": None,
"agent_key_reused": None,
"agent_key_obtained_at": None,
}
},
}
(hermes_home / "auth.json").write_text(json.dumps(auth_store, indent=2))
def _mint_payload(api_key: str = "agent-key") -> dict:
return {
"api_key": api_key,
"key_id": "key-id-1",
"expires_at": datetime.now(timezone.utc).isoformat(),
"expires_in": 1800,
"reused": False,
}
def test_refresh_token_persisted_when_mint_returns_insufficient_credits(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("credits exhausted", provider="nous", code="insufficient_credits")
return _mint_payload(api_key="agent-key-2")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(AuthError) as exc:
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert exc.value.code == "insufficient_credits"
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key-2"
assert refresh_calls == ["refresh-old", "refresh-1"]
def test_refresh_token_persisted_when_mint_times_out(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
return {
"access_token": "access-1",
"refresh_token": "refresh-1",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
raise httpx.ReadTimeout("mint timeout")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(httpx.ReadTimeout):
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
def test_mint_retry_uses_latest_rotated_refresh_token(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("stale access token", provider="nous", code="invalid_token")
return _mint_payload(api_key="agent-key")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key"
assert refresh_calls == ["refresh-old", "refresh-1"]