Refactor batch processing with rich progress tracking and update logging in AIAgent
- Replaced tqdm with rich for enhanced visual progress tracking in batch processing. - Adjusted logging levels in AIAgent to suppress asyncio debug messages. - Modified datagen script to reduce number of workers for improved performance.
This commit is contained in:
@@ -30,7 +30,8 @@ from datetime import datetime
|
|||||||
from multiprocessing import Pool, Manager, Lock
|
from multiprocessing import Pool, Manager, Lock
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from tqdm import tqdm
|
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||||
|
from rich.console import Console
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
@@ -643,14 +644,36 @@ class BatchRunner:
|
|||||||
print(f"✅ Created {len(tasks)} batch tasks")
|
print(f"✅ Created {len(tasks)} batch tasks")
|
||||||
print(f"🚀 Starting parallel batch processing...\n")
|
print(f"🚀 Starting parallel batch processing...\n")
|
||||||
|
|
||||||
# Use imap_unordered with tqdm for progress tracking
|
# Use rich Progress for better visual tracking with persistent bottom bar
|
||||||
results = list(tqdm(
|
# redirect_stdout/stderr lets rich manage all output so progress bar stays clean
|
||||||
pool.imap_unordered(_process_batch_worker, tasks),
|
results = []
|
||||||
total=len(tasks),
|
console = Console(force_terminal=True)
|
||||||
desc="📦 Batches",
|
with Progress(
|
||||||
unit="batch",
|
SpinnerColumn(),
|
||||||
ncols=80
|
TextColumn("[bold blue]📦 Batches"),
|
||||||
))
|
BarColumn(bar_width=40),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
TextColumn("•"),
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
console=console,
|
||||||
|
refresh_per_second=2,
|
||||||
|
transient=False,
|
||||||
|
redirect_stdout=False,
|
||||||
|
redirect_stderr=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("Processing", total=len(tasks))
|
||||||
|
|
||||||
|
# Temporarily suppress DEBUG logging to avoid bar interference
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
original_level = root_logger.level
|
||||||
|
root_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||||
|
results.append(result)
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
finally:
|
||||||
|
root_logger.setLevel(original_level)
|
||||||
|
|
||||||
# Aggregate all batch statistics and update checkpoint
|
# Aggregate all batch statistics and update checkpoint
|
||||||
all_completed_prompts = list(completed_prompts_set)
|
all_completed_prompts = list(completed_prompts_set)
|
||||||
|
|||||||
@@ -7,3 +7,4 @@ tenacity
|
|||||||
python-dotenv
|
python-dotenv
|
||||||
fire
|
fire
|
||||||
httpx
|
httpx
|
||||||
|
rich
|
||||||
|
|||||||
@@ -127,7 +127,8 @@ class AIAgent:
|
|||||||
logging.getLogger('openai._base_client').setLevel(logging.WARNING)
|
logging.getLogger('openai._base_client').setLevel(logging.WARNING)
|
||||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||||
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
logging.getLogger('httpcore').setLevel(logging.WARNING)
|
||||||
print("🔍 Verbose logging enabled (OpenAI/httpx internal logs suppressed)")
|
logging.getLogger('asyncio').setLevel(logging.WARNING) # Suppress asyncio debug
|
||||||
|
print("🔍 Verbose logging enabled (OpenAI/httpx/asyncio internal logs suppressed)")
|
||||||
else:
|
else:
|
||||||
# Set logging to INFO level for important messages only
|
# Set logging to INFO level for important messages only
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ python batch_runner.py \
|
|||||||
--model="z-ai/glm-4.7" \
|
--model="z-ai/glm-4.7" \
|
||||||
--base_url="https://openrouter.ai/api/v1" \
|
--base_url="https://openrouter.ai/api/v1" \
|
||||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||||
--num_workers=25 \
|
--num_workers=1 \
|
||||||
--max_turns=25 \
|
--max_turns=25 \
|
||||||
--verbose \
|
|
||||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
|
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
|
||||||
2>&1 | tee "$LOG_FILE"
|
2>&1 | tee "$LOG_FILE"
|
||||||
|
|
||||||
echo "✅ Log saved to: $LOG_FILE"
|
echo "✅ Log saved to: $LOG_FILE"
|
||||||
|
# --verbose \
|
||||||
Reference in New Issue
Block a user