471 lines
16 KiB
Python
471 lines
16 KiB
Python
"""Tests for the async event bus (infrastructure.events.bus)."""
|
|
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
import infrastructure.events.bus as bus_module
|
|
from infrastructure.events.bus import (
|
|
Event,
|
|
EventBus,
|
|
emit,
|
|
event_bus,
|
|
get_event_bus,
|
|
init_event_bus_persistence,
|
|
on,
|
|
)
|
|
|
|
|
|
class TestEvent:
|
|
"""Test Event dataclass."""
|
|
|
|
def test_event_defaults(self):
|
|
e = Event(type="test.event", source="unit_test")
|
|
assert e.type == "test.event"
|
|
assert e.source == "unit_test"
|
|
assert e.data == {}
|
|
assert e.timestamp # auto-generated
|
|
assert e.id.startswith("evt_")
|
|
|
|
def test_event_custom_data(self):
|
|
e = Event(type="a.b", source="s", data={"key": "val"}, id="custom-id")
|
|
assert e.data == {"key": "val"}
|
|
assert e.id == "custom-id"
|
|
|
|
|
|
class TestEventBus:
|
|
"""Test EventBus subscribe/publish/history."""
|
|
|
|
def _fresh_bus(self) -> EventBus:
|
|
return EventBus()
|
|
|
|
# ── subscribe + publish ──────────────────────────────────────────────
|
|
|
|
async def test_exact_match_subscribe(self):
|
|
bus = self._fresh_bus()
|
|
received = []
|
|
|
|
@bus.subscribe("task.created")
|
|
async def handler(event: Event):
|
|
received.append(event)
|
|
|
|
count = await bus.publish(Event(type="task.created", source="test"))
|
|
assert count == 1
|
|
assert len(received) == 1
|
|
assert received[0].type == "task.created"
|
|
|
|
async def test_wildcard_subscribe(self):
|
|
bus = self._fresh_bus()
|
|
received = []
|
|
|
|
@bus.subscribe("agent.*")
|
|
async def handler(event: Event):
|
|
received.append(event)
|
|
|
|
await bus.publish(Event(type="agent.joined", source="test"))
|
|
await bus.publish(Event(type="agent.left", source="test"))
|
|
await bus.publish(Event(type="task.created", source="test")) # should NOT match
|
|
|
|
assert len(received) == 2
|
|
|
|
async def test_star_subscribes_to_all(self):
|
|
bus = self._fresh_bus()
|
|
received = []
|
|
|
|
@bus.subscribe("*")
|
|
async def handler(event: Event):
|
|
received.append(event)
|
|
|
|
await bus.publish(Event(type="anything.here", source="test"))
|
|
await bus.publish(Event(type="x", source="test"))
|
|
|
|
assert len(received) == 2
|
|
|
|
async def test_no_subscribers_returns_zero(self):
|
|
bus = self._fresh_bus()
|
|
count = await bus.publish(Event(type="orphan.event", source="test"))
|
|
assert count == 0
|
|
|
|
async def test_multiple_handlers_same_pattern(self):
|
|
bus = self._fresh_bus()
|
|
calls = {"a": 0, "b": 0}
|
|
|
|
@bus.subscribe("foo.bar")
|
|
async def handler_a(event):
|
|
calls["a"] += 1
|
|
|
|
@bus.subscribe("foo.bar")
|
|
async def handler_b(event):
|
|
calls["b"] += 1
|
|
|
|
await bus.publish(Event(type="foo.bar", source="test"))
|
|
assert calls["a"] == 1
|
|
assert calls["b"] == 1
|
|
|
|
# ── unsubscribe ──────────────────────────────────────────────────────
|
|
|
|
async def test_unsubscribe(self):
|
|
bus = self._fresh_bus()
|
|
received = []
|
|
|
|
@bus.subscribe("x.y")
|
|
async def handler(event):
|
|
received.append(event)
|
|
|
|
ok = bus.unsubscribe("x.y", handler)
|
|
assert ok is True
|
|
|
|
await bus.publish(Event(type="x.y", source="test"))
|
|
assert len(received) == 0
|
|
|
|
async def test_unsubscribe_nonexistent_pattern(self):
|
|
bus = self._fresh_bus()
|
|
|
|
async def dummy(event):
|
|
pass
|
|
|
|
assert bus.unsubscribe("nope", dummy) is False
|
|
|
|
async def test_unsubscribe_wrong_handler(self):
|
|
bus = self._fresh_bus()
|
|
|
|
@bus.subscribe("a.b")
|
|
async def handler_a(event):
|
|
pass
|
|
|
|
async def handler_b(event):
|
|
pass
|
|
|
|
assert bus.unsubscribe("a.b", handler_b) is False
|
|
|
|
# ── error handling ───────────────────────────────────────────────────
|
|
|
|
async def test_handler_error_does_not_break_other_handlers(self):
|
|
bus = self._fresh_bus()
|
|
received = []
|
|
|
|
@bus.subscribe("err.test")
|
|
async def bad_handler(event):
|
|
raise ValueError("boom")
|
|
|
|
@bus.subscribe("err.test")
|
|
async def good_handler(event):
|
|
received.append(event)
|
|
|
|
count = await bus.publish(Event(type="err.test", source="test"))
|
|
assert count == 2 # both were invoked
|
|
assert len(received) == 1 # good_handler still ran
|
|
|
|
# ── history ──────────────────────────────────────────────────────────
|
|
|
|
async def test_history_stores_events(self):
|
|
bus = self._fresh_bus()
|
|
await bus.publish(Event(type="h.a", source="s"))
|
|
await bus.publish(Event(type="h.b", source="s"))
|
|
|
|
history = bus.get_history()
|
|
assert len(history) == 2
|
|
|
|
async def test_history_filter_by_type(self):
|
|
bus = self._fresh_bus()
|
|
await bus.publish(Event(type="h.a", source="s"))
|
|
await bus.publish(Event(type="h.b", source="s"))
|
|
|
|
assert len(bus.get_history(event_type="h.a")) == 1
|
|
|
|
async def test_history_filter_by_source(self):
|
|
bus = self._fresh_bus()
|
|
await bus.publish(Event(type="h.a", source="x"))
|
|
await bus.publish(Event(type="h.b", source="y"))
|
|
|
|
assert len(bus.get_history(source="x")) == 1
|
|
|
|
async def test_history_limit(self):
|
|
bus = self._fresh_bus()
|
|
for _i in range(5):
|
|
await bus.publish(Event(type="h.x", source="s"))
|
|
|
|
assert len(bus.get_history(limit=3)) == 3
|
|
|
|
async def test_history_max_cap(self):
|
|
bus = self._fresh_bus()
|
|
bus._max_history = 10
|
|
for _i in range(15):
|
|
await bus.publish(Event(type="cap", source="s"))
|
|
|
|
assert len(bus._history) == 10
|
|
|
|
async def test_clear_history(self):
|
|
bus = self._fresh_bus()
|
|
await bus.publish(Event(type="x", source="s"))
|
|
bus.clear_history()
|
|
assert len(bus.get_history()) == 0
|
|
|
|
# ── pattern matching ─────────────────────────────────────────────────
|
|
|
|
def test_match_exact(self):
|
|
bus = self._fresh_bus()
|
|
assert bus._match_pattern("a.b.c", "a.b.c") is True
|
|
assert bus._match_pattern("a.b.c", "a.b.d") is False
|
|
|
|
def test_match_wildcard(self):
|
|
bus = self._fresh_bus()
|
|
assert bus._match_pattern("agent.joined", "agent.*") is True
|
|
assert bus._match_pattern("agent.left", "agent.*") is True
|
|
assert bus._match_pattern("task.created", "agent.*") is False
|
|
|
|
def test_match_star(self):
|
|
bus = self._fresh_bus()
|
|
assert bus._match_pattern("anything", "*") is True
|
|
|
|
|
|
class TestConvenienceFunctions:
|
|
"""Test module-level emit() and on() helpers."""
|
|
|
|
async def test_emit(self):
|
|
# Clear singleton history first
|
|
event_bus.clear_history()
|
|
event_bus._subscribers.clear()
|
|
|
|
received = []
|
|
|
|
@on("conv.test")
|
|
async def handler(event):
|
|
received.append(event)
|
|
|
|
count = await emit("conv.test", "unit", {"foo": "bar"})
|
|
assert count == 1
|
|
assert received[0].data == {"foo": "bar"}
|
|
|
|
# Cleanup
|
|
event_bus._subscribers.clear()
|
|
event_bus.clear_history()
|
|
|
|
|
|
# ── Persistence ──────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestEventBusPersistence:
|
|
"""Test that EventBus persists events to SQLite."""
|
|
|
|
@pytest.fixture
|
|
def persistent_bus(self, tmp_path):
|
|
"""Create an EventBus with persistence enabled."""
|
|
db_path = tmp_path / "events.db"
|
|
bus = EventBus()
|
|
bus.enable_persistence(db_path)
|
|
return bus
|
|
|
|
async def test_publish_persists_event(self, persistent_bus):
|
|
"""Published events should be written to SQLite."""
|
|
await persistent_bus.publish(
|
|
Event(type="task.created", source="test", data={"task_id": "t1"})
|
|
)
|
|
events = persistent_bus.replay(event_type="task.created")
|
|
assert len(events) >= 1
|
|
assert events[0].type == "task.created"
|
|
assert events[0].data["task_id"] == "t1"
|
|
|
|
async def test_replay_returns_persisted_events(self, persistent_bus):
|
|
"""Replay should return events from SQLite, not just in-memory history."""
|
|
for i in range(5):
|
|
await persistent_bus.publish(Event(type="task.created", source="test", data={"i": i}))
|
|
|
|
# Create a fresh bus pointing at the same DB to prove persistence
|
|
bus2 = EventBus()
|
|
bus2.enable_persistence(persistent_bus._persistence_db_path)
|
|
events = bus2.replay(event_type="task.created")
|
|
assert len(events) == 5
|
|
|
|
async def test_replay_filters_by_type(self, persistent_bus):
|
|
"""Replay should filter by event type."""
|
|
await persistent_bus.publish(Event(type="task.created", source="s"))
|
|
await persistent_bus.publish(Event(type="agent.joined", source="s"))
|
|
|
|
tasks = persistent_bus.replay(event_type="task.created")
|
|
agents = persistent_bus.replay(event_type="agent.joined")
|
|
assert len(tasks) == 1
|
|
assert len(agents) == 1
|
|
|
|
async def test_replay_filters_by_source(self, persistent_bus):
|
|
"""Replay should filter by source."""
|
|
await persistent_bus.publish(Event(type="x", source="alpha"))
|
|
await persistent_bus.publish(Event(type="x", source="beta"))
|
|
|
|
alpha_events = persistent_bus.replay(source="alpha")
|
|
assert len(alpha_events) == 1
|
|
assert alpha_events[0].source == "alpha"
|
|
|
|
async def test_replay_filters_by_task_id(self, persistent_bus):
|
|
"""Replay should filter by task_id in data."""
|
|
await persistent_bus.publish(
|
|
Event(type="task.started", source="s", data={"task_id": "abc"})
|
|
)
|
|
await persistent_bus.publish(
|
|
Event(type="task.started", source="s", data={"task_id": "xyz"})
|
|
)
|
|
|
|
events = persistent_bus.replay(task_id="abc")
|
|
assert len(events) == 1
|
|
assert events[0].data["task_id"] == "abc"
|
|
|
|
async def test_replay_respects_limit(self, persistent_bus):
|
|
"""Replay should respect the limit parameter."""
|
|
for _i in range(10):
|
|
await persistent_bus.publish(Event(type="x", source="s"))
|
|
|
|
events = persistent_bus.replay(limit=3)
|
|
assert len(events) == 3
|
|
|
|
async def test_persistence_failure_does_not_crash(self, tmp_path):
|
|
"""If persistence fails, publish should still work (graceful degradation)."""
|
|
bus = EventBus()
|
|
# Enable persistence to a read-only path to simulate failure
|
|
bus.enable_persistence(tmp_path / "events.db")
|
|
|
|
received = []
|
|
|
|
@bus.subscribe("test.event")
|
|
async def handler(event):
|
|
received.append(event)
|
|
|
|
# Should not raise even if persistence has issues
|
|
count = await bus.publish(Event(type="test.event", source="test"))
|
|
assert count == 1
|
|
assert len(received) == 1
|
|
|
|
async def test_bus_without_persistence_still_works(self):
|
|
"""EventBus should work fine without persistence enabled."""
|
|
bus = EventBus()
|
|
received = []
|
|
|
|
@bus.subscribe("x")
|
|
async def handler(event):
|
|
received.append(event)
|
|
|
|
await bus.publish(Event(type="x", source="s"))
|
|
assert len(received) == 1
|
|
|
|
# replay returns empty when no persistence
|
|
events = bus.replay()
|
|
assert events == []
|
|
|
|
async def test_wal_mode_on_persistence_db(self, persistent_bus):
|
|
"""Persistence database should use WAL mode."""
|
|
conn = sqlite3.connect(str(persistent_bus._persistence_db_path))
|
|
try:
|
|
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
|
|
assert mode == "wal"
|
|
finally:
|
|
conn.close()
|
|
|
|
async def test_persist_event_exception_is_swallowed(self, tmp_path):
|
|
"""_persist_event must not propagate SQLite errors."""
|
|
from unittest.mock import MagicMock
|
|
|
|
bus = EventBus()
|
|
bus.enable_persistence(tmp_path / "events.db")
|
|
|
|
# Make the INSERT raise an OperationalError
|
|
mock_conn = MagicMock()
|
|
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
|
|
|
from contextlib import contextmanager
|
|
|
|
@contextmanager
|
|
def fake_ctx():
|
|
yield mock_conn
|
|
|
|
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
|
# Should not raise
|
|
bus._persist_event(Event(type="x", source="s"))
|
|
|
|
async def test_replay_exception_returns_empty(self, tmp_path):
|
|
"""replay() must return [] when SQLite query fails."""
|
|
from unittest.mock import MagicMock
|
|
|
|
bus = EventBus()
|
|
bus.enable_persistence(tmp_path / "events.db")
|
|
|
|
mock_conn = MagicMock()
|
|
mock_conn.execute.side_effect = sqlite3.OperationalError("simulated failure")
|
|
|
|
from contextlib import contextmanager
|
|
|
|
@contextmanager
|
|
def fake_ctx():
|
|
yield mock_conn
|
|
|
|
with patch.object(bus, "_get_persistence_conn", fake_ctx):
|
|
result = bus.replay()
|
|
assert result == []
|
|
|
|
|
|
# ── Singleton helpers ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestSingletonHelpers:
|
|
"""Test get_event_bus(), init_event_bus_persistence(), and module __getattr__."""
|
|
|
|
def test_get_event_bus_returns_same_instance(self):
|
|
"""get_event_bus() is a true singleton."""
|
|
a = get_event_bus()
|
|
b = get_event_bus()
|
|
assert a is b
|
|
|
|
def test_module_event_bus_attr_is_singleton(self):
|
|
"""Accessing bus_module.event_bus via __getattr__ returns the singleton."""
|
|
assert bus_module.event_bus is get_event_bus()
|
|
|
|
def test_module_getattr_unknown_raises(self):
|
|
"""Accessing an unknown module attribute raises AttributeError."""
|
|
with pytest.raises(AttributeError):
|
|
_ = bus_module.no_such_attr # type: ignore[attr-defined]
|
|
|
|
def test_init_event_bus_persistence_sets_path(self, tmp_path):
|
|
"""init_event_bus_persistence() enables persistence on the singleton."""
|
|
bus = get_event_bus()
|
|
original_path = bus._persistence_db_path
|
|
try:
|
|
bus._persistence_db_path = None # reset for the test
|
|
db_path = tmp_path / "test_init.db"
|
|
init_event_bus_persistence(db_path)
|
|
assert bus._persistence_db_path == db_path
|
|
finally:
|
|
bus._persistence_db_path = original_path
|
|
|
|
def test_init_event_bus_persistence_is_idempotent(self, tmp_path):
|
|
"""Calling init_event_bus_persistence() twice keeps the first path."""
|
|
bus = get_event_bus()
|
|
original_path = bus._persistence_db_path
|
|
try:
|
|
bus._persistence_db_path = None
|
|
first_path = tmp_path / "first.db"
|
|
second_path = tmp_path / "second.db"
|
|
init_event_bus_persistence(first_path)
|
|
init_event_bus_persistence(second_path) # should be ignored
|
|
assert bus._persistence_db_path == first_path
|
|
finally:
|
|
bus._persistence_db_path = original_path
|
|
|
|
def test_init_event_bus_persistence_default_path(self):
|
|
"""init_event_bus_persistence() uses 'data/events.db' when no path given."""
|
|
bus = get_event_bus()
|
|
original_path = bus._persistence_db_path
|
|
try:
|
|
bus._persistence_db_path = None
|
|
# Patch enable_persistence to capture what path it receives
|
|
captured = {}
|
|
|
|
def fake_enable(path: Path) -> None:
|
|
captured["path"] = path
|
|
|
|
with patch.object(bus, "enable_persistence", side_effect=fake_enable):
|
|
init_event_bus_persistence()
|
|
|
|
assert captured["path"] == Path("data/events.db")
|
|
finally:
|
|
bus._persistence_db_path = original_path
|