Compare commits
1 Commits
am/296-177
...
dispatch/3
| Author | SHA1 | Date | |
|---|---|---|---|
| 08e015d14d |
540
tests/test_warm_session_provider.py
Normal file
540
tests/test_warm_session_provider.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Integration tests for warm session provider.
|
||||
|
||||
Tests the warm session provisioning system end-to-end.
|
||||
Addresses issue #594.
|
||||
|
||||
Issue: #327, #594
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add parent directory to path for imports
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from tools.warm_session_provider import (
|
||||
WarmContext,
|
||||
WarmProfile,
|
||||
WarmSessionProvider,
|
||||
WarmSessionMiddleware
|
||||
)
|
||||
|
||||
|
||||
class TestWarmContext(unittest.TestCase):
|
||||
"""Test WarmContext dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test WarmContext creation."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test system context",
|
||||
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py", "readme.md"],
|
||||
known_tools=["terminal", "file_operations"]
|
||||
)
|
||||
|
||||
self.assertEqual(context.system_prompt_extension, "Test system context")
|
||||
self.assertEqual(len(context.successful_patterns), 1)
|
||||
self.assertEqual(context.user_preferences["message_style"], "concise")
|
||||
self.assertEqual(len(context.known_files), 2)
|
||||
self.assertEqual(len(context.known_tools), 2)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test WarmContext to_dict conversion."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test",
|
||||
successful_patterns=[{"tool": "test"}],
|
||||
user_preferences={"style": "test"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["test_tool"]
|
||||
)
|
||||
|
||||
data = context.to_dict()
|
||||
|
||||
self.assertEqual(data["system_prompt_extension"], "Test")
|
||||
self.assertEqual(len(data["successful_patterns"]), 1)
|
||||
self.assertEqual(data["user_preferences"]["style"], "test")
|
||||
self.assertEqual(len(data["known_files"]), 1)
|
||||
self.assertEqual(len(data["known_tools"]), 1)
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test WarmContext from_dict creation."""
|
||||
data = {
|
||||
"system_prompt_extension": "Test",
|
||||
"successful_patterns": [{"tool": "test"}],
|
||||
"user_preferences": {"style": "test"},
|
||||
"known_files": ["test.py"],
|
||||
"known_tools": ["test_tool"]
|
||||
}
|
||||
|
||||
context = WarmContext.from_dict(data)
|
||||
|
||||
self.assertEqual(context.system_prompt_extension, "Test")
|
||||
self.assertEqual(len(context.successful_patterns), 1)
|
||||
self.assertEqual(context.user_preferences["style"], "test")
|
||||
|
||||
|
||||
class TestWarmProfile(unittest.TestCase):
|
||||
"""Test WarmProfile dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test WarmProfile creation."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_001",
|
||||
name="Test Profile",
|
||||
description="Test description",
|
||||
context=context,
|
||||
created_from_session="session_123",
|
||||
usage_count=5,
|
||||
success_rate=0.8
|
||||
)
|
||||
|
||||
self.assertEqual(profile.profile_id, "test_001")
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.description, "Test description")
|
||||
self.assertEqual(profile.context.system_prompt_extension, "Test")
|
||||
self.assertEqual(profile.created_from_session, "session_123")
|
||||
self.assertEqual(profile.usage_count, 5)
|
||||
self.assertEqual(profile.success_rate, 0.8)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test WarmProfile to_dict conversion."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_001",
|
||||
name="Test Profile",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
data = profile.to_dict()
|
||||
|
||||
self.assertEqual(data["profile_id"], "test_001")
|
||||
self.assertEqual(data["name"], "Test Profile")
|
||||
self.assertEqual(data["description"], "Test description")
|
||||
self.assertEqual(data["context"]["system_prompt_extension"], "Test")
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test WarmProfile from_dict creation."""
|
||||
data = {
|
||||
"profile_id": "test_001",
|
||||
"name": "Test Profile",
|
||||
"description": "Test description",
|
||||
"context": {
|
||||
"system_prompt_extension": "Test",
|
||||
"successful_patterns": [],
|
||||
"user_preferences": {},
|
||||
"known_files": [],
|
||||
"known_tools": []
|
||||
},
|
||||
"created_from_session": "session_123",
|
||||
"usage_count": 5,
|
||||
"success_rate": 0.8
|
||||
}
|
||||
|
||||
profile = WarmProfile.from_dict(data)
|
||||
|
||||
self.assertEqual(profile.profile_id, "test_001")
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.context.system_prompt_extension, "Test")
|
||||
self.assertEqual(profile.usage_count, 5)
|
||||
self.assertEqual(profile.success_rate, 0.8)
|
||||
|
||||
|
||||
class TestWarmSessionProvider(unittest.TestCase):
|
||||
"""Test WarmSessionProvider."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_extract_profile(self):
|
||||
"""Test profile extraction from session."""
|
||||
# Mock session DB
|
||||
mock_session_db = Mock()
|
||||
|
||||
# Mock messages
|
||||
mock_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Help me with Python"},
|
||||
{"role": "assistant", "content": "I'll help you with Python.", "tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python --version"}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Python 3.11.0"},
|
||||
{"role": "user", "content": "Thanks!"}
|
||||
]
|
||||
|
||||
mock_session_db.get_messages.return_value = mock_messages
|
||||
|
||||
# Extract profile
|
||||
profile = self.provider.extract_profile(mock_session_db, "session_123", "Test Profile")
|
||||
|
||||
self.assertIsNotNone(profile)
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.created_from_session, "session_123")
|
||||
self.assertEqual(len(profile.context.successful_patterns), 1)
|
||||
self.assertEqual(profile.context.successful_patterns[0]["tool"], "terminal")
|
||||
self.assertIn("terminal", profile.context.known_tools)
|
||||
|
||||
def test_save_and_load_profile(self):
|
||||
"""Test saving and loading profiles."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test context",
|
||||
successful_patterns=[{"tool": "test"}],
|
||||
user_preferences={"style": "test"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["test_tool"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_save_load",
|
||||
name="Test Save Load",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Load profile
|
||||
loaded_profile = self.provider.load_profile("test_save_load")
|
||||
|
||||
self.assertIsNotNone(loaded_profile)
|
||||
self.assertEqual(loaded_profile.profile_id, "test_save_load")
|
||||
self.assertEqual(loaded_profile.name, "Test Save Load")
|
||||
self.assertEqual(loaded_profile.context.system_prompt_extension, "Test context")
|
||||
|
||||
def test_list_profiles(self):
|
||||
"""Test listing profiles."""
|
||||
# Create multiple profiles
|
||||
for i in range(3):
|
||||
context = WarmContext(system_prompt_extension=f"Context {i}")
|
||||
profile = WarmProfile(
|
||||
profile_id=f"test_list_{i}",
|
||||
name=f"Test Profile {i}",
|
||||
description=f"Description {i}",
|
||||
context=context
|
||||
)
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# List profiles
|
||||
profiles = self.provider.list_profiles()
|
||||
|
||||
self.assertEqual(len(profiles), 3)
|
||||
profile_ids = [p["profile_id"] for p in profiles]
|
||||
self.assertIn("test_list_0", profile_ids)
|
||||
self.assertIn("test_list_1", profile_ids)
|
||||
self.assertIn("test_list_2", profile_ids)
|
||||
|
||||
def test_delete_profile(self):
|
||||
"""Test deleting profiles."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_delete",
|
||||
name="Test Delete",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Verify it exists
|
||||
self.assertIsNotNone(self.provider.load_profile("test_delete"))
|
||||
|
||||
# Delete profile
|
||||
result = self.provider.delete_profile("test_delete")
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNone(self.provider.load_profile("test_delete"))
|
||||
|
||||
def test_activate_deactivate_profile(self):
|
||||
"""Test activating and deactivating profiles."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_activate",
|
||||
name="Test Activate",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Activate profile
|
||||
result = self.provider.activate_profile("test_activate")
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(self.provider.active_profile)
|
||||
self.assertEqual(self.provider.active_profile.profile_id, "test_activate")
|
||||
|
||||
# Deactivate profile
|
||||
self.provider.deactivate_profile()
|
||||
|
||||
self.assertIsNone(self.provider.active_profile)
|
||||
|
||||
def test_get_session_context(self):
|
||||
"""Test getting session context."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test system context",
|
||||
successful_patterns=[
|
||||
{"tool": "terminal", "arguments": '{"command": "ls"}'},
|
||||
{"tool": "file_operations", "arguments": '{"operation": "read"}'}
|
||||
],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py", "readme.md"],
|
||||
known_tools=["terminal", "file_operations"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_context",
|
||||
name="Test Context",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save and activate profile
|
||||
self.provider.save_profile(profile)
|
||||
self.provider.activate_profile("test_context")
|
||||
|
||||
# Get session context
|
||||
session_context = self.provider.get_session_context("Test user message")
|
||||
|
||||
self.assertIsNotNone(session_context)
|
||||
self.assertEqual(session_context["profile_id"], "test_context")
|
||||
self.assertEqual(session_context["profile_name"], "Test Context")
|
||||
self.assertIn("Test system context", session_context["system_extension"])
|
||||
self.assertEqual(len(session_context["example_messages"]), 6) # 2 patterns * 3 messages each
|
||||
self.assertEqual(session_context["user_message"], "Test user message")
|
||||
|
||||
def test_update_profile_success(self):
|
||||
"""Test updating profile success rate."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_success",
|
||||
name="Test Success",
|
||||
description="Test description",
|
||||
context=context,
|
||||
usage_count=0,
|
||||
success_rate=0.0
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Update success
|
||||
self.provider.update_profile_success("test_success", True)
|
||||
|
||||
# Reload and check
|
||||
updated_profile = self.provider.load_profile("test_success")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
# Update again
|
||||
self.provider.update_profile_success("test_success", False)
|
||||
|
||||
# Reload and check
|
||||
updated_profile = self.provider.load_profile("test_success")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 2)
|
||||
self.assertEqual(updated_profile.success_rate, 0.5)
|
||||
|
||||
|
||||
class TestWarmSessionMiddleware(unittest.TestCase):
|
||||
"""Test WarmSessionMiddleware."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
self.middleware = WarmSessionMiddleware(self.provider)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_prepare_session_without_profile(self):
|
||||
"""Test preparing session without active profile."""
|
||||
result = self.middleware.prepare_session("Test message")
|
||||
|
||||
self.assertFalse(result["warm"])
|
||||
self.assertEqual(len(result["messages"]), 1)
|
||||
self.assertEqual(result["messages"][0]["role"], "user")
|
||||
self.assertEqual(result["messages"][0]["content"], "Test message")
|
||||
|
||||
def test_prepare_session_with_profile(self):
|
||||
"""Test preparing session with active profile."""
|
||||
# Create and activate profile
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test context",
|
||||
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["terminal"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_middleware",
|
||||
name="Test Middleware",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
self.provider.save_profile(profile)
|
||||
self.provider.activate_profile("test_middleware")
|
||||
|
||||
# Prepare session
|
||||
result = self.middleware.prepare_session("Test user message")
|
||||
|
||||
self.assertTrue(result["warm"])
|
||||
self.assertEqual(result["profile_id"], "test_middleware")
|
||||
self.assertEqual(result["profile_name"], "Test Middleware")
|
||||
self.assertGreater(len(result["messages"]), 1) # Should have system + examples + user message
|
||||
|
||||
# Check system message
|
||||
system_messages = [m for m in result["messages"] if m["role"] == "system"]
|
||||
self.assertEqual(len(system_messages), 1)
|
||||
self.assertIn("Test context", system_messages[0]["content"])
|
||||
|
||||
# Check user message
|
||||
user_messages = [m for m in result["messages"] if m["role"] == "user"]
|
||||
self.assertEqual(len(user_messages), 1)
|
||||
self.assertEqual(user_messages[0]["content"], "Test user message")
|
||||
|
||||
def test_record_result(self):
|
||||
"""Test recording session result."""
|
||||
# Create profile
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_record",
|
||||
name="Test Record",
|
||||
description="Test description",
|
||||
context=context,
|
||||
usage_count=0,
|
||||
success_rate=0.0
|
||||
)
|
||||
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Record result
|
||||
self.middleware.record_result("test_record", True)
|
||||
|
||||
# Check updated profile
|
||||
updated_profile = self.provider.load_profile("test_record")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete warm session flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
self.middleware = WarmSessionMiddleware(self.provider)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_complete_flow(self):
|
||||
"""Test complete warm session flow."""
|
||||
# 1. Mock session DB with successful session
|
||||
mock_session_db = Mock()
|
||||
|
||||
mock_messages = [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Help me write a Python function"},
|
||||
{"role": "assistant", "content": "I'll help you write a Python function.", "tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python -c \"print(\'Hello\')\""}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Hello"},
|
||||
{"role": "user", "content": "Now write a function to add two numbers"},
|
||||
{"role": "assistant", "content": "Here's a function to add two numbers:", "tool_calls": [
|
||||
{"id": "call_2", "type": "function", "function": {"name": "file_operations", "arguments": '{"operation": "write", "path": "add.py", "content": "def add(a, b):\\n return a + b"}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "File written successfully"},
|
||||
{"role": "user", "content": "Thanks!"}
|
||||
]
|
||||
|
||||
mock_session_db.get_messages.return_value = mock_messages
|
||||
|
||||
# 2. Extract profile from successful session
|
||||
profile = self.provider.extract_profile(mock_session_db, "session_123", "Coding Assistant")
|
||||
|
||||
self.assertIsNotNone(profile)
|
||||
self.assertEqual(profile.name, "Coding Assistant")
|
||||
self.assertEqual(len(profile.context.successful_patterns), 2) # terminal and file_operations
|
||||
self.assertIn("terminal", profile.context.known_tools)
|
||||
self.assertIn("file_operations", profile.context.known_tools)
|
||||
|
||||
# 3. Activate profile
|
||||
result = self.provider.activate_profile(profile.profile_id)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(self.provider.active_profile)
|
||||
|
||||
# 4. Prepare warm session
|
||||
session_data = self.middleware.prepare_session("Help me write a function to multiply numbers")
|
||||
|
||||
self.assertTrue(session_data["warm"])
|
||||
self.assertEqual(session_data["profile_id"], profile.profile_id)
|
||||
self.assertGreater(len(session_data["messages"]), 1)
|
||||
|
||||
# 5. Simulate session completion and record result
|
||||
self.middleware.record_result(profile.profile_id, True)
|
||||
|
||||
# 6. Verify profile was updated
|
||||
updated_profile = self.provider.load_profile(profile.profile_id)
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
# 7. List profiles
|
||||
profiles = self.provider.list_profiles()
|
||||
|
||||
self.assertEqual(len(profiles), 1)
|
||||
self.assertEqual(profiles[0]["profile_id"], profile.profile_id)
|
||||
|
||||
# 8. Get context
|
||||
context = self.provider.get_session_context("Test message")
|
||||
|
||||
self.assertIsNotNone(context)
|
||||
self.assertEqual(context["profile_id"], profile.profile_id)
|
||||
self.assertIn("coding assistant", context["system_extension"].lower())
|
||||
|
||||
# 9. Deactivate profile
|
||||
self.provider.deactivate_profile()
|
||||
|
||||
self.assertIsNone(self.provider.active_profile)
|
||||
|
||||
# 10. Delete profile
|
||||
result = self.provider.delete_profile(profile.profile_id)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNone(self.provider.load_profile(profile.profile_id))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user