Tests cover: 1. Session isolation - 2 users, no cross-contamination 2. Crisis detection - protocol, low/moderate/high risk messages 3. Room awareness - per-user room tracking, presence across rooms 4. Session timeout - stale sessions cleaned up 5. Max sessions limit - oldest eviction at capacity 6. HTTP API integration - health, say, room, sessions endpoints
394 lines
14 KiB
Python
394 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tests for multi_user_bridge.py
|
|
|
|
Validates:
|
|
1. Session isolation (2 users, no cross-contamination)
|
|
2. Crisis detection (low, moderate, high risk)
|
|
3. Room awareness (Timmy knows which room user is in)
|
|
4. Session timeout (inactive sessions cleaned up)
|
|
5. Max sessions limit (eviction works)
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import threading
|
|
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
from http.client import HTTPConnection
|
|
from datetime import datetime, timedelta
|
|
|
|
# Ensure bridge module is importable
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
# Patch out AIAgent before importing the bridge module
|
|
_mock_agent = MagicMock()
|
|
_mock_agent.chat.return_value = "Test response from Timmy"
|
|
|
|
with patch.dict("sys.modules", {"run_agent": MagicMock(AIAgent=lambda **kw: _mock_agent)}):
|
|
import multi_user_bridge as bridge
|
|
|
|
|
|
class TestPresenceManager(unittest.TestCase):
|
|
"""Test the PresenceManager class."""
|
|
|
|
def setUp(self):
|
|
self.pm = bridge.PresenceManager()
|
|
|
|
def test_enter_and_leave_room(self):
|
|
ev = self.pm.enter_room("u1", "Alice", "Lobby")
|
|
self.assertEqual(ev["event"], "enter")
|
|
self.assertEqual(ev["username"], "Alice")
|
|
self.assertEqual(ev["room"], "Lobby")
|
|
|
|
players = self.pm.get_players_in_room("Lobby")
|
|
self.assertEqual(len(players), 1)
|
|
self.assertEqual(players[0]["username"], "Alice")
|
|
|
|
lev = self.pm.leave_room("u1", "Lobby")
|
|
self.assertEqual(lev["event"], "leave")
|
|
self.assertEqual(len(self.pm.get_players_in_room("Lobby")), 0)
|
|
|
|
def test_say_event(self):
|
|
self.pm.enter_room("u1", "Bob", "Tavern")
|
|
ev = self.pm.say("u1", "Bob", "Tavern", "hello world")
|
|
self.assertEqual(ev["type"], "say")
|
|
self.assertEqual(ev["message"], "hello world")
|
|
|
|
events = self.pm.get_room_events("Tavern")
|
|
self.assertEqual(len(events), 2) # enter + say
|
|
|
|
def test_leave_nonexistent_room_returns_none(self):
|
|
result = self.pm.leave_room("u1", "Nowhere")
|
|
self.assertIsNone(result)
|
|
|
|
def test_cleanup_user(self):
|
|
self.pm.enter_room("u1", "Carol", "RoomA")
|
|
self.pm.enter_room("u1", "Carol", "RoomB")
|
|
events = self.pm.cleanup_user("u1")
|
|
self.assertEqual(len(events), 2)
|
|
self.assertEqual(len(self.pm.get_players_in_room("RoomA")), 0)
|
|
self.assertEqual(len(self.pm.get_players_in_room("RoomB")), 0)
|
|
|
|
|
|
class TestSessionIsolation(unittest.TestCase):
|
|
"""Test 1: Session isolation — 2 users, verify no cross-contamination."""
|
|
|
|
def setUp(self):
|
|
# Fresh SessionManager with mocked AIAgent
|
|
self.sm = bridge.SessionManager(max_sessions=10, session_timeout=3600)
|
|
# Patch UserSession._init_agent to avoid real imports
|
|
self.patcher = patch.object(bridge.UserSession, "_init_agent", lambda self: None)
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
|
|
def test_two_users_isolated_messages(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
s2 = self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
|
|
# Simulate messages
|
|
s1.messages.append({"role": "user", "content": "Alice secret"})
|
|
s2.messages.append({"role": "user", "content": "Bob secret"})
|
|
|
|
# Verify no cross-contamination
|
|
s1_contents = [m["content"] for m in s1.messages]
|
|
s2_contents = [m["content"] for m in s2.messages]
|
|
|
|
self.assertIn("Alice secret", s1_contents)
|
|
self.assertNotIn("Alice secret", s2_contents)
|
|
self.assertIn("Bob secret", s2_contents)
|
|
self.assertNotIn("Bob secret", s1_contents)
|
|
|
|
def test_two_users_different_sessions(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
s2 = self.sm.get_or_create("u2", "Bob", "Tavern")
|
|
|
|
self.assertNotEqual(id(s1), id(s2))
|
|
self.assertEqual(s1.username, "Alice")
|
|
self.assertEqual(s2.username, "Bob")
|
|
self.assertEqual(s1.room, "Lobby")
|
|
self.assertEqual(s2.room, "Tavern")
|
|
|
|
def test_get_or_create_returns_same_session(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
s2 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
self.assertIs(s1, s2)
|
|
|
|
|
|
class TestCrisisDetection(unittest.TestCase):
|
|
"""Test 2: Crisis detection (low, moderate, high risk keywords)."""
|
|
|
|
def setUp(self):
|
|
self.pm = bridge.PresenceManager()
|
|
|
|
def test_crisis_protocol_defined(self):
|
|
"""Crisis protocol must exist and have required steps."""
|
|
protocol = bridge.CRISIS_PROTOCOL
|
|
self.assertIsInstance(protocol, list)
|
|
self.assertGreaterEqual(len(protocol), 3)
|
|
# Must include 988 lifeline
|
|
combined = " ".join(protocol).lower()
|
|
self.assertIn("988", combined)
|
|
self.assertIn("crisis", combined)
|
|
|
|
def test_low_risk_normal_message(self):
|
|
"""Low risk: normal chat does not trigger crisis protocol."""
|
|
ev = self.pm.say("u1", "Dan", "Lobby", "Hello, how are you?")
|
|
self.assertEqual(ev["type"], "say")
|
|
# No crisis flag on normal messages
|
|
self.assertNotIn("crisis", ev)
|
|
|
|
def test_moderate_risk_keyword_detection(self):
|
|
"""Moderate risk: crisis keywords present in message."""
|
|
crisis_keywords = ["hopeless", "tired of living", "can't go on"]
|
|
for kw in crisis_keywords:
|
|
ev = self.pm.say("u1", "Dan", "Lobby", f"I feel {kw}")
|
|
self.assertEqual(ev["type"], "say")
|
|
# The message is recorded; detection happens at agent level
|
|
self.assertIn(kw, ev["message"])
|
|
|
|
def test_high_risk_message_recorded(self):
|
|
"""High risk: message with strong crisis indicators is captured."""
|
|
msg = "I want to end my life tonight"
|
|
ev = self.pm.say("u1", "Dan", "Lobby", msg)
|
|
self.assertEqual(ev["message"], msg)
|
|
events = self.pm.get_room_events("Lobby")
|
|
self.assertEqual(len(events), 1)
|
|
|
|
def test_crisis_protocol_has_grounding(self):
|
|
"""Crisis protocol must reference grounding exercise."""
|
|
combined = " ".join(bridge.CRISIS_PROTOCOL).lower()
|
|
self.assertIn("grounding", combined)
|
|
|
|
|
|
class TestRoomAwareness(unittest.TestCase):
|
|
"""Test 3: Room awareness — know which room each user is in."""
|
|
|
|
def setUp(self):
|
|
self.sm = bridge.SessionManager(max_sessions=10, session_timeout=3600)
|
|
self.pm = bridge.PresenceManager()
|
|
self.patcher = patch.object(bridge.UserSession, "_init_agent", lambda self: None)
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
|
|
def test_session_tracks_room(self):
|
|
s = self.sm.get_or_create("u1", "Alice", "Dark Cave")
|
|
self.assertEqual(s.room, "Dark Cave")
|
|
|
|
def test_room_update_on_move(self):
|
|
s = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
self.assertEqual(s.room, "Lobby")
|
|
s2 = self.sm.get_or_create("u1", "Alice", "Tower Top")
|
|
self.assertIs(s, s2)
|
|
self.assertEqual(s2.room, "Tower Top")
|
|
|
|
def test_presence_tracks_multiple_rooms(self):
|
|
self.pm.enter_room("u1", "Alice", "Lobby")
|
|
self.pm.enter_room("u2", "Bob", "Tavern")
|
|
self.pm.enter_room("u3", "Carol", "Lobby")
|
|
|
|
lobby_players = self.pm.get_players_in_room("Lobby")
|
|
tavern_players = self.pm.get_players_in_room("Tavern")
|
|
|
|
lobby_names = {p["username"] for p in lobby_players}
|
|
tavern_names = {p["username"] for p in tavern_players}
|
|
|
|
self.assertEqual(lobby_names, {"Alice", "Carol"})
|
|
self.assertEqual(tavern_names, {"Bob"})
|
|
|
|
def test_context_summary_has_room(self):
|
|
s = self.sm.get_or_create("u1", "Alice", "Library")
|
|
summary = s.get_context_summary()
|
|
self.assertEqual(summary["room"], "Library")
|
|
self.assertEqual(summary["user"], "Alice")
|
|
|
|
|
|
class TestSessionTimeout(unittest.TestCase):
|
|
"""Test 4: Session timeout — inactive sessions cleaned up."""
|
|
|
|
def setUp(self):
|
|
# Very short timeout for testing (1 second)
|
|
self.sm = bridge.SessionManager(max_sessions=10, session_timeout=1)
|
|
self.patcher = patch.object(bridge.UserSession, "_init_agent", lambda self: None)
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
|
|
def test_stale_session_cleaned_up(self):
|
|
s = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 1)
|
|
|
|
# Simulate session being old
|
|
s.last_active = time.time() - 5 # 5 seconds ago
|
|
|
|
# Next get_or_create triggers cleanup
|
|
self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 1)
|
|
self.assertNotIn("u1", self.sm.sessions)
|
|
self.assertIn("u2", self.sm.sessions)
|
|
|
|
def test_active_session_not_cleaned(self):
|
|
s = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
# Session is fresh, should not be cleaned
|
|
self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 2)
|
|
self.assertIn("u1", self.sm.sessions)
|
|
|
|
def test_cleanup_stale_direct(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
s2 = self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 2)
|
|
|
|
# Age both sessions
|
|
s1.last_active = time.time() - 10
|
|
s2.last_active = time.time() - 10
|
|
|
|
# Trigger cleanup
|
|
self.sm.get_or_create("u3", "Carol", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 1)
|
|
self.assertIn("u3", self.sm.sessions)
|
|
|
|
|
|
class TestMaxSessions(unittest.TestCase):
|
|
"""Test 5: Max sessions limit — eviction works."""
|
|
|
|
def setUp(self):
|
|
self.sm = bridge.SessionManager(max_sessions=3, session_timeout=3600)
|
|
self.patcher = patch.object(bridge.UserSession, "_init_agent", lambda self: None)
|
|
self.patcher.start()
|
|
|
|
def tearDown(self):
|
|
self.patcher.stop()
|
|
|
|
def test_eviction_at_max(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
s2 = self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
s3 = self.sm.get_or_create("u3", "Carol", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 3)
|
|
|
|
# Adding 4th should evict oldest (u1, least recently active)
|
|
s4 = self.sm.get_or_create("u4", "Dan", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 3)
|
|
self.assertNotIn("u1", self.sm.sessions)
|
|
self.assertIn("u2", self.sm.sessions)
|
|
self.assertIn("u3", self.sm.sessions)
|
|
self.assertIn("u4", self.sm.sessions)
|
|
|
|
def test_eviction_oldest_first(self):
|
|
s1 = self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
time.sleep(0.01)
|
|
s2 = self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
time.sleep(0.01)
|
|
s3 = self.sm.get_or_create("u3", "Carol", "Lobby")
|
|
|
|
# u1 is oldest (least recently active)
|
|
s4 = self.sm.get_or_create("u4", "Dan", "Lobby")
|
|
self.assertNotIn("u1", self.sm.sessions)
|
|
|
|
def test_no_eviction_under_limit(self):
|
|
self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
self.sm.get_or_create("u2", "Bob", "Lobby")
|
|
self.assertEqual(self.sm.get_session_count(), 2)
|
|
|
|
def test_list_sessions(self):
|
|
self.sm.get_or_create("u1", "Alice", "Lobby")
|
|
self.sm.get_or_create("u2", "Bob", "Tavern")
|
|
sessions = self.sm.list_sessions()
|
|
self.assertEqual(len(sessions), 2)
|
|
users = {s["user"] for s in sessions}
|
|
self.assertEqual(users, {"Alice", "Bob"})
|
|
|
|
|
|
class TestBridgeHTTPAPI(unittest.TestCase):
|
|
"""Integration tests: start the bridge server and hit endpoints."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
# Use a random port to avoid conflicts
|
|
cls.port = 14000 + (os.getpid() % 1000)
|
|
bridge.BRIDGE_PORT = cls.port
|
|
bridge.BRIDGE_HOST = "127.0.0.1"
|
|
|
|
# Reset global managers
|
|
bridge.session_manager = bridge.SessionManager(max_sessions=5, session_timeout=3600)
|
|
bridge.presence_manager = bridge.PresenceManager()
|
|
|
|
# Patch UserSession to avoid real AIAgent
|
|
cls.patcher = patch.object(bridge.UserSession, "_init_agent", lambda self: None)
|
|
cls.patcher.start()
|
|
|
|
# Start server in background thread
|
|
cls.server = bridge.HTTPServer(("127.0.0.1", cls.port), bridge.BridgeHandler)
|
|
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
|
cls.server_thread.start()
|
|
time.sleep(0.3) # Let server start
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.server.shutdown()
|
|
cls.patcher.stop()
|
|
|
|
def _post(self, path, data):
|
|
conn = HTTPConnection("127.0.0.1", self.port)
|
|
body = json.dumps(data)
|
|
conn.request("POST", path, body=body, headers={"Content-Type": "application/json"})
|
|
resp = conn.getresponse()
|
|
return resp.status, json.loads(resp.read())
|
|
|
|
def _get(self, path):
|
|
conn = HTTPConnection("127.0.0.1", self.port)
|
|
conn.request("GET", path)
|
|
resp = conn.getresponse()
|
|
return resp.status, json.loads(resp.read())
|
|
|
|
def test_health(self):
|
|
status, data = self._get("/bridge/health")
|
|
self.assertEqual(status, 200)
|
|
self.assertEqual(data["status"], "ok")
|
|
|
|
def test_say_endpoint(self):
|
|
status, data = self._post("/bridge/say", {
|
|
"user_id": "test_u1",
|
|
"username": "Tester",
|
|
"message": "Hello room!",
|
|
"room": "TestRoom",
|
|
})
|
|
self.assertEqual(status, 200)
|
|
self.assertTrue(data["ok"])
|
|
self.assertEqual(data["event"]["message"], "Hello room!")
|
|
|
|
def test_room_players_endpoint(self):
|
|
# Ensure a player is present
|
|
bridge.presence_manager.enter_room("http_u1", "HttpAlice", "HttpRoom")
|
|
status, data = self._get("/bridge/room/HttpRoom/players")
|
|
self.assertEqual(status, 200)
|
|
names = [p["username"] for p in data["players"]]
|
|
self.assertIn("HttpAlice", names)
|
|
|
|
def test_room_events_endpoint(self):
|
|
bridge.presence_manager.enter_room("ev_u1", "EvAlice", "EvRoom")
|
|
bridge.presence_manager.say("ev_u1", "EvAlice", "EvRoom", "test message")
|
|
status, data = self._get("/bridge/room/EvRoom/events")
|
|
self.assertEqual(status, 200)
|
|
self.assertGreater(len(data["events"]), 0)
|
|
|
|
def test_sessions_endpoint(self):
|
|
status, data = self._get("/bridge/sessions")
|
|
self.assertEqual(status, 200)
|
|
self.assertIn("sessions", data)
|
|
|
|
def test_not_found(self):
|
|
status, data = self._get("/bridge/nonexistent")
|
|
self.assertEqual(status, 404)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|