Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m6s
Code-heavy sessions improve over time. File-heavy sessions degrade. Session templates seed new sessions with successful tool call patterns to establish feedback loops early. ## What - tools/session_templates.py — task type classification (60% threshold), template extraction from SQLite sessions, storage, injection, CLI - tests/tools/test_session_templates.py — 17 tests ## Task types CODE, FILE, RESEARCH, MIXED — classified by tool call frequency. ## CLI python -m tools.session_templates list python -m tools.session_templates create <session_id> --name my-template python -m tools.session_templates delete <name> Closes #329.
189 lines
7.2 KiB
Python
189 lines
7.2 KiB
Python
"""Tests for session templates (code-first seeding)."""
|
|
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from tools.session_templates import (
|
|
SessionTemplate,
|
|
SessionTemplates,
|
|
TaskType,
|
|
ToolCallExample,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_templates(tmp_path):
|
|
return SessionTemplates(templates_dir=tmp_path / "templates")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task type classification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestClassifyTaskType:
|
|
def test_code_dominant(self, tmp_templates):
|
|
calls = [
|
|
{"name": "execute_code"}, {"name": "execute_code"},
|
|
{"name": "execute_code"}, {"name": "read_file"},
|
|
]
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.CODE
|
|
|
|
def test_file_dominant(self, tmp_templates):
|
|
calls = [
|
|
{"name": "read_file"}, {"name": "write_file"},
|
|
{"name": "patch"}, {"name": "read_file"},
|
|
{"name": "execute_code"},
|
|
]
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.FILE
|
|
|
|
def test_research_dominant(self, tmp_templates):
|
|
calls = [
|
|
{"name": "web_search"}, {"name": "web_fetch"},
|
|
{"name": "web_search"}, {"name": "read_file"},
|
|
]
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.RESEARCH
|
|
|
|
def test_mixed_no_dominant(self, tmp_templates):
|
|
calls = [
|
|
{"name": "execute_code"}, {"name": "read_file"},
|
|
{"name": "web_search"},
|
|
]
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.MIXED
|
|
|
|
def test_empty_returns_mixed(self, tmp_templates):
|
|
assert tmp_templates.classify_task_type([]) == TaskType.MIXED
|
|
|
|
def test_threshold_is_60_percent(self, tmp_templates):
|
|
# 59% code (5/9) should be MIXED
|
|
calls = [{"name": "execute_code"}] * 5 + [{"name": "read_file"}] * 4
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.MIXED
|
|
|
|
# 60% code (6/10) should be CODE
|
|
calls = [{"name": "execute_code"}] * 6 + [{"name": "read_file"}] * 4
|
|
assert tmp_templates.classify_task_type(calls) == TaskType.CODE
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Template CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestTemplateCRUD:
|
|
def test_save_and_list(self, tmp_templates):
|
|
template = SessionTemplate(
|
|
name="test-code",
|
|
task_type=TaskType.CODE,
|
|
examples=[
|
|
ToolCallExample(tool_name="execute_code", args={"code": "print('hi')"}, success=True),
|
|
],
|
|
created_at="2026-01-01T00:00:00Z",
|
|
)
|
|
tmp_templates.save_template(template)
|
|
|
|
templates = tmp_templates.list_templates()
|
|
assert len(templates) == 1
|
|
assert templates[0].name == "test-code"
|
|
assert templates[0].task_type == TaskType.CODE
|
|
|
|
def test_list_filter_by_type(self, tmp_templates):
|
|
tmp_templates.save_template(SessionTemplate(name="t1", task_type=TaskType.CODE, examples=[]))
|
|
tmp_templates.save_template(SessionTemplate(name="t2", task_type=TaskType.FILE, examples=[]))
|
|
|
|
code_templates = tmp_templates.list_templates(TaskType.CODE)
|
|
assert len(code_templates) == 1
|
|
assert code_templates[0].name == "t1"
|
|
|
|
def test_delete(self, tmp_templates):
|
|
tmp_templates.save_template(SessionTemplate(name="delete-me", task_type=TaskType.CODE, examples=[]))
|
|
assert tmp_templates.delete_template("delete-me") is True
|
|
assert len(tmp_templates.list_templates()) == 0
|
|
|
|
def test_delete_nonexistent(self, tmp_templates):
|
|
assert tmp_templates.delete_template("nope") is False
|
|
|
|
def test_get_template_returns_best(self, tmp_templates):
|
|
tmp_templates.save_template(SessionTemplate(
|
|
name="low-usage", task_type=TaskType.CODE, examples=[], usage_count=1,
|
|
))
|
|
tmp_templates.save_template(SessionTemplate(
|
|
name="high-usage", task_type=TaskType.CODE, examples=[], usage_count=5,
|
|
))
|
|
best = tmp_templates.get_template(TaskType.CODE)
|
|
assert best.name == "high-usage"
|
|
|
|
def test_get_template_returns_none_if_empty(self, tmp_templates):
|
|
assert tmp_templates.get_template(TaskType.CODE) is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Template injection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestInjectIntoMessages:
|
|
def test_injects_after_system(self, tmp_templates):
|
|
template = SessionTemplate(
|
|
name="test-inject",
|
|
task_type=TaskType.CODE,
|
|
examples=[
|
|
ToolCallExample(
|
|
tool_name="execute_code",
|
|
args={"code": "x=1"},
|
|
result_preview="1",
|
|
success=True,
|
|
),
|
|
],
|
|
)
|
|
messages = [
|
|
{"role": "system", "content": "You are Timmy."},
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
result = tmp_templates.inject_into_messages(template, messages)
|
|
|
|
# Should have: system, template system note, assistant tool call, tool result, user
|
|
assert len(result) == 5
|
|
assert result[0]["role"] == "system"
|
|
assert "Session Template" in result[1]["content"]
|
|
assert result[2]["role"] == "assistant"
|
|
assert result[3]["role"] == "tool"
|
|
assert result[4]["role"] == "user"
|
|
|
|
def test_skips_failed_examples(self, tmp_templates):
|
|
template = SessionTemplate(
|
|
name="test-fail",
|
|
task_type=TaskType.CODE,
|
|
examples=[
|
|
ToolCallExample(tool_name="execute_code", args={}, success=False),
|
|
ToolCallExample(tool_name="read_file", args={"path": "x"}, success=True),
|
|
],
|
|
)
|
|
messages = [{"role": "system", "content": "sys"}]
|
|
result = tmp_templates.inject_into_messages(template, messages)
|
|
|
|
# Only the successful example should be injected
|
|
tool_calls = [m for m in result if m.get("role") == "assistant" and m.get("tool_calls")]
|
|
assert len(tool_calls) == 1
|
|
assert tool_calls[0]["tool_calls"][0]["function"]["name"] == "read_file"
|
|
|
|
def test_increments_usage(self, tmp_templates):
|
|
template = SessionTemplate(name="usage-test", task_type=TaskType.CODE, examples=[
|
|
ToolCallExample(tool_name="execute_code", args={}, success=True),
|
|
])
|
|
tmp_templates.save_template(template)
|
|
|
|
tmp_templates.inject_into_messages(template, [{"role": "system", "content": "x"}])
|
|
assert template.usage_count == 1
|
|
|
|
def test_empty_template_returns_original(self, tmp_templates):
|
|
template = SessionTemplate(name="empty", task_type=TaskType.CODE, examples=[])
|
|
messages = [{"role": "user", "content": "hi"}]
|
|
result = tmp_templates.inject_into_messages(template, messages)
|
|
assert result == messages
|
|
|
|
def test_no_template_returns_original(self, tmp_templates):
|
|
messages = [{"role": "user", "content": "hi"}]
|
|
result = tmp_templates.inject_into_messages(None, messages)
|
|
assert result == messages
|