#!/usr/bin/env python3 """ Parallel Agent System for Allegro-Primus Manages multiple concurrent worktrees and task distribution """ import json import multiprocessing import threading import queue import time from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from dataclasses import dataclass, asdict from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Callable, Any, Set import hashlib try: from .worktree_manager import WorktreeManager from .task_runner import TaskRunner, TaskResult except ImportError: from worktree_manager import WorktreeManager from task_runner import TaskRunner, TaskResult @dataclass class ParallelTask: """A task to be executed in parallel""" task_id: str issue_number: str task_type: str # 'command', 'script', 'function' payload: Any # Command list, script string, or function dependencies: List[str] # Task IDs that must complete first priority: int = 0 timeout: int = 300 auto_commit: bool = False metadata: Dict = None def __post_init__(self): if self.metadata is None: self.metadata = {} @dataclass class ConflictInfo: """Information about a detected conflict""" task_id: str issue_number: str conflicting_files: List[str] conflict_type: str # 'merge', 'content', 'dependency' details: str timestamp: str class TaskQueue: """Priority queue for task distribution""" def __init__(self): self._queue = queue.PriorityQueue() self._tasks: Dict[str, ParallelTask] = {} self._completed: Set[str] = set() self._failed: Set[str] = set() self._lock = threading.Lock() def add(self, task: ParallelTask): """Add task to queue""" with self._lock: self._tasks[task.task_id] = task # Lower priority number = higher priority self._queue.put((task.priority, task.task_id)) def get(self) -> Optional[ParallelTask]: """Get next ready task (dependencies satisfied)""" while True: try: _, task_id = self._queue.get(timeout=0.1) task = self._tasks[task_id] # Check dependencies with self._lock: pending_deps = set(task.dependencies) - self._completed if pending_deps: # Re-queue with lower priority if dependencies not met self._queue.put((task.priority + 1, task_id)) continue return task except queue.Empty: return None def mark_completed(self, task_id: str): """Mark a task as completed""" with self._lock: self._completed.add(task_id) def mark_failed(self, task_id: str): """Mark a task as failed""" with self._lock: self._failed.add(task_id) def is_done(self) -> bool: """Check if all tasks are processed""" with self._lock: return self._completed.union(self._failed) == set(self._tasks.keys()) def get_stats(self) -> Dict: """Get queue statistics""" with self._lock: return { "total": len(self._tasks), "completed": len(self._completed), "failed": len(self._failed), "pending": len(self._tasks) - len(self._completed) - len(self._failed) } class ConflictDetector: """Detects conflicts between parallel worktrees""" def __init__(self, worktree_manager: WorktreeManager): self.wm = worktree_manager self._file_ownership: Dict[str, str] = {} # file -> task_id self._lock = threading.Lock() def register_files(self, task_id: str, files: List[str]): """Register files as being modified by a task""" with self._lock: for f in files: if f in self._file_ownership and self._file_ownership[f] != task_id: # Conflict detected return self._file_ownership[f] self._file_ownership[f] = task_id return None def check_overlap(self, issue1: str, issue2: str) -> List[str]: """Check for overlapping file changes between two issues""" wt1_status = self.wm.get_worktree_status(issue1) wt2_status = self.wm.get_worktree_status(issue2) if not wt1_status.get("exists") or not wt2_status.get("exists"): return [] files1 = set(f.split()[-1] for f in wt1_status.get("files_changed", []) if f) files2 = set(f.split()[-1] for f in wt2_status.get("files_changed", []) if f) return list(files1 & files2) def detect_conflicts(self, results: List[TaskResult]) -> List[ConflictInfo]: """Detect conflicts from a batch of results""" conflicts = [] # Group by files modified file_modifications: Dict[str, List[str]] = {} for result in results: if not result.success: continue # Get files modified in this task wt_status = self.wm.get_worktree_status(result.issue_number) for f in wt_status.get("files_changed", []): if f: file_path = f.split()[-1] if file_path not in file_modifications: file_modifications[file_path] = [] file_modifications[file_path].append(result.task_id) # Find conflicts (same file modified by multiple tasks) for file_path, tasks in file_modifications.items(): if len(tasks) > 1: conflicts.append(ConflictInfo( task_id=tasks[0], issue_number=results[0].issue_number if results else "", conflicting_files=[file_path], conflict_type="content", details=f"File {file_path} modified by tasks: {', '.join(tasks)}", timestamp=datetime.now().isoformat() )) return conflicts class ParallelAgent: """Manages parallel execution across multiple worktrees""" def __init__(self, max_workers: int = None, repo_root: str = None): self.max_workers = max_workers or multiprocessing.cpu_count() self.wm = WorktreeManager(repo_root) self.queue = TaskQueue() self.conflict_detector = ConflictDetector(self.wm) self.results: List[TaskResult] = [] self.conflicts: List[ConflictInfo] = [] self._lock = threading.Lock() self._worktrees: Dict[str, str] = {} # issue_number -> path def _prepare_worktree(self, issue_number: str, description: str = None) -> str: """Prepare a worktree for an issue""" existing = self.wm.get_worktree_for_issue(issue_number) if existing: return existing return self.wm.create_worktree(issue_number, description) def _execute_task(self, task: ParallelTask) -> TaskResult: """Execute a single task in its worktree""" # Ensure worktree exists if task.issue_number not in self._worktrees: self._worktrees[task.issue_number] = self._prepare_worktree( task.issue_number, task.metadata.get("description") ) worktree_path = self._worktrees[task.issue_number] runner = TaskRunner(worktree_path) # Execute based on type if task.task_type == "command": result = runner.run_command( task_id=task.task_id, command=task.payload, issue_number=task.issue_number, timeout=task.timeout ) elif task.task_type == "script": result = runner.run_script( task_id=task.task_id, script_content=task.payload, issue_number=task.issue_number, timeout=task.timeout ) elif task.task_type == "function": result = runner.run_python_function( task_id=task.task_id, func=task.payload, issue_number=task.issue_number, args=task.metadata.get("args", ()), kwargs=task.metadata.get("kwargs", {}) ) else: result = TaskResult( task_id=task.task_id, issue_number=task.issue_number, success=False, return_code=-1, stdout="", stderr=f"Unknown task type: {task.task_type}", execution_time=0, timestamp=datetime.now().isoformat(), artifacts=[], error="Invalid task type" ) # Auto-commit if successful and requested if result.success and task.auto_commit: commit_hash = runner.auto_commit( message=f"Auto-commit for {task.task_id}", allow_empty=False ) if commit_hash: result.commit_hash = commit_hash return result def add_task(self, task: ParallelTask): """Add a task to the queue""" self.queue.add(task) def run_parallel(self, tasks: List[ParallelTask] = None) -> Dict: """ Execute all queued tasks in parallel Returns: Dict with results, conflicts, and statistics """ if tasks: for task in tasks: self.add_task(task) # Group tasks by issue for worktree preparation issues = set() with self.queue._lock: for task in self.queue._tasks.values(): issues.add(task.issue_number) # Pre-create worktrees print(f"Preparing {len(issues)} worktrees...") for issue in issues: self._worktrees[issue] = self._prepare_worktree(issue) # Execute tasks in parallel with ThreadPoolExecutor(max_workers=self.max_workers) as executor: futures = {} active_tasks = {} # Submit initial batch while len(futures) < self.max_workers: task = self.queue.get() if not task: break future = executor.submit(self._execute_task, task) futures[future] = task.task_id active_tasks[task.task_id] = task # Process results and submit new tasks for future in as_completed(futures): task_id = futures[future] task = active_tasks[task_id] try: result = future.result() except Exception as e: result = TaskResult( task_id=task_id, issue_number=task.issue_number, success=False, return_code=-1, stdout="", stderr=str(e), execution_time=0, timestamp=datetime.now().isoformat(), artifacts=[], error=str(e) ) with self._lock: self.results.append(result) if result.success: self.queue.mark_completed(task_id) else: self.queue.mark_failed(task_id) # Submit next task next_task = self.queue.get() if next_task: new_future = executor.submit(self._execute_task, next_task) futures[new_future] = next_task.task_id active_tasks[next_task.task_id] = next_task # Detect conflicts self.conflicts = self.conflict_detector.detect_conflicts(self.results) return { "results": [r.to_dict() for r in self.results], "conflicts": [asdict(c) for c in self.conflicts], "stats": { "total_tasks": len(self.results), "successful": sum(1 for r in self.results if r.success), "failed": sum(1 for r in self.results if not r.success), "conflicts": len(self.conflicts), "worktrees_used": len(self._worktrees) } } def generate_report(self, output_path: str = None) -> str: """Generate execution report""" report = { "timestamp": datetime.now().isoformat(), "max_workers": self.max_workers, "worktrees": list(self._worktrees.keys()), "results": [r.to_dict() for r in self.results], "conflicts": [asdict(c) for c in self.conflicts], "stats": { "total_tasks": len(self.results), "successful": sum(1 for r in self.results if r.success), "failed": sum(1 for r in self.results if not r.success), "conflicts": len(self.conflicts) } } report_json = json.dumps(report, indent=2) if output_path: Path(output_path).write_text(report_json) return report_json class IssueParallelizer: """High-level interface for working on multiple issues in parallel""" def __init__(self, repo_root: str = None): self.agent = ParallelAgent(repo_root=repo_root) self.issues: Dict[str, Dict] = {} def register_issue(self, issue_number: str, description: str = None): """Register an issue for parallel work""" self.issues[issue_number] = { "description": description, "tasks": [] } def add_task_to_issue(self, issue_number: str, command: List[str], task_id: str = None, **kwargs): """Add a command task to an issue""" if issue_number not in self.issues: self.register_issue(issue_number) task_id = task_id or f"{issue_number}-{len(self.issues[issue_number]['tasks'])}" task = ParallelTask( task_id=task_id, issue_number=issue_number, task_type="command", payload=command, dependencies=[], metadata={"description": self.issues[issue_number]["description"]}, **kwargs ) self.agent.add_task(task) def run(self) -> Dict: """Execute all registered tasks in parallel""" return self.agent.run_parallel() if __name__ == "__main__": # Example usage print("ParallelAgent loaded successfully") print(f"Max workers: {multiprocessing.cpu_count()}")