"""Tests for the async event bus (infrastructure.events.bus).""" import sqlite3 import pytest from infrastructure.events.bus import Event, EventBus, emit, event_bus, on class TestEvent: """Test Event dataclass.""" def test_event_defaults(self): e = Event(type="test.event", source="unit_test") assert e.type == "test.event" assert e.source == "unit_test" assert e.data == {} assert e.timestamp # auto-generated assert e.id.startswith("evt_") def test_event_custom_data(self): e = Event(type="a.b", source="s", data={"key": "val"}, id="custom-id") assert e.data == {"key": "val"} assert e.id == "custom-id" class TestEventBus: """Test EventBus subscribe/publish/history.""" def _fresh_bus(self) -> EventBus: return EventBus() # ── subscribe + publish ────────────────────────────────────────────── async def test_exact_match_subscribe(self): bus = self._fresh_bus() received = [] @bus.subscribe("task.created") async def handler(event: Event): received.append(event) count = await bus.publish(Event(type="task.created", source="test")) assert count == 1 assert len(received) == 1 assert received[0].type == "task.created" async def test_wildcard_subscribe(self): bus = self._fresh_bus() received = [] @bus.subscribe("agent.*") async def handler(event: Event): received.append(event) await bus.publish(Event(type="agent.joined", source="test")) await bus.publish(Event(type="agent.left", source="test")) await bus.publish(Event(type="task.created", source="test")) # should NOT match assert len(received) == 2 async def test_star_subscribes_to_all(self): bus = self._fresh_bus() received = [] @bus.subscribe("*") async def handler(event: Event): received.append(event) await bus.publish(Event(type="anything.here", source="test")) await bus.publish(Event(type="x", source="test")) assert len(received) == 2 async def test_no_subscribers_returns_zero(self): bus = self._fresh_bus() count = await bus.publish(Event(type="orphan.event", source="test")) assert count == 0 async def test_multiple_handlers_same_pattern(self): bus = self._fresh_bus() calls = {"a": 0, "b": 0} @bus.subscribe("foo.bar") async def handler_a(event): calls["a"] += 1 @bus.subscribe("foo.bar") async def handler_b(event): calls["b"] += 1 await bus.publish(Event(type="foo.bar", source="test")) assert calls["a"] == 1 assert calls["b"] == 1 # ── unsubscribe ────────────────────────────────────────────────────── async def test_unsubscribe(self): bus = self._fresh_bus() received = [] @bus.subscribe("x.y") async def handler(event): received.append(event) ok = bus.unsubscribe("x.y", handler) assert ok is True await bus.publish(Event(type="x.y", source="test")) assert len(received) == 0 async def test_unsubscribe_nonexistent_pattern(self): bus = self._fresh_bus() async def dummy(event): pass assert bus.unsubscribe("nope", dummy) is False async def test_unsubscribe_wrong_handler(self): bus = self._fresh_bus() @bus.subscribe("a.b") async def handler_a(event): pass async def handler_b(event): pass assert bus.unsubscribe("a.b", handler_b) is False # ── error handling ─────────────────────────────────────────────────── async def test_handler_error_does_not_break_other_handlers(self): bus = self._fresh_bus() received = [] @bus.subscribe("err.test") async def bad_handler(event): raise ValueError("boom") @bus.subscribe("err.test") async def good_handler(event): received.append(event) count = await bus.publish(Event(type="err.test", source="test")) assert count == 2 # both were invoked assert len(received) == 1 # good_handler still ran # ── history ────────────────────────────────────────────────────────── async def test_history_stores_events(self): bus = self._fresh_bus() await bus.publish(Event(type="h.a", source="s")) await bus.publish(Event(type="h.b", source="s")) history = bus.get_history() assert len(history) == 2 async def test_history_filter_by_type(self): bus = self._fresh_bus() await bus.publish(Event(type="h.a", source="s")) await bus.publish(Event(type="h.b", source="s")) assert len(bus.get_history(event_type="h.a")) == 1 async def test_history_filter_by_source(self): bus = self._fresh_bus() await bus.publish(Event(type="h.a", source="x")) await bus.publish(Event(type="h.b", source="y")) assert len(bus.get_history(source="x")) == 1 async def test_history_limit(self): bus = self._fresh_bus() for _i in range(5): await bus.publish(Event(type="h.x", source="s")) assert len(bus.get_history(limit=3)) == 3 async def test_history_max_cap(self): bus = self._fresh_bus() bus._max_history = 10 for _i in range(15): await bus.publish(Event(type="cap", source="s")) assert len(bus._history) == 10 async def test_clear_history(self): bus = self._fresh_bus() await bus.publish(Event(type="x", source="s")) bus.clear_history() assert len(bus.get_history()) == 0 # ── pattern matching ───────────────────────────────────────────────── def test_match_exact(self): bus = self._fresh_bus() assert bus._match_pattern("a.b.c", "a.b.c") is True assert bus._match_pattern("a.b.c", "a.b.d") is False def test_match_wildcard(self): bus = self._fresh_bus() assert bus._match_pattern("agent.joined", "agent.*") is True assert bus._match_pattern("agent.left", "agent.*") is True assert bus._match_pattern("task.created", "agent.*") is False def test_match_star(self): bus = self._fresh_bus() assert bus._match_pattern("anything", "*") is True class TestConvenienceFunctions: """Test module-level emit() and on() helpers.""" async def test_emit(self): # Clear singleton history first event_bus.clear_history() event_bus._subscribers.clear() received = [] @on("conv.test") async def handler(event): received.append(event) count = await emit("conv.test", "unit", {"foo": "bar"}) assert count == 1 assert received[0].data == {"foo": "bar"} # Cleanup event_bus._subscribers.clear() event_bus.clear_history() # ── Persistence ────────────────────────────────────────────────────────── class TestEventBusPersistence: """Test that EventBus persists events to SQLite.""" @pytest.fixture def persistent_bus(self, tmp_path): """Create an EventBus with persistence enabled.""" db_path = tmp_path / "events.db" bus = EventBus() bus.enable_persistence(db_path) return bus async def test_publish_persists_event(self, persistent_bus): """Published events should be written to SQLite.""" await persistent_bus.publish( Event(type="task.created", source="test", data={"task_id": "t1"}) ) events = persistent_bus.replay(event_type="task.created") assert len(events) >= 1 assert events[0].type == "task.created" assert events[0].data["task_id"] == "t1" async def test_replay_returns_persisted_events(self, persistent_bus): """Replay should return events from SQLite, not just in-memory history.""" for i in range(5): await persistent_bus.publish(Event(type="task.created", source="test", data={"i": i})) # Create a fresh bus pointing at the same DB to prove persistence bus2 = EventBus() bus2.enable_persistence(persistent_bus._persistence_db_path) events = bus2.replay(event_type="task.created") assert len(events) == 5 async def test_replay_filters_by_type(self, persistent_bus): """Replay should filter by event type.""" await persistent_bus.publish(Event(type="task.created", source="s")) await persistent_bus.publish(Event(type="agent.joined", source="s")) tasks = persistent_bus.replay(event_type="task.created") agents = persistent_bus.replay(event_type="agent.joined") assert len(tasks) == 1 assert len(agents) == 1 async def test_replay_filters_by_source(self, persistent_bus): """Replay should filter by source.""" await persistent_bus.publish(Event(type="x", source="alpha")) await persistent_bus.publish(Event(type="x", source="beta")) alpha_events = persistent_bus.replay(source="alpha") assert len(alpha_events) == 1 assert alpha_events[0].source == "alpha" async def test_replay_filters_by_task_id(self, persistent_bus): """Replay should filter by task_id in data.""" await persistent_bus.publish( Event(type="task.started", source="s", data={"task_id": "abc"}) ) await persistent_bus.publish( Event(type="task.started", source="s", data={"task_id": "xyz"}) ) events = persistent_bus.replay(task_id="abc") assert len(events) == 1 assert events[0].data["task_id"] == "abc" async def test_replay_respects_limit(self, persistent_bus): """Replay should respect the limit parameter.""" for _i in range(10): await persistent_bus.publish(Event(type="x", source="s")) events = persistent_bus.replay(limit=3) assert len(events) == 3 async def test_persistence_failure_does_not_crash(self, tmp_path): """If persistence fails, publish should still work (graceful degradation).""" bus = EventBus() # Enable persistence to a read-only path to simulate failure bus.enable_persistence(tmp_path / "events.db") received = [] @bus.subscribe("test.event") async def handler(event): received.append(event) # Should not raise even if persistence has issues count = await bus.publish(Event(type="test.event", source="test")) assert count == 1 assert len(received) == 1 async def test_bus_without_persistence_still_works(self): """EventBus should work fine without persistence enabled.""" bus = EventBus() received = [] @bus.subscribe("x") async def handler(event): received.append(event) await bus.publish(Event(type="x", source="s")) assert len(received) == 1 # replay returns empty when no persistence events = bus.replay() assert events == [] async def test_wal_mode_on_persistence_db(self, persistent_bus): """Persistence database should use WAL mode.""" conn = sqlite3.connect(str(persistent_bus._persistence_db_path)) try: mode = conn.execute("PRAGMA journal_mode").fetchone()[0] assert mode == "wal" finally: conn.close()