Files
hermes-agent/tests/test_context_budget.py
Alexander Whitestone d3d2ce9ea1
Some checks failed
Nix / nix (macos-latest) (pull_request) Has been cancelled
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Nix / nix (ubuntu-latest) (pull_request) Failing after 6s
Contributor Attribution Check / check-attribution (pull_request) Failing after 29s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 38s
Tests / e2e (pull_request) Successful in 1m30s
Tests / test (pull_request) Failing after 35m7s
feat: poka-yoke context budget tracker — integrated into agent loop (#838)
2026-04-16 04:15:25 +00:00

188 lines
6.0 KiB
Python

"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
import time
from unittest.mock import patch
import pytest
from agent.context_budget import (
ContextBudgetTracker,
THRESHOLD_CAUTION,
THRESHOLD_ELEVATED,
THRESHOLD_CRITICAL,
PREFLIGHT_SAFETY_MARGIN,
WARNING_COOLDOWN,
)
@pytest.fixture
def tracker():
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
class TestThresholds:
def test_threshold_tokens_computed(self, tracker):
assert tracker.threshold_tokens == 64_000
def test_caution_at_70_percent(self, tracker):
tracker.update(int(64_000 * 0.70))
assert tracker.warning_level == "elevated"
def test_no_warning_below_caution(self, tracker):
tracker.update(int(64_000 * 0.69))
assert tracker.warning_level is None
def test_critical_at_85_percent(self, tracker):
tracker.update(int(64_000 * 0.85))
assert tracker.warning_level == "critical"
def test_emergency_at_95_percent(self, tracker):
tracker.update(int(64_000 * 0.95))
assert tracker.warning_level == "emergency"
class TestWarningMessages:
def test_elevated_message_contains_70(self, tracker):
tracker.update(int(64_000 * 0.70))
msg = tracker.warning_message
assert msg is not None
assert "CONTEXT WARNING" in msg
def test_critical_message(self, tracker):
tracker.update(int(64_000 * 0.85))
msg = tracker.warning_message
assert "compaction approaching" in msg
def test_emergency_message(self, tracker):
tracker.update(int(64_000 * 0.95))
msg = tracker.warning_message
assert "CONTEXT CRITICAL" in msg
def test_no_message_below_caution(self, tracker):
tracker.update(10_000)
assert tracker.warning_message is None
class TestWarningDedup:
def test_repeated_update_same_tier_suppressed(self, tracker):
"""Same tier within cooldown should not re-emit."""
tracker.update(int(64_000 * 0.71))
msg1 = tracker.update(int(64_000 * 0.72))
assert msg1 is None # suppressed by cooldown
def test_higher_tier_breaks_through_cooldown(self, tracker):
"""Crossing to a higher tier should always emit."""
tracker.update(int(64_000 * 0.71))
msg = tracker.update(int(64_000 * 0.86))
assert msg is not None
assert "compaction approaching" in msg.lower()
def test_cooldown_expires_allows_reemit(self, tracker):
tracker.update(int(64_000 * 0.71))
# Fast-forward cooldown
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
msg = tracker.update(int(64_000 * 0.72))
assert msg is not None
class TestCheckpoint:
def test_should_checkpoint_at_85(self, tracker):
tracker.update(int(64_000 * 0.86))
assert tracker.should_checkpoint() is True
def test_no_checkpoint_below_85(self, tracker):
tracker.update(int(64_000 * 0.84))
assert tracker.should_checkpoint() is False
def test_checkpoint_cooldown(self, tracker):
tracker.update(int(64_000 * 0.86))
tracker.should_checkpoint() # saves
assert tracker.should_checkpoint() is False # cooldown
class TestPreflight:
def test_can_fit_small_addition(self, tracker):
tracker.update(30_000)
assert tracker.can_fit(5_000) is True
def test_cannot_fit_overflow(self, tracker):
tracker.update(int(64_000 * 0.88))
assert tracker.can_fit(10_000) is False
def test_estimate_file_tokens(self, tracker):
assert tracker.estimate_file_tokens(4_000) == 1_000
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
def test_tokens_remaining(self, tracker):
tracker.update(30_000)
remaining = tracker.tokens_remaining()
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
assert remaining == safe_limit - 30_000
class TestTrend:
def test_growing_trend(self, tracker):
for i in range(10):
tracker.update(10_000 + i * 5_000)
assert tracker.trend() == "growing"
def test_shrinking_trend(self, tracker):
for i in range(10):
tracker.update(60_000 - i * 5_000)
assert tracker.trend() == "shrinking"
def test_stable_trend(self, tracker):
for _ in range(10):
tracker.update(30_000)
assert tracker.trend() == "stable"
class TestSummary:
def test_summary_keys(self, tracker):
tracker.update(40_000)
s = tracker.summary()
assert "context_length" in s
assert "current_tokens" in s
assert "warning_level" in s
assert "trend" in s
assert s["current_tokens"] == 40_000
def test_format_status(self, tracker):
tracker.update(30_000)
status = tracker.format_status()
assert "Context:" in status
assert "[" in status # progress bar
assert "%" in status
class TestEdgeCases:
def test_zero_context_length(self):
t = ContextBudgetTracker(context_length=0)
assert t.threshold_tokens == 0
assert t.progress == 0.0
assert t.warning_level is None
def test_different_threshold_percent(self):
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
assert t.threshold_tokens == 80_000
t.update(int(80_000 * 0.70))
assert t.warning_level == "elevated"
def test_over_threshold_progress(self, tracker):
"""Progress can exceed 1.0 (past compression threshold)."""
tracker.update(70_000)
assert tracker.progress > 1.0
assert tracker.warning_level == "emergency"
def test_peak_tracking(self, tracker):
tracker.update(10_000)
tracker.update(50_000)
tracker.update(30_000)
assert tracker.peak_tokens == 50_000
def test_turn_count(self, tracker):
assert tracker.turn_count == 0
tracker.update(10_000)
tracker.update(20_000)
assert tracker.turn_count == 2