Merge PR #269: Fix nous refresh token rotation failure on key mint failure

Fixes a bug where the refresh token was not persisted when the API key
mint failed (e.g., 402 insufficient credits, timeout). The rotated
refresh token was lost, causing subsequent auth attempts to fail with
a stale token.

Changes:
- Persist auth state immediately after each successful token refresh,
  before attempting the mint
- Use latest in-memory refresh token on mint-retry paths (was using
  the stale original)
- Atomic durable writes for auth.json (temp file + fsync + replace)
- Opt-in OAuth trace logging (HERMES_OAUTH_TRACE=1, fingerprint-only)
- 3 regression tests covering refresh+402, refresh+timeout, and
  invalid-token retry behavior

Author: Robin Fernandes <rewbs>
This commit is contained in:
teknium1
2026-03-04 17:52:10 -08:00
2 changed files with 292 additions and 6 deletions

View File

@@ -21,8 +21,10 @@ import os
import shutil
import stat
import base64
import hashlib
import subprocess
import time
import uuid
import webbrowser
from contextlib import contextmanager
from dataclasses import dataclass, field
@@ -147,6 +149,31 @@ def format_auth_error(error: Exception) -> str:
return str(error)
def _token_fingerprint(token: Any) -> Optional[str]:
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
if not isinstance(token, str):
return None
cleaned = token.strip()
if not cleaned:
return None
return hashlib.sha256(cleaned.encode("utf-8")).hexdigest()[:12]
def _oauth_trace_enabled() -> bool:
raw = os.getenv("HERMES_OAUTH_TRACE", "").strip().lower()
return raw in {"1", "true", "yes", "on"}
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
if not _oauth_trace_enabled():
return
payload: Dict[str, Any] = {"event": event}
if sequence_id:
payload["sequence_id"] = sequence_id
payload.update(fields)
logger.info("oauth_trace %s", json.dumps(payload, sort_keys=True, ensure_ascii=False))
# =============================================================================
# Auth Store — persistence layer for ~/.hermes/auth.json
# =============================================================================
@@ -216,7 +243,29 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
auth_file.parent.mkdir(parents=True, exist_ok=True)
auth_store["version"] = AUTH_STORE_VERSION
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
auth_file.write_text(json.dumps(auth_store, indent=2) + "\n")
payload = json.dumps(auth_store, indent=2) + "\n"
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
try:
with tmp_path.open("w", encoding="utf-8") as handle:
handle.write(payload)
handle.flush()
os.fsync(handle.fileno())
os.replace(tmp_path, auth_file)
try:
dir_fd = os.open(str(auth_file.parent), os.O_RDONLY)
except OSError:
dir_fd = None
if dir_fd is not None:
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
finally:
try:
if tmp_path.exists():
tmp_path.unlink()
except OSError:
pass
# Restrict file permissions to owner only
try:
auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
@@ -906,6 +955,7 @@ def resolve_nous_runtime_credentials(
expires_in, source ("cache" or "portal").
"""
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
with _auth_store_lock():
auth_store = _load_auth_store()
@@ -928,8 +978,35 @@ def resolve_nous_runtime_credentials(
).rstrip("/")
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
def _persist_state(reason: str) -> None:
try:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
except Exception as exc:
_oauth_trace(
"nous_state_persist_failed",
sequence_id=sequence_id,
reason=reason,
error_type=type(exc).__name__,
)
raise
_oauth_trace(
"nous_state_persisted",
sequence_id=sequence_id,
reason=reason,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
access_token_fp=_token_fingerprint(state.get("access_token")),
)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
_oauth_trace(
"nous_runtime_credentials_start",
sequence_id=sequence_id,
force_mint=bool(force_mint),
min_key_ttl_seconds=min_key_ttl_seconds,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
)
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
access_token = state.get("access_token")
@@ -945,12 +1022,19 @@ def resolve_nous_runtime_credentials(
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="access_expiring",
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
@@ -964,6 +1048,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring
used_cached_key = False
@@ -971,23 +1065,45 @@ def resolve_nous_runtime_credentials(
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
used_cached_key = True
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
else:
try:
_oauth_trace(
"mint_start",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
)
except AuthError as exc:
_oauth_trace(
"mint_error",
sequence_id=sequence_id,
code=exc.code,
)
# Retry path: access token may be stale server-side despite local checks
if exc.code in {"invalid_token", "invalid_grant"} and isinstance(refresh_token, str) and refresh_token:
latest_refresh_token = state.get("refresh_token")
if (
exc.code in {"invalid_token", "invalid_grant"}
and isinstance(latest_refresh_token, str)
and latest_refresh_token
):
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token,
client_id=client_id, refresh_token=latest_refresh_token,
)
now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
@@ -999,6 +1115,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
@@ -1018,6 +1144,11 @@ def resolve_nous_runtime_credentials(
minted_url = _optional_base_url(mint_payload.get("inference_base_url"))
if minted_url:
inference_base_url = minted_url
_oauth_trace(
"mint_success",
sequence_id=sequence_id,
reused=bool(mint_payload.get("reused", False)),
)
# Persist routing and TLS metadata for non-interactive refresh/mint
state["portal_base_url"] = portal_base_url
@@ -1028,8 +1159,7 @@ def resolve_nous_runtime_credentials(
"ca_bundle": verify if isinstance(verify, str) else None,
}
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
_persist_state("resolve_nous_runtime_credentials_final")
api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key:

View File

@@ -0,0 +1,156 @@
"""Regression tests for Nous OAuth refresh + agent-key mint interactions."""
import json
from datetime import datetime, timezone
from pathlib import Path
import httpx
import pytest
from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials
def _setup_nous_auth(
hermes_home: Path,
*,
access_token: str = "access-old",
refresh_token: str = "refresh-old",
) -> None:
hermes_home.mkdir(parents=True, exist_ok=True)
auth_store = {
"version": 1,
"active_provider": "nous",
"providers": {
"nous": {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"token_type": "Bearer",
"scope": "inference:mint_agent_key",
"access_token": access_token,
"refresh_token": refresh_token,
"obtained_at": "2026-02-01T00:00:00+00:00",
"expires_in": 0,
"expires_at": "2026-02-01T00:00:00+00:00",
"agent_key": None,
"agent_key_id": None,
"agent_key_expires_at": None,
"agent_key_expires_in": None,
"agent_key_reused": None,
"agent_key_obtained_at": None,
}
},
}
(hermes_home / "auth.json").write_text(json.dumps(auth_store, indent=2))
def _mint_payload(api_key: str = "agent-key") -> dict:
return {
"api_key": api_key,
"key_id": "key-id-1",
"expires_at": datetime.now(timezone.utc).isoformat(),
"expires_in": 1800,
"reused": False,
}
def test_refresh_token_persisted_when_mint_returns_insufficient_credits(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("credits exhausted", provider="nous", code="insufficient_credits")
return _mint_payload(api_key="agent-key-2")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(AuthError) as exc:
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert exc.value.code == "insufficient_credits"
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key-2"
assert refresh_calls == ["refresh-old", "refresh-1"]
def test_refresh_token_persisted_when_mint_times_out(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
return {
"access_token": "access-1",
"refresh_token": "refresh-1",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
raise httpx.ReadTimeout("mint timeout")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(httpx.ReadTimeout):
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
def test_mint_retry_uses_latest_rotated_refresh_token(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("stale access token", provider="nous", code="invalid_token")
return _mint_payload(api_key="agent-key")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key"
assert refresh_calls == ["refresh-old", "refresh-1"]