diff --git a/src/infrastructure/protocol.py b/src/infrastructure/protocol.py new file mode 100644 index 00000000..dfabd3fe --- /dev/null +++ b/src/infrastructure/protocol.py @@ -0,0 +1,261 @@ +"""Shared WebSocket message protocol for the Matrix frontend. + +Defines all WebSocket message types as an enum and typed dataclasses +with ``to_json()`` / ``from_json()`` helpers so every producer and the +gateway speak the same language. + +Message wire format +------------------- +.. code-block:: json + + {"type": "agent_state", "agent_id": "timmy", "data": {...}, "ts": 1234567890} +""" + +import json +import logging +import time +from dataclasses import asdict, dataclass, field +from enum import StrEnum +from typing import Any + +logger = logging.getLogger(__name__) + + +class MessageType(StrEnum): + """All WebSocket message types defined by the Matrix PROTOCOL.md.""" + + AGENT_STATE = "agent_state" + VISITOR_STATE = "visitor_state" + BARK = "bark" + THOUGHT = "thought" + SYSTEM_STATUS = "system_status" + CONNECTION_ACK = "connection_ack" + ERROR = "error" + TASK_UPDATE = "task_update" + MEMORY_FLASH = "memory_flash" + + +# --------------------------------------------------------------------------- +# Base message +# --------------------------------------------------------------------------- + + +@dataclass +class WSMessage: + """Base WebSocket message with common envelope fields.""" + + type: str + ts: float = field(default_factory=time.time) + + def to_json(self) -> str: + """Serialise the message to a JSON string.""" + return json.dumps(asdict(self)) + + @classmethod + def from_json(cls, raw: str) -> "WSMessage": + """Deserialise a JSON string into the correct message subclass. + + Falls back to the base ``WSMessage`` when the ``type`` field is + unrecognised. + """ + data = json.loads(raw) + msg_type = data.get("type") + sub = _REGISTRY.get(msg_type) + if sub is not None: + return sub.from_json(raw) + return cls(**data) + + +# --------------------------------------------------------------------------- +# Concrete message types +# --------------------------------------------------------------------------- + + +@dataclass +class AgentStateMessage(WSMessage): + """State update for a single agent.""" + + type: str = field(default=MessageType.AGENT_STATE) + agent_id: str = "" + data: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_json(cls, raw: str) -> "AgentStateMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.AGENT_STATE), + ts=payload.get("ts", time.time()), + agent_id=payload.get("agent_id", ""), + data=payload.get("data", {}), + ) + + +@dataclass +class VisitorStateMessage(WSMessage): + """State update for a visitor / user session.""" + + type: str = field(default=MessageType.VISITOR_STATE) + visitor_id: str = "" + data: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_json(cls, raw: str) -> "VisitorStateMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.VISITOR_STATE), + ts=payload.get("ts", time.time()), + visitor_id=payload.get("visitor_id", ""), + data=payload.get("data", {}), + ) + + +@dataclass +class BarkMessage(WSMessage): + """A bark (chat-like utterance) from an agent.""" + + type: str = field(default=MessageType.BARK) + agent_id: str = "" + content: str = "" + + @classmethod + def from_json(cls, raw: str) -> "BarkMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.BARK), + ts=payload.get("ts", time.time()), + agent_id=payload.get("agent_id", ""), + content=payload.get("content", ""), + ) + + +@dataclass +class ThoughtMessage(WSMessage): + """An inner thought from an agent.""" + + type: str = field(default=MessageType.THOUGHT) + agent_id: str = "" + content: str = "" + + @classmethod + def from_json(cls, raw: str) -> "ThoughtMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.THOUGHT), + ts=payload.get("ts", time.time()), + agent_id=payload.get("agent_id", ""), + content=payload.get("content", ""), + ) + + +@dataclass +class SystemStatusMessage(WSMessage): + """System-wide status broadcast.""" + + type: str = field(default=MessageType.SYSTEM_STATUS) + status: str = "" + data: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_json(cls, raw: str) -> "SystemStatusMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.SYSTEM_STATUS), + ts=payload.get("ts", time.time()), + status=payload.get("status", ""), + data=payload.get("data", {}), + ) + + +@dataclass +class ConnectionAckMessage(WSMessage): + """Acknowledgement sent when a client connects.""" + + type: str = field(default=MessageType.CONNECTION_ACK) + client_id: str = "" + + @classmethod + def from_json(cls, raw: str) -> "ConnectionAckMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.CONNECTION_ACK), + ts=payload.get("ts", time.time()), + client_id=payload.get("client_id", ""), + ) + + +@dataclass +class ErrorMessage(WSMessage): + """Error message sent to a client.""" + + type: str = field(default=MessageType.ERROR) + code: str = "" + message: str = "" + + @classmethod + def from_json(cls, raw: str) -> "ErrorMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.ERROR), + ts=payload.get("ts", time.time()), + code=payload.get("code", ""), + message=payload.get("message", ""), + ) + + +@dataclass +class TaskUpdateMessage(WSMessage): + """Update about a task (created, assigned, completed, etc.).""" + + type: str = field(default=MessageType.TASK_UPDATE) + task_id: str = "" + status: str = "" + data: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_json(cls, raw: str) -> "TaskUpdateMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.TASK_UPDATE), + ts=payload.get("ts", time.time()), + task_id=payload.get("task_id", ""), + status=payload.get("status", ""), + data=payload.get("data", {}), + ) + + +@dataclass +class MemoryFlashMessage(WSMessage): + """A flash of memory — a recalled or stored memory event.""" + + type: str = field(default=MessageType.MEMORY_FLASH) + agent_id: str = "" + memory_key: str = "" + content: str = "" + + @classmethod + def from_json(cls, raw: str) -> "MemoryFlashMessage": + payload = json.loads(raw) + return cls( + type=payload.get("type", MessageType.MEMORY_FLASH), + ts=payload.get("ts", time.time()), + agent_id=payload.get("agent_id", ""), + memory_key=payload.get("memory_key", ""), + content=payload.get("content", ""), + ) + + +# --------------------------------------------------------------------------- +# Registry for from_json dispatch +# --------------------------------------------------------------------------- + +_REGISTRY: dict[str, type[WSMessage]] = { + MessageType.AGENT_STATE: AgentStateMessage, + MessageType.VISITOR_STATE: VisitorStateMessage, + MessageType.BARK: BarkMessage, + MessageType.THOUGHT: ThoughtMessage, + MessageType.SYSTEM_STATUS: SystemStatusMessage, + MessageType.CONNECTION_ACK: ConnectionAckMessage, + MessageType.ERROR: ErrorMessage, + MessageType.TASK_UPDATE: TaskUpdateMessage, + MessageType.MEMORY_FLASH: MemoryFlashMessage, +} diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py new file mode 100644 index 00000000..7a7526c3 --- /dev/null +++ b/tests/unit/test_protocol.py @@ -0,0 +1,173 @@ +"""Tests for infrastructure.protocol — WebSocket message types.""" + +import json + +import pytest + +from infrastructure.protocol import ( + AgentStateMessage, + BarkMessage, + ConnectionAckMessage, + ErrorMessage, + MemoryFlashMessage, + MessageType, + SystemStatusMessage, + TaskUpdateMessage, + ThoughtMessage, + VisitorStateMessage, + WSMessage, +) + +# --------------------------------------------------------------------------- +# MessageType enum +# --------------------------------------------------------------------------- + + +class TestMessageType: + """MessageType enum covers all 9 Matrix PROTOCOL.md types.""" + + def test_has_all_nine_types(self): + assert len(MessageType) == 9 + + @pytest.mark.parametrize( + "member,value", + [ + (MessageType.AGENT_STATE, "agent_state"), + (MessageType.VISITOR_STATE, "visitor_state"), + (MessageType.BARK, "bark"), + (MessageType.THOUGHT, "thought"), + (MessageType.SYSTEM_STATUS, "system_status"), + (MessageType.CONNECTION_ACK, "connection_ack"), + (MessageType.ERROR, "error"), + (MessageType.TASK_UPDATE, "task_update"), + (MessageType.MEMORY_FLASH, "memory_flash"), + ], + ) + def test_enum_values(self, member, value): + assert member.value == value + + def test_str_comparison(self): + """MessageType is a str enum so it can be compared to plain strings.""" + assert MessageType.BARK == "bark" + + +# --------------------------------------------------------------------------- +# to_json / from_json round-trip +# --------------------------------------------------------------------------- + + +class TestAgentStateMessage: + def test_defaults(self): + msg = AgentStateMessage() + assert msg.type == "agent_state" + assert msg.agent_id == "" + assert msg.data == {} + + def test_round_trip(self): + msg = AgentStateMessage(agent_id="timmy", data={"mood": "happy"}, ts=1000.0) + raw = msg.to_json() + restored = AgentStateMessage.from_json(raw) + assert restored.agent_id == "timmy" + assert restored.data == {"mood": "happy"} + assert restored.ts == 1000.0 + + def test_to_json_structure(self): + msg = AgentStateMessage(agent_id="timmy", data={"x": 1}, ts=123.0) + parsed = json.loads(msg.to_json()) + assert parsed["type"] == "agent_state" + assert parsed["agent_id"] == "timmy" + assert parsed["data"] == {"x": 1} + assert parsed["ts"] == 123.0 + + +class TestVisitorStateMessage: + def test_round_trip(self): + msg = VisitorStateMessage(visitor_id="v1", data={"page": "/"}, ts=1.0) + restored = VisitorStateMessage.from_json(msg.to_json()) + assert restored.visitor_id == "v1" + assert restored.data == {"page": "/"} + + +class TestBarkMessage: + def test_round_trip(self): + msg = BarkMessage(agent_id="timmy", content="woof!", ts=1.0) + restored = BarkMessage.from_json(msg.to_json()) + assert restored.agent_id == "timmy" + assert restored.content == "woof!" + + +class TestThoughtMessage: + def test_round_trip(self): + msg = ThoughtMessage(agent_id="timmy", content="hmm...", ts=1.0) + restored = ThoughtMessage.from_json(msg.to_json()) + assert restored.content == "hmm..." + + +class TestSystemStatusMessage: + def test_round_trip(self): + msg = SystemStatusMessage(status="healthy", data={"uptime": 3600}, ts=1.0) + restored = SystemStatusMessage.from_json(msg.to_json()) + assert restored.status == "healthy" + assert restored.data == {"uptime": 3600} + + +class TestConnectionAckMessage: + def test_round_trip(self): + msg = ConnectionAckMessage(client_id="abc-123", ts=1.0) + restored = ConnectionAckMessage.from_json(msg.to_json()) + assert restored.client_id == "abc-123" + + +class TestErrorMessage: + def test_round_trip(self): + msg = ErrorMessage(code="INVALID", message="bad request", ts=1.0) + restored = ErrorMessage.from_json(msg.to_json()) + assert restored.code == "INVALID" + assert restored.message == "bad request" + + +class TestTaskUpdateMessage: + def test_round_trip(self): + msg = TaskUpdateMessage(task_id="t1", status="completed", data={"result": "ok"}, ts=1.0) + restored = TaskUpdateMessage.from_json(msg.to_json()) + assert restored.task_id == "t1" + assert restored.status == "completed" + assert restored.data == {"result": "ok"} + + +class TestMemoryFlashMessage: + def test_round_trip(self): + msg = MemoryFlashMessage(agent_id="timmy", memory_key="fav_food", content="kibble", ts=1.0) + restored = MemoryFlashMessage.from_json(msg.to_json()) + assert restored.memory_key == "fav_food" + assert restored.content == "kibble" + + +# --------------------------------------------------------------------------- +# WSMessage.from_json dispatch +# --------------------------------------------------------------------------- + + +class TestWSMessageDispatch: + """WSMessage.from_json dispatches to the correct subclass.""" + + def test_dispatch_to_bark(self): + raw = json.dumps({"type": "bark", "agent_id": "t", "content": "woof", "ts": 1.0}) + msg = WSMessage.from_json(raw) + assert isinstance(msg, BarkMessage) + assert msg.content == "woof" + + def test_dispatch_to_error(self): + raw = json.dumps({"type": "error", "code": "E1", "message": "oops", "ts": 1.0}) + msg = WSMessage.from_json(raw) + assert isinstance(msg, ErrorMessage) + + def test_unknown_type_returns_base(self): + raw = json.dumps({"type": "unknown_future_type", "ts": 1.0}) + msg = WSMessage.from_json(raw) + assert type(msg) is WSMessage + assert msg.type == "unknown_future_type" + + def test_invalid_json_raises(self): + with pytest.raises(json.JSONDecodeError): + WSMessage.from_json("not json")