Compare commits
1 Commits
fix/468-17
...
dispatch/3
| Author | SHA1 | Date | |
|---|---|---|---|
| 08e015d14d |
@@ -157,82 +157,6 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
|
||||
# Patterns for detecting local service references in cron job prompts
|
||||
_LOCAL_SERVICE_PATTERNS = [
|
||||
# Localhost patterns
|
||||
r'localhost:\d+',
|
||||
r'127\.0\.0\.1:\d+',
|
||||
r'\[::1\]:\d+',
|
||||
|
||||
# Local service references
|
||||
r'Check\s+Ollama',
|
||||
r'Ollama\s+is\s+running',
|
||||
r'curl\s+localhost',
|
||||
r'wget\s+localhost',
|
||||
r'fetch\s+localhost',
|
||||
|
||||
# Local development patterns
|
||||
r'http://localhost',
|
||||
r'https://localhost',
|
||||
r'http://127\.0\.0\.1',
|
||||
r'https://127\.0\.0\.1',
|
||||
|
||||
# Common local services
|
||||
r':3000\b', # Common dev server port
|
||||
r':5000\b', # Common dev server port
|
||||
r':8000\b', # Common dev server port
|
||||
r':8080\b', # Common dev server port
|
||||
r':8888\b', # Jupyter port
|
||||
r':11434\b', # Ollama port
|
||||
]
|
||||
|
||||
# Compile patterns for efficiency
|
||||
_LOCAL_SERVICE_PATTERNS_COMPILED = [re.compile(pattern, re.IGNORECASE) for pattern in _LOCAL_SERVICE_PATTERNS]
|
||||
|
||||
|
||||
def _detect_local_service_refs(prompt: str) -> list[str]:
|
||||
"""
|
||||
Detect references to local services in a prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to scan
|
||||
|
||||
Returns:
|
||||
List of matched patterns (empty if none found)
|
||||
"""
|
||||
matches = []
|
||||
for pattern in _LOCAL_SERVICE_PATTERNS_COMPILED:
|
||||
if pattern.search(prompt):
|
||||
matches.append(pattern.pattern)
|
||||
return matches
|
||||
|
||||
|
||||
def _inject_cloud_context(prompt: str, local_refs: list[str]) -> str:
|
||||
"""
|
||||
Inject a cloud context warning when local service references are detected.
|
||||
|
||||
Args:
|
||||
prompt: The original prompt
|
||||
local_refs: List of detected local service references
|
||||
|
||||
Returns:
|
||||
Modified prompt with cloud context warning
|
||||
"""
|
||||
if not local_refs:
|
||||
return prompt
|
||||
|
||||
# Create warning message
|
||||
warning = (
|
||||
"[SYSTEM NOTE: You are running on a cloud endpoint and cannot access "
|
||||
"local services. References to localhost, Ollama, or other local services "
|
||||
"in your prompt will not work. Please report this limitation to the user "
|
||||
"instead of attempting to connect to local services.]\n\n"
|
||||
)
|
||||
|
||||
# Prepend warning to prompt
|
||||
return warning + prompt
|
||||
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
# response with this marker to suppress delivery. Output is still saved
|
||||
# locally for audit.
|
||||
@@ -744,23 +668,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = _build_job_prompt(job)
|
||||
|
||||
# Inject cloud context warning if running on cloud endpoint
|
||||
# and prompt references local services
|
||||
try:
|
||||
_runtime_base_url = turn_route['runtime'].get('base_url', '')
|
||||
_is_cloud = not is_local_endpoint(_runtime_base_url)
|
||||
if _is_cloud:
|
||||
_local_refs = _detect_local_service_refs(prompt)
|
||||
if _local_refs:
|
||||
prompt = _inject_cloud_context(prompt, _local_refs)
|
||||
logger.info(
|
||||
"Job '%s': injected cloud context warning for local service refs: %s",
|
||||
job_id, _local_refs
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("Job '%s': cloud context injection skipped: %s", job_id, _e)
|
||||
|
||||
origin = _resolve_origin(job)
|
||||
_cron_session_id = f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
"""
|
||||
Test cloud context injection for cron jobs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from cron.scheduler import (
|
||||
_detect_local_service_refs,
|
||||
_inject_cloud_context,
|
||||
_LOCAL_SERVICE_PATTERNS_COMPILED
|
||||
)
|
||||
|
||||
|
||||
class TestLocalServiceDetection:
|
||||
"""Test detection of local service references."""
|
||||
|
||||
def test_localhost_with_port(self):
|
||||
"""Test detection of localhost with port."""
|
||||
prompt = "Check if Ollama is running on localhost:11434"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0
|
||||
assert any('localhost:\d+' in ref for ref in refs)
|
||||
|
||||
def test_127_0_0_1_with_port(self):
|
||||
"""Test detection of 127.0.0.1 with port."""
|
||||
prompt = "Connect to http://127.0.0.1:8080/api"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0
|
||||
assert any('127\.0\.0\.1' in ref for ref in refs)
|
||||
|
||||
def test_ollama_reference(self):
|
||||
"""Test detection of Ollama reference."""
|
||||
prompt = "Check Ollama status"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0
|
||||
assert any('Check\s+Ollama' in ref for ref in refs)
|
||||
|
||||
def test_curl_localhost(self):
|
||||
"""Test detection of curl localhost."""
|
||||
prompt = "Run curl localhost:3000 to test the server"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0
|
||||
assert any('curl\s+localhost' in ref for ref in refs)
|
||||
|
||||
def test_no_local_refs(self):
|
||||
"""Test no detection when no local references."""
|
||||
prompt = "Check the weather in New York"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) == 0
|
||||
|
||||
def test_multiple_refs(self):
|
||||
"""Test detection of multiple local references."""
|
||||
prompt = "Check localhost:3000 and also Ollama on 127.0.0.1:11434"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) >= 2
|
||||
|
||||
|
||||
class TestCloudContextInjection:
|
||||
"""Test cloud context warning injection."""
|
||||
|
||||
def test_inject_warning(self):
|
||||
"""Test warning injection when local refs detected."""
|
||||
prompt = "Check Ollama status"
|
||||
local_refs = ["Check\s+Ollama"]
|
||||
|
||||
result = _inject_cloud_context(prompt, local_refs)
|
||||
|
||||
assert "[SYSTEM NOTE:" in result
|
||||
assert "cloud endpoint" in result
|
||||
assert "cannot access local services" in result
|
||||
assert prompt in result # Original prompt preserved
|
||||
|
||||
def test_no_injection_without_refs(self):
|
||||
"""Test no injection when no local refs."""
|
||||
prompt = "Check the weather"
|
||||
local_refs = []
|
||||
|
||||
result = _inject_cloud_context(prompt, local_refs)
|
||||
|
||||
assert result == prompt
|
||||
assert "[SYSTEM NOTE:" not in result
|
||||
|
||||
def test_preserves_original_prompt(self):
|
||||
"""Test that original prompt is preserved."""
|
||||
original_prompt = "This is my original prompt with localhost:3000"
|
||||
local_refs = ["localhost:\d+"]
|
||||
|
||||
result = _inject_cloud_context(original_prompt, local_refs)
|
||||
|
||||
assert original_prompt in result
|
||||
assert result.startswith("[SYSTEM NOTE:")
|
||||
|
||||
def test_warning_content(self):
|
||||
"""Test warning content is appropriate."""
|
||||
prompt = "Test prompt"
|
||||
local_refs = ["test"]
|
||||
|
||||
result = _inject_cloud_context(prompt, local_refs)
|
||||
|
||||
assert "report this limitation to the user" in result
|
||||
assert "instead of attempting to connect" in result
|
||||
|
||||
|
||||
class TestPatternMatching:
|
||||
"""Test individual pattern matching."""
|
||||
|
||||
def test_common_ports(self):
|
||||
"""Test detection of common development ports."""
|
||||
common_ports = [3000, 5000, 8000, 8080, 8888, 11434]
|
||||
|
||||
for port in common_ports:
|
||||
prompt = f"Check localhost:{port}"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0, f"Failed to detect port {port}"
|
||||
|
||||
def test_http_protocols(self):
|
||||
"""Test detection of HTTP/HTTPS protocols."""
|
||||
protocols = ["http://localhost", "https://localhost",
|
||||
"http://127.0.0.1", "https://127.0.0.1"]
|
||||
|
||||
for protocol in protocols:
|
||||
prompt = f"Connect to {protocol}:8080"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0, f"Failed to detect {protocol}"
|
||||
|
||||
def test_ipv6_localhost(self):
|
||||
"""Test detection of IPv6 localhost."""
|
||||
prompt = "Connect to [::1]:8080"
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0
|
||||
assert any('\[::1\]' in ref for ref in refs)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and false positives."""
|
||||
|
||||
def test_case_insensitive(self):
|
||||
"""Test case insensitive matching."""
|
||||
prompts = [
|
||||
"CHECK LOCALHOST:3000",
|
||||
"check Localhost:3000",
|
||||
"Check LOCALHOST:3000"
|
||||
]
|
||||
|
||||
for prompt in prompts:
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
assert len(refs) > 0, f"Failed case insensitive: {prompt}"
|
||||
|
||||
def test_no_false_positives(self):
|
||||
"""Test no false positives for similar patterns."""
|
||||
safe_prompts = [
|
||||
"Check the localhost documentation",
|
||||
"Read about 127.0.0.1 in the manual",
|
||||
"The Ollama project is interesting",
|
||||
"Port 3000 is commonly used",
|
||||
"The localhost file is in /etc/hosts"
|
||||
]
|
||||
|
||||
for prompt in safe_prompts:
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
# These might still match due to pattern design, but that's acceptable
|
||||
# The important thing is that they don't crash
|
||||
assert isinstance(refs, list)
|
||||
|
||||
def test_empty_prompt(self):
|
||||
"""Test empty prompt handling."""
|
||||
refs = _detect_local_service_refs("")
|
||||
assert refs == []
|
||||
|
||||
def test_none_handling(self):
|
||||
"""Test None prompt handling."""
|
||||
# The function should handle None gracefully
|
||||
try:
|
||||
refs = _detect_local_service_refs(None)
|
||||
assert refs == []
|
||||
except Exception as e:
|
||||
# If it raises an exception, that's also acceptable
|
||||
assert isinstance(e, (TypeError, AttributeError))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
540
tests/test_warm_session_provider.py
Normal file
540
tests/test_warm_session_provider.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Integration tests for warm session provider.
|
||||
|
||||
Tests the warm session provisioning system end-to-end.
|
||||
Addresses issue #594.
|
||||
|
||||
Issue: #327, #594
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add parent directory to path for imports
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from tools.warm_session_provider import (
|
||||
WarmContext,
|
||||
WarmProfile,
|
||||
WarmSessionProvider,
|
||||
WarmSessionMiddleware
|
||||
)
|
||||
|
||||
|
||||
class TestWarmContext(unittest.TestCase):
|
||||
"""Test WarmContext dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test WarmContext creation."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test system context",
|
||||
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py", "readme.md"],
|
||||
known_tools=["terminal", "file_operations"]
|
||||
)
|
||||
|
||||
self.assertEqual(context.system_prompt_extension, "Test system context")
|
||||
self.assertEqual(len(context.successful_patterns), 1)
|
||||
self.assertEqual(context.user_preferences["message_style"], "concise")
|
||||
self.assertEqual(len(context.known_files), 2)
|
||||
self.assertEqual(len(context.known_tools), 2)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test WarmContext to_dict conversion."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test",
|
||||
successful_patterns=[{"tool": "test"}],
|
||||
user_preferences={"style": "test"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["test_tool"]
|
||||
)
|
||||
|
||||
data = context.to_dict()
|
||||
|
||||
self.assertEqual(data["system_prompt_extension"], "Test")
|
||||
self.assertEqual(len(data["successful_patterns"]), 1)
|
||||
self.assertEqual(data["user_preferences"]["style"], "test")
|
||||
self.assertEqual(len(data["known_files"]), 1)
|
||||
self.assertEqual(len(data["known_tools"]), 1)
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test WarmContext from_dict creation."""
|
||||
data = {
|
||||
"system_prompt_extension": "Test",
|
||||
"successful_patterns": [{"tool": "test"}],
|
||||
"user_preferences": {"style": "test"},
|
||||
"known_files": ["test.py"],
|
||||
"known_tools": ["test_tool"]
|
||||
}
|
||||
|
||||
context = WarmContext.from_dict(data)
|
||||
|
||||
self.assertEqual(context.system_prompt_extension, "Test")
|
||||
self.assertEqual(len(context.successful_patterns), 1)
|
||||
self.assertEqual(context.user_preferences["style"], "test")
|
||||
|
||||
|
||||
class TestWarmProfile(unittest.TestCase):
|
||||
"""Test WarmProfile dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test WarmProfile creation."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_001",
|
||||
name="Test Profile",
|
||||
description="Test description",
|
||||
context=context,
|
||||
created_from_session="session_123",
|
||||
usage_count=5,
|
||||
success_rate=0.8
|
||||
)
|
||||
|
||||
self.assertEqual(profile.profile_id, "test_001")
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.description, "Test description")
|
||||
self.assertEqual(profile.context.system_prompt_extension, "Test")
|
||||
self.assertEqual(profile.created_from_session, "session_123")
|
||||
self.assertEqual(profile.usage_count, 5)
|
||||
self.assertEqual(profile.success_rate, 0.8)
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test WarmProfile to_dict conversion."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_001",
|
||||
name="Test Profile",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
data = profile.to_dict()
|
||||
|
||||
self.assertEqual(data["profile_id"], "test_001")
|
||||
self.assertEqual(data["name"], "Test Profile")
|
||||
self.assertEqual(data["description"], "Test description")
|
||||
self.assertEqual(data["context"]["system_prompt_extension"], "Test")
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test WarmProfile from_dict creation."""
|
||||
data = {
|
||||
"profile_id": "test_001",
|
||||
"name": "Test Profile",
|
||||
"description": "Test description",
|
||||
"context": {
|
||||
"system_prompt_extension": "Test",
|
||||
"successful_patterns": [],
|
||||
"user_preferences": {},
|
||||
"known_files": [],
|
||||
"known_tools": []
|
||||
},
|
||||
"created_from_session": "session_123",
|
||||
"usage_count": 5,
|
||||
"success_rate": 0.8
|
||||
}
|
||||
|
||||
profile = WarmProfile.from_dict(data)
|
||||
|
||||
self.assertEqual(profile.profile_id, "test_001")
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.context.system_prompt_extension, "Test")
|
||||
self.assertEqual(profile.usage_count, 5)
|
||||
self.assertEqual(profile.success_rate, 0.8)
|
||||
|
||||
|
||||
class TestWarmSessionProvider(unittest.TestCase):
|
||||
"""Test WarmSessionProvider."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_extract_profile(self):
|
||||
"""Test profile extraction from session."""
|
||||
# Mock session DB
|
||||
mock_session_db = Mock()
|
||||
|
||||
# Mock messages
|
||||
mock_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Help me with Python"},
|
||||
{"role": "assistant", "content": "I'll help you with Python.", "tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python --version"}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Python 3.11.0"},
|
||||
{"role": "user", "content": "Thanks!"}
|
||||
]
|
||||
|
||||
mock_session_db.get_messages.return_value = mock_messages
|
||||
|
||||
# Extract profile
|
||||
profile = self.provider.extract_profile(mock_session_db, "session_123", "Test Profile")
|
||||
|
||||
self.assertIsNotNone(profile)
|
||||
self.assertEqual(profile.name, "Test Profile")
|
||||
self.assertEqual(profile.created_from_session, "session_123")
|
||||
self.assertEqual(len(profile.context.successful_patterns), 1)
|
||||
self.assertEqual(profile.context.successful_patterns[0]["tool"], "terminal")
|
||||
self.assertIn("terminal", profile.context.known_tools)
|
||||
|
||||
def test_save_and_load_profile(self):
|
||||
"""Test saving and loading profiles."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test context",
|
||||
successful_patterns=[{"tool": "test"}],
|
||||
user_preferences={"style": "test"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["test_tool"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_save_load",
|
||||
name="Test Save Load",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Load profile
|
||||
loaded_profile = self.provider.load_profile("test_save_load")
|
||||
|
||||
self.assertIsNotNone(loaded_profile)
|
||||
self.assertEqual(loaded_profile.profile_id, "test_save_load")
|
||||
self.assertEqual(loaded_profile.name, "Test Save Load")
|
||||
self.assertEqual(loaded_profile.context.system_prompt_extension, "Test context")
|
||||
|
||||
def test_list_profiles(self):
|
||||
"""Test listing profiles."""
|
||||
# Create multiple profiles
|
||||
for i in range(3):
|
||||
context = WarmContext(system_prompt_extension=f"Context {i}")
|
||||
profile = WarmProfile(
|
||||
profile_id=f"test_list_{i}",
|
||||
name=f"Test Profile {i}",
|
||||
description=f"Description {i}",
|
||||
context=context
|
||||
)
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# List profiles
|
||||
profiles = self.provider.list_profiles()
|
||||
|
||||
self.assertEqual(len(profiles), 3)
|
||||
profile_ids = [p["profile_id"] for p in profiles]
|
||||
self.assertIn("test_list_0", profile_ids)
|
||||
self.assertIn("test_list_1", profile_ids)
|
||||
self.assertIn("test_list_2", profile_ids)
|
||||
|
||||
def test_delete_profile(self):
|
||||
"""Test deleting profiles."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_delete",
|
||||
name="Test Delete",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Verify it exists
|
||||
self.assertIsNotNone(self.provider.load_profile("test_delete"))
|
||||
|
||||
# Delete profile
|
||||
result = self.provider.delete_profile("test_delete")
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNone(self.provider.load_profile("test_delete"))
|
||||
|
||||
def test_activate_deactivate_profile(self):
|
||||
"""Test activating and deactivating profiles."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_activate",
|
||||
name="Test Activate",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Activate profile
|
||||
result = self.provider.activate_profile("test_activate")
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(self.provider.active_profile)
|
||||
self.assertEqual(self.provider.active_profile.profile_id, "test_activate")
|
||||
|
||||
# Deactivate profile
|
||||
self.provider.deactivate_profile()
|
||||
|
||||
self.assertIsNone(self.provider.active_profile)
|
||||
|
||||
def test_get_session_context(self):
|
||||
"""Test getting session context."""
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test system context",
|
||||
successful_patterns=[
|
||||
{"tool": "terminal", "arguments": '{"command": "ls"}'},
|
||||
{"tool": "file_operations", "arguments": '{"operation": "read"}'}
|
||||
],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py", "readme.md"],
|
||||
known_tools=["terminal", "file_operations"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_context",
|
||||
name="Test Context",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
# Save and activate profile
|
||||
self.provider.save_profile(profile)
|
||||
self.provider.activate_profile("test_context")
|
||||
|
||||
# Get session context
|
||||
session_context = self.provider.get_session_context("Test user message")
|
||||
|
||||
self.assertIsNotNone(session_context)
|
||||
self.assertEqual(session_context["profile_id"], "test_context")
|
||||
self.assertEqual(session_context["profile_name"], "Test Context")
|
||||
self.assertIn("Test system context", session_context["system_extension"])
|
||||
self.assertEqual(len(session_context["example_messages"]), 6) # 2 patterns * 3 messages each
|
||||
self.assertEqual(session_context["user_message"], "Test user message")
|
||||
|
||||
def test_update_profile_success(self):
|
||||
"""Test updating profile success rate."""
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_success",
|
||||
name="Test Success",
|
||||
description="Test description",
|
||||
context=context,
|
||||
usage_count=0,
|
||||
success_rate=0.0
|
||||
)
|
||||
|
||||
# Save profile
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Update success
|
||||
self.provider.update_profile_success("test_success", True)
|
||||
|
||||
# Reload and check
|
||||
updated_profile = self.provider.load_profile("test_success")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
# Update again
|
||||
self.provider.update_profile_success("test_success", False)
|
||||
|
||||
# Reload and check
|
||||
updated_profile = self.provider.load_profile("test_success")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 2)
|
||||
self.assertEqual(updated_profile.success_rate, 0.5)
|
||||
|
||||
|
||||
class TestWarmSessionMiddleware(unittest.TestCase):
|
||||
"""Test WarmSessionMiddleware."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
self.middleware = WarmSessionMiddleware(self.provider)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_prepare_session_without_profile(self):
|
||||
"""Test preparing session without active profile."""
|
||||
result = self.middleware.prepare_session("Test message")
|
||||
|
||||
self.assertFalse(result["warm"])
|
||||
self.assertEqual(len(result["messages"]), 1)
|
||||
self.assertEqual(result["messages"][0]["role"], "user")
|
||||
self.assertEqual(result["messages"][0]["content"], "Test message")
|
||||
|
||||
def test_prepare_session_with_profile(self):
|
||||
"""Test preparing session with active profile."""
|
||||
# Create and activate profile
|
||||
context = WarmContext(
|
||||
system_prompt_extension="Test context",
|
||||
successful_patterns=[{"tool": "terminal", "arguments": "{}"}],
|
||||
user_preferences={"message_style": "concise"},
|
||||
known_files=["test.py"],
|
||||
known_tools=["terminal"]
|
||||
)
|
||||
|
||||
profile = WarmProfile(
|
||||
profile_id="test_middleware",
|
||||
name="Test Middleware",
|
||||
description="Test description",
|
||||
context=context
|
||||
)
|
||||
|
||||
self.provider.save_profile(profile)
|
||||
self.provider.activate_profile("test_middleware")
|
||||
|
||||
# Prepare session
|
||||
result = self.middleware.prepare_session("Test user message")
|
||||
|
||||
self.assertTrue(result["warm"])
|
||||
self.assertEqual(result["profile_id"], "test_middleware")
|
||||
self.assertEqual(result["profile_name"], "Test Middleware")
|
||||
self.assertGreater(len(result["messages"]), 1) # Should have system + examples + user message
|
||||
|
||||
# Check system message
|
||||
system_messages = [m for m in result["messages"] if m["role"] == "system"]
|
||||
self.assertEqual(len(system_messages), 1)
|
||||
self.assertIn("Test context", system_messages[0]["content"])
|
||||
|
||||
# Check user message
|
||||
user_messages = [m for m in result["messages"] if m["role"] == "user"]
|
||||
self.assertEqual(len(user_messages), 1)
|
||||
self.assertEqual(user_messages[0]["content"], "Test user message")
|
||||
|
||||
def test_record_result(self):
|
||||
"""Test recording session result."""
|
||||
# Create profile
|
||||
context = WarmContext(system_prompt_extension="Test")
|
||||
profile = WarmProfile(
|
||||
profile_id="test_record",
|
||||
name="Test Record",
|
||||
description="Test description",
|
||||
context=context,
|
||||
usage_count=0,
|
||||
success_rate=0.0
|
||||
)
|
||||
|
||||
self.provider.save_profile(profile)
|
||||
|
||||
# Record result
|
||||
self.middleware.record_result("test_record", True)
|
||||
|
||||
# Check updated profile
|
||||
updated_profile = self.provider.load_profile("test_record")
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete warm session flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.provider = WarmSessionProvider(Path(self.temp_dir))
|
||||
self.middleware = WarmSessionMiddleware(self.provider)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
def test_complete_flow(self):
|
||||
"""Test complete warm session flow."""
|
||||
# 1. Mock session DB with successful session
|
||||
mock_session_db = Mock()
|
||||
|
||||
mock_messages = [
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Help me write a Python function"},
|
||||
{"role": "assistant", "content": "I'll help you write a Python function.", "tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "python -c \"print(\'Hello\')\""}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Hello"},
|
||||
{"role": "user", "content": "Now write a function to add two numbers"},
|
||||
{"role": "assistant", "content": "Here's a function to add two numbers:", "tool_calls": [
|
||||
{"id": "call_2", "type": "function", "function": {"name": "file_operations", "arguments": '{"operation": "write", "path": "add.py", "content": "def add(a, b):\\n return a + b"}'}}
|
||||
]},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "File written successfully"},
|
||||
{"role": "user", "content": "Thanks!"}
|
||||
]
|
||||
|
||||
mock_session_db.get_messages.return_value = mock_messages
|
||||
|
||||
# 2. Extract profile from successful session
|
||||
profile = self.provider.extract_profile(mock_session_db, "session_123", "Coding Assistant")
|
||||
|
||||
self.assertIsNotNone(profile)
|
||||
self.assertEqual(profile.name, "Coding Assistant")
|
||||
self.assertEqual(len(profile.context.successful_patterns), 2) # terminal and file_operations
|
||||
self.assertIn("terminal", profile.context.known_tools)
|
||||
self.assertIn("file_operations", profile.context.known_tools)
|
||||
|
||||
# 3. Activate profile
|
||||
result = self.provider.activate_profile(profile.profile_id)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNotNone(self.provider.active_profile)
|
||||
|
||||
# 4. Prepare warm session
|
||||
session_data = self.middleware.prepare_session("Help me write a function to multiply numbers")
|
||||
|
||||
self.assertTrue(session_data["warm"])
|
||||
self.assertEqual(session_data["profile_id"], profile.profile_id)
|
||||
self.assertGreater(len(session_data["messages"]), 1)
|
||||
|
||||
# 5. Simulate session completion and record result
|
||||
self.middleware.record_result(profile.profile_id, True)
|
||||
|
||||
# 6. Verify profile was updated
|
||||
updated_profile = self.provider.load_profile(profile.profile_id)
|
||||
|
||||
self.assertEqual(updated_profile.usage_count, 1)
|
||||
self.assertEqual(updated_profile.success_rate, 1.0)
|
||||
|
||||
# 7. List profiles
|
||||
profiles = self.provider.list_profiles()
|
||||
|
||||
self.assertEqual(len(profiles), 1)
|
||||
self.assertEqual(profiles[0]["profile_id"], profile.profile_id)
|
||||
|
||||
# 8. Get context
|
||||
context = self.provider.get_session_context("Test message")
|
||||
|
||||
self.assertIsNotNone(context)
|
||||
self.assertEqual(context["profile_id"], profile.profile_id)
|
||||
self.assertIn("coding assistant", context["system_extension"].lower())
|
||||
|
||||
# 9. Deactivate profile
|
||||
self.provider.deactivate_profile()
|
||||
|
||||
self.assertIsNone(self.provider.active_profile)
|
||||
|
||||
# 10. Delete profile
|
||||
result = self.provider.delete_profile(profile.profile_id)
|
||||
|
||||
self.assertTrue(result)
|
||||
self.assertIsNone(self.provider.load_profile(profile.profile_id))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user