"""Tests for session templates.""" import json import pytest import tempfile from pathlib import Path from tools.session_templates import Templates, Template, ToolExample, TaskType class TestClassification: def test_code(self): t = Templates() assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE def test_file(self): t = Templates() assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE def test_mixed(self): t = Templates() assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED class TestToolExample: def test_roundtrip(self): e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0) d = e.to_dict() e2 = ToolExample.from_dict(d) assert e2.tool_name == "execute_code" assert e2.result == "hi" class TestTemplate: def test_roundtrip(self): examples = [ToolExample("execute_code", {}, "result", True)] t = Template("test", TaskType.CODE, examples) d = t.to_dict() t2 = Template.from_dict(d) assert t2.name == "test" assert t2.task_type == TaskType.CODE class TestTemplates: def test_create_list(self): with tempfile.TemporaryDirectory() as d: ts = Templates(Path(d)) examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)] t = Template("test", TaskType.CODE, examples) ts.templates["test"] = t ts._save(t) assert len(ts.list()) == 1 def test_get(self): with tempfile.TemporaryDirectory() as d: ts = Templates(Path(d)) ts.templates["code"] = Template("code", TaskType.CODE, []) ts.templates["file"] = Template("file", TaskType.FILE, []) assert ts.get(TaskType.CODE).name == "code" assert ts.get(TaskType.FILE).name == "file" assert ts.get(TaskType.RESEARCH) is None def test_inject(self): with tempfile.TemporaryDirectory() as d: ts = Templates(Path(d)) examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)] t = Template("test", TaskType.CODE, examples) ts.templates["test"] = t messages = [{"role": "system", "content": "You are helpful."}] result = ts.inject(t, messages) assert len(result) > len(messages) assert t.used == 1 def test_delete(self): with tempfile.TemporaryDirectory() as d: ts = Templates(Path(d)) t = Template("test", TaskType.CODE, []) ts.templates["test"] = t ts._save(t) assert ts.delete("test") assert "test" not in ts.templates def test_stats(self): with tempfile.TemporaryDirectory() as d: ts = Templates(Path(d)) ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)]) ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)]) s = ts.stats() assert s["total"] == 2 assert s["examples"] == 2 if __name__ == "__main__": pytest.main([__file__])