2026-02-03 23:41:26 -08:00
#!/usr/bin/env python3
"""
RL Training Tools Module
This module provides tools for running RL training through Tinker - Atropos .
2026-02-04 10:36:01 -08:00
Directly manages training processes without requiring a separate API server .
Features :
- Environment discovery ( AST - based scanning for BaseEnv subclasses )
- Configuration management with locked infrastructure settings
- Training run lifecycle via subprocess management
2026-02-03 23:41:26 -08:00
- WandB metrics monitoring
Required environment variables :
- TINKER_API_KEY : API key for Tinker service
- WANDB_API_KEY : API key for Weights & Biases metrics
Usage :
from tools . rl_training_tool import (
rl_list_environments ,
rl_select_environment ,
rl_get_current_config ,
rl_edit_config ,
rl_start_training ,
rl_check_status ,
rl_stop_training ,
rl_get_results ,
)
"""
2026-02-04 10:36:01 -08:00
import ast
import asyncio
import importlib . util
2026-02-03 23:41:26 -08:00
import json
import os
2026-02-04 10:36:01 -08:00
import subprocess
import sys
2026-02-03 23:41:26 -08:00
import time
2026-02-04 10:36:01 -08:00
import uuid
2026-02-04 13:57:59 -08:00
from datetime import datetime
2026-02-04 10:36:01 -08:00
import yaml
from dataclasses import dataclass , field
from pathlib import Path
2026-02-03 23:41:26 -08:00
from typing import Any , Dict , List , Optional
2026-02-04 10:36:01 -08:00
# ============================================================================
# Path Configuration
# ============================================================================
# Path to tinker-atropos submodule (relative to hermes-agent root)
HERMES_ROOT = Path ( __file__ ) . parent . parent
TINKER_ATROPOS_ROOT = HERMES_ROOT / " tinker-atropos "
ENVIRONMENTS_DIR = TINKER_ATROPOS_ROOT / " tinker_atropos " / " environments "
CONFIGS_DIR = TINKER_ATROPOS_ROOT / " configs "
2026-03-11 17:00:39 +05:30
LOGS_DIR = Path ( os . getenv ( " HERMES_HOME " , Path . home ( ) / " .hermes " ) ) / " logs " / " rl_training "
2026-02-04 10:36:01 -08:00
2026-03-12 09:11:44 -07:00
def _ensure_logs_dir ( ) :
""" Lazily create logs directory on first use (avoid side effects at import time). """
if TINKER_ATROPOS_ROOT . exists ( ) :
LOGS_DIR . mkdir ( exist_ok = True )
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
# Locked Configuration (Infrastructure Settings)
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
# These fields cannot be changed by the model - they're tuned for our infrastructure
LOCKED_FIELDS = {
" env " : {
" tokenizer_name " : " Qwen/Qwen3-8B " ,
" rollout_server_url " : " http://localhost:8000 " ,
" use_wandb " : True ,
" max_token_length " : 8192 ,
" max_num_workers " : 2048 ,
" worker_timeout " : 3600 ,
" total_steps " : 2500 ,
" steps_per_eval " : 25 ,
" max_batches_offpolicy " : 3 ,
" inference_weight " : 1.0 ,
" eval_limit_ratio " : 0.1 ,
} ,
" openai " : [
{
" model_name " : " Qwen/Qwen3-8B " ,
" base_url " : " http://localhost:8001/v1 " ,
" api_key " : " x " ,
" weight " : 1.0 ,
" num_requests_for_eval " : 256 ,
" timeout " : 3600 ,
2026-02-04 13:57:59 -08:00
" server_type " : " sglang " , # Tinker uses sglang for actual training
2026-02-04 10:36:01 -08:00
}
] ,
" tinker " : {
" lora_rank " : 32 ,
" learning_rate " : 0.00004 ,
" max_token_trainer_length " : 9000 ,
" checkpoint_dir " : " ./temp/ " ,
" save_checkpoint_interval " : 25 ,
} ,
" slurm " : False ,
" testing " : False ,
}
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
LOCKED_FIELD_NAMES = set ( LOCKED_FIELDS . get ( " env " , { } ) . keys ( ) )
# ============================================================================
# State Management
# ============================================================================
@dataclass
class EnvironmentInfo :
""" Information about a discovered environment. """
name : str
class_name : str
file_path : str
description : str = " "
config_class : str = " BaseEnvConfig "
@dataclass
class RunState :
""" State for a training run. """
run_id : str
environment : str
config : Dict [ str , Any ]
status : str = " pending " # pending, starting, running, stopping, stopped, completed, failed
error_message : str = " "
wandb_project : str = " "
wandb_run_name : str = " "
start_time : float = 0.0
# Process handles
api_process : Optional [ subprocess . Popen ] = None
trainer_process : Optional [ subprocess . Popen ] = None
env_process : Optional [ subprocess . Popen ] = None
# Global state
_environments : List [ EnvironmentInfo ] = [ ]
_current_env : Optional [ str ] = None
_current_config : Dict [ str , Any ] = { }
_env_config_cache : Dict [ str , Dict [ str , Dict [ str , Any ] ] ] = { }
_active_runs : Dict [ str , RunState ] = { }
2026-02-03 23:41:26 -08:00
_last_status_check : Dict [ str , float ] = { }
2026-02-04 10:36:01 -08:00
# Rate limiting for status checks (30 minutes)
MIN_STATUS_CHECK_INTERVAL = 30 * 60
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
# Environment Discovery
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
def _scan_environments ( ) - > List [ EnvironmentInfo ] :
"""
Scan the environments directory for BaseEnv subclasses using AST .
"""
environments = [ ]
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
if not ENVIRONMENTS_DIR . exists ( ) :
return environments
for py_file in ENVIRONMENTS_DIR . glob ( " *.py " ) :
if py_file . name . startswith ( " _ " ) :
continue
2026-02-03 23:41:26 -08:00
try :
2026-02-04 10:36:01 -08:00
with open ( py_file , " r " ) as f :
tree = ast . parse ( f . read ( ) )
for node in ast . walk ( tree ) :
if isinstance ( node , ast . ClassDef ) :
# Check if class has BaseEnv as base
for base in node . bases :
base_name = " "
if isinstance ( base , ast . Name ) :
base_name = base . id
elif isinstance ( base , ast . Attribute ) :
base_name = base . attr
if base_name == " BaseEnv " :
# Extract name from class attribute if present
env_name = py_file . stem
description = " "
config_class = " BaseEnvConfig "
for item in node . body :
if isinstance ( item , ast . Assign ) :
for target in item . targets :
if isinstance ( target , ast . Name ) :
if target . id == " name " and isinstance ( item . value , ast . Constant ) :
env_name = item . value . value
elif target . id == " env_config_cls " and isinstance ( item . value , ast . Name ) :
config_class = item . value . id
# Get docstring
if isinstance ( item , ast . Expr ) and isinstance ( item . value , ast . Constant ) :
if isinstance ( item . value . value , str ) and not description :
description = item . value . value . split ( " \n " ) [ 0 ] . strip ( )
environments . append ( EnvironmentInfo (
name = env_name ,
class_name = node . name ,
file_path = str ( py_file ) ,
description = description or f " Environment from { py_file . name } " ,
config_class = config_class ,
) )
break
2026-02-03 23:41:26 -08:00
except Exception as e :
2026-02-04 10:36:01 -08:00
print ( f " Warning: Could not parse { py_file } : { e } " )
return environments
def _get_env_config_fields ( env_file_path : str ) - > Dict [ str , Dict [ str , Any ] ] :
"""
Dynamically import an environment and extract its config fields .
2026-02-04 13:57:59 -08:00
Uses config_init ( ) to get the actual config class , with fallback to
directly importing BaseEnvConfig if config_init fails .
2026-02-04 10:36:01 -08:00
"""
try :
# Load the environment module
spec = importlib . util . spec_from_file_location ( " env_module " , env_file_path )
module = importlib . util . module_from_spec ( spec )
sys . modules [ " env_module " ] = module
spec . loader . exec_module ( module )
# Find the BaseEnv subclass
env_class = None
for name , obj in vars ( module ) . items ( ) :
if isinstance ( obj , type ) and name != " BaseEnv " :
if hasattr ( obj , " config_init " ) and callable ( getattr ( obj , " config_init " ) ) :
env_class = obj
break
if not env_class :
return { }
2026-02-04 13:57:59 -08:00
# Try calling config_init to get the actual config class
config_class = None
try :
env_config , server_configs = env_class . config_init ( )
config_class = type ( env_config )
except Exception as config_error :
# Fallback: try to import BaseEnvConfig directly from atroposlib
print ( f " Note: config_init failed ( { config_error } ), using BaseEnvConfig defaults " )
try :
from atroposlib . envs . base import BaseEnvConfig
config_class = BaseEnvConfig
except ImportError :
return { }
if not config_class :
return { }
# Helper to make values JSON-serializable (handle enums, etc.)
def make_serializable ( val ) :
if val is None :
return None
if hasattr ( val , ' value ' ) : # Enum
return val . value
if hasattr ( val , ' name ' ) and hasattr ( val , ' __class__ ' ) and ' Enum ' in str ( type ( val ) ) :
return val . name
return val
2026-02-04 10:36:01 -08:00
# Extract fields from the Pydantic model
fields = { }
for field_name , field_info in config_class . model_fields . items ( ) :
field_type = field_info . annotation
2026-02-04 13:57:59 -08:00
default = make_serializable ( field_info . default )
2026-02-04 10:36:01 -08:00
description = field_info . description or " "
is_locked = field_name in LOCKED_FIELD_NAMES
# Convert type to string
type_name = getattr ( field_type , " __name__ " , str ( field_type ) )
if hasattr ( field_type , " __origin__ " ) :
type_name = str ( field_type )
2026-02-04 13:57:59 -08:00
locked_value = LOCKED_FIELDS . get ( " env " , { } ) . get ( field_name , default )
current_value = make_serializable ( locked_value ) if is_locked else default
2026-02-04 10:36:01 -08:00
fields [ field_name ] = {
" type " : type_name ,
2026-02-04 13:57:59 -08:00
" default " : default ,
2026-02-04 10:36:01 -08:00
" description " : description ,
" locked " : is_locked ,
2026-02-04 13:57:59 -08:00
" current_value " : current_value ,
2026-02-04 10:36:01 -08:00
}
return fields
except Exception as e :
print ( f " Warning: Could not introspect environment config: { e } " )
return { }
def _initialize_environments ( ) :
""" Initialize environment list on first use. """
global _environments
if not _environments :
_environments = _scan_environments ( )
# ============================================================================
# Subprocess Management
# ============================================================================
async def _spawn_training_run ( run_state : RunState , config_path : Path ) :
"""
Spawn the three processes needed for training :
1. run - api ( Atropos API server )
2. launch_training . py ( Tinker trainer + inference server )
3. environment . py serve ( the Atropos environment )
"""
run_id = run_state . run_id
2026-03-12 09:11:44 -07:00
_ensure_logs_dir ( )
2026-02-04 10:36:01 -08:00
# Log file paths
api_log = LOGS_DIR / f " api_ { run_id } .log "
trainer_log = LOGS_DIR / f " trainer_ { run_id } .log "
env_log = LOGS_DIR / f " env_ { run_id } .log "
try :
# Step 1: Start the Atropos API server (run-api)
print ( f " [ { run_id } ] Starting Atropos API server (run-api)... " )
2026-03-08 22:38:02 +03:00
# File must stay open while the subprocess runs; we store the handle
# on run_state so _stop_training_run() can close it when done.
api_log_file = open ( api_log , " w " ) # closed by _stop_training_run
run_state . api_log_file = api_log_file
2026-02-04 10:36:01 -08:00
run_state . api_process = subprocess . Popen (
[ " run-api " ] ,
stdout = api_log_file ,
stderr = subprocess . STDOUT ,
cwd = str ( TINKER_ATROPOS_ROOT ) ,
)
# Wait for API to start
await asyncio . sleep ( 5 )
if run_state . api_process . poll ( ) is not None :
run_state . status = " failed "
run_state . error_message = f " API server exited with code { run_state . api_process . returncode } . Check { api_log } "
2026-03-10 17:09:51 -07:00
_stop_training_run ( run_state )
2026-02-04 10:36:01 -08:00
return
print ( f " [ { run_id } ] Atropos API server started " )
# Step 2: Start the Tinker trainer
print ( f " [ { run_id } ] Starting Tinker trainer: launch_training.py --config { config_path } " )
2026-03-08 22:38:02 +03:00
trainer_log_file = open ( trainer_log , " w " ) # closed by _stop_training_run
run_state . trainer_log_file = trainer_log_file
2026-02-04 10:36:01 -08:00
run_state . trainer_process = subprocess . Popen (
2026-02-04 13:57:59 -08:00
[ sys . executable , " launch_training.py " , " --config " , str ( config_path ) ] ,
2026-02-04 10:36:01 -08:00
stdout = trainer_log_file ,
stderr = subprocess . STDOUT ,
cwd = str ( TINKER_ATROPOS_ROOT ) ,
env = { * * os . environ , " TINKER_API_KEY " : os . getenv ( " TINKER_API_KEY " , " " ) } ,
)
# Wait for trainer to initialize (it starts FastAPI inference server on 8001)
print ( f " [ { run_id } ] Waiting 30 seconds for trainer to initialize... " )
await asyncio . sleep ( 30 )
if run_state . trainer_process . poll ( ) is not None :
run_state . status = " failed "
run_state . error_message = f " Trainer exited with code { run_state . trainer_process . returncode } . Check { trainer_log } "
2026-03-10 17:09:51 -07:00
_stop_training_run ( run_state )
2026-02-04 10:36:01 -08:00
return
print ( f " [ { run_id } ] Trainer started, inference server on port 8001 " )
# Step 3: Start the environment
print ( f " [ { run_id } ] Waiting 90 more seconds before starting environment... " )
await asyncio . sleep ( 90 )
# Find the environment file
env_info = None
for env in _environments :
if env . name == run_state . environment :
env_info = env
break
if not env_info :
run_state . status = " failed "
run_state . error_message = f " Environment ' { run_state . environment } ' not found "
2026-03-10 17:09:51 -07:00
_stop_training_run ( run_state )
2026-02-04 10:36:01 -08:00
return
print ( f " [ { run_id } ] Starting environment: { env_info . file_path } serve " )
2026-03-08 22:38:02 +03:00
env_log_file = open ( env_log , " w " ) # closed by _stop_training_run
run_state . env_log_file = env_log_file
2026-02-04 10:36:01 -08:00
run_state . env_process = subprocess . Popen (
2026-02-04 13:57:59 -08:00
[ sys . executable , str ( env_info . file_path ) , " serve " , " --config " , str ( config_path ) ] ,
2026-02-04 10:36:01 -08:00
stdout = env_log_file ,
stderr = subprocess . STDOUT ,
cwd = str ( TINKER_ATROPOS_ROOT ) ,
)
# Wait for environment to connect
await asyncio . sleep ( 10 )
if run_state . env_process . poll ( ) is not None :
run_state . status = " failed "
run_state . error_message = f " Environment exited with code { run_state . env_process . returncode } . Check { env_log } "
2026-03-10 17:09:51 -07:00
_stop_training_run ( run_state )
2026-02-04 10:36:01 -08:00
return
run_state . status = " running "
run_state . start_time = time . time ( )
print ( f " [ { run_id } ] Training run started successfully! " )
# Start background monitoring
asyncio . create_task ( _monitor_training_run ( run_state ) )
except Exception as e :
run_state . status = " failed "
run_state . error_message = str ( e )
_stop_training_run ( run_state )
async def _monitor_training_run ( run_state : RunState ) :
""" Background task to monitor a training run. """
while run_state . status == " running " :
await asyncio . sleep ( 30 ) # Check every 30 seconds
# Check if any process has died
if run_state . env_process and run_state . env_process . poll ( ) is not None :
exit_code = run_state . env_process . returncode
if exit_code == 0 :
run_state . status = " completed "
else :
run_state . status = " failed "
run_state . error_message = f " Environment process exited with code { exit_code } "
_stop_training_run ( run_state )
break
if run_state . trainer_process and run_state . trainer_process . poll ( ) is not None :
exit_code = run_state . trainer_process . returncode
if exit_code == 0 :
run_state . status = " completed "
else :
run_state . status = " failed "
run_state . error_message = f " Trainer process exited with code { exit_code } "
_stop_training_run ( run_state )
break
if run_state . api_process and run_state . api_process . poll ( ) is not None :
run_state . status = " failed "
run_state . error_message = f " API server exited unexpectedly "
_stop_training_run ( run_state )
break
def _stop_training_run ( run_state : RunState ) :
""" Stop all processes for a training run. """
# Stop in reverse order: env -> trainer -> api
if run_state . env_process and run_state . env_process . poll ( ) is None :
print ( f " [ { run_state . run_id } ] Stopping environment process... " )
run_state . env_process . terminate ( )
try :
run_state . env_process . wait ( timeout = 10 )
except subprocess . TimeoutExpired :
run_state . env_process . kill ( )
if run_state . trainer_process and run_state . trainer_process . poll ( ) is None :
print ( f " [ { run_state . run_id } ] Stopping trainer process... " )
run_state . trainer_process . terminate ( )
try :
run_state . trainer_process . wait ( timeout = 10 )
except subprocess . TimeoutExpired :
run_state . trainer_process . kill ( )
if run_state . api_process and run_state . api_process . poll ( ) is None :
print ( f " [ { run_state . run_id } ] Stopping API server... " )
run_state . api_process . terminate ( )
try :
run_state . api_process . wait ( timeout = 10 )
except subprocess . TimeoutExpired :
run_state . api_process . kill ( )
if run_state . status == " running " :
run_state . status = " stopped "
2026-02-03 23:41:26 -08:00
2026-03-08 22:38:02 +03:00
# Close log file handles that were opened for subprocess stdout.
for attr in ( " env_log_file " , " trainer_log_file " , " api_log_file " ) :
fh = getattr ( run_state , attr , None )
if fh is not None :
try :
fh . close ( )
except Exception :
pass
setattr ( run_state , attr , None )
2026-02-03 23:41:26 -08:00
# ============================================================================
# Environment Discovery Tools
# ============================================================================
async def rl_list_environments ( ) - > str :
"""
List all available RL environments .
Scans tinker - atropos / tinker_atropos / environments / for Python files
containing classes that inherit from BaseEnv .
Returns information about each environment including :
- name : Environment identifier
- class_name : Python class name
- file_path : Path to the environment file
- description : Brief description if available
TIP : To create or modify RL environments :
1. Use terminal / file tools to inspect existing environments
2. Study how they load datasets , define verifiers , and structure rewards
3. Inspect HuggingFace datasets to understand data formats
4. Copy an existing environment as a template
Returns :
2026-02-04 10:36:01 -08:00
JSON string with list of environments
2026-02-03 23:41:26 -08:00
"""
2026-02-04 10:36:01 -08:00
_initialize_environments ( )
2026-02-03 23:41:26 -08:00
response = {
2026-02-04 10:36:01 -08:00
" environments " : [
{
" name " : env . name ,
" class_name " : env . class_name ,
" file_path " : env . file_path ,
" description " : env . description ,
}
for env in _environments
] ,
" count " : len ( _environments ) ,
2026-02-03 23:41:26 -08:00
" tips " : [
" Use rl_select_environment(name) to select an environment " ,
" Read the file_path with file tools to understand how each environment works " ,
" Look for load_dataset(), score_answer(), get_next_item() methods " ,
]
}
return json . dumps ( response , indent = 2 )
async def rl_select_environment ( name : str ) - > str :
"""
Select an RL environment for training .
2026-02-04 10:36:01 -08:00
This loads the environment ' s configuration fields into memory.
After selecting , use rl_get_current_config ( ) to see all configurable options
2026-02-03 23:41:26 -08:00
and rl_edit_config ( ) to modify specific fields .
Args :
name : Name of the environment to select ( from rl_list_environments )
Returns :
2026-02-04 10:36:01 -08:00
JSON string with selection result , file path , and configurable field count
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
TIP : Read the returned file_path to understand how the environment works .
2026-02-03 23:41:26 -08:00
"""
2026-02-04 10:36:01 -08:00
global _current_env , _current_config , _env_config_cache
_initialize_environments ( )
env_info = None
for env in _environments :
if env . name == name :
env_info = env
break
if not env_info :
return json . dumps ( {
" error " : f " Environment ' { name } ' not found " ,
" available " : [ e . name for e in _environments ] ,
} , indent = 2 )
_current_env = name
# Dynamically discover config fields
config_fields = _get_env_config_fields ( env_info . file_path )
_env_config_cache [ name ] = config_fields
# Initialize current config with defaults for non-locked fields
_current_config = { }
for field_name , field_info in config_fields . items ( ) :
if not field_info . get ( " locked " , False ) :
_current_config [ field_name ] = field_info . get ( " default " )
2026-02-04 13:57:59 -08:00
# Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps
timestamp = datetime . now ( ) . strftime ( " % Y % m %d - % H % M % S " )
_current_config [ " wandb_name " ] = f " { name } - { timestamp } "
2026-02-04 10:36:01 -08:00
return json . dumps ( {
" message " : f " Selected environment: { name } " ,
" environment " : name ,
" file_path " : env_info . file_path ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
# ============================================================================
# Configuration Tools
# ============================================================================
async def rl_get_current_config ( ) - > str :
"""
Get the current environment configuration .
2026-02-04 09:36:51 -08:00
Returns all configurable fields for the selected environment .
Each environment may have different configuration options .
2026-02-03 23:41:26 -08:00
2026-02-04 09:36:51 -08:00
Fields are divided into :
- configurable_fields : Can be changed with rl_edit_config ( )
- locked_fields : Infrastructure settings that cannot be changed
2026-02-03 23:41:26 -08:00
Returns :
2026-02-04 09:36:51 -08:00
JSON string with configurable and locked fields
2026-02-03 23:41:26 -08:00
"""
2026-02-04 10:36:01 -08:00
if not _current_env :
return json . dumps ( {
" error " : " No environment selected. Use rl_select_environment(name) first. " ,
} , indent = 2 )
config_fields = _env_config_cache . get ( _current_env , { } )
configurable = [ ]
locked = [ ]
for field_name , field_info in config_fields . items ( ) :
field_data = {
" name " : field_name ,
" type " : field_info . get ( " type " , " unknown " ) ,
" default " : field_info . get ( " default " ) ,
" description " : field_info . get ( " description " , " " ) ,
" current_value " : _current_config . get ( field_name , field_info . get ( " default " ) ) ,
}
if field_info . get ( " locked " , False ) :
field_data [ " locked_value " ] = LOCKED_FIELDS . get ( " env " , { } ) . get ( field_name )
locked . append ( field_data )
else :
configurable . append ( field_data )
return json . dumps ( {
" environment " : _current_env ,
" configurable_fields " : configurable ,
" locked_fields " : locked ,
" tip " : " Use rl_edit_config(field, value) to change any configurable field. " ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
async def rl_edit_config ( field : str , value : Any ) - > str :
"""
Update a configuration field .
2026-02-04 09:36:51 -08:00
Use rl_get_current_config ( ) first to see available fields for the
selected environment . Each environment has different options .
Locked fields ( infrastructure settings ) cannot be changed .
2026-02-03 23:41:26 -08:00
Args :
2026-02-04 09:36:51 -08:00
field : Name of the field to update ( from rl_get_current_config )
2026-02-03 23:41:26 -08:00
value : New value for the field
Returns :
JSON string with updated config or error message
"""
2026-02-04 10:36:01 -08:00
global _current_config
if not _current_env :
return json . dumps ( {
" error " : " No environment selected. Use rl_select_environment(name) first. " ,
} , indent = 2 )
config_fields = _env_config_cache . get ( _current_env , { } )
if field not in config_fields :
return json . dumps ( {
" error " : f " Unknown field ' { field } ' " ,
" available_fields " : list ( config_fields . keys ( ) ) ,
} , indent = 2 )
field_info = config_fields [ field ]
if field_info . get ( " locked " , False ) :
return json . dumps ( {
" error " : f " Field ' { field } ' is locked and cannot be changed " ,
" locked_value " : LOCKED_FIELDS . get ( " env " , { } ) . get ( field ) ,
} , indent = 2 )
_current_config [ field ] = value
return json . dumps ( {
" message " : f " Updated { field } = { value } " ,
" field " : field ,
" value " : value ,
" config " : _current_config ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
# ============================================================================
# Training Management Tools
# ============================================================================
2026-02-04 09:36:51 -08:00
async def rl_start_training ( ) - > str :
2026-02-03 23:41:26 -08:00
"""
Start a new RL training run with the current environment and config .
Requires an environment to be selected first using rl_select_environment ( ) .
2026-02-04 10:36:01 -08:00
Use rl_edit_config ( ) to adjust configuration before starting .
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
This spawns three processes :
1. run - api ( Atropos trajectory API )
2. launch_training . py ( Tinker trainer + inference server )
3. environment . py serve ( the selected environment )
2026-02-03 23:41:26 -08:00
2026-02-04 09:36:51 -08:00
WARNING : Training runs take hours . Use rl_check_status ( ) to monitor
progress ( recommended : check every 30 minutes at most ) .
2026-02-03 23:41:26 -08:00
Returns :
JSON string with run_id and initial status
"""
2026-02-04 10:36:01 -08:00
global _active_runs
if not _current_env :
return json . dumps ( {
" error " : " No environment selected. Use rl_select_environment(name) first. " ,
} , indent = 2 )
# Check API keys
if not os . getenv ( " TINKER_API_KEY " ) :
return json . dumps ( {
" error " : " TINKER_API_KEY not set. Add it to ~/.hermes/.env " ,
} , indent = 2 )
# Find environment file
env_info = None
for env in _environments :
if env . name == _current_env :
env_info = env
break
if not env_info or not Path ( env_info . file_path ) . exists ( ) :
return json . dumps ( {
" error " : f " Environment file not found for ' { _current_env } ' " ,
} , indent = 2 )
# Generate run ID
run_id = str ( uuid . uuid4 ( ) ) [ : 8 ]
# Create config YAML
CONFIGS_DIR . mkdir ( exist_ok = True )
config_path = CONFIGS_DIR / f " run_ { run_id } .yaml "
# Start with locked config as base
import copy
run_config = copy . deepcopy ( LOCKED_FIELDS )
if " env " not in run_config :
run_config [ " env " ] = { }
# Apply configurable fields
for field_name , value in _current_config . items ( ) :
if value is not None and value != " " :
run_config [ " env " ] [ field_name ] = value
# Set WandB settings
wandb_project = _current_config . get ( " wandb_project " , " atropos-tinker " )
if " tinker " not in run_config :
run_config [ " tinker " ] = { }
run_config [ " tinker " ] [ " wandb_project " ] = wandb_project
run_config [ " tinker " ] [ " wandb_run_name " ] = f " { _current_env } - { run_id } "
if " wandb_name " in _current_config and _current_config [ " wandb_name " ] :
run_config [ " env " ] [ " wandb_name " ] = _current_config [ " wandb_name " ]
with open ( config_path , " w " ) as f :
yaml . dump ( run_config , f , default_flow_style = False )
# Create run state
run_state = RunState (
run_id = run_id ,
environment = _current_env ,
config = _current_config . copy ( ) ,
status = " starting " ,
wandb_project = wandb_project ,
wandb_run_name = f " { _current_env } - { run_id } " ,
)
_active_runs [ run_id ] = run_state
# Start training in background
asyncio . create_task ( _spawn_training_run ( run_state , config_path ) )
return json . dumps ( {
" run_id " : run_id ,
" status " : " starting " ,
" environment " : _current_env ,
" config " : _current_config ,
" wandb_project " : wandb_project ,
" wandb_run_name " : f " { _current_env } - { run_id } " ,
" config_path " : str ( config_path ) ,
" logs " : {
" api " : str ( LOGS_DIR / f " api_ { run_id } .log " ) ,
" trainer " : str ( LOGS_DIR / f " trainer_ { run_id } .log " ) ,
" env " : str ( LOGS_DIR / f " env_ { run_id } .log " ) ,
} ,
" message " : " Training starting. Use rl_check_status(run_id) to monitor (recommended: every 30 minutes). " ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
async def rl_check_status ( run_id : str ) - > str :
"""
Get status and metrics for a training run .
RATE LIMITED : For long - running training , this function enforces a
minimum 30 - minute interval between checks for the same run_id .
Args :
run_id : The run ID returned by rl_start_training ( )
Returns :
2026-02-04 10:36:01 -08:00
JSON string with run status and metrics
2026-02-03 23:41:26 -08:00
"""
global _last_status_check
# Check rate limiting
now = time . time ( )
if run_id in _last_status_check :
elapsed = now - _last_status_check [ run_id ]
if elapsed < MIN_STATUS_CHECK_INTERVAL :
remaining = MIN_STATUS_CHECK_INTERVAL - elapsed
return json . dumps ( {
" rate_limited " : True ,
" run_id " : run_id ,
" message " : f " Rate limited. Next check available in { remaining / 60 : .0f } minutes. " ,
" next_check_in_seconds " : remaining ,
} , indent = 2 )
_last_status_check [ run_id ] = now
2026-02-04 10:36:01 -08:00
if run_id not in _active_runs :
return json . dumps ( {
" error " : f " Run ' { run_id } ' not found " ,
" active_runs " : list ( _active_runs . keys ( ) ) ,
} , indent = 2 )
run_state = _active_runs [ run_id ]
# Check process status
processes = {
" api " : run_state . api_process . poll ( ) if run_state . api_process else None ,
" trainer " : run_state . trainer_process . poll ( ) if run_state . trainer_process else None ,
" env " : run_state . env_process . poll ( ) if run_state . env_process else None ,
}
running_time = time . time ( ) - run_state . start_time if run_state . start_time else 0
result = {
" run_id " : run_id ,
" status " : run_state . status ,
" environment " : run_state . environment ,
" running_time_minutes " : running_time / 60 ,
" processes " : {
name : " running " if code is None else f " exited ( { code } ) "
for name , code in processes . items ( )
} ,
" wandb_project " : run_state . wandb_project ,
" wandb_run_name " : run_state . wandb_run_name ,
" logs " : {
" api " : str ( LOGS_DIR / f " api_ { run_id } .log " ) ,
" trainer " : str ( LOGS_DIR / f " trainer_ { run_id } .log " ) ,
" env " : str ( LOGS_DIR / f " env_ { run_id } .log " ) ,
} ,
}
if run_state . error_message :
result [ " error " ] = run_state . error_message
# Try to get WandB metrics if available
try :
import wandb
api = wandb . Api ( )
runs = api . runs (
f " { os . getenv ( ' WANDB_ENTITY ' , ' nousresearch ' ) } / { run_state . wandb_project } " ,
filters = { " display_name " : run_state . wandb_run_name }
)
if runs :
wandb_run = runs [ 0 ]
result [ " wandb_url " ] = wandb_run . url
result [ " metrics " ] = {
" step " : wandb_run . summary . get ( " _step " , 0 ) ,
" reward_mean " : wandb_run . summary . get ( " train/reward_mean " ) ,
" percent_correct " : wandb_run . summary . get ( " train/percent_correct " ) ,
" eval_percent_correct " : wandb_run . summary . get ( " eval/percent_correct " ) ,
}
except Exception as e :
result [ " wandb_error " ] = str ( e )
2026-02-03 23:41:26 -08:00
return json . dumps ( result , indent = 2 )
async def rl_stop_training ( run_id : str ) - > str :
"""
Stop a running training job .
Args :
run_id : The run ID to stop
Returns :
JSON string with stop confirmation
"""
2026-02-04 10:36:01 -08:00
if run_id not in _active_runs :
return json . dumps ( {
" error " : f " Run ' { run_id } ' not found " ,
" active_runs " : list ( _active_runs . keys ( ) ) ,
} , indent = 2 )
run_state = _active_runs [ run_id ]
if run_state . status not in ( " running " , " starting " ) :
return json . dumps ( {
" message " : f " Run ' { run_id } ' is not running (status: { run_state . status } ) " ,
} , indent = 2 )
_stop_training_run ( run_state )
return json . dumps ( {
" message " : f " Stopped training run ' { run_id } ' " ,
" run_id " : run_id ,
" status " : run_state . status ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
async def rl_get_results ( run_id : str ) - > str :
"""
2026-02-04 10:36:01 -08:00
Get final results and metrics for a training run .
2026-02-03 23:41:26 -08:00
Args :
run_id : The run ID to get results for
Returns :
2026-02-04 10:36:01 -08:00
JSON string with final results
2026-02-03 23:41:26 -08:00
"""
2026-02-04 10:36:01 -08:00
if run_id not in _active_runs :
return json . dumps ( {
" error " : f " Run ' { run_id } ' not found " ,
} , indent = 2 )
run_state = _active_runs [ run_id ]
result = {
" run_id " : run_id ,
" status " : run_state . status ,
" environment " : run_state . environment ,
" wandb_project " : run_state . wandb_project ,
" wandb_run_name " : run_state . wandb_run_name ,
}
# Get WandB metrics
try :
import wandb
api = wandb . Api ( )
runs = api . runs (
f " { os . getenv ( ' WANDB_ENTITY ' , ' nousresearch ' ) } / { run_state . wandb_project } " ,
filters = { " display_name " : run_state . wandb_run_name }
)
if runs :
wandb_run = runs [ 0 ]
result [ " wandb_url " ] = wandb_run . url
result [ " final_metrics " ] = dict ( wandb_run . summary )
result [ " history " ] = [ dict ( row ) for row in wandb_run . history ( samples = 10 ) ]
except Exception as e :
result [ " wandb_error " ] = str ( e )
2026-02-03 23:41:26 -08:00
return json . dumps ( result , indent = 2 )
2026-02-04 10:36:01 -08:00
async def rl_list_runs ( ) - > str :
"""
List all training runs ( active and completed ) .
Returns :
JSON string with list of runs and their status
"""
runs = [ ]
for run_id , run_state in _active_runs . items ( ) :
runs . append ( {
" run_id " : run_id ,
" environment " : run_state . environment ,
" status " : run_state . status ,
" wandb_run_name " : run_state . wandb_run_name ,
} )
return json . dumps ( {
" runs " : runs ,
" count " : len ( runs ) ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
# Inference Testing (via Atropos `process` mode with OpenRouter)
2026-02-03 23:41:26 -08:00
# ============================================================================
2026-02-04 10:36:01 -08:00
# Test models at different scales for robustness testing
2026-02-04 13:57:59 -08:00
# These are cheap, capable models on OpenRouter for testing parsing/scoring
2026-02-04 10:36:01 -08:00
TEST_MODELS = [
{ " id " : " qwen/qwen3-8b " , " name " : " Qwen3 8B " , " scale " : " small " } ,
2026-02-04 13:57:59 -08:00
{ " id " : " z-ai/glm-4.7-flash " , " name " : " GLM-4.7 Flash " , " scale " : " medium " } ,
2026-03-06 16:47:48 +08:00
{ " id " : " minimax/minimax-m2.5 " , " name " : " MiniMax M2.5 " , " scale " : " large " } ,
2026-02-04 10:36:01 -08:00
]
# Default test parameters - quick but representative
DEFAULT_NUM_STEPS = 3 # Number of steps (items) to test
DEFAULT_GROUP_SIZE = 16 # Completions per item (like training)
2026-02-03 23:41:26 -08:00
async def rl_test_inference (
2026-02-04 10:36:01 -08:00
num_steps : int = DEFAULT_NUM_STEPS ,
group_size : int = DEFAULT_GROUP_SIZE ,
models : Optional [ List [ str ] ] = None ,
2026-02-03 23:41:26 -08:00
) - > str :
"""
2026-02-04 10:36:01 -08:00
Quick inference test for any environment using Atropos ' s `process` mode.
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
Runs a few steps of inference + scoring to validate :
- Environment loads correctly
- Prompt construction works
- Inference parsing is robust ( tested with multiple model scales )
- Verifier / scoring logic works
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
Default : 3 steps × 16 completions = 48 total rollouts per model .
Tests 3 models = 144 total rollouts . Quick sanity check .
Test models ( varying intelligence levels for robustness ) :
- qwen / qwen3 - 8 b ( small )
- zhipu - ai / glm - 4 - flash ( medium )
- minimax / minimax - m1 ( large )
2026-02-03 23:41:26 -08:00
Args :
2026-02-04 10:36:01 -08:00
num_steps : Steps to run ( default : 3 , max recommended for testing )
group_size : Completions per step ( default : 16 , like training )
models : Optional model IDs to test . If None , uses all 3 test models .
2026-02-03 23:41:26 -08:00
Returns :
2026-02-04 10:36:01 -08:00
JSON with results per model : steps_tested , accuracy , scores
2026-02-03 23:41:26 -08:00
"""
2026-02-04 10:36:01 -08:00
if not _current_env :
return json . dumps ( {
" error " : " No environment selected. Use rl_select_environment(name) first. " ,
} , indent = 2 )
2026-02-03 23:41:26 -08:00
2026-02-04 10:36:01 -08:00
api_key = os . getenv ( " OPENROUTER_API_KEY " )
if not api_key :
return json . dumps ( {
" error " : " OPENROUTER_API_KEY not set. Required for inference testing. " ,
} , indent = 2 )
# Find environment info
env_info = None
for env in _environments :
if env . name == _current_env :
env_info = env
break
if not env_info :
return json . dumps ( {
" error " : f " Environment ' { _current_env } ' not found " ,
} , indent = 2 )
# Determine which models to test
if models :
test_models = [ m for m in TEST_MODELS if m [ " id " ] in models ]
if not test_models :
test_models = [ { " id " : m , " name " : m , " scale " : " custom " } for m in models ]
else :
test_models = TEST_MODELS
# Calculate total rollouts for logging
total_rollouts_per_model = num_steps * group_size
total_rollouts = total_rollouts_per_model * len ( test_models )
results = {
" environment " : _current_env ,
" environment_file " : env_info . file_path ,
" test_config " : {
" num_steps " : num_steps ,
" group_size " : group_size ,
" rollouts_per_model " : total_rollouts_per_model ,
" total_rollouts " : total_rollouts ,
} ,
" models_tested " : [ ] ,
}
# Create output directory for test results
2026-03-12 09:11:44 -07:00
_ensure_logs_dir ( )
2026-02-04 10:36:01 -08:00
test_output_dir = LOGS_DIR / " inference_tests "
test_output_dir . mkdir ( exist_ok = True )
for model_info in test_models :
model_id = model_info [ " id " ]
model_safe_name = model_id . replace ( " / " , " _ " )
print ( f " \n { ' = ' * 60 } " )
print ( f " Testing with { model_info [ ' name ' ] } ( { model_id } ) " )
print ( f " { ' = ' * 60 } " )
# Output file for this test run
output_file = test_output_dir / f " test_ { _current_env } _ { model_safe_name } .jsonl "
2026-02-04 21:07:07 -08:00
# Generate unique run ID for wandb
test_run_id = str ( uuid . uuid4 ( ) ) [ : 8 ]
wandb_run_name = f " test_inference_RSIAgent_ { _current_env } _ { test_run_id } "
2026-02-04 10:36:01 -08:00
# Build the process command using Atropos's built-in CLI
# This runs the environment's actual code with OpenRouter as the inference backend
2026-02-04 13:57:59 -08:00
# We pass our locked settings + test-specific overrides via CLI args
2026-02-04 10:36:01 -08:00
cmd = [
2026-02-04 13:57:59 -08:00
sys . executable , env_info . file_path , " process " ,
# Test-specific overrides
2026-02-04 10:36:01 -08:00
" --env.total_steps " , str ( num_steps ) ,
" --env.group_size " , str ( group_size ) ,
2026-02-04 21:07:07 -08:00
" --env.use_wandb " , " true " , # Enable wandb for test tracking
" --env.wandb_name " , wandb_run_name ,
2026-02-04 10:36:01 -08:00
" --env.data_path_to_save_groups " , str ( output_file ) ,
2026-02-04 13:57:59 -08:00
# Use locked settings from our config
" --env.tokenizer_name " , LOCKED_FIELDS [ " env " ] [ " tokenizer_name " ] ,
" --env.max_token_length " , str ( LOCKED_FIELDS [ " env " ] [ " max_token_length " ] ) ,
" --env.max_num_workers " , str ( LOCKED_FIELDS [ " env " ] [ " max_num_workers " ] ) ,
" --env.max_batches_offpolicy " , str ( LOCKED_FIELDS [ " env " ] [ " max_batches_offpolicy " ] ) ,
# OpenRouter config for inference testing
# IMPORTANT: Use server_type=openai for OpenRouter (not sglang)
# sglang is only for actual training with Tinker's inference server
2026-02-04 10:36:01 -08:00
" --openai.base_url " , " https://openrouter.ai/api/v1 " ,
" --openai.api_key " , api_key ,
" --openai.model_name " , model_id ,
2026-02-04 13:57:59 -08:00
" --openai.server_type " , " openai " , # OpenRouter is OpenAI-compatible
" --openai.health_check " , " false " , # OpenRouter doesn't have health endpoint
2026-02-04 10:36:01 -08:00
]
2026-02-04 13:57:59 -08:00
# Debug: Print the full command
cmd_str = " " . join ( str ( c ) for c in cmd )
# Hide API key in printed output
cmd_display = cmd_str . replace ( api_key , " ***API_KEY*** " )
print ( f " Command: { cmd_display } " )
print ( f " Working dir: { TINKER_ATROPOS_ROOT } " )
2026-02-04 21:07:07 -08:00
print ( f " WandB run: { wandb_run_name } " )
2026-02-04 10:36:01 -08:00
print ( f " { num_steps } steps × { group_size } completions = { total_rollouts_per_model } rollouts " )
model_results = {
" model " : model_id ,
" name " : model_info [ " name " ] ,
" scale " : model_info [ " scale " ] ,
2026-02-04 21:07:07 -08:00
" wandb_run " : wandb_run_name ,
2026-02-04 10:36:01 -08:00
" output_file " : str ( output_file ) ,
" steps " : [ ] ,
" steps_tested " : 0 ,
" total_completions " : 0 ,
" correct_completions " : 0 ,
}
try :
2026-02-04 21:07:07 -08:00
# Run the process command with real-time output streaming
2026-02-04 10:36:01 -08:00
process = await asyncio . create_subprocess_exec (
* cmd ,
stdout = asyncio . subprocess . PIPE ,
stderr = asyncio . subprocess . PIPE ,
cwd = str ( TINKER_ATROPOS_ROOT ) ,
)
2026-02-04 21:07:07 -08:00
# Stream output in real-time while collecting for logs
stdout_lines = [ ]
stderr_lines = [ ]
log_file = test_output_dir / f " test_ { _current_env } _ { model_safe_name } .log "
async def read_stream ( stream , lines_list , prefix = " " ) :
""" Read stream line by line and print in real-time. """
while True :
line = await stream . readline ( )
if not line :
break
decoded = line . decode ( ) . rstrip ( )
lines_list . append ( decoded )
# Print progress-related lines in real-time
if any ( kw in decoded . lower ( ) for kw in [ ' processing ' , ' group ' , ' step ' , ' progress ' , ' % ' , ' completed ' ] ) :
print ( f " { prefix } { decoded } " )
# Read both streams concurrently with timeout
try :
await asyncio . wait_for (
asyncio . gather (
read_stream ( process . stdout , stdout_lines , " 📊 " ) ,
read_stream ( process . stderr , stderr_lines , " ⚠️ " ) ,
) ,
timeout = 600 , # 10 minute timeout per model
)
except asyncio . TimeoutError :
process . kill ( )
raise
await process . wait ( )
2026-02-04 10:36:01 -08:00
2026-02-04 21:07:07 -08:00
# Combine output for logging
stdout_text = " \n " . join ( stdout_lines )
stderr_text = " \n " . join ( stderr_lines )
2026-02-04 13:57:59 -08:00
# Write logs to files for inspection outside CLI
with open ( log_file , " w " ) as f :
f . write ( f " Command: { cmd_display } \n " )
f . write ( f " Working dir: { TINKER_ATROPOS_ROOT } \n " )
f . write ( f " Return code: { process . returncode } \n " )
f . write ( f " \n { ' = ' * 60 } \n " )
f . write ( f " STDOUT: \n { ' = ' * 60 } \n " )
f . write ( stdout_text or " (empty) \n " )
f . write ( f " \n { ' = ' * 60 } \n " )
f . write ( f " STDERR: \n { ' = ' * 60 } \n " )
f . write ( stderr_text or " (empty) \n " )
print ( f " Log file: { log_file } " )
2026-02-04 10:36:01 -08:00
if process . returncode != 0 :
model_results [ " error " ] = f " Process exited with code { process . returncode } "
2026-02-04 13:57:59 -08:00
model_results [ " stderr " ] = stderr_text [ - 1000 : ]
model_results [ " stdout " ] = stdout_text [ - 1000 : ]
model_results [ " log_file " ] = str ( log_file )
print ( f " \n ❌ Error: { model_results [ ' error ' ] } " )
2026-02-04 21:07:07 -08:00
# Print last few lines of stderr for debugging
if stderr_lines :
print ( f " Last errors: " )
for line in stderr_lines [ - 5 : ] :
print ( f " { line } " )
2026-02-04 10:36:01 -08:00
else :
2026-02-04 13:57:59 -08:00
print ( f " \n ✅ Process completed successfully " )
print ( f " Output file: { output_file } " )
print ( f " File exists: { output_file . exists ( ) } " )
2026-02-04 10:36:01 -08:00
# Parse the output JSONL file
if output_file . exists ( ) :
# Read JSONL file (one JSON object per line = one step)
with open ( output_file , " r " ) as f :
for line in f :
line = line . strip ( )
if not line :
continue
try :
item = json . loads ( line )
scores = item . get ( " scores " , [ ] )
model_results [ " steps_tested " ] + = 1
model_results [ " total_completions " ] + = len ( scores )
correct = sum ( 1 for s in scores if s > 0 )
model_results [ " correct_completions " ] + = correct
model_results [ " steps " ] . append ( {
" step " : model_results [ " steps_tested " ] ,
" completions " : len ( scores ) ,
" correct " : correct ,
" scores " : scores ,
} )
except json . JSONDecodeError :
continue
print ( f " Completed { model_results [ ' steps_tested ' ] } steps " )
else :
model_results [ " error " ] = f " Output file not created: { output_file } "
except asyncio . TimeoutError :
model_results [ " error " ] = " Process timed out after 10 minutes "
print ( f " Timeout! " )
except Exception as e :
model_results [ " error " ] = str ( e )
print ( f " Error: { e } " )
# Calculate stats
if model_results [ " total_completions " ] > 0 :
model_results [ " accuracy " ] = round (
model_results [ " correct_completions " ] / model_results [ " total_completions " ] , 3
)
else :
model_results [ " accuracy " ] = 0
if model_results [ " steps_tested " ] > 0 :
steps_with_correct = sum ( 1 for s in model_results [ " steps " ] if s . get ( " correct " , 0 ) > 0 )
model_results [ " steps_with_correct " ] = steps_with_correct
model_results [ " step_success_rate " ] = round (
steps_with_correct / model_results [ " steps_tested " ] , 3
)
else :
model_results [ " steps_with_correct " ] = 0
model_results [ " step_success_rate " ] = 0
print ( f " Results: { model_results [ ' correct_completions ' ] } / { model_results [ ' total_completions ' ] } correct " )
print ( f " Accuracy: { model_results [ ' accuracy ' ] : .1% } " )
results [ " models_tested " ] . append ( model_results )
# Overall summary
working_models = [ m for m in results [ " models_tested " ] if m . get ( " steps_tested " , 0 ) > 0 ]
results [ " summary " ] = {
" steps_requested " : num_steps ,
" models_tested " : len ( test_models ) ,
" models_succeeded " : len ( working_models ) ,
" best_model " : max ( working_models , key = lambda x : x . get ( " accuracy " , 0 ) ) [ " model " ] if working_models else None ,
" avg_accuracy " : round (
sum ( m . get ( " accuracy " , 0 ) for m in working_models ) / len ( working_models ) , 3
) if working_models else 0 ,
" environment_working " : len ( working_models ) > 0 ,
" output_directory " : str ( test_output_dir ) ,
}
return json . dumps ( results , indent = 2 )
2026-02-03 23:41:26 -08:00
# ============================================================================
# Requirements Check
# ============================================================================
2026-02-07 00:05:04 +00:00
def check_rl_python_version ( ) - > bool :
"""
Check if Python version meets the minimum for RL tools .
tinker - atropos depends on the ' tinker ' package which requires Python > = 3.11 .
"""
return sys . version_info > = ( 3 , 11 )
2026-02-03 23:41:26 -08:00
def check_rl_api_keys ( ) - > bool :
"""
2026-02-07 00:05:04 +00:00
Check if required API keys and Python version are available .
RL training requires :
- Python > = 3.11 ( tinker package requirement )
- TINKER_API_KEY for the Tinker training API
- WANDB_API_KEY for Weights & Biases metrics
2026-02-03 23:41:26 -08:00
"""
2026-02-07 00:05:04 +00:00
if not check_rl_python_version ( ) :
return False
2026-02-03 23:41:26 -08:00
tinker_key = os . getenv ( " TINKER_API_KEY " )
wandb_key = os . getenv ( " WANDB_API_KEY " )
return bool ( tinker_key ) and bool ( wandb_key )
def get_missing_keys ( ) - > List [ str ] :
"""
2026-02-07 00:05:04 +00:00
Get list of missing requirements for RL tools ( API keys and Python version ) .
2026-02-03 23:41:26 -08:00
"""
missing = [ ]
2026-02-07 00:05:04 +00:00
if not check_rl_python_version ( ) :
missing . append ( f " Python >= 3.11 (current: { sys . version_info . major } . { sys . version_info . minor } ) " )
2026-02-03 23:41:26 -08:00
if not os . getenv ( " TINKER_API_KEY " ) :
missing . append ( " TINKER_API_KEY " )
if not os . getenv ( " WANDB_API_KEY " ) :
missing . append ( " WANDB_API_KEY " )
return missing
2026-02-21 20:22:33 -08:00
# ---------------------------------------------------------------------------
# Schemas + Registry
# ---------------------------------------------------------------------------
from tools . registry import registry
RL_LIST_ENVIRONMENTS_SCHEMA = { " name " : " rl_list_environments " , " description " : " List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards). " , " parameters " : { " type " : " object " , " properties " : { } , " required " : [ ] } }
RL_SELECT_ENVIRONMENT_SCHEMA = { " name " : " rl_select_environment " , " description " : " Select an RL environment for training. Loads the environment ' s default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them. " , " parameters " : { " type " : " object " , " properties " : { " name " : { " type " : " string " , " description " : " Name of the environment to select (from rl_list_environments) " } } , " required " : [ " name " ] } }
RL_GET_CURRENT_CONFIG_SCHEMA = { " name " : " rl_get_current_config " , " description " : " Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers. " , " parameters " : { " type " : " object " , " properties " : { } , " required " : [ ] } }
RL_EDIT_CONFIG_SCHEMA = { " name " : " rl_edit_config " , " description " : " Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked. " , " parameters " : { " type " : " object " , " properties " : { " field " : { " type " : " string " , " description " : " Name of the field to update (get available fields from rl_get_current_config) " } , " value " : { " description " : " New value for the field " } } , " required " : [ " field " , " value " ] } }
RL_START_TRAINING_SCHEMA = { " name " : " rl_start_training " , " description " : " Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. " , " parameters " : { " type " : " object " , " properties " : { } , " required " : [ ] } }
RL_CHECK_STATUS_SCHEMA = { " name " : " rl_check_status " , " description " : " Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct. " , " parameters " : { " type " : " object " , " properties " : { " run_id " : { " type " : " string " , " description " : " The run ID from rl_start_training() " } } , " required " : [ " run_id " ] } }
RL_STOP_TRAINING_SCHEMA = { " name " : " rl_stop_training " , " description " : " Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings. " , " parameters " : { " type " : " object " , " properties " : { " run_id " : { " type " : " string " , " description " : " The run ID to stop " } } , " required " : [ " run_id " ] } }
RL_GET_RESULTS_SCHEMA = { " name " : " rl_get_results " , " description " : " Get final results and metrics for a completed training run. Returns final metrics and path to trained weights. " , " parameters " : { " type " : " object " , " properties " : { " run_id " : { " type " : " string " , " description " : " The run ID to get results for " } } , " required " : [ " run_id " ] } }
RL_LIST_RUNS_SCHEMA = { " name " : " rl_list_runs " , " description " : " List all training runs (active and completed) with their status. " , " parameters " : { " type " : " object " , " properties " : { } , " required " : [ ] } }
2026-03-06 16:47:48 +08:00
RL_TEST_INFERENCE_SCHEMA = { " name " : " rl_test_inference " , " description " : " Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues. " , " parameters " : { " type " : " object " , " properties " : { " num_steps " : { " type " : " integer " , " description " : " Number of steps to run (default: 3, recommended max for testing) " , " default " : 3 } , " group_size " : { " type " : " integer " , " description " : " Completions per step (default: 16, like training) " , " default " : 16 } , " models " : { " type " : " array " , " items " : { " type " : " string " } , " description " : " Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.5 " } } , " required " : [ ] } }
2026-02-21 20:22:33 -08:00
_rl_env = [ " TINKER_API_KEY " , " WANDB_API_KEY " ]
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_list_environments " , emoji = " 🧪 " , toolset = " rl " , schema = RL_LIST_ENVIRONMENTS_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_list_environments ( ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_select_environment " , emoji = " 🧪 " , toolset = " rl " , schema = RL_SELECT_ENVIRONMENT_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_select_environment ( name = args . get ( " name " , " " ) ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_get_current_config " , emoji = " 🧪 " , toolset = " rl " , schema = RL_GET_CURRENT_CONFIG_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_get_current_config ( ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_edit_config " , emoji = " 🧪 " , toolset = " rl " , schema = RL_EDIT_CONFIG_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_edit_config ( field = args . get ( " field " , " " ) , value = args . get ( " value " ) ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_start_training " , emoji = " 🧪 " , toolset = " rl " , schema = RL_START_TRAINING_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_start_training ( ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_check_status " , emoji = " 🧪 " , toolset = " rl " , schema = RL_CHECK_STATUS_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_check_status ( run_id = args . get ( " run_id " , " " ) ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_stop_training " , emoji = " 🧪 " , toolset = " rl " , schema = RL_STOP_TRAINING_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_stop_training ( run_id = args . get ( " run_id " , " " ) ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_get_results " , emoji = " 🧪 " , toolset = " rl " , schema = RL_GET_RESULTS_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_get_results ( run_id = args . get ( " run_id " , " " ) ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_list_runs " , emoji = " 🧪 " , toolset = " rl " , schema = RL_LIST_RUNS_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_list_runs ( ) , check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )
2026-03-15 20:21:21 -07:00
registry . register ( name = " rl_test_inference " , emoji = " 🧪 " , toolset = " rl " , schema = RL_TEST_INFERENCE_SCHEMA ,
2026-02-21 20:22:33 -08:00
handler = lambda args , * * kw : rl_test_inference ( num_steps = args . get ( " num_steps " , 3 ) , group_size = args . get ( " group_size " , 16 ) , models = args . get ( " models " ) ) ,
check_fn = check_rl_api_keys , requires_env = _rl_env , is_async = True )