fix: resolve merge conflict with main in clipboard.py
This commit is contained in:
@@ -176,3 +176,93 @@ class TestCompressWithClient:
|
||||
contents = [m.get("content", "") for m in result]
|
||||
assert any("CONTEXT SUMMARY" in c for c in contents)
|
||||
assert len(result) < len(msgs)
|
||||
|
||||
def test_summarization_does_not_split_tool_call_pairs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
protect_first_n=3,
|
||||
protect_last_n=4,
|
||||
)
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "Could you address the reviewer comments in PR#71"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
||||
{"id": "call_b", "type": "function", "function": {"name": "skill_view", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a", "content": "output a"},
|
||||
{"role": "tool", "tool_call_id": "call_b", "content": "output b"},
|
||||
{"role": "user", "content": "later 1"},
|
||||
{"role": "assistant", "content": "later 2"},
|
||||
{"role": "tool", "tool_call_id": "call_x", "content": "later output"},
|
||||
{"role": "assistant", "content": "later 3"},
|
||||
{"role": "user", "content": "later 4"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
|
||||
answered_ids = {
|
||||
msg.get("tool_call_id")
|
||||
for msg in result
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id")
|
||||
}
|
||||
for msg in result:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
assert tc["id"] in answered_ids
|
||||
|
||||
def test_summarization_does_not_start_tail_with_tool_outputs(self):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "[CONTEXT SUMMARY]: compressed middle"
|
||||
mock_client.chat.completions.create.return_value = mock_response
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000), \
|
||||
patch("agent.context_compressor.get_text_auxiliary_client", return_value=(mock_client, "test-model")):
|
||||
c = ContextCompressor(
|
||||
model="test",
|
||||
quiet_mode=True,
|
||||
protect_first_n=2,
|
||||
protect_last_n=3,
|
||||
)
|
||||
|
||||
msgs = [
|
||||
{"role": "user", "content": "earlier 1"},
|
||||
{"role": "assistant", "content": "earlier 2"},
|
||||
{"role": "user", "content": "earlier 3"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "call_c", "type": "function", "function": {"name": "search_files", "arguments": "{}"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_c", "content": "output c"},
|
||||
{"role": "user", "content": "latest user"},
|
||||
]
|
||||
|
||||
result = c.compress(msgs)
|
||||
|
||||
called_ids = {
|
||||
tc["id"]
|
||||
for msg in result
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
for tc in msg["tool_calls"]
|
||||
}
|
||||
for msg in result:
|
||||
if msg.get("role") == "tool" and msg.get("tool_call_id"):
|
||||
assert msg["tool_call_id"] in called_ids
|
||||
|
||||
180
tests/gateway/test_async_memory_flush.py
Normal file
180
tests/gateway/test_async_memory_flush.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""Tests for proactive memory flush on session expiry.
|
||||
|
||||
Verifies that:
|
||||
1. _is_session_expired() works from a SessionEntry alone (no source needed)
|
||||
2. The sync callback is no longer called in get_or_create_session
|
||||
3. _pre_flushed_sessions tracking works correctly
|
||||
4. The background watcher can detect expired sessions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
|
||||
from gateway.session import SessionSource, SessionStore, SessionEntry
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def idle_store(tmp_path):
|
||||
"""SessionStore with a 60-minute idle reset policy."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def no_reset_store(tmp_path):
|
||||
"""SessionStore with no reset policy (mode=none)."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="none"),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
s = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
s._db = None
|
||||
s._loaded = True
|
||||
return s
|
||||
|
||||
|
||||
class TestIsSessionExpired:
|
||||
"""_is_session_expired should detect expiry from entry alone."""
|
||||
|
||||
def test_idle_session_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_1",
|
||||
created_at=datetime.now() - timedelta(hours=3),
|
||||
updated_at=datetime.now() - timedelta(minutes=120),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is True
|
||||
|
||||
def test_active_session_not_expired(self, idle_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_2",
|
||||
created_at=datetime.now() - timedelta(hours=1),
|
||||
updated_at=datetime.now() - timedelta(minutes=10),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_none_mode_never_expires(self, no_reset_store):
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_3",
|
||||
created_at=datetime.now() - timedelta(days=30),
|
||||
updated_at=datetime.now() - timedelta(days=30),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert no_reset_store._is_session_expired(entry) is False
|
||||
|
||||
def test_active_processes_prevent_expiry(self, idle_store):
|
||||
"""Sessions with active background processes should never expire."""
|
||||
idle_store._has_active_processes_fn = lambda key: True
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_4",
|
||||
created_at=datetime.now() - timedelta(hours=5),
|
||||
updated_at=datetime.now() - timedelta(hours=5),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert idle_store._is_session_expired(entry) is False
|
||||
|
||||
def test_daily_mode_expired(self, tmp_path):
|
||||
"""Daily mode should expire sessions from before today's reset hour."""
|
||||
config = GatewayConfig(
|
||||
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
|
||||
)
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = None
|
||||
store._loaded = True
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key="agent:main:telegram:dm",
|
||||
session_id="sid_5",
|
||||
created_at=datetime.now() - timedelta(days=2),
|
||||
updated_at=datetime.now() - timedelta(days=2),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
assert store._is_session_expired(entry) is True
|
||||
|
||||
|
||||
class TestGetOrCreateSessionNoCallback:
|
||||
"""get_or_create_session should NOT call a sync flush callback."""
|
||||
|
||||
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
|
||||
"""When a session auto-resets, the pre_flushed marker should be discarded."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
# Create initial session
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
old_sid = entry1.session_id
|
||||
|
||||
# Simulate the watcher having flushed it
|
||||
idle_store._pre_flushed_sessions.add(old_sid)
|
||||
|
||||
# Simulate the session going idle
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Next call should auto-reset
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.session_id != old_sid
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
# The old session_id should be removed from pre_flushed
|
||||
assert old_sid not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_no_sync_callback_invoked(self, idle_store):
|
||||
"""No synchronous callback should block during auto-reset."""
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="123",
|
||||
chat_type="dm",
|
||||
)
|
||||
entry1 = idle_store.get_or_create_session(source)
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=120)
|
||||
idle_store._save()
|
||||
|
||||
# Verify no _on_auto_reset attribute
|
||||
assert not hasattr(idle_store, '_on_auto_reset')
|
||||
|
||||
# This should NOT block (no sync LLM call)
|
||||
entry2 = idle_store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
|
||||
|
||||
class TestPreFlushedSessionsTracking:
|
||||
"""The _pre_flushed_sessions set should prevent double-flushing."""
|
||||
|
||||
def test_starts_empty(self, idle_store):
|
||||
assert len(idle_store._pre_flushed_sessions) == 0
|
||||
|
||||
def test_add_and_check(self, idle_store):
|
||||
idle_store._pre_flushed_sessions.add("sid_old")
|
||||
assert "sid_old" in idle_store._pre_flushed_sessions
|
||||
assert "sid_other" not in idle_store._pre_flushed_sessions
|
||||
|
||||
def test_discard_on_reset(self, idle_store):
|
||||
"""discard should remove without raising if not present."""
|
||||
idle_store._pre_flushed_sessions.add("sid_a")
|
||||
idle_store._pre_flushed_sessions.discard("sid_a")
|
||||
assert "sid_a" not in idle_store._pre_flushed_sessions
|
||||
# discard on non-existent should not raise
|
||||
idle_store._pre_flushed_sessions.discard("sid_nonexistent")
|
||||
200
tests/gateway/test_resume_command.py
Normal file
200
tests/gateway/test_resume_command.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Tests for /resume gateway slash command.
|
||||
|
||||
Tests the _handle_resume_command handler (switch to a previously-named session)
|
||||
across gateway messenger platforms.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_event(text="/resume", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _session_key_for_event(event):
|
||||
"""Get the session key that build_session_key produces for an event."""
|
||||
return build_session_key(event.source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None, current_session_id="current_session_001",
|
||||
event=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._session_db = session_db
|
||||
runner._running_agents = {}
|
||||
|
||||
# Compute the real session key if an event is provided
|
||||
session_key = build_session_key(event.source) if event else "agent:main:telegram:dm"
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = current_session_id
|
||||
mock_session_entry.session_key = session_key
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
mock_store.load_transcript.return_value = []
|
||||
mock_store.switch_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
# Stub out memory flushing
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_resume_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleResumeCommand:
|
||||
"""Tests for GatewayRunner._handle_resume_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is unavailable."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/resume My Project")
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "not available" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_named_sessions_when_no_arg(self, tmp_path):
|
||||
"""With no argument, lists recently titled sessions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram")
|
||||
db.create_session("sess_002", "telegram")
|
||||
db.set_session_title("sess_001", "Research")
|
||||
db.set_session_title("sess_002", "Coding")
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Research" in result
|
||||
assert "Coding" in result
|
||||
assert "Named Sessions" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_usage_when_no_titled(self, tmp_path):
|
||||
"""With no arg and no titled sessions, shows instructions."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_001", "telegram") # No title
|
||||
|
||||
event = _make_event(text="/resume")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No named sessions" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_by_name(self, tmp_path):
|
||||
"""Resolves a title and switches to that session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session_abc", "telegram")
|
||||
db.set_session_title("old_session_abc", "My Project")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
assert "My Project" in result
|
||||
# Verify switch_session was called with the old session ID
|
||||
runner.session_store.switch_session.assert_called_once()
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "old_session_abc"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_nonexistent_name(self, tmp_path):
|
||||
"""Returns error for unknown session name."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Nonexistent Session")
|
||||
runner = _make_runner(session_db=db, event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "No session found" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_already_on_session(self, tmp_path):
|
||||
"""Returns friendly message when already on the requested session."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
db.set_session_title("current_session_001", "Active Project")
|
||||
|
||||
event = _make_event(text="/resume Active Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
assert "Already on session" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_auto_lineage(self, tmp_path):
|
||||
"""Asking for 'My Project' when 'My Project #2' exists gets the latest."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("sess_v1", "telegram")
|
||||
db.set_session_title("sess_v1", "My Project")
|
||||
db.create_session("sess_v2", "telegram")
|
||||
db.set_session_title("sess_v2", "My Project #2")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume My Project")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
result = await runner._handle_resume_command(event)
|
||||
|
||||
assert "Resumed" in result
|
||||
# Should resolve to #2 (latest in lineage)
|
||||
call_args = runner.session_store.switch_session.call_args
|
||||
assert call_args[0][1] == "sess_v2"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_clears_running_agent(self, tmp_path):
|
||||
"""Switching sessions clears any cached running agent."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
# Simulate a running agent using the real session key
|
||||
real_key = _session_key_for_event(event)
|
||||
runner._running_agents[real_key] = MagicMock()
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
335
tests/gateway/test_send_image_file.py
Normal file
335
tests/gateway/test_send_image_file.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Tests for send_image_file() on Telegram, Discord, and Slack platforms,
|
||||
and MEDIA: .png extraction/routing in the base platform adapter.
|
||||
|
||||
Covers: local image file sending, file-not-found handling, fallback on error,
|
||||
MEDIA: tag extraction for image extensions, and routing to send_image_file.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MEDIA: extraction tests for image files
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractMediaImages:
|
||||
"""Test that MEDIA: tags with image extensions are correctly extracted."""
|
||||
|
||||
def test_png_image_extracted(self):
|
||||
content = "Here is the screenshot:\nMEDIA:/home/user/.hermes/browser_screenshots/shot.png"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/home/user/.hermes/browser_screenshots/shot.png"
|
||||
assert "MEDIA:" not in cleaned
|
||||
assert "Here is the screenshot" in cleaned
|
||||
|
||||
def test_jpg_image_extracted(self):
|
||||
content = "MEDIA:/tmp/photo.jpg"
|
||||
media, cleaned = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
assert media[0][0] == "/tmp/photo.jpg"
|
||||
|
||||
def test_webp_image_extracted(self):
|
||||
content = "MEDIA:/tmp/image.webp"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 1
|
||||
|
||||
def test_mixed_audio_and_image(self):
|
||||
content = "MEDIA:/audio.ogg\nMEDIA:/screenshot.png"
|
||||
media, _ = BasePlatformAdapter.extract_media(content)
|
||||
assert len(media) == 2
|
||||
paths = [m[0] for m in media]
|
||||
assert "/audio.ogg" in paths
|
||||
assert "/screenshot.png" in paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_telegram_mock():
|
||||
"""Install mock telegram modules so TelegramAdapter can be imported."""
|
||||
if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"):
|
||||
return
|
||||
|
||||
telegram_mod = MagicMock()
|
||||
telegram_mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
telegram_mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
telegram_mod.constants.ChatType.GROUP = "group"
|
||||
telegram_mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
_ensure_telegram_mock()
|
||||
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestTelegramSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = TelegramAdapter(config)
|
||||
a._bot = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_photo(self, adapter, tmp_path):
|
||||
"""send_image_file should call bot.send_photo with the opened file."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) # Minimal PNG-like
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 42
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "42"
|
||||
adapter._bot.send_photo.assert_awaited_once()
|
||||
|
||||
# Verify photo arg was a file object (opened in rb mode)
|
||||
call_kwargs = adapter._bot.send_photo.call_args
|
||||
assert call_kwargs.kwargs["chat_id"] == 12345
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
"""send_image_file should return error for nonexistent file."""
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/nonexistent/image.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
"""send_image_file should return error when bot is None."""
|
||||
adapter._bot = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_caption_truncated_to_1024(self, adapter, tmp_path):
|
||||
"""Telegram captions have a 1024 char limit."""
|
||||
img = tmp_path / "shot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 1
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
long_caption = "A" * 2000
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="12345", image_path=str(img), caption=long_caption)
|
||||
)
|
||||
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert len(call_kwargs["caption"]) == 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install mock discord module so DiscordAdapter can be imported."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
|
||||
for name in ("discord", "discord.ext", "discord.ext.commands"):
|
||||
sys.modules.setdefault(name, discord_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import discord as discord_mod_ref # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestDiscordSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_as_attachment(self, adapter, tmp_path):
|
||||
"""send_image_file should create discord.File and send to channel."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_channel = MagicMock()
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.id = 99
|
||||
mock_channel.send = AsyncMock(return_value=mock_msg)
|
||||
adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
assert result.message_id == "99"
|
||||
mock_channel.send.assert_awaited_once()
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._client = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="67890", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
def test_handles_missing_channel(self, adapter):
|
||||
adapter._client.get_channel = MagicMock(return_value=None)
|
||||
adapter._client.fetch_channel = AsyncMock(return_value=None)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="99999", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack send_image_file tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ensure_slack_mock():
|
||||
"""Install mock slack_bolt module so SlackAdapter can be imported."""
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
|
||||
slack_mod = MagicMock()
|
||||
for name in ("slack_bolt", "slack_bolt.async_app", "slack_sdk", "slack_sdk.web.async_client"):
|
||||
sys.modules.setdefault(name, slack_mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
|
||||
|
||||
class TestSlackSendImageFile:
|
||||
@pytest.fixture
|
||||
def adapter(self):
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake")
|
||||
a = SlackAdapter(config)
|
||||
a._app = MagicMock()
|
||||
return a
|
||||
|
||||
def test_sends_local_image_via_upload(self, adapter, tmp_path):
|
||||
"""send_image_file should call files_upload_v2 with the local path."""
|
||||
img = tmp_path / "screenshot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_result = MagicMock()
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path=str(img))
|
||||
)
|
||||
assert result.success
|
||||
adapter._app.client.files_upload_v2.assert_awaited_once()
|
||||
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args.kwargs
|
||||
assert call_kwargs["file"] == str(img)
|
||||
assert call_kwargs["filename"] == "screenshot.png"
|
||||
assert call_kwargs["channel"] == "C12345"
|
||||
|
||||
def test_returns_error_when_file_missing(self, adapter):
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/nonexistent.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "not found" in result.error
|
||||
|
||||
def test_returns_error_when_not_connected(self, adapter):
|
||||
adapter._app = None
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
adapter.send_image_file(chat_id="C12345", image_path="/tmp/img.png")
|
||||
)
|
||||
assert not result.success
|
||||
assert "Not connected" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# browser_vision screenshot cleanup tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScreenshotCleanup:
|
||||
def test_cleanup_removes_old_screenshots(self, tmp_path):
|
||||
"""_cleanup_old_screenshots should remove files older than max_age_hours."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
|
||||
# Create a "fresh" file
|
||||
fresh = tmp_path / "browser_screenshot_fresh.png"
|
||||
fresh.write_bytes(b"new")
|
||||
|
||||
# Create an "old" file and backdate its mtime
|
||||
old = tmp_path / "browser_screenshot_old.png"
|
||||
old.write_bytes(b"old")
|
||||
old_time = time.time() - (25 * 3600) # 25 hours ago
|
||||
os.utime(str(old), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert fresh.exists(), "Fresh screenshot should not be removed"
|
||||
assert not old.exists(), "Old screenshot should be removed"
|
||||
|
||||
def test_cleanup_ignores_non_screenshot_files(self, tmp_path):
|
||||
"""Only files matching browser_screenshot_*.png should be cleaned."""
|
||||
import time
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
|
||||
other_file = tmp_path / "important_data.txt"
|
||||
other_file.write_bytes(b"keep me")
|
||||
old_time = time.time() - (48 * 3600)
|
||||
os.utime(str(other_file), (old_time, old_time))
|
||||
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24)
|
||||
|
||||
assert other_file.exists(), "Non-screenshot files should not be touched"
|
||||
|
||||
def test_cleanup_handles_empty_dir(self, tmp_path):
|
||||
"""Cleanup should not fail on empty directory."""
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
_cleanup_old_screenshots(tmp_path, max_age_hours=24) # Should not raise
|
||||
|
||||
def test_cleanup_handles_nonexistent_dir(self):
|
||||
"""Cleanup should not fail if directory doesn't exist."""
|
||||
from pathlib import Path
|
||||
from tools.browser_tool import _cleanup_old_screenshots
|
||||
_cleanup_old_screenshots(Path("/nonexistent/dir"), max_age_hours=24) # Should not raise
|
||||
159
tests/gateway/test_session_hygiene.py
Normal file
159
tests/gateway/test_session_hygiene.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Tests for gateway session hygiene — auto-compression of large sessions.
|
||||
|
||||
Verifies that the gateway detects pathologically large transcripts and
|
||||
triggers auto-compression before running the agent. (#628)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_history(n_messages: int, content_size: int = 100) -> list:
|
||||
"""Build a fake transcript with n_messages user/assistant pairs."""
|
||||
history = []
|
||||
content = "x" * content_size
|
||||
for i in range(n_messages):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
history.append({"role": role, "content": content, "timestamp": f"t{i}"})
|
||||
return history
|
||||
|
||||
|
||||
def _make_large_history_tokens(target_tokens: int) -> list:
|
||||
"""Build a history that estimates to roughly target_tokens tokens."""
|
||||
# estimate_messages_tokens_rough counts total chars in str(msg) // 4
|
||||
# Each msg dict has ~60 chars of overhead + content chars
|
||||
# So for N tokens we need roughly N * 4 total chars across all messages
|
||||
target_chars = target_tokens * 4
|
||||
# Each message as a dict string is roughly len(content) + 60 chars
|
||||
msg_overhead = 60
|
||||
# Use 50 messages with appropriately sized content
|
||||
n_msgs = 50
|
||||
content_size = max(10, (target_chars // n_msgs) - msg_overhead)
|
||||
return _make_history(n_msgs, content_size=content_size)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection threshold tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionHygieneThresholds:
|
||||
"""Test that the threshold logic correctly identifies large sessions."""
|
||||
|
||||
def test_small_session_below_thresholds(self):
|
||||
"""A 10-message session should not trigger compression."""
|
||||
history = _make_history(10)
|
||||
msg_count = len(history)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
)
|
||||
assert not needs_compress
|
||||
|
||||
def test_large_message_count_triggers(self):
|
||||
"""200+ messages should trigger compression even if tokens are low."""
|
||||
history = _make_history(250, content_size=10)
|
||||
msg_count = len(history)
|
||||
|
||||
compress_msg_threshold = 200
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_large_token_count_triggers(self):
|
||||
"""High token count should trigger compression even if message count is low."""
|
||||
# 50 messages with huge content to exceed 100K tokens
|
||||
history = _make_history(50, content_size=10_000)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
needs_compress = approx_tokens >= compress_token_threshold
|
||||
assert needs_compress
|
||||
|
||||
def test_under_both_thresholds_no_trigger(self):
|
||||
"""Session under both thresholds should not trigger."""
|
||||
history = _make_history(100, content_size=100)
|
||||
msg_count = len(history)
|
||||
approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
compress_token_threshold = 100_000
|
||||
compress_msg_threshold = 200
|
||||
|
||||
needs_compress = (
|
||||
approx_tokens >= compress_token_threshold
|
||||
or msg_count >= compress_msg_threshold
|
||||
)
|
||||
assert not needs_compress
|
||||
|
||||
def test_custom_thresholds(self):
|
||||
"""Custom thresholds from config should be respected."""
|
||||
history = _make_history(60, content_size=100)
|
||||
msg_count = len(history)
|
||||
|
||||
# Custom lower threshold
|
||||
compress_msg_threshold = 50
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert needs_compress
|
||||
|
||||
# Custom higher threshold
|
||||
compress_msg_threshold = 100
|
||||
needs_compress = msg_count >= compress_msg_threshold
|
||||
assert not needs_compress
|
||||
|
||||
def test_minimum_message_guard(self):
|
||||
"""Sessions with fewer than 4 messages should never trigger."""
|
||||
history = _make_history(3, content_size=100_000)
|
||||
# Even with enormous content, < 4 messages should be skipped
|
||||
# (the gateway code checks `len(history) >= 4` before evaluating)
|
||||
assert len(history) < 4
|
||||
|
||||
|
||||
class TestSessionHygieneWarnThreshold:
|
||||
"""Test the post-compression warning threshold."""
|
||||
|
||||
def test_warn_when_still_large(self):
|
||||
"""If compressed result is still above warn_tokens, should warn."""
|
||||
# Simulate post-compression tokens
|
||||
warn_threshold = 200_000
|
||||
post_compress_tokens = 250_000
|
||||
assert post_compress_tokens >= warn_threshold
|
||||
|
||||
def test_no_warn_when_under(self):
|
||||
"""If compressed result is under warn_tokens, no warning."""
|
||||
warn_threshold = 200_000
|
||||
post_compress_tokens = 150_000
|
||||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
||||
class TestTokenEstimation:
|
||||
"""Verify rough token estimation works as expected for hygiene checks."""
|
||||
|
||||
def test_empty_history(self):
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_proportional_to_content(self):
|
||||
small = _make_history(10, content_size=100)
|
||||
large = _make_history(10, content_size=10_000)
|
||||
assert estimate_messages_tokens_rough(large) > estimate_messages_tokens_rough(small)
|
||||
|
||||
def test_proportional_to_count(self):
|
||||
few = _make_history(10, content_size=1000)
|
||||
many = _make_history(100, content_size=1000)
|
||||
assert estimate_messages_tokens_rough(many) > estimate_messages_tokens_rough(few)
|
||||
|
||||
def test_pathological_session_detected(self):
|
||||
"""The reported pathological case: 648 messages, ~299K tokens."""
|
||||
# Simulate a 648-message session averaging ~460 tokens per message
|
||||
history = _make_history(648, content_size=1800)
|
||||
tokens = estimate_messages_tokens_rough(history)
|
||||
# Should be well above the 100K default threshold
|
||||
assert tokens > 100_000
|
||||
assert len(history) > 200
|
||||
207
tests/gateway/test_title_command.py
Normal file
207
tests/gateway/test_title_command.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tests for /title gateway slash command.
|
||||
|
||||
Tests the _handle_title_command handler (set/show session titles)
|
||||
across all gateway messenger platforms.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_event(text="/title", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
||||
def _make_runner(session_db=None):
|
||||
"""Create a bare GatewayRunner with a mock session_store and optional session_db."""
|
||||
from gateway.run import GatewayRunner
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._session_db = session_db
|
||||
|
||||
# Mock session_store that returns a session entry with a known session_id
|
||||
mock_session_entry = MagicMock()
|
||||
mock_session_entry.session_id = "test_session_123"
|
||||
mock_session_entry.session_key = "telegram:12345:67890"
|
||||
mock_store = MagicMock()
|
||||
mock_store.get_or_create_session.return_value = mock_session_entry
|
||||
runner.session_store = mock_store
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_title_command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandleTitleCommand:
|
||||
"""Tests for GatewayRunner._handle_title_command."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_title(self, tmp_path):
|
||||
"""Setting a title returns confirmation."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title My Research Project")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "My Research Project" in result
|
||||
assert "✏️" in result
|
||||
|
||||
# Verify in DB
|
||||
assert db.get_session_title("test_session_123") == "My Research Project"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_set(self, tmp_path):
|
||||
"""Showing title when one is set returns the title."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
db.set_session_title("test_session_123", "Existing Title")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Existing Title" in result
|
||||
assert "📌" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_title_when_not_set(self, tmp_path):
|
||||
"""Showing title when none is set returns usage hint."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "No title set" in result
|
||||
assert "/title" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_conflict(self, tmp_path):
|
||||
"""Setting a title already used by another session returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("other_session", "telegram")
|
||||
db.set_session_title("other_session", "Taken Title")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Taken Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "already in use" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_db(self):
|
||||
"""Returns error when session database is not available."""
|
||||
runner = _make_runner(session_db=None)
|
||||
event = _make_event(text="/title My Title")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "not available" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_too_long(self, tmp_path):
|
||||
"""Setting a title that exceeds max length returns error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
long_title = "A" * 150
|
||||
event = _make_event(text=f"/title {long_title}")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "too long" in result
|
||||
assert "⚠️" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_control_chars_sanitized(self, tmp_path):
|
||||
"""Control characters are stripped and sanitized title is stored."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title hello\x00world")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "helloworld" in result
|
||||
assert db.get_session_title("test_session_123") == "helloworld"
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_only_control_chars(self, tmp_path):
|
||||
"""Title with only control chars returns empty error."""
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("test_session_123", "telegram")
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title \x00\x01\x02")
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "empty after cleanup" in result
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_works_across_platforms(self, tmp_path):
|
||||
"""The /title command works for Discord, Slack, and WhatsApp too."""
|
||||
from hermes_state import SessionDB
|
||||
for platform in [Platform.DISCORD, Platform.TELEGRAM]:
|
||||
db = SessionDB(db_path=tmp_path / f"state_{platform.value}.db")
|
||||
db.create_session("test_session_123", platform.value)
|
||||
|
||||
runner = _make_runner(session_db=db)
|
||||
event = _make_event(text="/title Cross-Platform Test", platform=platform)
|
||||
result = await runner._handle_title_command(event)
|
||||
assert "Cross-Platform Test" in result
|
||||
assert db.get_session_title("test_session_123") == "Cross-Platform Test"
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# /title in help and known_commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTitleInHelp:
|
||||
"""Verify /title appears in help text and known commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_in_help_output(self):
|
||||
"""The /help output includes /title."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(text="/help")
|
||||
# Need hooks for help command
|
||||
from gateway.hooks import HookRegistry
|
||||
runner.hooks = HookRegistry()
|
||||
result = await runner._handle_help_command(event)
|
||||
assert "/title" in result
|
||||
|
||||
def test_title_is_known_command(self):
|
||||
"""The /title command is in the _known_commands set."""
|
||||
from gateway.run import GatewayRunner
|
||||
import inspect
|
||||
source = inspect.getsource(GatewayRunner._handle_message)
|
||||
assert '"title"' in source
|
||||
145
tests/hermes_cli/test_commands.py
Normal file
145
tests/hermes_cli/test_commands.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Tests for shared slash command definitions and autocomplete."""
|
||||
|
||||
from prompt_toolkit.completion import CompleteEvent
|
||||
from prompt_toolkit.document import Document
|
||||
|
||||
from hermes_cli.commands import COMMANDS, SlashCommandCompleter
|
||||
|
||||
|
||||
# All commands that must be present in the shared COMMANDS dict.
|
||||
EXPECTED_COMMANDS = {
|
||||
"/help", "/tools", "/toolsets", "/model", "/provider", "/prompt",
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/quit",
|
||||
}
|
||||
|
||||
|
||||
def _completions(completer: SlashCommandCompleter, text: str):
|
||||
return list(
|
||||
completer.get_completions(
|
||||
Document(text=text),
|
||||
CompleteEvent(completion_requested=True),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestCommands:
|
||||
def test_shared_commands_include_cli_specific_entries(self):
|
||||
"""Entries that previously only existed in cli.py are now in the shared dict."""
|
||||
assert COMMANDS["/paste"] == "Check clipboard for an image and attach it"
|
||||
assert COMMANDS["/reload-mcp"] == "Reload MCP servers from config.yaml"
|
||||
|
||||
def test_all_expected_commands_present(self):
|
||||
"""Regression guard — every known command must appear in the shared dict."""
|
||||
assert set(COMMANDS.keys()) == EXPECTED_COMMANDS
|
||||
|
||||
def test_every_command_has_nonempty_description(self):
|
||||
for cmd, desc in COMMANDS.items():
|
||||
assert isinstance(desc, str) and len(desc) > 0, f"{cmd} has empty description"
|
||||
|
||||
|
||||
class TestSlashCommandCompleter:
|
||||
# -- basic prefix completion -----------------------------------------
|
||||
|
||||
def test_builtin_prefix_completion_uses_shared_registry(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/re")
|
||||
texts = {item.text for item in completions}
|
||||
|
||||
assert "reset" in texts
|
||||
assert "retry" in texts
|
||||
assert "reload-mcp" in texts
|
||||
|
||||
def test_builtin_completion_display_meta_shows_description(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/help")
|
||||
assert len(completions) == 1
|
||||
assert completions[0].display_meta_text == "Show this help message"
|
||||
|
||||
# -- exact-match trailing space --------------------------------------
|
||||
|
||||
def test_exact_match_completion_adds_trailing_space(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/help")
|
||||
|
||||
assert [item.text for item in completions] == ["help "]
|
||||
|
||||
def test_partial_match_does_not_add_trailing_space(self):
|
||||
completions = _completions(SlashCommandCompleter(), "/hel")
|
||||
|
||||
assert [item.text for item in completions] == ["help"]
|
||||
|
||||
# -- non-slash input returns nothing ---------------------------------
|
||||
|
||||
def test_no_completions_for_non_slash_input(self):
|
||||
assert _completions(SlashCommandCompleter(), "help") == []
|
||||
|
||||
def test_no_completions_for_empty_input(self):
|
||||
assert _completions(SlashCommandCompleter(), "") == []
|
||||
|
||||
# -- skill commands via provider ------------------------------------
|
||||
|
||||
def test_skill_commands_are_completed_from_provider(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/gif-search": {"description": "Search for GIFs across providers"},
|
||||
}
|
||||
)
|
||||
|
||||
completions = _completions(completer, "/gif")
|
||||
|
||||
assert len(completions) == 1
|
||||
assert completions[0].text == "gif-search"
|
||||
assert completions[0].display_text == "/gif-search"
|
||||
assert completions[0].display_meta_text == "⚡ Search for GIFs across providers"
|
||||
|
||||
def test_skill_exact_match_adds_trailing_space(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/gif-search": {"description": "Search for GIFs"},
|
||||
}
|
||||
)
|
||||
|
||||
completions = _completions(completer, "/gif-search")
|
||||
|
||||
assert len(completions) == 1
|
||||
assert completions[0].text == "gif-search "
|
||||
|
||||
def test_no_skill_provider_means_no_skill_completions(self):
|
||||
"""Default (None) provider should not blow up or add completions."""
|
||||
completer = SlashCommandCompleter()
|
||||
completions = _completions(completer, "/gif")
|
||||
# /gif doesn't match any builtin command
|
||||
assert completions == []
|
||||
|
||||
def test_skill_provider_exception_is_swallowed(self):
|
||||
"""A broken provider should not crash autocomplete."""
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: (_ for _ in ()).throw(RuntimeError("boom")),
|
||||
)
|
||||
# Should return builtin matches only, no crash
|
||||
completions = _completions(completer, "/he")
|
||||
texts = {item.text for item in completions}
|
||||
assert "help" in texts
|
||||
|
||||
def test_skill_description_truncated_at_50_chars(self):
|
||||
long_desc = "A" * 80
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/long-skill": {"description": long_desc},
|
||||
}
|
||||
)
|
||||
completions = _completions(completer, "/long")
|
||||
assert len(completions) == 1
|
||||
meta = completions[0].display_meta_text
|
||||
# "⚡ " prefix + 50 chars + "..."
|
||||
assert meta == f"⚡ {'A' * 50}..."
|
||||
|
||||
def test_skill_missing_description_uses_fallback(self):
|
||||
completer = SlashCommandCompleter(
|
||||
skill_commands_provider=lambda: {
|
||||
"/no-desc": {},
|
||||
}
|
||||
)
|
||||
completions = _completions(completer, "/no-desc")
|
||||
assert len(completions) == 1
|
||||
assert "Skill command" in completions[0].display_meta_text
|
||||
17
tests/hermes_cli/test_doctor.py
Normal file
17
tests/hermes_cli/test_doctor.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Tests for hermes doctor helpers."""
|
||||
|
||||
from hermes_cli.doctor import _has_provider_env_config
|
||||
|
||||
|
||||
class TestProviderEnvDetection:
|
||||
def test_detects_openai_api_key(self):
|
||||
content = "OPENAI_BASE_URL=http://localhost:1234/v1\nOPENAI_API_KEY=sk-test-key\n"
|
||||
assert _has_provider_env_config(content)
|
||||
|
||||
def test_detects_custom_endpoint_without_openrouter_key(self):
|
||||
content = "OPENAI_BASE_URL=http://localhost:8080/v1\n"
|
||||
assert _has_provider_env_config(content)
|
||||
|
||||
def test_returns_false_when_no_provider_settings(self):
|
||||
content = "TERMINAL_ENV=local\n"
|
||||
assert not _has_provider_env_config(content)
|
||||
220
tests/hermes_cli/test_model_validation.py
Normal file
220
tests/hermes_cli/test_model_validation.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for provider-aware `/model` validation in hermes_cli.models."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from hermes_cli.models import (
|
||||
curated_models_for_provider,
|
||||
fetch_api_models,
|
||||
normalize_provider,
|
||||
parse_model_input,
|
||||
provider_model_ids,
|
||||
validate_requested_model,
|
||||
)
|
||||
|
||||
|
||||
# -- helpers -----------------------------------------------------------------
|
||||
|
||||
FAKE_API_MODELS = [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"openai/gpt-5.4-pro",
|
||||
"openai/gpt-5.4",
|
||||
"google/gemini-3-pro-preview",
|
||||
]
|
||||
|
||||
|
||||
def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw):
|
||||
"""Shortcut: call validate_requested_model with mocked API."""
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=api_models):
|
||||
return validate_requested_model(model, provider, **kw)
|
||||
|
||||
|
||||
# -- parse_model_input -------------------------------------------------------
|
||||
|
||||
class TestParseModelInput:
|
||||
def test_plain_model_keeps_current_provider(self):
|
||||
provider, model = parse_model_input("anthropic/claude-sonnet-4.5", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
def test_provider_colon_model_switches_provider(self):
|
||||
provider, model = parse_model_input("openrouter:anthropic/claude-sonnet-4.5", "nous")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-sonnet-4.5"
|
||||
|
||||
def test_provider_alias_resolved(self):
|
||||
provider, model = parse_model_input("glm:glm-5", "openrouter")
|
||||
assert provider == "zai"
|
||||
assert model == "glm-5"
|
||||
|
||||
def test_no_slash_no_colon_keeps_provider(self):
|
||||
provider, model = parse_model_input("gpt-5.4", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "gpt-5.4"
|
||||
|
||||
def test_nous_provider_switch(self):
|
||||
provider, model = parse_model_input("nous:hermes-3", "openrouter")
|
||||
assert provider == "nous"
|
||||
assert model == "hermes-3"
|
||||
|
||||
def test_empty_model_after_colon_keeps_current(self):
|
||||
provider, model = parse_model_input("openrouter:", "nous")
|
||||
assert provider == "nous"
|
||||
assert model == "openrouter:"
|
||||
|
||||
def test_colon_at_start_keeps_current(self):
|
||||
provider, model = parse_model_input(":something", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == ":something"
|
||||
|
||||
def test_unknown_prefix_colon_not_treated_as_provider(self):
|
||||
"""Colons are only provider delimiters if the left side is a known provider."""
|
||||
provider, model = parse_model_input("anthropic/claude-3.5-sonnet:beta", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "anthropic/claude-3.5-sonnet:beta"
|
||||
|
||||
def test_http_url_not_treated_as_provider(self):
|
||||
provider, model = parse_model_input("http://localhost:8080/model", "openrouter")
|
||||
assert provider == "openrouter"
|
||||
assert model == "http://localhost:8080/model"
|
||||
|
||||
|
||||
# -- curated_models_for_provider ---------------------------------------------
|
||||
|
||||
class TestCuratedModelsForProvider:
|
||||
def test_openrouter_returns_curated_list(self):
|
||||
models = curated_models_for_provider("openrouter")
|
||||
assert len(models) > 0
|
||||
assert any("claude" in m[0] for m in models)
|
||||
|
||||
def test_zai_returns_glm_models(self):
|
||||
models = curated_models_for_provider("zai")
|
||||
assert any("glm" in m[0] for m in models)
|
||||
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
assert curated_models_for_provider("totally-unknown") == []
|
||||
|
||||
|
||||
# -- normalize_provider ------------------------------------------------------
|
||||
|
||||
class TestNormalizeProvider:
|
||||
def test_defaults_to_openrouter(self):
|
||||
assert normalize_provider(None) == "openrouter"
|
||||
assert normalize_provider("") == "openrouter"
|
||||
|
||||
def test_known_aliases(self):
|
||||
assert normalize_provider("glm") == "zai"
|
||||
assert normalize_provider("kimi") == "kimi-coding"
|
||||
assert normalize_provider("moonshot") == "kimi-coding"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert normalize_provider("OpenRouter") == "openrouter"
|
||||
|
||||
|
||||
# -- provider_model_ids ------------------------------------------------------
|
||||
|
||||
class TestProviderModelIds:
|
||||
def test_openrouter_returns_curated_list(self):
|
||||
ids = provider_model_ids("openrouter")
|
||||
assert len(ids) > 0
|
||||
assert all("/" in mid for mid in ids)
|
||||
|
||||
def test_unknown_provider_returns_empty(self):
|
||||
assert provider_model_ids("some-unknown-provider") == []
|
||||
|
||||
def test_zai_returns_glm_models(self):
|
||||
assert "glm-5" in provider_model_ids("zai")
|
||||
|
||||
|
||||
# -- fetch_api_models --------------------------------------------------------
|
||||
|
||||
class TestFetchApiModels:
|
||||
def test_returns_none_when_no_base_url(self):
|
||||
assert fetch_api_models("key", None) is None
|
||||
|
||||
def test_returns_none_on_network_error(self):
|
||||
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")):
|
||||
assert fetch_api_models("key", "https://example.com/v1") is None
|
||||
|
||||
|
||||
# -- validate — format checks -----------------------------------------------
|
||||
|
||||
class TestValidateFormatChecks:
|
||||
def test_empty_model_rejected(self):
|
||||
result = _validate("")
|
||||
assert result["accepted"] is False
|
||||
assert "empty" in result["message"]
|
||||
|
||||
def test_whitespace_only_rejected(self):
|
||||
result = _validate(" ")
|
||||
assert result["accepted"] is False
|
||||
|
||||
def test_model_with_spaces_rejected(self):
|
||||
result = _validate("anthropic/ claude-opus")
|
||||
assert result["accepted"] is False
|
||||
|
||||
def test_no_slash_model_still_probes_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["gpt-5.4", "gpt-5.4-pro"])
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_no_slash_model_rejected_if_not_in_api(self):
|
||||
result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"])
|
||||
assert result["accepted"] is False
|
||||
|
||||
|
||||
# -- validate — API found ----------------------------------------------------
|
||||
|
||||
class TestValidateApiFound:
|
||||
def test_model_found_in_api(self):
|
||||
result = _validate("anthropic/claude-opus-4.6")
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["recognized"] is True
|
||||
|
||||
def test_model_found_for_custom_endpoint(self):
|
||||
result = _validate(
|
||||
"my-model", provider="openrouter",
|
||||
api_models=["my-model"], base_url="http://localhost:11434/v1",
|
||||
)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
|
||||
# -- validate — API not found ------------------------------------------------
|
||||
|
||||
class TestValidateApiNotFound:
|
||||
def test_model_not_in_api_rejected(self):
|
||||
result = _validate("anthropic/claude-nonexistent")
|
||||
assert result["accepted"] is False
|
||||
assert "not a valid model" in result["message"]
|
||||
|
||||
def test_rejection_includes_suggestions(self):
|
||||
result = _validate("anthropic/claude-opus-4.5")
|
||||
assert result["accepted"] is False
|
||||
assert "Did you mean" in result["message"]
|
||||
|
||||
|
||||
# -- validate — API unreachable (fallback) -----------------------------------
|
||||
|
||||
class TestValidateApiFallback:
|
||||
def test_known_catalog_model_accepted_when_api_down(self):
|
||||
result = _validate("anthropic/claude-opus-4.6", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_model_session_only_when_api_down(self):
|
||||
result = _validate("anthropic/claude-next-gen", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
assert "session only" in result["message"].lower()
|
||||
|
||||
def test_zai_known_model_accepted_when_api_down(self):
|
||||
result = _validate("glm-5", provider="zai", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_unknown_provider_session_only_when_api_down(self):
|
||||
result = _validate("some-model", provider="totally-unknown", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is False
|
||||
31
tests/hermes_cli/test_skills_hub.py
Normal file
31
tests/hermes_cli/test_skills_hub.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from io import StringIO
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from hermes_cli.skills_hub import do_list
|
||||
|
||||
|
||||
def test_do_list_initializes_hub_dir(monkeypatch, tmp_path):
|
||||
import tools.skills_hub as hub
|
||||
import tools.skills_tool as skills_tool
|
||||
|
||||
hub_dir = tmp_path / "skills" / ".hub"
|
||||
monkeypatch.setattr(hub, "SKILLS_DIR", tmp_path / "skills")
|
||||
monkeypatch.setattr(hub, "HUB_DIR", hub_dir)
|
||||
monkeypatch.setattr(hub, "LOCK_FILE", hub_dir / "lock.json")
|
||||
monkeypatch.setattr(hub, "QUARANTINE_DIR", hub_dir / "quarantine")
|
||||
monkeypatch.setattr(hub, "AUDIT_LOG", hub_dir / "audit.log")
|
||||
monkeypatch.setattr(hub, "TAPS_FILE", hub_dir / "taps.json")
|
||||
monkeypatch.setattr(hub, "INDEX_CACHE_DIR", hub_dir / "index-cache")
|
||||
monkeypatch.setattr(skills_tool, "_find_all_skills", lambda: [])
|
||||
|
||||
console = Console(file=StringIO(), force_terminal=False, color_system=None)
|
||||
|
||||
assert not hub_dir.exists()
|
||||
|
||||
do_list(console=console)
|
||||
|
||||
assert hub_dir.exists()
|
||||
assert (hub_dir / "lock.json").exists()
|
||||
assert (hub_dir / "quarantine").is_dir()
|
||||
assert (hub_dir / "index-cache").is_dir()
|
||||
19
tests/hermes_cli/test_tools_config.py
Normal file
19
tests/hermes_cli/test_tools_config.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Tests for hermes_cli.tools_config platform tool persistence."""
|
||||
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
|
||||
|
||||
def test_get_platform_tools_uses_default_when_platform_not_configured():
|
||||
config = {}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
assert enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_preserves_explicit_empty_selection():
|
||||
config = {"platform_toolsets": {"cli": []}}
|
||||
|
||||
enabled = _get_platform_tools(config, "cli")
|
||||
|
||||
assert enabled == set()
|
||||
@@ -20,6 +20,8 @@ from hermes_cli.auth import (
|
||||
resolve_api_key_provider_credentials,
|
||||
get_auth_status,
|
||||
AuthError,
|
||||
KIMI_CODE_BASE_URL,
|
||||
_resolve_kimi_base_url,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,7 +86,7 @@ class TestProviderRegistry:
|
||||
PROVIDER_ENV_VARS = (
|
||||
"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY",
|
||||
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
|
||||
"KIMI_API_KEY", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
|
||||
"KIMI_API_KEY", "KIMI_BASE_URL", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
@@ -340,3 +342,87 @@ class TestHasAnyProviderConfigured:
|
||||
monkeypatch.setattr(config_module, "get_hermes_home", lambda: hermes_home)
|
||||
from hermes_cli.main import _has_any_provider_configured
|
||||
assert _has_any_provider_configured() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code auto-detection tests
|
||||
# =============================================================================
|
||||
|
||||
MOONSHOT_DEFAULT_URL = "https://api.moonshot.ai/v1"
|
||||
|
||||
|
||||
class TestResolveKimiBaseUrl:
|
||||
"""Test _resolve_kimi_base_url() helper for key-prefix auto-detection."""
|
||||
|
||||
def test_sk_kimi_prefix_routes_to_kimi_code(self):
|
||||
url = _resolve_kimi_base_url("sk-kimi-abc123", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_uses_default(self):
|
||||
url = _resolve_kimi_base_url("sk-abc123", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_empty_key_uses_default(self):
|
||||
url = _resolve_kimi_base_url("", MOONSHOT_DEFAULT_URL, "")
|
||||
assert url == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins_over_sk_kimi(self):
|
||||
"""KIMI_BASE_URL env var should always take priority."""
|
||||
custom = "https://custom.example.com/v1"
|
||||
url = _resolve_kimi_base_url("sk-kimi-abc123", MOONSHOT_DEFAULT_URL, custom)
|
||||
assert url == custom
|
||||
|
||||
def test_env_override_wins_over_legacy(self):
|
||||
custom = "https://custom.example.com/v1"
|
||||
url = _resolve_kimi_base_url("sk-abc123", MOONSHOT_DEFAULT_URL, custom)
|
||||
assert url == custom
|
||||
|
||||
|
||||
class TestKimiCodeStatusAutoDetect:
|
||||
"""Test that get_api_key_provider_status auto-detects sk-kimi- keys."""
|
||||
|
||||
def test_sk_kimi_key_gets_kimi_code_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-test-key-123")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["configured"] is True
|
||||
assert status["base_url"] == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_gets_moonshot_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-legacy-test-key")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["configured"] is True
|
||||
assert status["base_url"] == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-test-key")
|
||||
monkeypatch.setenv("KIMI_BASE_URL", "https://override.example/v1")
|
||||
status = get_api_key_provider_status("kimi-coding")
|
||||
assert status["base_url"] == "https://override.example/v1"
|
||||
|
||||
|
||||
class TestKimiCodeCredentialAutoDetect:
|
||||
"""Test that resolve_api_key_provider_credentials auto-detects sk-kimi- keys."""
|
||||
|
||||
def test_sk_kimi_key_gets_kimi_code_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["api_key"] == "sk-kimi-secret-key"
|
||||
assert creds["base_url"] == KIMI_CODE_BASE_URL
|
||||
|
||||
def test_legacy_key_gets_moonshot_url(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-legacy-secret-key")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["api_key"] == "sk-legacy-secret-key"
|
||||
assert creds["base_url"] == MOONSHOT_DEFAULT_URL
|
||||
|
||||
def test_env_override_wins(self, monkeypatch):
|
||||
monkeypatch.setenv("KIMI_API_KEY", "sk-kimi-secret-key")
|
||||
monkeypatch.setenv("KIMI_BASE_URL", "https://override.example/v1")
|
||||
creds = resolve_api_key_provider_credentials("kimi-coding")
|
||||
assert creds["base_url"] == "https://override.example/v1"
|
||||
|
||||
def test_non_kimi_providers_unaffected(self, monkeypatch):
|
||||
"""Ensure the auto-detect logic doesn't leak to other providers."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "sk-kimi-looks-like-kimi-but-isnt")
|
||||
creds = resolve_api_key_provider_credentials("zai")
|
||||
assert creds["base_url"] == "https://api.z.ai/api/paas/v4"
|
||||
|
||||
@@ -3,9 +3,7 @@ that only manifest at runtime (not in mocked unit tests)."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
@@ -72,6 +70,38 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestHistoryDisplay:
|
||||
def test_history_numbers_only_visible_messages_and_summarizes_tools(self, capsys):
|
||||
cli = _make_cli()
|
||||
cli.conversation_history = [
|
||||
{"role": "system", "content": "system prompt"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"id": "call_1"}, {"id": "call_2"}],
|
||||
},
|
||||
{"role": "tool", "content": "tool output 1"},
|
||||
{"role": "tool", "content": "tool output 2"},
|
||||
{"role": "assistant", "content": "All set."},
|
||||
{"role": "user", "content": "A" * 250},
|
||||
]
|
||||
|
||||
cli.show_history()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "[You #1]" in output
|
||||
assert "[Hermes #2]" in output
|
||||
assert "(requested 2 tool calls)" in output
|
||||
assert "[Tools]" in output
|
||||
assert "(2 tool messages hidden)" in output
|
||||
assert "[Hermes #3]" in output
|
||||
assert "[You #4]" in output
|
||||
assert "[You #5]" not in output
|
||||
assert "A" * 250 in output
|
||||
assert "A" * 250 + "..." not in output
|
||||
|
||||
|
||||
class TestProviderResolution:
|
||||
def test_api_key_is_string_or_none(self):
|
||||
cli = _make_cli()
|
||||
|
||||
133
tests/test_cli_model_command.py
Normal file
133
tests/test_cli_model_command.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Regression tests for the `/model` slash command in the interactive CLI."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
class TestModelCommand:
|
||||
def _make_cli(self):
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.model = "anthropic/claude-opus-4.6"
|
||||
cli_obj.agent = object()
|
||||
cli_obj.provider = "openrouter"
|
||||
cli_obj.requested_provider = "openrouter"
|
||||
cli_obj.base_url = "https://openrouter.ai/api/v1"
|
||||
cli_obj.api_key = "test-key"
|
||||
cli_obj._explicit_api_key = None
|
||||
cli_obj._explicit_base_url = None
|
||||
return cli_obj
|
||||
|
||||
def test_valid_model_from_api_saved_to_config(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["anthropic/claude-sonnet-4.5", "openai/gpt-5.4"]), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-4.5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "saved to config" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-4.5"
|
||||
save_mock.assert_called_once_with("model.default", "anthropic/claude-sonnet-4.5")
|
||||
|
||||
def test_invalid_model_from_api_is_rejected(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["anthropic/claude-opus-4.6"]), \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model anthropic/fake-model")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6"
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_api_unreachable_falls_back_session_only(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=None), \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-next")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "session only" in output
|
||||
assert "will revert on restart" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-next"
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_no_slash_model_probes_api_and_rejects(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["openai/gpt-5.4"]) as fetch_mock, \
|
||||
patch("cli.save_config_value") as save_mock:
|
||||
cli_obj.process_command("/model gpt-5.4")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "not a valid model" in output
|
||||
assert "Model unchanged" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged
|
||||
assert cli_obj.agent is not None # not reset
|
||||
save_mock.assert_not_called()
|
||||
|
||||
def test_validation_crash_falls_back_to_save(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.models.validate_requested_model",
|
||||
side_effect=RuntimeError("boom")), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model anthropic/claude-sonnet-4.5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "saved to config" in output
|
||||
assert cli_obj.model == "anthropic/claude-sonnet-4.5"
|
||||
save_mock.assert_called_once()
|
||||
|
||||
def test_show_model_when_no_argument(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
cli_obj.process_command("/model")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "anthropic/claude-opus-4.6" in output
|
||||
assert "OpenRouter" in output
|
||||
assert "Available models" in output
|
||||
assert "provider:model-name" in output
|
||||
|
||||
# -- provider switching tests -------------------------------------------
|
||||
|
||||
def test_provider_colon_model_switches_provider(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value={
|
||||
"provider": "zai",
|
||||
"api_key": "zai-key",
|
||||
"base_url": "https://api.z.ai/api/paas/v4",
|
||||
}), \
|
||||
patch("hermes_cli.models.fetch_api_models",
|
||||
return_value=["glm-5", "glm-4.7"]), \
|
||||
patch("cli.save_config_value", return_value=True) as save_mock:
|
||||
cli_obj.process_command("/model zai:glm-5")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "glm-5" in output
|
||||
assert "provider:" in output.lower() or "Z.AI" in output
|
||||
assert cli_obj.model == "glm-5"
|
||||
assert cli_obj.provider == "zai"
|
||||
assert cli_obj.base_url == "https://api.z.ai/api/paas/v4"
|
||||
# Both model and provider should be saved
|
||||
assert save_mock.call_count == 2
|
||||
|
||||
def test_provider_switch_fails_on_bad_credentials(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
with patch("hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
side_effect=Exception("No API key found")):
|
||||
cli_obj.process_command("/model nous:hermes-3")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Could not resolve credentials" in output
|
||||
assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged
|
||||
assert cli_obj.provider == "openrouter" # unchanged
|
||||
@@ -162,6 +162,128 @@ def test_runtime_resolution_rebuilds_agent_on_routing_change(monkeypatch):
|
||||
assert shell.api_mode == "codex_responses"
|
||||
|
||||
|
||||
def test_codex_provider_replaces_incompatible_default_model(monkeypatch):
|
||||
"""When provider resolves to openai-codex and no model was explicitly
|
||||
chosen, the global config default (e.g. anthropic/claude-opus-4.6) must
|
||||
be replaced with a Codex-compatible model. Fixes #651."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda access_token=None: ["gpt-5.2-codex", "gpt-5.1-codex-mini"],
|
||||
)
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is True
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
assert "anthropic" not in shell.model
|
||||
assert "claude" not in shell.model
|
||||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_replaces_incompatible_envvar_model(monkeypatch):
|
||||
"""Exact scenario from #651: LLM_MODEL is set to a non-Codex model and
|
||||
provider resolves to openai-codex. The model must be replaced and a
|
||||
warning printed since the user explicitly chose it."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.setenv("LLM_MODEL", "claude-opus-4-6")
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.codex_models.get_codex_model_ids",
|
||||
lambda access_token=None: ["gpt-5.2-codex", "gpt-5.1-codex-mini"],
|
||||
)
|
||||
|
||||
shell = cli.HermesCLI(compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.provider == "openai-codex"
|
||||
assert "claude" not in shell.model
|
||||
assert shell.model == "gpt-5.2-codex"
|
||||
|
||||
|
||||
def test_codex_provider_preserves_explicit_codex_model(monkeypatch):
|
||||
"""If the user explicitly passes a Codex-compatible model, it must be
|
||||
preserved even when the provider resolves to openai-codex."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="gpt-5.1-codex-mini", compact=True, max_turns=1)
|
||||
|
||||
assert shell._model_is_default is False
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.1-codex-mini"
|
||||
|
||||
|
||||
def test_codex_provider_strips_provider_prefix_from_model(monkeypatch):
|
||||
"""openai/gpt-5.3-codex should become gpt-5.3-codex — the Codex
|
||||
Responses API does not accept provider-prefixed model slugs."""
|
||||
cli = _import_cli()
|
||||
|
||||
monkeypatch.delenv("LLM_MODEL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_MODEL", raising=False)
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
return {
|
||||
"provider": "openai-codex",
|
||||
"api_mode": "codex_responses",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_key": "test-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.format_runtime_provider_error", lambda exc: str(exc))
|
||||
|
||||
shell = cli.HermesCLI(model="openai/gpt-5.3-codex", compact=True, max_turns=1)
|
||||
|
||||
assert shell._ensure_runtime_credentials() is True
|
||||
assert shell.model == "gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
|
||||
@@ -30,6 +30,14 @@ def test_get_codex_model_ids_prioritizes_default_and_cache(tmp_path, monkeypatch
|
||||
assert "gpt-5-hidden-codex" not in models
|
||||
|
||||
|
||||
def test_setup_wizard_codex_import_resolves():
|
||||
"""Regression test for #712: setup.py must import the correct function name."""
|
||||
# This mirrors the exact import used in hermes_cli/setup.py line 873.
|
||||
# A prior bug had 'get_codex_models' (wrong) instead of 'get_codex_model_ids'.
|
||||
from hermes_cli.codex_models import get_codex_model_ids as setup_import
|
||||
assert callable(setup_import)
|
||||
|
||||
|
||||
def test_get_codex_model_ids_falls_back_to_curated_defaults(tmp_path, monkeypatch):
|
||||
codex_home = tmp_path / "codex-home"
|
||||
codex_home.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -351,6 +351,173 @@ class TestPruneSessions:
|
||||
# Schema and WAL mode
|
||||
# =========================================================================
|
||||
|
||||
# =========================================================================
|
||||
# Session title
|
||||
# =========================================================================
|
||||
|
||||
class TestSessionTitle:
|
||||
def test_set_and_get_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
assert db.set_session_title("s1", "My Session") is True
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "My Session"
|
||||
|
||||
def test_set_title_nonexistent_session(self, db):
|
||||
assert db.set_session_title("nonexistent", "Title") is False
|
||||
|
||||
def test_title_initially_none(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_update_title(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "First Title")
|
||||
db.set_session_title("s1", "Updated Title")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Updated Title"
|
||||
|
||||
def test_title_in_search_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Debugging Auth")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
|
||||
sessions = db.search_sessions()
|
||||
titled = [s for s in sessions if s.get("title") == "Debugging Auth"]
|
||||
assert len(titled) == 1
|
||||
assert titled[0]["id"] == "s1"
|
||||
|
||||
def test_title_in_export(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Export Test")
|
||||
db.append_message("s1", role="user", content="Hello")
|
||||
|
||||
export = db.export_session("s1")
|
||||
assert export["title"] == "Export Test"
|
||||
|
||||
def test_title_with_special_characters(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
db.set_session_title("s1", title)
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == title
|
||||
|
||||
def test_title_empty_string_normalized_to_none(self, db):
|
||||
"""Empty strings are normalized to None (clearing the title)."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "My Title")
|
||||
# Setting to empty string should clear the title (normalize to None)
|
||||
db.set_session_title("s1", "")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] is None
|
||||
|
||||
def test_multiple_empty_titles_no_conflict(self, db):
|
||||
"""Multiple sessions can have empty-string (normalized to NULL) titles."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.set_session_title("s1", "")
|
||||
db.set_session_title("s2", "")
|
||||
# Both should be None, no uniqueness conflict
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_title_survives_end_session(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.set_session_title("s1", "Before End")
|
||||
db.end_session("s1", end_reason="user_exit")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["title"] == "Before End"
|
||||
assert session["ended_at"] is not None
|
||||
|
||||
|
||||
class TestSanitizeTitle:
|
||||
"""Tests for SessionDB.sanitize_title() validation and cleaning."""
|
||||
|
||||
def test_normal_title_unchanged(self):
|
||||
assert SessionDB.sanitize_title("My Project") == "My Project"
|
||||
|
||||
def test_strips_whitespace(self):
|
||||
assert SessionDB.sanitize_title(" hello world ") == "hello world"
|
||||
|
||||
def test_collapses_internal_whitespace(self):
|
||||
assert SessionDB.sanitize_title("hello world") == "hello world"
|
||||
|
||||
def test_tabs_and_newlines_collapsed(self):
|
||||
assert SessionDB.sanitize_title("hello\t\nworld") == "hello world"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert SessionDB.sanitize_title(None) is None
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
assert SessionDB.sanitize_title("") is None
|
||||
|
||||
def test_whitespace_only_returns_none(self):
|
||||
assert SessionDB.sanitize_title(" \t\n ") is None
|
||||
|
||||
def test_control_chars_stripped(self):
|
||||
# Null byte, bell, backspace, etc.
|
||||
assert SessionDB.sanitize_title("hello\x00world") == "helloworld"
|
||||
assert SessionDB.sanitize_title("\x07\x08test\x1b") == "test"
|
||||
|
||||
def test_del_char_stripped(self):
|
||||
assert SessionDB.sanitize_title("hello\x7fworld") == "helloworld"
|
||||
|
||||
def test_zero_width_chars_stripped(self):
|
||||
# Zero-width space (U+200B), zero-width joiner (U+200D)
|
||||
assert SessionDB.sanitize_title("hello\u200bworld") == "helloworld"
|
||||
assert SessionDB.sanitize_title("hello\u200dworld") == "helloworld"
|
||||
|
||||
def test_rtl_override_stripped(self):
|
||||
# Right-to-left override (U+202E) — used in filename spoofing attacks
|
||||
assert SessionDB.sanitize_title("hello\u202eworld") == "helloworld"
|
||||
|
||||
def test_bom_stripped(self):
|
||||
# Byte order mark (U+FEFF)
|
||||
assert SessionDB.sanitize_title("\ufeffhello") == "hello"
|
||||
|
||||
def test_only_control_chars_returns_none(self):
|
||||
assert SessionDB.sanitize_title("\x00\x01\x02\u200b\ufeff") is None
|
||||
|
||||
def test_max_length_allowed(self):
|
||||
title = "A" * 100
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_exceeds_max_length_raises(self):
|
||||
title = "A" * 101
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
SessionDB.sanitize_title(title)
|
||||
|
||||
def test_unicode_emoji_allowed(self):
|
||||
assert SessionDB.sanitize_title("🚀 My Project 🎉") == "🚀 My Project 🎉"
|
||||
|
||||
def test_cjk_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("我的项目") == "我的项目"
|
||||
|
||||
def test_accented_characters_allowed(self):
|
||||
assert SessionDB.sanitize_title("Résumé éditing") == "Résumé éditing"
|
||||
|
||||
def test_special_punctuation_allowed(self):
|
||||
title = "PR #438 — fixing the 'auth' middleware"
|
||||
assert SessionDB.sanitize_title(title) == title
|
||||
|
||||
def test_sanitize_applied_in_set_session_title(self, db):
|
||||
"""set_session_title applies sanitize_title internally."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", " hello\x00 world ")
|
||||
assert db.get_session("s1")["title"] == "hello world"
|
||||
|
||||
def test_too_long_title_rejected_by_set(self, db):
|
||||
"""set_session_title raises ValueError for overly long titles."""
|
||||
db.create_session("s1", "cli")
|
||||
with pytest.raises(ValueError, match="too long"):
|
||||
db.set_session_title("s1", "X" * 150)
|
||||
|
||||
|
||||
class TestSchemaInit:
|
||||
def test_wal_mode(self, db):
|
||||
cursor = db._conn.execute("PRAGMA journal_mode")
|
||||
@@ -373,4 +540,297 @@ class TestSchemaInit:
|
||||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 2
|
||||
assert version == 4
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
cursor = db._conn.execute("PRAGMA table_info(sessions)")
|
||||
columns = {row[1] for row in cursor.fetchall()}
|
||||
assert "title" in columns
|
||||
|
||||
def test_migration_from_v2(self, tmp_path):
|
||||
"""Simulate a v2 database and verify migration adds title column."""
|
||||
import sqlite3
|
||||
|
||||
db_path = tmp_path / "migrate_test.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
# Create v2 schema (without title column)
|
||||
conn.executescript("""
|
||||
CREATE TABLE schema_version (version INTEGER NOT NULL);
|
||||
INSERT INTO schema_version (version) VALUES (2);
|
||||
|
||||
CREATE TABLE sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
model TEXT,
|
||||
model_config TEXT,
|
||||
system_prompt TEXT,
|
||||
parent_session_id TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
ended_at REAL,
|
||||
end_reason TEXT,
|
||||
message_count INTEGER DEFAULT 0,
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id TEXT NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT,
|
||||
tool_call_id TEXT,
|
||||
tool_calls TEXT,
|
||||
tool_name TEXT,
|
||||
timestamp REAL NOT NULL,
|
||||
token_count INTEGER,
|
||||
finish_reason TEXT
|
||||
);
|
||||
""")
|
||||
conn.execute(
|
||||
"INSERT INTO sessions (id, source, started_at) VALUES (?, ?, ?)",
|
||||
("existing", "cli", 1000.0),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v4
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 4
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session is not None
|
||||
assert session["title"] is None
|
||||
|
||||
# Verify we can set title on migrated session
|
||||
assert migrated_db.set_session_title("existing", "Migrated Title") is True
|
||||
session = migrated_db.get_session("existing")
|
||||
assert session["title"] == "Migrated Title"
|
||||
|
||||
migrated_db.close()
|
||||
|
||||
|
||||
class TestTitleUniqueness:
|
||||
"""Tests for unique title enforcement and title-based lookups."""
|
||||
|
||||
def test_duplicate_title_raises(self, db):
|
||||
"""Setting a title already used by another session raises ValueError."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
with pytest.raises(ValueError, match="already in use"):
|
||||
db.set_session_title("s2", "my project")
|
||||
|
||||
def test_same_session_can_keep_title(self, db):
|
||||
"""A session can re-set its own title without error."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
# Should not raise — it's the same session
|
||||
assert db.set_session_title("s1", "my project") is True
|
||||
|
||||
def test_null_titles_not_unique(self, db):
|
||||
"""Multiple sessions can have NULL titles (no constraint violation)."""
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "cli")
|
||||
# Both have NULL titles — no error
|
||||
assert db.get_session("s1")["title"] is None
|
||||
assert db.get_session("s2")["title"] is None
|
||||
|
||||
def test_get_session_by_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
result = db.get_session_by_title("refactoring auth")
|
||||
assert result is not None
|
||||
assert result["id"] == "s1"
|
||||
|
||||
def test_get_session_by_title_not_found(self, db):
|
||||
assert db.get_session_by_title("nonexistent") is None
|
||||
|
||||
def test_get_session_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
assert db.get_session_title("s1") is None
|
||||
db.set_session_title("s1", "my title")
|
||||
assert db.get_session_title("s1") == "my title"
|
||||
|
||||
def test_get_session_title_nonexistent(self, db):
|
||||
assert db.get_session_title("nonexistent") is None
|
||||
|
||||
|
||||
class TestTitleLineage:
|
||||
"""Tests for title lineage resolution and auto-numbering."""
|
||||
|
||||
def test_resolve_exact_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.resolve_session_by_title("my project") == "s1"
|
||||
|
||||
def test_resolve_returns_latest_numbered(self, db):
|
||||
"""When numbered variants exist, return the most recent one."""
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
time.sleep(0.01)
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
# Resolving "my project" should return s3 (latest numbered variant)
|
||||
assert db.resolve_session_by_title("my project") == "s3"
|
||||
|
||||
def test_resolve_exact_numbered(self, db):
|
||||
"""Resolving an exact numbered title returns that specific session."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Resolving "my project #2" exactly should return s2
|
||||
assert db.resolve_session_by_title("my project #2") == "s2"
|
||||
|
||||
def test_resolve_nonexistent_title(self, db):
|
||||
assert db.resolve_session_by_title("nonexistent") is None
|
||||
|
||||
def test_next_title_no_existing(self, db):
|
||||
"""With no existing sessions, base title is returned as-is."""
|
||||
assert db.get_next_title_in_lineage("my project") == "my project"
|
||||
|
||||
def test_next_title_first_continuation(self, db):
|
||||
"""First continuation after the original gets #2."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #2"
|
||||
|
||||
def test_next_title_increments(self, db):
|
||||
"""Each continuation increments the number."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
db.create_session("s3", "cli")
|
||||
db.set_session_title("s3", "my project #3")
|
||||
assert db.get_next_title_in_lineage("my project") == "my project #4"
|
||||
|
||||
def test_next_title_strips_existing_number(self, db):
|
||||
"""Passing a numbered title strips the number and finds the base."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "my project #2")
|
||||
# Even when called with "my project #2", it should return #3
|
||||
assert db.get_next_title_in_lineage("my project #2") == "my project #3"
|
||||
|
||||
|
||||
class TestTitleSqlWildcards:
|
||||
"""Titles containing SQL LIKE wildcards (%, _) must not cause false matches."""
|
||||
|
||||
def test_resolve_title_with_underscore(self, db):
|
||||
"""A title like 'test_project' should not match 'testXproject #2'."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Resolving "test_project" should return s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("test_project") == "s1"
|
||||
|
||||
def test_resolve_title_with_percent(self, db):
|
||||
"""A title with '%' should not wildcard-match unrelated sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "100% done")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "100X done #2")
|
||||
# Should resolve to s1 (exact), not s2
|
||||
assert db.resolve_session_by_title("100% done") == "s1"
|
||||
|
||||
def test_next_lineage_with_underscore(self, db):
|
||||
"""get_next_title_in_lineage with underscores doesn't match wrong sessions."""
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "test_project")
|
||||
db.create_session("s2", "cli")
|
||||
db.set_session_title("s2", "testXproject #2")
|
||||
# Only "test_project" exists, so next should be "test_project #2"
|
||||
assert db.get_next_title_in_lineage("test_project") == "test_project #2"
|
||||
|
||||
|
||||
class TestListSessionsRich:
|
||||
"""Tests for enhanced session listing with preview and last_active."""
|
||||
|
||||
def test_preview_from_first_user_message(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "You are a helpful assistant.")
|
||||
db.append_message("s1", "user", "Help me refactor the auth module please")
|
||||
db.append_message("s1", "assistant", "Sure, let me look at it.")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions) == 1
|
||||
assert "Help me refactor the auth module" in sessions[0]["preview"]
|
||||
|
||||
def test_preview_truncated_at_60(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
long_msg = "A" * 100
|
||||
db.append_message("s1", "user", long_msg)
|
||||
sessions = db.list_sessions_rich()
|
||||
assert len(sessions[0]["preview"]) == 63 # 60 chars + "..."
|
||||
assert sessions[0]["preview"].endswith("...")
|
||||
|
||||
def test_preview_empty_when_no_user_messages(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "system", "System prompt")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["preview"] == ""
|
||||
|
||||
def test_last_active_from_latest_message(self, db):
|
||||
import time
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Hello")
|
||||
time.sleep(0.01)
|
||||
db.append_message("s1", "assistant", "Hi there!")
|
||||
sessions = db.list_sessions_rich()
|
||||
# last_active should be close to now (the assistant message)
|
||||
assert sessions[0]["last_active"] > sessions[0]["started_at"]
|
||||
|
||||
def test_last_active_fallback_to_started_at(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
sessions = db.list_sessions_rich()
|
||||
# No messages, so last_active falls back to started_at
|
||||
assert sessions[0]["last_active"] == sessions[0]["started_at"]
|
||||
|
||||
def test_rich_list_includes_title(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "refactoring auth")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert sessions[0]["title"] == "refactoring auth"
|
||||
|
||||
def test_rich_list_source_filter(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.create_session("s2", "telegram")
|
||||
sessions = db.list_sessions_rich(source="cli")
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0]["id"] == "s1"
|
||||
|
||||
def test_preview_newlines_collapsed(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.append_message("s1", "user", "Line one\nLine two\nLine three")
|
||||
sessions = db.list_sessions_rich()
|
||||
assert "\n" not in sessions[0]["preview"]
|
||||
assert "Line one Line two" in sessions[0]["preview"]
|
||||
|
||||
|
||||
class TestResolveSessionByNameOrId:
|
||||
"""Tests for the main.py helper that resolves names or IDs."""
|
||||
|
||||
def test_resolve_by_id(self, db):
|
||||
db.create_session("test-id-123", "cli")
|
||||
session = db.get_session("test-id-123")
|
||||
assert session is not None
|
||||
assert session["id"] == "test-id-123"
|
||||
|
||||
def test_resolve_by_title_falls_back(self, db):
|
||||
db.create_session("s1", "cli")
|
||||
db.set_session_title("s1", "my project")
|
||||
result = db.resolve_session_by_title("my project")
|
||||
assert result == "s1"
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestBuildApiKwargsCodex:
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "reasoning" in kwargs
|
||||
assert kwargs["reasoning"]["effort"] == "xhigh"
|
||||
assert kwargs["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_includes_encrypted_content_in_include(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
@@ -596,19 +596,19 @@ class TestCodexReasoningPreflight:
|
||||
# ── Reasoning effort consistency tests ───────────────────────────────────────
|
||||
|
||||
class TestReasoningEffortDefaults:
|
||||
"""Verify reasoning effort defaults to xhigh across all provider paths."""
|
||||
"""Verify reasoning effort defaults to medium across all provider paths."""
|
||||
|
||||
def test_openrouter_default_xhigh(self, monkeypatch):
|
||||
def test_openrouter_default_medium(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openrouter")
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
reasoning = kwargs["extra_body"]["reasoning"]
|
||||
assert reasoning["effort"] == "xhigh"
|
||||
assert reasoning["effort"] == "medium"
|
||||
|
||||
def test_codex_default_xhigh(self, monkeypatch):
|
||||
def test_codex_default_medium(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
base_url="https://chatgpt.com/backend-api/codex")
|
||||
kwargs = agent._build_api_kwargs([{"role": "user", "content": "hi"}])
|
||||
assert kwargs["reasoning"]["effort"] == "xhigh"
|
||||
assert kwargs["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_codex_reasoning_disabled(self, monkeypatch):
|
||||
agent = _make_agent(monkeypatch, "openai-codex", api_mode="codex_responses",
|
||||
|
||||
@@ -280,22 +280,21 @@ class TestMaskApiKey:
|
||||
|
||||
|
||||
class TestInit:
|
||||
def test_anthropic_base_url_fails_fast(self):
|
||||
"""Anthropic native endpoints should error before building an OpenAI client."""
|
||||
def test_anthropic_base_url_accepted(self):
|
||||
"""Anthropic base URLs should be accepted (OpenAI-compatible endpoint)."""
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI") as mock_openai,
|
||||
):
|
||||
with pytest.raises(ValueError, match="not supported yet"):
|
||||
AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://api.anthropic.com/v1/messages",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
mock_openai.assert_not_called()
|
||||
AIAgent(
|
||||
api_key="test-key-1234567890",
|
||||
base_url="https://api.anthropic.com/v1/",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
def test_prompt_caching_claude_openrouter(self):
|
||||
"""Claude model via OpenRouter should enable prompt caching."""
|
||||
@@ -498,12 +497,12 @@ class TestBuildApiKwargs:
|
||||
assert kwargs["extra_body"]["provider"]["only"] == ["Anthropic"]
|
||||
|
||||
def test_reasoning_config_default_openrouter(self, agent):
|
||||
"""Default reasoning config for OpenRouter should be xhigh."""
|
||||
"""Default reasoning config for OpenRouter should be medium."""
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
reasoning = kwargs["extra_body"]["reasoning"]
|
||||
assert reasoning["enabled"] is True
|
||||
assert reasoning["effort"] == "xhigh"
|
||||
assert reasoning["effort"] == "medium"
|
||||
|
||||
def test_reasoning_config_custom(self, agent):
|
||||
agent.reasoning_config = {"enabled": False}
|
||||
|
||||
635
tests/test_worktree.py
Normal file
635
tests/test_worktree.py
Normal file
@@ -0,0 +1,635 @@
|
||||
"""Tests for git worktree isolation (CLI --worktree / -w flag).
|
||||
|
||||
Verifies worktree creation, cleanup, .worktreeinclude handling,
|
||||
.gitignore management, and integration with the CLI. (#652)
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo(tmp_path):
|
||||
"""Create a temporary git repo for testing."""
|
||||
repo = tmp_path / "test-repo"
|
||||
repo.mkdir()
|
||||
subprocess.run(["git", "init"], cwd=repo, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "config", "user.email", "test@test.com"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "config", "user.name", "Test"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
# Create initial commit (worktrees need at least one commit)
|
||||
(repo / "README.md").write_text("# Test Repo\n")
|
||||
subprocess.run(["git", "add", "."], cwd=repo, capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lightweight reimplementations for testing (avoid importing cli.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _git_repo_root(cwd=None):
|
||||
"""Test version of _git_repo_root."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=cwd,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout.strip()
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _setup_worktree(repo_root):
|
||||
"""Test version of _setup_worktree — creates a worktree."""
|
||||
import uuid
|
||||
short_id = uuid.uuid4().hex[:8]
|
||||
wt_name = f"hermes-{short_id}"
|
||||
branch_name = f"hermes/{wt_name}"
|
||||
|
||||
worktrees_dir = Path(repo_root) / ".worktrees"
|
||||
worktrees_dir.mkdir(parents=True, exist_ok=True)
|
||||
wt_path = worktrees_dir / wt_name
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "add", str(wt_path), "-b", branch_name, "HEAD"],
|
||||
capture_output=True, text=True, timeout=30, cwd=repo_root,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
|
||||
return {
|
||||
"path": str(wt_path),
|
||||
"branch": branch_name,
|
||||
"repo_root": repo_root,
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_worktree(info):
|
||||
"""Test version of _cleanup_worktree."""
|
||||
wt_path = info["path"]
|
||||
branch = info["branch"]
|
||||
repo_root = info["repo_root"]
|
||||
|
||||
if not Path(wt_path).exists():
|
||||
return
|
||||
|
||||
# Check for uncommitted changes
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=10, cwd=wt_path,
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
|
||||
if has_changes:
|
||||
return False # Did not clean up
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", wt_path, "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=repo_root,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=repo_root,
|
||||
)
|
||||
return True # Cleaned up
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGitRepoDetection:
|
||||
"""Test git repo root detection."""
|
||||
|
||||
def test_detects_git_repo(self, git_repo):
|
||||
root = _git_repo_root(cwd=str(git_repo))
|
||||
assert root is not None
|
||||
assert Path(root).resolve() == git_repo.resolve()
|
||||
|
||||
def test_detects_subdirectory(self, git_repo):
|
||||
subdir = git_repo / "src" / "lib"
|
||||
subdir.mkdir(parents=True)
|
||||
root = _git_repo_root(cwd=str(subdir))
|
||||
assert root is not None
|
||||
assert Path(root).resolve() == git_repo.resolve()
|
||||
|
||||
def test_returns_none_outside_repo(self, tmp_path):
|
||||
# tmp_path itself is not a git repo
|
||||
bare_dir = tmp_path / "not-a-repo"
|
||||
bare_dir.mkdir()
|
||||
root = _git_repo_root(cwd=str(bare_dir))
|
||||
assert root is None
|
||||
|
||||
|
||||
class TestWorktreeCreation:
|
||||
"""Test worktree setup."""
|
||||
|
||||
def test_creates_worktree(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
assert info["branch"].startswith("hermes/hermes-")
|
||||
assert info["repo_root"] == str(git_repo)
|
||||
|
||||
# Verify it's a valid git worktree
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
def test_worktree_has_own_branch(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Check branch name in worktree
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == info["branch"]
|
||||
|
||||
def test_worktree_is_independent(self, git_repo):
|
||||
"""Two worktrees from the same repo are independent."""
|
||||
info1 = _setup_worktree(str(git_repo))
|
||||
info2 = _setup_worktree(str(git_repo))
|
||||
assert info1 is not None
|
||||
assert info2 is not None
|
||||
assert info1["path"] != info2["path"]
|
||||
assert info1["branch"] != info2["branch"]
|
||||
|
||||
# Create a file in worktree 1
|
||||
(Path(info1["path"]) / "only-in-wt1.txt").write_text("hello")
|
||||
|
||||
# It should NOT appear in worktree 2
|
||||
assert not (Path(info2["path"]) / "only-in-wt1.txt").exists()
|
||||
|
||||
def test_worktrees_dir_created(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert (git_repo / ".worktrees").is_dir()
|
||||
|
||||
def test_worktree_has_repo_files(self, git_repo):
|
||||
"""Worktree should contain the repo's tracked files."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert (Path(info["path"]) / "README.md").exists()
|
||||
|
||||
|
||||
class TestWorktreeCleanup:
|
||||
"""Test worktree cleanup on exit."""
|
||||
|
||||
def test_clean_worktree_removed(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is True
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_dirty_worktree_kept(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make uncommitted changes
|
||||
(Path(info["path"]) / "new-file.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "new-file.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False
|
||||
assert Path(info["path"]).exists() # Still there
|
||||
|
||||
def test_branch_deleted_on_cleanup(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
branch = info["branch"]
|
||||
|
||||
_cleanup_worktree(info)
|
||||
|
||||
# Branch should be gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", branch],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert branch not in result.stdout
|
||||
|
||||
def test_cleanup_nonexistent_worktree(self, git_repo):
|
||||
"""Cleanup should handle already-removed worktrees gracefully."""
|
||||
info = {
|
||||
"path": str(git_repo / ".worktrees" / "nonexistent"),
|
||||
"branch": "hermes/nonexistent",
|
||||
"repo_root": str(git_repo),
|
||||
}
|
||||
# Should not raise
|
||||
_cleanup_worktree(info)
|
||||
|
||||
|
||||
class TestWorktreeInclude:
|
||||
"""Test .worktreeinclude file handling."""
|
||||
|
||||
def test_copies_included_files(self, git_repo):
|
||||
"""Files listed in .worktreeinclude should be copied to the worktree."""
|
||||
# Create a .env file (gitignored)
|
||||
(git_repo / ".env").write_text("SECRET=abc123")
|
||||
(git_repo / ".gitignore").write_text(".env\n.worktrees/\n")
|
||||
subprocess.run(
|
||||
["git", "add", ".gitignore"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "Add gitignore"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
# Create .worktreeinclude
|
||||
(git_repo / ".worktreeinclude").write_text(".env\n")
|
||||
|
||||
# Import and use the real _setup_worktree logic for include handling
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Manually copy .worktreeinclude entries (mirrors cli.py logic)
|
||||
import shutil
|
||||
include_file = git_repo / ".worktreeinclude"
|
||||
wt_path = Path(info["path"])
|
||||
for line in include_file.read_text().splitlines():
|
||||
entry = line.strip()
|
||||
if not entry or entry.startswith("#"):
|
||||
continue
|
||||
src = git_repo / entry
|
||||
dst = wt_path / entry
|
||||
if src.is_file():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(str(src), str(dst))
|
||||
|
||||
# Verify .env was copied
|
||||
assert (wt_path / ".env").exists()
|
||||
assert (wt_path / ".env").read_text() == "SECRET=abc123"
|
||||
|
||||
def test_ignores_comments_and_blanks(self, git_repo):
|
||||
"""Comments and blank lines in .worktreeinclude should be skipped."""
|
||||
(git_repo / ".worktreeinclude").write_text(
|
||||
"# This is a comment\n"
|
||||
"\n"
|
||||
" # Another comment\n"
|
||||
)
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
# Should not crash — just skip all lines
|
||||
|
||||
|
||||
class TestGitignoreManagement:
|
||||
"""Test that .worktrees/ is added to .gitignore."""
|
||||
|
||||
def test_adds_to_gitignore(self, git_repo):
|
||||
"""Creating a worktree should add .worktrees/ to .gitignore."""
|
||||
# Remove any existing .gitignore
|
||||
gitignore = git_repo / ".gitignore"
|
||||
if gitignore.exists():
|
||||
gitignore.unlink()
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Now manually add .worktrees/ to .gitignore (mirrors cli.py logic)
|
||||
_ignore_entry = ".worktrees/"
|
||||
existing = gitignore.read_text() if gitignore.exists() else ""
|
||||
if _ignore_entry not in existing.splitlines():
|
||||
with open(gitignore, "a") as f:
|
||||
if existing and not existing.endswith("\n"):
|
||||
f.write("\n")
|
||||
f.write(f"{_ignore_entry}\n")
|
||||
|
||||
content = gitignore.read_text()
|
||||
assert ".worktrees/" in content
|
||||
|
||||
def test_does_not_duplicate_gitignore_entry(self, git_repo):
|
||||
"""If .worktrees/ is already in .gitignore, don't add again."""
|
||||
gitignore = git_repo / ".gitignore"
|
||||
gitignore.write_text(".worktrees/\n")
|
||||
|
||||
# The check should see it's already there
|
||||
existing = gitignore.read_text()
|
||||
assert ".worktrees/" in existing.splitlines()
|
||||
|
||||
|
||||
class TestMultipleWorktrees:
|
||||
"""Test running multiple worktrees concurrently (the core use case)."""
|
||||
|
||||
def test_ten_concurrent_worktrees(self, git_repo):
|
||||
"""Create 10 worktrees — simulating 10 parallel agents."""
|
||||
worktrees = []
|
||||
for _ in range(10):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
worktrees.append(info)
|
||||
|
||||
# All should exist and be independent
|
||||
paths = [info["path"] for info in worktrees]
|
||||
assert len(set(paths)) == 10 # All unique
|
||||
|
||||
# Each should have the repo files
|
||||
for info in worktrees:
|
||||
assert (Path(info["path"]) / "README.md").exists()
|
||||
|
||||
# Edit a file in one worktree
|
||||
(Path(worktrees[0]["path"]) / "README.md").write_text("Modified in wt0")
|
||||
|
||||
# Others should be unaffected
|
||||
for info in worktrees[1:]:
|
||||
assert (Path(info["path"]) / "README.md").read_text() == "# Test Repo\n"
|
||||
|
||||
# List worktrees via git
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "list"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
# Should have 11 entries: main + 10 worktrees
|
||||
lines = [l for l in result.stdout.strip().splitlines() if l.strip()]
|
||||
assert len(lines) == 11
|
||||
|
||||
# Cleanup all
|
||||
for info in worktrees:
|
||||
# Discard changes first so cleanup works
|
||||
subprocess.run(
|
||||
["git", "checkout", "--", "."],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
_cleanup_worktree(info)
|
||||
|
||||
# All should be removed
|
||||
for info in worktrees:
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestWorktreeDirectorySymlink:
|
||||
"""Test .worktreeinclude with directories (symlinked)."""
|
||||
|
||||
def test_symlinks_directory(self, git_repo):
|
||||
"""Directories in .worktreeinclude should be symlinked."""
|
||||
# Create a .venv directory
|
||||
venv_dir = git_repo / ".venv" / "lib"
|
||||
venv_dir.mkdir(parents=True)
|
||||
(venv_dir / "marker.txt").write_text("venv marker")
|
||||
(git_repo / ".gitignore").write_text(".venv/\n.worktrees/\n")
|
||||
subprocess.run(
|
||||
["git", "add", ".gitignore"], cwd=str(git_repo), capture_output=True
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "gitignore"], cwd=str(git_repo), capture_output=True
|
||||
)
|
||||
|
||||
(git_repo / ".worktreeinclude").write_text(".venv/\n")
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
wt_path = Path(info["path"])
|
||||
src = git_repo / ".venv"
|
||||
dst = wt_path / ".venv"
|
||||
|
||||
# Manually symlink (mirrors cli.py logic)
|
||||
if not dst.exists():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.symlink(str(src.resolve()), str(dst))
|
||||
|
||||
assert dst.is_symlink()
|
||||
assert (dst / "lib" / "marker.txt").read_text() == "venv marker"
|
||||
|
||||
|
||||
class TestStaleWorktreePruning:
|
||||
"""Test _prune_stale_worktrees garbage collection."""
|
||||
|
||||
def test_prunes_old_clean_worktree(self, git_repo):
|
||||
"""Old clean worktrees should be removed on prune."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
# Make the worktree look old (set mtime to 25h ago)
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Reimplementation of prune logic (matches cli.py)
|
||||
worktrees_dir = git_repo / ".worktrees"
|
||||
cutoff = time.time() - (24 * 3600)
|
||||
|
||||
for entry in worktrees_dir.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("hermes-"):
|
||||
continue
|
||||
try:
|
||||
mtime = entry.stat().st_mtime
|
||||
if mtime > cutoff:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
if status.stdout.strip():
|
||||
continue
|
||||
|
||||
branch_result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
branch = branch_result.stdout.strip()
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", str(entry), "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=str(git_repo),
|
||||
)
|
||||
if branch:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_recent_worktree(self, git_repo):
|
||||
"""Recent worktrees should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Don't modify mtime — it's recent
|
||||
worktrees_dir = git_repo / ".worktrees"
|
||||
cutoff = time.time() - (24 * 3600)
|
||||
|
||||
pruned = False
|
||||
for entry in worktrees_dir.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("hermes-"):
|
||||
continue
|
||||
mtime = entry.stat().st_mtime
|
||||
if mtime > cutoff:
|
||||
continue # Too recent
|
||||
pruned = True
|
||||
|
||||
assert not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_dirty_old_worktree(self, git_repo):
|
||||
"""Old worktrees with uncommitted changes should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make it dirty
|
||||
(Path(info["path"]) / "dirty.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "dirty.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it old
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Check if it would be pruned
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
assert has_changes # Should be dirty → not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases for robustness."""
|
||||
|
||||
def test_no_commits_repo(self, tmp_path):
|
||||
"""Worktree creation should fail gracefully on a repo with no commits."""
|
||||
repo = tmp_path / "empty-repo"
|
||||
repo.mkdir()
|
||||
subprocess.run(["git", "init"], cwd=str(repo), capture_output=True)
|
||||
|
||||
info = _setup_worktree(str(repo))
|
||||
assert info is None # Should fail gracefully
|
||||
|
||||
def test_not_a_git_repo(self, tmp_path):
|
||||
"""Repo detection should return None for non-git directories."""
|
||||
bare = tmp_path / "not-git"
|
||||
bare.mkdir()
|
||||
root = _git_repo_root(cwd=str(bare))
|
||||
assert root is None
|
||||
|
||||
def test_worktrees_dir_already_exists(self, git_repo):
|
||||
"""Should work fine if .worktrees/ already exists."""
|
||||
(git_repo / ".worktrees").mkdir(exist_ok=True)
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestCLIFlagLogic:
|
||||
"""Test the flag/config OR logic from main()."""
|
||||
|
||||
def test_worktree_flag_triggers(self):
|
||||
"""--worktree flag should trigger worktree creation."""
|
||||
worktree = True
|
||||
w = False
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_w_flag_triggers(self):
|
||||
"""-w flag should trigger worktree creation."""
|
||||
worktree = False
|
||||
w = True
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_config_triggers(self):
|
||||
"""worktree: true in config should trigger worktree creation."""
|
||||
worktree = False
|
||||
w = False
|
||||
config_worktree = True
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert use_worktree
|
||||
|
||||
def test_none_set_no_trigger(self):
|
||||
"""No flags and no config should not trigger."""
|
||||
worktree = False
|
||||
w = False
|
||||
config_worktree = False
|
||||
use_worktree = worktree or w or config_worktree
|
||||
assert not use_worktree
|
||||
|
||||
|
||||
class TestTerminalCWDIntegration:
|
||||
"""Test that TERMINAL_CWD is correctly set to the worktree path."""
|
||||
|
||||
def test_terminal_cwd_set(self, git_repo):
|
||||
"""After worktree setup, TERMINAL_CWD should point to the worktree."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# This is what main() does:
|
||||
os.environ["TERMINAL_CWD"] = info["path"]
|
||||
assert os.environ["TERMINAL_CWD"] == info["path"]
|
||||
assert Path(os.environ["TERMINAL_CWD"]).exists()
|
||||
|
||||
# Clean up env
|
||||
del os.environ["TERMINAL_CWD"]
|
||||
|
||||
def test_terminal_cwd_is_valid_git_repo(self, git_repo):
|
||||
"""The TERMINAL_CWD worktree should be a valid git working tree."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
|
||||
class TestSystemPromptInjection:
|
||||
"""Test that the agent gets worktree context in its system prompt."""
|
||||
|
||||
def test_prompt_note_format(self, git_repo):
|
||||
"""Verify the system prompt note contains all required info."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# This is what main() does:
|
||||
wt_note = (
|
||||
f"\n\n[System note: You are working in an isolated git worktree at "
|
||||
f"{info['path']}. Your branch is `{info['branch']}`. "
|
||||
f"Changes here do not affect the main working tree or other agents. "
|
||||
f"Remember to commit and push your changes, and create a PR if appropriate. "
|
||||
f"The original repo is at {info['repo_root']}.]"
|
||||
)
|
||||
|
||||
assert info["path"] in wt_note
|
||||
assert info["branch"] in wt_note
|
||||
assert info["repo_root"] in wt_note
|
||||
assert "isolated git worktree" in wt_note
|
||||
assert "commit and push" in wt_note
|
||||
@@ -550,14 +550,13 @@ class TestConvertToPng:
|
||||
"""BMP file should still be reported as success if no converter available."""
|
||||
dest = tmp_path / "img.png"
|
||||
dest.write_bytes(FAKE_BMP) # it's a BMP but named .png
|
||||
# Both Pillow and ImageMagick fail
|
||||
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
|
||||
# Pillow import fails
|
||||
with pytest.raises(Exception):
|
||||
from PIL import Image # noqa — this may or may not work
|
||||
# The function should still return True if file exists and has content
|
||||
# (raw BMP is better than nothing)
|
||||
assert dest.exists() and dest.stat().st_size > 0
|
||||
# Both Pillow and ImageMagick unavailable
|
||||
with patch.dict(sys.modules, {"PIL": None, "PIL.Image": None}):
|
||||
with patch("hermes_cli.clipboard.subprocess.run", side_effect=FileNotFoundError):
|
||||
result = _convert_to_png(dest)
|
||||
# Raw BMP is better than nothing — function should return True
|
||||
assert result is True
|
||||
assert dest.exists() and dest.stat().st_size > 0
|
||||
|
||||
def test_imagemagick_failure_preserves_original(self, tmp_path):
|
||||
"""When ImageMagick convert fails, the original file must not be lost."""
|
||||
@@ -647,11 +646,11 @@ class TestHasClipboardImage:
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
# Level 2: _build_multimodal_content — image → OpenAI vision format
|
||||
# Level 2: _preprocess_images_with_vision — image → text via vision tool
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestBuildMultimodalContent:
|
||||
"""Test the extracted _build_multimodal_content method directly."""
|
||||
class TestPreprocessImagesWithVision:
|
||||
"""Test vision-based image pre-processing for the CLI."""
|
||||
|
||||
@pytest.fixture
|
||||
def cli(self):
|
||||
@@ -682,55 +681,81 @@ class TestBuildMultimodalContent:
|
||||
img.write_bytes(content)
|
||||
return img
|
||||
|
||||
def _mock_vision_success(self, description="A test image with colored pixels."):
|
||||
"""Return an async mock that simulates a successful vision_analyze_tool call."""
|
||||
import json
|
||||
async def _fake_vision(**kwargs):
|
||||
return json.dumps({"success": True, "analysis": description})
|
||||
return _fake_vision
|
||||
|
||||
def _mock_vision_failure(self):
|
||||
"""Return an async mock that simulates a failed vision_analyze_tool call."""
|
||||
import json
|
||||
async def _fake_vision(**kwargs):
|
||||
return json.dumps({"success": False, "analysis": "Error"})
|
||||
return _fake_vision
|
||||
|
||||
def test_single_image_with_text(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
result = cli._build_multimodal_content("Describe this", [img])
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("Describe this", [img])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0] == {"type": "text", "text": "Describe this"}
|
||||
assert result[1]["type"] == "image_url"
|
||||
url = result[1]["image_url"]["url"]
|
||||
assert url.startswith("data:image/png;base64,")
|
||||
# Verify the base64 actually decodes to our image
|
||||
b64_data = url.split(",", 1)[1]
|
||||
assert base64.b64decode(b64_data) == FAKE_PNG
|
||||
assert isinstance(result, str)
|
||||
assert "A test image with colored pixels." in result
|
||||
assert "Describe this" in result
|
||||
assert str(img) in result
|
||||
assert "base64," not in result # no raw base64 image content
|
||||
|
||||
def test_multiple_images(self, cli, tmp_path):
|
||||
imgs = [self._make_image(tmp_path, f"img{i}.png") for i in range(3)]
|
||||
result = cli._build_multimodal_content("Compare", imgs)
|
||||
assert len(result) == 4 # 1 text + 3 images
|
||||
assert all(r["type"] == "image_url" for r in result[1:])
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("Compare", imgs)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert "Compare" in result
|
||||
# Each image path should be referenced
|
||||
for img in imgs:
|
||||
assert str(img) in result
|
||||
|
||||
def test_empty_text_gets_default_question(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
result = cli._build_multimodal_content("", [img])
|
||||
assert result[0]["text"] == "What do you see in this image?"
|
||||
|
||||
def test_jpeg_mime_type(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "photo.jpg", b"\xff\xd8\xff\x00" * 20)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/jpeg" in result[1]["image_url"]["url"]
|
||||
|
||||
def test_webp_mime_type(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "img.webp", b"RIFF\x00\x00" * 10)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/webp" in result[1]["image_url"]["url"]
|
||||
|
||||
def test_unknown_extension_defaults_to_png(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path, "data.bmp", b"\x00" * 50)
|
||||
result = cli._build_multimodal_content("test", [img])
|
||||
assert "image/png" in result[1]["image_url"]["url"]
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("", [img])
|
||||
assert isinstance(result, str)
|
||||
assert "A test image with colored pixels." in result
|
||||
|
||||
def test_missing_image_skipped(self, cli, tmp_path):
|
||||
missing = tmp_path / "gone.png"
|
||||
result = cli._build_multimodal_content("test", [missing])
|
||||
assert len(result) == 1 # only text
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("test", [missing])
|
||||
# No images analyzed, falls back to default
|
||||
assert result == "test"
|
||||
|
||||
def test_mix_of_existing_and_missing(self, cli, tmp_path):
|
||||
real = self._make_image(tmp_path, "real.png")
|
||||
missing = tmp_path / "gone.png"
|
||||
result = cli._build_multimodal_content("test", [real, missing])
|
||||
assert len(result) == 2 # text + 1 real image
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_success()):
|
||||
result = cli._preprocess_images_with_vision("test", [real, missing])
|
||||
assert str(real) in result
|
||||
assert str(missing) not in result
|
||||
assert "test" in result
|
||||
|
||||
def test_vision_failure_includes_path(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=self._mock_vision_failure()):
|
||||
result = cli._preprocess_images_with_vision("check this", [img])
|
||||
assert isinstance(result, str)
|
||||
assert str(img) in result # path still included for retry
|
||||
assert "check this" in result
|
||||
|
||||
def test_vision_exception_includes_path(self, cli, tmp_path):
|
||||
img = self._make_image(tmp_path)
|
||||
async def _explode(**kwargs):
|
||||
raise RuntimeError("API down")
|
||||
with patch("tools.vision_tools.vision_analyze_tool", side_effect=_explode):
|
||||
result = cli._preprocess_images_with_vision("check this", [img])
|
||||
assert isinstance(result, str)
|
||||
assert str(img) in result # path still included for retry
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -56,7 +56,6 @@ class TestDelegateRequirements(unittest.TestCase):
|
||||
self.assertIn("tasks", props)
|
||||
self.assertIn("context", props)
|
||||
self.assertIn("toolsets", props)
|
||||
self.assertIn("model", props)
|
||||
self.assertIn("max_iterations", props)
|
||||
self.assertEqual(props["tasks"]["maxItems"], 3)
|
||||
|
||||
|
||||
@@ -259,6 +259,70 @@ class TestShellFileOpsHelpers:
|
||||
assert ops.cwd == "/"
|
||||
|
||||
|
||||
class TestSearchPathValidation:
|
||||
"""Test that search() returns an error for non-existent paths."""
|
||||
|
||||
def test_search_nonexistent_path_returns_error(self, mock_env):
|
||||
"""search() should return an error when the path doesn't exist."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/nonexistent/path")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_nonexistent_path_files_mode(self, mock_env):
|
||||
"""search(target='files') should also return error for bad paths."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "not_found", "returncode": 1}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
return {"output": "", "returncode": 0}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("*.py", path="/nonexistent/path", target="files")
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "Path not found" in result.error
|
||||
|
||||
def test_search_existing_path_proceeds(self, mock_env):
|
||||
"""search() should proceed normally when the path exists."""
|
||||
def side_effect(command, **kwargs):
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 1 (no matches) with empty output
|
||||
return {"output": "", "returncode": 1}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/existing/path")
|
||||
assert result.error is None
|
||||
assert result.total_count == 0 # No matches but no error
|
||||
|
||||
def test_search_rg_error_exit_code(self, mock_env):
|
||||
"""search() should report error when rg returns exit code 2."""
|
||||
call_count = {"n": 0}
|
||||
def side_effect(command, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if "test -e" in command:
|
||||
return {"output": "exists", "returncode": 0}
|
||||
if "command -v" in command:
|
||||
return {"output": "yes", "returncode": 0}
|
||||
# rg returns exit 2 (error) with empty output
|
||||
return {"output": "", "returncode": 2}
|
||||
mock_env.execute.side_effect = side_effect
|
||||
ops = ShellFileOperations(mock_env)
|
||||
result = ops.search("pattern", path="/some/path")
|
||||
assert result.error is not None
|
||||
assert "search failed" in result.error.lower() or "Search error" in result.error
|
||||
|
||||
|
||||
class TestShellFileOpsWriteDenied:
|
||||
def test_write_file_denied_path(self, file_ops):
|
||||
result = file_ops.write_file("~/.ssh/authorized_keys", "evil key")
|
||||
|
||||
Reference in New Issue
Block a user