fix(gateway): persist watcher metadata in checkpoint for crash recovery (#1706)
Salvaged from PR #1573 by @eren-karakus0. Cherry-picked with authorship preserved. Fixes #1143 — background process notifications resume after gateway restart. Co-authored-by: Muhammet Eren Karakuş <erenkar950@gmail.com>
This commit is contained in:
@@ -984,6 +984,16 @@ class GatewayRunner:
|
||||
):
|
||||
self._schedule_update_notification_watch()
|
||||
|
||||
# Drain any recovered process watchers (from crash recovery checkpoint)
|
||||
try:
|
||||
from tools.process_registry import process_registry
|
||||
while process_registry.pending_watchers:
|
||||
watcher = process_registry.pending_watchers.pop(0)
|
||||
asyncio.create_task(self._run_process_watcher(watcher))
|
||||
logger.info("Resumed watcher for recovered process %s", watcher.get("session_id"))
|
||||
except Exception as e:
|
||||
logger.error("Recovered watcher setup error: %s", e)
|
||||
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
|
||||
@@ -50,13 +50,16 @@ def _build_runner(monkeypatch, tmp_path, mode: str) -> GatewayRunner:
|
||||
return runner
|
||||
|
||||
|
||||
def _watcher_dict(session_id="proc_test"):
|
||||
return {
|
||||
def _watcher_dict(session_id="proc_test", thread_id=""):
|
||||
d = {
|
||||
"session_id": session_id,
|
||||
"check_interval": 0,
|
||||
"platform": "telegram",
|
||||
"chat_id": "123",
|
||||
}
|
||||
if thread_id:
|
||||
d["thread_id"] = thread_id
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -196,3 +199,47 @@ async def test_run_process_watcher_respects_notification_mode(
|
||||
if expected_fragment is not None:
|
||||
sent_message = adapter.send.await_args.args[1]
|
||||
assert expected_fragment in sent_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_id_passed_to_send(monkeypatch, tmp_path):
|
||||
"""thread_id from watcher dict is forwarded as metadata to adapter.send()."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict(thread_id="42"))
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] == {"thread_id": "42"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_thread_id_sends_no_metadata(monkeypatch, tmp_path):
|
||||
"""When thread_id is empty, metadata should be None (general topic)."""
|
||||
import tools.process_registry as pr_module
|
||||
|
||||
sessions = [SimpleNamespace(output_buffer="done\n", exited=True, exit_code=0)]
|
||||
monkeypatch.setattr(pr_module, "process_registry", _FakeRegistry(sessions))
|
||||
|
||||
async def _instant_sleep(*_a, **_kw):
|
||||
pass
|
||||
monkeypatch.setattr(asyncio, "sleep", _instant_sleep)
|
||||
|
||||
runner = _build_runner(monkeypatch, tmp_path, "all")
|
||||
adapter = runner.adapters[Platform.TELEGRAM]
|
||||
|
||||
await runner._run_process_watcher(_watcher_dict())
|
||||
|
||||
assert adapter.send.await_count == 1
|
||||
_, kwargs = adapter.send.call_args
|
||||
assert kwargs["metadata"] is None
|
||||
|
||||
@@ -294,6 +294,61 @@ class TestCheckpoint:
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 0
|
||||
|
||||
def test_write_checkpoint_includes_watcher_metadata(self, registry, tmp_path):
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", tmp_path / "procs.json"):
|
||||
s = _make_session()
|
||||
s.watcher_platform = "telegram"
|
||||
s.watcher_chat_id = "999"
|
||||
s.watcher_thread_id = "42"
|
||||
s.watcher_interval = 60
|
||||
registry._running[s.id] = s
|
||||
registry._write_checkpoint()
|
||||
|
||||
data = json.loads((tmp_path / "procs.json").read_text())
|
||||
assert len(data) == 1
|
||||
assert data[0]["watcher_platform"] == "telegram"
|
||||
assert data[0]["watcher_chat_id"] == "999"
|
||||
assert data[0]["watcher_thread_id"] == "42"
|
||||
assert data[0]["watcher_interval"] == 60
|
||||
|
||||
def test_recover_enqueues_watchers(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(), # current process — guaranteed alive
|
||||
"task_id": "t1",
|
||||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 60,
|
||||
}]))
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 1
|
||||
w = registry.pending_watchers[0]
|
||||
assert w["session_id"] == "proc_live"
|
||||
assert w["platform"] == "telegram"
|
||||
assert w["chat_id"] == "123"
|
||||
assert w["thread_id"] == "42"
|
||||
assert w["check_interval"] == 60
|
||||
|
||||
def test_recover_skips_watcher_when_no_interval(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(),
|
||||
"task_id": "t1",
|
||||
"watcher_interval": 0,
|
||||
}]))
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 0
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kill process
|
||||
|
||||
@@ -78,6 +78,11 @@ class ProcessSession:
|
||||
output_buffer: str = "" # Rolling output (last MAX_OUTPUT_CHARS)
|
||||
max_output_chars: int = MAX_OUTPUT_CHARS
|
||||
detached: bool = False # True if recovered from crash (no pipe)
|
||||
# Watcher/notification metadata (persisted for crash recovery)
|
||||
watcher_platform: str = ""
|
||||
watcher_chat_id: str = ""
|
||||
watcher_thread_id: str = ""
|
||||
watcher_interval: int = 0 # 0 = no watcher configured
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_reader_thread: Optional[threading.Thread] = field(default=None, repr=False)
|
||||
_pty: Any = field(default=None, repr=False) # ptyprocess handle (when use_pty=True)
|
||||
@@ -709,6 +714,10 @@ class ProcessRegistry:
|
||||
"started_at": s.started_at,
|
||||
"task_id": s.task_id,
|
||||
"session_key": s.session_key,
|
||||
"watcher_platform": s.watcher_platform,
|
||||
"watcher_chat_id": s.watcher_chat_id,
|
||||
"watcher_thread_id": s.watcher_thread_id,
|
||||
"watcher_interval": s.watcher_interval,
|
||||
})
|
||||
|
||||
# Atomic write to avoid corruption on crash
|
||||
@@ -755,12 +764,27 @@ class ProcessRegistry:
|
||||
cwd=entry.get("cwd"),
|
||||
started_at=entry.get("started_at", time.time()),
|
||||
detached=True, # Can't read output, but can report status + kill
|
||||
watcher_platform=entry.get("watcher_platform", ""),
|
||||
watcher_chat_id=entry.get("watcher_chat_id", ""),
|
||||
watcher_thread_id=entry.get("watcher_thread_id", ""),
|
||||
watcher_interval=entry.get("watcher_interval", 0),
|
||||
)
|
||||
with self._lock:
|
||||
self._running[session.id] = session
|
||||
recovered += 1
|
||||
logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid)
|
||||
|
||||
# Re-enqueue watcher so gateway can resume notifications
|
||||
if session.watcher_interval > 0:
|
||||
self.pending_watchers.append({
|
||||
"session_id": session.id,
|
||||
"check_interval": session.watcher_interval,
|
||||
"session_key": session.session_key,
|
||||
"platform": session.watcher_platform,
|
||||
"chat_id": session.watcher_chat_id,
|
||||
"thread_id": session.watcher_thread_id,
|
||||
})
|
||||
|
||||
# Clear the checkpoint (will be rewritten as processes finish)
|
||||
try:
|
||||
from utils import atomic_json_write
|
||||
|
||||
@@ -1082,13 +1082,23 @@ def terminal_tool(
|
||||
result_data["check_interval_note"] = (
|
||||
f"Requested {check_interval}s raised to minimum 30s"
|
||||
)
|
||||
watcher_platform = os.getenv("HERMES_SESSION_PLATFORM", "")
|
||||
watcher_chat_id = os.getenv("HERMES_SESSION_CHAT_ID", "")
|
||||
watcher_thread_id = os.getenv("HERMES_SESSION_THREAD_ID", "")
|
||||
|
||||
# Store on session for checkpoint persistence
|
||||
proc_session.watcher_platform = watcher_platform
|
||||
proc_session.watcher_chat_id = watcher_chat_id
|
||||
proc_session.watcher_thread_id = watcher_thread_id
|
||||
proc_session.watcher_interval = effective_interval
|
||||
|
||||
process_registry.pending_watchers.append({
|
||||
"session_id": proc_session.id,
|
||||
"check_interval": effective_interval,
|
||||
"session_key": session_key,
|
||||
"platform": os.getenv("HERMES_SESSION_PLATFORM", ""),
|
||||
"chat_id": os.getenv("HERMES_SESSION_CHAT_ID", ""),
|
||||
"thread_id": os.getenv("HERMES_SESSION_THREAD_ID", ""),
|
||||
"platform": watcher_platform,
|
||||
"chat_id": watcher_chat_id,
|
||||
"thread_id": watcher_thread_id,
|
||||
})
|
||||
|
||||
return json.dumps(result_data, ensure_ascii=False)
|
||||
|
||||
Reference in New Issue
Block a user