test(templates): Add tests for session templates
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m16s
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 1m16s
Resolves #329
This commit is contained in:
92
tests/test_session_templates.py
Normal file
92
tests/test_session_templates.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tests for session templates."""
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tools.session_templates import Templates, Template, ToolExample, TaskType
|
||||
|
||||
|
||||
class TestClassification:
|
||||
def test_code(self):
|
||||
t = Templates()
|
||||
assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE
|
||||
|
||||
def test_file(self):
|
||||
t = Templates()
|
||||
assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE
|
||||
|
||||
def test_mixed(self):
|
||||
t = Templates()
|
||||
assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED
|
||||
|
||||
|
||||
class TestToolExample:
|
||||
def test_roundtrip(self):
|
||||
e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0)
|
||||
d = e.to_dict()
|
||||
e2 = ToolExample.from_dict(d)
|
||||
assert e2.tool_name == "execute_code"
|
||||
assert e2.result == "hi"
|
||||
|
||||
|
||||
class TestTemplate:
|
||||
def test_roundtrip(self):
|
||||
examples = [ToolExample("execute_code", {}, "result", True)]
|
||||
t = Template("test", TaskType.CODE, examples)
|
||||
d = t.to_dict()
|
||||
t2 = Template.from_dict(d)
|
||||
assert t2.name == "test"
|
||||
assert t2.task_type == TaskType.CODE
|
||||
|
||||
|
||||
class TestTemplates:
|
||||
def test_create_list(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
ts = Templates(Path(d))
|
||||
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
|
||||
t = Template("test", TaskType.CODE, examples)
|
||||
ts.templates["test"] = t
|
||||
ts._save(t)
|
||||
assert len(ts.list()) == 1
|
||||
|
||||
def test_get(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
ts = Templates(Path(d))
|
||||
ts.templates["code"] = Template("code", TaskType.CODE, [])
|
||||
ts.templates["file"] = Template("file", TaskType.FILE, [])
|
||||
assert ts.get(TaskType.CODE).name == "code"
|
||||
assert ts.get(TaskType.FILE).name == "file"
|
||||
assert ts.get(TaskType.RESEARCH) is None
|
||||
|
||||
def test_inject(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
ts = Templates(Path(d))
|
||||
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
|
||||
t = Template("test", TaskType.CODE, examples)
|
||||
ts.templates["test"] = t
|
||||
messages = [{"role": "system", "content": "You are helpful."}]
|
||||
result = ts.inject(t, messages)
|
||||
assert len(result) > len(messages)
|
||||
assert t.used == 1
|
||||
|
||||
def test_delete(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
ts = Templates(Path(d))
|
||||
t = Template("test", TaskType.CODE, [])
|
||||
ts.templates["test"] = t
|
||||
ts._save(t)
|
||||
assert ts.delete("test")
|
||||
assert "test" not in ts.templates
|
||||
|
||||
def test_stats(self):
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
ts = Templates(Path(d))
|
||||
ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)])
|
||||
ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)])
|
||||
s = ts.stats()
|
||||
assert s["total"] == 2
|
||||
assert s["examples"] == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user