From f4a75baca0f4e027dc21e061769be5f2ebf64dd8 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Tue, 14 Apr 2026 11:41:25 +0000 Subject: [PATCH] test(templates): Add tests for session templates Resolves #329 --- tests/test_session_templates.py | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/test_session_templates.py diff --git a/tests/test_session_templates.py b/tests/test_session_templates.py new file mode 100644 index 000000000..76d03c655 --- /dev/null +++ b/tests/test_session_templates.py @@ -0,0 +1,92 @@ +"""Tests for session templates.""" +import json +import pytest +import tempfile +from pathlib import Path +from tools.session_templates import Templates, Template, ToolExample, TaskType + + +class TestClassification: + def test_code(self): + t = Templates() + assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE + + def test_file(self): + t = Templates() + assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE + + def test_mixed(self): + t = Templates() + assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED + + +class TestToolExample: + def test_roundtrip(self): + e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0) + d = e.to_dict() + e2 = ToolExample.from_dict(d) + assert e2.tool_name == "execute_code" + assert e2.result == "hi" + + +class TestTemplate: + def test_roundtrip(self): + examples = [ToolExample("execute_code", {}, "result", True)] + t = Template("test", TaskType.CODE, examples) + d = t.to_dict() + t2 = Template.from_dict(d) + assert t2.name == "test" + assert t2.task_type == TaskType.CODE + + +class TestTemplates: + def test_create_list(self): + with tempfile.TemporaryDirectory() as d: + ts = Templates(Path(d)) + examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)] + t = Template("test", TaskType.CODE, examples) + ts.templates["test"] = t + ts._save(t) + assert len(ts.list()) == 1 + + def test_get(self): + with tempfile.TemporaryDirectory() as d: + ts = Templates(Path(d)) + ts.templates["code"] = Template("code", TaskType.CODE, []) + ts.templates["file"] = Template("file", TaskType.FILE, []) + assert ts.get(TaskType.CODE).name == "code" + assert ts.get(TaskType.FILE).name == "file" + assert ts.get(TaskType.RESEARCH) is None + + def test_inject(self): + with tempfile.TemporaryDirectory() as d: + ts = Templates(Path(d)) + examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)] + t = Template("test", TaskType.CODE, examples) + ts.templates["test"] = t + messages = [{"role": "system", "content": "You are helpful."}] + result = ts.inject(t, messages) + assert len(result) > len(messages) + assert t.used == 1 + + def test_delete(self): + with tempfile.TemporaryDirectory() as d: + ts = Templates(Path(d)) + t = Template("test", TaskType.CODE, []) + ts.templates["test"] = t + ts._save(t) + assert ts.delete("test") + assert "test" not in ts.templates + + def test_stats(self): + with tempfile.TemporaryDirectory() as d: + ts = Templates(Path(d)) + ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)]) + ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)]) + s = ts.stats() + assert s["total"] == 2 + assert s["examples"] == 2 + + +if __name__ == "__main__": + pytest.main([__file__])