Merge PR #297: Make batch_runner checkpoint incremental and atomic
Authored by aydnOktay. Three improvements to batch_runner fault tolerance: 1) Atomic checkpoint writes (temp file + fsync + os.replace) to prevent corruption on crashes — same pattern as auth.py's _save_auth_store(). 2) Incremental checkpoints after each batch result instead of only at end, so interrupted runs can resume with minimal progress loss. 3) Resume loads existing checkpoint state instead of initializing empty, preventing clobber of prior progress. Conflict resolved: kept both the incremental checkpoint logic (PR) and the batch worker error handling (HEAD) in the imap_unordered loop.
This commit is contained in:
@@ -29,6 +29,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool, Lock
|
||||
import traceback
|
||||
import tempfile
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
@@ -701,14 +702,33 @@ class BatchRunner:
|
||||
lock (Lock): Optional lock for thread-safe access
|
||||
"""
|
||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
|
||||
def _atomic_write():
|
||||
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
|
||||
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd, tmp_path = tempfile.mkstemp(
|
||||
dir=str(self.checkpoint_file.parent),
|
||||
prefix='.checkpoint_',
|
||||
suffix='.tmp',
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, self.checkpoint_file)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
if lock:
|
||||
with lock:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
_atomic_write()
|
||||
else:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
_atomic_write()
|
||||
|
||||
def _scan_completed_prompts_by_content(self) -> set:
|
||||
"""
|
||||
@@ -833,13 +853,15 @@ class BatchRunner:
|
||||
print(f" New batches created: {len(batches_to_process)}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
# Initialize checkpoint data (needed for saving at the end)
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
# Load existing checkpoint (so resume doesn't clobber prior progress)
|
||||
checkpoint_data = self._load_checkpoint()
|
||||
if checkpoint_data.get("run_name") != self.run_name:
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
@@ -861,7 +883,7 @@ class BatchRunner:
|
||||
}
|
||||
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
completed_prompts_set = set()
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
@@ -870,6 +892,9 @@ class BatchRunner:
|
||||
|
||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||
|
||||
# Checkpoint writes happen in the parent process; keep a lock for safety.
|
||||
checkpoint_lock = Lock()
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
@@ -915,6 +940,25 @@ class BatchRunner:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
results.append(result)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
# Incremental checkpoint update (so resume works after crash)
|
||||
try:
|
||||
batch_num = result.get('batch_num')
|
||||
completed = result.get('completed_prompts', []) or []
|
||||
completed_prompts_set.update(completed)
|
||||
|
||||
if isinstance(batch_num, int):
|
||||
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
|
||||
'processed': result.get('processed', 0),
|
||||
'skipped': result.get('skipped', 0),
|
||||
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
|
||||
}
|
||||
|
||||
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
|
||||
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||
except Exception as ckpt_err:
|
||||
# Don't fail the run if checkpoint write fails
|
||||
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
|
||||
except Exception as e:
|
||||
logger.error("Batch worker failed: %s", e, exc_info=True)
|
||||
raise
|
||||
@@ -946,9 +990,12 @@ class BatchRunner:
|
||||
for key in total_reasoning_stats:
|
||||
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
# Save final checkpoint (best-effort; incremental writes already happened)
|
||||
try:
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
|
||||
except Exception as ckpt_err:
|
||||
print(f"âš ï¸ Warning: Failed to save final checkpoint: {ckpt_err}")
|
||||
|
||||
# Calculate success rates
|
||||
for tool_name in total_tool_stats:
|
||||
|
||||
Reference in New Issue
Block a user