fix(gateway): persist watcher metadata in checkpoint for crash recovery (#1706)

Salvaged from PR #1573 by @eren-karakus0. Cherry-picked with authorship preserved.

Fixes #1143 — background process notifications resume after gateway restart.

Co-authored-by: Muhammet Eren Karakuş <erenkar950@gmail.com>
This commit is contained in:
Teknium
2026-03-17 03:52:15 -07:00
committed by GitHub
parent ce7418e274
commit d87655afff
5 changed files with 151 additions and 5 deletions

View File

@@ -984,6 +984,16 @@ class GatewayRunner:
):
self._schedule_update_notification_watch()
# Drain any recovered process watchers (from crash recovery checkpoint)
try:
from tools.process_registry import process_registry
while process_registry.pending_watchers:
watcher = process_registry.pending_watchers.pop(0)
asyncio.create_task(self._run_process_watcher(watcher))
logger.info("Resumed watcher for recovered process %s", watcher.get("session_id"))
except Exception as e:
logger.error("Recovered watcher setup error: %s", e)
# Start background session expiry watcher for proactive memory flushing
asyncio.create_task(self._session_expiry_watcher())

View File

@@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
return runner
def _watcher_dict(session_id="proc_test"):
return {
def _watcher_dict(session_id="proc_test", thread_id=""):
d = {
"session_id": session_id,
"check_interval": 0,
"platform": "telegram",
"chat_id": "123",
}
if thread_id:
d["thread_id"] = thread_id
return d
# ---------------------------------------------------------------------------
@@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode(
if expected_fragment is not None:
sent_message = adapter.send.await_args.args[1]
assert expected_fragment in sent_message
@pytest.mark.asyncio
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] == {"thread_id": "42"}
@pytest.mark.asyncio
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
"""When thread_id is empty, metadata should be None (general topic)."""
import tools.process_registry as pr_module
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
async def _instant_sleep(*_a, **_kw):
pass
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
runner = _build_runner(monkeypatch, tmp_path, "all")
adapter = runner.adapters[Platform.TELEGRAM]
await runner._run_process_watcher(_watcher_dict())
assert adapter.send.await_count == 1
_, kwargs = adapter.send.call_args
assert kwargs["metadata"] is None

View File

@@ -294,6 +294,61 @@ class TestCheckpoint:
recovered = registry.recover_from_checkpoint()
assert recovered == 0
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
s = _make_session()
s.watcher_platform = "telegram"
s.watcher_chat_id = "999"
s.watcher_thread_id = "42"
s.watcher_interval = 60
registry._running[s.id] = s
registry._write_checkpoint()
data = json.loads((tmp_path / "procs.json").read_text())
assert len(data) == 1
assert data[0]["watcher_platform"] == "telegram"
assert data[0]["watcher_chat_id"] == "999"
assert data[0]["watcher_thread_id"] == "42"
assert data[0]["watcher_interval"] == 60
def test_recover_enqueues_watchers(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(), # current process — guaranteed alive
"task_id": "t1",
"session_key": "sk1",
"watcher_platform": "telegram",
"watcher_chat_id": "123",
"watcher_thread_id": "42",
"watcher_interval": 60,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 1
w = registry.pending_watchers[0]
assert w["session_id"] == "proc_live"
assert w["platform"] == "telegram"
assert w["chat_id"] == "123"
assert w["thread_id"] == "42"
assert w["check_interval"] == 60
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
checkpoint = tmp_path / "procs.json"
checkpoint.write_text(json.dumps([{
"session_id": "proc_live",
"command": "sleep 999",
"pid": os.getpid(),
"task_id": "t1",
"watcher_interval": 0,
}]))
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
recovered = registry.recover_from_checkpoint()
assert recovered == 1
assert len(registry.pending_watchers) == 0
# =========================================================================
# Kill process

View File

@@ -78,6 +78,11 @@ class ProcessSession:
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
max_output_chars: int = MAX_OUTPUT_CHARS
detached: bool = False # True if recovered from crash (no pipe)
# Watcher/notification metadata (persisted for crash recovery)
watcher_platform: str = ""
watcher_chat_id: str = ""
watcher_thread_id: str = ""
watcher_interval: int = 0 # 0 = no watcher configured
_lock: threading.Lock = field(default_factory=threading.Lock)
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
@@ -709,6 +714,10 @@ class ProcessRegistry:
"started_at": s.started_at,
"task_id": s.task_id,
"session_key": s.session_key,
"watcher_platform": s.watcher_platform,
"watcher_chat_id": s.watcher_chat_id,
"watcher_thread_id": s.watcher_thread_id,
"watcher_interval": s.watcher_interval,
})
# Atomic write to avoid corruption on crash
@@ -755,12 +764,27 @@ class ProcessRegistry:
cwd=entry.get("cwd"),
started_at=entry.get("started_at", time.time()),
detached=True, # Can't read output, but can report status + kill
watcher_platform=entry.get("watcher_platform", ""),
watcher_chat_id=entry.get("watcher_chat_id", ""),
watcher_thread_id=entry.get("watcher_thread_id", ""),
watcher_interval=entry.get("watcher_interval", 0),
)
with self._lock:
self._running[session.id] = session
recovered += 1
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
# Re-enqueue watcher so gateway can resume notifications
if session.watcher_interval > 0:
self.pending_watchers.append({
"session_id": session.id,
"check_interval": session.watcher_interval,
"session_key": session.session_key,
"platform": session.watcher_platform,
"chat_id": session.watcher_chat_id,
"thread_id": session.watcher_thread_id,
})
# Clear the checkpoint (will be rewritten as processes finish)
try:
from utils import atomic_json_write

View File

@@ -1082,13 +1082,23 @@ def terminal_tool(
result_data["check_interval_note"] = (
f"Requested {check_interval}s raised to minimum 30s"
)
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
# Store on session for checkpoint persistence
proc_session.watcher_platform = watcher_platform
proc_session.watcher_chat_id = watcher_chat_id
proc_session.watcher_thread_id = watcher_thread_id
proc_session.watcher_interval = effective_interval
process_registry.pending_watchers.append({
"session_id": proc_session.id,
"check_interval": effective_interval,
"session_key": session_key,
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
"platform": watcher_platform,
"chat_id": watcher_chat_id,
"thread_id": watcher_thread_id,
})
return json.dumps(result_data, ensure_ascii=False)