[kimi] refactor: extract WebSocket message types into shared protocol module (#667) #696
261
src/infrastructure/protocol.py
Normal file
261
src/infrastructure/protocol.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""Shared WebSocket message protocol for the Matrix frontend.
|
||||
|
||||
Defines all WebSocket message types as an enum and typed dataclasses
|
||||
with ``to_json()`` / ``from_json()`` helpers so every producer and the
|
||||
gateway speak the same language.
|
||||
|
||||
Message wire format
|
||||
-------------------
|
||||
.. code-block:: json
|
||||
|
||||
{"type": "agent_state", "agent_id": "timmy", "data": {...}, "ts": 1234567890}
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(StrEnum):
|
||||
"""All WebSocket message types defined by the Matrix PROTOCOL.md."""
|
||||
|
||||
AGENT_STATE = "agent_state"
|
||||
VISITOR_STATE = "visitor_state"
|
||||
BARK = "bark"
|
||||
THOUGHT = "thought"
|
||||
SYSTEM_STATUS = "system_status"
|
||||
CONNECTION_ACK = "connection_ack"
|
||||
ERROR = "error"
|
||||
TASK_UPDATE = "task_update"
|
||||
MEMORY_FLASH = "memory_flash"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class WSMessage:
|
||||
"""Base WebSocket message with common envelope fields."""
|
||||
|
||||
type: str
|
||||
ts: float = field(default_factory=time.time)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Serialise the message to a JSON string."""
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "WSMessage":
|
||||
"""Deserialise a JSON string into the correct message subclass.
|
||||
|
||||
Falls back to the base ``WSMessage`` when the ``type`` field is
|
||||
unrecognised.
|
||||
"""
|
||||
data = json.loads(raw)
|
||||
msg_type = data.get("type")
|
||||
sub = _REGISTRY.get(msg_type)
|
||||
if sub is not None:
|
||||
return sub.from_json(raw)
|
||||
return cls(**data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete message types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStateMessage(WSMessage):
|
||||
"""State update for a single agent."""
|
||||
|
||||
type: str = field(default=MessageType.AGENT_STATE)
|
||||
agent_id: str = ""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "AgentStateMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.AGENT_STATE),
|
||||
ts=payload.get("ts", time.time()),
|
||||
agent_id=payload.get("agent_id", ""),
|
||||
data=payload.get("data", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisitorStateMessage(WSMessage):
|
||||
"""State update for a visitor / user session."""
|
||||
|
||||
type: str = field(default=MessageType.VISITOR_STATE)
|
||||
visitor_id: str = ""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "VisitorStateMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.VISITOR_STATE),
|
||||
ts=payload.get("ts", time.time()),
|
||||
visitor_id=payload.get("visitor_id", ""),
|
||||
data=payload.get("data", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BarkMessage(WSMessage):
|
||||
"""A bark (chat-like utterance) from an agent."""
|
||||
|
||||
type: str = field(default=MessageType.BARK)
|
||||
agent_id: str = ""
|
||||
content: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "BarkMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.BARK),
|
||||
ts=payload.get("ts", time.time()),
|
||||
agent_id=payload.get("agent_id", ""),
|
||||
content=payload.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThoughtMessage(WSMessage):
|
||||
"""An inner thought from an agent."""
|
||||
|
||||
type: str = field(default=MessageType.THOUGHT)
|
||||
agent_id: str = ""
|
||||
content: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "ThoughtMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.THOUGHT),
|
||||
ts=payload.get("ts", time.time()),
|
||||
agent_id=payload.get("agent_id", ""),
|
||||
content=payload.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemStatusMessage(WSMessage):
|
||||
"""System-wide status broadcast."""
|
||||
|
||||
type: str = field(default=MessageType.SYSTEM_STATUS)
|
||||
status: str = ""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "SystemStatusMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.SYSTEM_STATUS),
|
||||
ts=payload.get("ts", time.time()),
|
||||
status=payload.get("status", ""),
|
||||
data=payload.get("data", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionAckMessage(WSMessage):
|
||||
"""Acknowledgement sent when a client connects."""
|
||||
|
||||
type: str = field(default=MessageType.CONNECTION_ACK)
|
||||
client_id: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "ConnectionAckMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.CONNECTION_ACK),
|
||||
ts=payload.get("ts", time.time()),
|
||||
client_id=payload.get("client_id", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorMessage(WSMessage):
|
||||
"""Error message sent to a client."""
|
||||
|
||||
type: str = field(default=MessageType.ERROR)
|
||||
code: str = ""
|
||||
message: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "ErrorMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.ERROR),
|
||||
ts=payload.get("ts", time.time()),
|
||||
code=payload.get("code", ""),
|
||||
message=payload.get("message", ""),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskUpdateMessage(WSMessage):
|
||||
"""Update about a task (created, assigned, completed, etc.)."""
|
||||
|
||||
type: str = field(default=MessageType.TASK_UPDATE)
|
||||
task_id: str = ""
|
||||
status: str = ""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "TaskUpdateMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.TASK_UPDATE),
|
||||
ts=payload.get("ts", time.time()),
|
||||
task_id=payload.get("task_id", ""),
|
||||
status=payload.get("status", ""),
|
||||
data=payload.get("data", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFlashMessage(WSMessage):
|
||||
"""A flash of memory — a recalled or stored memory event."""
|
||||
|
||||
type: str = field(default=MessageType.MEMORY_FLASH)
|
||||
agent_id: str = ""
|
||||
memory_key: str = ""
|
||||
content: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> "MemoryFlashMessage":
|
||||
payload = json.loads(raw)
|
||||
return cls(
|
||||
type=payload.get("type", MessageType.MEMORY_FLASH),
|
||||
ts=payload.get("ts", time.time()),
|
||||
agent_id=payload.get("agent_id", ""),
|
||||
memory_key=payload.get("memory_key", ""),
|
||||
content=payload.get("content", ""),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry for from_json dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGISTRY: dict[str, type[WSMessage]] = {
|
||||
MessageType.AGENT_STATE: AgentStateMessage,
|
||||
MessageType.VISITOR_STATE: VisitorStateMessage,
|
||||
MessageType.BARK: BarkMessage,
|
||||
MessageType.THOUGHT: ThoughtMessage,
|
||||
MessageType.SYSTEM_STATUS: SystemStatusMessage,
|
||||
MessageType.CONNECTION_ACK: ConnectionAckMessage,
|
||||
MessageType.ERROR: ErrorMessage,
|
||||
MessageType.TASK_UPDATE: TaskUpdateMessage,
|
||||
MessageType.MEMORY_FLASH: MemoryFlashMessage,
|
||||
}
|
||||
173
tests/unit/test_protocol.py
Normal file
173
tests/unit/test_protocol.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Tests for infrastructure.protocol — WebSocket message types."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.protocol import (
|
||||
AgentStateMessage,
|
||||
BarkMessage,
|
||||
ConnectionAckMessage,
|
||||
ErrorMessage,
|
||||
MemoryFlashMessage,
|
||||
MessageType,
|
||||
SystemStatusMessage,
|
||||
TaskUpdateMessage,
|
||||
ThoughtMessage,
|
||||
VisitorStateMessage,
|
||||
WSMessage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MessageType enum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageType:
|
||||
"""MessageType enum covers all 9 Matrix PROTOCOL.md types."""
|
||||
|
||||
def test_has_all_nine_types(self):
|
||||
assert len(MessageType) == 9
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"member,value",
|
||||
[
|
||||
(MessageType.AGENT_STATE, "agent_state"),
|
||||
(MessageType.VISITOR_STATE, "visitor_state"),
|
||||
(MessageType.BARK, "bark"),
|
||||
(MessageType.THOUGHT, "thought"),
|
||||
(MessageType.SYSTEM_STATUS, "system_status"),
|
||||
(MessageType.CONNECTION_ACK, "connection_ack"),
|
||||
(MessageType.ERROR, "error"),
|
||||
(MessageType.TASK_UPDATE, "task_update"),
|
||||
(MessageType.MEMORY_FLASH, "memory_flash"),
|
||||
],
|
||||
)
|
||||
def test_enum_values(self, member, value):
|
||||
assert member.value == value
|
||||
|
||||
def test_str_comparison(self):
|
||||
"""MessageType is a str enum so it can be compared to plain strings."""
|
||||
assert MessageType.BARK == "bark"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# to_json / from_json round-trip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAgentStateMessage:
|
||||
def test_defaults(self):
|
||||
msg = AgentStateMessage()
|
||||
assert msg.type == "agent_state"
|
||||
assert msg.agent_id == ""
|
||||
assert msg.data == {}
|
||||
|
||||
def test_round_trip(self):
|
||||
msg = AgentStateMessage(agent_id="timmy", data={"mood": "happy"}, ts=1000.0)
|
||||
raw = msg.to_json()
|
||||
restored = AgentStateMessage.from_json(raw)
|
||||
assert restored.agent_id == "timmy"
|
||||
assert restored.data == {"mood": "happy"}
|
||||
assert restored.ts == 1000.0
|
||||
|
||||
def test_to_json_structure(self):
|
||||
msg = AgentStateMessage(agent_id="timmy", data={"x": 1}, ts=123.0)
|
||||
parsed = json.loads(msg.to_json())
|
||||
assert parsed["type"] == "agent_state"
|
||||
assert parsed["agent_id"] == "timmy"
|
||||
assert parsed["data"] == {"x": 1}
|
||||
assert parsed["ts"] == 123.0
|
||||
|
||||
|
||||
class TestVisitorStateMessage:
|
||||
def test_round_trip(self):
|
||||
msg = VisitorStateMessage(visitor_id="v1", data={"page": "/"}, ts=1.0)
|
||||
restored = VisitorStateMessage.from_json(msg.to_json())
|
||||
assert restored.visitor_id == "v1"
|
||||
assert restored.data == {"page": "/"}
|
||||
|
||||
|
||||
class TestBarkMessage:
|
||||
def test_round_trip(self):
|
||||
msg = BarkMessage(agent_id="timmy", content="woof!", ts=1.0)
|
||||
restored = BarkMessage.from_json(msg.to_json())
|
||||
assert restored.agent_id == "timmy"
|
||||
assert restored.content == "woof!"
|
||||
|
||||
|
||||
class TestThoughtMessage:
|
||||
def test_round_trip(self):
|
||||
msg = ThoughtMessage(agent_id="timmy", content="hmm...", ts=1.0)
|
||||
restored = ThoughtMessage.from_json(msg.to_json())
|
||||
assert restored.content == "hmm..."
|
||||
|
||||
|
||||
class TestSystemStatusMessage:
|
||||
def test_round_trip(self):
|
||||
msg = SystemStatusMessage(status="healthy", data={"uptime": 3600}, ts=1.0)
|
||||
restored = SystemStatusMessage.from_json(msg.to_json())
|
||||
assert restored.status == "healthy"
|
||||
assert restored.data == {"uptime": 3600}
|
||||
|
||||
|
||||
class TestConnectionAckMessage:
|
||||
def test_round_trip(self):
|
||||
msg = ConnectionAckMessage(client_id="abc-123", ts=1.0)
|
||||
restored = ConnectionAckMessage.from_json(msg.to_json())
|
||||
assert restored.client_id == "abc-123"
|
||||
|
||||
|
||||
class TestErrorMessage:
|
||||
def test_round_trip(self):
|
||||
msg = ErrorMessage(code="INVALID", message="bad request", ts=1.0)
|
||||
restored = ErrorMessage.from_json(msg.to_json())
|
||||
assert restored.code == "INVALID"
|
||||
assert restored.message == "bad request"
|
||||
|
||||
|
||||
class TestTaskUpdateMessage:
|
||||
def test_round_trip(self):
|
||||
msg = TaskUpdateMessage(task_id="t1", status="completed", data={"result": "ok"}, ts=1.0)
|
||||
restored = TaskUpdateMessage.from_json(msg.to_json())
|
||||
assert restored.task_id == "t1"
|
||||
assert restored.status == "completed"
|
||||
assert restored.data == {"result": "ok"}
|
||||
|
||||
|
||||
class TestMemoryFlashMessage:
|
||||
def test_round_trip(self):
|
||||
msg = MemoryFlashMessage(agent_id="timmy", memory_key="fav_food", content="kibble", ts=1.0)
|
||||
restored = MemoryFlashMessage.from_json(msg.to_json())
|
||||
assert restored.memory_key == "fav_food"
|
||||
assert restored.content == "kibble"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WSMessage.from_json dispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWSMessageDispatch:
|
||||
"""WSMessage.from_json dispatches to the correct subclass."""
|
||||
|
||||
def test_dispatch_to_bark(self):
|
||||
raw = json.dumps({"type": "bark", "agent_id": "t", "content": "woof", "ts": 1.0})
|
||||
msg = WSMessage.from_json(raw)
|
||||
assert isinstance(msg, BarkMessage)
|
||||
assert msg.content == "woof"
|
||||
|
||||
def test_dispatch_to_error(self):
|
||||
raw = json.dumps({"type": "error", "code": "E1", "message": "oops", "ts": 1.0})
|
||||
msg = WSMessage.from_json(raw)
|
||||
assert isinstance(msg, ErrorMessage)
|
||||
|
||||
def test_unknown_type_returns_base(self):
|
||||
raw = json.dumps({"type": "unknown_future_type", "ts": 1.0})
|
||||
msg = WSMessage.from_json(raw)
|
||||
assert type(msg) is WSMessage
|
||||
assert msg.type == "unknown_future_type"
|
||||
|
||||
def test_invalid_json_raises(self):
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
WSMessage.from_json("not json")
|
||||
Reference in New Issue
Block a user