Co-authored-by: Claude (Opus 4.6) <claude@hermes.local> Co-committed-by: Claude (Opus 4.6) <claude@hermes.local>
This commit was merged in pull request #1273.
This commit is contained in:
235
tests/unit/test_brain_worker.py
Normal file
235
tests/unit/test_brain_worker.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Unit tests for brain.worker.DistributedWorker."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from brain.worker import MAX_RETRIES, DelegatedTask, DistributedWorker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_task_registry():
|
||||
"""Reset the worker registry before each test."""
|
||||
DistributedWorker.clear()
|
||||
yield
|
||||
DistributedWorker.clear()
|
||||
|
||||
|
||||
class TestSubmit:
|
||||
def test_returns_task_id(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
task_id = DistributedWorker.submit("researcher", "research", "find something")
|
||||
assert isinstance(task_id, str)
|
||||
assert len(task_id) == 8
|
||||
|
||||
def test_task_registered_as_queued(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
task_id = DistributedWorker.submit("coder", "code", "fix the bug")
|
||||
status = DistributedWorker.get_status(task_id)
|
||||
assert status["found"] is True
|
||||
assert status["task_id"] == task_id
|
||||
assert status["agent"] == "coder"
|
||||
|
||||
def test_unique_task_ids(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
ids = [DistributedWorker.submit("coder", "code", "task") for _ in range(10)]
|
||||
assert len(set(ids)) == 10
|
||||
|
||||
def test_starts_daemon_thread(self):
|
||||
event = threading.Event()
|
||||
|
||||
def fake_run_task(record):
|
||||
event.set()
|
||||
|
||||
with patch.object(DistributedWorker, "_run_task", side_effect=fake_run_task):
|
||||
DistributedWorker.submit("coder", "code", "something")
|
||||
|
||||
assert event.wait(timeout=2), "Background thread did not start"
|
||||
|
||||
def test_priority_stored(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
task_id = DistributedWorker.submit("coder", "code", "task", priority="high")
|
||||
status = DistributedWorker.get_status(task_id)
|
||||
assert status["priority"] == "high"
|
||||
|
||||
|
||||
class TestGetStatus:
|
||||
def test_unknown_task_id(self):
|
||||
result = DistributedWorker.get_status("deadbeef")
|
||||
assert result["found"] is False
|
||||
assert result["task_id"] == "deadbeef"
|
||||
|
||||
def test_known_task_has_all_fields(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
task_id = DistributedWorker.submit("writer", "writing", "write a blog post")
|
||||
status = DistributedWorker.get_status(task_id)
|
||||
for key in ("found", "task_id", "agent", "role", "status", "backend", "created_at"):
|
||||
assert key in status, f"Missing key: {key}"
|
||||
|
||||
|
||||
class TestListTasks:
|
||||
def test_empty_initially(self):
|
||||
assert DistributedWorker.list_tasks() == []
|
||||
|
||||
def test_returns_registered_tasks(self):
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
DistributedWorker.submit("coder", "code", "task A")
|
||||
DistributedWorker.submit("writer", "writing", "task B")
|
||||
tasks = DistributedWorker.list_tasks()
|
||||
assert len(tasks) == 2
|
||||
agents = {t["agent"] for t in tasks}
|
||||
assert agents == {"coder", "writer"}
|
||||
|
||||
|
||||
class TestSelectBackend:
|
||||
def test_defaults_to_agentic_loop(self):
|
||||
with patch("brain.worker.logger"):
|
||||
backend = DistributedWorker._select_backend("code", "fix the bug")
|
||||
assert backend == "agentic_loop"
|
||||
|
||||
def test_kimi_for_heavy_research_with_gitea(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = True
|
||||
mock_settings.gitea_token = "tok"
|
||||
mock_settings.paperclip_api_key = ""
|
||||
|
||||
with (
|
||||
patch("timmy.kimi_delegation.exceeds_local_capacity", return_value=True),
|
||||
patch("config.settings", mock_settings),
|
||||
):
|
||||
backend = DistributedWorker._select_backend("research", "comprehensive survey " * 10)
|
||||
assert backend == "kimi"
|
||||
|
||||
def test_agentic_loop_when_no_gitea(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = False
|
||||
mock_settings.gitea_token = ""
|
||||
mock_settings.paperclip_api_key = ""
|
||||
|
||||
with patch("config.settings", mock_settings):
|
||||
backend = DistributedWorker._select_backend("research", "comprehensive survey " * 10)
|
||||
assert backend == "agentic_loop"
|
||||
|
||||
def test_paperclip_when_api_key_configured(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.gitea_enabled = False
|
||||
mock_settings.gitea_token = ""
|
||||
mock_settings.paperclip_api_key = "pk_test_123"
|
||||
|
||||
with patch("config.settings", mock_settings):
|
||||
backend = DistributedWorker._select_backend("code", "build a widget")
|
||||
assert backend == "paperclip"
|
||||
|
||||
|
||||
class TestRunTask:
|
||||
def test_marks_completed_on_success(self):
|
||||
record = DelegatedTask(
|
||||
task_id="abc12345",
|
||||
agent_name="coder",
|
||||
agent_role="code",
|
||||
task_description="fix bug",
|
||||
priority="normal",
|
||||
backend="agentic_loop",
|
||||
)
|
||||
|
||||
with patch.object(DistributedWorker, "_dispatch", return_value={"success": True}):
|
||||
DistributedWorker._run_task(record)
|
||||
|
||||
assert record.status == "completed"
|
||||
assert record.result == {"success": True}
|
||||
assert record.error is None
|
||||
|
||||
def test_marks_failed_after_exhausting_retries(self):
|
||||
record = DelegatedTask(
|
||||
task_id="fail1234",
|
||||
agent_name="coder",
|
||||
agent_role="code",
|
||||
task_description="broken task",
|
||||
priority="normal",
|
||||
backend="agentic_loop",
|
||||
)
|
||||
|
||||
with patch.object(DistributedWorker, "_dispatch", side_effect=RuntimeError("boom")):
|
||||
DistributedWorker._run_task(record)
|
||||
|
||||
assert record.status == "failed"
|
||||
assert "boom" in record.error
|
||||
assert record.retries == MAX_RETRIES
|
||||
|
||||
def test_retries_before_failing(self):
|
||||
record = DelegatedTask(
|
||||
task_id="retry001",
|
||||
agent_name="coder",
|
||||
agent_role="code",
|
||||
task_description="flaky task",
|
||||
priority="normal",
|
||||
backend="agentic_loop",
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def flaky_dispatch(r):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < MAX_RETRIES + 1:
|
||||
raise RuntimeError("transient failure")
|
||||
return {"success": True}
|
||||
|
||||
with patch.object(DistributedWorker, "_dispatch", side_effect=flaky_dispatch):
|
||||
DistributedWorker._run_task(record)
|
||||
|
||||
assert record.status == "completed"
|
||||
assert call_count == MAX_RETRIES + 1
|
||||
|
||||
def test_succeeds_on_first_attempt(self):
|
||||
record = DelegatedTask(
|
||||
task_id="ok000001",
|
||||
agent_name="writer",
|
||||
agent_role="writing",
|
||||
task_description="write summary",
|
||||
priority="low",
|
||||
backend="agentic_loop",
|
||||
)
|
||||
|
||||
with patch.object(DistributedWorker, "_dispatch", return_value={"summary": "done"}):
|
||||
DistributedWorker._run_task(record)
|
||||
|
||||
assert record.status == "completed"
|
||||
assert record.retries == 0
|
||||
|
||||
|
||||
class TestDelegatetaskIntegration:
|
||||
"""Integration: delegate_task should wire to DistributedWorker."""
|
||||
|
||||
def test_delegate_task_returns_task_id(self):
|
||||
from timmy.tools_delegation import delegate_task
|
||||
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
result = delegate_task("researcher", "research something for me")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["task_id"] is not None
|
||||
assert result["status"] == "queued"
|
||||
|
||||
def test_delegate_task_status_queued_for_valid_agent(self):
|
||||
from timmy.tools_delegation import delegate_task
|
||||
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
result = delegate_task("coder", "implement feature X")
|
||||
|
||||
assert result["status"] == "queued"
|
||||
assert len(result["task_id"]) == 8
|
||||
|
||||
def test_task_in_registry_after_delegation(self):
|
||||
from timmy.tools_delegation import delegate_task
|
||||
|
||||
with patch.object(DistributedWorker, "_run_task"):
|
||||
result = delegate_task("writer", "write documentation")
|
||||
|
||||
task_id = result["task_id"]
|
||||
status = DistributedWorker.get_status(task_id)
|
||||
assert status["found"] is True
|
||||
assert status["agent"] == "writer"
|
||||
Reference in New Issue
Block a user