Fix nous refresh token rotation failure in case where api key mint/retrieval fails
This commit is contained in:
@@ -21,8 +21,10 @@ import os
|
||||
import shutil
|
||||
import stat
|
||||
import base64
|
||||
import hashlib
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
import webbrowser
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
@@ -147,6 +149,31 @@ def format_auth_error(error: Exception) -> str:
|
||||
return str(error)
|
||||
|
||||
|
||||
def _token_fingerprint(token: Any) -> Optional[str]:
|
||||
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
|
||||
if not isinstance(token, str):
|
||||
return None
|
||||
cleaned = token.strip()
|
||||
if not cleaned:
|
||||
return None
|
||||
return hashlib.sha256(cleaned.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
def _oauth_trace_enabled() -> bool:
|
||||
raw = os.getenv("HERMES_OAUTH_TRACE", "").strip().lower()
|
||||
return raw in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
|
||||
if not _oauth_trace_enabled():
|
||||
return
|
||||
payload: Dict[str, Any] = {"event": event}
|
||||
if sequence_id:
|
||||
payload["sequence_id"] = sequence_id
|
||||
payload.update(fields)
|
||||
logger.info("oauth_trace %s", json.dumps(payload, sort_keys=True, ensure_ascii=False))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Auth Store — persistence layer for ~/.hermes/auth.json
|
||||
# =============================================================================
|
||||
@@ -216,7 +243,29 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
|
||||
auth_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
auth_store["version"] = AUTH_STORE_VERSION
|
||||
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
auth_file.write_text(json.dumps(auth_store, indent=2) + "\n")
|
||||
payload = json.dumps(auth_store, indent=2) + "\n"
|
||||
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
|
||||
try:
|
||||
with tmp_path.open("w", encoding="utf-8") as handle:
|
||||
handle.write(payload)
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
os.replace(tmp_path, auth_file)
|
||||
try:
|
||||
dir_fd = os.open(str(auth_file.parent), os.O_RDONLY)
|
||||
except OSError:
|
||||
dir_fd = None
|
||||
if dir_fd is not None:
|
||||
try:
|
||||
os.fsync(dir_fd)
|
||||
finally:
|
||||
os.close(dir_fd)
|
||||
finally:
|
||||
try:
|
||||
if tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
# Restrict file permissions to owner only
|
||||
try:
|
||||
auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
|
||||
@@ -906,6 +955,7 @@ def resolve_nous_runtime_credentials(
|
||||
expires_in, source ("cache" or "portal").
|
||||
"""
|
||||
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
|
||||
sequence_id = uuid.uuid4().hex[:12]
|
||||
|
||||
with _auth_store_lock():
|
||||
auth_store = _load_auth_store()
|
||||
@@ -928,8 +978,35 @@ def resolve_nous_runtime_credentials(
|
||||
).rstrip("/")
|
||||
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
|
||||
|
||||
def _persist_state(reason: str) -> None:
|
||||
try:
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
except Exception as exc:
|
||||
_oauth_trace(
|
||||
"nous_state_persist_failed",
|
||||
sequence_id=sequence_id,
|
||||
reason=reason,
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
raise
|
||||
_oauth_trace(
|
||||
"nous_state_persisted",
|
||||
sequence_id=sequence_id,
|
||||
reason=reason,
|
||||
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
|
||||
access_token_fp=_token_fingerprint(state.get("access_token")),
|
||||
)
|
||||
|
||||
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
|
||||
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
|
||||
_oauth_trace(
|
||||
"nous_runtime_credentials_start",
|
||||
sequence_id=sequence_id,
|
||||
force_mint=bool(force_mint),
|
||||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
|
||||
)
|
||||
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
access_token = state.get("access_token")
|
||||
@@ -945,12 +1022,19 @@ def resolve_nous_runtime_credentials(
|
||||
raise AuthError("Session expired and no refresh token is available.",
|
||||
provider="nous", relogin_required=True)
|
||||
|
||||
_oauth_trace(
|
||||
"refresh_start",
|
||||
sequence_id=sequence_id,
|
||||
reason="access_expiring",
|
||||
refresh_token_fp=_token_fingerprint(refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=refresh_token,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
previous_refresh_token = refresh_token
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
|
||||
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
|
||||
@@ -964,6 +1048,16 @@ def resolve_nous_runtime_credentials(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
"refresh_success",
|
||||
sequence_id=sequence_id,
|
||||
reason="access_expiring",
|
||||
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
|
||||
new_refresh_token_fp=_token_fingerprint(refresh_token),
|
||||
)
|
||||
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
|
||||
_persist_state("post_refresh_access_expiring")
|
||||
|
||||
# Step 2: mint agent key if missing/expiring
|
||||
used_cached_key = False
|
||||
@@ -971,23 +1065,45 @@ def resolve_nous_runtime_credentials(
|
||||
|
||||
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
|
||||
used_cached_key = True
|
||||
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
|
||||
else:
|
||||
try:
|
||||
_oauth_trace(
|
||||
"mint_start",
|
||||
sequence_id=sequence_id,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
|
||||
)
|
||||
except AuthError as exc:
|
||||
_oauth_trace(
|
||||
"mint_error",
|
||||
sequence_id=sequence_id,
|
||||
code=exc.code,
|
||||
)
|
||||
# Retry path: access token may be stale server-side despite local checks
|
||||
if exc.code in {"invalid_token", "invalid_grant"} and isinstance(refresh_token, str) and refresh_token:
|
||||
latest_refresh_token = state.get("refresh_token")
|
||||
if (
|
||||
exc.code in {"invalid_token", "invalid_grant"}
|
||||
and isinstance(latest_refresh_token, str)
|
||||
and latest_refresh_token
|
||||
):
|
||||
_oauth_trace(
|
||||
"refresh_start",
|
||||
sequence_id=sequence_id,
|
||||
reason="mint_retry_after_invalid_token",
|
||||
refresh_token_fp=_token_fingerprint(latest_refresh_token),
|
||||
)
|
||||
refreshed = _refresh_access_token(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
client_id=client_id, refresh_token=refresh_token,
|
||||
client_id=client_id, refresh_token=latest_refresh_token,
|
||||
)
|
||||
now = datetime.now(timezone.utc)
|
||||
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
|
||||
state["access_token"] = refreshed["access_token"]
|
||||
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
|
||||
state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
|
||||
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
|
||||
state["scope"] = refreshed.get("scope") or state.get("scope")
|
||||
refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
|
||||
@@ -999,6 +1115,16 @@ def resolve_nous_runtime_credentials(
|
||||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
access_token = state["access_token"]
|
||||
refresh_token = state["refresh_token"]
|
||||
_oauth_trace(
|
||||
"refresh_success",
|
||||
sequence_id=sequence_id,
|
||||
reason="mint_retry_after_invalid_token",
|
||||
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
|
||||
new_refresh_token_fp=_token_fingerprint(refresh_token),
|
||||
)
|
||||
# Persist retry refresh immediately for crash safety and cross-process visibility.
|
||||
_persist_state("post_refresh_mint_retry")
|
||||
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client, portal_base_url=portal_base_url,
|
||||
@@ -1018,6 +1144,11 @@ def resolve_nous_runtime_credentials(
|
||||
minted_url = _optional_base_url(mint_payload.get("inference_base_url"))
|
||||
if minted_url:
|
||||
inference_base_url = minted_url
|
||||
_oauth_trace(
|
||||
"mint_success",
|
||||
sequence_id=sequence_id,
|
||||
reused=bool(mint_payload.get("reused", False)),
|
||||
)
|
||||
|
||||
# Persist routing and TLS metadata for non-interactive refresh/mint
|
||||
state["portal_base_url"] = portal_base_url
|
||||
@@ -1028,8 +1159,7 @@ def resolve_nous_runtime_credentials(
|
||||
"ca_bundle": verify if isinstance(verify, str) else None,
|
||||
}
|
||||
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
_persist_state("resolve_nous_runtime_credentials_final")
|
||||
|
||||
api_key = state.get("agent_key")
|
||||
if not isinstance(api_key, str) or not api_key:
|
||||
|
||||
Reference in New Issue
Block a user