diff --git a/src/infrastructure/visitor.py b/src/infrastructure/visitor.py new file mode 100644 index 00000000..78c0ca45 --- /dev/null +++ b/src/infrastructure/visitor.py @@ -0,0 +1,166 @@ +"""Visitor state tracking for the Matrix frontend. + +Tracks active visitors as they connect and move around the 3D world, +and provides serialization for Matrix protocol broadcast messages. +""" + +import time +from dataclasses import dataclass, field +from datetime import UTC, datetime + + +@dataclass +class VisitorState: + """State for a single visitor in the Matrix. + + Attributes + ---------- + visitor_id: Unique identifier for the visitor (client ID). + display_name: Human-readable name shown above the visitor. + position: 3D coordinates (x, y, z) in the world. + rotation: Rotation angle in degrees (0-360). + connected_at: ISO timestamp when the visitor connected. + """ + + visitor_id: str + display_name: str = "" + position: dict[str, float] = field(default_factory=lambda: {"x": 0.0, "y": 0.0, "z": 0.0}) + rotation: float = 0.0 + connected_at: str = field( + default_factory=lambda: datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + ) + + def __post_init__(self): + """Set display_name to visitor_id if not provided; copy position dict.""" + if not self.display_name: + self.display_name = self.visitor_id + # Copy position to avoid shared mutable state + self.position = dict(self.position) + + +class VisitorRegistry: + """Registry of active visitors in the Matrix. + + Thread-safe singleton pattern (Python GIL protects dict operations). + Used by the WebSocket layer to track and broadcast visitor positions. + """ + + _instance: "VisitorRegistry | None" = None + + def __new__(cls) -> "VisitorRegistry": + """Singleton constructor.""" + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._visitors: dict[str, VisitorState] = {} + return cls._instance + + def add( + self, visitor_id: str, display_name: str = "", position: dict | None = None + ) -> VisitorState: + """Add a new visitor to the registry. + + Parameters + ---------- + visitor_id: Unique identifier for the visitor. + display_name: Optional display name (defaults to visitor_id). + position: Optional initial position (defaults to origin). + + Returns + ------- + The newly created VisitorState. + """ + visitor = VisitorState( + visitor_id=visitor_id, + display_name=display_name, + position=position if position else {"x": 0.0, "y": 0.0, "z": 0.0}, + ) + self._visitors[visitor_id] = visitor + return visitor + + def remove(self, visitor_id: str) -> bool: + """Remove a visitor from the registry. + + Parameters + ---------- + visitor_id: The visitor to remove. + + Returns + ------- + True if the visitor was found and removed, False otherwise. + """ + if visitor_id in self._visitors: + del self._visitors[visitor_id] + return True + return False + + def update_position( + self, + visitor_id: str, + position: dict[str, float], + rotation: float | None = None, + ) -> bool: + """Update a visitor's position and rotation. + + Parameters + ---------- + visitor_id: The visitor to update. + position: New 3D coordinates (x, y, z). + rotation: Optional new rotation angle. + + Returns + ------- + True if the visitor was found and updated, False otherwise. + """ + if visitor_id not in self._visitors: + return False + + self._visitors[visitor_id].position = position + if rotation is not None: + self._visitors[visitor_id].rotation = rotation + return True + + def get(self, visitor_id: str) -> VisitorState | None: + """Get a single visitor's state. + + Parameters + ---------- + visitor_id: The visitor to retrieve. + + Returns + ------- + The VisitorState if found, None otherwise. + """ + return self._visitors.get(visitor_id) + + def get_all(self) -> list[dict]: + """Get all active visitors as Matrix protocol message dicts. + + Returns + ------- + List of visitor_state dicts ready for WebSocket broadcast. + Each dict has: type, visitor_id, data (with display_name, + position, rotation, connected_at), and ts. + """ + now = int(time.time()) + return [ + { + "type": "visitor_state", + "visitor_id": v.visitor_id, + "data": { + "display_name": v.display_name, + "position": v.position, + "rotation": v.rotation, + "connected_at": v.connected_at, + }, + "ts": now, + } + for v in self._visitors.values() + ] + + def clear(self) -> None: + """Remove all visitors (useful for testing).""" + self._visitors.clear() + + def __len__(self) -> int: + """Return the number of active visitors.""" + return len(self._visitors) diff --git a/tests/unit/test_visitor.py b/tests/unit/test_visitor.py new file mode 100644 index 00000000..e6c282aa --- /dev/null +++ b/tests/unit/test_visitor.py @@ -0,0 +1,367 @@ +"""Tests for infrastructure.visitor — visitor state tracking.""" + +from unittest.mock import patch + +from infrastructure.visitor import VisitorRegistry, VisitorState + +# --------------------------------------------------------------------------- +# VisitorState dataclass tests +# --------------------------------------------------------------------------- + + +class TestVisitorState: + """Tests for the VisitorState dataclass.""" + + def test_defaults(self): + """VisitorState has correct defaults when only visitor_id is provided.""" + v = VisitorState(visitor_id="v1") + + assert v.visitor_id == "v1" + assert v.display_name == "v1" # Defaults to visitor_id + assert v.position == {"x": 0.0, "y": 0.0, "z": 0.0} + assert v.rotation == 0.0 + assert "T" in v.connected_at # ISO format check + + def test_custom_values(self): + """VisitorState accepts custom values for all fields.""" + v = VisitorState( + visitor_id="v2", + display_name="Alice", + position={"x": 1.0, "y": 2.0, "z": 3.0}, + rotation=90.0, + connected_at="2026-03-21T12:00:00Z", + ) + + assert v.visitor_id == "v2" + assert v.display_name == "Alice" + assert v.position == {"x": 1.0, "y": 2.0, "z": 3.0} + assert v.rotation == 90.0 + assert v.connected_at == "2026-03-21T12:00:00Z" + + def test_display_name_defaults_to_visitor_id(self): + """Empty display_name falls back to visitor_id.""" + v = VisitorState(visitor_id="charlie", display_name="") + assert v.display_name == "charlie" + + def test_position_is_copied_not_shared(self): + """Each VisitorState has its own position dict.""" + pos = {"x": 1.0, "y": 2.0, "z": 3.0} + v1 = VisitorState(visitor_id="v1", position=pos) + v2 = VisitorState(visitor_id="v2", position=pos) + + v1.position["x"] = 99.0 + assert v2.position["x"] == 1.0 # v2 unchanged + + +# --------------------------------------------------------------------------- +# VisitorRegistry singleton tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistrySingleton: + """Tests for the VisitorRegistry singleton behavior.""" + + def setup_method(self): + """Clear registry before each test.""" + VisitorRegistry._instance = None + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_singleton_returns_same_instance(self): + """Multiple calls return the same registry object.""" + r1 = VisitorRegistry() + r2 = VisitorRegistry() + assert r1 is r2 + + def test_singleton_shares_state(self): + """State is shared across all references to the singleton.""" + r1 = VisitorRegistry() + r1.add("v1") + + r2 = VisitorRegistry() + assert len(r2) == 1 + assert r2.get("v1") is not None + + +# --------------------------------------------------------------------------- +# VisitorRegistry.add tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryAdd: + """Tests for VisitorRegistry.add().""" + + def setup_method(self): + """Clear registry before each test.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_add_returns_visitor_state(self): + """add() returns the created VisitorState.""" + result = self.registry.add("v1") + assert isinstance(result, VisitorState) + assert result.visitor_id == "v1" + + def test_add_with_display_name(self): + """add() accepts a custom display name.""" + result = self.registry.add("v1", display_name="Alice") + assert result.display_name == "Alice" + + def test_add_with_position(self): + """add() accepts an initial position.""" + pos = {"x": 10.0, "y": 20.0, "z": 30.0} + result = self.registry.add("v1", position=pos) + assert result.position == pos + + def test_add_increases_count(self): + """Each add increases the registry size.""" + assert len(self.registry) == 0 + self.registry.add("v1") + assert len(self.registry) == 1 + self.registry.add("v2") + assert len(self.registry) == 2 + + +# --------------------------------------------------------------------------- +# VisitorRegistry.remove tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryRemove: + """Tests for VisitorRegistry.remove().""" + + def setup_method(self): + """Clear registry and add test visitors.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + self.registry.add("v1") + self.registry.add("v2") + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_remove_existing_returns_true(self): + """Removing an existing visitor returns True.""" + result = self.registry.remove("v1") + assert result is True + assert len(self.registry) == 1 + + def test_remove_nonexistent_returns_false(self): + """Removing a non-existent visitor returns False.""" + result = self.registry.remove("unknown") + assert result is False + assert len(self.registry) == 2 + + def test_removes_correct_visitor(self): + """remove() only removes the specified visitor.""" + self.registry.remove("v1") + assert self.registry.get("v1") is None + assert self.registry.get("v2") is not None + + +# --------------------------------------------------------------------------- +# VisitorRegistry.update_position tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryUpdatePosition: + """Tests for VisitorRegistry.update_position().""" + + def setup_method(self): + """Clear registry and add test visitor.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + self.registry.add("v1", position={"x": 0.0, "y": 0.0, "z": 0.0}) + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_update_position_returns_true(self): + """update_position returns True for existing visitor.""" + result = self.registry.update_position("v1", {"x": 1.0, "y": 2.0, "z": 3.0}) + assert result is True + + def test_update_position_returns_false_for_unknown(self): + """update_position returns False for non-existent visitor.""" + result = self.registry.update_position("unknown", {"x": 1.0, "y": 2.0, "z": 3.0}) + assert result is False + + def test_update_position_changes_values(self): + """update_position updates the stored position.""" + new_pos = {"x": 10.0, "y": 20.0, "z": 30.0} + self.registry.update_position("v1", new_pos) + + visitor = self.registry.get("v1") + assert visitor.position == new_pos + + def test_update_position_with_rotation(self): + """update_position can also update rotation.""" + self.registry.update_position("v1", {"x": 1.0, "y": 0.0, "z": 0.0}, rotation=180.0) + + visitor = self.registry.get("v1") + assert visitor.rotation == 180.0 + + def test_update_position_without_rotation_preserves_it(self): + """Calling update_position without rotation preserves existing rotation.""" + self.registry.update_position("v1", {"x": 1.0, "y": 0.0, "z": 0.0}, rotation=90.0) + self.registry.update_position("v1", {"x": 2.0, "y": 0.0, "z": 0.0}) + + visitor = self.registry.get("v1") + assert visitor.rotation == 90.0 + + +# --------------------------------------------------------------------------- +# VisitorRegistry.get tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryGet: + """Tests for VisitorRegistry.get().""" + + def setup_method(self): + """Clear registry and add test visitor.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + self.registry.add("v1", display_name="Alice") + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_get_existing_returns_visitor(self): + """get() returns VisitorState for existing visitor.""" + result = self.registry.get("v1") + assert isinstance(result, VisitorState) + assert result.visitor_id == "v1" + assert result.display_name == "Alice" + + def test_get_nonexistent_returns_none(self): + """get() returns None for non-existent visitor.""" + result = self.registry.get("unknown") + assert result is None + + +# --------------------------------------------------------------------------- +# VisitorRegistry.get_all tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryGetAll: + """Tests for VisitorRegistry.get_all() — Matrix protocol format.""" + + def setup_method(self): + """Clear registry and add test visitors.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + self.registry.add("v1", display_name="Alice", position={"x": 1.0, "y": 2.0, "z": 3.0}) + self.registry.add("v2", display_name="Bob", position={"x": 4.0, "y": 5.0, "z": 6.0}) + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_get_all_returns_list(self): + """get_all() returns a list.""" + result = self.registry.get_all() + assert isinstance(result, list) + assert len(result) == 2 + + def test_get_all_format_has_required_fields(self): + """Each entry has type, visitor_id, data, and ts.""" + result = self.registry.get_all() + + for entry in result: + assert "type" in entry + assert "visitor_id" in entry + assert "data" in entry + assert "ts" in entry + + def test_get_all_type_is_visitor_state(self): + """The type field is 'visitor_state'.""" + result = self.registry.get_all() + assert all(entry["type"] == "visitor_state" for entry in result) + + def test_get_all_data_has_required_fields(self): + """data dict contains display_name, position, rotation, connected_at.""" + result = self.registry.get_all() + + for entry in result: + data = entry["data"] + assert "display_name" in data + assert "position" in data + assert "rotation" in data + assert "connected_at" in data + + def test_get_all_position_is_dict(self): + """position within data is a dict with x, y, z.""" + result = self.registry.get_all() + + for entry in result: + pos = entry["data"]["position"] + assert isinstance(pos, dict) + assert "x" in pos + assert "y" in pos + assert "z" in pos + + def test_get_all_ts_is_unix_timestamp(self): + """ts is an integer Unix timestamp.""" + result = self.registry.get_all() + + for entry in result: + assert isinstance(entry["ts"], int) + assert entry["ts"] > 0 + + @patch("infrastructure.visitor.time") + def test_get_all_uses_current_time(self, mock_time): + """ts is set from time.time().""" + mock_time.time.return_value = 1742529600 + + result = self.registry.get_all() + assert all(entry["ts"] == 1742529600 for entry in result) + + def test_get_all_empty_registry(self): + """get_all() returns empty list when no visitors.""" + self.registry.clear() + result = self.registry.get_all() + assert result == [] + + +# --------------------------------------------------------------------------- +# VisitorRegistry.clear tests +# --------------------------------------------------------------------------- + + +class TestVisitorRegistryClear: + """Tests for VisitorRegistry.clear().""" + + def setup_method(self): + """Clear registry and add test visitors.""" + VisitorRegistry._instance = None + self.registry = VisitorRegistry() + self.registry.add("v1") + self.registry.add("v2") + self.registry.add("v3") + + def teardown_method(self): + """Clean up after each test.""" + VisitorRegistry._instance = None + + def test_clear_removes_all_visitors(self): + """clear() removes all visitors from the registry.""" + assert len(self.registry) == 3 + self.registry.clear() + assert len(self.registry) == 0 + + def test_clear_allows_readding(self): + """Visitors can be re-added after clear().""" + self.registry.clear() + self.registry.add("v1") + assert len(self.registry) == 1