Make batch_runner checkpoint incremental and atomic

This commit is contained in:
aydnOktay
2026-03-03 01:43:07 +03:00
parent 669e4d0297
commit ac6d747fa6

View File

@@ -29,6 +29,7 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Lock
import traceback
import tempfile
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
@@ -649,14 +650,33 @@ class BatchRunner:
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
def _atomic_write():
"""Write checkpoint atomically (temp file + replace) to avoid corruption on crash."""
self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(
dir=str(self.checkpoint_file.parent),
prefix='.checkpoint_',
suffix='.tmp',
)
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, self.checkpoint_file)
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
_atomic_write()
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
_atomic_write()
def _scan_completed_prompts_by_content(self) -> set:
"""
@@ -781,13 +801,15 @@ class BatchRunner:
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Load existing checkpoint (so resume doesn't clobber prior progress)
checkpoint_data = self._load_checkpoint()
if checkpoint_data.get("run_name") != self.run_name:
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Prepare configuration for workers
config = {
@@ -809,7 +831,7 @@ class BatchRunner:
}
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Aggregate statistics across all batches
total_tool_stats = {}
@@ -818,6 +840,9 @@ class BatchRunner:
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
# Checkpoint writes happen in the parent process; keep a lock for safety.
checkpoint_lock = Lock()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@@ -863,6 +888,25 @@ class BatchRunner:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
# Incremental checkpoint update (so resume works after crash)
try:
batch_num = result.get('batch_num')
completed = result.get('completed_prompts', []) or []
completed_prompts_set.update(completed)
if isinstance(batch_num, int):
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
'processed': result.get('processed', 0),
'skipped': result.get('skipped', 0),
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
}
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
# Don't fail the run if checkpoint write fails
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
finally:
root_logger.setLevel(original_level)
@@ -891,9 +935,12 @@ class BatchRunner:
for key in total_reasoning_stats:
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint (best-effort; incremental writes already happened)
try:
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}")
# Calculate success rates
for tool_name in total_tool_stats: