Merge PR #297: Make batch_runner checkpoint incremental and atomic

Authored by aydnOktay. Three improvements to batch_runner fault tolerance:
1) Atomic checkpoint writes (temp file + fsync + os.replace) to prevent
   corruption on crashes — same pattern as auth.py's _save_auth_store().
2) Incremental checkpoints after each batch result instead of only at end,
   so interrupted runs can resume with minimal progress loss.
3) Resume loads existing checkpoint state instead of initializing empty,
   preventing clobber of prior progress.

Conflict resolved: kept both the incremental checkpoint logic (PR) and
the batch worker error handling (HEAD) in the imap_unordered loop.
This commit is contained in:
teknium1
2026-03-06 05:16:31 -08:00

View File

@@ -29,6 +29,7 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Lock
import traceback
import tempfile
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
@@ -701,14 +702,33 @@ class BatchRunner:
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
def _atomic_write():
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(self.checkpoint_file.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, self.checkpoint_file)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
_atomic_write()
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
_atomic_write()
def _scan_completed_prompts_by_content(self) -> set:
"""
@@ -833,13 +853,15 @@ class BatchRunner:
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Load existing checkpoint (so resume doesn't clobber prior progress)
checkpoint_data = self._load_checkpoint()
if checkpoint_data.get("run_name") != self.run_name:
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Prepare configuration for workers
config = {
@@ -861,7 +883,7 @@ class BatchRunner:
}
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Aggregate statistics across all batches
total_tool_stats = {}
@@ -870,6 +892,9 @@ class BatchRunner:
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
# Checkpoint writes happen in the parent process; keep a lock for safety.
checkpoint_lock = Lock()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@@ -915,6 +940,25 @@ class BatchRunner:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
# Incremental checkpoint update (so resume works after crash)
try:
batch_num = result.get('batch_num')
completed = result.get('completed_prompts', []) or []
completed_prompts_set.update(completed)
if isinstance(batch_num, int):
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
'processed': result.get('processed', 0),
'skipped': result.get('skipped', 0),
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
}
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
# Don't fail the run if checkpoint write fails
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
except Exception as e:
logger.error("Batch worker failed: %s", e, exc_info=True)
raise
@@ -946,9 +990,12 @@ class BatchRunner:
for key in total_reasoning_stats:
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint (best-effort; incremental writes already happened)
try:
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}")
# Calculate success rates
for tool_name in total_tool_stats: