updates for stability and speed
This commit is contained in:
@@ -155,7 +155,8 @@ def _process_single_prompt(
|
||||
if config.get("verbose"):
|
||||
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
|
||||
|
||||
# Initialize agent with sampled toolsets
|
||||
# Initialize agent with sampled toolsets and log prefix for identification
|
||||
log_prefix = f"[B{batch_num}:P{prompt_index}]"
|
||||
agent = AIAgent(
|
||||
base_url=config.get("base_url"),
|
||||
api_key=config.get("api_key"),
|
||||
@@ -165,7 +166,12 @@ def _process_single_prompt(
|
||||
save_trajectories=False, # We handle saving ourselves
|
||||
verbose_logging=config.get("verbose", False),
|
||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
|
||||
log_prefix_chars=config.get("log_prefix_chars", 100)
|
||||
log_prefix_chars=config.get("log_prefix_chars", 100),
|
||||
log_prefix=log_prefix,
|
||||
providers_allowed=config.get("providers_allowed"),
|
||||
providers_ignored=config.get("providers_ignored"),
|
||||
providers_order=config.get("providers_order"),
|
||||
provider_sort=config.get("provider_sort"),
|
||||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
@@ -326,6 +332,10 @@ class BatchRunner:
|
||||
verbose: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: List[str] = None,
|
||||
providers_ignored: List[str] = None,
|
||||
providers_order: List[str] = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
@@ -343,6 +353,10 @@ class BatchRunner:
|
||||
verbose (bool): Enable verbose logging
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (List[str]): OpenRouter providers to allow (optional)
|
||||
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
|
||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -356,6 +370,10 @@ class BatchRunner:
|
||||
self.verbose = verbose
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.providers_allowed = providers_allowed
|
||||
self.providers_ignored = providers_ignored
|
||||
self.providers_order = providers_order
|
||||
self.provider_sort = provider_sort
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -512,7 +530,11 @@ class BatchRunner:
|
||||
"api_key": self.api_key,
|
||||
"verbose": self.verbose,
|
||||
"ephemeral_system_prompt": self.ephemeral_system_prompt,
|
||||
"log_prefix_chars": self.log_prefix_chars
|
||||
"log_prefix_chars": self.log_prefix_chars,
|
||||
"providers_allowed": self.providers_allowed,
|
||||
"providers_ignored": self.providers_ignored,
|
||||
"providers_order": self.providers_order,
|
||||
"provider_sort": self.provider_sort,
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
@@ -523,6 +545,8 @@ class BatchRunner:
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
@@ -537,6 +561,9 @@ class BatchRunner:
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
print(f"✅ Created {len(tasks)} batch tasks")
|
||||
print(f"🚀 Starting parallel batch processing...\n")
|
||||
|
||||
# Use map to process batches in parallel
|
||||
results = pool.map(_process_batch_worker, tasks)
|
||||
|
||||
@@ -657,6 +684,10 @@ def main(
|
||||
list_distributions: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: str = None,
|
||||
providers_ignored: str = None,
|
||||
providers_order: str = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
@@ -676,6 +707,10 @@ def main(
|
||||
list_distributions (bool): List available toolset distributions and exit
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
|
||||
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -723,6 +758,11 @@ def main(
|
||||
print("❌ Error: --run_name is required")
|
||||
return
|
||||
|
||||
# Parse provider preferences (comma-separated strings to lists)
|
||||
providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None
|
||||
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
@@ -737,7 +777,11 @@ def main(
|
||||
num_workers=num_workers,
|
||||
verbose=verbose,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
providers_allowed=providers_allowed_list,
|
||||
providers_ignored=providers_ignored_list,
|
||||
providers_order=providers_order_list,
|
||||
provider_sort=provider_sort,
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
Reference in New Issue
Block a user