diff --git a/tests/tools/test_rl_training_tool.py b/tests/tools/test_rl_training_tool.py new file mode 100644 index 000000000..8b68ea8d9 --- /dev/null +++ b/tests/tools/test_rl_training_tool.py @@ -0,0 +1,142 @@ +"""Tests for rl_training_tool.py — file handle lifecycle and cleanup. + +Verifies that _stop_training_run properly closes log file handles, +terminates processes, and handles edge cases on failure paths. +Inspired by PR #715 (0xbyt4). +""" + +from unittest.mock import MagicMock + +import pytest + +from tools.rl_training_tool import RunState, _stop_training_run + + +def _make_run_state(**overrides) -> RunState: + """Create a minimal RunState for testing.""" + defaults = { + "run_id": "test-run-001", + "environment": "test_env", + "config": {}, + } + defaults.update(overrides) + return RunState(**defaults) + + +class TestStopTrainingRunFileHandles: + """Verify that _stop_training_run closes log file handles stored as attributes.""" + + def test_closes_all_log_file_handles(self): + state = _make_run_state() + files = {} + for attr in ("api_log_file", "trainer_log_file", "env_log_file"): + fh = MagicMock() + setattr(state, attr, fh) + files[attr] = fh + + _stop_training_run(state) + + for attr, fh in files.items(): + fh.close.assert_called_once() + assert getattr(state, attr) is None + + def test_clears_file_attrs_to_none(self): + state = _make_run_state() + state.api_log_file = MagicMock() + + _stop_training_run(state) + + assert state.api_log_file is None + + def test_close_exception_does_not_propagate(self): + """If a file handle .close() raises, it must not crash.""" + state = _make_run_state() + bad_fh = MagicMock() + bad_fh.close.side_effect = OSError("already closed") + good_fh = MagicMock() + state.api_log_file = bad_fh + state.trainer_log_file = good_fh + + _stop_training_run(state) # should not raise + + bad_fh.close.assert_called_once() + good_fh.close.assert_called_once() + + def test_handles_missing_file_attrs(self): + """RunState without log file attrs should not crash.""" + state = _make_run_state() + # No log file attrs set at all — getattr(..., None) should handle it + _stop_training_run(state) # should not raise + + +class TestStopTrainingRunProcesses: + """Verify that _stop_training_run terminates processes correctly.""" + + def test_terminates_running_processes(self): + state = _make_run_state() + for attr in ("api_process", "trainer_process", "env_process"): + proc = MagicMock() + proc.poll.return_value = None # still running + setattr(state, attr, proc) + + _stop_training_run(state) + + for attr in ("api_process", "trainer_process", "env_process"): + getattr(state, attr).terminate.assert_called_once() + + def test_does_not_terminate_exited_processes(self): + state = _make_run_state() + proc = MagicMock() + proc.poll.return_value = 0 # already exited + state.api_process = proc + + _stop_training_run(state) + + proc.terminate.assert_not_called() + + def test_handles_none_processes(self): + state = _make_run_state() + # All process attrs are None by default + _stop_training_run(state) # should not raise + + def test_handles_mixed_running_and_exited_processes(self): + state = _make_run_state() + # api still running + api = MagicMock() + api.poll.return_value = None + state.api_process = api + # trainer already exited + trainer = MagicMock() + trainer.poll.return_value = 0 + state.trainer_process = trainer + # env is None + state.env_process = None + + _stop_training_run(state) + + api.terminate.assert_called_once() + trainer.terminate.assert_not_called() + + +class TestStopTrainingRunStatus: + """Verify status transitions in _stop_training_run.""" + + def test_sets_status_to_stopped_when_running(self): + state = _make_run_state(status="running") + _stop_training_run(state) + assert state.status == "stopped" + + def test_does_not_change_status_when_failed(self): + state = _make_run_state(status="failed") + _stop_training_run(state) + assert state.status == "failed" + + def test_does_not_change_status_when_pending(self): + state = _make_run_state(status="pending") + _stop_training_run(state) + assert state.status == "pending" + + def test_no_crash_with_no_processes_and_no_files(self): + state = _make_run_state() + _stop_training_run(state) # should not raise + assert state.status == "pending" diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index bf4c6ad64..61b8a7088 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -340,6 +340,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): if run_state.api_process.poll() is not None: run_state.status = "failed" run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}" + _stop_training_run(run_state) return print(f"[{run_id}] Atropos API server started") @@ -364,8 +365,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): if run_state.trainer_process.poll() is not None: run_state.status = "failed" run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}" - if run_state.api_process: - run_state.api_process.terminate() + _stop_training_run(run_state) return print(f"[{run_id}] Trainer started, inference server on port 8001") @@ -384,6 +384,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): if not env_info: run_state.status = "failed" run_state.error_message = f"Environment '{run_state.environment}' not found" + _stop_training_run(run_state) return print(f"[{run_id}] Starting environment: {env_info.file_path} serve") @@ -403,10 +404,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): if run_state.env_process.poll() is not None: run_state.status = "failed" run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}" - if run_state.trainer_process: - run_state.trainer_process.terminate() - if run_state.api_process: - run_state.api_process.terminate() + _stop_training_run(run_state) return run_state.status = "running"