From ac6d747fa6105a8d54a4a4c6d7013369497959a6 Mon Sep 17 00:00:00 2001 From: aydnOktay Date: Tue, 3 Mar 2026 01:43:07 +0300 Subject: [PATCH] Make batch_runner checkpoint incremental and atomic --- batch_runner.py | 79 +++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/batch_runner.py b/batch_runner.py index 54a1a585..894f7c09 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -29,6 +29,7 @@ from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from multiprocessing import Pool, Lock import traceback +import tempfile from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn from rich.console import Console @@ -649,14 +650,33 @@ class BatchRunner: lock (Lock): Optional lock for thread-safe access """ checkpoint_data["last_updated"] = datetime.now().isoformat() - + + def _atomic_write(): + """Write checkpoint atomically (temp file + replace) to avoid corruption on crash.""" + self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp( + dir=str(self.checkpoint_file.parent), + prefix='.checkpoint_', + suffix='.tmp', + ) + try: + with os.fdopen(fd, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, self.checkpoint_file) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + if lock: with lock: - with open(self.checkpoint_file, 'w', encoding='utf-8') as f: - json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + _atomic_write() else: - with open(self.checkpoint_file, 'w', encoding='utf-8') as f: - json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + _atomic_write() def _scan_completed_prompts_by_content(self) -> set: """ @@ -781,13 +801,15 @@ class BatchRunner: print(f" New batches created: {len(batches_to_process)}") print("=" * 70 + "\n") - # Initialize checkpoint data (needed for saving at the end) - checkpoint_data = { - "run_name": self.run_name, - "completed_prompts": [], - "batch_stats": {}, - "last_updated": None - } + # Load existing checkpoint (so resume doesn't clobber prior progress) + checkpoint_data = self._load_checkpoint() + if checkpoint_data.get("run_name") != self.run_name: + checkpoint_data = { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } # Prepare configuration for workers config = { @@ -809,7 +831,7 @@ class BatchRunner: } # For backward compatibility, still track by index (but this is secondary to content matching) - completed_prompts_set = set() + completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) # Aggregate statistics across all batches total_tool_stats = {} @@ -818,6 +840,9 @@ class BatchRunner: print(f"\n🔧 Initializing {self.num_workers} worker processes...") + # Checkpoint writes happen in the parent process; keep a lock for safety. + checkpoint_lock = Lock() + # Process batches in parallel with Pool(processes=self.num_workers) as pool: # Create tasks for each batch @@ -863,6 +888,25 @@ class BatchRunner: for result in pool.imap_unordered(_process_batch_worker, tasks): results.append(result) progress.update(task, advance=1) + + # Incremental checkpoint update (so resume works after crash) + try: + batch_num = result.get('batch_num') + completed = result.get('completed_prompts', []) or [] + completed_prompts_set.update(completed) + + if isinstance(batch_num, int): + checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = { + 'processed': result.get('processed', 0), + 'skipped': result.get('skipped', 0), + 'discarded_no_reasoning': result.get('discarded_no_reasoning', 0), + } + + checkpoint_data['completed_prompts'] = sorted(completed_prompts_set) + self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) + except Exception as ckpt_err: + # Don't fail the run if checkpoint write fails + print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}") finally: root_logger.setLevel(original_level) @@ -891,9 +935,12 @@ class BatchRunner: for key in total_reasoning_stats: total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) - # Save final checkpoint - checkpoint_data["completed_prompts"] = all_completed_prompts - self._save_checkpoint(checkpoint_data) + # Save final checkpoint (best-effort; incremental writes already happened) + try: + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data, lock=checkpoint_lock) + except Exception as ckpt_err: + print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}") # Calculate success rates for tool_name in total_tool_stats: