[claude] TES3MP server hardening — multi-player stability & anti-grief (#860) (#1321)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled

This commit was merged in pull request #1321.
This commit is contained in:
2026-03-24 02:13:57 +00:00
parent af162f1a80
commit d4e5a5d293
8 changed files with 1624 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
"""TES3MP server hardening — multi-player stability and anti-grief.
Provides:
- ``MultiClientStressRunner`` — concurrent-client stress testing (Phase 8)
- ``QuestArbiter`` — quest-state conflict resolution
- ``AntiGriefPolicy`` — rate limiting and blocked-action enforcement
- ``RecoveryManager`` — crash recovery with state preservation
- ``WorldStateBackup`` — rotating world-state backups
- ``ResourceMonitor`` — CPU/RAM/disk monitoring under load
"""
from infrastructure.world.hardening.anti_grief import AntiGriefPolicy
from infrastructure.world.hardening.backup import WorldStateBackup
from infrastructure.world.hardening.monitor import ResourceMonitor
from infrastructure.world.hardening.quest_arbiter import QuestArbiter
from infrastructure.world.hardening.recovery import RecoveryManager
from infrastructure.world.hardening.stress import MultiClientStressRunner
__all__ = [
"AntiGriefPolicy",
"WorldStateBackup",
"ResourceMonitor",
"QuestArbiter",
"RecoveryManager",
"MultiClientStressRunner",
]

View File

@@ -0,0 +1,147 @@
"""Anti-grief policy for community agent deployments.
Enforces two controls:
1. **Blocked actions** — a configurable set of action names that are
never permitted (e.g. ``destroy``, ``kill_npc``, ``steal``).
2. **Rate limiting** — a sliding-window counter per player that caps the
number of actions in a given time window.
Usage::
policy = AntiGriefPolicy(max_actions_per_window=30, window_seconds=60.0)
result = policy.check("player-01", command)
if result is not None:
# action blocked — return result to the caller
return result
# proceed with the action
"""
from __future__ import annotations
import logging
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from datetime import UTC, datetime
from infrastructure.world.types import ActionResult, ActionStatus, CommandInput
logger = logging.getLogger(__name__)
# Actions never permitted in community deployments.
_DEFAULT_BLOCKED: frozenset[str] = frozenset(
{
"destroy",
"kill_npc",
"steal",
"grief",
"cheat",
"spawn_item",
}
)
@dataclass
class ViolationRecord:
"""Record of a single policy violation."""
player_id: str
action: str
reason: str
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
class AntiGriefPolicy:
"""Enforce rate limits and action restrictions for agent deployments.
Parameters
----------
max_actions_per_window:
Maximum actions allowed per player inside the sliding window.
window_seconds:
Duration of the sliding rate-limit window in seconds.
blocked_actions:
Additional action names to block beyond the built-in defaults.
"""
def __init__(
self,
*,
max_actions_per_window: int = 30,
window_seconds: float = 60.0,
blocked_actions: set[str] | None = None,
) -> None:
self._max = max_actions_per_window
self._window = window_seconds
self._blocked = _DEFAULT_BLOCKED | (blocked_actions or set())
# Per-player sliding-window timestamp buckets
self._timestamps: dict[str, deque[float]] = defaultdict(deque)
self._violations: list[ViolationRecord] = []
# -- public API --------------------------------------------------------
def check(self, player_id: str, command: CommandInput) -> ActionResult | None:
"""Evaluate *command* for *player_id*.
Returns ``None`` if the action is permitted, or an ``ActionResult``
with ``FAILURE`` status if it should be blocked. Callers must
reject the action when a non-``None`` result is returned.
"""
# 1. Blocked-action check
if command.action in self._blocked:
self._record(player_id, command.action, "blocked action type")
return ActionResult(
status=ActionStatus.FAILURE,
message=(
f"Action '{command.action}' is not permitted "
"in community deployments."
),
)
# 2. Rate-limit check (sliding window)
now = time.monotonic()
bucket = self._timestamps[player_id]
while bucket and now - bucket[0] > self._window:
bucket.popleft()
if len(bucket) >= self._max:
self._record(player_id, command.action, "rate limit exceeded")
return ActionResult(
status=ActionStatus.FAILURE,
message=(
f"Rate limit: player '{player_id}' exceeded "
f"{self._max} actions per {self._window:.0f}s window."
),
)
bucket.append(now)
return None # Permitted
def reset_player(self, player_id: str) -> None:
"""Clear the rate-limit bucket for *player_id* (e.g. on reconnect)."""
self._timestamps.pop(player_id, None)
def is_blocked_action(self, action: str) -> bool:
"""Return ``True`` if *action* is in the blocked-action set."""
return action in self._blocked
@property
def violation_count(self) -> int:
return len(self._violations)
@property
def violations(self) -> list[ViolationRecord]:
return list(self._violations)
# -- internal ----------------------------------------------------------
def _record(self, player_id: str, action: str, reason: str) -> None:
rec = ViolationRecord(player_id=player_id, action=action, reason=reason)
self._violations.append(rec)
logger.warning(
"AntiGrief: player=%s action=%s reason=%s",
player_id,
action,
reason,
)

View File

@@ -0,0 +1,178 @@
"""World-state backup strategy — timestamped files with rotation.
``WorldStateBackup`` writes each backup as a standalone JSON file and
maintains a ``MANIFEST.jsonl`` index for fast listing. Old backups
beyond the retention limit are rotated out automatically.
Usage::
backup = WorldStateBackup("var/backups/", max_backups=10)
record = backup.create(adapter, notes="pre-phase-8 checkpoint")
backup.restore(adapter, record.backup_id)
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass
from datetime import UTC, datetime
from pathlib import Path
from infrastructure.world.adapters.mock import MockWorldAdapter
logger = logging.getLogger(__name__)
@dataclass
class BackupRecord:
"""Metadata entry written to the backup manifest."""
backup_id: str
timestamp: str
location: str
entity_count: int
event_count: int
size_bytes: int = 0
notes: str = ""
class WorldStateBackup:
"""Timestamped, rotating world-state backups.
Each backup is a JSON file named ``backup_<timestamp>.json`` inside
*backup_dir*. A ``MANIFEST.jsonl`` index tracks all backups for fast
listing and rotation.
Parameters
----------
backup_dir:
Directory where backup files and the manifest are stored.
max_backups:
Maximum number of backup files to retain.
"""
MANIFEST_NAME = "MANIFEST.jsonl"
def __init__(
self,
backup_dir: Path | str,
*,
max_backups: int = 10,
) -> None:
self._dir = Path(backup_dir)
self._dir.mkdir(parents=True, exist_ok=True)
self._max = max_backups
# -- create ------------------------------------------------------------
def create(
self,
adapter: MockWorldAdapter,
*,
notes: str = "",
) -> BackupRecord:
"""Snapshot *adapter* and write a new backup file.
Returns the ``BackupRecord`` describing the backup.
"""
perception = adapter.observe()
ts = datetime.now(UTC).strftime("%Y%m%dT%H%M%S%f")
backup_id = f"backup_{ts}"
payload = {
"backup_id": backup_id,
"timestamp": datetime.now(UTC).isoformat(),
"location": perception.location,
"entities": list(perception.entities),
"events": list(perception.events),
"raw": dict(perception.raw),
"notes": notes,
}
backup_path = self._dir / f"{backup_id}.json"
backup_path.write_text(json.dumps(payload, indent=2))
size = backup_path.stat().st_size
record = BackupRecord(
backup_id=backup_id,
timestamp=payload["timestamp"],
location=perception.location,
entity_count=len(perception.entities),
event_count=len(perception.events),
size_bytes=size,
notes=notes,
)
self._update_manifest(record)
self._rotate()
logger.info(
"WorldStateBackup: created %s (%d bytes)", backup_id, size
)
return record
# -- restore -----------------------------------------------------------
def restore(self, adapter: MockWorldAdapter, backup_id: str) -> bool:
"""Restore *adapter* state from backup *backup_id*.
Returns ``True`` on success, ``False`` if the backup file is missing.
"""
backup_path = self._dir / f"{backup_id}.json"
if not backup_path.exists():
logger.warning("WorldStateBackup: backup %s not found", backup_id)
return False
payload = json.loads(backup_path.read_text())
adapter._location = payload.get("location", "")
adapter._entities = list(payload.get("entities", []))
adapter._events = list(payload.get("events", []))
logger.info("WorldStateBackup: restored from %s", backup_id)
return True
# -- listing -----------------------------------------------------------
def list_backups(self) -> list[BackupRecord]:
"""Return all backup records, most recent first."""
manifest = self._dir / self.MANIFEST_NAME
if not manifest.exists():
return []
records: list[BackupRecord] = []
for line in manifest.read_text().strip().splitlines():
try:
data = json.loads(line)
records.append(BackupRecord(**data))
except (json.JSONDecodeError, TypeError):
continue
return list(reversed(records))
def latest(self) -> BackupRecord | None:
"""Return the most recent backup record, or ``None``."""
backups = self.list_backups()
return backups[0] if backups else None
# -- internal ----------------------------------------------------------
def _update_manifest(self, record: BackupRecord) -> None:
manifest = self._dir / self.MANIFEST_NAME
with manifest.open("a") as f:
f.write(json.dumps(asdict(record)) + "\n")
def _rotate(self) -> None:
"""Remove oldest backups when over the retention limit."""
backups = self.list_backups() # most recent first
if len(backups) <= self._max:
return
to_remove = backups[self._max :]
for rec in to_remove:
path = self._dir / f"{rec.backup_id}.json"
try:
path.unlink(missing_ok=True)
logger.debug("WorldStateBackup: rotated out %s", rec.backup_id)
except OSError as exc:
logger.warning(
"WorldStateBackup: could not remove %s: %s", path, exc
)
# Rewrite manifest with only the retained backups
keep = backups[: self._max]
manifest = self._dir / self.MANIFEST_NAME
manifest.write_text(
"\n".join(json.dumps(asdict(r)) for r in reversed(keep)) + "\n"
)

View File

@@ -0,0 +1,196 @@
"""Resource monitoring — CPU, RAM, and disk usage under load.
``ResourceMonitor`` collects lightweight resource snapshots. When
``psutil`` is installed it uses richer per-process metrics; otherwise it
falls back to stdlib primitives (``shutil.disk_usage``, ``os.getloadavg``).
Usage::
monitor = ResourceMonitor()
monitor.sample() # single reading
monitor.sample_n(10, interval_s=0.5) # 10 readings, 0.5 s apart
print(monitor.summary())
"""
from __future__ import annotations
import logging
import os
import shutil
import time
from dataclasses import dataclass
from datetime import UTC, datetime
logger = logging.getLogger(__name__)
@dataclass
class ResourceSnapshot:
"""Point-in-time resource usage reading.
Attributes:
timestamp: ISO-8601 timestamp.
cpu_percent: CPU usage 0100; ``-1`` if unavailable.
memory_used_mb: Resident memory in MiB; ``-1`` if unavailable.
memory_total_mb: Total system memory in MiB; ``-1`` if unavailable.
disk_used_gb: Disk used for the watched path in GiB.
disk_total_gb: Total disk for the watched path in GiB.
load_avg_1m: 1-minute load average; ``-1`` on Windows.
"""
timestamp: str
cpu_percent: float = -1.0
memory_used_mb: float = -1.0
memory_total_mb: float = -1.0
disk_used_gb: float = -1.0
disk_total_gb: float = -1.0
load_avg_1m: float = -1.0
class ResourceMonitor:
"""Lightweight resource monitor for multi-agent load testing.
Captures ``ResourceSnapshot`` readings and retains the last
*max_history* entries. Uses ``psutil`` when available, with a
graceful fallback to stdlib primitives.
Parameters
----------
max_history:
Maximum number of snapshots retained in memory.
watch_path:
Filesystem path used for disk-usage measurement.
"""
def __init__(
self,
*,
max_history: int = 100,
watch_path: str = ".",
) -> None:
self._max = max_history
self._watch = watch_path
self._history: list[ResourceSnapshot] = []
self._psutil = self._try_import_psutil()
# -- public API --------------------------------------------------------
def sample(self) -> ResourceSnapshot:
"""Take a single resource snapshot and add it to history."""
snap = self._collect()
self._history.append(snap)
if len(self._history) > self._max:
self._history = self._history[-self._max :]
return snap
def sample_n(
self,
n: int,
*,
interval_s: float = 0.1,
) -> list[ResourceSnapshot]:
"""Take *n* samples spaced *interval_s* seconds apart.
Useful for profiling resource usage during a stress test run.
"""
results: list[ResourceSnapshot] = []
for i in range(n):
results.append(self.sample())
if i < n - 1:
time.sleep(interval_s)
return results
@property
def history(self) -> list[ResourceSnapshot]:
return list(self._history)
def peak_cpu(self) -> float:
"""Return the highest cpu_percent seen, or ``-1`` if no samples."""
valid = [s.cpu_percent for s in self._history if s.cpu_percent >= 0]
return max(valid) if valid else -1.0
def peak_memory_mb(self) -> float:
"""Return the highest memory_used_mb seen, or ``-1`` if no samples."""
valid = [s.memory_used_mb for s in self._history if s.memory_used_mb >= 0]
return max(valid) if valid else -1.0
def summary(self) -> str:
"""Human-readable summary of recorded resource snapshots."""
if not self._history:
return "ResourceMonitor: no samples collected"
return (
f"ResourceMonitor: {len(self._history)} samples — "
f"peak CPU {self.peak_cpu():.1f}%, "
f"peak RAM {self.peak_memory_mb():.1f} MiB"
)
# -- internal ----------------------------------------------------------
def _collect(self) -> ResourceSnapshot:
ts = datetime.now(UTC).isoformat()
# Disk (always available via stdlib)
try:
usage = shutil.disk_usage(self._watch)
disk_used_gb = round((usage.total - usage.free) / (1024**3), 3)
disk_total_gb = round(usage.total / (1024**3), 3)
except OSError:
disk_used_gb = -1.0
disk_total_gb = -1.0
# Load average (POSIX only)
try:
load_avg_1m = round(os.getloadavg()[0], 3)
except AttributeError:
load_avg_1m = -1.0 # Windows
if self._psutil:
return self._collect_psutil(ts, disk_used_gb, disk_total_gb, load_avg_1m)
return ResourceSnapshot(
timestamp=ts,
disk_used_gb=disk_used_gb,
disk_total_gb=disk_total_gb,
load_avg_1m=load_avg_1m,
)
def _collect_psutil(
self,
ts: str,
disk_used_gb: float,
disk_total_gb: float,
load_avg_1m: float,
) -> ResourceSnapshot:
psutil = self._psutil
try:
cpu = round(psutil.cpu_percent(interval=None), 2)
except Exception:
cpu = -1.0
try:
vm = psutil.virtual_memory()
mem_used = round(vm.used / (1024**2), 2)
mem_total = round(vm.total / (1024**2), 2)
except Exception:
mem_used = -1.0
mem_total = -1.0
return ResourceSnapshot(
timestamp=ts,
cpu_percent=cpu,
memory_used_mb=mem_used,
memory_total_mb=mem_total,
disk_used_gb=disk_used_gb,
disk_total_gb=disk_total_gb,
load_avg_1m=load_avg_1m,
)
@staticmethod
def _try_import_psutil():
try:
import psutil
return psutil
except ImportError:
logger.debug(
"ResourceMonitor: psutil not available — using stdlib fallback"
)
return None

View File

@@ -0,0 +1,178 @@
"""Quest state conflict resolution for multi-player sessions.
When multiple agents attempt to advance the same quest simultaneously
the arbiter serialises access via a per-quest lock, records the
authoritative state, and rejects conflicting updates with a logged
``ConflictRecord``. First-come-first-served semantics are used.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import StrEnum
logger = logging.getLogger(__name__)
class QuestStage(StrEnum):
"""Canonical quest progression stages."""
AVAILABLE = "available"
ACTIVE = "active"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class QuestLock:
"""Lock held by a player on a quest."""
player_id: str
quest_id: str
stage: QuestStage
acquired_at: datetime = field(default_factory=lambda: datetime.now(UTC))
@dataclass
class ConflictRecord:
"""Record of a detected quest-state conflict."""
quest_id: str
winner: str
loser: str
resolution: str
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
class QuestArbiter:
"""Serialise quest progression across multiple concurrent agents.
The first player to ``claim`` a quest holds the authoritative lock.
Subsequent claimants are rejected — their attempt is recorded in
``conflicts`` for audit purposes.
Thread-safe: all mutations are protected by an internal lock.
"""
def __init__(self) -> None:
self._locks: dict[str, QuestLock] = {}
self._conflicts: list[ConflictRecord] = []
self._mu = threading.Lock()
# -- public API --------------------------------------------------------
def claim(self, player_id: str, quest_id: str, stage: QuestStage) -> bool:
"""Attempt to claim *quest_id* for *player_id* at *stage*.
Returns ``True`` if the claim was granted (no existing lock, or same
player updating their own lock), ``False`` on conflict.
"""
with self._mu:
existing = self._locks.get(quest_id)
if existing is None:
self._locks[quest_id] = QuestLock(
player_id=player_id,
quest_id=quest_id,
stage=stage,
)
logger.info(
"QuestArbiter: %s claimed '%s' at stage %s",
player_id,
quest_id,
stage,
)
return True
if existing.player_id == player_id:
existing.stage = stage
return True
# Conflict: different player already holds the lock
conflict = ConflictRecord(
quest_id=quest_id,
winner=existing.player_id,
loser=player_id,
resolution=(
f"first-come-first-served; {existing.player_id} retains lock"
),
)
self._conflicts.append(conflict)
logger.warning(
"QuestArbiter: conflict on '%s'%s rejected (held by %s)",
quest_id,
player_id,
existing.player_id,
)
return False
def release(self, player_id: str, quest_id: str) -> bool:
"""Release *player_id*'s lock on *quest_id*.
Returns ``True`` if released, ``False`` if the player didn't hold it.
"""
with self._mu:
lock = self._locks.get(quest_id)
if lock is not None and lock.player_id == player_id:
del self._locks[quest_id]
logger.info("QuestArbiter: %s released '%s'", player_id, quest_id)
return True
return False
def advance(
self,
player_id: str,
quest_id: str,
new_stage: QuestStage,
) -> bool:
"""Advance a quest the player already holds to *new_stage*.
Returns ``True`` on success. Locks for COMPLETED/FAILED stages are
automatically released after the advance.
"""
with self._mu:
lock = self._locks.get(quest_id)
if lock is None or lock.player_id != player_id:
logger.warning(
"QuestArbiter: %s cannot advance '%s' — not the lock holder",
player_id,
quest_id,
)
return False
lock.stage = new_stage
logger.info(
"QuestArbiter: %s advanced '%s' to %s",
player_id,
quest_id,
new_stage,
)
if new_stage in (QuestStage.COMPLETED, QuestStage.FAILED):
del self._locks[quest_id]
return True
def get_stage(self, quest_id: str) -> QuestStage | None:
"""Return the authoritative stage for *quest_id*, or ``None``."""
with self._mu:
lock = self._locks.get(quest_id)
return lock.stage if lock else None
def lock_holder(self, quest_id: str) -> str | None:
"""Return the player_id holding the lock for *quest_id*, or ``None``."""
with self._mu:
lock = self._locks.get(quest_id)
return lock.player_id if lock else None
@property
def active_lock_count(self) -> int:
with self._mu:
return len(self._locks)
@property
def conflict_count(self) -> int:
return len(self._conflicts)
@property
def conflicts(self) -> list[ConflictRecord]:
return list(self._conflicts)

View File

@@ -0,0 +1,184 @@
"""Crash recovery with world-state preservation.
``RecoveryManager`` takes periodic snapshots of a ``MockWorldAdapter``'s
state and persists them to a JSONL file. On restart, the last clean
snapshot can be loaded to rebuild adapter state and minimise data loss.
Usage::
mgr = RecoveryManager("var/recovery.jsonl")
snap = mgr.snapshot(adapter) # save state
...
mgr.restore(adapter) # restore latest on restart
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from infrastructure.world.adapters.mock import MockWorldAdapter
logger = logging.getLogger(__name__)
@dataclass
class WorldSnapshot:
"""Serialisable snapshot of a world adapter's state.
Attributes:
snapshot_id: Unique identifier (ISO timestamp by default).
timestamp: ISO-8601 string of when the snapshot was taken.
location: World location at snapshot time.
entities: Entities present at snapshot time.
events: Recent events at snapshot time.
metadata: Arbitrary extra payload from the adapter's ``raw`` field.
"""
snapshot_id: str
timestamp: str
location: str = ""
entities: list[str] = field(default_factory=list)
events: list[str] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
class RecoveryManager:
"""Snapshot-based crash recovery for world adapters.
Snapshots are appended to a JSONL file; the most recent entry is
used when restoring. Old snapshots beyond *max_snapshots* are
trimmed automatically.
Parameters
----------
state_path:
Path to the JSONL file where snapshots are stored.
max_snapshots:
Maximum number of snapshots to retain.
"""
def __init__(
self,
state_path: Path | str,
*,
max_snapshots: int = 50,
) -> None:
self._path = Path(state_path)
self._max = max_snapshots
self._path.parent.mkdir(parents=True, exist_ok=True)
# -- snapshot ----------------------------------------------------------
def snapshot(
self,
adapter: MockWorldAdapter,
*,
snapshot_id: str | None = None,
) -> WorldSnapshot:
"""Snapshot *adapter* state and persist to disk.
Returns the ``WorldSnapshot`` that was saved.
"""
perception = adapter.observe()
sid = snapshot_id or datetime.now(UTC).strftime("%Y%m%dT%H%M%S%f")
snap = WorldSnapshot(
snapshot_id=sid,
timestamp=datetime.now(UTC).isoformat(),
location=perception.location,
entities=list(perception.entities),
events=list(perception.events),
metadata=dict(perception.raw),
)
self._append(snap)
logger.info("RecoveryManager: snapshot %s saved to %s", sid, self._path)
return snap
# -- restore -----------------------------------------------------------
def restore(
self,
adapter: MockWorldAdapter,
*,
snapshot_id: str | None = None,
) -> WorldSnapshot | None:
"""Restore *adapter* from a snapshot.
Parameters
----------
snapshot_id:
If given, restore from that specific snapshot ID.
Otherwise restore from the most recent snapshot.
Returns the ``WorldSnapshot`` used to restore, or ``None`` if none found.
"""
history = self.load_history()
if not history:
logger.warning("RecoveryManager: no snapshots found at %s", self._path)
return None
if snapshot_id is None:
snap_data = history[0] # most recent
else:
snap_data = next(
(s for s in history if s["snapshot_id"] == snapshot_id),
None,
)
if snap_data is None:
logger.warning("RecoveryManager: snapshot %s not found", snapshot_id)
return None
snap = WorldSnapshot(**snap_data)
adapter._location = snap.location
adapter._entities = list(snap.entities)
adapter._events = list(snap.events)
logger.info("RecoveryManager: restored from snapshot %s", snap.snapshot_id)
return snap
# -- history -----------------------------------------------------------
def load_history(self) -> list[dict]:
"""Return all snapshots as dicts, most recent first."""
if not self._path.exists():
return []
records: list[dict] = []
for line in self._path.read_text().strip().splitlines():
try:
records.append(json.loads(line))
except json.JSONDecodeError:
continue
return list(reversed(records))
def latest(self) -> WorldSnapshot | None:
"""Return the most recent snapshot, or ``None``."""
history = self.load_history()
if not history:
return None
return WorldSnapshot(**history[0])
@property
def snapshot_count(self) -> int:
"""Number of snapshots currently on disk."""
return len(self.load_history())
# -- internal ----------------------------------------------------------
def _append(self, snap: WorldSnapshot) -> None:
with self._path.open("a") as f:
f.write(json.dumps(asdict(snap)) + "\n")
self._trim()
def _trim(self) -> None:
"""Keep only the last *max_snapshots* lines."""
lines = [
ln
for ln in self._path.read_text().strip().splitlines()
if ln.strip()
]
if len(lines) > self._max:
lines = lines[-self._max :]
self._path.write_text("\n".join(lines) + "\n")

View File

@@ -0,0 +1,168 @@
"""Multi-client stress runner — validates 6+ concurrent automated agents.
Runs N simultaneous ``MockWorldAdapter`` instances through heartbeat cycles
concurrently via asyncio and collects per-client results. The runner is
the primary gate for Phase 8 multi-player stability requirements.
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from infrastructure.world.adapters.mock import MockWorldAdapter
from infrastructure.world.benchmark.scenarios import BenchmarkScenario
from infrastructure.world.types import ActionStatus, CommandInput
logger = logging.getLogger(__name__)
@dataclass
class ClientResult:
"""Result for a single simulated client in a stress run."""
client_id: str
cycles_completed: int = 0
actions_taken: int = 0
errors: list[str] = field(default_factory=list)
wall_time_ms: int = 0
success: bool = False
@dataclass
class StressTestReport:
"""Aggregated report across all simulated clients."""
client_count: int
scenario_name: str
results: list[ClientResult] = field(default_factory=list)
total_time_ms: int = 0
timestamp: str = ""
@property
def success_count(self) -> int:
return sum(1 for r in self.results if r.success)
@property
def error_count(self) -> int:
return sum(len(r.errors) for r in self.results)
@property
def all_passed(self) -> bool:
return all(r.success for r in self.results)
def summary(self) -> str:
lines = [
f"=== Stress Test: {self.scenario_name} ===",
f"Clients: {self.client_count} Passed: {self.success_count} "
f"Errors: {self.error_count} Time: {self.total_time_ms} ms",
]
for r in self.results:
status = "OK" if r.success else "FAIL"
lines.append(
f" [{status}] {r.client_id}"
f"{r.cycles_completed} cycles, {r.actions_taken} actions, "
f"{r.wall_time_ms} ms"
)
for err in r.errors:
lines.append(f" Error: {err}")
return "\n".join(lines)
class MultiClientStressRunner:
"""Run N concurrent automated clients through a scenario.
Each client gets its own ``MockWorldAdapter`` instance. All clients
run their observe/act cycles concurrently via ``asyncio.gather``.
Parameters
----------
client_count:
Number of simultaneous clients. Must be >= 1.
Phase 8 target is 6+ (see ``MIN_CLIENTS_FOR_PHASE8``).
cycles_per_client:
How many observe→act cycles each client executes.
"""
MIN_CLIENTS_FOR_PHASE8 = 6
def __init__(
self,
*,
client_count: int = 6,
cycles_per_client: int = 5,
) -> None:
if client_count < 1:
raise ValueError("client_count must be >= 1")
self._client_count = client_count
self._cycles = cycles_per_client
@property
def meets_phase8_requirement(self) -> bool:
"""True when client_count >= 6 (Phase 8 multi-player target)."""
return self._client_count >= self.MIN_CLIENTS_FOR_PHASE8
async def run(self, scenario: BenchmarkScenario) -> StressTestReport:
"""Launch all clients concurrently and return the aggregated report."""
report = StressTestReport(
client_count=self._client_count,
scenario_name=scenario.name,
timestamp=datetime.now(UTC).isoformat(),
)
suite_start = time.monotonic()
tasks = [
self._run_client(f"client-{i:02d}", scenario)
for i in range(self._client_count)
]
report.results = list(await asyncio.gather(*tasks))
report.total_time_ms = int((time.monotonic() - suite_start) * 1000)
logger.info(
"StressTest '%s': %d/%d clients passed in %d ms",
scenario.name,
report.success_count,
self._client_count,
report.total_time_ms,
)
return report
async def _run_client(
self,
client_id: str,
scenario: BenchmarkScenario,
) -> ClientResult:
result = ClientResult(client_id=client_id)
adapter = MockWorldAdapter(
location=scenario.start_location,
entities=list(scenario.entities),
events=list(scenario.events),
)
adapter.connect()
start = time.monotonic()
try:
for _ in range(self._cycles):
perception = adapter.observe()
result.cycles_completed += 1
cmd = CommandInput(
action="observe",
parameters={"location": perception.location},
)
action_result = adapter.act(cmd)
if action_result.status == ActionStatus.SUCCESS:
result.actions_taken += 1
# Yield to the event loop between cycles
await asyncio.sleep(0)
result.success = True
except Exception as exc:
msg = f"{type(exc).__name__}: {exc}"
result.errors.append(msg)
logger.warning("StressTest client %s failed: %s", client_id, msg)
finally:
adapter.disconnect()
result.wall_time_ms = int((time.monotonic() - start) * 1000)
return result

View File

@@ -0,0 +1,547 @@
"""Tests for TES3MP server hardening — multi-player stability & anti-grief.
Covers:
- MultiClientStressRunner (Phase 8: 6+ concurrent clients)
- QuestArbiter (conflict resolution)
- AntiGriefPolicy (rate limiting, blocked actions)
- RecoveryManager (snapshot / restore)
- WorldStateBackup (create / restore / rotate)
- ResourceMonitor (sampling, peak, summary)
"""
from __future__ import annotations
import pytest
from infrastructure.world.adapters.mock import MockWorldAdapter
from infrastructure.world.benchmark.scenarios import BenchmarkScenario
from infrastructure.world.hardening.anti_grief import AntiGriefPolicy
from infrastructure.world.hardening.backup import BackupRecord, WorldStateBackup
from infrastructure.world.hardening.monitor import ResourceMonitor, ResourceSnapshot
from infrastructure.world.hardening.quest_arbiter import (
QuestArbiter,
QuestStage,
)
from infrastructure.world.hardening.recovery import RecoveryManager, WorldSnapshot
from infrastructure.world.hardening.stress import (
MultiClientStressRunner,
StressTestReport,
)
from infrastructure.world.types import CommandInput
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SIMPLE_SCENARIO = BenchmarkScenario(
name="Stress Smoke",
description="Minimal scenario for stress testing",
start_location="Seyda Neen",
entities=["Guard"],
events=["player_spawned"],
max_cycles=3,
tags=["stress"],
)
# ---------------------------------------------------------------------------
# MultiClientStressRunner
# ---------------------------------------------------------------------------
class TestMultiClientStressRunner:
def test_phase8_requirement_met(self):
runner = MultiClientStressRunner(client_count=6)
assert runner.meets_phase8_requirement is True
def test_phase8_requirement_not_met(self):
runner = MultiClientStressRunner(client_count=5)
assert runner.meets_phase8_requirement is False
def test_invalid_client_count(self):
with pytest.raises(ValueError):
MultiClientStressRunner(client_count=0)
@pytest.mark.asyncio
async def test_run_six_clients(self):
runner = MultiClientStressRunner(client_count=6, cycles_per_client=3)
report = await runner.run(_SIMPLE_SCENARIO)
assert isinstance(report, StressTestReport)
assert report.client_count == 6
assert len(report.results) == 6
assert report.all_passed is True
@pytest.mark.asyncio
async def test_all_clients_complete_cycles(self):
runner = MultiClientStressRunner(client_count=6, cycles_per_client=4)
report = await runner.run(_SIMPLE_SCENARIO)
for result in report.results:
assert result.cycles_completed == 4
assert result.actions_taken == 4
assert result.errors == []
@pytest.mark.asyncio
async def test_report_has_timestamp(self):
runner = MultiClientStressRunner(client_count=2, cycles_per_client=1)
report = await runner.run(_SIMPLE_SCENARIO)
assert report.timestamp
@pytest.mark.asyncio
async def test_report_summary_string(self):
runner = MultiClientStressRunner(client_count=2, cycles_per_client=1)
report = await runner.run(_SIMPLE_SCENARIO)
summary = report.summary()
assert "Stress Smoke" in summary
assert "Clients: 2" in summary
@pytest.mark.asyncio
async def test_single_client(self):
runner = MultiClientStressRunner(client_count=1, cycles_per_client=2)
report = await runner.run(_SIMPLE_SCENARIO)
assert report.success_count == 1
@pytest.mark.asyncio
async def test_client_ids_are_unique(self):
runner = MultiClientStressRunner(client_count=6, cycles_per_client=1)
report = await runner.run(_SIMPLE_SCENARIO)
ids = [r.client_id for r in report.results]
assert len(ids) == len(set(ids))
# ---------------------------------------------------------------------------
# QuestArbiter
# ---------------------------------------------------------------------------
class TestQuestArbiter:
def test_first_claim_granted(self):
arbiter = QuestArbiter()
assert arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE) is True
def test_conflict_rejected(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.claim("bob", "fargoth_ring", QuestStage.ACTIVE) is False
def test_conflict_recorded(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
arbiter.claim("bob", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.conflict_count == 1
assert arbiter.conflicts[0].winner == "alice"
assert arbiter.conflicts[0].loser == "bob"
def test_same_player_can_update_own_lock(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
# Alice updates her own lock — no conflict
assert arbiter.claim("alice", "fargoth_ring", QuestStage.COMPLETED) is True
assert arbiter.conflict_count == 0
def test_release_frees_quest(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
arbiter.release("alice", "fargoth_ring")
# Bob can now claim
assert arbiter.claim("bob", "fargoth_ring", QuestStage.ACTIVE) is True
def test_release_wrong_player_fails(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.release("bob", "fargoth_ring") is False
assert arbiter.active_lock_count == 1
def test_advance_updates_stage(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.advance("alice", "fargoth_ring", QuestStage.COMPLETED) is True
# Lock should be released after COMPLETED
assert arbiter.active_lock_count == 0
def test_advance_failed_releases_lock(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
arbiter.advance("alice", "fargoth_ring", QuestStage.FAILED)
assert arbiter.active_lock_count == 0
def test_advance_wrong_player_fails(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.advance("bob", "fargoth_ring", QuestStage.COMPLETED) is False
def test_get_stage(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.get_stage("fargoth_ring") == QuestStage.ACTIVE
def test_get_stage_unknown_quest(self):
assert QuestArbiter().get_stage("nonexistent") is None
def test_lock_holder(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "fargoth_ring", QuestStage.ACTIVE)
assert arbiter.lock_holder("fargoth_ring") == "alice"
def test_active_lock_count(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "quest_a", QuestStage.ACTIVE)
arbiter.claim("bob", "quest_b", QuestStage.ACTIVE)
assert arbiter.active_lock_count == 2
def test_multiple_quests_independent(self):
arbiter = QuestArbiter()
arbiter.claim("alice", "quest_a", QuestStage.ACTIVE)
# Bob can claim a different quest without conflict
assert arbiter.claim("bob", "quest_b", QuestStage.ACTIVE) is True
assert arbiter.conflict_count == 0
# ---------------------------------------------------------------------------
# AntiGriefPolicy
# ---------------------------------------------------------------------------
class TestAntiGriefPolicy:
def test_permitted_action_returns_none(self):
policy = AntiGriefPolicy()
cmd = CommandInput(action="move", target="north")
assert policy.check("player-01", cmd) is None
def test_blocked_action_rejected(self):
policy = AntiGriefPolicy()
cmd = CommandInput(action="destroy", target="barrel")
result = policy.check("player-01", cmd)
assert result is not None
assert "destroy" in result.message
assert policy.violation_count == 1
def test_custom_blocked_action(self):
policy = AntiGriefPolicy(blocked_actions={"teleport"})
cmd = CommandInput(action="teleport")
result = policy.check("player-01", cmd)
assert result is not None
def test_is_blocked_action(self):
policy = AntiGriefPolicy()
assert policy.is_blocked_action("kill_npc") is True
assert policy.is_blocked_action("move") is False
def test_rate_limit_exceeded(self):
policy = AntiGriefPolicy(max_actions_per_window=3, window_seconds=60.0)
cmd = CommandInput(action="move")
# First 3 actions should pass
for _ in range(3):
assert policy.check("player-01", cmd) is None
# 4th action should be blocked
result = policy.check("player-01", cmd)
assert result is not None
assert "Rate limit" in result.message
def test_rate_limit_per_player(self):
policy = AntiGriefPolicy(max_actions_per_window=2, window_seconds=60.0)
cmd = CommandInput(action="move")
# player-01 exhausts limit
policy.check("player-01", cmd)
policy.check("player-01", cmd)
assert policy.check("player-01", cmd) is not None
# player-02 is unaffected
assert policy.check("player-02", cmd) is None
def test_reset_player_clears_bucket(self):
policy = AntiGriefPolicy(max_actions_per_window=2, window_seconds=60.0)
cmd = CommandInput(action="move")
policy.check("player-01", cmd)
policy.check("player-01", cmd)
policy.reset_player("player-01")
# Should be allowed again
assert policy.check("player-01", cmd) is None
def test_violations_list(self):
policy = AntiGriefPolicy()
policy.check("player-01", CommandInput(action="steal"))
assert len(policy.violations) == 1
assert policy.violations[0].player_id == "player-01"
assert policy.violations[0].action == "steal"
def test_all_default_blocked_actions(self):
policy = AntiGriefPolicy()
for action in ("destroy", "kill_npc", "steal", "grief", "cheat", "spawn_item"):
assert policy.is_blocked_action(action), f"{action!r} should be blocked"
# ---------------------------------------------------------------------------
# RecoveryManager
# ---------------------------------------------------------------------------
class TestRecoveryManager:
def test_snapshot_creates_file(self, tmp_path):
path = tmp_path / "recovery.jsonl"
mgr = RecoveryManager(path)
adapter = MockWorldAdapter(location="Vivec")
adapter.connect()
snap = mgr.snapshot(adapter)
assert path.exists()
assert snap.location == "Vivec"
def test_snapshot_returns_world_snapshot(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter(location="Balmora", entities=["Guard"])
adapter.connect()
snap = mgr.snapshot(adapter)
assert isinstance(snap, WorldSnapshot)
assert snap.location == "Balmora"
assert "Guard" in snap.entities
def test_restore_latest(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter(location="Seyda Neen")
adapter.connect()
mgr.snapshot(adapter)
# Change location and restore
adapter._location = "Somewhere Else"
result = mgr.restore(adapter)
assert result is not None
assert adapter._location == "Seyda Neen"
def test_restore_by_id(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter(location="Ald'ruhn")
adapter.connect()
mgr.snapshot(adapter, snapshot_id="snap-001")
mgr.snapshot(adapter) # second snapshot
adapter._location = "Elsewhere"
result = mgr.restore(adapter, snapshot_id="snap-001")
assert result is not None
assert result.snapshot_id == "snap-001"
def test_restore_missing_id_returns_none(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter()
adapter.connect()
mgr.snapshot(adapter)
result = mgr.restore(adapter, snapshot_id="nonexistent")
assert result is None
def test_restore_empty_history_returns_none(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter()
adapter.connect()
assert mgr.restore(adapter) is None
def test_load_history_most_recent_first(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
for i in range(3):
adapter = MockWorldAdapter(location=f"location-{i}")
adapter.connect()
mgr.snapshot(adapter)
history = mgr.load_history()
assert len(history) == 3
# Most recent was location-2
assert history[0]["location"] == "location-2"
def test_latest_returns_snapshot(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter(location="Gnisis")
adapter.connect()
mgr.snapshot(adapter)
latest = mgr.latest()
assert latest is not None
assert latest.location == "Gnisis"
def test_max_snapshots_trim(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl", max_snapshots=3)
for i in range(5):
adapter = MockWorldAdapter(location=f"loc-{i}")
adapter.connect()
mgr.snapshot(adapter)
assert mgr.snapshot_count == 3
def test_snapshot_count(self, tmp_path):
mgr = RecoveryManager(tmp_path / "recovery.jsonl")
adapter = MockWorldAdapter()
adapter.connect()
for _ in range(4):
mgr.snapshot(adapter)
assert mgr.snapshot_count == 4
# ---------------------------------------------------------------------------
# WorldStateBackup
# ---------------------------------------------------------------------------
class TestWorldStateBackup:
def test_create_writes_file(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter(location="Tel Vos")
adapter.connect()
record = backup.create(adapter)
assert (tmp_path / "backups" / f"{record.backup_id}.json").exists()
def test_create_returns_record(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter(location="Caldera", entities=["Merchant"])
adapter.connect()
record = backup.create(adapter, notes="test note")
assert isinstance(record, BackupRecord)
assert record.location == "Caldera"
assert record.entity_count == 1
assert record.notes == "test note"
assert record.size_bytes > 0
def test_restore_from_backup(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter(location="Ald-ruhn")
adapter.connect()
record = backup.create(adapter)
adapter._location = "Nowhere"
assert backup.restore(adapter, record.backup_id) is True
assert adapter._location == "Ald-ruhn"
def test_restore_missing_backup(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter()
adapter.connect()
assert backup.restore(adapter, "backup_nonexistent") is False
def test_list_backups_most_recent_first(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter()
adapter.connect()
ids = []
for i in range(3):
adapter._location = f"loc-{i}"
r = backup.create(adapter)
ids.append(r.backup_id)
listed = backup.list_backups()
assert len(listed) == 3
# Most recent last created → first in list
assert listed[0].backup_id == ids[-1]
def test_latest_returns_most_recent(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
adapter = MockWorldAdapter(location="Vivec")
adapter.connect()
backup.create(adapter)
adapter._location = "Molag Mar"
record = backup.create(adapter)
assert backup.latest().backup_id == record.backup_id
def test_empty_list_returns_empty(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups")
assert backup.list_backups() == []
assert backup.latest() is None
def test_rotation_removes_oldest(self, tmp_path):
backup = WorldStateBackup(tmp_path / "backups", max_backups=3)
adapter = MockWorldAdapter()
adapter.connect()
records = [backup.create(adapter) for _ in range(5)]
listed = backup.list_backups()
assert len(listed) == 3
# Oldest two should be gone
listed_ids = {r.backup_id for r in listed}
assert records[0].backup_id not in listed_ids
assert records[1].backup_id not in listed_ids
# Newest three should be present
for rec in records[2:]:
assert rec.backup_id in listed_ids
# ---------------------------------------------------------------------------
# ResourceMonitor
# ---------------------------------------------------------------------------
class TestResourceMonitor:
def test_sample_returns_snapshot(self):
monitor = ResourceMonitor()
snap = monitor.sample()
assert isinstance(snap, ResourceSnapshot)
assert snap.timestamp
def test_snapshot_has_disk_fields(self):
monitor = ResourceMonitor(watch_path=".")
snap = monitor.sample()
# Disk should be available on any OS
assert snap.disk_used_gb >= 0
assert snap.disk_total_gb > 0
def test_history_grows(self):
monitor = ResourceMonitor()
monitor.sample()
monitor.sample()
assert len(monitor.history) == 2
def test_history_capped(self):
monitor = ResourceMonitor(max_history=3)
for _ in range(5):
monitor.sample()
assert len(monitor.history) == 3
def test_sample_n(self):
monitor = ResourceMonitor()
results = monitor.sample_n(4, interval_s=0)
assert len(results) == 4
assert all(isinstance(s, ResourceSnapshot) for s in results)
def test_peak_cpu_no_samples(self):
monitor = ResourceMonitor()
assert monitor.peak_cpu() == -1.0
def test_peak_memory_no_samples(self):
monitor = ResourceMonitor()
assert monitor.peak_memory_mb() == -1.0
def test_summary_no_samples(self):
monitor = ResourceMonitor()
assert "no samples" in monitor.summary()
def test_summary_with_samples(self):
monitor = ResourceMonitor()
monitor.sample()
summary = monitor.summary()
assert "ResourceMonitor" in summary
assert "samples" in summary
def test_history_is_copy(self):
monitor = ResourceMonitor()
monitor.sample()
history = monitor.history
history.clear()
assert len(monitor.history) == 1 # original unaffected
# ---------------------------------------------------------------------------
# Module-level import test
# ---------------------------------------------------------------------------
class TestHardeningModuleImport:
def test_all_exports_importable(self):
from infrastructure.world.hardening import (
AntiGriefPolicy,
MultiClientStressRunner,
QuestArbiter,
RecoveryManager,
ResourceMonitor,
WorldStateBackup,
)
for cls in (
AntiGriefPolicy,
MultiClientStressRunner,
QuestArbiter,
RecoveryManager,
ResourceMonitor,
WorldStateBackup,
):
assert cls is not None