Add timeout configuration for trajectory processing
- Updated `trajectory_compression.yaml` to include a new `per_trajectory_timeout` setting, allowing for a timeout of 300 seconds per trajectory. This enhancement helps prevent hanging on problematic entries during processing, improving overall reliability and efficiency in trajectory handling.
This commit is contained in:
@@ -84,6 +84,10 @@ processing:
|
||||
# If true, save trajectories even if compression can't get under target
|
||||
# (will compress as much as possible)
|
||||
save_over_limit: true
|
||||
|
||||
# Timeout per trajectory in seconds (skip if takes longer)
|
||||
# Helps avoid hanging on problematic entries
|
||||
per_trajectory_timeout: 300 # 5 minutes
|
||||
|
||||
# Metrics to track
|
||||
metrics:
|
||||
|
||||
@@ -1252,7 +1252,7 @@ def main(
|
||||
if save_sample:
|
||||
import uuid
|
||||
sample_id = str(uuid.uuid4())[:8]
|
||||
sample_filename = f"sample_{sample_id}.jsonl"
|
||||
sample_filename = f"sample_{sample_id}.json"
|
||||
|
||||
# Convert messages to trajectory format (same as batch_runner)
|
||||
trajectory = agent._convert_to_trajectory_format(
|
||||
@@ -1271,7 +1271,8 @@ def main(
|
||||
|
||||
try:
|
||||
with open(sample_filename, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
# Pretty-print JSON with indent for readability
|
||||
f.write(json.dumps(entry, ensure_ascii=False, indent=2))
|
||||
print(f"\n💾 Sample trajectory saved to: {sample_filename}")
|
||||
except Exception as e:
|
||||
print(f"\n⚠️ Failed to save sample: {e}")
|
||||
|
||||
@@ -86,6 +86,7 @@ class CompressionConfig:
|
||||
max_concurrent_requests: int = 50 # Max concurrent API calls for summarization
|
||||
skip_under_target: bool = True
|
||||
save_over_limit: bool = True
|
||||
per_trajectory_timeout: int = 300 # Timeout per trajectory in seconds (default: 5 min)
|
||||
|
||||
# Metrics
|
||||
metrics_enabled: bool = True
|
||||
@@ -966,10 +967,13 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
# Results storage: {file_path: {entry_idx: (processed_entry, metrics)}}
|
||||
results = {f: {} for f in jsonl_files}
|
||||
|
||||
# Track timeouts separately
|
||||
timeout_count = 0
|
||||
|
||||
async def process_single(file_path: Path, entry_idx: int, entry: Dict,
|
||||
progress, main_task, status_task):
|
||||
"""Process a single entry with semaphore rate limiting."""
|
||||
nonlocal compressed_count, skipped_count, api_calls, in_flight
|
||||
"""Process a single entry with semaphore rate limiting and timeout."""
|
||||
nonlocal compressed_count, skipped_count, api_calls, in_flight, timeout_count
|
||||
|
||||
async with semaphore:
|
||||
# Track in-flight
|
||||
@@ -977,7 +981,11 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
in_flight += 1
|
||||
|
||||
try:
|
||||
processed_entry, metrics = await self.process_entry_async(entry)
|
||||
# Apply per-trajectory timeout
|
||||
processed_entry, metrics = await asyncio.wait_for(
|
||||
self.process_entry_async(entry),
|
||||
timeout=self.config.per_trajectory_timeout
|
||||
)
|
||||
results[file_path][entry_idx] = (processed_entry, metrics)
|
||||
|
||||
# Update aggregate metrics (with lock for thread safety)
|
||||
@@ -997,8 +1005,24 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
progress.advance(main_task)
|
||||
progress.update(
|
||||
status_task,
|
||||
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
|
||||
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | ⏱️ {timeout_count} timeout | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f"Timeout processing entry from {file_path}:{entry_idx} (>{self.config.per_trajectory_timeout}s)")
|
||||
|
||||
async with progress_lock:
|
||||
self.aggregate_metrics.trajectories_failed += 1
|
||||
timeout_count += 1
|
||||
in_flight -= 1
|
||||
progress.advance(main_task)
|
||||
progress.update(
|
||||
status_task,
|
||||
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | ⏱️ {timeout_count} timeout | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
|
||||
)
|
||||
|
||||
# Skip this entry entirely (don't include in output)
|
||||
results[file_path][entry_idx] = None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}")
|
||||
@@ -1056,8 +1080,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
output_path = output_dir / file_path.name
|
||||
file_results = results[file_path]
|
||||
|
||||
# Sort by original entry index to preserve order
|
||||
sorted_entries = [file_results[idx][0] for idx in sorted(file_results.keys())]
|
||||
# Sort by original entry index to preserve order, skip None (timed out) entries
|
||||
sorted_entries = [
|
||||
file_results[idx][0]
|
||||
for idx in sorted(file_results.keys())
|
||||
if file_results[idx] is not None
|
||||
]
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for entry in sorted_entries:
|
||||
|
||||
Reference in New Issue
Block a user