Compare commits

...

2 Commits

Author SHA1 Message Date
f4a75baca0 test(templates): Add tests for session templates
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 33s
Resolves #329
2026-04-14 11:41:25 +00:00
33d80f27ff feat(templates): Session templates for code-first seeding (#329)
Implement session templates based on research: code-heavy sessions improve over time.

Features:
- Task type classification (code/file/research/mixed)
- Template extraction from sessions
- Template storage in ~/.hermes/session-templates/
- Template injection into new sessions
- Tags support
- Usage tracking
- CLI: list/create/delete/stats

Resolves #329
2026-04-14 11:39:32 +00:00
2 changed files with 367 additions and 0 deletions

View File

@@ -0,0 +1,92 @@
"""Tests for session templates."""
import json
import pytest
import tempfile
from pathlib import Path
from tools.session_templates import Templates, Template, ToolExample, TaskType
class TestClassification:
def test_code(self):
t = Templates()
assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE
def test_file(self):
t = Templates()
assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE
def test_mixed(self):
t = Templates()
assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED
class TestToolExample:
def test_roundtrip(self):
e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0)
d = e.to_dict()
e2 = ToolExample.from_dict(d)
assert e2.tool_name == "execute_code"
assert e2.result == "hi"
class TestTemplate:
def test_roundtrip(self):
examples = [ToolExample("execute_code", {}, "result", True)]
t = Template("test", TaskType.CODE, examples)
d = t.to_dict()
t2 = Template.from_dict(d)
assert t2.name == "test"
assert t2.task_type == TaskType.CODE
class TestTemplates:
def test_create_list(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
t = Template("test", TaskType.CODE, examples)
ts.templates["test"] = t
ts._save(t)
assert len(ts.list()) == 1
def test_get(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
ts.templates["code"] = Template("code", TaskType.CODE, [])
ts.templates["file"] = Template("file", TaskType.FILE, [])
assert ts.get(TaskType.CODE).name == "code"
assert ts.get(TaskType.FILE).name == "file"
assert ts.get(TaskType.RESEARCH) is None
def test_inject(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
t = Template("test", TaskType.CODE, examples)
ts.templates["test"] = t
messages = [{"role": "system", "content": "You are helpful."}]
result = ts.inject(t, messages)
assert len(result) > len(messages)
assert t.used == 1
def test_delete(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
t = Template("test", TaskType.CODE, [])
ts.templates["test"] = t
ts._save(t)
assert ts.delete("test")
assert "test" not in ts.templates
def test_stats(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)])
ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)])
s = ts.stats()
assert s["total"] == 2
assert s["examples"] == 2
if __name__ == "__main__":
pytest.main([__file__])

275
tools/session_templates.py Normal file
View File

@@ -0,0 +1,275 @@
"""
Session templates for code-first seeding.
Research: Code-heavy sessions (execute_code dominant in first 30 turns) improve over time.
File-heavy sessions degrade. Key is deterministic feedback loops.
"""
import json
import logging
import sqlite3
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict, field
from enum import Enum
logger = logging.getLogger(__name__)
TEMPLATE_DIR = Path.home() / ".hermes" / "session-templates"
class TaskType(Enum):
CODE = "code"
FILE = "file"
RESEARCH = "research"
MIXED = "mixed"
@dataclass
class ToolExample:
tool_name: str
arguments: Dict[str, Any]
result: str
success: bool
turn: int = 0
def to_dict(self):
return asdict(self)
@classmethod
def from_dict(cls, data):
return cls(**data)
@dataclass
class Template:
name: str
task_type: TaskType
examples: List[ToolExample]
desc: str = ""
created: float = 0.0
used: int = 0
session_id: Optional[str] = None
tags: List[str] = field(default_factory=list)
def __post_init__(self):
if self.created == 0.0:
self.created = time.time()
def to_dict(self):
d = asdict(self)
d['task_type'] = self.task_type.value
return d
@classmethod
def from_dict(cls, data):
data['task_type'] = TaskType(data['task_type'])
data['examples'] = [ToolExample.from_dict(e) for e in data.get('examples', [])]
return cls(**data)
class Templates:
def __init__(self, dir=None):
self.dir = dir or TEMPLATE_DIR
self.dir.mkdir(parents=True, exist_ok=True)
self.templates = {}
self._load()
def _load(self):
for f in self.dir.glob("*.json"):
try:
with open(f) as fh:
t = Template.from_dict(json.load(fh))
self.templates[t.name] = t
except Exception as e:
logger.warning(f"Load failed {f}: {e}")
def _save(self, t):
with open(self.dir / f"{t.name}.json", 'w') as f:
json.dump(t.to_dict(), f, indent=2)
def classify(self, calls):
if not calls:
return TaskType.MIXED
code = {'execute_code', 'code_execution'}
file_ops = {'read_file', 'write_file', 'patch', 'search_files'}
research = {'web_search', 'web_fetch', 'browser_navigate'}
names = [c.get('tool_name', '') for c in calls]
total = len(names)
if sum(1 for n in names if n in code) / total > 0.6:
return TaskType.CODE
if sum(1 for n in names if n in file_ops) / total > 0.6:
return TaskType.FILE
if sum(1 for n in names if n in research) / total > 0.6:
return TaskType.RESEARCH
return TaskType.MIXED
def extract(self, session_id, max_n=10):
db = Path.home() / ".hermes" / "state.db"
if not db.exists():
return []
try:
conn = sqlite3.connect(str(db))
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT role, content, tool_calls FROM messages WHERE session_id=? ORDER BY timestamp LIMIT 100",
(session_id,)
).fetchall()
conn.close()
examples = []
turn = 0
for r in rows:
if len(examples) >= max_n:
break
if r['role'] == 'assistant' and r['tool_calls']:
try:
for tc in json.loads(r['tool_calls']):
if len(examples) >= max_n:
break
name = tc.get('function', {}).get('name')
if not name:
continue
try:
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
except:
args = {}
examples.append(ToolExample(name, args, "", True, turn))
turn += 1
except:
continue
elif r['role'] == 'tool' and examples and examples[-1].result == "":
examples[-1].result = r['content'] or ""
return examples
except Exception as e:
logger.error(f"Extract failed: {e}")
return []
def create(self, session_id, name=None, task_type=None, max_n=10, desc="", tags=None):
examples = self.extract(session_id, max_n)
if not examples:
return None
if task_type is None:
task_type = self.classify([{'tool_name': e.tool_name} for e in examples])
if name is None:
name = f"{task_type.value}_{session_id[:8]}_{int(time.time())}"
t = Template(name, task_type, examples, desc or f"{len(examples)} examples", time.time(), 0, session_id, tags or [])
self.templates[name] = t
self._save(t)
logger.info(f"Created {name} with {len(examples)} examples")
return t
def get(self, task_type, tags=None):
matching = [t for t in self.templates.values() if t.task_type == task_type]
if tags:
matching = [t for t in matching if any(tag in t.tags for tag in tags)]
if not matching:
return None
matching.sort(key=lambda t: t.used)
return matching[0]
def inject(self, template, messages):
if not template.examples:
return messages
injection = [{
"role": "system",
"content": f"Template: {template.name} ({template.task_type.value})\n{template.desc}"
}]
for i, ex in enumerate(template.examples):
injection.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"tpl_{i}",
"type": "function",
"function": {"name": ex.tool_name, "arguments": json.dumps(ex.arguments)}
}]
})
injection.append({
"role": "tool",
"tool_call_id": f"tpl_{i}",
"content": ex.result
})
idx = 0
for i, m in enumerate(messages):
if m.get("role") != "system":
break
idx = i + 1
for i, m in enumerate(injection):
messages.insert(idx + i, m)
template.used += 1
self._save(template)
return messages
def list(self, task_type=None, tags=None):
ts = list(self.templates.values())
if task_type:
ts = [t for t in ts if t.task_type == task_type]
if tags:
ts = [t for t in ts if any(tag in t.tags for tag in tags)]
ts.sort(key=lambda t: t.created, reverse=True)
return ts
def delete(self, name):
if name not in self.templates:
return False
del self.templates[name]
p = self.dir / f"{name}.json"
if p.exists():
p.unlink()
return True
def stats(self):
if not self.templates:
return {"total": 0, "by_type": {}, "examples": 0, "usage": 0}
by_type = {}
total_ex = 0
total_use = 0
for t in self.templates.values():
by_type[t.task_type.value] = by_type.get(t.task_type.value, 0) + 1
total_ex += len(t.examples)
total_use += t.used
return {"total": len(self.templates), "by_type": by_type, "examples": total_ex, "usage": total_use}
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
s = p.add_subparsers(dest="cmd")
lp = s.add_parser("list")
lp.add_argument("--type", choices=["code", "file", "research", "mixed"])
lp.add_argument("--tags")
cp = s.add_parser("create")
cp.add_argument("session_id")
cp.add_argument("--name")
cp.add_argument("--type", choices=["code", "file", "research", "mixed"])
cp.add_argument("--max", type=int, default=10)
cp.add_argument("--desc")
cp.add_argument("--tags")
dp = s.add_parser("delete")
dp.add_argument("name")
sp = s.add_parser("stats")
args = p.parse_args()
ts = Templates()
if args.cmd == "list":
tt = TaskType(args.type) if args.type else None
tags = args.tags.split(",") if args.tags else None
for t in ts.list(tt, tags):
print(f"{t.name}: {t.task_type.value} ({len(t.examples)} ex, used {t.used}x)")
elif args.cmd == "create":
tt = TaskType(args.type) if args.type else None
tags = args.tags.split(",") if args.tags else None
t = ts.create(args.session_id, args.name, tt, args.max, args.desc or "", tags)
if t:
print(f"Created: {t.name} ({len(t.examples)} examples)")
else:
print("Failed")
elif args.cmd == "delete":
print("Deleted" if ts.delete(args.name) else "Not found")
elif args.cmd == "stats":
s = ts.stats()
print(f"Total: {s['total']}, Examples: {s['examples']}, Usage: {s['usage']}")
for k, v in s['by_type'].items():
print(f" {k}: {v}")
else:
p.print_help()