Some checks failed
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 19s
Contributor Attribution Check / check-attribution (pull_request) Failing after 16s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Tests / test (pull_request) Failing after 18m30s
Tests / e2e (pull_request) Successful in 1m16s
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
"""
|
|
Tests for context budget tracker
|
|
|
|
Issue: #838
|
|
"""
|
|
|
|
import json
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
from agent.context_budget import (
|
|
ContextBudget,
|
|
ContextBudgetTracker,
|
|
estimate_tokens,
|
|
estimate_messages_tokens,
|
|
check_context_budget,
|
|
preflight_token_check,
|
|
THRESHOLD_WARNING,
|
|
THRESHOLD_CRITICAL,
|
|
THRESHOLD_DANGER,
|
|
)
|
|
|
|
|
|
class TestContextBudget(unittest.TestCase):
|
|
|
|
def test_basic_budget(self):
|
|
b = ContextBudget(context_limit=10000)
|
|
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
|
|
self.assertEqual(b.remaining, 8000)
|
|
self.assertEqual(b.utilization, 0.0)
|
|
|
|
def test_utilization(self):
|
|
b = ContextBudget(context_limit=10000, used_tokens=4000)
|
|
self.assertEqual(b.utilization, 0.5)
|
|
self.assertEqual(b.remaining, 4000)
|
|
|
|
|
|
class TestTokenEstimation(unittest.TestCase):
|
|
|
|
def test_estimate_tokens(self):
|
|
self.assertEqual(estimate_tokens(""), 0)
|
|
self.assertEqual(estimate_tokens("a" * 4), 1)
|
|
self.assertEqual(estimate_tokens("a" * 400), 100)
|
|
|
|
def test_estimate_messages(self):
|
|
messages = [
|
|
{"role": "user", "content": "a" * 400},
|
|
{"role": "assistant", "content": "b" * 800},
|
|
]
|
|
tokens = estimate_messages_tokens(messages)
|
|
self.assertEqual(tokens, 300) # 100 + 200
|
|
|
|
|
|
class TestContextBudgetTracker(unittest.TestCase):
|
|
|
|
def test_warning_at_70_percent(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
tracker.budget.used_tokens = 5600 # 70% of 8000 available
|
|
warning = tracker.get_warning()
|
|
self.assertIsNotNone(warning)
|
|
self.assertIn("70", warning)
|
|
|
|
def test_critical_at_85_percent(self):
|
|
with tempfile.TemporaryDirectory() as tmp:
|
|
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
|
|
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
|
|
tracker.budget.used_tokens = 6800 # 85% of 8000
|
|
warning = tracker.get_warning()
|
|
self.assertIsNotNone(warning)
|
|
self.assertIn("85", warning)
|
|
|
|
def test_danger_at_95_percent(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
tracker.budget.used_tokens = 7600 # 95% of 8000
|
|
warning = tracker.get_warning()
|
|
self.assertIsNotNone(warning)
|
|
self.assertIn("CRITICAL", warning)
|
|
|
|
def test_can_fit(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
tracker.budget.used_tokens = 5000
|
|
self.assertTrue(tracker.can_fit(1000))
|
|
self.assertFalse(tracker.can_fit(5000))
|
|
|
|
def test_preflight_check(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
tracker.budget.used_tokens = 5000
|
|
|
|
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
|
|
self.assertTrue(can_fit)
|
|
self.assertEqual(msg, "")
|
|
|
|
|
|
class TestCheckContextBudget(unittest.TestCase):
|
|
|
|
def test_no_warning_under_threshold(self):
|
|
with patch("agent.context_budget._tracker", None):
|
|
messages = [{"role": "user", "content": "short"}]
|
|
warning = check_context_budget(messages)
|
|
self.assertIsNone(warning)
|
|
|
|
def test_warning_over_threshold(self):
|
|
with patch("agent.context_budget._tracker", None):
|
|
# Create messages that exceed 70% of default 128k context
|
|
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
|
|
warning = check_context_budget(messages)
|
|
self.assertIsNotNone(warning)
|
|
|
|
|
|
class TestStatusLine(unittest.TestCase):
|
|
|
|
def test_green_status(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
line = tracker.get_status_line()
|
|
self.assertIn("GREEN", line)
|
|
|
|
def test_red_status(self):
|
|
tracker = ContextBudgetTracker(context_limit=10000)
|
|
tracker.budget.used_tokens = 7600
|
|
line = tracker.get_status_line()
|
|
self.assertIn("RED", line)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|