Refactor batch processing with rich progress tracking and update logging in AIAgent

- Replaced tqdm with rich for enhanced visual progress tracking in batch processing.
- Adjusted logging levels in AIAgent to suppress asyncio debug messages.
- Modified datagen script to reduce number of workers for improved performance.
This commit is contained in:
teknium
2026-01-14 14:02:59 +00:00
parent 6e3dbb8d8b
commit b32cc4b09d
4 changed files with 37 additions and 12 deletions

View File

@@ -30,7 +30,8 @@ from datetime import datetime
from multiprocessing import Pool, Manager, Lock
import traceback
from tqdm import tqdm
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import fire
from run_agent import AIAgent
@@ -643,14 +644,36 @@ class BatchRunner:
print(f"✅ Created {len(tasks)} batch tasks")
print(f"🚀 Starting parallel batch processing...\n")
# Use imap_unordered with tqdm for progress tracking
results = list(tqdm(
pool.imap_unordered(_process_batch_worker, tasks),
total=len(tasks),
desc="📦 Batches",
unit="batch",
ncols=80
))
# Use rich Progress for better visual tracking with persistent bottom bar
# redirect_stdout/stderr lets rich manage all output so progress bar stays clean
results = []
console = Console(force_terminal=True)
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]📦 Batches"),
BarColumn(bar_width=40),
MofNCompleteColumn(),
TextColumn(""),
TimeRemainingColumn(),
console=console,
refresh_per_second=2,
transient=False,
redirect_stdout=False,
redirect_stderr=False,
) as progress:
task = progress.add_task("Processing", total=len(tasks))
# Temporarily suppress DEBUG logging to avoid bar interference
root_logger = logging.getLogger()
original_level = root_logger.level
root_logger.setLevel(logging.WARNING)
try:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
finally:
root_logger.setLevel(original_level)
# Aggregate all batch statistics and update checkpoint
all_completed_prompts = list(completed_prompts_set)