Enhance RL test inference with WandB integration and real-time output streaming

- Added unique run ID generation for WandB tracking during test inference.
- Enabled WandB usage for test tracking and updated command-line arguments accordingly.
- Implemented real-time output streaming for process execution, improving log visibility and debugging.
- Enhanced error handling to display last few lines of stderr for better troubleshooting.
This commit is contained in:
teknium1
2026-02-04 21:07:07 -08:00
parent 3c0d0dba49
commit 5c3105b437

View File

@@ -1093,6 +1093,10 @@ async def rl_test_inference(
# Output file for this test run
output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl"
# Generate unique run ID for wandb
test_run_id = str(uuid.uuid4())[:8]
wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}"
# Build the process command using Atropos's built-in CLI
# This runs the environment's actual code with OpenRouter as the inference backend
# We pass our locked settings + test-specific overrides via CLI args
@@ -1101,7 +1105,8 @@ async def rl_test_inference(
# Test-specific overrides
"--env.total_steps", str(num_steps),
"--env.group_size", str(group_size),
"--env.use_wandb", "false", # No wandb for quick tests
"--env.use_wandb", "true", # Enable wandb for test tracking
"--env.wandb_name", wandb_run_name,
"--env.data_path_to_save_groups", str(output_file),
# Use locked settings from our config
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
@@ -1124,12 +1129,14 @@ async def rl_test_inference(
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
print(f"Command: {cmd_display}")
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
print(f"WandB run: {wandb_run_name}")
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
model_results = {
"model": model_id,
"name": model_info["name"],
"scale": model_info["scale"],
"wandb_run": wandb_run_name,
"output_file": str(output_file),
"steps": [],
"steps_tested": 0,
@@ -1138,7 +1145,7 @@ async def rl_test_inference(
}
try:
# Run the process command
# Run the process command with real-time output streaming
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
@@ -1146,17 +1153,43 @@ async def rl_test_inference(
cwd=str(TINKER_ATROPOS_ROOT),
)
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=600, # 10 minute timeout per model
)
# Stream output in real-time while collecting for logs
stdout_lines = []
stderr_lines = []
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
# Decode output
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
async def read_stream(stream, lines_list, prefix=""):
"""Read stream line by line and print in real-time."""
while True:
line = await stream.readline()
if not line:
break
decoded = line.decode().rstrip()
lines_list.append(decoded)
# Print progress-related lines in real-time
if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']):
print(f" {prefix}{decoded}")
# Read both streams concurrently with timeout
try:
await asyncio.wait_for(
asyncio.gather(
read_stream(process.stdout, stdout_lines, "📊 "),
read_stream(process.stderr, stderr_lines, "⚠️ "),
),
timeout=600, # 10 minute timeout per model
)
except asyncio.TimeoutError:
process.kill()
raise
await process.wait()
# Combine output for logging
stdout_text = "\n".join(stdout_lines)
stderr_text = "\n".join(stderr_lines)
# Write logs to files for inspection outside CLI
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
with open(log_file, "w") as f:
f.write(f"Command: {cmd_display}\n")
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
@@ -1170,21 +1203,17 @@ async def rl_test_inference(
print(f" Log file: {log_file}")
# Print to console for immediate debugging
if stdout_text.strip():
print(f"\n--- STDOUT ---")
print(stdout_text[-2000:]) # Last 2000 chars
if stderr_text.strip():
print(f"\n--- STDERR ---")
print(stderr_text[-2000:]) # Last 2000 chars
if process.returncode != 0:
model_results["error"] = f"Process exited with code {process.returncode}"
model_results["stderr"] = stderr_text[-1000:]
model_results["stdout"] = stdout_text[-1000:]
model_results["log_file"] = str(log_file)
print(f"\n ❌ Error: {model_results['error']}")
# Print last few lines of stderr for debugging
if stderr_lines:
print(f" Last errors:")
for line in stderr_lines[-5:]:
print(f" {line}")
else:
print(f"\n ✅ Process completed successfully")
print(f" Output file: {output_file}")