initial RL training tools and loop

This commit is contained in:
teknium1
2026-02-03 23:41:26 -08:00
parent 51a6b7d2b5
commit f018999da9
5 changed files with 1199 additions and 3 deletions

View File

@@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
# RL Training tools (Tinker-Atropos)
from tools.rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_test_inference,
rl_list_runs,
rl_health_check,
check_rl_api_keys,
)
# Cronjob management tools (CLI-only)
from tools.cronjob_tools import (
schedule_cronjob,
@@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = {
"setup_url": None,
"tools": ["skills_categories", "skills_list", "skill_view"],
},
"rl": {
"name": "RL Training (Tinker-Atropos)",
"env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"],
"check_fn": check_rl_api_keys,
"setup_url": "https://wandb.ai/authorize",
"tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs",
],
},
}
@@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]:
]]
def get_rl_tool_definitions() -> List[Dict[str, Any]]:
"""
Get tool definitions for RL training tools in OpenAI's expected format.
These tools enable running RL training through Tinker-Atropos.
Returns:
List[Dict]: List of RL tool definitions compatible with OpenAI API
"""
return [
{
"type": "function",
"function": {
"name": "rl_list_environments",
"description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_select_environment",
"description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the environment to select (from rl_list_environments)"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_get_current_config",
"description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_edit_config",
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
"parameters": {
"type": "object",
"properties": {
"field": {
"type": "string",
"description": "Name of the field to update"
},
"value": {
"description": "New value for the field"
}
},
"required": ["field", "value"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_start_training",
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
"parameters": {
"type": "object",
"properties": {
"wandb_project": {
"type": "string",
"description": "WandB project name for logging",
"default": "rl-training"
},
"lora_rank": {
"type": "integer",
"description": "LoRA rank for training",
"default": 32
},
"learning_rate": {
"type": "number",
"description": "Learning rate",
"default": 4e-5
}
},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_check_status",
"description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID from rl_start_training()"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_stop_training",
"description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID to stop"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_get_results",
"description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID to get results for"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_test_inference",
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.",
"parameters": {
"type": "object",
"properties": {
"prompts": {
"type": "array",
"items": {"type": "string"},
"description": "List of test prompts to run through the environment"
},
"max_tokens": {
"type": "integer",
"description": "Maximum tokens to generate per prompt",
"default": 256
},
"temperature": {
"type": "number",
"description": "Sampling temperature",
"default": 1.0
}
},
"required": ["prompts"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_list_runs",
"description": "List all training runs (active and completed) with their status.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
]
def get_all_tool_names() -> List[str]:
"""
Get the names of all available tools across all toolsets.
@@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]:
"schedule_cronjob", "list_cronjobs", "remove_cronjob"
])
# RL Training tools
if check_rl_api_keys():
tool_names.extend([
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
])
return tool_names
@@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str:
# Cronjob management tools
"schedule_cronjob": "cronjob_tools",
"list_cronjobs": "cronjob_tools",
"remove_cronjob": "cronjob_tools"
"remove_cronjob": "cronjob_tools",
# RL Training tools
"rl_list_environments": "rl_tools",
"rl_select_environment": "rl_tools",
"rl_get_current_config": "rl_tools",
"rl_edit_config": "rl_tools",
"rl_start_training": "rl_tools",
"rl_check_status": "rl_tools",
"rl_stop_training": "rl_tools",
"rl_get_results": "rl_tools",
"rl_test_inference": "rl_tools",
"rl_list_runs": "rl_tools",
}
return toolset_mapping.get(tool_name, "unknown")
@@ -635,6 +877,11 @@ def get_tool_definitions(
for tool in get_cronjob_tool_definitions_formatted():
all_available_tools_map[tool["function"]["name"]] = tool
# RL Training tools
if check_rl_api_keys():
for tool in get_rl_tool_definitions():
all_available_tools_map[tool["function"]["name"]] = tool
# Determine which tools to include based on toolsets
tools_to_include = set()
@@ -663,7 +910,14 @@ def get_tool_definitions(
"browser_press", "browser_close", "browser_get_images",
"browser_vision"
],
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
"rl_tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]
}
legacy_tools = legacy_map.get(toolset_name, [])
tools_to_include.update(legacy_tools)
@@ -708,7 +962,14 @@ def get_tool_definitions(
"browser_press", "browser_close", "browser_get_images",
"browser_vision"
],
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
"rl_tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]
}
legacy_tools = legacy_map.get(toolset_name, [])
tools_to_include.difference_update(legacy_tools)
@@ -1018,6 +1279,89 @@ def handle_cronjob_function_call(
return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False)
def handle_rl_function_call(
function_name: str,
function_args: Dict[str, Any]
) -> str:
"""
Handle function calls for RL training tools.
These tools communicate with the RL API server to manage training runs.
Args:
function_name (str): Name of the RL function to call
function_args (Dict): Arguments for the function
Returns:
str: Function result as JSON string
"""
# Run async functions in event loop
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if function_name == "rl_list_environments":
return loop.run_until_complete(rl_list_environments())
elif function_name == "rl_select_environment":
return loop.run_until_complete(
rl_select_environment(name=function_args.get("name", ""))
)
elif function_name == "rl_get_current_config":
return loop.run_until_complete(rl_get_current_config())
elif function_name == "rl_edit_config":
return loop.run_until_complete(
rl_edit_config(
field=function_args.get("field", ""),
value=function_args.get("value")
)
)
elif function_name == "rl_start_training":
return loop.run_until_complete(
rl_start_training(
wandb_project=function_args.get("wandb_project", "rl-training"),
lora_rank=function_args.get("lora_rank", 32),
learning_rate=function_args.get("learning_rate", 4e-5)
)
)
elif function_name == "rl_check_status":
return loop.run_until_complete(
rl_check_status(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_stop_training":
return loop.run_until_complete(
rl_stop_training(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_get_results":
return loop.run_until_complete(
rl_get_results(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_test_inference":
return loop.run_until_complete(
rl_test_inference(
prompts=function_args.get("prompts", []),
max_tokens=function_args.get("max_tokens", 256),
temperature=function_args.get("temperature", 1.0)
)
)
elif function_name == "rl_list_runs":
return loop.run_until_complete(rl_list_runs())
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
def handle_function_call(
function_name: str,
function_args: Dict[str, Any],
@@ -1081,6 +1425,16 @@ def handle_function_call(
elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]:
return handle_cronjob_function_call(function_name, function_args, task_id)
# Route RL training tools
elif function_name in [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]:
return handle_rl_function_call(function_name, function_args)
else:
error_msg = f"Unknown function: {function_name}"
print(f"{error_msg}")

363
rl_cli.py Normal file
View File

@@ -0,0 +1,363 @@
#!/usr/bin/env python3
"""
RL Training CLI Runner
Dedicated CLI runner for RL training workflows with:
- Extended timeouts for long-running training
- RL-focused system prompts
- Full toolset including RL training tools
- Special handling for 30-minute check intervals
Usage:
python rl_cli.py "Train a model on GSM8k for math reasoning"
python rl_cli.py --interactive
python rl_cli.py --list-environments
Environment Variables:
TINKER_API_KEY: API key for Tinker service (required)
WANDB_API_KEY: API key for WandB metrics (required)
RL_API_URL: URL of RL API server (default: http://localhost:8080)
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
"""
import asyncio
import os
import sys
from pathlib import Path
import fire
# Load environment variables from .env file
from dotenv import load_dotenv
env_path = Path(__file__).parent / '.env'
if env_path.exists():
load_dotenv(dotenv_path=env_path)
print(f"✅ Loaded environment variables from {env_path}")
# Import agent and tools
from run_agent import AIAgent
from model_tools import get_tool_definitions, check_toolset_requirements
from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check
# ============================================================================
# RL-Specific Configuration
# ============================================================================
# Extended timeouts for long-running RL operations
RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows
# RL-focused system prompt
RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models.
## Your Capabilities
You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos:
1. **DISCOVER**: Use `rl_list_environments` to see available RL environments
2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards)
3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format
4. **CREATE**: Copy existing environments as templates, modify for your needs
5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training
6. **TEST**: Always use `rl_test_inference` before full training to validate your setup
7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor
8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance
## Environment Files
Environment files are located in: `tinker-atropos/tinker_atropos/environments/`
Study existing environments to learn patterns. Look for:
- `load_dataset()` calls - how data is loaded
- `score_answer()` / `score()` - verification logic
- `get_next_item()` - prompt formatting
- `system_prompt` - instruction format
- `config_init()` - default configuration
## Creating New Environments
To create a new environment:
1. Read an existing environment file (e.g., gsm8k_tinker.py)
2. Use terminal to explore the target dataset format
3. Copy the environment file as a template
4. Modify the dataset loading, prompt formatting, and verifier logic
5. Test with `rl_test_inference` before training
## Important Guidelines
- **Always test before training**: Training runs take hours - verify everything works first
- **Monitor metrics**: Check WandB for reward/mean and percent_correct
- **Status check intervals**: Wait at least 30 minutes between status checks
- **Early stopping**: Stop training early if metrics look bad or stagnant
- **Iterate quickly**: Start with small total_steps to validate, then scale up
## Available Toolsets
You have access to:
- **RL tools**: Environment discovery, config management, training, testing
- **Terminal**: Run commands, inspect files, explore datasets
- **Web**: Search for information, documentation, papers
- **File tools**: Read and modify code files
When asked to train a model, follow this workflow:
1. List available environments
2. Select and configure the appropriate environment
3. Test with sample prompts
4. Start training with conservative settings
5. Monitor progress and adjust as needed
"""
# Toolsets to enable for RL workflows
RL_TOOLSETS = ["base", "terminal", "web", "rl"]
# ============================================================================
# Helper Functions
# ============================================================================
def check_requirements():
"""Check that all required environment variables and services are available."""
errors = []
# Check API keys
if not os.getenv("OPENROUTER_API_KEY"):
errors.append("OPENROUTER_API_KEY not set - required for agent")
missing_rl_keys = get_missing_keys()
if missing_rl_keys:
errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}")
if errors:
print("❌ Missing requirements:")
for error in errors:
print(f" - {error}")
print("\nPlease set these environment variables in your .env file or shell.")
return False
return True
async def check_rl_server():
"""Check if the RL API server is running."""
try:
result = await rl_health_check()
import json
data = json.loads(result)
if "error" in data:
return False, data["error"]
return True, data
except Exception as e:
return False, str(e)
def list_environments_sync():
"""List available environments (synchronous wrapper)."""
from tools.rl_training_tool import rl_list_environments
import json
async def _list():
result = await rl_list_environments()
return json.loads(result)
return asyncio.run(_list())
# ============================================================================
# Main CLI
# ============================================================================
def main(
task: str = None,
model: str = "anthropic/claude-sonnet-4-20250514",
api_key: str = None,
base_url: str = "https://openrouter.ai/api/v1",
max_iterations: int = RL_MAX_ITERATIONS,
interactive: bool = False,
list_environments: bool = False,
check_server: bool = False,
verbose: bool = False,
save_trajectories: bool = True,
):
"""
RL Training CLI - Dedicated runner for RL training workflows.
Args:
task: The training task/goal (e.g., "Train a model on GSM8k for math")
model: Model to use for the agent (default: claude-sonnet-4)
api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided)
base_url: API base URL (default: OpenRouter)
max_iterations: Maximum agent iterations (default: 200 for long workflows)
interactive: Run in interactive mode (multiple conversations)
list_environments: Just list available RL environments and exit
check_server: Check if RL API server is running and exit
verbose: Enable verbose logging
save_trajectories: Save conversation trajectories (default: True for RL)
Examples:
# Train on a specific environment
python rl_cli.py "Train a model on GSM8k math problems"
# Interactive mode
python rl_cli.py --interactive
# List available environments
python rl_cli.py --list-environments
# Check server status
python rl_cli.py --check-server
"""
print("🎯 RL Training Agent")
print("=" * 60)
# Handle server check
if check_server:
print("\n🔍 Checking RL API server...")
ok, result = asyncio.run(check_rl_server())
if ok:
print("✅ RL API server is running")
print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}")
print(f" Current environment: {result.get('current_environment', 'none')}")
print(f" Active runs: {result.get('active_runs', 0)}")
else:
print(f"❌ RL API server not accessible: {result}")
print("\nTo start the server:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
return
# Handle environment listing
if list_environments:
print("\n📋 Available RL Environments:")
print("-" * 40)
try:
data = list_environments_sync()
if "error" in data:
print(f"❌ Error: {data['error']}")
return
envs = data.get("environments", [])
if not envs:
print("No environments found.")
print("\nMake sure the RL API server is running:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
return
for env in envs:
print(f"\n 📦 {env['name']}")
print(f" Class: {env['class_name']}")
print(f" Path: {env['file_path']}")
if env.get('description'):
desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '')
print(f" Description: {desc}")
print(f"\n📊 Total: {len(envs)} environments")
print("\nUse `rl_select_environment(name)` to select an environment for training.")
except Exception as e:
print(f"❌ Error listing environments: {e}")
print("\nMake sure the RL API server is running:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
return
# Check requirements
if not check_requirements():
sys.exit(1)
# Set default task if none provided
if not task and not interactive:
print("\n⚠️ No task provided. Use --interactive for interactive mode or provide a task.")
print("\nExamples:")
print(' python rl_cli.py "Train a model on GSM8k math problems"')
print(' python rl_cli.py "Create an RL environment for code generation"')
print(' python rl_cli.py --interactive')
return
# Get API key
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
if not api_key:
print("❌ No API key provided. Set OPENROUTER_API_KEY or pass --api-key")
sys.exit(1)
print(f"\n🤖 Model: {model}")
print(f"🔧 Max iterations: {max_iterations}")
print(f"📁 Toolsets: {', '.join(RL_TOOLSETS)}")
print("=" * 60)
# Create agent with RL configuration
agent = AIAgent(
base_url=base_url,
api_key=api_key,
model=model,
max_iterations=max_iterations,
enabled_toolsets=RL_TOOLSETS,
save_trajectories=save_trajectories,
verbose_logging=verbose,
quiet_mode=False,
ephemeral_system_prompt=RL_SYSTEM_PROMPT,
)
if interactive:
# Interactive mode - multiple conversations
print("\n🔄 Interactive RL Training Mode")
print("Type 'quit' or 'exit' to end the session.")
print("Type 'status' to check active training runs.")
print("-" * 40)
while True:
try:
user_input = input("\n🎯 RL Task> ").strip()
if not user_input:
continue
if user_input.lower() in ('quit', 'exit', 'q'):
print("\n👋 Goodbye!")
break
if user_input.lower() == 'status':
# Quick status check
from tools.rl_training_tool import rl_list_runs
import json
result = asyncio.run(rl_list_runs())
runs = json.loads(result)
if isinstance(runs, list) and runs:
print("\n📊 Active Runs:")
for run in runs:
print(f" - {run['run_id']}: {run['environment']} ({run['status']})")
else:
print("\nNo active runs.")
continue
# Run the agent
print("\n" + "=" * 60)
response = agent.run_conversation(user_input)
print("\n" + "=" * 60)
except KeyboardInterrupt:
print("\n\n👋 Interrupted. Goodbye!")
break
except Exception as e:
print(f"\n❌ Error: {e}")
if verbose:
import traceback
traceback.print_exc()
else:
# Single task mode
print(f"\n📝 Task: {task}")
print("-" * 40)
try:
response = agent.run_conversation(task)
print("\n" + "=" * 60)
print("✅ Task completed")
except KeyboardInterrupt:
print("\n\n⚠️ Interrupted by user")
except Exception as e:
print(f"\n❌ Error: {e}")
if verbose:
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -95,6 +95,23 @@ from .cronjob_tools import (
REMOVE_CRONJOB_SCHEMA
)
# RL Training tools (Tinker-Atropos)
from .rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_test_inference,
rl_list_runs,
rl_health_check,
check_rl_api_keys,
get_missing_keys,
)
__all__ = [
# Web tools
'web_search_tool',
@@ -152,5 +169,19 @@ __all__ = [
'SCHEDULE_CRONJOB_SCHEMA',
'LIST_CRONJOBS_SCHEMA',
'REMOVE_CRONJOB_SCHEMA',
# RL Training tools
'rl_list_environments',
'rl_select_environment',
'rl_get_current_config',
'rl_edit_config',
'rl_start_training',
'rl_check_status',
'rl_stop_training',
'rl_get_results',
'rl_test_inference',
'rl_list_runs',
'rl_health_check',
'check_rl_api_keys',
'get_missing_keys',
]

436
tools/rl_training_tool.py Normal file
View File

@@ -0,0 +1,436 @@
#!/usr/bin/env python3
"""
RL Training Tools Module
This module provides tools for running RL training through Tinker-Atropos.
Communicates with the RL API server (rl_api_server.py) to manage:
- Environment discovery and selection
- Configuration management
- Training run lifecycle
- WandB metrics monitoring
- Inference-only testing
Required environment variables:
- TINKER_API_KEY: API key for Tinker service
- WANDB_API_KEY: API key for Weights & Biases metrics
Optional environment variables:
- RL_API_URL: URL of the RL API server (default: http://localhost:8080)
- WANDB_ENTITY: WandB entity/team name
- WANDB_PROJECT: Default WandB project name
Usage:
from tools.rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_test_inference,
)
"""
import json
import os
import time
from typing import Any, Dict, List, Optional
import aiohttp
# ============================================================================
# Configuration
# ============================================================================
# Default RL API server URL (can be overridden via environment variable)
RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080")
# Rate limiting for status checks (30 minutes in seconds)
MIN_STATUS_CHECK_INTERVAL = 30 * 60
_last_status_check: Dict[str, float] = {}
# ============================================================================
# Helper Functions
# ============================================================================
async def _make_request(
method: str,
endpoint: str,
data: Optional[Dict] = None,
timeout: int = 30,
) -> Dict[str, Any]:
"""Make an HTTP request to the RL API server."""
url = f"{RL_API_URL}{endpoint}"
async with aiohttp.ClientSession() as session:
try:
if method == "GET":
async with session.get(url, timeout=timeout) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
return {"error": f"HTTP {response.status}: {error_text}"}
elif method == "POST":
async with session.post(url, json=data, timeout=timeout) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
return {"error": f"HTTP {response.status}: {error_text}"}
except aiohttp.ClientConnectorError:
return {
"error": f"Cannot connect to RL API server at {RL_API_URL}. "
"Make sure the server is running: "
"cd tinker-atropos && uvicorn rl_api_server:app --port 8080"
}
except Exception as e:
return {"error": f"Request failed: {str(e)}"}
# ============================================================================
# Environment Discovery Tools
# ============================================================================
async def rl_list_environments() -> str:
"""
List all available RL environments.
Scans tinker-atropos/tinker_atropos/environments/ for Python files
containing classes that inherit from BaseEnv.
Returns information about each environment including:
- name: Environment identifier
- class_name: Python class name
- file_path: Path to the environment file
- description: Brief description if available
TIP: To create or modify RL environments:
1. Use terminal/file tools to inspect existing environments
2. Study how they load datasets, define verifiers, and structure rewards
3. Inspect HuggingFace datasets to understand data formats
4. Copy an existing environment as a template
5. Test with rl_test_inference before running full training
Returns:
JSON string with list of environments or error message
"""
result = await _make_request("GET", "/environments")
if "error" in result:
return json.dumps(result, indent=2)
# Add helpful tips to the response
response = {
"environments": result,
"count": len(result),
"tips": [
"Use rl_select_environment(name) to select an environment",
"Read the file_path with file tools to understand how each environment works",
"Look for load_dataset(), score_answer(), get_next_item() methods",
]
}
return json.dumps(response, indent=2)
async def rl_select_environment(name: str) -> str:
"""
Select an RL environment for training.
This loads the environment's default configuration into the config state.
After selecting, use rl_get_current_config() to see the configuration
and rl_edit_config() to modify specific fields.
Args:
name: Name of the environment to select (from rl_list_environments)
Returns:
JSON string with selection result, file path, and current config
TIP: Read the returned file_path to understand how the environment works:
- How it loads data (load_dataset calls)
- How it verifies answers (score_answer method)
- What prompts it uses (system_prompt, get_next_item)
"""
result = await _make_request("POST", f"/environments/{name}/select")
return json.dumps(result, indent=2)
# ============================================================================
# Configuration Tools
# ============================================================================
async def rl_get_current_config() -> str:
"""
Get the current environment configuration.
Returns only the fields that are safe to modify. Other fields
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
Available fields:
- group_size: Rollouts per prompt (4-16 typical)
- max_token_length: Max generation tokens (2048-16384)
- total_steps: Training steps (50-2000)
- steps_per_eval: Steps between evaluations
- use_wandb: Enable WandB logging
- wandb_name: WandB run name prefix
- max_num_workers: Concurrent workers (-1 = auto)
Returns:
JSON string with current config fields and their values
"""
result = await _make_request("GET", "/config")
return json.dumps(result, indent=2)
async def rl_edit_config(field: str, value: Any) -> str:
"""
Update a configuration field.
Only exposed fields can be modified. Validates field name and type.
Args:
field: Name of the field to update (e.g., "group_size", "total_steps")
value: New value for the field
Valid fields:
- group_size (int): Rollouts per prompt
- max_token_length (int): Max generation tokens
- total_steps (int): Training steps
- steps_per_eval (int): Eval frequency
- use_wandb (bool): Enable logging
- wandb_name (str): Run name prefix
- max_num_workers (int): Workers count
Returns:
JSON string with updated config or error message
"""
result = await _make_request("POST", "/config", {"field": field, "value": value})
return json.dumps(result, indent=2)
# ============================================================================
# Training Management Tools
# ============================================================================
async def rl_start_training(
wandb_project: str = "rl-training",
lora_rank: int = 32,
learning_rate: float = 4e-5,
) -> str:
"""
Start a new RL training run with the current environment and config.
Requires an environment to be selected first using rl_select_environment().
WARNING: Training runs can take hours to days. Use rl_check_status() to
monitor progress (recommended: check every 30 minutes at most).
Args:
wandb_project: WandB project name for logging
lora_rank: LoRA rank for training (default: 32)
learning_rate: Learning rate (default: 4e-5)
Returns:
JSON string with run_id and initial status
TIP: Before starting training:
1. Test with rl_test_inference() to verify the environment works
2. Start with fewer total_steps to validate the setup
3. Monitor WandB metrics for reward/mean and percent_correct
"""
result = await _make_request("POST", "/runs", {
"wandb_project": wandb_project,
"lora_rank": lora_rank,
"learning_rate": learning_rate,
})
return json.dumps(result, indent=2)
async def rl_check_status(run_id: str) -> str:
"""
Get status and metrics for a training run.
RATE LIMITED: For long-running training, this function enforces a
minimum 30-minute interval between checks for the same run_id.
Fetches latest metrics from WandB if available:
- step: Current training step
- state: Run state (running, finished, crashed)
- reward_mean: Average reward across batches
- loss: Training loss
- percent_correct: Training accuracy
- eval_percent_correct: Evaluation accuracy
Args:
run_id: The run ID returned by rl_start_training()
Returns:
JSON string with run status and metrics, or rate limit message
"""
global _last_status_check
# Check rate limiting
now = time.time()
if run_id in _last_status_check:
elapsed = now - _last_status_check[run_id]
if elapsed < MIN_STATUS_CHECK_INTERVAL:
remaining = MIN_STATUS_CHECK_INTERVAL - elapsed
return json.dumps({
"rate_limited": True,
"run_id": run_id,
"message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.",
"next_check_in_seconds": remaining,
}, indent=2)
_last_status_check[run_id] = now
result = await _make_request("GET", f"/runs/{run_id}")
return json.dumps(result, indent=2)
async def rl_stop_training(run_id: str) -> str:
"""
Stop a running training job.
Use this if:
- Metrics look bad or training is stagnant
- You want to try different settings
- You need to free up resources
Args:
run_id: The run ID to stop
Returns:
JSON string with stop confirmation
"""
result = await _make_request("POST", f"/runs/{run_id}/stop")
return json.dumps(result, indent=2)
async def rl_get_results(run_id: str) -> str:
"""
Get final results and metrics for a completed training run.
Returns:
- Final metrics (reward, loss, accuracy)
- WandB run URL for detailed analysis
- Path to trained weights (tinker:// URL)
Args:
run_id: The run ID to get results for
Returns:
JSON string with final results and weights path
"""
result = await _make_request("GET", f"/runs/{run_id}/metrics")
return json.dumps(result, indent=2)
# ============================================================================
# Inference Testing Tools
# ============================================================================
async def rl_test_inference(
prompts: List[str],
max_tokens: int = 256,
temperature: float = 1.0,
) -> str:
"""
Test inference + verifier on sample prompts WITHOUT full training.
Use this to validate environments before committing to long training runs.
Tests:
- Data loading and formatting
- Model inference through Tinker
- Verifier/reward function logic
NOTE: This still requires the RL API server to be running with
Tinker access for the Sample() method.
Args:
prompts: List of test prompts to run through the environment
max_tokens: Maximum tokens to generate per prompt
temperature: Sampling temperature
Returns:
JSON string with responses and verifier scores for each prompt
TIP: Include prompts with known correct/incorrect answers to verify
the reward function is working correctly.
"""
result = await _make_request("POST", "/test/inference", {
"prompts": prompts,
"max_tokens": max_tokens,
"temperature": temperature,
})
return json.dumps(result, indent=2)
# ============================================================================
# Utility Tools
# ============================================================================
async def rl_list_runs() -> str:
"""
List all training runs (active and completed).
Returns:
JSON string with list of runs and their status
"""
result = await _make_request("GET", "/runs")
return json.dumps(result, indent=2)
# ============================================================================
# Requirements Check
# ============================================================================
def check_rl_api_keys() -> bool:
"""
Check if required API keys are available in environment variables.
Required:
- TINKER_API_KEY: For Tinker training service
- WANDB_API_KEY: For metrics logging and fetching
Returns:
bool: True if all required keys are set, False otherwise
"""
tinker_key = os.getenv("TINKER_API_KEY")
wandb_key = os.getenv("WANDB_API_KEY")
return bool(tinker_key) and bool(wandb_key)
def get_missing_keys() -> List[str]:
"""
Get list of missing required API keys.
Returns:
List of missing key names
"""
missing = []
if not os.getenv("TINKER_API_KEY"):
missing.append("TINKER_API_KEY")
if not os.getenv("WANDB_API_KEY"):
missing.append("WANDB_API_KEY")
return missing
# ============================================================================
# Debug/Status
# ============================================================================
async def rl_health_check() -> str:
"""
Check if the RL API server is running and accessible.
Returns:
JSON string with server health status
"""
result = await _make_request("GET", "/health")
return json.dumps(result, indent=2)

View File

@@ -90,6 +90,18 @@ TOOLSETS = {
"includes": []
},
"rl": {
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
"tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
],
"includes": []
},
# Scenario-specific toolsets
"debugging": {