Files
timmy-config/wizards/allegro-primus/git_tools/parallel_agent.py
2026-03-31 20:02:01 +00:00

423 lines
15 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Parallel Agent System for Allegro-Primus
Manages multiple concurrent worktrees and task distribution
"""
import json
import multiprocessing
import threading
import queue
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Callable, Any, Set
import hashlib
try:
from .worktree_manager import WorktreeManager
from .task_runner import TaskRunner, TaskResult
except ImportError:
from worktree_manager import WorktreeManager
from task_runner import TaskRunner, TaskResult
@dataclass
class ParallelTask:
"""A task to be executed in parallel"""
task_id: str
issue_number: str
task_type: str # 'command', 'script', 'function'
payload: Any # Command list, script string, or function
dependencies: List[str] # Task IDs that must complete first
priority: int = 0
timeout: int = 300
auto_commit: bool = False
metadata: Dict = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
@dataclass
class ConflictInfo:
"""Information about a detected conflict"""
task_id: str
issue_number: str
conflicting_files: List[str]
conflict_type: str # 'merge', 'content', 'dependency'
details: str
timestamp: str
class TaskQueue:
"""Priority queue for task distribution"""
def __init__(self):
self._queue = queue.PriorityQueue()
self._tasks: Dict[str, ParallelTask] = {}
self._completed: Set[str] = set()
self._failed: Set[str] = set()
self._lock = threading.Lock()
def add(self, task: ParallelTask):
"""Add task to queue"""
with self._lock:
self._tasks[task.task_id] = task
# Lower priority number = higher priority
self._queue.put((task.priority, task.task_id))
def get(self) -> Optional[ParallelTask]:
"""Get next ready task (dependencies satisfied)"""
while True:
try:
_, task_id = self._queue.get(timeout=0.1)
task = self._tasks[task_id]
# Check dependencies
with self._lock:
pending_deps = set(task.dependencies) - self._completed
if pending_deps:
# Re-queue with lower priority if dependencies not met
self._queue.put((task.priority + 1, task_id))
continue
return task
except queue.Empty:
return None
def mark_completed(self, task_id: str):
"""Mark a task as completed"""
with self._lock:
self._completed.add(task_id)
def mark_failed(self, task_id: str):
"""Mark a task as failed"""
with self._lock:
self._failed.add(task_id)
def is_done(self) -> bool:
"""Check if all tasks are processed"""
with self._lock:
return self._completed.union(self._failed) == set(self._tasks.keys())
def get_stats(self) -> Dict:
"""Get queue statistics"""
with self._lock:
return {
"total": len(self._tasks),
"completed": len(self._completed),
"failed": len(self._failed),
"pending": len(self._tasks) - len(self._completed) - len(self._failed)
}
class ConflictDetector:
"""Detects conflicts between parallel worktrees"""
def __init__(self, worktree_manager: WorktreeManager):
self.wm = worktree_manager
self._file_ownership: Dict[str, str] = {} # file -> task_id
self._lock = threading.Lock()
def register_files(self, task_id: str, files: List[str]):
"""Register files as being modified by a task"""
with self._lock:
for f in files:
if f in self._file_ownership and self._file_ownership[f] != task_id:
# Conflict detected
return self._file_ownership[f]
self._file_ownership[f] = task_id
return None
def check_overlap(self, issue1: str, issue2: str) -> List[str]:
"""Check for overlapping file changes between two issues"""
wt1_status = self.wm.get_worktree_status(issue1)
wt2_status = self.wm.get_worktree_status(issue2)
if not wt1_status.get("exists") or not wt2_status.get("exists"):
return []
files1 = set(f.split()[-1] for f in wt1_status.get("files_changed", []) if f)
files2 = set(f.split()[-1] for f in wt2_status.get("files_changed", []) if f)
return list(files1 & files2)
def detect_conflicts(self, results: List[TaskResult]) -> List[ConflictInfo]:
"""Detect conflicts from a batch of results"""
conflicts = []
# Group by files modified
file_modifications: Dict[str, List[str]] = {}
for result in results:
if not result.success:
continue
# Get files modified in this task
wt_status = self.wm.get_worktree_status(result.issue_number)
for f in wt_status.get("files_changed", []):
if f:
file_path = f.split()[-1]
if file_path not in file_modifications:
file_modifications[file_path] = []
file_modifications[file_path].append(result.task_id)
# Find conflicts (same file modified by multiple tasks)
for file_path, tasks in file_modifications.items():
if len(tasks) > 1:
conflicts.append(ConflictInfo(
task_id=tasks[0],
issue_number=results[0].issue_number if results else "",
conflicting_files=[file_path],
conflict_type="content",
details=f"File {file_path} modified by tasks: {', '.join(tasks)}",
timestamp=datetime.now().isoformat()
))
return conflicts
class ParallelAgent:
"""Manages parallel execution across multiple worktrees"""
def __init__(self, max_workers: int = None, repo_root: str = None):
self.max_workers = max_workers or multiprocessing.cpu_count()
self.wm = WorktreeManager(repo_root)
self.queue = TaskQueue()
self.conflict_detector = ConflictDetector(self.wm)
self.results: List[TaskResult] = []
self.conflicts: List[ConflictInfo] = []
self._lock = threading.Lock()
self._worktrees: Dict[str, str] = {} # issue_number -> path
def _prepare_worktree(self, issue_number: str, description: str = None) -> str:
"""Prepare a worktree for an issue"""
existing = self.wm.get_worktree_for_issue(issue_number)
if existing:
return existing
return self.wm.create_worktree(issue_number, description)
def _execute_task(self, task: ParallelTask) -> TaskResult:
"""Execute a single task in its worktree"""
# Ensure worktree exists
if task.issue_number not in self._worktrees:
self._worktrees[task.issue_number] = self._prepare_worktree(
task.issue_number,
task.metadata.get("description")
)
worktree_path = self._worktrees[task.issue_number]
runner = TaskRunner(worktree_path)
# Execute based on type
if task.task_type == "command":
result = runner.run_command(
task_id=task.task_id,
command=task.payload,
issue_number=task.issue_number,
timeout=task.timeout
)
elif task.task_type == "script":
result = runner.run_script(
task_id=task.task_id,
script_content=task.payload,
issue_number=task.issue_number,
timeout=task.timeout
)
elif task.task_type == "function":
result = runner.run_python_function(
task_id=task.task_id,
func=task.payload,
issue_number=task.issue_number,
args=task.metadata.get("args", ()),
kwargs=task.metadata.get("kwargs", {})
)
else:
result = TaskResult(
task_id=task.task_id,
issue_number=task.issue_number,
success=False,
return_code=-1,
stdout="",
stderr=f"Unknown task type: {task.task_type}",
execution_time=0,
timestamp=datetime.now().isoformat(),
artifacts=[],
error="Invalid task type"
)
# Auto-commit if successful and requested
if result.success and task.auto_commit:
commit_hash = runner.auto_commit(
message=f"Auto-commit for {task.task_id}",
allow_empty=False
)
if commit_hash:
result.commit_hash = commit_hash
return result
def add_task(self, task: ParallelTask):
"""Add a task to the queue"""
self.queue.add(task)
def run_parallel(self, tasks: List[ParallelTask] = None) -> Dict:
"""
Execute all queued tasks in parallel
Returns:
Dict with results, conflicts, and statistics
"""
if tasks:
for task in tasks:
self.add_task(task)
# Group tasks by issue for worktree preparation
issues = set()
with self.queue._lock:
for task in self.queue._tasks.values():
issues.add(task.issue_number)
# Pre-create worktrees
print(f"Preparing {len(issues)} worktrees...")
for issue in issues:
self._worktrees[issue] = self._prepare_worktree(issue)
# Execute tasks in parallel
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
futures = {}
active_tasks = {}
# Submit initial batch
while len(futures) < self.max_workers:
task = self.queue.get()
if not task:
break
future = executor.submit(self._execute_task, task)
futures[future] = task.task_id
active_tasks[task.task_id] = task
# Process results and submit new tasks
for future in as_completed(futures):
task_id = futures[future]
task = active_tasks[task_id]
try:
result = future.result()
except Exception as e:
result = TaskResult(
task_id=task_id,
issue_number=task.issue_number,
success=False,
return_code=-1,
stdout="",
stderr=str(e),
execution_time=0,
timestamp=datetime.now().isoformat(),
artifacts=[],
error=str(e)
)
with self._lock:
self.results.append(result)
if result.success:
self.queue.mark_completed(task_id)
else:
self.queue.mark_failed(task_id)
# Submit next task
next_task = self.queue.get()
if next_task:
new_future = executor.submit(self._execute_task, next_task)
futures[new_future] = next_task.task_id
active_tasks[next_task.task_id] = next_task
# Detect conflicts
self.conflicts = self.conflict_detector.detect_conflicts(self.results)
return {
"results": [r.to_dict() for r in self.results],
"conflicts": [asdict(c) for c in self.conflicts],
"stats": {
"total_tasks": len(self.results),
"successful": sum(1 for r in self.results if r.success),
"failed": sum(1 for r in self.results if not r.success),
"conflicts": len(self.conflicts),
"worktrees_used": len(self._worktrees)
}
}
def generate_report(self, output_path: str = None) -> str:
"""Generate execution report"""
report = {
"timestamp": datetime.now().isoformat(),
"max_workers": self.max_workers,
"worktrees": list(self._worktrees.keys()),
"results": [r.to_dict() for r in self.results],
"conflicts": [asdict(c) for c in self.conflicts],
"stats": {
"total_tasks": len(self.results),
"successful": sum(1 for r in self.results if r.success),
"failed": sum(1 for r in self.results if not r.success),
"conflicts": len(self.conflicts)
}
}
report_json = json.dumps(report, indent=2)
if output_path:
Path(output_path).write_text(report_json)
return report_json
class IssueParallelizer:
"""High-level interface for working on multiple issues in parallel"""
def __init__(self, repo_root: str = None):
self.agent = ParallelAgent(repo_root=repo_root)
self.issues: Dict[str, Dict] = {}
def register_issue(self, issue_number: str, description: str = None):
"""Register an issue for parallel work"""
self.issues[issue_number] = {
"description": description,
"tasks": []
}
def add_task_to_issue(self, issue_number: str, command: List[str],
task_id: str = None, **kwargs):
"""Add a command task to an issue"""
if issue_number not in self.issues:
self.register_issue(issue_number)
task_id = task_id or f"{issue_number}-{len(self.issues[issue_number]['tasks'])}"
task = ParallelTask(
task_id=task_id,
issue_number=issue_number,
task_type="command",
payload=command,
dependencies=[],
metadata={"description": self.issues[issue_number]["description"]},
**kwargs
)
self.agent.add_task(task)
def run(self) -> Dict:
"""Execute all registered tasks in parallel"""
return self.agent.run_parallel()
if __name__ == "__main__":
# Example usage
print("ParallelAgent loaded successfully")
print(f"Max workers: {multiprocessing.cpu_count()}")