fix(nous): add 3-minute TTL cache to free-tier detection
check_nous_free_tier() now caches its result for 180 seconds to avoid redundant Portal API calls during a session (auxiliary client init, model selection, login flow all call it independently). The TTL is short enough that an account upgrade from free to paid is reflected within 3 minutes. clear_nous_free_tier_cache() is exposed for explicit invalidation on login/logout. Adds 4 tests for cache hit, TTL expiry, explicit clear, and TTL bound.
This commit is contained in:
@@ -404,13 +404,38 @@ def partition_nous_models_by_tier(
|
||||
return (selectable, unavailable)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TTL cache for free-tier detection — avoids repeated API calls within a
|
||||
# session while still picking up upgrades quickly.
|
||||
# ---------------------------------------------------------------------------
|
||||
_FREE_TIER_CACHE_TTL: int = 180 # seconds (3 minutes)
|
||||
_free_tier_cache: tuple[bool, float] | None = None # (result, timestamp)
|
||||
|
||||
|
||||
def clear_nous_free_tier_cache() -> None:
|
||||
"""Invalidate the cached free-tier result (e.g. after login/logout)."""
|
||||
global _free_tier_cache
|
||||
_free_tier_cache = None
|
||||
|
||||
|
||||
def check_nous_free_tier() -> bool:
|
||||
"""Check if the current Nous Portal user is on a free (unpaid) tier.
|
||||
|
||||
Resolves the OAuth access token from the auth store, calls the
|
||||
portal account endpoint, and returns True if the account has no
|
||||
paid subscription. Returns False (assume paid) on any error.
|
||||
Results are cached for ``_FREE_TIER_CACHE_TTL`` seconds to avoid
|
||||
hitting the Portal API on every call. The cache is short-lived so
|
||||
that an account upgrade is reflected within a few minutes.
|
||||
|
||||
Returns False (assume paid) on any error — never blocks paying users.
|
||||
"""
|
||||
global _free_tier_cache
|
||||
import time
|
||||
|
||||
now = time.monotonic()
|
||||
if _free_tier_cache is not None:
|
||||
cached_result, cached_at = _free_tier_cache
|
||||
if now - cached_at < _FREE_TIER_CACHE_TTL:
|
||||
return cached_result
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_provider_auth_state, resolve_nous_runtime_credentials
|
||||
|
||||
@@ -419,15 +444,20 @@ def check_nous_free_tier() -> bool:
|
||||
|
||||
state = get_provider_auth_state("nous")
|
||||
if not state:
|
||||
_free_tier_cache = (False, now)
|
||||
return False
|
||||
access_token = state.get("access_token", "")
|
||||
portal_url = state.get("portal_base_url", "")
|
||||
if not access_token:
|
||||
_free_tier_cache = (False, now)
|
||||
return False
|
||||
|
||||
account_info = fetch_nous_account_tier(access_token, portal_url)
|
||||
return is_nous_free_tier(account_info)
|
||||
result = is_nous_free_tier(account_info)
|
||||
_free_tier_cache = (result, now)
|
||||
return result
|
||||
except Exception:
|
||||
_free_tier_cache = (False, now)
|
||||
return False # default to paid on error — don't block users
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Tests for the hermes_cli models module."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from hermes_cli.models import (
|
||||
OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model,
|
||||
filter_nous_free_models, _NOUS_ALLOWED_FREE_MODELS,
|
||||
is_nous_free_tier, partition_nous_models_by_tier,
|
||||
check_nous_free_tier, clear_nous_free_tier_cache,
|
||||
_FREE_TIER_CACHE_TTL,
|
||||
)
|
||||
import hermes_cli.models as _models_mod
|
||||
|
||||
|
||||
class TestModelIds:
|
||||
@@ -291,3 +296,63 @@ class TestPartitionNousModelsByTier:
|
||||
sel, unav = partition_nous_models_by_tier(models, pricing, free_tier=True)
|
||||
assert sel == []
|
||||
assert unav == models
|
||||
|
||||
|
||||
class TestCheckNousFreeTierCache:
|
||||
"""Tests for the TTL cache on check_nous_free_tier()."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset cache before each test."""
|
||||
clear_nous_free_tier_cache()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset cache after each test."""
|
||||
clear_nous_free_tier_cache()
|
||||
|
||||
@patch("hermes_cli.models.fetch_nous_account_tier")
|
||||
@patch("hermes_cli.models.is_nous_free_tier", return_value=True)
|
||||
def test_result_is_cached(self, mock_is_free, mock_fetch):
|
||||
"""Second call within TTL returns cached result without API call."""
|
||||
mock_fetch.return_value = {"subscription": {"monthly_charge": 0}}
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
|
||||
result1 = check_nous_free_tier()
|
||||
result2 = check_nous_free_tier()
|
||||
|
||||
assert result1 is True
|
||||
assert result2 is True
|
||||
# fetch_nous_account_tier should only be called once (cached on second call)
|
||||
assert mock_fetch.call_count == 1
|
||||
|
||||
@patch("hermes_cli.models.fetch_nous_account_tier")
|
||||
@patch("hermes_cli.models.is_nous_free_tier", return_value=False)
|
||||
def test_cache_expires_after_ttl(self, mock_is_free, mock_fetch):
|
||||
"""After TTL expires, the API is called again."""
|
||||
mock_fetch.return_value = {"subscription": {"monthly_charge": 20}}
|
||||
with patch("hermes_cli.auth.get_provider_auth_state", return_value={"access_token": "tok"}), \
|
||||
patch("hermes_cli.auth.resolve_nous_runtime_credentials"):
|
||||
result1 = check_nous_free_tier()
|
||||
assert mock_fetch.call_count == 1
|
||||
|
||||
# Simulate TTL expiry by backdating the cache timestamp
|
||||
cached_result, cached_at = _models_mod._free_tier_cache
|
||||
_models_mod._free_tier_cache = (cached_result, cached_at - _FREE_TIER_CACHE_TTL - 1)
|
||||
|
||||
result2 = check_nous_free_tier()
|
||||
assert mock_fetch.call_count == 2
|
||||
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
|
||||
def test_clear_cache_forces_refresh(self):
|
||||
"""clear_nous_free_tier_cache() invalidates the cached result."""
|
||||
# Manually seed the cache
|
||||
import time
|
||||
_models_mod._free_tier_cache = (True, time.monotonic())
|
||||
|
||||
clear_nous_free_tier_cache()
|
||||
assert _models_mod._free_tier_cache is None
|
||||
|
||||
def test_cache_ttl_is_short(self):
|
||||
"""TTL should be short enough to catch upgrades quickly (<=5 min)."""
|
||||
assert _FREE_TIER_CACHE_TTL <= 300
|
||||
|
||||
Reference in New Issue
Block a user