423 lines
15 KiB
Python
Executable File
423 lines
15 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Parallel Agent System for Allegro-Primus
|
|
Manages multiple concurrent worktrees and task distribution
|
|
"""
|
|
|
|
import json
|
|
import multiprocessing
|
|
import threading
|
|
import queue
|
|
import time
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Callable, Any, Set
|
|
import hashlib
|
|
|
|
try:
|
|
from .worktree_manager import WorktreeManager
|
|
from .task_runner import TaskRunner, TaskResult
|
|
except ImportError:
|
|
from worktree_manager import WorktreeManager
|
|
from task_runner import TaskRunner, TaskResult
|
|
|
|
|
|
@dataclass
|
|
class ParallelTask:
|
|
"""A task to be executed in parallel"""
|
|
task_id: str
|
|
issue_number: str
|
|
task_type: str # 'command', 'script', 'function'
|
|
payload: Any # Command list, script string, or function
|
|
dependencies: List[str] # Task IDs that must complete first
|
|
priority: int = 0
|
|
timeout: int = 300
|
|
auto_commit: bool = False
|
|
metadata: Dict = None
|
|
|
|
def __post_init__(self):
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
|
|
|
|
@dataclass
|
|
class ConflictInfo:
|
|
"""Information about a detected conflict"""
|
|
task_id: str
|
|
issue_number: str
|
|
conflicting_files: List[str]
|
|
conflict_type: str # 'merge', 'content', 'dependency'
|
|
details: str
|
|
timestamp: str
|
|
|
|
|
|
class TaskQueue:
|
|
"""Priority queue for task distribution"""
|
|
|
|
def __init__(self):
|
|
self._queue = queue.PriorityQueue()
|
|
self._tasks: Dict[str, ParallelTask] = {}
|
|
self._completed: Set[str] = set()
|
|
self._failed: Set[str] = set()
|
|
self._lock = threading.Lock()
|
|
|
|
def add(self, task: ParallelTask):
|
|
"""Add task to queue"""
|
|
with self._lock:
|
|
self._tasks[task.task_id] = task
|
|
# Lower priority number = higher priority
|
|
self._queue.put((task.priority, task.task_id))
|
|
|
|
def get(self) -> Optional[ParallelTask]:
|
|
"""Get next ready task (dependencies satisfied)"""
|
|
while True:
|
|
try:
|
|
_, task_id = self._queue.get(timeout=0.1)
|
|
task = self._tasks[task_id]
|
|
|
|
# Check dependencies
|
|
with self._lock:
|
|
pending_deps = set(task.dependencies) - self._completed
|
|
if pending_deps:
|
|
# Re-queue with lower priority if dependencies not met
|
|
self._queue.put((task.priority + 1, task_id))
|
|
continue
|
|
|
|
return task
|
|
except queue.Empty:
|
|
return None
|
|
|
|
def mark_completed(self, task_id: str):
|
|
"""Mark a task as completed"""
|
|
with self._lock:
|
|
self._completed.add(task_id)
|
|
|
|
def mark_failed(self, task_id: str):
|
|
"""Mark a task as failed"""
|
|
with self._lock:
|
|
self._failed.add(task_id)
|
|
|
|
def is_done(self) -> bool:
|
|
"""Check if all tasks are processed"""
|
|
with self._lock:
|
|
return self._completed.union(self._failed) == set(self._tasks.keys())
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get queue statistics"""
|
|
with self._lock:
|
|
return {
|
|
"total": len(self._tasks),
|
|
"completed": len(self._completed),
|
|
"failed": len(self._failed),
|
|
"pending": len(self._tasks) - len(self._completed) - len(self._failed)
|
|
}
|
|
|
|
|
|
class ConflictDetector:
|
|
"""Detects conflicts between parallel worktrees"""
|
|
|
|
def __init__(self, worktree_manager: WorktreeManager):
|
|
self.wm = worktree_manager
|
|
self._file_ownership: Dict[str, str] = {} # file -> task_id
|
|
self._lock = threading.Lock()
|
|
|
|
def register_files(self, task_id: str, files: List[str]):
|
|
"""Register files as being modified by a task"""
|
|
with self._lock:
|
|
for f in files:
|
|
if f in self._file_ownership and self._file_ownership[f] != task_id:
|
|
# Conflict detected
|
|
return self._file_ownership[f]
|
|
self._file_ownership[f] = task_id
|
|
return None
|
|
|
|
def check_overlap(self, issue1: str, issue2: str) -> List[str]:
|
|
"""Check for overlapping file changes between two issues"""
|
|
wt1_status = self.wm.get_worktree_status(issue1)
|
|
wt2_status = self.wm.get_worktree_status(issue2)
|
|
|
|
if not wt1_status.get("exists") or not wt2_status.get("exists"):
|
|
return []
|
|
|
|
files1 = set(f.split()[-1] for f in wt1_status.get("files_changed", []) if f)
|
|
files2 = set(f.split()[-1] for f in wt2_status.get("files_changed", []) if f)
|
|
|
|
return list(files1 & files2)
|
|
|
|
def detect_conflicts(self, results: List[TaskResult]) -> List[ConflictInfo]:
|
|
"""Detect conflicts from a batch of results"""
|
|
conflicts = []
|
|
|
|
# Group by files modified
|
|
file_modifications: Dict[str, List[str]] = {}
|
|
|
|
for result in results:
|
|
if not result.success:
|
|
continue
|
|
|
|
# Get files modified in this task
|
|
wt_status = self.wm.get_worktree_status(result.issue_number)
|
|
for f in wt_status.get("files_changed", []):
|
|
if f:
|
|
file_path = f.split()[-1]
|
|
if file_path not in file_modifications:
|
|
file_modifications[file_path] = []
|
|
file_modifications[file_path].append(result.task_id)
|
|
|
|
# Find conflicts (same file modified by multiple tasks)
|
|
for file_path, tasks in file_modifications.items():
|
|
if len(tasks) > 1:
|
|
conflicts.append(ConflictInfo(
|
|
task_id=tasks[0],
|
|
issue_number=results[0].issue_number if results else "",
|
|
conflicting_files=[file_path],
|
|
conflict_type="content",
|
|
details=f"File {file_path} modified by tasks: {', '.join(tasks)}",
|
|
timestamp=datetime.now().isoformat()
|
|
))
|
|
|
|
return conflicts
|
|
|
|
|
|
class ParallelAgent:
|
|
"""Manages parallel execution across multiple worktrees"""
|
|
|
|
def __init__(self, max_workers: int = None, repo_root: str = None):
|
|
self.max_workers = max_workers or multiprocessing.cpu_count()
|
|
self.wm = WorktreeManager(repo_root)
|
|
self.queue = TaskQueue()
|
|
self.conflict_detector = ConflictDetector(self.wm)
|
|
self.results: List[TaskResult] = []
|
|
self.conflicts: List[ConflictInfo] = []
|
|
self._lock = threading.Lock()
|
|
self._worktrees: Dict[str, str] = {} # issue_number -> path
|
|
|
|
def _prepare_worktree(self, issue_number: str, description: str = None) -> str:
|
|
"""Prepare a worktree for an issue"""
|
|
existing = self.wm.get_worktree_for_issue(issue_number)
|
|
if existing:
|
|
return existing
|
|
|
|
return self.wm.create_worktree(issue_number, description)
|
|
|
|
def _execute_task(self, task: ParallelTask) -> TaskResult:
|
|
"""Execute a single task in its worktree"""
|
|
# Ensure worktree exists
|
|
if task.issue_number not in self._worktrees:
|
|
self._worktrees[task.issue_number] = self._prepare_worktree(
|
|
task.issue_number,
|
|
task.metadata.get("description")
|
|
)
|
|
|
|
worktree_path = self._worktrees[task.issue_number]
|
|
runner = TaskRunner(worktree_path)
|
|
|
|
# Execute based on type
|
|
if task.task_type == "command":
|
|
result = runner.run_command(
|
|
task_id=task.task_id,
|
|
command=task.payload,
|
|
issue_number=task.issue_number,
|
|
timeout=task.timeout
|
|
)
|
|
elif task.task_type == "script":
|
|
result = runner.run_script(
|
|
task_id=task.task_id,
|
|
script_content=task.payload,
|
|
issue_number=task.issue_number,
|
|
timeout=task.timeout
|
|
)
|
|
elif task.task_type == "function":
|
|
result = runner.run_python_function(
|
|
task_id=task.task_id,
|
|
func=task.payload,
|
|
issue_number=task.issue_number,
|
|
args=task.metadata.get("args", ()),
|
|
kwargs=task.metadata.get("kwargs", {})
|
|
)
|
|
else:
|
|
result = TaskResult(
|
|
task_id=task.task_id,
|
|
issue_number=task.issue_number,
|
|
success=False,
|
|
return_code=-1,
|
|
stdout="",
|
|
stderr=f"Unknown task type: {task.task_type}",
|
|
execution_time=0,
|
|
timestamp=datetime.now().isoformat(),
|
|
artifacts=[],
|
|
error="Invalid task type"
|
|
)
|
|
|
|
# Auto-commit if successful and requested
|
|
if result.success and task.auto_commit:
|
|
commit_hash = runner.auto_commit(
|
|
message=f"Auto-commit for {task.task_id}",
|
|
allow_empty=False
|
|
)
|
|
if commit_hash:
|
|
result.commit_hash = commit_hash
|
|
|
|
return result
|
|
|
|
def add_task(self, task: ParallelTask):
|
|
"""Add a task to the queue"""
|
|
self.queue.add(task)
|
|
|
|
def run_parallel(self, tasks: List[ParallelTask] = None) -> Dict:
|
|
"""
|
|
Execute all queued tasks in parallel
|
|
|
|
Returns:
|
|
Dict with results, conflicts, and statistics
|
|
"""
|
|
if tasks:
|
|
for task in tasks:
|
|
self.add_task(task)
|
|
|
|
# Group tasks by issue for worktree preparation
|
|
issues = set()
|
|
with self.queue._lock:
|
|
for task in self.queue._tasks.values():
|
|
issues.add(task.issue_number)
|
|
|
|
# Pre-create worktrees
|
|
print(f"Preparing {len(issues)} worktrees...")
|
|
for issue in issues:
|
|
self._worktrees[issue] = self._prepare_worktree(issue)
|
|
|
|
# Execute tasks in parallel
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
futures = {}
|
|
active_tasks = {}
|
|
|
|
# Submit initial batch
|
|
while len(futures) < self.max_workers:
|
|
task = self.queue.get()
|
|
if not task:
|
|
break
|
|
future = executor.submit(self._execute_task, task)
|
|
futures[future] = task.task_id
|
|
active_tasks[task.task_id] = task
|
|
|
|
# Process results and submit new tasks
|
|
for future in as_completed(futures):
|
|
task_id = futures[future]
|
|
task = active_tasks[task_id]
|
|
|
|
try:
|
|
result = future.result()
|
|
except Exception as e:
|
|
result = TaskResult(
|
|
task_id=task_id,
|
|
issue_number=task.issue_number,
|
|
success=False,
|
|
return_code=-1,
|
|
stdout="",
|
|
stderr=str(e),
|
|
execution_time=0,
|
|
timestamp=datetime.now().isoformat(),
|
|
artifacts=[],
|
|
error=str(e)
|
|
)
|
|
|
|
with self._lock:
|
|
self.results.append(result)
|
|
|
|
if result.success:
|
|
self.queue.mark_completed(task_id)
|
|
else:
|
|
self.queue.mark_failed(task_id)
|
|
|
|
# Submit next task
|
|
next_task = self.queue.get()
|
|
if next_task:
|
|
new_future = executor.submit(self._execute_task, next_task)
|
|
futures[new_future] = next_task.task_id
|
|
active_tasks[next_task.task_id] = next_task
|
|
|
|
# Detect conflicts
|
|
self.conflicts = self.conflict_detector.detect_conflicts(self.results)
|
|
|
|
return {
|
|
"results": [r.to_dict() for r in self.results],
|
|
"conflicts": [asdict(c) for c in self.conflicts],
|
|
"stats": {
|
|
"total_tasks": len(self.results),
|
|
"successful": sum(1 for r in self.results if r.success),
|
|
"failed": sum(1 for r in self.results if not r.success),
|
|
"conflicts": len(self.conflicts),
|
|
"worktrees_used": len(self._worktrees)
|
|
}
|
|
}
|
|
|
|
def generate_report(self, output_path: str = None) -> str:
|
|
"""Generate execution report"""
|
|
report = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"max_workers": self.max_workers,
|
|
"worktrees": list(self._worktrees.keys()),
|
|
"results": [r.to_dict() for r in self.results],
|
|
"conflicts": [asdict(c) for c in self.conflicts],
|
|
"stats": {
|
|
"total_tasks": len(self.results),
|
|
"successful": sum(1 for r in self.results if r.success),
|
|
"failed": sum(1 for r in self.results if not r.success),
|
|
"conflicts": len(self.conflicts)
|
|
}
|
|
}
|
|
|
|
report_json = json.dumps(report, indent=2)
|
|
|
|
if output_path:
|
|
Path(output_path).write_text(report_json)
|
|
|
|
return report_json
|
|
|
|
|
|
class IssueParallelizer:
|
|
"""High-level interface for working on multiple issues in parallel"""
|
|
|
|
def __init__(self, repo_root: str = None):
|
|
self.agent = ParallelAgent(repo_root=repo_root)
|
|
self.issues: Dict[str, Dict] = {}
|
|
|
|
def register_issue(self, issue_number: str, description: str = None):
|
|
"""Register an issue for parallel work"""
|
|
self.issues[issue_number] = {
|
|
"description": description,
|
|
"tasks": []
|
|
}
|
|
|
|
def add_task_to_issue(self, issue_number: str, command: List[str],
|
|
task_id: str = None, **kwargs):
|
|
"""Add a command task to an issue"""
|
|
if issue_number not in self.issues:
|
|
self.register_issue(issue_number)
|
|
|
|
task_id = task_id or f"{issue_number}-{len(self.issues[issue_number]['tasks'])}"
|
|
|
|
task = ParallelTask(
|
|
task_id=task_id,
|
|
issue_number=issue_number,
|
|
task_type="command",
|
|
payload=command,
|
|
dependencies=[],
|
|
metadata={"description": self.issues[issue_number]["description"]},
|
|
**kwargs
|
|
)
|
|
|
|
self.agent.add_task(task)
|
|
|
|
def run(self) -> Dict:
|
|
"""Execute all registered tasks in parallel"""
|
|
return self.agent.run_parallel()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Example usage
|
|
print("ParallelAgent loaded successfully")
|
|
print(f"Max workers: {multiprocessing.cpu_count()}")
|