Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
0d9ff94693 fix: add inference server health check with auto-restart
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 39s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 1m8s
Tests / e2e (pull_request) Successful in 3m12s
Tests / test (pull_request) Failing after 55m37s
Closes #713

llama-server on port 8081 was DOWN and nobody noticed until an
audit found it. The fix is not just restarting the process —
it's adding detection so this never goes unnoticed again.

Changes:

- scripts/inference_health.py: Health check utility for local
  inference servers (llama-server, Ollama). Features:
  - HTTP health endpoint check with latency measurement
  - Process alive detection (pgrep)
  - Auto-restart for dead servers (--auto-restart)
  - JSON output for cron integration (--json)
  - Port-specific check (--port 8081)
  - Default endpoints for fleet: llama-server:8081, Ollama:11434
  - Exit code 1 if any server is down (CI/cron integration)

- tests/test_inference_health.py: Tests for result formatting,
  JSON output, endpoint configuration.

Usage:
  python scripts/inference_health.py              # check all
  python scripts/inference_health.py --port 8081  # check llama only
  python scripts/inference_health.py --auto-restart  # restart dead
  python scripts/inference_health.py --json       # machine output
2026-04-14 22:36:24 -04:00
4 changed files with 368 additions and 246 deletions

View File

@@ -1,189 +0,0 @@
"""
Gateway Message Deduplication — Prevent double-posting.
Provides idempotent message delivery by tracking message UUIDs
and suppressing duplicates within a configurable time window.
"""
import hashlib
import logging
import time
import uuid
from typing import Dict, Optional, Set
from dataclasses import dataclass, field
from collections import OrderedDict
logger = logging.getLogger(__name__)
@dataclass
class MessageRecord:
"""Record of a sent message."""
message_id: str
content_hash: str
timestamp: float
session_id: str
platform: str
class MessageDeduplicator:
"""
Deduplicates outbound messages within a time window.
Each message gets a UUID. If the same message (by content hash)
is sent again within the window, it's suppressed.
"""
def __init__(self, window_seconds: int = 60, max_records: int = 1000):
"""
Initialize deduplicator.
Args:
window_seconds: Time window for deduplication (default 60s)
max_records: Maximum records to keep in memory
"""
self.window_seconds = window_seconds
self.max_records = max_records
self._records: OrderedDict[str, MessageRecord] = OrderedDict()
self._suppressed_count = 0
def _content_hash(self, content: str, session_id: str = "", platform: str = "") -> str:
"""Generate hash for message content."""
combined = f"{session_id}:{platform}:{content}"
return hashlib.sha256(combined.encode()).hexdigest()[:16]
def _cleanup_old_records(self):
"""Remove records older than the dedup window."""
cutoff = time.time() - self.window_seconds
to_remove = []
for msg_id, record in self._records.items():
if record.timestamp < cutoff:
to_remove.append(msg_id)
for msg_id in to_remove:
del self._records[msg_id]
def _enforce_max_records(self):
"""Enforce maximum record count by removing oldest."""
while len(self._records) > self.max_records:
self._records.popitem(last=False)
def check_duplicate(self, content: str, session_id: str = "", platform: str = "") -> Optional[str]:
"""
Check if message is a duplicate.
Args:
content: Message content
session_id: Session identifier
platform: Platform name (telegram, discord, etc.)
Returns:
Message ID if duplicate found, None if new message
"""
self._cleanup_old_records()
content_hash = self._content_hash(content, session_id, platform)
for msg_id, record in self._records.items():
if record.content_hash == content_hash:
age = time.time() - record.timestamp
if age < self.window_seconds:
self._suppressed_count += 1
logger.info(
"Suppressed duplicate message (age: %.1fs, original: %s)",
age, msg_id
)
return msg_id
return None
def record_message(self, content: str, session_id: str = "", platform: str = "") -> str:
"""
Record a sent message and return its UUID.
Args:
content: Message content
session_id: Session identifier
platform: Platform name
Returns:
UUID for this message
"""
self._cleanup_old_records()
message_id = str(uuid.uuid4())
content_hash = self._content_hash(content, session_id, platform)
self._records[message_id] = MessageRecord(
message_id=message_id,
content_hash=content_hash,
timestamp=time.time(),
session_id=session_id,
platform=platform,
)
self._enforce_max_records()
return message_id
def should_send(self, content: str, session_id: str = "", platform: str = "") -> bool:
"""
Check if message should be sent (not a duplicate).
Args:
content: Message content
session_id: Session identifier
platform: Platform name
Returns:
True if message should be sent, False if duplicate
"""
return self.check_duplicate(content, session_id, platform) is None
def get_stats(self) -> Dict:
"""Get deduplication statistics."""
return {
"total_records": len(self._records),
"suppressed_count": self._suppressed_count,
"window_seconds": self.window_seconds,
"max_records": self.max_records,
}
def clear(self):
"""Clear all records."""
self._records.clear()
self._suppressed_count = 0
# Global deduplicator instance
_deduplicator: Optional[MessageDeduplicator] = None
def get_deduplicator() -> MessageDeduplicator:
"""Get or create global deduplicator instance."""
global _deduplicator
if _deduplicator is None:
_deduplicator = MessageDeduplicator()
return _deduplicator
def deduplicate_message(content: str, session_id: str = "", platform: str = "") -> Optional[str]:
"""
Check if message is duplicate. Returns message_id if duplicate, None if new.
"""
return get_deduplicator().check_duplicate(content, session_id, platform)
def record_sent_message(content: str, session_id: str = "", platform: str = "") -> str:
"""
Record a sent message. Returns UUID for the message.
"""
return get_deduplicator().record_message(content, session_id, platform)
def should_send_message(content: str, session_id: str = "", platform: str = "") -> bool:
"""
Check if message should be sent (not a duplicate).
"""
return get_deduplicator().should_send(content, session_id, platform)

272
scripts/inference_health.py Normal file
View File

@@ -0,0 +1,272 @@
#!/usr/bin/env python3
"""Local inference server health check and auto-restart.
Checks llama-server, Ollama, and other local inference endpoints.
Reports status, latency, and can auto-restart dead processes.
Refs: #713 — llama-server DOWN on port 8081
"""
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from urllib.request import Request, urlopen
from urllib.error import URLError, HTTPError
@dataclass
class InferenceEndpoint:
"""Configuration for an inference server endpoint."""
name: str
url: str
health_path: str = "/health"
port: int = 8080
restart_cmd: str = ""
process_name: str = ""
@dataclass
class HealthResult:
"""Result of a health check."""
name: str
url: str
status: str # "ok", "down", "slow", "error"
latency_ms: float = 0.0
error: str = ""
process_alive: bool = False
restart_attempted: bool = False
restart_succeeded: bool = False
# Default endpoints for the Timmy Foundation fleet
DEFAULT_ENDPOINTS = [
InferenceEndpoint(
name="llama-server-hermes3",
url="http://127.0.0.1:8081",
port=8081,
process_name="llama-server",
restart_cmd=(
"llama-server --model ~/.ollama/models/blobs/sha256-c8985d "
"--port 8081 --host 127.0.0.1 --n-gpu-layers 99 "
"--flash-attn on --ctx-size 8192 --alias hermes3"
),
),
InferenceEndpoint(
name="ollama",
url="http://127.0.0.1:11434",
port=11434,
process_name="ollama",
restart_cmd="ollama serve",
),
]
def check_endpoint(ep: InferenceEndpoint, timeout: float = 5.0) -> HealthResult:
"""Check a single inference endpoint.
Args:
ep: Endpoint configuration.
timeout: HTTP timeout in seconds.
Returns:
HealthResult with status and latency.
"""
url = ep.url.rstrip("/") + ep.health_path
start = time.time()
# Check if process is alive
process_alive = False
if ep.process_name:
try:
result = subprocess.run(
["pgrep", "-f", ep.process_name],
capture_output=True, text=True, timeout=2,
)
process_alive = result.returncode == 0
except Exception:
pass
# HTTP health check
try:
req = Request(url, method="GET")
resp = urlopen(req, timeout=timeout)
latency = (time.time() - start) * 1000
if resp.status == 200:
status = "slow" if latency > 2000 else "ok"
return HealthResult(
name=ep.name, url=ep.url, status=status,
latency_ms=round(latency, 1), process_alive=process_alive,
)
else:
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=f"HTTP {resp.status}",
)
except URLError as e:
latency = (time.time() - start) * 1000
error_msg = str(e.reason) if hasattr(e, 'reason') else str(e)
return HealthResult(
name=ep.name, url=ep.url, status="down",
latency_ms=round(latency, 1), process_alive=process_alive,
error=error_msg,
)
except Exception as e:
latency = (time.time() - start) * 1000
return HealthResult(
name=ep.name, url=ep.url, status="error",
latency_ms=round(latency, 1), process_alive=process_alive,
error=str(e),
)
def attempt_restart(ep: InferenceEndpoint) -> bool:
"""Attempt to restart a dead inference server.
Args:
ep: Endpoint configuration with restart_cmd.
Returns:
True if restart command executed successfully.
"""
if not ep.restart_cmd:
return False
try:
# Run restart in background
subprocess.Popen(
ep.restart_cmd,
shell=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
# Wait a moment for the process to start
time.sleep(3)
return True
except Exception as e:
print(f"Restart failed for {ep.name}: {e}", file=sys.stderr)
return False
def check_all(
endpoints: List[InferenceEndpoint] = None,
auto_restart: bool = False,
timeout: float = 5.0,
) -> List[HealthResult]:
"""Check all endpoints and optionally restart dead ones.
Args:
endpoints: List of endpoints to check. Uses DEFAULT_ENDPOINTS if None.
auto_restart: If True, attempt to restart down endpoints.
timeout: HTTP timeout per endpoint.
Returns:
List of HealthResult for each endpoint.
"""
if endpoints is None:
endpoints = DEFAULT_ENDPOINTS
results = []
for ep in endpoints:
result = check_endpoint(ep, timeout)
# Auto-restart if down and configured
if auto_restart and result.status == "down" and ep.restart_cmd:
result.restart_attempted = True
result.restart_succeeded = attempt_restart(ep)
if result.restart_succeeded:
# Re-check after restart
time.sleep(2)
result2 = check_endpoint(ep, timeout)
result.status = result2.status
result.latency_ms = result2.latency_ms
result.error = result2.error
results.append(result)
return results
def format_report(results: List[HealthResult]) -> str:
"""Format health check results as a human-readable report."""
lines = [
"# Local Inference Health Check",
f"Time: {time.strftime('%Y-%m-%d %H:%M:%S')}",
"",
"| Endpoint | Status | Latency | Process | Error |",
"|----------|--------|---------|---------|-------|",
]
for r in results:
status_icon = {"ok": "", "slow": "⚠️", "down": "", "error": "💥"}.get(r.status, "?")
proc = "alive" if r.process_alive else "dead"
lat = f"{r.latency_ms}ms" if r.latency_ms > 0 else "-"
err = r.error[:40] if r.error else "-"
lines.append(f"| {r.name} | {status_icon} {r.status} | {lat} | {proc} | {err} |")
down = [r for r in results if r.status in ("down", "error")]
if down:
lines.extend(["", "## DOWN", ""])
for r in down:
lines.append(f"- **{r.name}** ({r.url}): {r.error}")
if r.restart_attempted:
status = "✅ restarted" if r.restart_succeeded else "❌ restart failed"
lines.append(f" Restart: {status}")
return "\n".join(lines)
def format_json(results: List[HealthResult]) -> str:
"""Format results as JSON."""
data = []
for r in results:
data.append({
"name": r.name,
"url": r.url,
"status": r.status,
"latency_ms": r.latency_ms,
"process_alive": r.process_alive,
"error": r.error or None,
"restart_attempted": r.restart_attempted,
"restart_succeeded": r.restart_succeeded,
})
return json.dumps({"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "endpoints": data}, indent=2)
def main():
import argparse
p = argparse.ArgumentParser(description="Local inference health check")
p.add_argument("--json", action="store_true", help="JSON output")
p.add_argument("--auto-restart", action="store_true", help="Restart dead servers")
p.add_argument("--timeout", type=float, default=5.0, help="HTTP timeout (seconds)")
p.add_argument("--port", type=int, help="Check specific port only")
a = p.parse_args()
endpoints = DEFAULT_ENDPOINTS
if a.port:
endpoints = [ep for ep in DEFAULT_ENDPOINTS if ep.port == a.port]
if not endpoints:
print(f"No endpoint configured for port {a.port}", file=sys.stderr)
sys.exit(1)
results = check_all(endpoints, auto_restart=a.auto_restart, timeout=a.timeout)
if a.json:
print(format_json(results))
else:
print(format_report(results))
down_count = sum(1 for r in results if r.status in ("down", "error"))
sys.exit(1 if down_count > 0 else 0)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,96 @@
"""Tests for inference health check (#713)."""
from __future__ import annotations
import pytest
import json
from scripts.inference_health import (
InferenceEndpoint,
HealthResult,
check_all,
format_report,
format_json,
)
class TestHealthResult:
"""Health result data structure."""
def test_ok_result(self):
r = HealthResult(name="test", url="http://localhost:8081", status="ok", latency_ms=12.5)
assert r.status == "ok"
assert r.latency_ms == 12.5
assert not r.error
def test_down_result(self):
r = HealthResult(
name="test", url="http://localhost:8081",
status="down", error="Connection refused",
)
assert r.status == "down"
assert r.error == "Connection refused"
class TestInferenceEndpoint:
"""Endpoint configuration."""
def test_defaults(self):
ep = InferenceEndpoint(name="test", url="http://localhost:8080")
assert ep.health_path == "/health"
assert ep.port == 8080
assert ep.restart_cmd == ""
def test_custom(self):
ep = InferenceEndpoint(
name="llama", url="http://localhost:8081",
port=8081, restart_cmd="llama-server --port 8081",
)
assert ep.port == 8081
assert "llama-server" in ep.restart_cmd
class TestFormatReport:
"""Report formatting."""
def test_all_ok(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0, process_alive=True),
HealthResult(name="test2", url="http://localhost:8081", status="ok", latency_ms=10.0, process_alive=True),
]
report = format_report(results)
assert "Health Check" in report
assert "test1" in report
assert "test2" in report
assert "DOWN" not in report
def test_with_down(self):
results = [
HealthResult(name="test1", url="http://localhost:8080", status="ok", latency_ms=5.0),
HealthResult(
name="test2", url="http://localhost:8081",
status="down", error="Connection refused", process_alive=False,
),
]
report = format_report(results)
assert "DOWN" in report
assert "Connection refused" in report
class TestFormatJson:
"""JSON output format."""
def test_valid_json(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok", latency_ms=5.0)]
output = format_json(results)
data = json.loads(output)
assert "timestamp" in data
assert "endpoints" in data
assert len(data["endpoints"]) == 1
assert data["endpoints"][0]["name"] == "test"
def test_none_error_serializes(self):
results = [HealthResult(name="test", url="http://localhost:8080", status="ok")]
output = format_json(results)
data = json.loads(output)
assert data["endpoints"][0]["error"] is None

View File

@@ -1,57 +0,0 @@
"""
Tests for message deduplication (#756).
"""
import pytest
import time
from gateway.message_dedup import MessageDeduplicator
class TestMessageDeduplicator:
def test_first_message_allowed(self):
dedup = MessageDeduplicator()
assert dedup.should_send("Hello") is True
def test_duplicate_suppressed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session1", "telegram") is False
def test_different_session_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session2", "telegram") is True
def test_different_platform_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("Hello", "session1", "discord") is True
def test_different_content_allowed(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello", "session1", "telegram")
assert dedup.should_send("World", "session1", "telegram") is True
def test_window_expiry(self):
dedup = MessageDeduplicator(window_seconds=1)
dedup.record_message("Hello", "session1", "telegram")
time.sleep(1.1)
assert dedup.should_send("Hello", "session1", "telegram") is True
def test_record_returns_uuid(self):
dedup = MessageDeduplicator()
msg_id = dedup.record_message("Hello")
assert msg_id is not None
assert len(msg_id) == 36 # UUID format
def test_stats(self):
dedup = MessageDeduplicator()
dedup.record_message("Hello")
dedup.record_message("Hello") # duplicate
stats = dedup.get_stats()
assert stats["total_records"] == 1
assert stats["suppressed_count"] == 1
if __name__ == "__main__":
pytest.main([__file__])