updates for stability and speed

This commit is contained in:
teknium
2026-01-08 08:57:51 +00:00
parent f957ec2267
commit 6af6ff2a0a
5 changed files with 447 additions and 131 deletions

View File

@@ -155,7 +155,8 @@ def _process_single_prompt(
if config.get("verbose"):
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
# Initialize agent with sampled toolsets
# Initialize agent with sampled toolsets and log prefix for identification
log_prefix = f"[B{batch_num}:P{prompt_index}]"
agent = AIAgent(
base_url=config.get("base_url"),
api_key=config.get("api_key"),
@@ -165,7 +166,12 @@ def _process_single_prompt(
save_trajectories=False, # We handle saving ourselves
verbose_logging=config.get("verbose", False),
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
log_prefix_chars=config.get("log_prefix_chars", 100)
log_prefix_chars=config.get("log_prefix_chars", 100),
log_prefix=log_prefix,
providers_allowed=config.get("providers_allowed"),
providers_ignored=config.get("providers_ignored"),
providers_order=config.get("providers_order"),
provider_sort=config.get("provider_sort"),
)
# Run the agent with task_id to ensure each task gets its own isolated VM
@@ -326,6 +332,10 @@ class BatchRunner:
verbose: bool = False,
ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100,
providers_allowed: List[str] = None,
providers_ignored: List[str] = None,
providers_order: List[str] = None,
provider_sort: str = None,
):
"""
Initialize the batch runner.
@@ -343,6 +353,10 @@ class BatchRunner:
verbose (bool): Enable verbose logging
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
providers_allowed (List[str]): OpenRouter providers to allow (optional)
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
providers_order (List[str]): OpenRouter providers to try in order (optional)
provider_sort (str): Sort providers by price/throughput/latency (optional)
"""
self.dataset_file = Path(dataset_file)
self.batch_size = batch_size
@@ -356,6 +370,10 @@ class BatchRunner:
self.verbose = verbose
self.ephemeral_system_prompt = ephemeral_system_prompt
self.log_prefix_chars = log_prefix_chars
self.providers_allowed = providers_allowed
self.providers_ignored = providers_ignored
self.providers_order = providers_order
self.provider_sort = provider_sort
# Validate distribution
if not validate_distribution(distribution):
@@ -512,7 +530,11 @@ class BatchRunner:
"api_key": self.api_key,
"verbose": self.verbose,
"ephemeral_system_prompt": self.ephemeral_system_prompt,
"log_prefix_chars": self.log_prefix_chars
"log_prefix_chars": self.log_prefix_chars,
"providers_allowed": self.providers_allowed,
"providers_ignored": self.providers_ignored,
"providers_order": self.providers_order,
"provider_sort": self.provider_sort,
}
# Get completed prompts set
@@ -523,6 +545,8 @@ class BatchRunner:
start_time = time.time()
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@@ -537,6 +561,9 @@ class BatchRunner:
for batch_num, batch_data in enumerate(self.batches)
]
print(f"✅ Created {len(tasks)} batch tasks")
print(f"🚀 Starting parallel batch processing...\n")
# Use map to process batches in parallel
results = pool.map(_process_batch_worker, tasks)
@@ -657,6 +684,10 @@ def main(
list_distributions: bool = False,
ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100,
providers_allowed: str = None,
providers_ignored: str = None,
providers_order: str = None,
provider_sort: str = None,
):
"""
Run batch processing of agent prompts from a dataset.
@@ -676,6 +707,10 @@ def main(
list_distributions (bool): List available toolset distributions and exit
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
Examples:
# Basic usage
@@ -723,6 +758,11 @@ def main(
print("❌ Error: --run_name is required")
return
# Parse provider preferences (comma-separated strings to lists)
providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
# Initialize and run batch runner
try:
runner = BatchRunner(
@@ -737,7 +777,11 @@ def main(
num_workers=num_workers,
verbose=verbose,
ephemeral_system_prompt=ephemeral_system_prompt,
log_prefix_chars=log_prefix_chars
log_prefix_chars=log_prefix_chars,
providers_allowed=providers_allowed_list,
providers_ignored=providers_ignored_list,
providers_order=providers_order_list,
provider_sort=provider_sort,
)
runner.run(resume=resume)