diff --git a/tests/tools/test_session_templates.py b/tests/tools/test_session_templates.py new file mode 100644 index 000000000..16b98246b --- /dev/null +++ b/tests/tools/test_session_templates.py @@ -0,0 +1,188 @@ +"""Tests for session templates (code-first seeding).""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from tools.session_templates import ( + SessionTemplate, + SessionTemplates, + TaskType, + ToolCallExample, +) + + +@pytest.fixture +def tmp_templates(tmp_path): + return SessionTemplates(templates_dir=tmp_path / "templates") + + +# --------------------------------------------------------------------------- +# Task type classification +# --------------------------------------------------------------------------- + +class TestClassifyTaskType: + def test_code_dominant(self, tmp_templates): + calls = [ + {"name": "execute_code"}, {"name": "execute_code"}, + {"name": "execute_code"}, {"name": "read_file"}, + ] + assert tmp_templates.classify_task_type(calls) == TaskType.CODE + + def test_file_dominant(self, tmp_templates): + calls = [ + {"name": "read_file"}, {"name": "write_file"}, + {"name": "patch"}, {"name": "read_file"}, + {"name": "execute_code"}, + ] + assert tmp_templates.classify_task_type(calls) == TaskType.FILE + + def test_research_dominant(self, tmp_templates): + calls = [ + {"name": "web_search"}, {"name": "web_fetch"}, + {"name": "web_search"}, {"name": "read_file"}, + ] + assert tmp_templates.classify_task_type(calls) == TaskType.RESEARCH + + def test_mixed_no_dominant(self, tmp_templates): + calls = [ + {"name": "execute_code"}, {"name": "read_file"}, + {"name": "web_search"}, + ] + assert tmp_templates.classify_task_type(calls) == TaskType.MIXED + + def test_empty_returns_mixed(self, tmp_templates): + assert tmp_templates.classify_task_type([]) == TaskType.MIXED + + def test_threshold_is_60_percent(self, tmp_templates): + # 59% code (5/9) should be MIXED + calls = [{"name": "execute_code"}] * 5 + [{"name": "read_file"}] * 4 + assert tmp_templates.classify_task_type(calls) == TaskType.MIXED + + # 60% code (6/10) should be CODE + calls = [{"name": "execute_code"}] * 6 + [{"name": "read_file"}] * 4 + assert tmp_templates.classify_task_type(calls) == TaskType.CODE + + +# --------------------------------------------------------------------------- +# Template CRUD +# --------------------------------------------------------------------------- + +class TestTemplateCRUD: + def test_save_and_list(self, tmp_templates): + template = SessionTemplate( + name="test-code", + task_type=TaskType.CODE, + examples=[ + ToolCallExample(tool_name="execute_code", args={"code": "print('hi')"}, success=True), + ], + created_at="2026-01-01T00:00:00Z", + ) + tmp_templates.save_template(template) + + templates = tmp_templates.list_templates() + assert len(templates) == 1 + assert templates[0].name == "test-code" + assert templates[0].task_type == TaskType.CODE + + def test_list_filter_by_type(self, tmp_templates): + tmp_templates.save_template(SessionTemplate(name="t1", task_type=TaskType.CODE, examples=[])) + tmp_templates.save_template(SessionTemplate(name="t2", task_type=TaskType.FILE, examples=[])) + + code_templates = tmp_templates.list_templates(TaskType.CODE) + assert len(code_templates) == 1 + assert code_templates[0].name == "t1" + + def test_delete(self, tmp_templates): + tmp_templates.save_template(SessionTemplate(name="delete-me", task_type=TaskType.CODE, examples=[])) + assert tmp_templates.delete_template("delete-me") is True + assert len(tmp_templates.list_templates()) == 0 + + def test_delete_nonexistent(self, tmp_templates): + assert tmp_templates.delete_template("nope") is False + + def test_get_template_returns_best(self, tmp_templates): + tmp_templates.save_template(SessionTemplate( + name="low-usage", task_type=TaskType.CODE, examples=[], usage_count=1, + )) + tmp_templates.save_template(SessionTemplate( + name="high-usage", task_type=TaskType.CODE, examples=[], usage_count=5, + )) + best = tmp_templates.get_template(TaskType.CODE) + assert best.name == "high-usage" + + def test_get_template_returns_none_if_empty(self, tmp_templates): + assert tmp_templates.get_template(TaskType.CODE) is None + + +# --------------------------------------------------------------------------- +# Template injection +# --------------------------------------------------------------------------- + +class TestInjectIntoMessages: + def test_injects_after_system(self, tmp_templates): + template = SessionTemplate( + name="test-inject", + task_type=TaskType.CODE, + examples=[ + ToolCallExample( + tool_name="execute_code", + args={"code": "x=1"}, + result_preview="1", + success=True, + ), + ], + ) + messages = [ + {"role": "system", "content": "You are Timmy."}, + {"role": "user", "content": "Hello"}, + ] + result = tmp_templates.inject_into_messages(template, messages) + + # Should have: system, template system note, assistant tool call, tool result, user + assert len(result) == 5 + assert result[0]["role"] == "system" + assert "Session Template" in result[1]["content"] + assert result[2]["role"] == "assistant" + assert result[3]["role"] == "tool" + assert result[4]["role"] == "user" + + def test_skips_failed_examples(self, tmp_templates): + template = SessionTemplate( + name="test-fail", + task_type=TaskType.CODE, + examples=[ + ToolCallExample(tool_name="execute_code", args={}, success=False), + ToolCallExample(tool_name="read_file", args={"path": "x"}, success=True), + ], + ) + messages = [{"role": "system", "content": "sys"}] + result = tmp_templates.inject_into_messages(template, messages) + + # Only the successful example should be injected + tool_calls = [m for m in result if m.get("role") == "assistant" and m.get("tool_calls")] + assert len(tool_calls) == 1 + assert tool_calls[0]["tool_calls"][0]["function"]["name"] == "read_file" + + def test_increments_usage(self, tmp_templates): + template = SessionTemplate(name="usage-test", task_type=TaskType.CODE, examples=[ + ToolCallExample(tool_name="execute_code", args={}, success=True), + ]) + tmp_templates.save_template(template) + + tmp_templates.inject_into_messages(template, [{"role": "system", "content": "x"}]) + assert template.usage_count == 1 + + def test_empty_template_returns_original(self, tmp_templates): + template = SessionTemplate(name="empty", task_type=TaskType.CODE, examples=[]) + messages = [{"role": "user", "content": "hi"}] + result = tmp_templates.inject_into_messages(template, messages) + assert result == messages + + def test_no_template_returns_original(self, tmp_templates): + messages = [{"role": "user", "content": "hi"}] + result = tmp_templates.inject_into_messages(None, messages) + assert result == messages diff --git a/tools/session_templates.py b/tools/session_templates.py new file mode 100644 index 000000000..cc2f93753 --- /dev/null +++ b/tools/session_templates.py @@ -0,0 +1,418 @@ +""" +Session templates for code-first seeding. + +Research finding: Code-heavy sessions (execute_code dominant in first 30 turns) +improve over time. File-heavy sessions degrade. The key is deterministic +feedback loops, not arbitrary context. + +This module provides: +1. Task type classification (CODE, FILE, RESEARCH, MIXED) +2. Template extraction from completed sessions +3. Template storage (~/.hermes/session-templates/) +4. Template injection into new sessions +5. CLI interface for template management + +Closes #329. +""" + +from __future__ import annotations + +import json +import os +import sqlite3 +from dataclasses import asdict, dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Optional + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +HERMES_HOME = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes"))) +TEMPLATES_DIR = HERMES_HOME / "session-templates" +SESSIONS_DB = HERMES_HOME / "state.db" + +# Tool classification sets +CODE_TOOLS = frozenset({"execute_code", "code_execution"}) +FILE_TOOLS = frozenset({"read_file", "write_file", "patch", "search_files"}) +RESEARCH_TOOLS = frozenset({"web_search", "web_fetch", "browser_navigate", "browser_snapshot"}) + +# Dominance threshold for task type classification +DOMINANCE_THRESHOLD = 0.6 + +# Default max examples to extract per template +DEFAULT_MAX_EXAMPLES = 10 + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + +class TaskType(str, Enum): + CODE = "code" + FILE = "file" + RESEARCH = "research" + MIXED = "mixed" + + +@dataclass +class ToolCallExample: + """A single tool call with its result, used as a template example.""" + tool_name: str + args: dict[str, Any] + result_preview: str = "" + success: bool = True + + +@dataclass +class SessionTemplate: + """A session template containing tool call examples for seeding.""" + name: str + task_type: TaskType + examples: list[ToolCallExample] = field(default_factory=list) + source_session_id: str = "" + created_at: str = "" + usage_count: int = 0 + description: str = "" + + +# --------------------------------------------------------------------------- +# Core logic +# --------------------------------------------------------------------------- + +class SessionTemplates: + """Manages session templates for code-first seeding.""" + + def __init__(self, templates_dir: Optional[Path] = None): + self.templates_dir = templates_dir or TEMPLATES_DIR + self.templates_dir.mkdir(parents=True, exist_ok=True) + + def classify_task_type(self, tool_calls: list[dict[str, Any]]) -> TaskType: + """Classify a session's task type based on tool call patterns. + + Uses 60% threshold for dominant type. + """ + if not tool_calls: + return TaskType.MIXED + + total = len(tool_calls) + code_count = 0 + file_count = 0 + research_count = 0 + + for tc in tool_calls: + name = tc.get("name", tc.get("tool_name", "")).lower() + if name in CODE_TOOLS: + code_count += 1 + elif name in FILE_TOOLS: + file_count += 1 + elif name in RESEARCH_TOOLS: + research_count += 1 + + code_ratio = code_count / total + file_ratio = file_count / total + research_ratio = research_count / total + + if code_ratio >= DOMINANCE_THRESHOLD: + return TaskType.CODE + if file_ratio >= DOMINANCE_THRESHOLD: + return TaskType.FILE + if research_ratio >= DOMINANCE_THRESHOLD: + return TaskType.RESEARCH + return TaskType.MIXED + + def extract_from_session( + self, + session_id: str, + max_examples: int = DEFAULT_MAX_EXAMPLES, + ) -> list[ToolCallExample]: + """Extract tool call examples from a completed session. + + Reads from the SQLite session database. + """ + examples: list[ToolCallExample] = [] + + db_path = SESSIONS_DB + if not db_path.exists(): + return examples + + try: + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + + rows = conn.execute( + "SELECT messages FROM sessions WHERE session_id = ? ORDER BY created_at DESC LIMIT 1", + (session_id,), + ).fetchone() + + if not rows: + conn.close() + return examples + + messages = json.loads(rows["messages"]) + + # Extract tool calls from assistant messages + for msg in messages: + if msg.get("role") != "assistant": + continue + tool_calls = msg.get("tool_calls", []) + if not tool_calls: + continue + + for tc in tool_calls: + if len(examples) >= max_examples: + break + + fn = tc.get("function", {}) + name = fn.get("name", "") + if not name: + continue + + try: + args = json.loads(fn.get("arguments", "{}")) + except (json.JSONDecodeError, TypeError): + args = {} + + # Find the corresponding tool result + result_preview = "" + success = True + tc_id = tc.get("id", "") + + for result_msg in messages: + if (result_msg.get("role") == "tool" + and result_msg.get("tool_call_id") == tc_id): + content = result_msg.get("content", "") + result_preview = str(content)[:200] + # Heuristic: errors contain common failure markers + if any(marker in result_preview.lower() for marker in ("error", "failed", "traceback", "exception")): + success = False + break + + examples.append(ToolCallExample( + tool_name=name, + args=args, + result_preview=result_preview, + success=success, + )) + + conn.close() + except Exception: + pass + + return examples + + def create_template( + self, + session_id: str, + name: Optional[str] = None, + description: str = "", + max_examples: int = DEFAULT_MAX_EXAMPLES, + ) -> Optional[SessionTemplate]: + """Create a template from a session's tool call history.""" + examples = self.extract_from_session(session_id, max_examples) + if not examples: + return None + + tool_calls_for_type = [{"name": e.tool_name} for e in examples] + task_type = self.classify_task_type(tool_calls_for_type) + + template_name = name or f"{task_type.value}_{session_id[:8]}" + + from datetime import datetime + template = SessionTemplate( + name=template_name, + task_type=task_type, + examples=examples, + source_session_id=session_id, + created_at=datetime.utcnow().isoformat() + "Z", + description=description or f"Auto-extracted from {session_id}", + ) + + self.save_template(template) + return template + + def save_template(self, template: SessionTemplate) -> Path: + """Save a template to disk.""" + path = self.templates_dir / f"{template.name}.json" + data = { + "name": template.name, + "task_type": template.task_type.value, + "examples": [asdict(e) for e in template.examples], + "source_session_id": template.source_session_id, + "created_at": template.created_at, + "usage_count": template.usage_count, + "description": template.description, + } + path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n") + return path + + def get_template(self, task_type: TaskType) -> Optional[SessionTemplate]: + """Get the best template for a given task type.""" + templates = self.list_templates(task_type) + if not templates: + return None + + # Prefer templates with more usage (proven useful) + templates.sort(key=lambda t: t.usage_count, reverse=True) + return templates[0] + + def list_templates(self, task_type: Optional[TaskType] = None) -> list[SessionTemplate]: + """List all templates, optionally filtered by type.""" + templates: list[SessionTemplate] = [] + + for path in sorted(self.templates_dir.glob("*.json")): + try: + data = json.loads(path.read_text()) + examples = [ToolCallExample(**e) for e in data.get("examples", [])] + template = SessionTemplate( + name=data["name"], + task_type=TaskType(data["task_type"]), + examples=examples, + source_session_id=data.get("source_session_id", ""), + created_at=data.get("created_at", ""), + usage_count=data.get("usage_count", 0), + description=data.get("description", ""), + ) + if task_type is None or template.task_type == task_type: + templates.append(template) + except Exception: + continue + + return templates + + def delete_template(self, name: str) -> bool: + """Delete a template by name.""" + path = self.templates_dir / f"{name}.json" + if path.exists(): + path.unlink() + return True + return False + + def inject_into_messages( + self, + template: SessionTemplate, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Inject template examples into a session's messages. + + Inserts tool call examples after system messages to establish + feedback loops early. + """ + if not template or not template.examples: + return messages + + # Build injection messages + injection: list[dict[str, Any]] = [] + + # System note about the template + injection.append({ + "role": "system", + "content": ( + f"[Session Template: '{template.name}' ({template.task_type.value})]\n" + f"The following are examples of successful tool calls from a similar session. " + f"Use them as patterns for your own tool usage." + ), + }) + + # Add example tool call/result pairs + for ex in template.examples: + if not ex.success: + continue # Only inject successful examples + + injection.append({ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": f"template_{template.name}_{ex.tool_name}", + "type": "function", + "function": { + "name": ex.tool_name, + "arguments": json.dumps(ex.args), + }, + }], + }) + injection.append({ + "role": "tool", + "tool_call_id": f"template_{template.name}_{ex.tool_name}", + "content": ex.result_preview or "(example result)", + }) + + # Find insertion point: after system messages + insert_idx = 0 + for i, msg in enumerate(messages): + if msg.get("role") == "system": + insert_idx = i + 1 + else: + break + + # Insert + result = messages[:insert_idx] + injection + messages[insert_idx:] + + # Update usage count + template.usage_count += 1 + self.save_template(template) + + return result + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _cli(): + """Simple CLI for session template management.""" + import argparse + import sys + + parser = argparse.ArgumentParser(description="Session template management") + sub = parser.add_subparsers(dest="command") + + # list + list_cmd = sub.add_parser("list", help="List templates") + list_cmd.add_argument("--type", choices=["code", "file", "research", "mixed"]) + + # create + create_cmd = sub.add_parser("create", help="Create template from session") + create_cmd.add_argument("session_id", help="Session ID to extract from") + create_cmd.add_argument("--name", help="Template name") + create_cmd.add_argument("--max-examples", type=int, default=10) + + # delete + delete_cmd = sub.add_parser("delete", help="Delete template") + delete_cmd.add_argument("name", help="Template name") + + args = parser.parse_args() + tm = SessionTemplates() + + if args.command == "list": + task_type = TaskType(args.type) if args.type else None + templates = tm.list_templates(task_type) + if not templates: + print("No templates found.") + return + for t in templates: + print(f" {t.name:30s} {t.task_type.value:10s} {len(t.examples)} examples, used {t.usage_count}x") + + elif args.command == "create": + template = tm.create_template(args.session_id, name=args.name, max_examples=args.max_examples) + if template: + print(f"Created template: {template.name} ({template.task_type.value}, {len(template.examples)} examples)") + else: + print(f"No tool calls found in session {args.session_id}") + sys.exit(1) + + elif args.command == "delete": + if tm.delete_template(args.name): + print(f"Deleted template: {args.name}") + else: + print(f"Template not found: {args.name}") + sys.exit(1) + + else: + parser.print_help() + + +if __name__ == "__main__": + _cli()