Enhance RL test inference with WandB integration and real-time output streaming
- Added unique run ID generation for WandB tracking during test inference. - Enabled WandB usage for test tracking and updated command-line arguments accordingly. - Implemented real-time output streaming for process execution, improving log visibility and debugging. - Enhanced error handling to display last few lines of stderr for better troubleshooting.
This commit is contained in:
@@ -1093,6 +1093,10 @@ async def rl_test_inference(
|
||||
# Output file for this test run
|
||||
output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl"
|
||||
|
||||
# Generate unique run ID for wandb
|
||||
test_run_id = str(uuid.uuid4())[:8]
|
||||
wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}"
|
||||
|
||||
# Build the process command using Atropos's built-in CLI
|
||||
# This runs the environment's actual code with OpenRouter as the inference backend
|
||||
# We pass our locked settings + test-specific overrides via CLI args
|
||||
@@ -1101,7 +1105,8 @@ async def rl_test_inference(
|
||||
# Test-specific overrides
|
||||
"--env.total_steps", str(num_steps),
|
||||
"--env.group_size", str(group_size),
|
||||
"--env.use_wandb", "false", # No wandb for quick tests
|
||||
"--env.use_wandb", "true", # Enable wandb for test tracking
|
||||
"--env.wandb_name", wandb_run_name,
|
||||
"--env.data_path_to_save_groups", str(output_file),
|
||||
# Use locked settings from our config
|
||||
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
|
||||
@@ -1124,12 +1129,14 @@ async def rl_test_inference(
|
||||
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
|
||||
print(f"Command: {cmd_display}")
|
||||
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
|
||||
print(f"WandB run: {wandb_run_name}")
|
||||
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
|
||||
|
||||
model_results = {
|
||||
"model": model_id,
|
||||
"name": model_info["name"],
|
||||
"scale": model_info["scale"],
|
||||
"wandb_run": wandb_run_name,
|
||||
"output_file": str(output_file),
|
||||
"steps": [],
|
||||
"steps_tested": 0,
|
||||
@@ -1138,7 +1145,7 @@ async def rl_test_inference(
|
||||
}
|
||||
|
||||
try:
|
||||
# Run the process command
|
||||
# Run the process command with real-time output streaming
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
@@ -1146,17 +1153,43 @@ async def rl_test_inference(
|
||||
cwd=str(TINKER_ATROPOS_ROOT),
|
||||
)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=600, # 10 minute timeout per model
|
||||
)
|
||||
# Stream output in real-time while collecting for logs
|
||||
stdout_lines = []
|
||||
stderr_lines = []
|
||||
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
|
||||
|
||||
# Decode output
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
async def read_stream(stream, lines_list, prefix=""):
|
||||
"""Read stream line by line and print in real-time."""
|
||||
while True:
|
||||
line = await stream.readline()
|
||||
if not line:
|
||||
break
|
||||
decoded = line.decode().rstrip()
|
||||
lines_list.append(decoded)
|
||||
# Print progress-related lines in real-time
|
||||
if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']):
|
||||
print(f" {prefix}{decoded}")
|
||||
|
||||
# Read both streams concurrently with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(
|
||||
read_stream(process.stdout, stdout_lines, "📊 "),
|
||||
read_stream(process.stderr, stderr_lines, "⚠️ "),
|
||||
),
|
||||
timeout=600, # 10 minute timeout per model
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
raise
|
||||
|
||||
await process.wait()
|
||||
|
||||
# Combine output for logging
|
||||
stdout_text = "\n".join(stdout_lines)
|
||||
stderr_text = "\n".join(stderr_lines)
|
||||
|
||||
# Write logs to files for inspection outside CLI
|
||||
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
|
||||
with open(log_file, "w") as f:
|
||||
f.write(f"Command: {cmd_display}\n")
|
||||
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
|
||||
@@ -1170,21 +1203,17 @@ async def rl_test_inference(
|
||||
|
||||
print(f" Log file: {log_file}")
|
||||
|
||||
# Print to console for immediate debugging
|
||||
if stdout_text.strip():
|
||||
print(f"\n--- STDOUT ---")
|
||||
print(stdout_text[-2000:]) # Last 2000 chars
|
||||
|
||||
if stderr_text.strip():
|
||||
print(f"\n--- STDERR ---")
|
||||
print(stderr_text[-2000:]) # Last 2000 chars
|
||||
|
||||
if process.returncode != 0:
|
||||
model_results["error"] = f"Process exited with code {process.returncode}"
|
||||
model_results["stderr"] = stderr_text[-1000:]
|
||||
model_results["stdout"] = stdout_text[-1000:]
|
||||
model_results["log_file"] = str(log_file)
|
||||
print(f"\n ❌ Error: {model_results['error']}")
|
||||
# Print last few lines of stderr for debugging
|
||||
if stderr_lines:
|
||||
print(f" Last errors:")
|
||||
for line in stderr_lines[-5:]:
|
||||
print(f" {line}")
|
||||
else:
|
||||
print(f"\n ✅ Process completed successfully")
|
||||
print(f" Output file: {output_file}")
|
||||
|
||||
Reference in New Issue
Block a user