Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 08f1d0bc8d | |||
| 42a9f6366c |
189
gateway/message_dedup.py
Normal file
189
gateway/message_dedup.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Gateway Message Deduplication — Prevent double-posting.
|
||||||
|
|
||||||
|
Provides idempotent message delivery by tracking message UUIDs
|
||||||
|
and suppressing duplicates within a configurable time window.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Optional, Set
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageRecord:
|
||||||
|
"""Record of a sent message."""
|
||||||
|
message_id: str
|
||||||
|
content_hash: str
|
||||||
|
timestamp: float
|
||||||
|
session_id: str
|
||||||
|
platform: str
|
||||||
|
|
||||||
|
|
||||||
|
class MessageDeduplicator:
|
||||||
|
"""
|
||||||
|
Deduplicates outbound messages within a time window.
|
||||||
|
|
||||||
|
Each message gets a UUID. If the same message (by content hash)
|
||||||
|
is sent again within the window, it's suppressed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_seconds: int = 60, max_records: int = 1000):
|
||||||
|
"""
|
||||||
|
Initialize deduplicator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window_seconds: Time window for deduplication (default 60s)
|
||||||
|
max_records: Maximum records to keep in memory
|
||||||
|
"""
|
||||||
|
self.window_seconds = window_seconds
|
||||||
|
self.max_records = max_records
|
||||||
|
self._records: OrderedDict[str, MessageRecord] = OrderedDict()
|
||||||
|
self._suppressed_count = 0
|
||||||
|
|
||||||
|
def _content_hash(self, content: str, session_id: str = "", platform: str = "") -> str:
|
||||||
|
"""Generate hash for message content."""
|
||||||
|
combined = f"{session_id}:{platform}:{content}"
|
||||||
|
return hashlib.sha256(combined.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
def _cleanup_old_records(self):
|
||||||
|
"""Remove records older than the dedup window."""
|
||||||
|
cutoff = time.time() - self.window_seconds
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for msg_id, record in self._records.items():
|
||||||
|
if record.timestamp < cutoff:
|
||||||
|
to_remove.append(msg_id)
|
||||||
|
|
||||||
|
for msg_id in to_remove:
|
||||||
|
del self._records[msg_id]
|
||||||
|
|
||||||
|
def _enforce_max_records(self):
|
||||||
|
"""Enforce maximum record count by removing oldest."""
|
||||||
|
while len(self._records) > self.max_records:
|
||||||
|
self._records.popitem(last=False)
|
||||||
|
|
||||||
|
def check_duplicate(self, content: str, session_id: str = "", platform: str = "") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Check if message is a duplicate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content
|
||||||
|
session_id: Session identifier
|
||||||
|
platform: Platform name (telegram, discord, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message ID if duplicate found, None if new message
|
||||||
|
"""
|
||||||
|
self._cleanup_old_records()
|
||||||
|
|
||||||
|
content_hash = self._content_hash(content, session_id, platform)
|
||||||
|
|
||||||
|
for msg_id, record in self._records.items():
|
||||||
|
if record.content_hash == content_hash:
|
||||||
|
age = time.time() - record.timestamp
|
||||||
|
if age < self.window_seconds:
|
||||||
|
self._suppressed_count += 1
|
||||||
|
logger.info(
|
||||||
|
"Suppressed duplicate message (age: %.1fs, original: %s)",
|
||||||
|
age, msg_id
|
||||||
|
)
|
||||||
|
return msg_id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def record_message(self, content: str, session_id: str = "", platform: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Record a sent message and return its UUID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content
|
||||||
|
session_id: Session identifier
|
||||||
|
platform: Platform name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UUID for this message
|
||||||
|
"""
|
||||||
|
self._cleanup_old_records()
|
||||||
|
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
content_hash = self._content_hash(content, session_id, platform)
|
||||||
|
|
||||||
|
self._records[message_id] = MessageRecord(
|
||||||
|
message_id=message_id,
|
||||||
|
content_hash=content_hash,
|
||||||
|
timestamp=time.time(),
|
||||||
|
session_id=session_id,
|
||||||
|
platform=platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._enforce_max_records()
|
||||||
|
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
def should_send(self, content: str, session_id: str = "", platform: str = "") -> bool:
|
||||||
|
"""
|
||||||
|
Check if message should be sent (not a duplicate).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content
|
||||||
|
session_id: Session identifier
|
||||||
|
platform: Platform name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if message should be sent, False if duplicate
|
||||||
|
"""
|
||||||
|
return self.check_duplicate(content, session_id, platform) is None
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict:
|
||||||
|
"""Get deduplication statistics."""
|
||||||
|
return {
|
||||||
|
"total_records": len(self._records),
|
||||||
|
"suppressed_count": self._suppressed_count,
|
||||||
|
"window_seconds": self.window_seconds,
|
||||||
|
"max_records": self.max_records,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all records."""
|
||||||
|
self._records.clear()
|
||||||
|
self._suppressed_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Global deduplicator instance
|
||||||
|
_deduplicator: Optional[MessageDeduplicator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_deduplicator() -> MessageDeduplicator:
|
||||||
|
"""Get or create global deduplicator instance."""
|
||||||
|
global _deduplicator
|
||||||
|
if _deduplicator is None:
|
||||||
|
_deduplicator = MessageDeduplicator()
|
||||||
|
return _deduplicator
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_message(content: str, session_id: str = "", platform: str = "") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Check if message is duplicate. Returns message_id if duplicate, None if new.
|
||||||
|
"""
|
||||||
|
return get_deduplicator().check_duplicate(content, session_id, platform)
|
||||||
|
|
||||||
|
|
||||||
|
def record_sent_message(content: str, session_id: str = "", platform: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Record a sent message. Returns UUID for the message.
|
||||||
|
"""
|
||||||
|
return get_deduplicator().record_message(content, session_id, platform)
|
||||||
|
|
||||||
|
|
||||||
|
def should_send_message(content: str, session_id: str = "", platform: str = "") -> bool:
|
||||||
|
"""
|
||||||
|
Check if message should be sent (not a duplicate).
|
||||||
|
"""
|
||||||
|
return get_deduplicator().should_send(content, session_id, platform)
|
||||||
57
tests/test_message_dedup.py
Normal file
57
tests/test_message_dedup.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
Tests for message deduplication (#756).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
from gateway.message_dedup import MessageDeduplicator
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageDeduplicator:
|
||||||
|
def test_first_message_allowed(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
assert dedup.should_send("Hello") is True
|
||||||
|
|
||||||
|
def test_duplicate_suppressed(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
dedup.record_message("Hello", "session1", "telegram")
|
||||||
|
assert dedup.should_send("Hello", "session1", "telegram") is False
|
||||||
|
|
||||||
|
def test_different_session_allowed(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
dedup.record_message("Hello", "session1", "telegram")
|
||||||
|
assert dedup.should_send("Hello", "session2", "telegram") is True
|
||||||
|
|
||||||
|
def test_different_platform_allowed(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
dedup.record_message("Hello", "session1", "telegram")
|
||||||
|
assert dedup.should_send("Hello", "session1", "discord") is True
|
||||||
|
|
||||||
|
def test_different_content_allowed(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
dedup.record_message("Hello", "session1", "telegram")
|
||||||
|
assert dedup.should_send("World", "session1", "telegram") is True
|
||||||
|
|
||||||
|
def test_window_expiry(self):
|
||||||
|
dedup = MessageDeduplicator(window_seconds=1)
|
||||||
|
dedup.record_message("Hello", "session1", "telegram")
|
||||||
|
time.sleep(1.1)
|
||||||
|
assert dedup.should_send("Hello", "session1", "telegram") is True
|
||||||
|
|
||||||
|
def test_record_returns_uuid(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
msg_id = dedup.record_message("Hello")
|
||||||
|
assert msg_id is not None
|
||||||
|
assert len(msg_id) == 36 # UUID format
|
||||||
|
|
||||||
|
def test_stats(self):
|
||||||
|
dedup = MessageDeduplicator()
|
||||||
|
dedup.record_message("Hello")
|
||||||
|
dedup.record_message("Hello") # duplicate
|
||||||
|
stats = dedup.get_stats()
|
||||||
|
assert stats["total_records"] == 1
|
||||||
|
assert stats["suppressed_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user