refactor: extract atomic_json_write helper, add 24 checkpoint tests

Extract the duplicated temp-file + fsync + os.replace pattern from
batch_runner.py (1 instance) and process_registry.py (2 instances) into
a shared utils.atomic_json_write() function.

Add 12 tests for atomic_json_write covering: valid JSON, parent dir
creation, overwrite, crash safety (original preserved on error), no temp
file leaks, string paths, unicode, custom indent, concurrent writes.

Add 12 tests for batch_runner checkpoint behavior covering:
_save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock,
parent dirs, no temp leaks), _load_checkpoint (missing file, existing
data, corrupt JSON), and resume logic (preserves prior progress,
different run_name starts fresh).
This commit is contained in:
teknium1
2026-03-06 05:50:12 -08:00
parent c05c60665e
commit d63b363cde
5 changed files with 340 additions and 64 deletions

View File

@@ -29,8 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Lock
import traceback
import tempfile
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import fire
@@ -703,32 +701,12 @@ class BatchRunner:
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
def _atomic_write():
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(self.checkpoint_file.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, self.checkpoint_file)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
from utils import atomic_json_write
if lock:
with lock:
_atomic_write()
atomic_json_write(self.checkpoint_file, checkpoint_data)
else:
_atomic_write()
atomic_json_write(self.checkpoint_file, checkpoint_data)
def _scan_completed_prompts_by_content(self) -> set:
"""