Compare commits

...

1 Commits

Author SHA1 Message Date
08e015d14d test: Add integration tests for warm session provider
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 19s
Closes #594. Comprehensive tests for warm session provisioning system.
2026-04-14 15:38:04 +00:00

View File

@@ -0,0 +1,540 @@
"""
Integration tests for warm session provider.
Tests the warm session provisioning system end-to-end.
Addresses issue #594.
Issue: #327, #594
"""
import json
import os
import tempfile
import unittest
from datetime import datetime
from pathlib import Path
from unittest.mock import Mock, patch, MagicMock
# Add parent directory to path for imports
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from tools.warm_session_provider import (
WarmContext,
WarmProfile,
WarmSessionProvider,
WarmSessionMiddleware
)
class TestWarmContext(unittest.TestCase):
"""Test WarmContext dataclass."""
def test_creation(self):
"""Test WarmContext creation."""
context = WarmContext(
system_prompt_extension="Test system context",
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
user_preferences={"message_style": "concise"},
known_files=["test.py", "readme.md"],
known_tools=["terminal", "file_operations"]
)
self.assertEqual(context.system_prompt_extension, "Test system context")
self.assertEqual(len(context.successful_patterns), 1)
self.assertEqual(context.user_preferences["message_style"], "concise")
self.assertEqual(len(context.known_files), 2)
self.assertEqual(len(context.known_tools), 2)
def test_to_dict(self):
"""Test WarmContext to_dict conversion."""
context = WarmContext(
system_prompt_extension="Test",
successful_patterns=[{"tool": "test"}],
user_preferences={"style": "test"},
known_files=["test.py"],
known_tools=["test_tool"]
)
data = context.to_dict()
self.assertEqual(data["system_prompt_extension"], "Test")
self.assertEqual(len(data["successful_patterns"]), 1)
self.assertEqual(data["user_preferences"]["style"], "test")
self.assertEqual(len(data["known_files"]), 1)
self.assertEqual(len(data["known_tools"]), 1)
def test_from_dict(self):
"""Test WarmContext from_dict creation."""
data = {
"system_prompt_extension": "Test",
"successful_patterns": [{"tool": "test"}],
"user_preferences": {"style": "test"},
"known_files": ["test.py"],
"known_tools": ["test_tool"]
}
context = WarmContext.from_dict(data)
self.assertEqual(context.system_prompt_extension, "Test")
self.assertEqual(len(context.successful_patterns), 1)
self.assertEqual(context.user_preferences["style"], "test")
class TestWarmProfile(unittest.TestCase):
"""Test WarmProfile dataclass."""
def test_creation(self):
"""Test WarmProfile creation."""
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_001",
name="Test Profile",
description="Test description",
context=context,
created_from_session="session_123",
usage_count=5,
success_rate=0.8
)
self.assertEqual(profile.profile_id, "test_001")
self.assertEqual(profile.name, "Test Profile")
self.assertEqual(profile.description, "Test description")
self.assertEqual(profile.context.system_prompt_extension, "Test")
self.assertEqual(profile.created_from_session, "session_123")
self.assertEqual(profile.usage_count, 5)
self.assertEqual(profile.success_rate, 0.8)
def test_to_dict(self):
"""Test WarmProfile to_dict conversion."""
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_001",
name="Test Profile",
description="Test description",
context=context
)
data = profile.to_dict()
self.assertEqual(data["profile_id"], "test_001")
self.assertEqual(data["name"], "Test Profile")
self.assertEqual(data["description"], "Test description")
self.assertEqual(data["context"]["system_prompt_extension"], "Test")
def test_from_dict(self):
"""Test WarmProfile from_dict creation."""
data = {
"profile_id": "test_001",
"name": "Test Profile",
"description": "Test description",
"context": {
"system_prompt_extension": "Test",
"successful_patterns": [],
"user_preferences": {},
"known_files": [],
"known_tools": []
},
"created_from_session": "session_123",
"usage_count": 5,
"success_rate": 0.8
}
profile = WarmProfile.from_dict(data)
self.assertEqual(profile.profile_id, "test_001")
self.assertEqual(profile.name, "Test Profile")
self.assertEqual(profile.context.system_prompt_extension, "Test")
self.assertEqual(profile.usage_count, 5)
self.assertEqual(profile.success_rate, 0.8)
class TestWarmSessionProvider(unittest.TestCase):
"""Test WarmSessionProvider."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.provider = WarmSessionProvider(Path(self.temp_dir))
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_extract_profile(self):
"""Test profile extraction from session."""
# Mock session DB
mock_session_db = Mock()
# Mock messages
mock_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Help me with Python"},
{"role": "assistant", "content": "I'll help you with Python.", "tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python --version"}'}}
]},
{"role": "tool", "tool_call_id": "call_1", "content": "Python 3.11.0"},
{"role": "user", "content": "Thanks!"}
]
mock_session_db.get_messages.return_value = mock_messages
# Extract profile
profile = self.provider.extract_profile(mock_session_db, "session_123", "Test Profile")
self.assertIsNotNone(profile)
self.assertEqual(profile.name, "Test Profile")
self.assertEqual(profile.created_from_session, "session_123")
self.assertEqual(len(profile.context.successful_patterns), 1)
self.assertEqual(profile.context.successful_patterns[0]["tool"], "terminal")
self.assertIn("terminal", profile.context.known_tools)
def test_save_and_load_profile(self):
"""Test saving and loading profiles."""
context = WarmContext(
system_prompt_extension="Test context",
successful_patterns=[{"tool": "test"}],
user_preferences={"style": "test"},
known_files=["test.py"],
known_tools=["test_tool"]
)
profile = WarmProfile(
profile_id="test_save_load",
name="Test Save Load",
description="Test description",
context=context
)
# Save profile
self.provider.save_profile(profile)
# Load profile
loaded_profile = self.provider.load_profile("test_save_load")
self.assertIsNotNone(loaded_profile)
self.assertEqual(loaded_profile.profile_id, "test_save_load")
self.assertEqual(loaded_profile.name, "Test Save Load")
self.assertEqual(loaded_profile.context.system_prompt_extension, "Test context")
def test_list_profiles(self):
"""Test listing profiles."""
# Create multiple profiles
for i in range(3):
context = WarmContext(system_prompt_extension=f"Context {i}")
profile = WarmProfile(
profile_id=f"test_list_{i}",
name=f"Test Profile {i}",
description=f"Description {i}",
context=context
)
self.provider.save_profile(profile)
# List profiles
profiles = self.provider.list_profiles()
self.assertEqual(len(profiles), 3)
profile_ids = [p["profile_id"] for p in profiles]
self.assertIn("test_list_0", profile_ids)
self.assertIn("test_list_1", profile_ids)
self.assertIn("test_list_2", profile_ids)
def test_delete_profile(self):
"""Test deleting profiles."""
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_delete",
name="Test Delete",
description="Test description",
context=context
)
# Save profile
self.provider.save_profile(profile)
# Verify it exists
self.assertIsNotNone(self.provider.load_profile("test_delete"))
# Delete profile
result = self.provider.delete_profile("test_delete")
self.assertTrue(result)
self.assertIsNone(self.provider.load_profile("test_delete"))
def test_activate_deactivate_profile(self):
"""Test activating and deactivating profiles."""
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_activate",
name="Test Activate",
description="Test description",
context=context
)
# Save profile
self.provider.save_profile(profile)
# Activate profile
result = self.provider.activate_profile("test_activate")
self.assertTrue(result)
self.assertIsNotNone(self.provider.active_profile)
self.assertEqual(self.provider.active_profile.profile_id, "test_activate")
# Deactivate profile
self.provider.deactivate_profile()
self.assertIsNone(self.provider.active_profile)
def test_get_session_context(self):
"""Test getting session context."""
context = WarmContext(
system_prompt_extension="Test system context",
successful_patterns=[
{"tool": "terminal", "arguments": '{"command": "ls"}'},
{"tool": "file_operations", "arguments": '{"operation": "read"}'}
],
user_preferences={"message_style": "concise"},
known_files=["test.py", "readme.md"],
known_tools=["terminal", "file_operations"]
)
profile = WarmProfile(
profile_id="test_context",
name="Test Context",
description="Test description",
context=context
)
# Save and activate profile
self.provider.save_profile(profile)
self.provider.activate_profile("test_context")
# Get session context
session_context = self.provider.get_session_context("Test user message")
self.assertIsNotNone(session_context)
self.assertEqual(session_context["profile_id"], "test_context")
self.assertEqual(session_context["profile_name"], "Test Context")
self.assertIn("Test system context", session_context["system_extension"])
self.assertEqual(len(session_context["example_messages"]), 6) # 2 patterns * 3 messages each
self.assertEqual(session_context["user_message"], "Test user message")
def test_update_profile_success(self):
"""Test updating profile success rate."""
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_success",
name="Test Success",
description="Test description",
context=context,
usage_count=0,
success_rate=0.0
)
# Save profile
self.provider.save_profile(profile)
# Update success
self.provider.update_profile_success("test_success", True)
# Reload and check
updated_profile = self.provider.load_profile("test_success")
self.assertEqual(updated_profile.usage_count, 1)
self.assertEqual(updated_profile.success_rate, 1.0)
# Update again
self.provider.update_profile_success("test_success", False)
# Reload and check
updated_profile = self.provider.load_profile("test_success")
self.assertEqual(updated_profile.usage_count, 2)
self.assertEqual(updated_profile.success_rate, 0.5)
class TestWarmSessionMiddleware(unittest.TestCase):
"""Test WarmSessionMiddleware."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.provider = WarmSessionProvider(Path(self.temp_dir))
self.middleware = WarmSessionMiddleware(self.provider)
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_prepare_session_without_profile(self):
"""Test preparing session without active profile."""
result = self.middleware.prepare_session("Test message")
self.assertFalse(result["warm"])
self.assertEqual(len(result["messages"]), 1)
self.assertEqual(result["messages"][0]["role"], "user")
self.assertEqual(result["messages"][0]["content"], "Test message")
def test_prepare_session_with_profile(self):
"""Test preparing session with active profile."""
# Create and activate profile
context = WarmContext(
system_prompt_extension="Test context",
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
user_preferences={"message_style": "concise"},
known_files=["test.py"],
known_tools=["terminal"]
)
profile = WarmProfile(
profile_id="test_middleware",
name="Test Middleware",
description="Test description",
context=context
)
self.provider.save_profile(profile)
self.provider.activate_profile("test_middleware")
# Prepare session
result = self.middleware.prepare_session("Test user message")
self.assertTrue(result["warm"])
self.assertEqual(result["profile_id"], "test_middleware")
self.assertEqual(result["profile_name"], "Test Middleware")
self.assertGreater(len(result["messages"]), 1) # Should have system + examples + user message
# Check system message
system_messages = [m for m in result["messages"] if m["role"] == "system"]
self.assertEqual(len(system_messages), 1)
self.assertIn("Test context", system_messages[0]["content"])
# Check user message
user_messages = [m for m in result["messages"] if m["role"] == "user"]
self.assertEqual(len(user_messages), 1)
self.assertEqual(user_messages[0]["content"], "Test user message")
def test_record_result(self):
"""Test recording session result."""
# Create profile
context = WarmContext(system_prompt_extension="Test")
profile = WarmProfile(
profile_id="test_record",
name="Test Record",
description="Test description",
context=context,
usage_count=0,
success_rate=0.0
)
self.provider.save_profile(profile)
# Record result
self.middleware.record_result("test_record", True)
# Check updated profile
updated_profile = self.provider.load_profile("test_record")
self.assertEqual(updated_profile.usage_count, 1)
self.assertEqual(updated_profile.success_rate, 1.0)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete warm session flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.provider = WarmSessionProvider(Path(self.temp_dir))
self.middleware = WarmSessionMiddleware(self.provider)
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
def test_complete_flow(self):
"""Test complete warm session flow."""
# 1. Mock session DB with successful session
mock_session_db = Mock()
mock_messages = [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Help me write a Python function"},
{"role": "assistant", "content": "I'll help you write a Python function.", "tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python -c \"print(\'Hello\')\""}'}}
]},
{"role": "tool", "tool_call_id": "call_1", "content": "Hello"},
{"role": "user", "content": "Now write a function to add two numbers"},
{"role": "assistant", "content": "Here's a function to add two numbers:", "tool_calls": [
{"id": "call_2", "type": "function", "function": {"name": "file_operations", "arguments": '{"operation": "write", "path": "add.py", "content": "def add(a, b):\\n return a + b"}'}}
]},
{"role": "tool", "tool_call_id": "call_2", "content": "File written successfully"},
{"role": "user", "content": "Thanks!"}
]
mock_session_db.get_messages.return_value = mock_messages
# 2. Extract profile from successful session
profile = self.provider.extract_profile(mock_session_db, "session_123", "Coding Assistant")
self.assertIsNotNone(profile)
self.assertEqual(profile.name, "Coding Assistant")
self.assertEqual(len(profile.context.successful_patterns), 2) # terminal and file_operations
self.assertIn("terminal", profile.context.known_tools)
self.assertIn("file_operations", profile.context.known_tools)
# 3. Activate profile
result = self.provider.activate_profile(profile.profile_id)
self.assertTrue(result)
self.assertIsNotNone(self.provider.active_profile)
# 4. Prepare warm session
session_data = self.middleware.prepare_session("Help me write a function to multiply numbers")
self.assertTrue(session_data["warm"])
self.assertEqual(session_data["profile_id"], profile.profile_id)
self.assertGreater(len(session_data["messages"]), 1)
# 5. Simulate session completion and record result
self.middleware.record_result(profile.profile_id, True)
# 6. Verify profile was updated
updated_profile = self.provider.load_profile(profile.profile_id)
self.assertEqual(updated_profile.usage_count, 1)
self.assertEqual(updated_profile.success_rate, 1.0)
# 7. List profiles
profiles = self.provider.list_profiles()
self.assertEqual(len(profiles), 1)
self.assertEqual(profiles[0]["profile_id"], profile.profile_id)
# 8. Get context
context = self.provider.get_session_context("Test message")
self.assertIsNotNone(context)
self.assertEqual(context["profile_id"], profile.profile_id)
self.assertIn("coding assistant", context["system_extension"].lower())
# 9. Deactivate profile
self.provider.deactivate_profile()
self.assertIsNone(self.provider.active_profile)
# 10. Delete profile
result = self.provider.delete_profile(profile.profile_id)
self.assertTrue(result)
self.assertIsNone(self.provider.load_profile(profile.profile_id))
if __name__ == "__main__":
unittest.main()