Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 33s
Resolves #329
93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
"""Tests for session templates."""
|
|
import json
|
|
import pytest
|
|
import tempfile
|
|
from pathlib import Path
|
|
from tools.session_templates import Templates, Template, ToolExample, TaskType
|
|
|
|
|
|
class TestClassification:
|
|
def test_code(self):
|
|
t = Templates()
|
|
assert t.classify([{"tool_name": "execute_code"}] * 4 + [{"tool_name": "read_file"}]) == TaskType.CODE
|
|
|
|
def test_file(self):
|
|
t = Templates()
|
|
assert t.classify([{"tool_name": "read_file"}] * 4 + [{"tool_name": "execute_code"}]) == TaskType.FILE
|
|
|
|
def test_mixed(self):
|
|
t = Templates()
|
|
assert t.classify([{"tool_name": "execute_code"}, {"tool_name": "read_file"}]) == TaskType.MIXED
|
|
|
|
|
|
class TestToolExample:
|
|
def test_roundtrip(self):
|
|
e = ToolExample("execute_code", {"code": "print('hi')"}, "hi", True, 0)
|
|
d = e.to_dict()
|
|
e2 = ToolExample.from_dict(d)
|
|
assert e2.tool_name == "execute_code"
|
|
assert e2.result == "hi"
|
|
|
|
|
|
class TestTemplate:
|
|
def test_roundtrip(self):
|
|
examples = [ToolExample("execute_code", {}, "result", True)]
|
|
t = Template("test", TaskType.CODE, examples)
|
|
d = t.to_dict()
|
|
t2 = Template.from_dict(d)
|
|
assert t2.name == "test"
|
|
assert t2.task_type == TaskType.CODE
|
|
|
|
|
|
class TestTemplates:
|
|
def test_create_list(self):
|
|
with tempfile.TemporaryDirectory() as d:
|
|
ts = Templates(Path(d))
|
|
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
|
|
t = Template("test", TaskType.CODE, examples)
|
|
ts.templates["test"] = t
|
|
ts._save(t)
|
|
assert len(ts.list()) == 1
|
|
|
|
def test_get(self):
|
|
with tempfile.TemporaryDirectory() as d:
|
|
ts = Templates(Path(d))
|
|
ts.templates["code"] = Template("code", TaskType.CODE, [])
|
|
ts.templates["file"] = Template("file", TaskType.FILE, [])
|
|
assert ts.get(TaskType.CODE).name == "code"
|
|
assert ts.get(TaskType.FILE).name == "file"
|
|
assert ts.get(TaskType.RESEARCH) is None
|
|
|
|
def test_inject(self):
|
|
with tempfile.TemporaryDirectory() as d:
|
|
ts = Templates(Path(d))
|
|
examples = [ToolExample("execute_code", {"code": "print('hi')"}, "hi", True)]
|
|
t = Template("test", TaskType.CODE, examples)
|
|
ts.templates["test"] = t
|
|
messages = [{"role": "system", "content": "You are helpful."}]
|
|
result = ts.inject(t, messages)
|
|
assert len(result) > len(messages)
|
|
assert t.used == 1
|
|
|
|
def test_delete(self):
|
|
with tempfile.TemporaryDirectory() as d:
|
|
ts = Templates(Path(d))
|
|
t = Template("test", TaskType.CODE, [])
|
|
ts.templates["test"] = t
|
|
ts._save(t)
|
|
assert ts.delete("test")
|
|
assert "test" not in ts.templates
|
|
|
|
def test_stats(self):
|
|
with tempfile.TemporaryDirectory() as d:
|
|
ts = Templates(Path(d))
|
|
ts.templates["a"] = Template("a", TaskType.CODE, [ToolExample("x", {}, "", True)])
|
|
ts.templates["b"] = Template("b", TaskType.FILE, [ToolExample("y", {}, "", True)])
|
|
s = ts.stats()
|
|
assert s["total"] == 2
|
|
assert s["examples"] == 2
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|