Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
412ee7329a fix(cron): runtime-aware prompts + provider mismatch detection (#372)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m7s
After provider migration (Ollama -> Nous/mimo-v2-pro), cron jobs with
provider-specific prompts ran on the wrong provider without knowing it.
Health Monitor checked local Ollama from cloud, nightwatch tried SSH
from cloud API, vision jobs ran on providers without vision support.

Changes to cron/scheduler.py:

1. _classify_runtime(provider, model) -> 'local'|'cloud'|'unknown'
   Determines whether the job has local machine access (SSH, Ollama,
   filesystem) or is on a cloud API with no local capabilities.

2. _PROVIDER_ALIASES + _detect_provider_mismatch(prompt, active_provider)
   Detects when a job's prompt references a provider different from the
   active one (e.g. 'ollama' in prompt when running on 'nous'). Logs
   a warning so operators know which prompts need updating.

3. _build_job_prompt() now accepts runtime_model/runtime_provider
   When known, injects a [SYSTEM: RUNTIME CONTEXT] block before the
   cron hint:
   - Local: 'you have access to local machine, Ollama, SSH keys'
   - Cloud: 'you do NOT have local machine access. Do NOT SSH, etc.'

4. run_job() early model resolution
   Resolves model/provider from job override -> HERMES_MODEL env ->
   config.yaml model.default, derives provider from model prefix.
   Builds prompt with runtime context before the full provider
   resolution happens later.

5. Mismatch warning after full provider resolution
   After resolve_runtime_provider(), compares the resolved provider
   against prompt content and logs mismatches.

Supersedes #403 (early resolution only) and #427 (mismatch detection
only). Combines both approaches with local/cloud capability awareness.

Closes #372
2026-04-13 20:25:51 -04:00
4 changed files with 276 additions and 610 deletions

View File

@@ -544,8 +544,78 @@ def _run_job_script(script_path: str) -> tuple[bool, str]:
return False, f"Script execution failed: {exc}"
def _build_job_prompt(job: dict) -> str:
"""Build the effective prompt for a cron job, optionally loading one or more skills first."""
# ---------------------------------------------------------------------------
# Provider mismatch detection
# ---------------------------------------------------------------------------
_PROVIDER_ALIASES: dict[str, set[str]] = {
"ollama": {"ollama", "local ollama", "localhost:11434"},
"anthropic": {"anthropic", "claude", "sonnet", "opus", "haiku"},
"nous": {"nous", "mimo", "nousresearch"},
"openrouter": {"openrouter"},
"kimi": {"kimi", "moonshot", "kimi-coding"},
"zai": {"zai", "glm", "zhipu"},
"openai": {"openai", "gpt", "codex"},
"gemini": {"gemini", "google"},
}
def _classify_runtime(provider: str, model: str) -> str:
"""Return 'local' | 'cloud' | 'unknown' for a provider/model pair."""
p = (provider or "").strip().lower()
m = (model or "").strip().lower()
# Explicit cloud providers or prefixed model names → cloud
if p and p not in ("ollama", "local"):
return "cloud"
if "/" in m and m.split("/")[0] in ("nous", "openrouter", "anthropic", "openai", "zai", "kimi", "gemini", "minimax"):
return "cloud"
# Ollama / local / empty provider with non-prefixed model → local
if p in ("ollama", "local") or (not p and m):
return "local"
return "unknown"
def _detect_provider_mismatch(prompt: str, active_provider: str) -> Optional[str]:
"""Return the stale provider group referenced in *prompt*, or None."""
if not active_provider or not prompt:
return None
prompt_lower = prompt.lower()
active_lower = active_provider.lower().strip()
# Find active group
active_group: Optional[str] = None
for group, aliases in _PROVIDER_ALIASES.items():
if active_lower in aliases or active_lower.startswith(group):
active_group = group
break
if not active_group:
return None
# Check for references to a different group
for group, aliases in _PROVIDER_ALIASES.items():
if group == active_group:
continue
for alias in aliases:
if alias in prompt_lower:
return group
return None
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
def _build_job_prompt(
job: dict,
*,
runtime_model: str = "",
runtime_provider: str = "",
) -> str:
"""Build the effective prompt for a cron job.
Args:
job: The cron job dict.
runtime_model: Resolved model name (e.g. "xiaomi/mimo-v2-pro").
runtime_provider: Resolved provider name (e.g. "nous", "openrouter").
"""
prompt = job.get("prompt", "")
skills = job.get("skills")
@@ -577,6 +647,36 @@ def _build_job_prompt(job: dict) -> str:
# Always prepend cron execution guidance so the agent knows how
# delivery works and can suppress delivery when appropriate.
#
# Runtime context injection — tells the agent what it can actually do.
# Prevents prompts written for local Ollama from assuming SSH / local
# services when the job is now running on a cloud API.
_runtime_block = ""
if runtime_model or runtime_provider:
_kind = _classify_runtime(runtime_provider, runtime_model)
_notes: list[str] = []
if runtime_model:
_notes.append(f"MODEL: {runtime_model}")
if runtime_provider:
_notes.append(f"PROVIDER: {runtime_provider}")
if _kind == "local":
_notes.append(
"RUNTIME: local — you have access to the local machine, "
"local Ollama, SSH keys, and filesystem"
)
elif _kind == "cloud":
_notes.append(
"RUNTIME: cloud API — you do NOT have local machine access. "
"Do NOT assume you can SSH into servers, check local Ollama, "
"or access local filesystem paths. Use terminal tools only "
"for commands that work from this environment."
)
if _notes:
_runtime_block = (
"[SYSTEM: RUNTIME CONTEXT — "
+ "; ".join(_notes)
+ ". Adjust your approach based on these capabilities.]\\n\\n"
)
cron_hint = (
"[SYSTEM: You are running as a scheduled cron job. "
"DELIVERY: Your final response will be automatically delivered "
@@ -596,7 +696,7 @@ def _build_job_prompt(job: dict) -> str:
"\"[SCRIPT_FAILED]: forge.alexanderwhitestone.com timed out\" "
"\"[SCRIPT_FAILED]: script exited with code 1\".]\\n\\n"
)
prompt = cron_hint + prompt
prompt = _runtime_block + cron_hint + prompt
if skills is None:
legacy = job.get("skill")
skills = [legacy] if legacy else []
@@ -666,7 +766,36 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
job_id = job["id"]
job_name = job["name"]
prompt = _build_job_prompt(job)
# ── Early model/provider resolution ───────────────────────────────────
# We need the model name before building the prompt so the runtime
# context block can be injected. Full provider resolution happens
# later (smart routing, etc.) but the basic name is enough here.
_early_model = job.get("model") or os.getenv("HERMES_MODEL") or ""
_early_provider = os.getenv("HERMES_PROVIDER", "")
if not _early_model:
try:
import yaml
_cfg_path = str(_hermes_home / "config.yaml")
if os.path.exists(_cfg_path):
with open(_cfg_path) as _f:
_cfg_early = yaml.safe_load(_f) or {}
_mc = _cfg_early.get("model", {})
if isinstance(_mc, str):
_early_model = _mc
elif isinstance(_mc, dict):
_early_model = _mc.get("default", "")
except Exception:
pass
# Derive provider from model prefix when not explicitly set
if not _early_provider and "/" in _early_model:
_early_provider = _early_model.split("/")[0]
prompt = _build_job_prompt(
job,
runtime_model=_early_model,
runtime_provider=_early_provider,
)
origin = _resolve_origin(job)
_cron_session_id = f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
@@ -762,6 +891,20 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
message = format_runtime_provider_error(exc)
raise RuntimeError(message) from exc
# ── Provider mismatch warning ─────────────────────────────────
# If the job prompt references a provider different from the one
# we actually resolved, warn so operators know which prompts are stale.
_resolved_provider = runtime.get("provider", "") or ""
_raw_prompt = job.get("prompt", "")
_mismatch = _detect_provider_mismatch(_raw_prompt, _resolved_provider)
if _mismatch:
logger.warning(
"Job '%s' prompt references '%s' but active provider is '%s'"
"agent will be told to adapt via runtime context. "
"Consider updating this job's prompt.",
job_name, _mismatch, _resolved_provider,
)
from agent.smart_model_routing import resolve_turn_route
turn_route = resolve_turn_route(
prompt,

View File

@@ -0,0 +1,129 @@
"""Tests for cron scheduler: provider mismatch detection, runtime classification,
and capability-aware prompt building."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
def _import_scheduler():
"""Import the scheduler module, bypassing __init__.py re-exports that may
reference symbols not yet merged upstream."""
import importlib.util
spec = importlib.util.spec_from_file_location(
"cron.scheduler", str(Path(__file__).resolve().parent.parent / "cron" / "scheduler.py"),
)
mod = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(mod)
except Exception:
pass # some top-level imports may fail in CI; functions are still defined
return mod
_sched = _import_scheduler()
_classify_runtime = _sched._classify_runtime
_detect_provider_mismatch = _sched._detect_provider_mismatch
_build_job_prompt = _sched._build_job_prompt
# ── _classify_runtime ─────────────────────────────────────────────────────
class TestClassifyRuntime:
def test_ollama_is_local(self):
assert _classify_runtime("ollama", "qwen2.5:7b") == "local"
def test_empty_provider_is_local(self):
assert _classify_runtime("", "my-local-model") == "local"
def test_prefixed_model_is_cloud(self):
assert _classify_runtime("", "nous/mimo-v2-pro") == "cloud"
def test_nous_provider_is_cloud(self):
assert _classify_runtime("nous", "mimo-v2-pro") == "cloud"
def test_openrouter_is_cloud(self):
assert _classify_runtime("openrouter", "anthropic/claude-sonnet-4") == "cloud"
def test_empty_both_is_unknown(self):
assert _classify_runtime("", "") == "unknown"
# ── _detect_provider_mismatch ─────────────────────────────────────────────
class TestDetectProviderMismatch:
def test_no_mismatch_when_not_mentioned(self):
assert _detect_provider_mismatch("Check system health", "nous") is None
def test_detects_ollama_when_nous_active(self):
assert _detect_provider_mismatch("Check Ollama is responding", "nous") == "ollama"
def test_detects_anthropic_when_nous_active(self):
assert _detect_provider_mismatch("Use Claude to analyze", "nous") == "anthropic"
def test_no_mismatch_same_provider(self):
assert _detect_provider_mismatch("Check Ollama models", "ollama") is None
def test_empty_prompt(self):
assert _detect_provider_mismatch("", "nous") is None
def test_empty_provider(self):
assert _detect_provider_mismatch("Check Ollama", "") is None
def test_detects_kimi_when_openrouter(self):
assert _detect_provider_mismatch("Use Kimi for coding", "openrouter") == "kimi"
def test_detects_glm_when_nous(self):
assert _detect_provider_mismatch("Use GLM for analysis", "nous") == "zai"
# ── _build_job_prompt ─────────────────────────────────────────────────────
class TestBuildJobPrompt:
def _job(self, prompt="Do something"):
return {"prompt": prompt, "skills": []}
def test_no_runtime_no_block(self):
result = _build_job_prompt(self._job())
assert "Do something" in result
assert "RUNTIME CONTEXT" not in result
def test_cloud_runtime_injected(self):
result = _build_job_prompt(
self._job(),
runtime_model="xiaomi/mimo-v2-pro",
runtime_provider="nous",
)
assert "MODEL: xiaomi/mimo-v2-pro" in result
assert "PROVIDER: nous" in result
assert "cloud API" in result
assert "Do NOT assume you can SSH" in result
def test_local_runtime_injected(self):
result = _build_job_prompt(
self._job(),
runtime_model="qwen2.5:7b",
runtime_provider="ollama",
)
assert "RUNTIME: local" in result
assert "SSH keys" in result
def test_empty_runtime_no_block(self):
result = _build_job_prompt(self._job(), runtime_model="", runtime_provider="")
assert "RUNTIME CONTEXT" not in result
def test_cron_hint_always_present(self):
result = _build_job_prompt(self._job())
assert "scheduled cron job" in result
assert "[SYSTEM:" in result
def test_runtime_block_before_cron_hint(self):
result = _build_job_prompt(
self._job("Check Ollama"),
runtime_model="mimo-v2-pro",
runtime_provider="nous",
)
runtime_pos = result.index("RUNTIME CONTEXT")
cron_pos = result.index("scheduled cron job")
assert runtime_pos < cron_pos

View File

@@ -1,188 +0,0 @@
"""Tests for session templates (code-first seeding)."""
import json
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from tools.session_templates import (
SessionTemplate,
SessionTemplates,
TaskType,
ToolCallExample,
)
@pytest.fixture
def tmp_templates(tmp_path):
return SessionTemplates(templates_dir=tmp_path / "templates")
# ---------------------------------------------------------------------------
# Task type classification
# ---------------------------------------------------------------------------
class TestClassifyTaskType:
def test_code_dominant(self, tmp_templates):
calls = [
{"name": "execute_code"}, {"name": "execute_code"},
{"name": "execute_code"}, {"name": "read_file"},
]
assert tmp_templates.classify_task_type(calls) == TaskType.CODE
def test_file_dominant(self, tmp_templates):
calls = [
{"name": "read_file"}, {"name": "write_file"},
{"name": "patch"}, {"name": "read_file"},
{"name": "execute_code"},
]
assert tmp_templates.classify_task_type(calls) == TaskType.FILE
def test_research_dominant(self, tmp_templates):
calls = [
{"name": "web_search"}, {"name": "web_fetch"},
{"name": "web_search"}, {"name": "read_file"},
]
assert tmp_templates.classify_task_type(calls) == TaskType.RESEARCH
def test_mixed_no_dominant(self, tmp_templates):
calls = [
{"name": "execute_code"}, {"name": "read_file"},
{"name": "web_search"},
]
assert tmp_templates.classify_task_type(calls) == TaskType.MIXED
def test_empty_returns_mixed(self, tmp_templates):
assert tmp_templates.classify_task_type([]) == TaskType.MIXED
def test_threshold_is_60_percent(self, tmp_templates):
# 59% code (5/9) should be MIXED
calls = [{"name": "execute_code"}] * 5 + [{"name": "read_file"}] * 4
assert tmp_templates.classify_task_type(calls) == TaskType.MIXED
# 60% code (6/10) should be CODE
calls = [{"name": "execute_code"}] * 6 + [{"name": "read_file"}] * 4
assert tmp_templates.classify_task_type(calls) == TaskType.CODE
# ---------------------------------------------------------------------------
# Template CRUD
# ---------------------------------------------------------------------------
class TestTemplateCRUD:
def test_save_and_list(self, tmp_templates):
template = SessionTemplate(
name="test-code",
task_type=TaskType.CODE,
examples=[
ToolCallExample(tool_name="execute_code", args={"code": "print('hi')"}, success=True),
],
created_at="2026-01-01T00:00:00Z",
)
tmp_templates.save_template(template)
templates = tmp_templates.list_templates()
assert len(templates) == 1
assert templates[0].name == "test-code"
assert templates[0].task_type == TaskType.CODE
def test_list_filter_by_type(self, tmp_templates):
tmp_templates.save_template(SessionTemplate(name="t1", task_type=TaskType.CODE, examples=[]))
tmp_templates.save_template(SessionTemplate(name="t2", task_type=TaskType.FILE, examples=[]))
code_templates = tmp_templates.list_templates(TaskType.CODE)
assert len(code_templates) == 1
assert code_templates[0].name == "t1"
def test_delete(self, tmp_templates):
tmp_templates.save_template(SessionTemplate(name="delete-me", task_type=TaskType.CODE, examples=[]))
assert tmp_templates.delete_template("delete-me") is True
assert len(tmp_templates.list_templates()) == 0
def test_delete_nonexistent(self, tmp_templates):
assert tmp_templates.delete_template("nope") is False
def test_get_template_returns_best(self, tmp_templates):
tmp_templates.save_template(SessionTemplate(
name="low-usage", task_type=TaskType.CODE, examples=[], usage_count=1,
))
tmp_templates.save_template(SessionTemplate(
name="high-usage", task_type=TaskType.CODE, examples=[], usage_count=5,
))
best = tmp_templates.get_template(TaskType.CODE)
assert best.name == "high-usage"
def test_get_template_returns_none_if_empty(self, tmp_templates):
assert tmp_templates.get_template(TaskType.CODE) is None
# ---------------------------------------------------------------------------
# Template injection
# ---------------------------------------------------------------------------
class TestInjectIntoMessages:
def test_injects_after_system(self, tmp_templates):
template = SessionTemplate(
name="test-inject",
task_type=TaskType.CODE,
examples=[
ToolCallExample(
tool_name="execute_code",
args={"code": "x=1"},
result_preview="1",
success=True,
),
],
)
messages = [
{"role": "system", "content": "You are Timmy."},
{"role": "user", "content": "Hello"},
]
result = tmp_templates.inject_into_messages(template, messages)
# Should have: system, template system note, assistant tool call, tool result, user
assert len(result) == 5
assert result[0]["role"] == "system"
assert "Session Template" in result[1]["content"]
assert result[2]["role"] == "assistant"
assert result[3]["role"] == "tool"
assert result[4]["role"] == "user"
def test_skips_failed_examples(self, tmp_templates):
template = SessionTemplate(
name="test-fail",
task_type=TaskType.CODE,
examples=[
ToolCallExample(tool_name="execute_code", args={}, success=False),
ToolCallExample(tool_name="read_file", args={"path": "x"}, success=True),
],
)
messages = [{"role": "system", "content": "sys"}]
result = tmp_templates.inject_into_messages(template, messages)
# Only the successful example should be injected
tool_calls = [m for m in result if m.get("role") == "assistant" and m.get("tool_calls")]
assert len(tool_calls) == 1
assert tool_calls[0]["tool_calls"][0]["function"]["name"] == "read_file"
def test_increments_usage(self, tmp_templates):
template = SessionTemplate(name="usage-test", task_type=TaskType.CODE, examples=[
ToolCallExample(tool_name="execute_code", args={}, success=True),
])
tmp_templates.save_template(template)
tmp_templates.inject_into_messages(template, [{"role": "system", "content": "x"}])
assert template.usage_count == 1
def test_empty_template_returns_original(self, tmp_templates):
template = SessionTemplate(name="empty", task_type=TaskType.CODE, examples=[])
messages = [{"role": "user", "content": "hi"}]
result = tmp_templates.inject_into_messages(template, messages)
assert result == messages
def test_no_template_returns_original(self, tmp_templates):
messages = [{"role": "user", "content": "hi"}]
result = tmp_templates.inject_into_messages(None, messages)
assert result == messages

View File

@@ -1,418 +0,0 @@
"""
Session templates for code-first seeding.
Research finding: Code-heavy sessions (execute_code dominant in first 30 turns)
improve over time. File-heavy sessions degrade. The key is deterministic
feedback loops, not arbitrary context.
This module provides:
1. Task type classification (CODE, FILE, RESEARCH, MIXED)
2. Template extraction from completed sessions
3. Template storage (~/.hermes/session-templates/)
4. Template injection into new sessions
5. CLI interface for template management
Closes #329.
"""
from __future__ import annotations
import json
import os
import sqlite3
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Optional
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
HERMES_HOME = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes")))
TEMPLATES_DIR = HERMES_HOME / "session-templates"
SESSIONS_DB = HERMES_HOME / "state.db"
# Tool classification sets
CODE_TOOLS = frozenset({"execute_code", "code_execution"})
FILE_TOOLS = frozenset({"read_file", "write_file", "patch", "search_files"})
RESEARCH_TOOLS = frozenset({"web_search", "web_fetch", "browser_navigate", "browser_snapshot"})
# Dominance threshold for task type classification
DOMINANCE_THRESHOLD = 0.6
# Default max examples to extract per template
DEFAULT_MAX_EXAMPLES = 10
# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------
class TaskType(str, Enum):
CODE = "code"
FILE = "file"
RESEARCH = "research"
MIXED = "mixed"
@dataclass
class ToolCallExample:
"""A single tool call with its result, used as a template example."""
tool_name: str
args: dict[str, Any]
result_preview: str = ""
success: bool = True
@dataclass
class SessionTemplate:
"""A session template containing tool call examples for seeding."""
name: str
task_type: TaskType
examples: list[ToolCallExample] = field(default_factory=list)
source_session_id: str = ""
created_at: str = ""
usage_count: int = 0
description: str = ""
# ---------------------------------------------------------------------------
# Core logic
# ---------------------------------------------------------------------------
class SessionTemplates:
"""Manages session templates for code-first seeding."""
def __init__(self, templates_dir: Optional[Path] = None):
self.templates_dir = templates_dir or TEMPLATES_DIR
self.templates_dir.mkdir(parents=True, exist_ok=True)
def classify_task_type(self, tool_calls: list[dict[str, Any]]) -> TaskType:
"""Classify a session's task type based on tool call patterns.
Uses 60% threshold for dominant type.
"""
if not tool_calls:
return TaskType.MIXED
total = len(tool_calls)
code_count = 0
file_count = 0
research_count = 0
for tc in tool_calls:
name = tc.get("name", tc.get("tool_name", "")).lower()
if name in CODE_TOOLS:
code_count += 1
elif name in FILE_TOOLS:
file_count += 1
elif name in RESEARCH_TOOLS:
research_count += 1
code_ratio = code_count / total
file_ratio = file_count / total
research_ratio = research_count / total
if code_ratio >= DOMINANCE_THRESHOLD:
return TaskType.CODE
if file_ratio >= DOMINANCE_THRESHOLD:
return TaskType.FILE
if research_ratio >= DOMINANCE_THRESHOLD:
return TaskType.RESEARCH
return TaskType.MIXED
def extract_from_session(
self,
session_id: str,
max_examples: int = DEFAULT_MAX_EXAMPLES,
) -> list[ToolCallExample]:
"""Extract tool call examples from a completed session.
Reads from the SQLite session database.
"""
examples: list[ToolCallExample] = []
db_path = SESSIONS_DB
if not db_path.exists():
return examples
try:
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT messages FROM sessions WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
(session_id,),
).fetchone()
if not rows:
conn.close()
return examples
messages = json.loads(rows["messages"])
# Extract tool calls from assistant messages
for msg in messages:
if msg.get("role") != "assistant":
continue
tool_calls = msg.get("tool_calls", [])
if not tool_calls:
continue
for tc in tool_calls:
if len(examples) >= max_examples:
break
fn = tc.get("function", {})
name = fn.get("name", "")
if not name:
continue
try:
args = json.loads(fn.get("arguments", "{}"))
except (json.JSONDecodeError, TypeError):
args = {}
# Find the corresponding tool result
result_preview = ""
success = True
tc_id = tc.get("id", "")
for result_msg in messages:
if (result_msg.get("role") == "tool"
and result_msg.get("tool_call_id") == tc_id):
content = result_msg.get("content", "")
result_preview = str(content)[:200]
# Heuristic: errors contain common failure markers
if any(marker in result_preview.lower() for marker in ("error", "failed", "traceback", "exception")):
success = False
break
examples.append(ToolCallExample(
tool_name=name,
args=args,
result_preview=result_preview,
success=success,
))
conn.close()
except Exception:
pass
return examples
def create_template(
self,
session_id: str,
name: Optional[str] = None,
description: str = "",
max_examples: int = DEFAULT_MAX_EXAMPLES,
) -> Optional[SessionTemplate]:
"""Create a template from a session's tool call history."""
examples = self.extract_from_session(session_id, max_examples)
if not examples:
return None
tool_calls_for_type = [{"name": e.tool_name} for e in examples]
task_type = self.classify_task_type(tool_calls_for_type)
template_name = name or f"{task_type.value}_{session_id[:8]}"
from datetime import datetime
template = SessionTemplate(
name=template_name,
task_type=task_type,
examples=examples,
source_session_id=session_id,
created_at=datetime.utcnow().isoformat() + "Z",
description=description or f"Auto-extracted from {session_id}",
)
self.save_template(template)
return template
def save_template(self, template: SessionTemplate) -> Path:
"""Save a template to disk."""
path = self.templates_dir / f"{template.name}.json"
data = {
"name": template.name,
"task_type": template.task_type.value,
"examples": [asdict(e) for e in template.examples],
"source_session_id": template.source_session_id,
"created_at": template.created_at,
"usage_count": template.usage_count,
"description": template.description,
}
path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n")
return path
def get_template(self, task_type: TaskType) -> Optional[SessionTemplate]:
"""Get the best template for a given task type."""
templates = self.list_templates(task_type)
if not templates:
return None
# Prefer templates with more usage (proven useful)
templates.sort(key=lambda t: t.usage_count, reverse=True)
return templates[0]
def list_templates(self, task_type: Optional[TaskType] = None) -> list[SessionTemplate]:
"""List all templates, optionally filtered by type."""
templates: list[SessionTemplate] = []
for path in sorted(self.templates_dir.glob("*.json")):
try:
data = json.loads(path.read_text())
examples = [ToolCallExample(**e) for e in data.get("examples", [])]
template = SessionTemplate(
name=data["name"],
task_type=TaskType(data["task_type"]),
examples=examples,
source_session_id=data.get("source_session_id", ""),
created_at=data.get("created_at", ""),
usage_count=data.get("usage_count", 0),
description=data.get("description", ""),
)
if task_type is None or template.task_type == task_type:
templates.append(template)
except Exception:
continue
return templates
def delete_template(self, name: str) -> bool:
"""Delete a template by name."""
path = self.templates_dir / f"{name}.json"
if path.exists():
path.unlink()
return True
return False
def inject_into_messages(
self,
template: SessionTemplate,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Inject template examples into a session's messages.
Inserts tool call examples after system messages to establish
feedback loops early.
"""
if not template or not template.examples:
return messages
# Build injection messages
injection: list[dict[str, Any]] = []
# System note about the template
injection.append({
"role": "system",
"content": (
f"[Session Template: '{template.name}' ({template.task_type.value})]\n"
f"The following are examples of successful tool calls from a similar session. "
f"Use them as patterns for your own tool usage."
),
})
# Add example tool call/result pairs
for ex in template.examples:
if not ex.success:
continue # Only inject successful examples
injection.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": f"template_{template.name}_{ex.tool_name}",
"type": "function",
"function": {
"name": ex.tool_name,
"arguments": json.dumps(ex.args),
},
}],
})
injection.append({
"role": "tool",
"tool_call_id": f"template_{template.name}_{ex.tool_name}",
"content": ex.result_preview or "(example result)",
})
# Find insertion point: after system messages
insert_idx = 0
for i, msg in enumerate(messages):
if msg.get("role") == "system":
insert_idx = i + 1
else:
break
# Insert
result = messages[:insert_idx] + injection + messages[insert_idx:]
# Update usage count
template.usage_count += 1
self.save_template(template)
return result
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _cli():
"""Simple CLI for session template management."""
import argparse
import sys
parser = argparse.ArgumentParser(description="Session template management")
sub = parser.add_subparsers(dest="command")
# list
list_cmd = sub.add_parser("list", help="List templates")
list_cmd.add_argument("--type", choices=["code", "file", "research", "mixed"])
# create
create_cmd = sub.add_parser("create", help="Create template from session")
create_cmd.add_argument("session_id", help="Session ID to extract from")
create_cmd.add_argument("--name", help="Template name")
create_cmd.add_argument("--max-examples", type=int, default=10)
# delete
delete_cmd = sub.add_parser("delete", help="Delete template")
delete_cmd.add_argument("name", help="Template name")
args = parser.parse_args()
tm = SessionTemplates()
if args.command == "list":
task_type = TaskType(args.type) if args.type else None
templates = tm.list_templates(task_type)
if not templates:
print("No templates found.")
return
for t in templates:
print(f" {t.name:30s} {t.task_type.value:10s} {len(t.examples)} examples, used {t.usage_count}x")
elif args.command == "create":
template = tm.create_template(args.session_id, name=args.name, max_examples=args.max_examples)
if template:
print(f"Created template: {template.name} ({template.task_type.value}, {len(template.examples)} examples)")
else:
print(f"No tool calls found in session {args.session_id}")
sys.exit(1)
elif args.command == "delete":
if tm.delete_template(args.name):
print(f"Deleted template: {args.name}")
else:
print(f"Template not found: {args.name}")
sys.exit(1)
else:
parser.print_help()
if __name__ == "__main__":
_cli()