179 lines
5.8 KiB
Python
179 lines
5.8 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Test script to verify performance optimizations in run_agent.py"""
|
||
|
|
|
||
|
|
import time
|
||
|
|
import threading
|
||
|
|
import json
|
||
|
|
from unittest.mock import MagicMock, patch, mock_open
|
||
|
|
|
||
|
|
def test_session_log_batching():
|
||
|
|
"""Test that session logging uses batching."""
|
||
|
|
print("Testing session log batching...")
|
||
|
|
|
||
|
|
from run_agent import AIAgent
|
||
|
|
|
||
|
|
# Create agent with mocked client
|
||
|
|
with patch('run_agent.OpenAI'):
|
||
|
|
agent = AIAgent(
|
||
|
|
base_url="http://localhost:8000/v1",
|
||
|
|
api_key="test-key",
|
||
|
|
model="gpt-4",
|
||
|
|
quiet_mode=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Mock the file operations
|
||
|
|
with patch('run_agent.atomic_json_write') as mock_write:
|
||
|
|
# Simulate multiple rapid calls to _save_session_log
|
||
|
|
messages = [{"role": "user", "content": "test"}]
|
||
|
|
|
||
|
|
start = time.time()
|
||
|
|
for i in range(10):
|
||
|
|
agent._save_session_log(messages)
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
# Give batching time to process
|
||
|
|
time.sleep(0.1)
|
||
|
|
|
||
|
|
# The batching should have deferred most writes
|
||
|
|
# With batching, we expect fewer actual writes than calls
|
||
|
|
write_calls = mock_write.call_count
|
||
|
|
|
||
|
|
print(f" 10 save calls resulted in {write_calls} actual writes")
|
||
|
|
print(f" Time for 10 calls: {elapsed*1000:.2f}ms")
|
||
|
|
|
||
|
|
# Should be significantly faster with batching
|
||
|
|
assert elapsed < 0.1, f"Batching setup too slow: {elapsed}s"
|
||
|
|
|
||
|
|
# Cleanup
|
||
|
|
agent._shutdown_session_log_batcher()
|
||
|
|
|
||
|
|
print(" ✓ Session log batching test passed\n")
|
||
|
|
|
||
|
|
|
||
|
|
def test_hydrate_todo_caching():
|
||
|
|
"""Test that _hydrate_todo_store caches results."""
|
||
|
|
print("Testing todo store hydration caching...")
|
||
|
|
|
||
|
|
from run_agent import AIAgent
|
||
|
|
|
||
|
|
with patch('run_agent.OpenAI'):
|
||
|
|
agent = AIAgent(
|
||
|
|
base_url="http://localhost:8000/v1",
|
||
|
|
api_key="test-key",
|
||
|
|
model="gpt-4",
|
||
|
|
quiet_mode=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create a history with a todo response
|
||
|
|
history = [
|
||
|
|
{"role": "tool", "content": json.dumps({"todos": [{"id": 1, "text": "Test"}]})}
|
||
|
|
] * 50 # 50 messages
|
||
|
|
|
||
|
|
# First call - should scan
|
||
|
|
agent._hydrate_todo_store(history)
|
||
|
|
assert agent._todo_store_hydrated == True, "Should mark as hydrated"
|
||
|
|
|
||
|
|
# Second call - should skip due to caching
|
||
|
|
start = time.time()
|
||
|
|
agent._hydrate_todo_store(history)
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
print(f" Cached call took {elapsed*1000:.3f}ms")
|
||
|
|
assert elapsed < 0.001, f"Cached call too slow: {elapsed}s"
|
||
|
|
|
||
|
|
print(" ✓ Todo hydration caching test passed\n")
|
||
|
|
|
||
|
|
|
||
|
|
def test_api_call_timeout():
|
||
|
|
"""Test that API calls have proper timeout handling."""
|
||
|
|
print("Testing API call timeout handling...")
|
||
|
|
|
||
|
|
from run_agent import AIAgent
|
||
|
|
|
||
|
|
with patch('run_agent.OpenAI'):
|
||
|
|
agent = AIAgent(
|
||
|
|
base_url="http://localhost:8000/v1",
|
||
|
|
api_key="test-key",
|
||
|
|
model="gpt-4",
|
||
|
|
quiet_mode=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Check that _interruptible_api_call accepts timeout parameter
|
||
|
|
import inspect
|
||
|
|
sig = inspect.signature(agent._interruptible_api_call)
|
||
|
|
assert 'timeout' in sig.parameters, "Should accept timeout parameter"
|
||
|
|
|
||
|
|
# Check default timeout value
|
||
|
|
timeout_param = sig.parameters['timeout']
|
||
|
|
assert timeout_param.default == 300.0, f"Default timeout should be 300s, got {timeout_param.default}"
|
||
|
|
|
||
|
|
# Check _anthropic_messages_create has timeout
|
||
|
|
sig2 = inspect.signature(agent._anthropic_messages_create)
|
||
|
|
assert 'timeout' in sig2.parameters, "Anthropic messages should accept timeout"
|
||
|
|
|
||
|
|
print(" ✓ API call timeout test passed\n")
|
||
|
|
|
||
|
|
|
||
|
|
def test_concurrent_session_writes():
|
||
|
|
"""Test that concurrent session writes are handled properly."""
|
||
|
|
print("Testing concurrent session write handling...")
|
||
|
|
|
||
|
|
from run_agent import AIAgent
|
||
|
|
|
||
|
|
with patch('run_agent.OpenAI'):
|
||
|
|
agent = AIAgent(
|
||
|
|
base_url="http://localhost:8000/v1",
|
||
|
|
api_key="test-key",
|
||
|
|
model="gpt-4",
|
||
|
|
quiet_mode=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch('run_agent.atomic_json_write') as mock_write:
|
||
|
|
messages = [{"role": "user", "content": f"test {i}"} for i in range(5)]
|
||
|
|
|
||
|
|
# Simulate concurrent calls from multiple threads
|
||
|
|
errors = []
|
||
|
|
def save_msg(msg):
|
||
|
|
try:
|
||
|
|
agent._save_session_log(msg)
|
||
|
|
except Exception as e:
|
||
|
|
errors.append(e)
|
||
|
|
|
||
|
|
threads = []
|
||
|
|
for msg in messages:
|
||
|
|
t = threading.Thread(target=save_msg, args=(msg,))
|
||
|
|
threads.append(t)
|
||
|
|
t.start()
|
||
|
|
|
||
|
|
for t in threads:
|
||
|
|
t.join(timeout=1.0)
|
||
|
|
|
||
|
|
# Cleanup
|
||
|
|
agent._shutdown_session_log_batcher()
|
||
|
|
|
||
|
|
# Should have no errors
|
||
|
|
assert len(errors) == 0, f"Concurrent writes caused errors: {errors}"
|
||
|
|
|
||
|
|
print(" ✓ Concurrent session write test passed\n")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
print("=" * 60)
|
||
|
|
print("Performance Optimizations Test Suite")
|
||
|
|
print("=" * 60 + "\n")
|
||
|
|
|
||
|
|
try:
|
||
|
|
test_session_log_batching()
|
||
|
|
test_hydrate_todo_caching()
|
||
|
|
test_api_call_timeout()
|
||
|
|
test_concurrent_session_writes()
|
||
|
|
|
||
|
|
print("=" * 60)
|
||
|
|
print("All tests passed! ✓")
|
||
|
|
print("=" * 60)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n✗ Test failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
exit(1)
|