diff --git a/gateway/run.py b/gateway/run.py index 7bfb8059e..99ed538c1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3502,10 +3502,12 @@ class GatewayRunner: os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id if context.source.chat_name: os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name + if context.source.thread_id: + os.environ["HERMES_SESSION_THREAD_ID"] = str(context.source.thread_id) def _clear_session_env(self) -> None: """Clear session environment variables.""" - for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]: + for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME", "HERMES_SESSION_THREAD_ID"]: if var in os.environ: del os.environ[var] diff --git a/tests/gateway/test_session_env.py b/tests/gateway/test_session_env.py new file mode 100644 index 000000000..596df89ec --- /dev/null +++ b/tests/gateway/test_session_env.py @@ -0,0 +1,45 @@ +import os + +from gateway.config import Platform +from gateway.run import GatewayRunner +from gateway.session import SessionContext, SessionSource + + +def test_set_session_env_includes_thread_id(monkeypatch): + runner = object.__new__(GatewayRunner) + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="-1001", + chat_name="Group", + chat_type="group", + thread_id="17585", + ) + context = SessionContext(source=source, connected_platforms=[], home_channels={}) + + monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False) + monkeypatch.delenv("HERMES_SESSION_CHAT_ID", raising=False) + monkeypatch.delenv("HERMES_SESSION_CHAT_NAME", raising=False) + monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) + + runner._set_session_env(context) + + assert os.getenv("HERMES_SESSION_PLATFORM") == "telegram" + assert os.getenv("HERMES_SESSION_CHAT_ID") == "-1001" + assert os.getenv("HERMES_SESSION_CHAT_NAME") == "Group" + assert os.getenv("HERMES_SESSION_THREAD_ID") == "17585" + + +def test_clear_session_env_removes_thread_id(monkeypatch): + runner = object.__new__(GatewayRunner) + + monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram") + monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "-1001") + monkeypatch.setenv("HERMES_SESSION_CHAT_NAME", "Group") + monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "17585") + + runner._clear_session_env() + + assert os.getenv("HERMES_SESSION_PLATFORM") is None + assert os.getenv("HERMES_SESSION_CHAT_ID") is None + assert os.getenv("HERMES_SESSION_CHAT_NAME") is None + assert os.getenv("HERMES_SESSION_THREAD_ID") is None diff --git a/tests/tools/test_cronjob_tools.py b/tests/tools/test_cronjob_tools.py index 293622070..2a9197083 100644 --- a/tests/tools/test_cronjob_tools.py +++ b/tests/tools/test_cronjob_tools.py @@ -153,6 +153,36 @@ class TestScheduleCronjob: assert job["provider"] == "custom" assert job["base_url"] == "http://127.0.0.1:4000/v1" + def test_thread_id_captured_in_origin(self, monkeypatch): + monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram") + monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "123456") + monkeypatch.setenv("HERMES_SESSION_THREAD_ID", "42") + import cron.jobs as _jobs + created = json.loads(schedule_cronjob( + prompt="Thread test", + schedule="every 1h", + deliver="origin", + )) + assert created["success"] is True + job_id = created["job_id"] + job = _jobs.get_job(job_id) + assert job["origin"]["thread_id"] == "42" + + def test_thread_id_absent_when_not_set(self, monkeypatch): + monkeypatch.setenv("HERMES_SESSION_PLATFORM", "telegram") + monkeypatch.setenv("HERMES_SESSION_CHAT_ID", "123456") + monkeypatch.delenv("HERMES_SESSION_THREAD_ID", raising=False) + import cron.jobs as _jobs + created = json.loads(schedule_cronjob( + prompt="No thread test", + schedule="every 1h", + deliver="origin", + )) + assert created["success"] is True + job_id = created["job_id"] + job = _jobs.get_job(job_id) + assert job["origin"].get("thread_id") is None + # ========================================================================= # list_cronjobs diff --git a/tools/cronjob_tools.py b/tools/cronjob_tools.py index 9ff7127bb..7a0daaf88 100644 --- a/tools/cronjob_tools.py +++ b/tools/cronjob_tools.py @@ -72,6 +72,7 @@ def _origin_from_env() -> Optional[Dict[str, str]]: "platform": origin_platform, "chat_id": origin_chat_id, "chat_name": os.getenv("HERMES_SESSION_CHAT_NAME"), + "thread_id": os.getenv("HERMES_SESSION_THREAD_ID"), } return None