Add timeout configuration for trajectory processing

- Updated `trajectory_compression.yaml` to include a new `per_trajectory_timeout` setting, allowing for a timeout of 300 seconds per trajectory. This enhancement helps prevent hanging on problematic entries during processing, improving overall reliability and efficiency in trajectory handling.
This commit is contained in:
teknium
2026-01-30 07:34:58 +00:00
parent e8c6135a91
commit 8e8b6be690
3 changed files with 41 additions and 8 deletions

View File

@@ -84,6 +84,10 @@ processing:
# If true, save trajectories even if compression can't get under target
# (will compress as much as possible)
save_over_limit: true
# Timeout per trajectory in seconds (skip if takes longer)
# Helps avoid hanging on problematic entries
per_trajectory_timeout: 300 # 5 minutes
# Metrics to track
metrics:

View File

@@ -1252,7 +1252,7 @@ def main(
if save_sample:
import uuid
sample_id = str(uuid.uuid4())[:8]
sample_filename = f"sample_{sample_id}.jsonl"
sample_filename = f"sample_{sample_id}.json"
# Convert messages to trajectory format (same as batch_runner)
trajectory = agent._convert_to_trajectory_format(
@@ -1271,7 +1271,8 @@ def main(
try:
with open(sample_filename, "w", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
# Pretty-print JSON with indent for readability
f.write(json.dumps(entry, ensure_ascii=False, indent=2))
print(f"\n💾 Sample trajectory saved to: {sample_filename}")
except Exception as e:
print(f"\n⚠️ Failed to save sample: {e}")

View File

@@ -86,6 +86,7 @@ class CompressionConfig:
max_concurrent_requests: int = 50 # Max concurrent API calls for summarization
skip_under_target: bool = True
save_over_limit: bool = True
per_trajectory_timeout: int = 300 # Timeout per trajectory in seconds (default: 5 min)
# Metrics
metrics_enabled: bool = True
@@ -966,10 +967,13 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
# Results storage: {file_path: {entry_idx: (processed_entry, metrics)}}
results = {f: {} for f in jsonl_files}
# Track timeouts separately
timeout_count = 0
async def process_single(file_path: Path, entry_idx: int, entry: Dict,
progress, main_task, status_task):
"""Process a single entry with semaphore rate limiting."""
nonlocal compressed_count, skipped_count, api_calls, in_flight
"""Process a single entry with semaphore rate limiting and timeout."""
nonlocal compressed_count, skipped_count, api_calls, in_flight, timeout_count
async with semaphore:
# Track in-flight
@@ -977,7 +981,11 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
in_flight += 1
try:
processed_entry, metrics = await self.process_entry_async(entry)
# Apply per-trajectory timeout
processed_entry, metrics = await asyncio.wait_for(
self.process_entry_async(entry),
timeout=self.config.per_trajectory_timeout
)
results[file_path][entry_idx] = (processed_entry, metrics)
# Update aggregate metrics (with lock for thread safety)
@@ -997,8 +1005,24 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
progress.advance(main_task)
progress.update(
status_task,
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | ⏱️ {timeout_count} timeout | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
)
except asyncio.TimeoutError:
self.logger.warning(f"Timeout processing entry from {file_path}:{entry_idx} (>{self.config.per_trajectory_timeout}s)")
async with progress_lock:
self.aggregate_metrics.trajectories_failed += 1
timeout_count += 1
in_flight -= 1
progress.advance(main_task)
progress.update(
status_task,
description=f"[dim]✅ {compressed_count} compressed | ⏭️ {skipped_count} skipped | ⏱️ {timeout_count} timeout | 🔄 {api_calls} API calls | ⚡ {in_flight} in-flight[/dim]"
)
# Skip this entry entirely (don't include in output)
results[file_path][entry_idx] = None
except Exception as e:
self.logger.error(f"Error processing entry from {file_path}:{entry_idx}: {e}")
@@ -1056,8 +1080,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
output_path = output_dir / file_path.name
file_results = results[file_path]
# Sort by original entry index to preserve order
sorted_entries = [file_results[idx][0] for idx in sorted(file_results.keys())]
# Sort by original entry index to preserve order, skip None (timed out) entries
sorted_entries = [
file_results[idx][0]
for idx in sorted(file_results.keys())
if file_results[idx] is not None
]
with open(output_path, 'w', encoding='utf-8') as f:
for entry in sorted_entries: