225 lines
7.0 KiB
Python
225 lines
7.0 KiB
Python
#!/usr/bin/env python3
|
|
"""WebSocket load test harness for Nexus gateway infrastructure."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import math
|
|
import sys
|
|
import time
|
|
import tracemalloc
|
|
from pathlib import Path
|
|
from typing import Any, Callable
|
|
|
|
try:
|
|
import resource
|
|
except ImportError: # pragma: no cover - not expected on Unix CI, but keep portable
|
|
resource = None
|
|
|
|
try:
|
|
import websockets
|
|
except ImportError: # pragma: no cover - tests inject connector
|
|
websockets = None
|
|
|
|
|
|
def percentile(values: list[float], pct: float) -> float:
|
|
if not values:
|
|
return 0.0
|
|
ordered = sorted(float(value) for value in values)
|
|
if len(ordered) == 1:
|
|
return round(ordered[0], 1)
|
|
rank = (len(ordered) - 1) * (pct / 100.0)
|
|
lower = math.floor(rank)
|
|
upper = math.ceil(rank)
|
|
if lower == upper:
|
|
return round(ordered[lower], 1)
|
|
weight = rank - lower
|
|
interpolated = ordered[lower] + (ordered[upper] - ordered[lower]) * weight
|
|
return round(interpolated, 1)
|
|
|
|
|
|
def measure_memory() -> dict[str, int | None]:
|
|
current_bytes, peak_bytes = tracemalloc.get_traced_memory()
|
|
rss_bytes = None
|
|
if resource is not None:
|
|
try:
|
|
rss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
|
|
rss_bytes = rss if sys.platform == "darwin" else rss * 1024
|
|
except Exception:
|
|
rss_bytes = None
|
|
return {
|
|
"rss_bytes": rss_bytes,
|
|
"tracemalloc_current_bytes": int(current_bytes),
|
|
"tracemalloc_peak_bytes": int(peak_bytes),
|
|
}
|
|
|
|
|
|
def write_report(path: str | Path, report: dict[str, Any]) -> None:
|
|
output = Path(path)
|
|
output.write_text(json.dumps(report, indent=2) + "\n")
|
|
|
|
|
|
def _normalize_payload(payload: Any) -> str | None:
|
|
if payload is None:
|
|
return None
|
|
if isinstance(payload, str):
|
|
return payload
|
|
return json.dumps(payload)
|
|
|
|
|
|
async def _connect_once(
|
|
url: str,
|
|
hold_seconds: float,
|
|
payload: str | None,
|
|
connector: Callable[[str], Any],
|
|
) -> dict[str, Any]:
|
|
start = time.perf_counter()
|
|
try:
|
|
async with connector(url) as websocket:
|
|
connect_ms = (time.perf_counter() - start) * 1000
|
|
messages_sent = 0
|
|
if payload is not None:
|
|
await websocket.send(payload)
|
|
messages_sent = 1
|
|
if hold_seconds > 0:
|
|
await asyncio.sleep(hold_seconds)
|
|
return {
|
|
"success": True,
|
|
"connect_ms": connect_ms,
|
|
"messages_sent": messages_sent,
|
|
}
|
|
except Exception as exc: # pragma: no cover - exercised in live use
|
|
return {
|
|
"success": False,
|
|
"connect_ms": (time.perf_counter() - start) * 1000,
|
|
"messages_sent": 0,
|
|
"error": str(exc),
|
|
}
|
|
|
|
|
|
async def run_load_test(
|
|
*,
|
|
url: str,
|
|
concurrency: int,
|
|
rounds: int,
|
|
hold_seconds: float = 0.1,
|
|
payload: Any = None,
|
|
connector: Callable[[str], Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
if concurrency < 1:
|
|
raise ValueError("concurrency must be >= 1")
|
|
if rounds < 1:
|
|
raise ValueError("rounds must be >= 1")
|
|
|
|
if connector is None:
|
|
if websockets is None:
|
|
raise RuntimeError("websockets package is required for live load testing")
|
|
|
|
def connector(target_url: str):
|
|
return websockets.connect(target_url, open_timeout=10)
|
|
|
|
payload_text = _normalize_payload(payload)
|
|
started_tracing = False
|
|
if not tracemalloc.is_tracing():
|
|
tracemalloc.start()
|
|
started_tracing = True
|
|
|
|
memory_before = measure_memory()
|
|
wall_start = time.perf_counter()
|
|
results: list[dict[str, Any]] = []
|
|
|
|
for _ in range(rounds):
|
|
tasks = [
|
|
asyncio.create_task(
|
|
_connect_once(
|
|
url=url,
|
|
hold_seconds=hold_seconds,
|
|
payload=payload_text,
|
|
connector=connector,
|
|
)
|
|
)
|
|
for _ in range(concurrency)
|
|
]
|
|
results.extend(await asyncio.gather(*tasks))
|
|
|
|
wall_time_ms = round((time.perf_counter() - wall_start) * 1000, 1)
|
|
memory_after = measure_memory()
|
|
if started_tracing:
|
|
tracemalloc.stop()
|
|
|
|
attempted = len(results)
|
|
successful = sum(1 for result in results if result["success"])
|
|
failed = attempted - successful
|
|
latencies = [result["connect_ms"] for result in results if result["success"]]
|
|
messages_sent = sum(result["messages_sent"] for result in results)
|
|
errors = [result.get("error") for result in results if result.get("error")]
|
|
|
|
avg_connect_ms = round(sum(latencies) / len(latencies), 1) if latencies else 0.0
|
|
min_connect_ms = round(min(latencies), 1) if latencies else 0.0
|
|
max_connect_ms = round(max(latencies), 1) if latencies else 0.0
|
|
|
|
return {
|
|
"url": url,
|
|
"concurrency": concurrency,
|
|
"rounds": rounds,
|
|
"hold_seconds": hold_seconds,
|
|
"attempted_connections": attempted,
|
|
"successful_connections": successful,
|
|
"failed_connections": failed,
|
|
"messages_sent": messages_sent,
|
|
"success_rate": round(successful / attempted, 4) if attempted else 0.0,
|
|
"avg_connect_ms": avg_connect_ms,
|
|
"min_connect_ms": min_connect_ms,
|
|
"max_connect_ms": max_connect_ms,
|
|
"p95_connect_ms": percentile(latencies, 95),
|
|
"wall_time_ms": wall_time_ms,
|
|
"memory_before": memory_before,
|
|
"memory_after": memory_after,
|
|
"memory_peak_delta_bytes": max(
|
|
0,
|
|
memory_after["tracemalloc_peak_bytes"] - memory_before["tracemalloc_peak_bytes"],
|
|
),
|
|
"errors": errors[:5],
|
|
}
|
|
|
|
|
|
def _parse_payload(raw: str | None) -> Any:
|
|
if raw is None:
|
|
return None
|
|
try:
|
|
return json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
return raw
|
|
|
|
|
|
def main(argv: list[str] | None = None) -> int:
|
|
parser = argparse.ArgumentParser(description="WebSocket load test harness for Nexus")
|
|
parser.add_argument("--url", required=True, help="WebSocket URL to exercise, e.g. ws://127.0.0.1:8765")
|
|
parser.add_argument("--concurrency", type=int, default=10, help="Concurrent connections per round")
|
|
parser.add_argument("--rounds", type=int, default=1, help="Number of connection rounds to execute")
|
|
parser.add_argument("--hold-seconds", type=float, default=0.1, help="How long to keep each connection open")
|
|
parser.add_argument("--payload", help="Optional message payload to send after connect; JSON accepted")
|
|
parser.add_argument("--output", help="Optional path to write JSON report")
|
|
args = parser.parse_args(argv)
|
|
|
|
report = asyncio.run(
|
|
run_load_test(
|
|
url=args.url,
|
|
concurrency=args.concurrency,
|
|
rounds=args.rounds,
|
|
hold_seconds=args.hold_seconds,
|
|
payload=_parse_payload(args.payload),
|
|
)
|
|
)
|
|
|
|
if args.output:
|
|
write_report(args.output, report)
|
|
print(json.dumps(report, indent=2))
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|