#!/usr/bin/env python3 """Test script to verify performance optimizations in run_agent.py""" import time import threading import json from unittest.mock import MagicMock, patch, mock_open def test_session_log_batching(): """Test that session logging uses batching.""" print("Testing session log batching...") from run_agent import AIAgent # Create agent with mocked client with patch('run_agent.OpenAI'): agent = AIAgent( base_url="http://localhost:8000/v1", api_key="test-key", model="gpt-4", quiet_mode=True, ) # Mock the file operations with patch('run_agent.atomic_json_write') as mock_write: # Simulate multiple rapid calls to _save_session_log messages = [{"role": "user", "content": "test"}] start = time.time() for i in range(10): agent._save_session_log(messages) elapsed = time.time() - start # Give batching time to process time.sleep(0.1) # The batching should have deferred most writes # With batching, we expect fewer actual writes than calls write_calls = mock_write.call_count print(f" 10 save calls resulted in {write_calls} actual writes") print(f" Time for 10 calls: {elapsed*1000:.2f}ms") # Should be significantly faster with batching assert elapsed < 0.1, f"Batching setup too slow: {elapsed}s" # Cleanup agent._shutdown_session_log_batcher() print(" ✓ Session log batching test passed\n") def test_hydrate_todo_caching(): """Test that _hydrate_todo_store caches results.""" print("Testing todo store hydration caching...") from run_agent import AIAgent with patch('run_agent.OpenAI'): agent = AIAgent( base_url="http://localhost:8000/v1", api_key="test-key", model="gpt-4", quiet_mode=True, ) # Create a history with a todo response history = [ {"role": "tool", "content": json.dumps({"todos": [{"id": 1, "text": "Test"}]})} ] * 50 # 50 messages # First call - should scan agent._hydrate_todo_store(history) assert agent._todo_store_hydrated == True, "Should mark as hydrated" # Second call - should skip due to caching start = time.time() agent._hydrate_todo_store(history) elapsed = time.time() - start print(f" Cached call took {elapsed*1000:.3f}ms") assert elapsed < 0.001, f"Cached call too slow: {elapsed}s" print(" ✓ Todo hydration caching test passed\n") def test_api_call_timeout(): """Test that API calls have proper timeout handling.""" print("Testing API call timeout handling...") from run_agent import AIAgent with patch('run_agent.OpenAI'): agent = AIAgent( base_url="http://localhost:8000/v1", api_key="test-key", model="gpt-4", quiet_mode=True, ) # Check that _interruptible_api_call accepts timeout parameter import inspect sig = inspect.signature(agent._interruptible_api_call) assert 'timeout' in sig.parameters, "Should accept timeout parameter" # Check default timeout value timeout_param = sig.parameters['timeout'] assert timeout_param.default == 300.0, f"Default timeout should be 300s, got {timeout_param.default}" # Check _anthropic_messages_create has timeout sig2 = inspect.signature(agent._anthropic_messages_create) assert 'timeout' in sig2.parameters, "Anthropic messages should accept timeout" print(" ✓ API call timeout test passed\n") def test_concurrent_session_writes(): """Test that concurrent session writes are handled properly.""" print("Testing concurrent session write handling...") from run_agent import AIAgent with patch('run_agent.OpenAI'): agent = AIAgent( base_url="http://localhost:8000/v1", api_key="test-key", model="gpt-4", quiet_mode=True, ) with patch('run_agent.atomic_json_write') as mock_write: messages = [{"role": "user", "content": f"test {i}"} for i in range(5)] # Simulate concurrent calls from multiple threads errors = [] def save_msg(msg): try: agent._save_session_log(msg) except Exception as e: errors.append(e) threads = [] for msg in messages: t = threading.Thread(target=save_msg, args=(msg,)) threads.append(t) t.start() for t in threads: t.join(timeout=1.0) # Cleanup agent._shutdown_session_log_batcher() # Should have no errors assert len(errors) == 0, f"Concurrent writes caused errors: {errors}" print(" ✓ Concurrent session write test passed\n") if __name__ == "__main__": print("=" * 60) print("Performance Optimizations Test Suite") print("=" * 60 + "\n") try: test_session_log_batching() test_hydrate_todo_caching() test_api_call_timeout() test_concurrent_session_writes() print("=" * 60) print("All tests passed! ✓") print("=" * 60) except Exception as e: print(f"\n✗ Test failed: {e}") import traceback traceback.print_exc() exit(1)