From b39aee90b47da70464cf3eef2d2637e60c47a617 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Wed, 22 Apr 2026 11:33:03 -0400 Subject: [PATCH] fix(pipelines): resume checkpoints and atomically claim jobs (#621) Load checkpoints from the checkpoint table when resuming paused jobs and atomically claim pending work so parallel runs do not process the same job twice. --- pipelines/orchestrator.py | 116 +++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 58 deletions(-) diff --git a/pipelines/orchestrator.py b/pipelines/orchestrator.py index 70550814..9e33bd93 100644 --- a/pipelines/orchestrator.py +++ b/pipelines/orchestrator.py @@ -235,6 +235,33 @@ class OrchestratorDB: conn = sqlite3.connect(str(self.db_path)) conn.row_factory = sqlite3.Row return conn + + def _job_from_row(self, row: sqlite3.Row) -> Job: + """Hydrate a Job from a DB row, loading checkpoints from either storage path.""" + checkpoint = None + if row['checkpoint']: + checkpoint = JobCheckpoint.from_dict(json.loads(row['checkpoint'])) + else: + checkpoint = self.get_checkpoint(row['id']) + + return Job( + id=row['id'], + pipeline=row['pipeline'], + task=json.loads(row['task']), + status=JobStatus(row['status']), + priority=JobPriority(row['priority']), + token_budget=row['token_budget'], + token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])), + retry_count=row['retry_count'], + max_retries=row['max_retries'], + created_at=row['created_at'], + started_at=row['started_at'], + completed_at=row['completed_at'], + error=row['error'], + result=json.loads(row['result']) if row['result'] else None, + checkpoint=checkpoint, + metadata=json.loads(row['metadata']) if row['metadata'] else {} + ) def save_job(self, job: Job): """Save or update a job.""" @@ -265,24 +292,7 @@ class OrchestratorDB: if not row: return None - return Job( - id=row['id'], - pipeline=row['pipeline'], - task=json.loads(row['task']), - status=JobStatus(row['status']), - priority=JobPriority(row['priority']), - token_budget=row['token_budget'], - token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])), - retry_count=row['retry_count'], - max_retries=row['max_retries'], - created_at=row['created_at'], - started_at=row['started_at'], - completed_at=row['completed_at'], - error=row['error'], - result=json.loads(row['result']) if row['result'] else None, - checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None, - metadata=json.loads(row['metadata']) if row['metadata'] else {} - ) + return self._job_from_row(row) def get_next_job(self, pipeline: Optional[str] = None) -> Optional[Job]: """Get next pending job (highest priority first).""" @@ -303,24 +313,34 @@ class OrchestratorDB: if not row: return None - return Job( - id=row['id'], - pipeline=row['pipeline'], - task=json.loads(row['task']), - status=JobStatus(row['status']), - priority=JobPriority(row['priority']), - token_budget=row['token_budget'], - token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])), - retry_count=row['retry_count'], - max_retries=row['max_retries'], - created_at=row['created_at'], - started_at=row['started_at'], - completed_at=row['completed_at'], - error=row['error'], - result=json.loads(row['result']) if row['result'] else None, - checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None, - metadata=json.loads(row['metadata']) if row['metadata'] else {} - ) + return self._job_from_row(row) + + def claim_next_job(self, pipeline: Optional[str] = None) -> Optional[Job]: + """Atomically claim the next pending job for execution.""" + conn = self._get_conn() + try: + conn.execute("BEGIN IMMEDIATE") + query = "SELECT id FROM jobs WHERE status = 'pending'" + params = [] + if pipeline: + query += " AND pipeline = ?" + params.append(pipeline) + query += " ORDER BY priority DESC, created_at ASC LIMIT 1" + row = conn.execute(query, params).fetchone() + if not row: + conn.commit() + return None + + started_at = time.time() + conn.execute( + "UPDATE jobs SET status = ?, started_at = ? WHERE id = ?", + (JobStatus.RUNNING.value, started_at, row['id']) + ) + claimed = conn.execute("SELECT * FROM jobs WHERE id = ?", (row['id'],)).fetchone() + conn.commit() + return self._job_from_row(claimed) + finally: + conn.close() def get_jobs_by_status(self, status: JobStatus, pipeline: Optional[str] = None) -> List[Job]: """Get all jobs with given status.""" @@ -338,27 +358,7 @@ class OrchestratorDB: rows = conn.execute(query, params).fetchall() conn.close() - return [ - Job( - id=row['id'], - pipeline=row['pipeline'], - task=json.loads(row['task']), - status=JobStatus(row['status']), - priority=JobPriority(row['priority']), - token_budget=row['token_budget'], - token_usage=TokenUsage.from_dict(json.loads(row['token_usage'])), - retry_count=row['retry_count'], - max_retries=row['max_retries'], - created_at=row['created_at'], - started_at=row['started_at'], - completed_at=row['completed_at'], - error=row['error'], - result=json.loads(row['result']) if row['result'] else None, - checkpoint=JobCheckpoint.from_dict(json.loads(row['checkpoint'])) if row['checkpoint'] else None, - metadata=json.loads(row['metadata']) if row['metadata'] else {} - ) - for row in rows - ] + return [self._job_from_row(row) for row in rows] def save_checkpoint(self, job_id: str, checkpoint: JobCheckpoint): """Save a checkpoint for a job.""" @@ -612,7 +612,7 @@ class PipelineOrchestrator: continue # Get next job - job = self.db.get_next_job(pipeline) + job = self.db.claim_next_job(pipeline) if not job: if not futures: