fix: call _stop_training_run on early-return failure paths
The 4 early-return paths in _spawn_training_run (API exit, trainer
exit, env not found, env exit) were doing manual process.terminate()
or returning without cleanup, leaking open log file handles. Now all
paths call _stop_training_run() which handles both process termination
and file handle closure.
Also adds 12 tests for _stop_training_run covering file handle
cleanup, process termination, status transitions, and edge cases.
Inspired by PR #715 (0xbyt4) which identified the early-return issue.
Core file handle fix was already on main via e28dc13 (memosr.eth).
This commit is contained in:
142
tests/tools/test_rl_training_tool.py
Normal file
142
tests/tools/test_rl_training_tool.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for rl_training_tool.py — file handle lifecycle and cleanup.
|
||||
|
||||
Verifies that _stop_training_run properly closes log file handles,
|
||||
terminates processes, and handles edge cases on failure paths.
|
||||
Inspired by PR #715 (0xbyt4).
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.rl_training_tool import RunState, _stop_training_run
|
||||
|
||||
|
||||
def _make_run_state(**overrides) -> RunState:
|
||||
"""Create a minimal RunState for testing."""
|
||||
defaults = {
|
||||
"run_id": "test-run-001",
|
||||
"environment": "test_env",
|
||||
"config": {},
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return RunState(**defaults)
|
||||
|
||||
|
||||
class TestStopTrainingRunFileHandles:
|
||||
"""Verify that _stop_training_run closes log file handles stored as attributes."""
|
||||
|
||||
def test_closes_all_log_file_handles(self):
|
||||
state = _make_run_state()
|
||||
files = {}
|
||||
for attr in ("api_log_file", "trainer_log_file", "env_log_file"):
|
||||
fh = MagicMock()
|
||||
setattr(state, attr, fh)
|
||||
files[attr] = fh
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
for attr, fh in files.items():
|
||||
fh.close.assert_called_once()
|
||||
assert getattr(state, attr) is None
|
||||
|
||||
def test_clears_file_attrs_to_none(self):
|
||||
state = _make_run_state()
|
||||
state.api_log_file = MagicMock()
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
assert state.api_log_file is None
|
||||
|
||||
def test_close_exception_does_not_propagate(self):
|
||||
"""If a file handle .close() raises, it must not crash."""
|
||||
state = _make_run_state()
|
||||
bad_fh = MagicMock()
|
||||
bad_fh.close.side_effect = OSError("already closed")
|
||||
good_fh = MagicMock()
|
||||
state.api_log_file = bad_fh
|
||||
state.trainer_log_file = good_fh
|
||||
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
bad_fh.close.assert_called_once()
|
||||
good_fh.close.assert_called_once()
|
||||
|
||||
def test_handles_missing_file_attrs(self):
|
||||
"""RunState without log file attrs should not crash."""
|
||||
state = _make_run_state()
|
||||
# No log file attrs set at all — getattr(..., None) should handle it
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
|
||||
class TestStopTrainingRunProcesses:
|
||||
"""Verify that _stop_training_run terminates processes correctly."""
|
||||
|
||||
def test_terminates_running_processes(self):
|
||||
state = _make_run_state()
|
||||
for attr in ("api_process", "trainer_process", "env_process"):
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = None # still running
|
||||
setattr(state, attr, proc)
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
for attr in ("api_process", "trainer_process", "env_process"):
|
||||
getattr(state, attr).terminate.assert_called_once()
|
||||
|
||||
def test_does_not_terminate_exited_processes(self):
|
||||
state = _make_run_state()
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = 0 # already exited
|
||||
state.api_process = proc
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
proc.terminate.assert_not_called()
|
||||
|
||||
def test_handles_none_processes(self):
|
||||
state = _make_run_state()
|
||||
# All process attrs are None by default
|
||||
_stop_training_run(state) # should not raise
|
||||
|
||||
def test_handles_mixed_running_and_exited_processes(self):
|
||||
state = _make_run_state()
|
||||
# api still running
|
||||
api = MagicMock()
|
||||
api.poll.return_value = None
|
||||
state.api_process = api
|
||||
# trainer already exited
|
||||
trainer = MagicMock()
|
||||
trainer.poll.return_value = 0
|
||||
state.trainer_process = trainer
|
||||
# env is None
|
||||
state.env_process = None
|
||||
|
||||
_stop_training_run(state)
|
||||
|
||||
api.terminate.assert_called_once()
|
||||
trainer.terminate.assert_not_called()
|
||||
|
||||
|
||||
class TestStopTrainingRunStatus:
|
||||
"""Verify status transitions in _stop_training_run."""
|
||||
|
||||
def test_sets_status_to_stopped_when_running(self):
|
||||
state = _make_run_state(status="running")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "stopped"
|
||||
|
||||
def test_does_not_change_status_when_failed(self):
|
||||
state = _make_run_state(status="failed")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "failed"
|
||||
|
||||
def test_does_not_change_status_when_pending(self):
|
||||
state = _make_run_state(status="pending")
|
||||
_stop_training_run(state)
|
||||
assert state.status == "pending"
|
||||
|
||||
def test_no_crash_with_no_processes_and_no_files(self):
|
||||
state = _make_run_state()
|
||||
_stop_training_run(state) # should not raise
|
||||
assert state.status == "pending"
|
||||
@@ -340,6 +340,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
||||
if run_state.api_process.poll() is not None:
|
||||
run_state.status = "failed"
|
||||
run_state.error_message = f"API server exited with code {run_state.api_process.returncode}. Check {api_log}"
|
||||
_stop_training_run(run_state)
|
||||
return
|
||||
|
||||
print(f"[{run_id}] Atropos API server started")
|
||||
@@ -364,8 +365,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
||||
if run_state.trainer_process.poll() is not None:
|
||||
run_state.status = "failed"
|
||||
run_state.error_message = f"Trainer exited with code {run_state.trainer_process.returncode}. Check {trainer_log}"
|
||||
if run_state.api_process:
|
||||
run_state.api_process.terminate()
|
||||
_stop_training_run(run_state)
|
||||
return
|
||||
|
||||
print(f"[{run_id}] Trainer started, inference server on port 8001")
|
||||
@@ -384,6 +384,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
||||
if not env_info:
|
||||
run_state.status = "failed"
|
||||
run_state.error_message = f"Environment '{run_state.environment}' not found"
|
||||
_stop_training_run(run_state)
|
||||
return
|
||||
|
||||
print(f"[{run_id}] Starting environment: {env_info.file_path} serve")
|
||||
@@ -403,10 +404,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
||||
if run_state.env_process.poll() is not None:
|
||||
run_state.status = "failed"
|
||||
run_state.error_message = f"Environment exited with code {run_state.env_process.returncode}. Check {env_log}"
|
||||
if run_state.trainer_process:
|
||||
run_state.trainer_process.terminate()
|
||||
if run_state.api_process:
|
||||
run_state.api_process.terminate()
|
||||
_stop_training_run(run_state)
|
||||
return
|
||||
|
||||
run_state.status = "running"
|
||||
|
||||
Reference in New Issue
Block a user