"""Unit tests for brain.worker.DistributedWorker.""" from __future__ import annotations import threading from unittest.mock import MagicMock, patch import pytest from brain.worker import MAX_RETRIES, DelegatedTask, DistributedWorker @pytest.fixture(autouse=True) def clear_task_registry(): """Reset the worker registry before each test.""" DistributedWorker.clear() yield DistributedWorker.clear() class TestSubmit: def test_returns_task_id(self): with patch.object(DistributedWorker, "_run_task"): task_id = DistributedWorker.submit("researcher", "research", "find something") assert isinstance(task_id, str) assert len(task_id) == 8 def test_task_registered_as_queued(self): with patch.object(DistributedWorker, "_run_task"): task_id = DistributedWorker.submit("coder", "code", "fix the bug") status = DistributedWorker.get_status(task_id) assert status["found"] is True assert status["task_id"] == task_id assert status["agent"] == "coder" def test_unique_task_ids(self): with patch.object(DistributedWorker, "_run_task"): ids = [DistributedWorker.submit("coder", "code", "task") for _ in range(10)] assert len(set(ids)) == 10 def test_starts_daemon_thread(self): event = threading.Event() def fake_run_task(record): event.set() with patch.object(DistributedWorker, "_run_task", side_effect=fake_run_task): DistributedWorker.submit("coder", "code", "something") assert event.wait(timeout=2), "Background thread did not start" def test_priority_stored(self): with patch.object(DistributedWorker, "_run_task"): task_id = DistributedWorker.submit("coder", "code", "task", priority="high") status = DistributedWorker.get_status(task_id) assert status["priority"] == "high" class TestGetStatus: def test_unknown_task_id(self): result = DistributedWorker.get_status("deadbeef") assert result["found"] is False assert result["task_id"] == "deadbeef" def test_known_task_has_all_fields(self): with patch.object(DistributedWorker, "_run_task"): task_id = DistributedWorker.submit("writer", "writing", "write a blog post") status = DistributedWorker.get_status(task_id) for key in ("found", "task_id", "agent", "role", "status", "backend", "created_at"): assert key in status, f"Missing key: {key}" class TestListTasks: def test_empty_initially(self): assert DistributedWorker.list_tasks() == [] def test_returns_registered_tasks(self): with patch.object(DistributedWorker, "_run_task"): DistributedWorker.submit("coder", "code", "task A") DistributedWorker.submit("writer", "writing", "task B") tasks = DistributedWorker.list_tasks() assert len(tasks) == 2 agents = {t["agent"] for t in tasks} assert agents == {"coder", "writer"} class TestSelectBackend: def test_defaults_to_agentic_loop(self): with patch("brain.worker.logger"): backend = DistributedWorker._select_backend("code", "fix the bug") assert backend == "agentic_loop" def test_kimi_for_heavy_research_with_gitea(self): mock_settings = MagicMock() mock_settings.gitea_enabled = True mock_settings.gitea_token = "tok" mock_settings.paperclip_api_key = "" with ( patch("timmy.kimi_delegation.exceeds_local_capacity", return_value=True), patch("config.settings", mock_settings), ): backend = DistributedWorker._select_backend("research", "comprehensive survey " * 10) assert backend == "kimi" def test_agentic_loop_when_no_gitea(self): mock_settings = MagicMock() mock_settings.gitea_enabled = False mock_settings.gitea_token = "" mock_settings.paperclip_api_key = "" with patch("config.settings", mock_settings): backend = DistributedWorker._select_backend("research", "comprehensive survey " * 10) assert backend == "agentic_loop" def test_paperclip_when_api_key_configured(self): mock_settings = MagicMock() mock_settings.gitea_enabled = False mock_settings.gitea_token = "" mock_settings.paperclip_api_key = "pk_test_123" with patch("config.settings", mock_settings): backend = DistributedWorker._select_backend("code", "build a widget") assert backend == "paperclip" class TestRunTask: def test_marks_completed_on_success(self): record = DelegatedTask( task_id="abc12345", agent_name="coder", agent_role="code", task_description="fix bug", priority="normal", backend="agentic_loop", ) with patch.object(DistributedWorker, "_dispatch", return_value={"success": True}): DistributedWorker._run_task(record) assert record.status == "completed" assert record.result == {"success": True} assert record.error is None def test_marks_failed_after_exhausting_retries(self): record = DelegatedTask( task_id="fail1234", agent_name="coder", agent_role="code", task_description="broken task", priority="normal", backend="agentic_loop", ) with patch.object(DistributedWorker, "_dispatch", side_effect=RuntimeError("boom")): DistributedWorker._run_task(record) assert record.status == "failed" assert "boom" in record.error assert record.retries == MAX_RETRIES def test_retries_before_failing(self): record = DelegatedTask( task_id="retry001", agent_name="coder", agent_role="code", task_description="flaky task", priority="normal", backend="agentic_loop", ) call_count = 0 def flaky_dispatch(r): nonlocal call_count call_count += 1 if call_count < MAX_RETRIES + 1: raise RuntimeError("transient failure") return {"success": True} with patch.object(DistributedWorker, "_dispatch", side_effect=flaky_dispatch): DistributedWorker._run_task(record) assert record.status == "completed" assert call_count == MAX_RETRIES + 1 def test_succeeds_on_first_attempt(self): record = DelegatedTask( task_id="ok000001", agent_name="writer", agent_role="writing", task_description="write summary", priority="low", backend="agentic_loop", ) with patch.object(DistributedWorker, "_dispatch", return_value={"summary": "done"}): DistributedWorker._run_task(record) assert record.status == "completed" assert record.retries == 0 class TestDelegatetaskIntegration: """Integration: delegate_task should wire to DistributedWorker.""" def test_delegate_task_returns_task_id(self): from timmy.tools_delegation import delegate_task with patch.object(DistributedWorker, "_run_task"): result = delegate_task("researcher", "research something for me") assert result["success"] is True assert result["task_id"] is not None assert result["status"] == "queued" def test_delegate_task_status_queued_for_valid_agent(self): from timmy.tools_delegation import delegate_task with patch.object(DistributedWorker, "_run_task"): result = delegate_task("coder", "implement feature X") assert result["status"] == "queued" assert len(result["task_id"]) == 8 def test_task_in_registry_after_delegation(self): from timmy.tools_delegation import delegate_task with patch.object(DistributedWorker, "_run_task"): result = delegate_task("writer", "write documentation") task_id = result["task_id"] status = DistributedWorker.get_status(task_id) assert status["found"] is True assert status["agent"] == "writer"