Files
hermes-agent/tools/session_ab_testing.py
Alexander Whitestone 67fa881227 feat(research): Add A/B testing framework for warm vs cold sessions
Addresses #327 research questions. Framework for comparing session performance with statistical analysis.
2026-04-14 01:44:00 +00:00

518 lines
18 KiB
Python

"""
Warm Session A/B Testing Framework
Framework for comparing warm vs cold session performance.
Addresses research questions from issue #327.
Issue: #327
"""
import json
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict, field
from enum import Enum
import statistics
logger = logging.getLogger(__name__)
class SessionType(Enum):
"""Type of session for A/B testing."""
COLD = "cold" # Fresh session, no warm-up
WARM = "warm" # Session with warm-up context
@dataclass
class TestTask:
"""A task for A/B testing."""
task_id: str
description: str
prompt: str
expected_tools: List[str] = field(default_factory=list)
success_criteria: Dict[str, Any] = field(default_factory=dict)
category: str = "general"
difficulty: str = "medium" # easy, medium, hard
@dataclass
class SessionResult:
"""Result from a session test."""
session_id: str
session_type: SessionType
task_id: str
start_time: str
end_time: Optional[str] = None
message_count: int = 0
tool_calls: int = 0
successful_tool_calls: int = 0
errors: List[str] = field(default_factory=list)
completion_time_seconds: float = 0.0
user_corrections: int = 0
success: bool = False
notes: str = ""
@property
def error_rate(self) -> float:
"""Calculate error rate."""
if self.tool_calls == 0:
return 0.0
return (self.tool_calls - self.successful_tool_calls) / self.tool_calls
@property
def success_rate(self) -> float:
"""Calculate success rate."""
if self.tool_calls == 0:
return 0.0
return self.successful_tool_calls / self.tool_calls
def to_dict(self) -> Dict[str, Any]:
return {
"session_id": self.session_id,
"session_type": self.session_type.value,
"task_id": self.task_id,
"start_time": self.start_time,
"end_time": self.end_time,
"message_count": self.message_count,
"tool_calls": self.tool_calls,
"successful_tool_calls": self.successful_tool_calls,
"errors": self.errors,
"completion_time_seconds": self.completion_time_seconds,
"user_corrections": self.user_corrections,
"success": self.success,
"error_rate": self.error_rate,
"success_rate": self.success_rate,
"notes": self.notes
}
@dataclass
class ABTestResult:
"""Results from an A/B test."""
test_id: str
task: TestTask
cold_results: List[SessionResult] = field(default_factory=list)
warm_results: List[SessionResult] = field(default_factory=list)
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
def add_result(self, result: SessionResult):
"""Add a session result."""
if result.session_type == SessionType.COLD:
self.cold_results.append(result)
else:
self.warm_results.append(result)
def get_summary(self) -> Dict[str, Any]:
"""Get summary statistics."""
def calc_stats(results: List[SessionResult]) -> Dict[str, Any]:
if not results:
return {"count": 0}
error_rates = [r.error_rate for r in results]
success_rates = [r.success_rate for r in results]
completion_times = [r.completion_time_seconds for r in results if r.completion_time_seconds > 0]
message_counts = [r.message_count for r in results]
return {
"count": len(results),
"avg_error_rate": statistics.mean(error_rates) if error_rates else 0,
"avg_success_rate": statistics.mean(success_rates) if success_rates else 0,
"avg_completion_time": statistics.mean(completion_times) if completion_times else 0,
"avg_messages": statistics.mean(message_counts) if message_counts else 0,
"success_count": sum(1 for r in results if r.success)
}
cold_stats = calc_stats(self.cold_results)
warm_stats = calc_stats(self.warm_results)
# Calculate improvement
improvement = {}
if cold_stats.get("count", 0) > 0 and warm_stats.get("count", 0) > 0:
cold_error = cold_stats.get("avg_error_rate", 0)
warm_error = warm_stats.get("avg_error_rate", 0)
if cold_error > 0:
improvement["error_rate"] = (cold_error - warm_error) / cold_error
cold_success = cold_stats.get("avg_success_rate", 0)
warm_success = warm_stats.get("avg_success_rate", 0)
if cold_success > 0:
improvement["success_rate"] = (warm_success - cold_success) / cold_success
return {
"task_id": self.task.task_id,
"cold": cold_stats,
"warm": warm_stats,
"improvement": improvement,
"recommendation": self._get_recommendation(cold_stats, warm_stats)
}
def _get_recommendation(self, cold_stats: Dict, warm_stats: Dict) -> str:
"""Generate recommendation based on results."""
if cold_stats.get("count", 0) < 3 or warm_stats.get("count", 0) < 3:
return "Insufficient data (need at least 3 tests each)"
cold_error = cold_stats.get("avg_error_rate", 0)
warm_error = warm_stats.get("avg_error_rate", 0)
if warm_error < cold_error * 0.8: # 20% improvement
return "WARM recommended: Significant error reduction"
elif warm_error > cold_error * 1.2: # 20% worse
return "COLD recommended: Warm sessions performed worse"
else:
return "No significant difference detected"
def to_dict(self) -> Dict[str, Any]:
return {
"test_id": self.test_id,
"task": asdict(self.task),
"cold_results": [r.to_dict() for r in self.cold_results],
"warm_results": [r.to_dict() for r in self.warm_results],
"created_at": self.created_at,
"summary": self.get_summary()
}
class ABTestManager:
"""Manage A/B tests."""
def __init__(self, test_dir: Path = None):
self.test_dir = test_dir or Path.home() / ".hermes" / "ab_tests"
self.test_dir.mkdir(parents=True, exist_ok=True)
def create_test(self, task: TestTask) -> ABTestResult:
"""Create a new A/B test."""
test_id = f"test_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{task.task_id}"
result = ABTestResult(
test_id=test_id,
task=task
)
self.save_test(result)
return result
def save_test(self, test: ABTestResult):
"""Save test results."""
path = self.test_dir / f"{test.test_id}.json"
with open(path, 'w') as f:
json.dump(test.to_dict(), f, indent=2)
def load_test(self, test_id: str) -> Optional[ABTestResult]:
"""Load test results."""
path = self.test_dir / f"{test_id}.json"
if not path.exists():
return None
try:
with open(path, 'r') as f:
data = json.load(f)
task = TestTask(**data["task"])
test = ABTestResult(
test_id=data["test_id"],
task=task,
created_at=data.get("created_at", "")
)
for r in data.get("cold_results", []):
r["session_type"] = SessionType(r["session_type"])
test.cold_results.append(SessionResult(**r))
for r in data.get("warm_results", []):
r["session_type"] = SessionType(r["session_type"])
test.warm_results.append(SessionResult(**r))
return test
except Exception as e:
logger.error(f"Failed to load test: {e}")
return None
def list_tests(self) -> List[Dict[str, Any]]:
"""List all tests."""
tests = []
for path in self.test_dir.glob("*.json"):
try:
with open(path, 'r') as f:
data = json.load(f)
tests.append({
"test_id": data.get("test_id"),
"task_id": data.get("task", {}).get("task_id"),
"description": data.get("task", {}).get("description", ""),
"cold_count": len(data.get("cold_results", [])),
"warm_count": len(data.get("warm_results", [])),
"created_at": data.get("created_at")
})
except:
pass
return tests
def delete_test(self, test_id: str) -> bool:
"""Delete a test."""
path = self.test_dir / f"{test_id}.json"
if path.exists():
path.unlink()
return True
return False
class ABTestRunner:
"""Run A/B tests."""
def __init__(self, manager: ABTestManager = None):
self.manager = manager or ABTestManager()
def run_comparison(
self,
task: TestTask,
cold_messages: List[Dict],
warm_messages: List[Dict],
session_db=None
) -> Tuple[SessionResult, SessionResult]:
"""
Run a comparison between cold and warm sessions.
Returns:
Tuple of (cold_result, warm_result)
"""
# This is a framework - actual execution would depend on
# integration with the agent system
cold_result = SessionResult(
session_id=f"cold_{task.task_id}_{int(time.time())}",
session_type=SessionType.COLD,
task_id=task.task_id,
start_time=datetime.now().isoformat()
)
warm_result = SessionResult(
session_id=f"warm_{task.task_id}_{int(time.time())}",
session_type=SessionType.WARM,
task_id=task.task_id,
start_time=datetime.now().isoformat()
)
# In a real implementation, this would:
# 1. Start a cold session with cold_messages
# 2. Execute the task and collect metrics
# 3. Start a warm session with warm_messages
# 4. Execute the same task and collect metrics
# 5. Return both results
return cold_result, warm_result
def analyze_results(self, test_id: str) -> Dict[str, Any]:
"""Analyze test results."""
test = self.manager.load_test(test_id)
if not test:
return {"error": "Test not found"}
summary = test.get_summary()
# Add statistical significance check
if (summary["cold"].get("count", 0) >= 3 and
summary["warm"].get("count", 0) >= 3):
# Simple t-test approximation
cold_errors = [r.error_rate for r in test.cold_results]
warm_errors = [r.error_rate for r in test.warm_results]
if len(cold_errors) >= 2 and len(warm_errors) >= 2:
cold_std = statistics.stdev(cold_errors) if len(cold_errors) > 1 else 0
warm_std = statistics.stdev(warm_errors) if len(warm_errors) > 1 else 0
summary["statistical_notes"] = {
"cold_std_dev": cold_std,
"warm_std_dev": warm_std,
"significance": "low" if max(cold_std, warm_std) > 0.2 else "medium"
}
return summary
# CLI Interface
def ab_test_cli(args: List[str]) -> int:
"""CLI interface for A/B testing."""
import argparse
parser = argparse.ArgumentParser(description="Warm session A/B testing")
subparsers = parser.add_subparsers(dest="command")
# Create test
create_parser = subparsers.add_parser("create", help="Create a new test")
create_parser.add_argument("--task-id", required=True, help="Task ID")
create_parser.add_argument("--description", required=True, help="Task description")
create_parser.add_argument("--prompt", required=True, help="Test prompt")
create_parser.add_argument("--category", default="general", help="Task category")
create_parser.add_argument("--difficulty", default="medium", choices=["easy", "medium", "hard"])
# List tests
subparsers.add_parser("list", help="List all tests")
# Show test results
show_parser = subparsers.add_parser("show", help="Show test results")
show_parser.add_argument("test_id", help="Test ID")
# Analyze test
analyze_parser = subparsers.add_parser("analyze", help="Analyze test results")
analyze_parser.add_argument("test_id", help="Test ID")
# Delete test
delete_parser = subparsers.add_parser("delete", help="Delete a test")
delete_parser.add_argument("test_id", help="Test ID")
# Add result
add_parser = subparsers.add_parser("add-result", help="Add a test result")
add_parser.add_argument("test_id", help="Test ID")
add_parser.add_argument("--session-type", required=True, choices=["cold", "warm"])
add_parser.add_argument("--session-id", required=True, help="Session ID")
add_parser.add_argument("--tool-calls", type=int, default=0)
add_parser.add_argument("--successful-calls", type=int, default=0)
add_parser.add_argument("--completion-time", type=float, default=0.0)
add_parser.add_argument("--success", action="store_true")
add_parser.add_argument("--notes", default="")
parsed = parser.parse_args(args)
if not parsed.command:
parser.print_help()
return 1
manager = ABTestManager()
runner = ABTestRunner(manager)
if parsed.command == "create":
task = TestTask(
task_id=parsed.task_id,
description=parsed.description,
prompt=parsed.prompt,
category=parsed.category,
difficulty=parsed.difficulty
)
test = manager.create_test(task)
print(f"Created test: {test.test_id}")
print(f"Task: {task.description}")
return 0
elif parsed.command == "list":
tests = manager.list_tests()
if not tests:
print("No tests found.")
return 0
print("\n=== A/B Tests ===\n")
for t in tests:
print(f"ID: {t['test_id']}")
print(f" Task: {t['description']}")
print(f" Cold tests: {t['cold_count']}, Warm tests: {t['warm_count']}")
print(f" Created: {t['created_at']}")
print()
return 0
elif parsed.command == "show":
test = manager.load_test(parsed.test_id)
if not test:
print(f"Test {parsed.test_id} not found")
return 1
print(f"\n=== Test: {test.test_id} ===\n")
print(f"Task: {test.task.description}")
print(f"Prompt: {test.task.prompt}")
print(f"Category: {test.task.category}, Difficulty: {test.task.difficulty}")
print(f"\nCold sessions: {len(test.cold_results)}")
for r in test.cold_results:
print(f" {r.session_id}: {r.success_rate:.0%} success, {r.error_rate:.0%} errors")
print(f"\nWarm sessions: {len(test.warm_results)}")
for r in test.warm_results:
print(f" {r.session_id}: {r.success_rate:.0%} success, {r.error_rate:.0%} errors")
return 0
elif parsed.command == "analyze":
analysis = runner.analyze_results(parsed.test_id)
if "error" in analysis:
print(f"Error: {analysis['error']}")
return 1
print(f"\n=== Analysis: {parsed.test_id} ===\n")
cold = analysis.get("cold", {})
warm = analysis.get("warm", {})
print("Cold Sessions:")
print(f" Count: {cold.get('count', 0)}")
print(f" Avg error rate: {cold.get('avg_error_rate', 0):.1%}")
print(f" Avg success rate: {cold.get('avg_success_rate', 0):.1%}")
print(f" Avg completion time: {cold.get('avg_completion_time', 0):.1f}s")
print("\nWarm Sessions:")
print(f" Count: {warm.get('count', 0)}")
print(f" Avg error rate: {warm.get('avg_error_rate', 0):.1%}")
print(f" Avg success rate: {warm.get('avg_success_rate', 0):.1%}")
print(f" Avg completion time: {warm.get('avg_completion_time', 0):.1f}s")
improvement = analysis.get("improvement", {})
if improvement:
print("\nImprovement:")
if "error_rate" in improvement:
print(f" Error rate: {improvement['error_rate']:+.1%}")
if "success_rate" in improvement:
print(f" Success rate: {improvement['success_rate']:+.1%}")
print(f"\nRecommendation: {analysis.get('recommendation', 'N/A')}")
return 0
elif parsed.command == "delete":
if manager.delete_test(parsed.test_id):
print(f"Deleted test: {parsed.test_id}")
return 0
else:
print(f"Test {parsed.test_id} not found")
return 1
elif parsed.command == "add-result":
test = manager.load_test(parsed.test_id)
if not test:
print(f"Test {parsed.test_id} not found")
return 1
result = SessionResult(
session_id=parsed.session_id,
session_type=SessionType(parsed.session_type),
task_id=test.task.task_id,
start_time=datetime.now().isoformat(),
end_time=datetime.now().isoformat(),
tool_calls=parsed.tool_calls,
successful_tool_calls=parsed.successful_calls,
completion_time_seconds=parsed.completion_time,
success=parsed.success,
notes=parsed.notes
)
test.add_result(result)
manager.save_test(test)
print(f"Added {parsed.session_type} result to test {parsed.test_id}")
print(f" Session: {parsed.session_id}")
print(f" Success rate: {result.success_rate:.0%}")
return 0
return 1
if __name__ == "__main__":
import sys
sys.exit(ab_test_cli(sys.argv[1:]))