Compare commits
1 Commits
claude/iss
...
whip/327-1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eaeb2eb614 |
210
tests/test_warm_session.py
Normal file
210
tests/test_warm_session.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for warm session provisioning (#327)."""
|
||||
import json
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from tempfile import mkdtemp
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _make_messages(tool_names=None, error_rate=0.0, count=50):
|
||||
"""Generate fake session messages for testing."""
|
||||
msgs = [{"role": "system", "content": "You are helpful."}]
|
||||
tools = tool_names or ["terminal", "read_file", "write_file"]
|
||||
|
||||
for i in range(count):
|
||||
msgs.append({"role": "user", "content": f"Do task {i}"})
|
||||
tool_name = tools[i % len(tools)]
|
||||
tc_id = f"call_{i}"
|
||||
msgs.append({
|
||||
"role": "assistant",
|
||||
"content": f"Working on task {i}",
|
||||
"tool_calls": [{"id": tc_id, "function": {"name": tool_name, "arguments": "{}"}}],
|
||||
})
|
||||
if i / count < error_rate:
|
||||
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": "Error: something went wrong"})
|
||||
else:
|
||||
msgs.append({"role": "tool", "tool_call_id": tc_id, "content": f"Result for task {i}: " + "x" * 100})
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def warm_dir(tmp_path, monkeypatch):
|
||||
"""Redirect warm sessions to temp directory."""
|
||||
d = tmp_path / "warm_sessions" / "templates"
|
||||
d.mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
return d
|
||||
|
||||
|
||||
class TestClassifySession:
|
||||
def test_coding(self):
|
||||
from tools.warm_session import classify_session
|
||||
msgs = _make_messages(["execute_code", "write_file", "read_file", "patch", "terminal"])
|
||||
assert classify_session(msgs) == "coding"
|
||||
|
||||
def test_research(self):
|
||||
from tools.warm_session import classify_session
|
||||
msgs = _make_messages(["web_search", "web_extract", "session_search", "read_file"])
|
||||
assert classify_session(msgs) == "research"
|
||||
|
||||
def test_ops(self):
|
||||
from tools.warm_session import classify_session
|
||||
msgs = _make_messages(["terminal", "terminal", "terminal", "cronjob", "process"])
|
||||
assert classify_session(msgs) == "ops"
|
||||
|
||||
def test_general(self):
|
||||
from tools.warm_session import classify_session
|
||||
msgs = _make_messages(["read_file", "web_search", "execute_code", "cronjob", "send_message"])
|
||||
assert classify_session(msgs) == "general"
|
||||
|
||||
def test_no_tools(self):
|
||||
from tools.warm_session import classify_session
|
||||
assert classify_session([{"role": "user", "content": "hi"}]) == "general"
|
||||
|
||||
|
||||
class TestScoreSession:
|
||||
def test_good_session(self):
|
||||
from tools.warm_session import score_session
|
||||
msgs = _make_messages(error_rate=0.0, count=50)
|
||||
s = score_session(msgs)
|
||||
assert s["is_proficient"] is True
|
||||
assert s["is_successful"] is True
|
||||
assert s["error_rate"] == 0.0
|
||||
assert s["total_messages"] == 151
|
||||
|
||||
def test_error_session(self):
|
||||
from tools.warm_session import score_session
|
||||
msgs = _make_messages(error_rate=0.5, count=50)
|
||||
s = score_session(msgs)
|
||||
assert s["is_successful"] is False
|
||||
assert s["error_rate"] > 0.3
|
||||
|
||||
def test_short_session(self):
|
||||
from tools.warm_session import score_session
|
||||
msgs = _make_messages(count=5)
|
||||
s = score_session(msgs)
|
||||
assert s["is_proficient"] is False
|
||||
|
||||
def test_diverse_tools(self):
|
||||
from tools.warm_session import score_session
|
||||
msgs = _make_messages(["a", "b", "c", "d", "e", "f", "g", "h"])
|
||||
s = score_session(msgs)
|
||||
assert s["unique_tools"] == 8
|
||||
|
||||
|
||||
class TestExtractWarmSeed:
|
||||
def test_extracts_messages(self):
|
||||
from tools.warm_session import extract_warm_seed
|
||||
msgs = _make_messages(count=30)
|
||||
seed = extract_warm_seed(msgs, max_messages=20)
|
||||
assert len(seed) <= 20
|
||||
assert len(seed) > 0
|
||||
|
||||
def test_includes_system(self):
|
||||
from tools.warm_session import extract_warm_seed
|
||||
msgs = _make_messages(count=20)
|
||||
seed = extract_warm_seed(msgs)
|
||||
system_msgs = [m for m in seed if m.get("role") == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
|
||||
def test_skips_error_results(self):
|
||||
from tools.warm_session import extract_warm_seed
|
||||
msgs = _make_messages(error_rate=0.5, count=20)
|
||||
seed = extract_warm_seed(msgs)
|
||||
tool_results = [m for m in seed if m.get("role") == "tool"]
|
||||
for tr in tool_results:
|
||||
assert "Error" not in tr.get("content", "")
|
||||
|
||||
|
||||
class TestCaptureTemplate:
|
||||
def test_captures_good_session(self):
|
||||
from tools.warm_session import capture_template
|
||||
msgs = _make_messages(count=50)
|
||||
template = capture_template("sess_001", msgs, name="test_coding")
|
||||
assert template is not None
|
||||
assert template["category"] == "coding"
|
||||
assert template["source_session"] == "sess_001"
|
||||
|
||||
def test_rejects_bad_session(self):
|
||||
from tools.warm_session import capture_template
|
||||
msgs = _make_messages(error_rate=0.8, count=10)
|
||||
template = capture_template("sess_bad", msgs)
|
||||
assert template is None
|
||||
|
||||
def test_saves_to_disk(self):
|
||||
from tools.warm_session import capture_template, _warm_sessions_dir
|
||||
msgs = _make_messages(count=50)
|
||||
capture_template("sess_disk", msgs, name="test_disk")
|
||||
path = _warm_sessions_dir() / "test_disk.json"
|
||||
assert path.exists()
|
||||
|
||||
|
||||
class TestListAndLoad:
|
||||
def test_list_templates(self):
|
||||
from tools.warm_session import capture_template, list_templates
|
||||
msgs = _make_messages(count=50)
|
||||
capture_template("s1", msgs, name="list_test")
|
||||
templates = list_templates()
|
||||
assert any(t["name"] == "list_test" for t in templates)
|
||||
|
||||
def test_list_by_category(self):
|
||||
from tools.warm_session import capture_template, list_templates
|
||||
msgs = _make_messages(["web_search", "web_extract", "read_file"], count=50)
|
||||
capture_template("s2", msgs, name="research_test")
|
||||
research = list_templates("research")
|
||||
assert any(t["name"] == "research_test" for t in research)
|
||||
coding = list_templates("coding")
|
||||
assert not any(t["name"] == "research_test" for t in coding)
|
||||
|
||||
def test_load_template(self):
|
||||
from tools.warm_session import capture_template, load_template
|
||||
msgs = _make_messages(count=50)
|
||||
capture_template("s3", msgs, name="load_test")
|
||||
loaded = load_template("load_test")
|
||||
assert loaded is not None
|
||||
assert loaded["source_session"] == "s3"
|
||||
|
||||
def test_load_nonexistent(self):
|
||||
from tools.warm_session import load_template
|
||||
assert load_template("does_not_exist") is None
|
||||
|
||||
|
||||
class TestProvisionSession:
|
||||
def test_provision(self):
|
||||
from tools.warm_session import capture_template, provision_session
|
||||
msgs = _make_messages(count=50)
|
||||
capture_template("sp1", msgs, name="prov_test")
|
||||
ok, seed, name = provision_session()
|
||||
assert ok is True
|
||||
assert len(seed) > 0
|
||||
assert name == "prov_test"
|
||||
|
||||
def test_provision_empty(self):
|
||||
from tools.warm_session import provision_session
|
||||
ok, seed, msg = provision_session()
|
||||
assert ok is False
|
||||
assert "No warm templates" in msg
|
||||
|
||||
def test_provision_by_category(self):
|
||||
from tools.warm_session import capture_template, provision_session
|
||||
msgs = _make_messages(["terminal", "cronjob", "process"], count=50)
|
||||
capture_template("sp2", msgs, name="ops_prov")
|
||||
ok, seed, name = provision_session(category="ops")
|
||||
assert ok is True
|
||||
|
||||
|
||||
class TestGetProvisionStats:
|
||||
def test_empty(self):
|
||||
from tools.warm_session import get_provision_stats
|
||||
stats = get_provision_stats()
|
||||
assert stats["total_templates"] == 0
|
||||
assert stats["total_provisions"] == 0
|
||||
|
||||
def test_after_capture(self):
|
||||
from tools.warm_session import capture_template, get_provision_stats
|
||||
msgs = _make_messages(count=50)
|
||||
capture_template("stat1", msgs, name="stat_test")
|
||||
stats = get_provision_stats()
|
||||
assert stats["total_templates"] == 1
|
||||
423
tools/warm_session.py
Normal file
423
tools/warm_session.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""Warm Session Provisioning — Pre-proficient agent sessions.
|
||||
|
||||
Based on empirical finding: marathon sessions (100+ msgs) have LOWER per-tool
|
||||
error rates (5.7%) than mid-length sessions (9.0% at 51-100 msgs). Agents
|
||||
improve with experience in a session.
|
||||
|
||||
This module captures successful session patterns as "warm templates" and
|
||||
provisions new sessions from them, giving users a pre-proficient starting
|
||||
point instead of a cold start.
|
||||
|
||||
Architecture:
|
||||
warm_sessions/
|
||||
├── templates/ # Saved warm templates
|
||||
│ ├── coding.json # Code-heavy session template
|
||||
│ ├── research.json # Research/analysis template
|
||||
│ ├── ops.json # DevOps/infrastructure template
|
||||
│ └── general.json # General-purpose template
|
||||
└── metrics.json # Template performance metrics
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum messages for a session to be considered "proficient"
|
||||
_MIN_PROFICIENT_MESSAGES = 20
|
||||
# Maximum error rate for a session to be considered "successful"
|
||||
_MAX_ERROR_RATE = 0.10
|
||||
# Maximum templates to keep per category
|
||||
_MAX_TEMPLATES_PER_CATEGORY = 5
|
||||
|
||||
|
||||
def _warm_sessions_dir() -> Path:
|
||||
"""Get the warm sessions directory."""
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / "warm_sessions" / "templates"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
|
||||
def _metrics_path() -> Path:
|
||||
"""Get the metrics file path."""
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / "warm_sessions"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d / "metrics.json"
|
||||
|
||||
|
||||
def _load_metrics() -> Dict[str, Any]:
|
||||
"""Load template performance metrics."""
|
||||
path = _metrics_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {"templates": {}, "provisions": [], "version": 1}
|
||||
|
||||
|
||||
def _save_metrics(metrics: Dict[str, Any]) -> None:
|
||||
"""Save template performance metrics."""
|
||||
_metrics_path().write_text(json.dumps(metrics, indent=2, default=str))
|
||||
|
||||
|
||||
def classify_session(messages: List[Dict[str, Any]]) -> str:
|
||||
"""Classify a session into a category based on tool usage patterns.
|
||||
|
||||
Returns one of: 'coding', 'research', 'ops', 'general'
|
||||
"""
|
||||
tool_calls = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg.get("tool_calls", []):
|
||||
fn = tc.get("function", {})
|
||||
name = fn.get("name", "")
|
||||
tool_calls.append(name)
|
||||
|
||||
if not tool_calls:
|
||||
return "general"
|
||||
|
||||
# Count tool categories
|
||||
coding_tools = {"execute_code", "write_file", "read_file", "patch", "search_files"}
|
||||
research_tools = {"web_search", "web_extract", "session_search", "fact_store", "read_file"}
|
||||
ops_tools = {"terminal", "cronjob", "process", "send_message"}
|
||||
|
||||
coding_count = sum(1 for t in tool_calls if t in coding_tools)
|
||||
research_count = sum(1 for t in tool_calls if t in research_tools)
|
||||
ops_count = sum(1 for t in tool_calls if t in ops_tools)
|
||||
|
||||
total = len(tool_calls)
|
||||
if coding_count / total > 0.4:
|
||||
return "coding"
|
||||
if research_count / total > 0.4:
|
||||
return "research"
|
||||
if ops_count / total > 0.4:
|
||||
return "ops"
|
||||
return "general"
|
||||
|
||||
|
||||
def score_session(messages: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Score a session for template-worthiness.
|
||||
|
||||
Returns dict with:
|
||||
total_messages: int
|
||||
tool_calls: int
|
||||
tool_results: int
|
||||
error_count: int (tool results containing error indicators)
|
||||
error_rate: float (0.0-1.0)
|
||||
unique_tools: int
|
||||
diverse_tools: list
|
||||
is_proficient: bool (meets minimum quality bar)
|
||||
is_successful: bool (low error rate)
|
||||
score: float (0.0-1.0 overall quality score)
|
||||
category: str
|
||||
"""
|
||||
tool_calls = []
|
||||
tool_results = []
|
||||
errors = 0
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
if role == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg.get("tool_calls", []):
|
||||
fn = tc.get("function", {})
|
||||
tool_calls.append(fn.get("name", ""))
|
||||
elif role == "tool":
|
||||
content = str(msg.get("content", ""))
|
||||
tool_results.append(content)
|
||||
# Detect errors in tool results
|
||||
content_lower = content.lower()
|
||||
if any(err in content_lower for err in [
|
||||
"error", "traceback", "exception", "failed", "permission denied",
|
||||
"not found", "command not found", "no such file",
|
||||
]):
|
||||
errors += 1
|
||||
|
||||
total = len(messages)
|
||||
tc_count = len(tool_calls)
|
||||
tr_count = len(tool_results)
|
||||
unique_tools = list(set(tool_calls))
|
||||
error_rate = errors / max(tr_count, 1)
|
||||
|
||||
is_proficient = total >= _MIN_PROFICIENT_MESSAGES and tc_count >= 5
|
||||
is_successful = error_rate <= _MAX_ERROR_RATE and tc_count >= 3
|
||||
|
||||
# Score: weighted combination of message count, tool diversity, and success rate
|
||||
msg_score = min(1.0, total / 100) # 100 msgs = perfect
|
||||
diversity_score = min(1.0, len(unique_tools) / 8) # 8 unique tools = perfect
|
||||
success_score = 1.0 - error_rate
|
||||
|
||||
overall_score = (msg_score * 0.2 + diversity_score * 0.3 + success_score * 0.5)
|
||||
|
||||
return {
|
||||
"total_messages": total,
|
||||
"tool_calls": tc_count,
|
||||
"tool_results": tr_count,
|
||||
"error_count": errors,
|
||||
"error_rate": round(error_rate, 3),
|
||||
"unique_tools": len(unique_tools),
|
||||
"diverse_tools": unique_tools,
|
||||
"is_proficient": is_proficient,
|
||||
"is_successful": is_successful,
|
||||
"score": round(overall_score, 3),
|
||||
"category": classify_session(messages),
|
||||
}
|
||||
|
||||
|
||||
def extract_warm_seed(messages: List[Dict[str, Any]], max_messages: int = 30) -> List[Dict[str, Any]]:
|
||||
"""Extract a warm seed from a session's messages.
|
||||
|
||||
Selects the most instructive messages — prioritizing:
|
||||
1. Successful tool calls with meaningful results
|
||||
2. User-agent exchanges that established patterns
|
||||
3. System/meta messages that set context
|
||||
|
||||
Returns a curated subset suitable for seeding a new session.
|
||||
"""
|
||||
seed = []
|
||||
successful_tool_pairs = [] # (assistant_with_tool_call, tool_result)
|
||||
user_exchanges = [] # user messages with their assistant responses
|
||||
|
||||
i = 0
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
role = msg.get("role", "")
|
||||
|
||||
if role == "system":
|
||||
# Include system messages for context
|
||||
seed.append(msg)
|
||||
elif role == "user":
|
||||
# Look ahead for assistant response
|
||||
if i + 1 < len(messages) and messages[i + 1].get("role") == "assistant":
|
||||
user_exchanges.append((msg, messages[i + 1]))
|
||||
elif role == "assistant" and msg.get("tool_calls"):
|
||||
# Collect successful tool call + result pairs
|
||||
for tc in msg.get("tool_calls", []):
|
||||
tc_id = tc.get("id", "")
|
||||
# Find matching tool result
|
||||
for j in range(i + 1, min(i + 5, len(messages))):
|
||||
if messages[j].get("role") == "tool" and messages[j].get("tool_call_id") == tc_id:
|
||||
content = str(messages[j].get("content", ""))
|
||||
is_error = any(e in content.lower() for e in [
|
||||
"error", "traceback", "exception", "failed"
|
||||
])
|
||||
if not is_error and len(content) > 50:
|
||||
successful_tool_pairs.append((msg, messages[j]))
|
||||
break
|
||||
i += 1
|
||||
|
||||
# Take the best successful tool examples (up to half the budget)
|
||||
tool_budget = max_messages // 2
|
||||
for assistant_msg, tool_msg in successful_tool_pairs[:tool_budget]:
|
||||
seed.append(assistant_msg)
|
||||
seed.append(tool_msg)
|
||||
|
||||
# Fill remaining with user exchanges
|
||||
remaining = max_messages - len(seed)
|
||||
for user_msg, assistant_msg in user_exchanges[:remaining // 2]:
|
||||
seed.append(user_msg)
|
||||
seed.append(assistant_msg)
|
||||
|
||||
return seed[:max_messages]
|
||||
|
||||
|
||||
def capture_template(
|
||||
session_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
name: str = "",
|
||||
description: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Capture a session as a warm template.
|
||||
|
||||
Returns the template dict if successful, None if session doesn't meet quality bar.
|
||||
"""
|
||||
score_info = score_session(messages)
|
||||
|
||||
if not score_info["is_successful"]:
|
||||
logger.info(
|
||||
"Session %s not suitable for template: error_rate=%.1f%% (max %.0f%%)",
|
||||
session_id, score_info["error_rate"] * 100, _MAX_ERROR_RATE * 100,
|
||||
)
|
||||
return None
|
||||
|
||||
if score_info["tool_calls"] < 3:
|
||||
logger.info("Session %s not suitable: only %d tool calls", session_id, score_info["tool_calls"])
|
||||
return None
|
||||
|
||||
seed = extract_warm_seed(messages)
|
||||
if len(seed) < 5:
|
||||
logger.info("Session %s: extracted seed too small (%d msgs)", session_id, len(seed))
|
||||
return None
|
||||
|
||||
category = score_info["category"]
|
||||
template_name = name or f"{category}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
template = {
|
||||
"name": template_name,
|
||||
"description": description or f"Warm template from session {session_id}",
|
||||
"category": category,
|
||||
"source_session": session_id,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"score_info": score_info,
|
||||
"seed_messages": seed,
|
||||
"seed_message_count": len(seed),
|
||||
"version": 1,
|
||||
}
|
||||
|
||||
# Save template
|
||||
template_dir = _warm_sessions_dir()
|
||||
template_path = template_dir / f"{template_name}.json"
|
||||
template_path.write_text(json.dumps(template, indent=2, ensure_ascii=False))
|
||||
|
||||
# Update metrics
|
||||
metrics = _load_metrics()
|
||||
metrics["templates"][template_name] = {
|
||||
"category": category,
|
||||
"created_at": template["created_at"],
|
||||
"score": score_info["score"],
|
||||
"source_session": session_id,
|
||||
"seed_messages": len(seed),
|
||||
"provision_count": 0,
|
||||
}
|
||||
_save_metrics(metrics)
|
||||
|
||||
logger.info(
|
||||
"Captured warm template '%s' (category=%s, score=%.2f, %d seed msgs from session %s)",
|
||||
template_name, category, score_info["score"], len(seed), session_id,
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def list_templates(category: str = "") -> List[Dict[str, Any]]:
|
||||
"""List available warm templates, optionally filtered by category."""
|
||||
template_dir = _warm_sessions_dir()
|
||||
templates = []
|
||||
|
||||
for path in sorted(template_dir.glob("*.json")):
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
if category and data.get("category") != category:
|
||||
continue
|
||||
templates.append({
|
||||
"name": data.get("name", path.stem),
|
||||
"category": data.get("category", "unknown"),
|
||||
"description": data.get("description", ""),
|
||||
"score": data.get("score_info", {}).get("score", 0),
|
||||
"seed_messages": data.get("seed_message_count", 0),
|
||||
"created_at": data.get("created_at", ""),
|
||||
"source_session": data.get("source_session", ""),
|
||||
"path": str(path),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug("Failed to load template %s: %s", path, e)
|
||||
|
||||
# Sort by score descending
|
||||
templates.sort(key=lambda t: t.get("score", 0), reverse=True)
|
||||
return templates
|
||||
|
||||
|
||||
def load_template(name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load a warm template by name."""
|
||||
template_dir = _warm_sessions_dir()
|
||||
|
||||
# Try exact match first
|
||||
path = template_dir / f"{name}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load template %s: %s", name, e)
|
||||
|
||||
# Try category match
|
||||
for path in template_dir.glob("*.json"):
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
if data.get("category") == name:
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def provision_session(
|
||||
category: str = "",
|
||||
task_hint: str = "",
|
||||
) -> Tuple[bool, List[Dict[str, Any]], str]:
|
||||
"""Provision a warm session by loading the best matching template.
|
||||
|
||||
Args:
|
||||
category: Desired template category (coding, research, ops, general).
|
||||
If empty, selects the highest-scored template overall.
|
||||
task_hint: Description of the task to help select the best template.
|
||||
|
||||
Returns:
|
||||
(success, seed_messages, template_name_or_error)
|
||||
"""
|
||||
# Find best template
|
||||
if category:
|
||||
templates = list_templates(category)
|
||||
else:
|
||||
templates = list_templates()
|
||||
|
||||
if not templates:
|
||||
return False, [], "No warm templates available"
|
||||
|
||||
# Select best template
|
||||
best = templates[0]
|
||||
template = load_template(best["name"])
|
||||
|
||||
if not template:
|
||||
return False, [], f"Failed to load template: {best['name']}"
|
||||
|
||||
seed_messages = template.get("seed_messages", [])
|
||||
|
||||
# Record provision event
|
||||
metrics = _load_metrics()
|
||||
metrics.setdefault("provisions", []).append({
|
||||
"template": best["name"],
|
||||
"category": best["category"],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"task_hint": task_hint[:200] if task_hint else "",
|
||||
})
|
||||
# Update provision count
|
||||
if best["name"] in metrics.get("templates", {}):
|
||||
metrics["templates"][best["name"]]["provision_count"] = (
|
||||
metrics["templates"][best["name"]].get("provision_count", 0) + 1
|
||||
)
|
||||
_save_metrics(metrics)
|
||||
|
||||
logger.info(
|
||||
"Provisioned warm session from template '%s' (category=%s, %d seed msgs)",
|
||||
best["name"], best["category"], len(seed_messages),
|
||||
)
|
||||
|
||||
return True, seed_messages, best["name"]
|
||||
|
||||
|
||||
def get_provision_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about warm session provisioning."""
|
||||
metrics = _load_metrics()
|
||||
templates = metrics.get("templates", {})
|
||||
provisions = metrics.get("provisions", [])
|
||||
|
||||
total_provisions = len(provisions)
|
||||
categories = {}
|
||||
for t in templates.values():
|
||||
cat = t.get("category", "unknown")
|
||||
categories[cat] = categories.get(cat, 0) + 1
|
||||
|
||||
return {
|
||||
"total_templates": len(templates),
|
||||
"total_provisions": total_provisions,
|
||||
"categories": categories,
|
||||
"recent_provisions": provisions[-10:] if provisions else [],
|
||||
}
|
||||
Reference in New Issue
Block a user