Some checks failed
Nix / nix (macos-latest) (pull_request) Has been cancelled
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Nix / nix (ubuntu-latest) (pull_request) Failing after 6s
Contributor Attribution Check / check-attribution (pull_request) Failing after 29s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 38s
Tests / e2e (pull_request) Successful in 1m30s
Tests / test (pull_request) Failing after 35m7s
188 lines
6.0 KiB
Python
188 lines
6.0 KiB
Python
"""Tests for ContextBudgetTracker — poka-yoke context window safety."""
|
|
import time
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from agent.context_budget import (
|
|
ContextBudgetTracker,
|
|
THRESHOLD_CAUTION,
|
|
THRESHOLD_ELEVATED,
|
|
THRESHOLD_CRITICAL,
|
|
PREFLIGHT_SAFETY_MARGIN,
|
|
WARNING_COOLDOWN,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tracker():
|
|
"""Standard tracker: 128K context, 50% compression threshold = 64K threshold."""
|
|
return ContextBudgetTracker(context_length=128_000, threshold_percent=0.50)
|
|
|
|
|
|
class TestThresholds:
|
|
def test_threshold_tokens_computed(self, tracker):
|
|
assert tracker.threshold_tokens == 64_000
|
|
|
|
def test_caution_at_70_percent(self, tracker):
|
|
tracker.update(int(64_000 * 0.70))
|
|
assert tracker.warning_level == "elevated"
|
|
|
|
def test_no_warning_below_caution(self, tracker):
|
|
tracker.update(int(64_000 * 0.69))
|
|
assert tracker.warning_level is None
|
|
|
|
def test_critical_at_85_percent(self, tracker):
|
|
tracker.update(int(64_000 * 0.85))
|
|
assert tracker.warning_level == "critical"
|
|
|
|
def test_emergency_at_95_percent(self, tracker):
|
|
tracker.update(int(64_000 * 0.95))
|
|
assert tracker.warning_level == "emergency"
|
|
|
|
|
|
class TestWarningMessages:
|
|
def test_elevated_message_contains_70(self, tracker):
|
|
tracker.update(int(64_000 * 0.70))
|
|
msg = tracker.warning_message
|
|
assert msg is not None
|
|
assert "CONTEXT WARNING" in msg
|
|
|
|
def test_critical_message(self, tracker):
|
|
tracker.update(int(64_000 * 0.85))
|
|
msg = tracker.warning_message
|
|
assert "compaction approaching" in msg
|
|
|
|
def test_emergency_message(self, tracker):
|
|
tracker.update(int(64_000 * 0.95))
|
|
msg = tracker.warning_message
|
|
assert "CONTEXT CRITICAL" in msg
|
|
|
|
def test_no_message_below_caution(self, tracker):
|
|
tracker.update(10_000)
|
|
assert tracker.warning_message is None
|
|
|
|
|
|
class TestWarningDedup:
|
|
def test_repeated_update_same_tier_suppressed(self, tracker):
|
|
"""Same tier within cooldown should not re-emit."""
|
|
tracker.update(int(64_000 * 0.71))
|
|
msg1 = tracker.update(int(64_000 * 0.72))
|
|
assert msg1 is None # suppressed by cooldown
|
|
|
|
def test_higher_tier_breaks_through_cooldown(self, tracker):
|
|
"""Crossing to a higher tier should always emit."""
|
|
tracker.update(int(64_000 * 0.71))
|
|
msg = tracker.update(int(64_000 * 0.86))
|
|
assert msg is not None
|
|
assert "compaction approaching" in msg.lower()
|
|
|
|
def test_cooldown_expires_allows_reemit(self, tracker):
|
|
tracker.update(int(64_000 * 0.71))
|
|
# Fast-forward cooldown
|
|
tracker._last_warned_time = time.time() - WARNING_COOLDOWN - 1
|
|
msg = tracker.update(int(64_000 * 0.72))
|
|
assert msg is not None
|
|
|
|
|
|
class TestCheckpoint:
|
|
def test_should_checkpoint_at_85(self, tracker):
|
|
tracker.update(int(64_000 * 0.86))
|
|
assert tracker.should_checkpoint() is True
|
|
|
|
def test_no_checkpoint_below_85(self, tracker):
|
|
tracker.update(int(64_000 * 0.84))
|
|
assert tracker.should_checkpoint() is False
|
|
|
|
def test_checkpoint_cooldown(self, tracker):
|
|
tracker.update(int(64_000 * 0.86))
|
|
tracker.should_checkpoint() # saves
|
|
assert tracker.should_checkpoint() is False # cooldown
|
|
|
|
|
|
class TestPreflight:
|
|
def test_can_fit_small_addition(self, tracker):
|
|
tracker.update(30_000)
|
|
assert tracker.can_fit(5_000) is True
|
|
|
|
def test_cannot_fit_overflow(self, tracker):
|
|
tracker.update(int(64_000 * 0.88))
|
|
assert tracker.can_fit(10_000) is False
|
|
|
|
def test_estimate_file_tokens(self, tracker):
|
|
assert tracker.estimate_file_tokens(4_000) == 1_000
|
|
assert tracker.estimate_file_tokens(100) >= 1 # minimum 1
|
|
|
|
def test_tokens_remaining(self, tracker):
|
|
tracker.update(30_000)
|
|
remaining = tracker.tokens_remaining()
|
|
safe_limit = int(64_000 * PREFLIGHT_SAFETY_MARGIN)
|
|
assert remaining == safe_limit - 30_000
|
|
|
|
|
|
class TestTrend:
|
|
def test_growing_trend(self, tracker):
|
|
for i in range(10):
|
|
tracker.update(10_000 + i * 5_000)
|
|
assert tracker.trend() == "growing"
|
|
|
|
def test_shrinking_trend(self, tracker):
|
|
for i in range(10):
|
|
tracker.update(60_000 - i * 5_000)
|
|
assert tracker.trend() == "shrinking"
|
|
|
|
def test_stable_trend(self, tracker):
|
|
for _ in range(10):
|
|
tracker.update(30_000)
|
|
assert tracker.trend() == "stable"
|
|
|
|
|
|
class TestSummary:
|
|
def test_summary_keys(self, tracker):
|
|
tracker.update(40_000)
|
|
s = tracker.summary()
|
|
assert "context_length" in s
|
|
assert "current_tokens" in s
|
|
assert "warning_level" in s
|
|
assert "trend" in s
|
|
assert s["current_tokens"] == 40_000
|
|
|
|
def test_format_status(self, tracker):
|
|
tracker.update(30_000)
|
|
status = tracker.format_status()
|
|
assert "Context:" in status
|
|
assert "[" in status # progress bar
|
|
assert "%" in status
|
|
|
|
|
|
class TestEdgeCases:
|
|
def test_zero_context_length(self):
|
|
t = ContextBudgetTracker(context_length=0)
|
|
assert t.threshold_tokens == 0
|
|
assert t.progress == 0.0
|
|
assert t.warning_level is None
|
|
|
|
def test_different_threshold_percent(self):
|
|
t = ContextBudgetTracker(context_length=100_000, threshold_percent=0.80)
|
|
assert t.threshold_tokens == 80_000
|
|
t.update(int(80_000 * 0.70))
|
|
assert t.warning_level == "elevated"
|
|
|
|
def test_over_threshold_progress(self, tracker):
|
|
"""Progress can exceed 1.0 (past compression threshold)."""
|
|
tracker.update(70_000)
|
|
assert tracker.progress > 1.0
|
|
assert tracker.warning_level == "emergency"
|
|
|
|
def test_peak_tracking(self, tracker):
|
|
tracker.update(10_000)
|
|
tracker.update(50_000)
|
|
tracker.update(30_000)
|
|
assert tracker.peak_tokens == 50_000
|
|
|
|
def test_turn_count(self, tracker):
|
|
assert tracker.turn_count == 0
|
|
tracker.update(10_000)
|
|
tracker.update(20_000)
|
|
assert tracker.turn_count == 2
|