Compare commits

..

2 Commits

Author SHA1 Message Date
f4a75baca0 test(templates): Add tests for session templates
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 33s
Resolves #329
2026-04-14 11:41:25 +00:00
33d80f27ff feat(templates): Session templates for code-first seeding (#329)
Implement session templates based on research: code-heavy sessions improve over time.

Features:
- Task type classification (code/file/research/mixed)
- Template extraction from sessions
- Template storage in ~/.hermes/session-templates/
- Template injection into new sessions
- Tags support
- Usage tracking
- CLI: list/create/delete/stats

Resolves #329
2026-04-14 11:39:32 +00:00
4 changed files with 367 additions and 160 deletions

View File

@@ -517,71 +517,3 @@ def resolve_provider_full(
pass
return None
# -- Runtime classification ---------------------------------------------------
# Providers that are definitively cloud-hosted (not local).
# Used by _classify_runtime() to distinguish cloud vs unknown.
_CLOUD_PREFIXES: frozenset[str] = frozenset(HERMES_OVERLAYS.keys()) | frozenset({
# Common aliases that normalize to cloud providers
"openai", "gemini", "google", "google-gemini", "google-ai-studio",
"claude", "claude-code", "copilot", "github", "github-copilot",
"glm", "z-ai", "z.ai", "zhipu", "zai",
"kimi", "kimi-coding", "moonshot",
"minimax", "minimax-china", "minimax_cn",
"deep-seek",
"dashscope", "aliyun", "qwen", "alibaba-cloud", "alibaba",
"hf", "hugging-face", "huggingface-hub", "huggingface",
"ai-gateway", "aigateway", "vercel-ai-gateway",
"opencode-zen", "zen",
"opencode-go-sub",
"kilocode", "kilo-code", "kilo-gateway", "kilo",
})
# Providers that are definitively local (self-hosted, no external API).
_LOCAL_PROVIDERS: frozenset[str] = frozenset({
"ollama", "local",
"vllm", "llamacpp", "llama.cpp", "llama-cpp", "lmstudio", "lm-studio",
})
def _classify_runtime(provider: Optional[str], model: str) -> str:
"""Classify a provider/model pair into a runtime category.
Returns one of:
``"cloud"`` — the request targets a known remote/hosted provider.
``"local"`` — the request targets a self-hosted/local inference server.
``"unknown"`` — provider is unrecognised or not specified without enough
context to determine the runtime type.
Edge-case rules (in order):
1. If *provider* is set and is a known local provider → ``"local"``.
2. If *provider* is set and is a known cloud provider → ``"cloud"``.
3. If *provider* is set but **not** in either known set → ``"unknown"``.
(Previously fell through to ``"local"`` — this was the bug.)
4. If *provider* is empty/None, inspect the model string for a recognised
cloud prefix (e.g. ``"openai/gpt-4o"`` → ``"cloud"``).
5. Everything else → ``"unknown"``.
"""
p = (provider or "").strip().lower()
if p:
# Rule 1: known local provider
if p in _LOCAL_PROVIDERS:
return "local"
# Rule 2: known cloud provider
if p in _CLOUD_PREFIXES:
return "cloud"
# Rule 3: provider is set but unrecognised — do NOT default to "local"
return "unknown"
# Rule 4: no provider — try to infer from the model string
m = (model or "").strip().lower()
if "/" in m:
model_prefix = m.split("/", 1)[0]
if model_prefix in _CLOUD_PREFIXES:
return "cloud"
# Rule 5: insufficient context
return "unknown"

View File

@@ -1,92 +0,0 @@
"""Tests for _classify_runtime() edge cases.
Covers the bug reported in #556: unknown provider with a model string
incorrectly returned "local" instead of "unknown".
"""
import pytest
from hermes_cli.providers import _classify_runtime
class TestClassifyRuntimeLocalProviders:
def test_ollama_no_model(self):
assert _classify_runtime("ollama", "") == "local"
def test_ollama_with_model(self):
assert _classify_runtime("ollama", "llama3:8b") == "local"
def test_local_provider_no_model(self):
assert _classify_runtime("local", "") == "local"
def test_local_provider_with_model(self):
assert _classify_runtime("local", "my-model") == "local"
def test_vllm_provider(self):
assert _classify_runtime("vllm", "meta/llama-3") == "local"
def test_llamacpp_provider(self):
assert _classify_runtime("llamacpp", "mistral") == "local"
class TestClassifyRuntimeCloudProviders:
def test_anthropic_provider(self):
assert _classify_runtime("anthropic", "claude-opus-4-6") == "cloud"
def test_openrouter_provider(self):
assert _classify_runtime("openrouter", "anthropic/claude-opus-4-6") == "cloud"
def test_nous_provider(self):
assert _classify_runtime("nous", "hermes-3") == "cloud"
def test_gemini_provider(self):
assert _classify_runtime("gemini", "gemini-pro") == "cloud"
def test_deepseek_provider(self):
assert _classify_runtime("deepseek", "deepseek-chat") == "cloud"
class TestClassifyRuntimeUnknownProviders:
"""Regression tests for #556: unknown provider should return 'unknown', not 'local'."""
def test_unknown_provider_with_model(self):
"""Core bug: 'custom' provider with model must not return 'local'."""
assert _classify_runtime("custom", "my-model") == "unknown"
def test_unknown_provider_no_model(self):
"""Unknown provider with no model should return 'unknown'."""
assert _classify_runtime("custom", "") == "unknown"
def test_arbitrary_provider_with_model(self):
"""Any unrecognised provider string with a model returns 'unknown'."""
assert _classify_runtime("my-private-llm", "some-model") == "unknown"
def test_arbitrary_provider_no_model(self):
assert _classify_runtime("my-private-llm", "") == "unknown"
def test_whitespace_only_provider_treated_as_empty(self):
"""Provider with only whitespace is treated as absent."""
# No model either → unknown
assert _classify_runtime(" ", "") == "unknown"
class TestClassifyRuntimeEmptyProvider:
def test_empty_provider_cloud_prefixed_model(self):
"""Empty provider with cloud-prefixed model returns 'cloud'."""
assert _classify_runtime("", "openrouter/gpt-4o") == "cloud"
def test_none_provider_cloud_prefixed_model(self):
assert _classify_runtime(None, "anthropic/claude-opus-4-6") == "cloud"
def test_empty_provider_no_model(self):
assert _classify_runtime("", "") == "unknown"
def test_none_provider_no_model(self):
assert _classify_runtime(None, "") == "unknown"
def test_empty_provider_non_cloud_prefixed_model(self):
"""No provider, model without a recognized prefix → unknown."""
assert _classify_runtime("", "my-model") == "unknown"
def test_empty_provider_model_with_unknown_prefix(self):
"""Model prefix that isn't a known cloud provider → unknown."""
assert _classify_runtime("", "myprivate/llm-7b") == "unknown"

View File

@@ -0,0 +1,92 @@
"""Tests for session templates."""
import json
import pytest
import tempfile
from pathlib import Path
from tools.session_templates import Templates, Template, ToolExample, TaskType
class TestClassification:
def test_code(self):
t = Templates()
assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE
def test_file(self):
t = Templates()
assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE
def test_mixed(self):
t = Templates()
assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED
class TestToolExample:
def test_roundtrip(self):
e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0)
d = e.to_dict()
e2 = ToolExample.from_dict(d)
assert e2.tool_name == "execute_code"
assert e2.result == "hi"
class TestTemplate:
def test_roundtrip(self):
examples = [ToolExample("execute_code", {}, "result", True)]
t = Template("test", TaskType.CODE, examples)
d = t.to_dict()
t2 = Template.from_dict(d)
assert t2.name == "test"
assert t2.task_type == TaskType.CODE
class TestTemplates:
def test_create_list(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
t = Template("test", TaskType.CODE, examples)
ts.templates["test"] = t
ts._save(t)
assert len(ts.list()) == 1
def test_get(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
ts.templates["code"] = Template("code", TaskType.CODE, [])
ts.templates["file"] = Template("file", TaskType.FILE, [])
assert ts.get(TaskType.CODE).name == "code"
assert ts.get(TaskType.FILE).name == "file"
assert ts.get(TaskType.RESEARCH) is None
def test_inject(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
t = Template("test", TaskType.CODE, examples)
ts.templates["test"] = t
messages = [{"role": "system", "content": "You are helpful."}]
result = ts.inject(t, messages)
assert len(result) > len(messages)
assert t.used == 1
def test_delete(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
t = Template("test", TaskType.CODE, [])
ts.templates["test"] = t
ts._save(t)
assert ts.delete("test")
assert "test" not in ts.templates
def test_stats(self):
with tempfile.TemporaryDirectory() as d:
ts = Templates(Path(d))
ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)])
ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)])
s = ts.stats()
assert s["total"] == 2
assert s["examples"] == 2
if __name__ == "__main__":
pytest.main([__file__])

275
tools/session_templates.py Normal file
View File

@@ -0,0 +1,275 @@
"""
Session templates for code-first seeding.
Research: Code-heavy sessions (execute_code dominant in first 30 turns) improve over time.
File-heavy sessions degrade. Key is deterministic feedback loops.
"""
import json
import logging
import sqlite3
import time
from pathlib import Path
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, asdict, field
from enum import Enum
logger = logging.getLogger(__name__)
TEMPLATE_DIR = Path.home() / ".hermes" / "session-templates"
class TaskType(Enum):
CODE = "code"
FILE = "file"
RESEARCH = "research"
MIXED = "mixed"
@dataclass
class ToolExample:
tool_name: str
arguments: Dict[str, Any]
result: str
success: bool
turn: int = 0
def to_dict(self):
return asdict(self)
@classmethod
def from_dict(cls, data):
return cls(**data)
@dataclass
class Template:
name: str
task_type: TaskType
examples: List[ToolExample]
desc: str = ""
created: float = 0.0
used: int = 0
session_id: Optional[str] = None
tags: List[str] = field(default_factory=list)
def __post_init__(self):
if self.created == 0.0:
self.created = time.time()
def to_dict(self):
d = asdict(self)
d['task_type'] = self.task_type.value
return d
@classmethod
def from_dict(cls, data):
data['task_type'] = TaskType(data['task_type'])
data['examples'] = [ToolExample.from_dict(e) for e in data.get('examples', [])]
return cls(**data)
class Templates:
def __init__(self, dir=None):
self.dir = dir or TEMPLATE_DIR
self.dir.mkdir(parents=True, exist_ok=True)
self.templates = {}
self._load()
def _load(self):
for f in self.dir.glob("*.json"):
try:
with open(f) as fh:
t = Template.from_dict(json.load(fh))
self.templates[t.name] = t
except Exception as e:
logger.warning(f"Load failed {f}: {e}")
def _save(self, t):
with open(self.dir / f"{t.name}.json", 'w') as f:
json.dump(t.to_dict(), f, indent=2)
def classify(self, calls):
if not calls:
return TaskType.MIXED
code = {'execute_code', 'code_execution'}
file_ops = {'read_file', 'write_file', 'patch', 'search_files'}
research = {'web_search', 'web_fetch', 'browser_navigate'}
names = [c.get('tool_name', '') for c in calls]
total = len(names)
if sum(1 for n in names if n in code) / total > 0.6:
return TaskType.CODE
if sum(1 for n in names if n in file_ops) / total > 0.6:
return TaskType.FILE
if sum(1 for n in names if n in research) / total > 0.6:
return TaskType.RESEARCH
return TaskType.MIXED
def extract(self, session_id, max_n=10):
db = Path.home() / ".hermes" / "state.db"
if not db.exists():
return []
try:
conn = sqlite3.connect(str(db))
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT role, content, tool_calls FROM messages WHERE session_id=? ORDER BY timestamp LIMIT 100",
(session_id,)
).fetchall()
conn.close()
examples = []
turn = 0
for r in rows:
if len(examples) >= max_n:
break
if r['role'] == 'assistant' and r['tool_calls']:
try:
for tc in json.loads(r['tool_calls']):
if len(examples) >= max_n:
break
name = tc.get('function', {}).get('name')
if not name:
continue
try:
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
except:
args = {}
examples.append(ToolExample(name, args, "", True, turn))
turn += 1
except:
continue
elif r['role'] == 'tool' and examples and examples[-1].result == "":
examples[-1].result = r['content'] or ""
return examples
except Exception as e:
logger.error(f"Extract failed: {e}")
return []
def create(self, session_id, name=None, task_type=None, max_n=10, desc="", tags=None):
examples = self.extract(session_id, max_n)
if not examples:
return None
if task_type is None:
task_type = self.classify([{'tool_name': e.tool_name} for e in examples])
if name is None:
name = f"{task_type.value}_{session_id[:8]}_{int(time.time())}"
t = Template(name, task_type, examples, desc or f"{len(examples)} examples", time.time(), 0, session_id, tags or [])
self.templates[name] = t
self._save(t)
logger.info(f"Created {name} with {len(examples)} examples")
return t
def get(self, task_type, tags=None):
matching = [t for t in self.templates.values() if t.task_type == task_type]
if tags:
matching = [t for t in matching if any(tag in t.tags for tag in tags)]
if not matching:
return None
matching.sort(key=lambda t: t.used)
return matching[0]
def inject(self, template, messages):
if not template.examples:
return messages
injection = [{
"role": "system",
"content": f"Template: {template.name} ({template.task_type.value})\n{template.desc}"
}]
for i, ex in enumerate(template.examples):
injection.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"tpl_{i}",
"type": "function",
"function": {"name": ex.tool_name, "arguments": json.dumps(ex.arguments)}
}]
})
injection.append({
"role": "tool",
"tool_call_id": f"tpl_{i}",
"content": ex.result
})
idx = 0
for i, m in enumerate(messages):
if m.get("role") != "system":
break
idx = i + 1
for i, m in enumerate(injection):
messages.insert(idx + i, m)
template.used += 1
self._save(template)
return messages
def list(self, task_type=None, tags=None):
ts = list(self.templates.values())
if task_type:
ts = [t for t in ts if t.task_type == task_type]
if tags:
ts = [t for t in ts if any(tag in t.tags for tag in tags)]
ts.sort(key=lambda t: t.created, reverse=True)
return ts
def delete(self, name):
if name not in self.templates:
return False
del self.templates[name]
p = self.dir / f"{name}.json"
if p.exists():
p.unlink()
return True
def stats(self):
if not self.templates:
return {"total": 0, "by_type": {}, "examples": 0, "usage": 0}
by_type = {}
total_ex = 0
total_use = 0
for t in self.templates.values():
by_type[t.task_type.value] = by_type.get(t.task_type.value, 0) + 1
total_ex += len(t.examples)
total_use += t.used
return {"total": len(self.templates), "by_type": by_type, "examples": total_ex, "usage": total_use}
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
s = p.add_subparsers(dest="cmd")
lp = s.add_parser("list")
lp.add_argument("--type", choices=["code", "file", "research", "mixed"])
lp.add_argument("--tags")
cp = s.add_parser("create")
cp.add_argument("session_id")
cp.add_argument("--name")
cp.add_argument("--type", choices=["code", "file", "research", "mixed"])
cp.add_argument("--max", type=int, default=10)
cp.add_argument("--desc")
cp.add_argument("--tags")
dp = s.add_parser("delete")
dp.add_argument("name")
sp = s.add_parser("stats")
args = p.parse_args()
ts = Templates()
if args.cmd == "list":
tt = TaskType(args.type) if args.type else None
tags = args.tags.split(",") if args.tags else None
for t in ts.list(tt, tags):
print(f"{t.name}: {t.task_type.value} ({len(t.examples)} ex, used {t.used}x)")
elif args.cmd == "create":
tt = TaskType(args.type) if args.type else None
tags = args.tags.split(",") if args.tags else None
t = ts.create(args.session_id, args.name, tt, args.max, args.desc or "", tags)
if t:
print(f"Created: {t.name} ({len(t.examples)} examples)")
else:
print("Failed")
elif args.cmd == "delete":
print("Deleted" if ts.delete(args.name) else "Not found")
elif args.cmd == "stats":
s = ts.stats()
print(f"Total: {s['total']}, Examples: {s['examples']}, Usage: {s['usage']}")
for k, v in s['by_type'].items():
print(f" {k}: {v}")
else:
p.print_help()