Compare commits

...

1 Commits

Author SHA1 Message Date
Timmy
eaeb2eb614 feat: warm session provisioning — pre-proficient agent sessions (#327)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 40s
Empirical finding: marathon sessions have 5.7% error rate vs 9.0% for
mid-length sessions. Agents improve with experience. This module captures
successful session patterns and provisions new sessions from them.

- classify_session(): coding/research/ops/general from tool usage patterns
- score_session(): quality scoring (error rate, tool diversity, length)
- extract_warm_seed(): curated message subset prioritizing successful tool calls
- capture_template(): save session as reusable warm template
- provision_session(): load best template for new session seeding
- list_templates/load_template/get_provision_stats APIs
- Templates stored in ~/.hermes/warm_sessions/templates/
- 24 tests
2026-04-13 21:12:25 -04:00
2 changed files with 633 additions and 0 deletions

210
tests/test_warm_session.py Normal file
View File

@@ -0,0 +1,210 @@
"""Tests for warm session provisioning (#327)."""
import json
import os
import pytest
from pathlib import Path
from tempfile import mkdtemp
from unittest.mock import patch
def _make_messages(tool_names=None, error_rate=0.0, count=50):
"""Generate fake session messages for testing."""
msgs = [{"role": "system", "content": "You are helpful."}]
tools = tool_names or ["terminal", "read_file", "write_file"]
for i in range(count):
msgs.append({"role": "user", "content": f"Do task {i}"})
tool_name = tools[i % len(tools)]
tc_id = f"call_{i}"
msgs.append({
"role": "assistant",
"content": f"Working on task {i}",
"tool_calls": [{"id": tc_id, "function": {"name": tool_name, "arguments": "{}"}}],
})
if i / count < error_rate:
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "Error: something went wrong"})
else:
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": f"Result for task {i}: " + "x" * 100})
return msgs
@pytest.fixture(autouse=True)
def warm_dir(tmp_path, monkeypatch):
"""Redirect warm sessions to temp directory."""
d = tmp_path / "warm_sessions" / "templates"
d.mkdir(parents=True)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
return d
class TestClassifySession:
def test_coding(self):
from tools.warm_session import classify_session
msgs = _make_messages(["execute_code", "write_file", "read_file", "patch", "terminal"])
assert classify_session(msgs) == "coding"
def test_research(self):
from tools.warm_session import classify_session
msgs = _make_messages(["web_search", "web_extract", "session_search", "read_file"])
assert classify_session(msgs) == "research"
def test_ops(self):
from tools.warm_session import classify_session
msgs = _make_messages(["terminal", "terminal", "terminal", "cronjob", "process"])
assert classify_session(msgs) == "ops"
def test_general(self):
from tools.warm_session import classify_session
msgs = _make_messages(["read_file", "web_search", "execute_code", "cronjob", "send_message"])
assert classify_session(msgs) == "general"
def test_no_tools(self):
from tools.warm_session import classify_session
assert classify_session([{"role": "user", "content": "hi"}]) == "general"
class TestScoreSession:
def test_good_session(self):
from tools.warm_session import score_session
msgs = _make_messages(error_rate=0.0, count=50)
s = score_session(msgs)
assert s["is_proficient"] is True
assert s["is_successful"] is True
assert s["error_rate"] == 0.0
assert s["total_messages"] == 151
def test_error_session(self):
from tools.warm_session import score_session
msgs = _make_messages(error_rate=0.5, count=50)
s = score_session(msgs)
assert s["is_successful"] is False
assert s["error_rate"] > 0.3
def test_short_session(self):
from tools.warm_session import score_session
msgs = _make_messages(count=5)
s = score_session(msgs)
assert s["is_proficient"] is False
def test_diverse_tools(self):
from tools.warm_session import score_session
msgs = _make_messages(["a", "b", "c", "d", "e", "f", "g", "h"])
s = score_session(msgs)
assert s["unique_tools"] == 8
class TestExtractWarmSeed:
def test_extracts_messages(self):
from tools.warm_session import extract_warm_seed
msgs = _make_messages(count=30)
seed = extract_warm_seed(msgs, max_messages=20)
assert len(seed) <= 20
assert len(seed) > 0
def test_includes_system(self):
from tools.warm_session import extract_warm_seed
msgs = _make_messages(count=20)
seed = extract_warm_seed(msgs)
system_msgs = [m for m in seed if m.get("role") == "system"]
assert len(system_msgs) >= 1
def test_skips_error_results(self):
from tools.warm_session import extract_warm_seed
msgs = _make_messages(error_rate=0.5, count=20)
seed = extract_warm_seed(msgs)
tool_results = [m for m in seed if m.get("role") == "tool"]
for tr in tool_results:
assert "Error" not in tr.get("content", "")
class TestCaptureTemplate:
def test_captures_good_session(self):
from tools.warm_session import capture_template
msgs = _make_messages(count=50)
template = capture_template("sess_001", msgs, name="test_coding")
assert template is not None
assert template["category"] == "coding"
assert template["source_session"] == "sess_001"
def test_rejects_bad_session(self):
from tools.warm_session import capture_template
msgs = _make_messages(error_rate=0.8, count=10)
template = capture_template("sess_bad", msgs)
assert template is None
def test_saves_to_disk(self):
from tools.warm_session import capture_template, _warm_sessions_dir
msgs = _make_messages(count=50)
capture_template("sess_disk", msgs, name="test_disk")
path = _warm_sessions_dir() / "test_disk.json"
assert path.exists()
class TestListAndLoad:
def test_list_templates(self):
from tools.warm_session import capture_template, list_templates
msgs = _make_messages(count=50)
capture_template("s1", msgs, name="list_test")
templates = list_templates()
assert any(t["name"] == "list_test" for t in templates)
def test_list_by_category(self):
from tools.warm_session import capture_template, list_templates
msgs = _make_messages(["web_search", "web_extract", "read_file"], count=50)
capture_template("s2", msgs, name="research_test")
research = list_templates("research")
assert any(t["name"] == "research_test" for t in research)
coding = list_templates("coding")
assert not any(t["name"] == "research_test" for t in coding)
def test_load_template(self):
from tools.warm_session import capture_template, load_template
msgs = _make_messages(count=50)
capture_template("s3", msgs, name="load_test")
loaded = load_template("load_test")
assert loaded is not None
assert loaded["source_session"] == "s3"
def test_load_nonexistent(self):
from tools.warm_session import load_template
assert load_template("does_not_exist") is None
class TestProvisionSession:
def test_provision(self):
from tools.warm_session import capture_template, provision_session
msgs = _make_messages(count=50)
capture_template("sp1", msgs, name="prov_test")
ok, seed, name = provision_session()
assert ok is True
assert len(seed) > 0
assert name == "prov_test"
def test_provision_empty(self):
from tools.warm_session import provision_session
ok, seed, msg = provision_session()
assert ok is False
assert "No warm templates" in msg
def test_provision_by_category(self):
from tools.warm_session import capture_template, provision_session
msgs = _make_messages(["terminal", "cronjob", "process"], count=50)
capture_template("sp2", msgs, name="ops_prov")
ok, seed, name = provision_session(category="ops")
assert ok is True
class TestGetProvisionStats:
def test_empty(self):
from tools.warm_session import get_provision_stats
stats = get_provision_stats()
assert stats["total_templates"] == 0
assert stats["total_provisions"] == 0
def test_after_capture(self):
from tools.warm_session import capture_template, get_provision_stats
msgs = _make_messages(count=50)
capture_template("stat1", msgs, name="stat_test")
stats = get_provision_stats()
assert stats["total_templates"] == 1

423
tools/warm_session.py Normal file
View File

@@ -0,0 +1,423 @@
"""Warm Session Provisioning — Pre-proficient agent sessions.
Based on empirical finding: marathon sessions (100+ msgs) have LOWER per-tool
error rates (5.7%) than mid-length sessions (9.0% at 51-100 msgs). Agents
improve with experience in a session.
This module captures successful session patterns as "warm templates" and
provisions new sessions from them, giving users a pre-proficient starting
point instead of a cold start.
Architecture:
warm_sessions/
├── templates/ # Saved warm templates
│ ├── coding.json # Code-heavy session template
│ ├── research.json # Research/analysis template
│ ├── ops.json # DevOps/infrastructure template
│ └── general.json # General-purpose template
└── metrics.json # Template performance metrics
"""
import json
import logging
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Minimum messages for a session to be considered "proficient"
_MIN_PROFICIENT_MESSAGES = 20
# Maximum error rate for a session to be considered "successful"
_MAX_ERROR_RATE = 0.10
# Maximum templates to keep per category
_MAX_TEMPLATES_PER_CATEGORY = 5
def _warm_sessions_dir() -> Path:
"""Get the warm sessions directory."""
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
d = home / "warm_sessions" / "templates"
d.mkdir(parents=True, exist_ok=True)
return d
def _metrics_path() -> Path:
"""Get the metrics file path."""
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
d = home / "warm_sessions"
d.mkdir(parents=True, exist_ok=True)
return d / "metrics.json"
def _load_metrics() -> Dict[str, Any]:
"""Load template performance metrics."""
path = _metrics_path()
if path.exists():
try:
return json.loads(path.read_text())
except Exception:
pass
return {"templates": {}, "provisions": [], "version": 1}
def _save_metrics(metrics: Dict[str, Any]) -> None:
"""Save template performance metrics."""
_metrics_path().write_text(json.dumps(metrics, indent=2, default=str))
def classify_session(messages: List[Dict[str, Any]]) -> str:
"""Classify a session into a category based on tool usage patterns.
Returns one of: 'coding', 'research', 'ops', 'general'
"""
tool_calls = []
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg.get("tool_calls", []):
fn = tc.get("function", {})
name = fn.get("name", "")
tool_calls.append(name)
if not tool_calls:
return "general"
# Count tool categories
coding_tools = {"execute_code", "write_file", "read_file", "patch", "search_files"}
research_tools = {"web_search", "web_extract", "session_search", "fact_store", "read_file"}
ops_tools = {"terminal", "cronjob", "process", "send_message"}
coding_count = sum(1 for t in tool_calls if t in coding_tools)
research_count = sum(1 for t in tool_calls if t in research_tools)
ops_count = sum(1 for t in tool_calls if t in ops_tools)
total = len(tool_calls)
if coding_count / total > 0.4:
return "coding"
if research_count / total > 0.4:
return "research"
if ops_count / total > 0.4:
return "ops"
return "general"
def score_session(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Score a session for template-worthiness.
Returns dict with:
total_messages: int
tool_calls: int
tool_results: int
error_count: int (tool results containing error indicators)
error_rate: float (0.0-1.0)
unique_tools: int
diverse_tools: list
is_proficient: bool (meets minimum quality bar)
is_successful: bool (low error rate)
score: float (0.0-1.0 overall quality score)
category: str
"""
tool_calls = []
tool_results = []
errors = 0
for msg in messages:
role = msg.get("role", "")
if role == "assistant" and msg.get("tool_calls"):
for tc in msg.get("tool_calls", []):
fn = tc.get("function", {})
tool_calls.append(fn.get("name", ""))
elif role == "tool":
content = str(msg.get("content", ""))
tool_results.append(content)
# Detect errors in tool results
content_lower = content.lower()
if any(err in content_lower for err in [
"error", "traceback", "exception", "failed", "permission denied",
"not found", "command not found", "no such file",
]):
errors += 1
total = len(messages)
tc_count = len(tool_calls)
tr_count = len(tool_results)
unique_tools = list(set(tool_calls))
error_rate = errors / max(tr_count, 1)
is_proficient = total >= _MIN_PROFICIENT_MESSAGES and tc_count >= 5
is_successful = error_rate <= _MAX_ERROR_RATE and tc_count >= 3
# Score: weighted combination of message count, tool diversity, and success rate
msg_score = min(1.0, total / 100) # 100 msgs = perfect
diversity_score = min(1.0, len(unique_tools) / 8) # 8 unique tools = perfect
success_score = 1.0 - error_rate
overall_score = (msg_score * 0.2 + diversity_score * 0.3 + success_score * 0.5)
return {
"total_messages": total,
"tool_calls": tc_count,
"tool_results": tr_count,
"error_count": errors,
"error_rate": round(error_rate, 3),
"unique_tools": len(unique_tools),
"diverse_tools": unique_tools,
"is_proficient": is_proficient,
"is_successful": is_successful,
"score": round(overall_score, 3),
"category": classify_session(messages),
}
def extract_warm_seed(messages: List[Dict[str, Any]], max_messages: int = 30) -> List[Dict[str, Any]]:
"""Extract a warm seed from a session's messages.
Selects the most instructive messages — prioritizing:
1. Successful tool calls with meaningful results
2. User-agent exchanges that established patterns
3. System/meta messages that set context
Returns a curated subset suitable for seeding a new session.
"""
seed = []
successful_tool_pairs = [] # (assistant_with_tool_call, tool_result)
user_exchanges = [] # user messages with their assistant responses
i = 0
while i < len(messages):
msg = messages[i]
role = msg.get("role", "")
if role == "system":
# Include system messages for context
seed.append(msg)
elif role == "user":
# Look ahead for assistant response
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
user_exchanges.append((msg, messages[i + 1]))
elif role == "assistant" and msg.get("tool_calls"):
# Collect successful tool call + result pairs
for tc in msg.get("tool_calls", []):
tc_id = tc.get("id", "")
# Find matching tool result
for j in range(i + 1, min(i + 5, len(messages))):
if messages[j].get("role") == "tool" and messages[j].get("tool_call_id") == tc_id:
content = str(messages[j].get("content", ""))
is_error = any(e in content.lower() for e in [
"error", "traceback", "exception", "failed"
])
if not is_error and len(content) > 50:
successful_tool_pairs.append((msg, messages[j]))
break
i += 1
# Take the best successful tool examples (up to half the budget)
tool_budget = max_messages // 2
for assistant_msg, tool_msg in successful_tool_pairs[:tool_budget]:
seed.append(assistant_msg)
seed.append(tool_msg)
# Fill remaining with user exchanges
remaining = max_messages - len(seed)
for user_msg, assistant_msg in user_exchanges[:remaining // 2]:
seed.append(user_msg)
seed.append(assistant_msg)
return seed[:max_messages]
def capture_template(
session_id: str,
messages: List[Dict[str, Any]],
name: str = "",
description: str = "",
) -> Optional[Dict[str, Any]]:
"""Capture a session as a warm template.
Returns the template dict if successful, None if session doesn't meet quality bar.
"""
score_info = score_session(messages)
if not score_info["is_successful"]:
logger.info(
"Session %s not suitable for template: error_rate=%.1f%% (max %.0f%%)",
session_id, score_info["error_rate"] * 100, _MAX_ERROR_RATE * 100,
)
return None
if score_info["tool_calls"] < 3:
logger.info("Session %s not suitable: only %d tool calls", session_id, score_info["tool_calls"])
return None
seed = extract_warm_seed(messages)
if len(seed) < 5:
logger.info("Session %s: extracted seed too small (%d msgs)", session_id, len(seed))
return None
category = score_info["category"]
template_name = name or f"{category}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
template = {
"name": template_name,
"description": description or f"Warm template from session {session_id}",
"category": category,
"source_session": session_id,
"created_at": datetime.now().isoformat(),
"score_info": score_info,
"seed_messages": seed,
"seed_message_count": len(seed),
"version": 1,
}
# Save template
template_dir = _warm_sessions_dir()
template_path = template_dir / f"{template_name}.json"
template_path.write_text(json.dumps(template, indent=2, ensure_ascii=False))
# Update metrics
metrics = _load_metrics()
metrics["templates"][template_name] = {
"category": category,
"created_at": template["created_at"],
"score": score_info["score"],
"source_session": session_id,
"seed_messages": len(seed),
"provision_count": 0,
}
_save_metrics(metrics)
logger.info(
"Captured warm template '%s' (category=%s, score=%.2f, %d seed msgs from session %s)",
template_name, category, score_info["score"], len(seed), session_id,
)
return template
def list_templates(category: str = "") -> List[Dict[str, Any]]:
"""List available warm templates, optionally filtered by category."""
template_dir = _warm_sessions_dir()
templates = []
for path in sorted(template_dir.glob("*.json")):
try:
data = json.loads(path.read_text())
if category and data.get("category") != category:
continue
templates.append({
"name": data.get("name", path.stem),
"category": data.get("category", "unknown"),
"description": data.get("description", ""),
"score": data.get("score_info", {}).get("score", 0),
"seed_messages": data.get("seed_message_count", 0),
"created_at": data.get("created_at", ""),
"source_session": data.get("source_session", ""),
"path": str(path),
})
except Exception as e:
logger.debug("Failed to load template %s: %s", path, e)
# Sort by score descending
templates.sort(key=lambda t: t.get("score", 0), reverse=True)
return templates
def load_template(name: str) -> Optional[Dict[str, Any]]:
"""Load a warm template by name."""
template_dir = _warm_sessions_dir()
# Try exact match first
path = template_dir / f"{name}.json"
if path.exists():
try:
return json.loads(path.read_text())
except Exception as e:
logger.warning("Failed to load template %s: %s", name, e)
# Try category match
for path in template_dir.glob("*.json"):
try:
data = json.loads(path.read_text())
if data.get("category") == name:
return data
except Exception:
pass
return None
def provision_session(
category: str = "",
task_hint: str = "",
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""Provision a warm session by loading the best matching template.
Args:
category: Desired template category (coding, research, ops, general).
If empty, selects the highest-scored template overall.
task_hint: Description of the task to help select the best template.
Returns:
(success, seed_messages, template_name_or_error)
"""
# Find best template
if category:
templates = list_templates(category)
else:
templates = list_templates()
if not templates:
return False, [], "No warm templates available"
# Select best template
best = templates[0]
template = load_template(best["name"])
if not template:
return False, [], f"Failed to load template: {best['name']}"
seed_messages = template.get("seed_messages", [])
# Record provision event
metrics = _load_metrics()
metrics.setdefault("provisions", []).append({
"template": best["name"],
"category": best["category"],
"timestamp": datetime.now().isoformat(),
"task_hint": task_hint[:200] if task_hint else "",
})
# Update provision count
if best["name"] in metrics.get("templates", {}):
metrics["templates"][best["name"]]["provision_count"] = (
metrics["templates"][best["name"]].get("provision_count", 0) + 1
)
_save_metrics(metrics)
logger.info(
"Provisioned warm session from template '%s' (category=%s, %d seed msgs)",
best["name"], best["category"], len(seed_messages),
)
return True, seed_messages, best["name"]
def get_provision_stats() -> Dict[str, Any]:
"""Get statistics about warm session provisioning."""
metrics = _load_metrics()
templates = metrics.get("templates", {})
provisions = metrics.get("provisions", [])
total_provisions = len(provisions)
categories = {}
for t in templates.values():
cat = t.get("category", "unknown")
categories[cat] = categories.get(cat, 0) + 1
return {
"total_templates": len(templates),
"total_provisions": total_provisions,
"categories": categories,
"recent_provisions": provisions[-10:] if provisions else [],
}