Integrates with existing agent system. Provides pre-contextualized sessions based on extracted patterns. Part of #327.
517 lines
18 KiB
Python
517 lines
18 KiB
Python
"""
|
|
Warm Session Provider
|
|
|
|
Production-ready warm session provisioning that integrates with the
|
|
existing agent system. Provides pre-contextualized sessions based on
|
|
extracted patterns from successful sessions.
|
|
|
|
Issue: #327
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from dataclasses import dataclass, asdict, field
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class WarmContext:
|
|
"""Context for warming up a session."""
|
|
system_prompt_extension: str = ""
|
|
successful_patterns: List[Dict[str, Any]] = field(default_factory=list)
|
|
user_preferences: Dict[str, Any] = field(default_factory=dict)
|
|
known_files: List[str] = field(default_factory=list)
|
|
known_tools: List[str] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'WarmContext':
|
|
return cls(**data)
|
|
|
|
|
|
@dataclass
|
|
class WarmProfile:
|
|
"""Profile for warm session provisioning."""
|
|
profile_id: str
|
|
name: str
|
|
description: str
|
|
context: WarmContext
|
|
created_from_session: Optional[str] = None
|
|
usage_count: int = 0
|
|
success_rate: float = 0.0
|
|
last_used: Optional[str] = None
|
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"profile_id": self.profile_id,
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"context": self.context.to_dict(),
|
|
"created_from_session": self.created_from_session,
|
|
"usage_count": self.usage_count,
|
|
"success_rate": self.success_rate,
|
|
"last_used": self.last_used,
|
|
"created_at": self.created_at
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'WarmProfile':
|
|
context = WarmContext.from_dict(data.get("context", {}))
|
|
return cls(
|
|
profile_id=data["profile_id"],
|
|
name=data["name"],
|
|
description=data["description"],
|
|
context=context,
|
|
created_from_session=data.get("created_from_session"),
|
|
usage_count=data.get("usage_count", 0),
|
|
success_rate=data.get("success_rate", 0.0),
|
|
last_used=data.get("last_used"),
|
|
created_at=data.get("created_at", datetime.now().isoformat())
|
|
)
|
|
|
|
|
|
class WarmSessionProvider:
|
|
"""Provider for warm sessions."""
|
|
|
|
def __init__(self, profile_dir: Path = None):
|
|
self.profile_dir = profile_dir or Path.home() / ".hermes" / "warm_profiles"
|
|
self.profile_dir.mkdir(parents=True, exist_ok=True)
|
|
self.active_profile: Optional[WarmProfile] = None
|
|
|
|
def extract_profile(self, session_db, session_id: str, name: str = None) -> Optional[WarmProfile]:
|
|
"""Extract a warm profile from an existing session."""
|
|
try:
|
|
messages = session_db.get_messages(session_id)
|
|
if not messages:
|
|
return None
|
|
|
|
# Extract context
|
|
context = self._extract_context(messages)
|
|
|
|
# Create profile
|
|
profile = WarmProfile(
|
|
profile_id=f"warm_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
name=name or f"Profile from {session_id[:8]}",
|
|
description=f"Extracted from session {session_id}",
|
|
context=context,
|
|
created_from_session=session_id
|
|
)
|
|
|
|
# Save profile
|
|
self.save_profile(profile)
|
|
|
|
return profile
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract profile: {e}")
|
|
return None
|
|
|
|
def _extract_context(self, messages: List[Dict]) -> WarmContext:
|
|
"""Extract context from messages."""
|
|
system_prompt_extension = ""
|
|
successful_patterns = []
|
|
user_preferences = {}
|
|
known_files = set()
|
|
known_tools = set()
|
|
|
|
# Extract system context
|
|
for msg in messages:
|
|
if msg.get("role") == "system":
|
|
content = msg.get("content", "")
|
|
if content:
|
|
system_prompt_extension = content[:1000]
|
|
break
|
|
|
|
# Extract successful tool patterns
|
|
for i, msg in enumerate(messages):
|
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
# Check if tool call was successful
|
|
for j in range(i + 1, min(i + 3, len(messages))):
|
|
if messages[j].get("role") == "tool":
|
|
content = messages[j].get("content", "")
|
|
if content and "error" not in content.lower()[:100]:
|
|
for tool_call in msg["tool_calls"]:
|
|
func = tool_call.get("function", {})
|
|
successful_patterns.append({
|
|
"tool": func.get("name"),
|
|
"arguments": func.get("arguments", "{}"),
|
|
"success": True
|
|
})
|
|
known_tools.add(func.get("name"))
|
|
break
|
|
|
|
# Extract user preferences
|
|
user_messages = [m for m in messages if m.get("role") == "user"]
|
|
if user_messages:
|
|
avg_length = sum(len(m.get("content", "")) for m in user_messages) / len(user_messages)
|
|
questions = sum(1 for m in user_messages if "?" in m.get("content", ""))
|
|
|
|
user_preferences = {
|
|
"message_style": "detailed" if avg_length > 100 else "concise",
|
|
"question_ratio": questions / len(user_messages),
|
|
"avg_message_length": avg_length
|
|
}
|
|
|
|
# Extract known files
|
|
for msg in messages:
|
|
content = msg.get("content", "")
|
|
import re
|
|
files = re.findall(r'[\w/\.]+\.[\w]+', content)
|
|
known_files.update(f for f in files if len(f) < 50)
|
|
|
|
return WarmContext(
|
|
system_prompt_extension=system_prompt_extension,
|
|
successful_patterns=successful_patterns[:10], # Limit to top 10
|
|
user_preferences=user_preferences,
|
|
known_files=list(known_files)[:20], # Limit to 20 files
|
|
known_tools=list(known_tools)[:10] # Limit to 10 tools
|
|
)
|
|
|
|
def save_profile(self, profile: WarmProfile):
|
|
"""Save a warm profile."""
|
|
profile_path = self.profile_dir / f"{profile.profile_id}.json"
|
|
with open(profile_path, 'w') as f:
|
|
json.dump(profile.to_dict(), f, indent=2)
|
|
|
|
# Update last used
|
|
profile.last_used = datetime.now().isoformat()
|
|
profile.usage_count += 1
|
|
|
|
def load_profile(self, profile_id: str) -> Optional[WarmProfile]:
|
|
"""Load a warm profile."""
|
|
profile_path = self.profile_dir / f"{profile_id}.json"
|
|
if not profile_path.exists():
|
|
return None
|
|
|
|
try:
|
|
with open(profile_path, 'r') as f:
|
|
data = json.load(f)
|
|
return WarmProfile.from_dict(data)
|
|
except Exception as e:
|
|
logger.error(f"Failed to load profile: {e}")
|
|
return None
|
|
|
|
def list_profiles(self) -> List[Dict[str, Any]]:
|
|
"""List all warm profiles."""
|
|
profiles = []
|
|
for profile_path in self.profile_dir.glob("*.json"):
|
|
try:
|
|
with open(profile_path, 'r') as f:
|
|
data = json.load(f)
|
|
profiles.append({
|
|
"profile_id": data.get("profile_id"),
|
|
"name": data.get("name"),
|
|
"description": data.get("description"),
|
|
"usage_count": data.get("usage_count", 0),
|
|
"success_rate": data.get("success_rate", 0.0),
|
|
"last_used": data.get("last_used")
|
|
})
|
|
except:
|
|
pass
|
|
return profiles
|
|
|
|
def delete_profile(self, profile_id: str) -> bool:
|
|
"""Delete a warm profile."""
|
|
profile_path = self.profile_dir / f"{profile_id}.json"
|
|
if profile_path.exists():
|
|
profile_path.unlink()
|
|
return True
|
|
return False
|
|
|
|
def activate_profile(self, profile_id: str) -> bool:
|
|
"""Activate a warm profile for use."""
|
|
profile = self.load_profile(profile_id)
|
|
if not profile:
|
|
return False
|
|
|
|
self.active_profile = profile
|
|
return True
|
|
|
|
def deactivate_profile(self):
|
|
"""Deactivate the current warm profile."""
|
|
self.active_profile = None
|
|
|
|
def get_session_context(self, user_message: str = "") -> Dict[str, Any]:
|
|
"""Get context for starting a warm session."""
|
|
if not self.active_profile:
|
|
return {}
|
|
|
|
context = self.active_profile.context
|
|
|
|
# Build system prompt extension
|
|
system_parts = []
|
|
if context.system_prompt_extension:
|
|
system_parts.append(context.system_prompt_extension)
|
|
|
|
if context.known_files:
|
|
system_parts.append(f"Known files: {', '.join(context.known_files[:10])}")
|
|
|
|
if context.known_tools:
|
|
system_parts.append(f"Familiar tools: {', '.join(context.known_tools)}")
|
|
|
|
if context.user_preferences:
|
|
style = context.user_preferences.get("message_style", "balanced")
|
|
system_parts.append(f"User prefers {style} responses.")
|
|
|
|
system_extension = "\n".join(system_parts)
|
|
|
|
# Build example messages from successful patterns
|
|
example_messages = []
|
|
if context.successful_patterns:
|
|
for i, pattern in enumerate(context.successful_patterns[:3]):
|
|
# User request
|
|
example_messages.append({
|
|
"role": "user",
|
|
"content": f"[Example {i+1}] Use {pattern['tool']}"
|
|
})
|
|
|
|
# Assistant with tool call
|
|
example_messages.append({
|
|
"role": "assistant",
|
|
"content": f"I'll use {pattern['tool']}.",
|
|
"tool_calls": [{
|
|
"id": f"example_{i}",
|
|
"type": "function",
|
|
"function": {
|
|
"name": pattern["tool"],
|
|
"arguments": pattern.get("arguments", "{}")
|
|
}
|
|
}]
|
|
})
|
|
|
|
# Tool result
|
|
example_messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": f"example_{i}",
|
|
"content": "Success"
|
|
})
|
|
|
|
return {
|
|
"system_extension": system_extension,
|
|
"example_messages": example_messages,
|
|
"user_message": user_message,
|
|
"profile_id": self.active_profile.profile_id,
|
|
"profile_name": self.active_profile.name
|
|
}
|
|
|
|
def update_profile_success(self, profile_id: str, success: bool):
|
|
"""Update profile success rate."""
|
|
profile = self.load_profile(profile_id)
|
|
if not profile:
|
|
return
|
|
|
|
# Simple moving average
|
|
if profile.usage_count > 0:
|
|
profile.success_rate = (
|
|
(profile.success_rate * (profile.usage_count - 1) + (1.0 if success else 0.0))
|
|
/ profile.usage_count
|
|
)
|
|
else:
|
|
profile.success_rate = 1.0 if success else 0.0
|
|
|
|
self.save_profile(profile)
|
|
|
|
|
|
class WarmSessionMiddleware:
|
|
"""Middleware for warm session integration."""
|
|
|
|
def __init__(self, provider: WarmSessionProvider = None):
|
|
self.provider = provider or WarmSessionProvider()
|
|
|
|
def prepare_session(self, user_message: str, profile_id: str = None) -> Dict[str, Any]:
|
|
"""Prepare a warm session."""
|
|
if profile_id:
|
|
self.provider.activate_profile(profile_id)
|
|
|
|
context = self.provider.get_session_context(user_message)
|
|
|
|
if not context:
|
|
# No warm context, return empty
|
|
return {
|
|
"warm": False,
|
|
"messages": [{"role": "user", "content": user_message}]
|
|
}
|
|
|
|
# Build messages
|
|
messages = []
|
|
|
|
# Add system extension if available
|
|
if context.get("system_extension"):
|
|
messages.append({
|
|
"role": "system",
|
|
"content": context["system_extension"]
|
|
})
|
|
|
|
# Add example messages
|
|
messages.extend(context.get("example_messages", []))
|
|
|
|
# Add user message
|
|
messages.append({
|
|
"role": "user",
|
|
"content": user_message
|
|
})
|
|
|
|
return {
|
|
"warm": True,
|
|
"profile_id": context.get("profile_id"),
|
|
"profile_name": context.get("profile_name"),
|
|
"messages": messages
|
|
}
|
|
|
|
def record_result(self, profile_id: str, success: bool):
|
|
"""Record session result for profile."""
|
|
self.provider.update_profile_success(profile_id, success)
|
|
|
|
|
|
# CLI Interface
|
|
def warm_provider_cli(args: List[str]) -> int:
|
|
"""CLI interface for warm session provider."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Warm session provider")
|
|
subparsers = parser.add_subparsers(dest="command")
|
|
|
|
# Extract profile
|
|
extract_parser = subparsers.add_parser("extract", help="Extract profile from session")
|
|
extract_parser.add_argument("session_id", help="Session ID")
|
|
extract_parser.add_argument("--name", "-n", required=True, help="Profile name")
|
|
|
|
# List profiles
|
|
subparsers.add_parser("list", help="List profiles")
|
|
|
|
# Activate profile
|
|
activate_parser = subparsers.add_parser("activate", help="Activate a profile")
|
|
activate_parser.add_argument("profile_id", help="Profile ID")
|
|
|
|
# Deactivate profile
|
|
subparsers.add_parser("deactivate", help="Deactivate current profile")
|
|
|
|
# Show current context
|
|
context_parser = subparsers.add_parser("context", help="Show current warm context")
|
|
context_parser.add_argument("--message", "-m", default="", help="User message")
|
|
|
|
# Delete profile
|
|
delete_parser = subparsers.add_parser("delete", help="Delete a profile")
|
|
delete_parser.add_argument("profile_id", help="Profile ID")
|
|
|
|
parsed = parser.parse_args(args)
|
|
|
|
if not parsed.command:
|
|
parser.print_help()
|
|
return 1
|
|
|
|
provider = WarmSessionProvider()
|
|
|
|
if parsed.command == "extract":
|
|
try:
|
|
from hermes_state import SessionDB
|
|
session_db = SessionDB()
|
|
except ImportError:
|
|
print("Error: Cannot import SessionDB")
|
|
return 1
|
|
|
|
profile = provider.extract_profile(session_db, parsed.session_id, parsed.name)
|
|
|
|
if not profile:
|
|
print(f"Failed to extract profile from session {parsed.session_id}")
|
|
return 1
|
|
|
|
print(f"Extracted profile: {profile.profile_id}")
|
|
print(f"Name: {profile.name}")
|
|
print(f"Known tools: {len(profile.context.known_tools)}")
|
|
print(f"Known files: {len(profile.context.known_files)}")
|
|
print(f"Successful patterns: {len(profile.context.successful_patterns)}")
|
|
|
|
return 0
|
|
|
|
elif parsed.command == "list":
|
|
profiles = provider.list_profiles()
|
|
|
|
if not profiles:
|
|
print("No profiles found.")
|
|
return 0
|
|
|
|
print("\n=== Warm Session Profiles ===\n")
|
|
for p in profiles:
|
|
print(f"ID: {p['profile_id']}")
|
|
print(f" Name: {p['name']}")
|
|
print(f" Description: {p['description']}")
|
|
print(f" Usage: {p['usage_count']} times, {p['success_rate']:.0%} success")
|
|
if p['last_used']:
|
|
print(f" Last used: {p['last_used']}")
|
|
print()
|
|
|
|
return 0
|
|
|
|
elif parsed.command == "activate":
|
|
if provider.activate_profile(parsed.profile_id):
|
|
print(f"Activated profile: {parsed.profile_id}")
|
|
return 0
|
|
else:
|
|
print(f"Profile {parsed.profile_id} not found")
|
|
return 1
|
|
|
|
elif parsed.command == "deactivate":
|
|
provider.deactivate_profile()
|
|
print("Deactivated current profile")
|
|
return 0
|
|
|
|
elif parsed.command == "context":
|
|
context = provider.get_session_context(parsed.message)
|
|
|
|
if not context:
|
|
print("No active warm profile")
|
|
return 0
|
|
|
|
print(f"\n=== Warm Context: {context.get('profile_name')} ===\n")
|
|
|
|
if context.get("system_extension"):
|
|
print("System Extension:")
|
|
print(context["system_extension"][:500])
|
|
print()
|
|
|
|
examples = context.get("example_messages", [])
|
|
if examples:
|
|
print(f"Example messages: {len(examples)}")
|
|
for i in range(0, len(examples), 3):
|
|
if i + 2 < len(examples):
|
|
user_msg = examples[i]
|
|
assistant_msg = examples[i+1]
|
|
tool_msg = examples[i+2]
|
|
print(f" Example {i//3 + 1}:")
|
|
print(f" User: {user_msg.get('content', '')}")
|
|
print(f" Assistant: {assistant_msg.get('content', '')}")
|
|
if assistant_msg.get("tool_calls"):
|
|
for tc in assistant_msg["tool_calls"]:
|
|
func = tc.get("function", {})
|
|
print(f" Tool: {func.get('name')}()")
|
|
print(f" Result: {tool_msg.get('content', '')[:50]}...")
|
|
|
|
if parsed.message:
|
|
print(f"\nUser message: {parsed.message}")
|
|
|
|
return 0
|
|
|
|
elif parsed.command == "delete":
|
|
if provider.delete_profile(parsed.profile_id):
|
|
print(f"Deleted profile: {parsed.profile_id}")
|
|
return 0
|
|
else:
|
|
print(f"Profile {parsed.profile_id} not found")
|
|
return 1
|
|
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
sys.exit(warm_provider_cli(sys.argv[1:]))
|