From f018999da97862bfc919a8eaddfec57ce0cdea18 Mon Sep 17 00:00:00 2001 From: teknium1 Date: Tue, 3 Feb 2026 23:41:26 -0800 Subject: [PATCH] initial RL training tools and loop --- model_tools.py | 360 ++++++++++++++++++++++++++++++- rl_cli.py | 363 +++++++++++++++++++++++++++++++ tools/__init__.py | 31 +++ tools/rl_training_tool.py | 436 ++++++++++++++++++++++++++++++++++++++ toolsets.py | 12 ++ 5 files changed, 1199 insertions(+), 3 deletions(-) create mode 100644 rl_cli.py create mode 100644 tools/rl_training_tool.py diff --git a/model_tools.py b/model_tools.py index e78323f60..ebabaf564 100644 --- a/model_tools.py +++ b/model_tools.py @@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION +# RL Training tools (Tinker-Atropos) +from tools.rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + rl_list_runs, + rl_health_check, + check_rl_api_keys, +) # Cronjob management tools (CLI-only) from tools.cronjob_tools import ( schedule_cronjob, @@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = { "setup_url": None, "tools": ["skills_categories", "skills_list", "skill_view"], }, + "rl": { + "name": "RL Training (Tinker-Atropos)", + "env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"], + "check_fn": check_rl_api_keys, + "setup_url": "https://wandb.ai/authorize", + "tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs", + ], + }, } @@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]: ]] +def get_rl_tool_definitions() -> List[Dict[str, Any]]: + """ + Get tool definitions for RL training tools in OpenAI's expected format. + + These tools enable running RL training through Tinker-Atropos. + + Returns: + List[Dict]: List of RL tool definitions compatible with OpenAI API + """ + return [ + { + "type": "function", + "function": { + "name": "rl_list_environments", + "description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_select_environment", + "description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Name of the environment to select (from rl_list_environments)" + } + }, + "required": ["name"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_get_current_config", + "description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_edit_config", + "description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).", + "parameters": { + "type": "object", + "properties": { + "field": { + "type": "string", + "description": "Name of the field to update" + }, + "value": { + "description": "New value for the field" + } + }, + "required": ["field", "value"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_start_training", + "description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!", + "parameters": { + "type": "object", + "properties": { + "wandb_project": { + "type": "string", + "description": "WandB project name for logging", + "default": "rl-training" + }, + "lora_rank": { + "type": "integer", + "description": "LoRA rank for training", + "default": 32 + }, + "learning_rate": { + "type": "number", + "description": "Learning rate", + "default": 4e-5 + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_check_status", + "description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID from rl_start_training()" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_stop_training", + "description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID to stop" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_get_results", + "description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.", + "parameters": { + "type": "object", + "properties": { + "run_id": { + "type": "string", + "description": "The run ID to get results for" + } + }, + "required": ["run_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_test_inference", + "description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.", + "parameters": { + "type": "object", + "properties": { + "prompts": { + "type": "array", + "items": {"type": "string"}, + "description": "List of test prompts to run through the environment" + }, + "max_tokens": { + "type": "integer", + "description": "Maximum tokens to generate per prompt", + "default": 256 + }, + "temperature": { + "type": "number", + "description": "Sampling temperature", + "default": 1.0 + } + }, + "required": ["prompts"] + } + } + }, + { + "type": "function", + "function": { + "name": "rl_list_runs", + "description": "List all training runs (active and completed) with their status.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } + ] + + def get_all_tool_names() -> List[str]: """ Get the names of all available tools across all toolsets. @@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]: "schedule_cronjob", "list_cronjobs", "remove_cronjob" ]) + # RL Training tools + if check_rl_api_keys(): + tool_names.extend([ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ]) + return tool_names @@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str: # Cronjob management tools "schedule_cronjob": "cronjob_tools", "list_cronjobs": "cronjob_tools", - "remove_cronjob": "cronjob_tools" + "remove_cronjob": "cronjob_tools", + # RL Training tools + "rl_list_environments": "rl_tools", + "rl_select_environment": "rl_tools", + "rl_get_current_config": "rl_tools", + "rl_edit_config": "rl_tools", + "rl_start_training": "rl_tools", + "rl_check_status": "rl_tools", + "rl_stop_training": "rl_tools", + "rl_get_results": "rl_tools", + "rl_test_inference": "rl_tools", + "rl_list_runs": "rl_tools", } return toolset_mapping.get(tool_name, "unknown") @@ -635,6 +877,11 @@ def get_tool_definitions( for tool in get_cronjob_tool_definitions_formatted(): all_available_tools_map[tool["function"]["name"]] = tool + # RL Training tools + if check_rl_api_keys(): + for tool in get_rl_tool_definitions(): + all_available_tools_map[tool["function"]["name"]] = tool + # Determine which tools to include based on toolsets tools_to_include = set() @@ -663,7 +910,14 @@ def get_tool_definitions( "browser_press", "browser_close", "browser_get_images", "browser_vision" ], - "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"] + "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], + "rl_tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.update(legacy_tools) @@ -708,7 +962,14 @@ def get_tool_definitions( "browser_press", "browser_close", "browser_get_images", "browser_vision" ], - "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"] + "cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"], + "rl_tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ] } legacy_tools = legacy_map.get(toolset_name, []) tools_to_include.difference_update(legacy_tools) @@ -1018,6 +1279,89 @@ def handle_cronjob_function_call( return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False) +def handle_rl_function_call( + function_name: str, + function_args: Dict[str, Any] +) -> str: + """ + Handle function calls for RL training tools. + + These tools communicate with the RL API server to manage training runs. + + Args: + function_name (str): Name of the RL function to call + function_args (Dict): Arguments for the function + + Returns: + str: Function result as JSON string + """ + # Run async functions in event loop + import asyncio + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + if function_name == "rl_list_environments": + return loop.run_until_complete(rl_list_environments()) + + elif function_name == "rl_select_environment": + return loop.run_until_complete( + rl_select_environment(name=function_args.get("name", "")) + ) + + elif function_name == "rl_get_current_config": + return loop.run_until_complete(rl_get_current_config()) + + elif function_name == "rl_edit_config": + return loop.run_until_complete( + rl_edit_config( + field=function_args.get("field", ""), + value=function_args.get("value") + ) + ) + + elif function_name == "rl_start_training": + return loop.run_until_complete( + rl_start_training( + wandb_project=function_args.get("wandb_project", "rl-training"), + lora_rank=function_args.get("lora_rank", 32), + learning_rate=function_args.get("learning_rate", 4e-5) + ) + ) + + elif function_name == "rl_check_status": + return loop.run_until_complete( + rl_check_status(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_stop_training": + return loop.run_until_complete( + rl_stop_training(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_get_results": + return loop.run_until_complete( + rl_get_results(run_id=function_args.get("run_id", "")) + ) + + elif function_name == "rl_test_inference": + return loop.run_until_complete( + rl_test_inference( + prompts=function_args.get("prompts", []), + max_tokens=function_args.get("max_tokens", 256), + temperature=function_args.get("temperature", 1.0) + ) + ) + + elif function_name == "rl_list_runs": + return loop.run_until_complete(rl_list_runs()) + + return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False) + + def handle_function_call( function_name: str, function_args: Dict[str, Any], @@ -1081,6 +1425,16 @@ def handle_function_call( elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]: return handle_cronjob_function_call(function_name, function_args, task_id) + # Route RL training tools + elif function_name in [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ]: + return handle_rl_function_call(function_name, function_args) + else: error_msg = f"Unknown function: {function_name}" print(f"āŒ {error_msg}") diff --git a/rl_cli.py b/rl_cli.py new file mode 100644 index 000000000..cd76c91d6 --- /dev/null +++ b/rl_cli.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +RL Training CLI Runner + +Dedicated CLI runner for RL training workflows with: +- Extended timeouts for long-running training +- RL-focused system prompts +- Full toolset including RL training tools +- Special handling for 30-minute check intervals + +Usage: + python rl_cli.py "Train a model on GSM8k for math reasoning" + python rl_cli.py --interactive + python rl_cli.py --list-environments + +Environment Variables: + TINKER_API_KEY: API key for Tinker service (required) + WANDB_API_KEY: API key for WandB metrics (required) + RL_API_URL: URL of RL API server (default: http://localhost:8080) + OPENROUTER_API_KEY: API key for OpenRouter (required for agent) +""" + +import asyncio +import os +import sys +from pathlib import Path + +import fire + +# Load environment variables from .env file +from dotenv import load_dotenv + +env_path = Path(__file__).parent / '.env' +if env_path.exists(): + load_dotenv(dotenv_path=env_path) + print(f"āœ… Loaded environment variables from {env_path}") + +# Import agent and tools +from run_agent import AIAgent +from model_tools import get_tool_definitions, check_toolset_requirements +from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check + + +# ============================================================================ +# RL-Specific Configuration +# ============================================================================ + +# Extended timeouts for long-running RL operations +RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows + +# RL-focused system prompt +RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models. + +## Your Capabilities + +You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos: + +1. **DISCOVER**: Use `rl_list_environments` to see available RL environments +2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards) +3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format +4. **CREATE**: Copy existing environments as templates, modify for your needs +5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training +6. **TEST**: Always use `rl_test_inference` before full training to validate your setup +7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor +8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance + +## Environment Files + +Environment files are located in: `tinker-atropos/tinker_atropos/environments/` + +Study existing environments to learn patterns. Look for: +- `load_dataset()` calls - how data is loaded +- `score_answer()` / `score()` - verification logic +- `get_next_item()` - prompt formatting +- `system_prompt` - instruction format +- `config_init()` - default configuration + +## Creating New Environments + +To create a new environment: +1. Read an existing environment file (e.g., gsm8k_tinker.py) +2. Use terminal to explore the target dataset format +3. Copy the environment file as a template +4. Modify the dataset loading, prompt formatting, and verifier logic +5. Test with `rl_test_inference` before training + +## Important Guidelines + +- **Always test before training**: Training runs take hours - verify everything works first +- **Monitor metrics**: Check WandB for reward/mean and percent_correct +- **Status check intervals**: Wait at least 30 minutes between status checks +- **Early stopping**: Stop training early if metrics look bad or stagnant +- **Iterate quickly**: Start with small total_steps to validate, then scale up + +## Available Toolsets + +You have access to: +- **RL tools**: Environment discovery, config management, training, testing +- **Terminal**: Run commands, inspect files, explore datasets +- **Web**: Search for information, documentation, papers +- **File tools**: Read and modify code files + +When asked to train a model, follow this workflow: +1. List available environments +2. Select and configure the appropriate environment +3. Test with sample prompts +4. Start training with conservative settings +5. Monitor progress and adjust as needed +""" + +# Toolsets to enable for RL workflows +RL_TOOLSETS = ["base", "terminal", "web", "rl"] + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def check_requirements(): + """Check that all required environment variables and services are available.""" + errors = [] + + # Check API keys + if not os.getenv("OPENROUTER_API_KEY"): + errors.append("OPENROUTER_API_KEY not set - required for agent") + + missing_rl_keys = get_missing_keys() + if missing_rl_keys: + errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}") + + if errors: + print("āŒ Missing requirements:") + for error in errors: + print(f" - {error}") + print("\nPlease set these environment variables in your .env file or shell.") + return False + + return True + + +async def check_rl_server(): + """Check if the RL API server is running.""" + try: + result = await rl_health_check() + import json + data = json.loads(result) + if "error" in data: + return False, data["error"] + return True, data + except Exception as e: + return False, str(e) + + +def list_environments_sync(): + """List available environments (synchronous wrapper).""" + from tools.rl_training_tool import rl_list_environments + import json + + async def _list(): + result = await rl_list_environments() + return json.loads(result) + + return asyncio.run(_list()) + + +# ============================================================================ +# Main CLI +# ============================================================================ + +def main( + task: str = None, + model: str = "anthropic/claude-sonnet-4-20250514", + api_key: str = None, + base_url: str = "https://openrouter.ai/api/v1", + max_iterations: int = RL_MAX_ITERATIONS, + interactive: bool = False, + list_environments: bool = False, + check_server: bool = False, + verbose: bool = False, + save_trajectories: bool = True, +): + """ + RL Training CLI - Dedicated runner for RL training workflows. + + Args: + task: The training task/goal (e.g., "Train a model on GSM8k for math") + model: Model to use for the agent (default: claude-sonnet-4) + api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided) + base_url: API base URL (default: OpenRouter) + max_iterations: Maximum agent iterations (default: 200 for long workflows) + interactive: Run in interactive mode (multiple conversations) + list_environments: Just list available RL environments and exit + check_server: Check if RL API server is running and exit + verbose: Enable verbose logging + save_trajectories: Save conversation trajectories (default: True for RL) + + Examples: + # Train on a specific environment + python rl_cli.py "Train a model on GSM8k math problems" + + # Interactive mode + python rl_cli.py --interactive + + # List available environments + python rl_cli.py --list-environments + + # Check server status + python rl_cli.py --check-server + """ + print("šŸŽÆ RL Training Agent") + print("=" * 60) + + # Handle server check + if check_server: + print("\nšŸ” Checking RL API server...") + ok, result = asyncio.run(check_rl_server()) + if ok: + print("āœ… RL API server is running") + print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}") + print(f" Current environment: {result.get('current_environment', 'none')}") + print(f" Active runs: {result.get('active_runs', 0)}") + else: + print(f"āŒ RL API server not accessible: {result}") + print("\nTo start the server:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + # Handle environment listing + if list_environments: + print("\nšŸ“‹ Available RL Environments:") + print("-" * 40) + try: + data = list_environments_sync() + if "error" in data: + print(f"āŒ Error: {data['error']}") + return + + envs = data.get("environments", []) + if not envs: + print("No environments found.") + print("\nMake sure the RL API server is running:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + for env in envs: + print(f"\n šŸ“¦ {env['name']}") + print(f" Class: {env['class_name']}") + print(f" Path: {env['file_path']}") + if env.get('description'): + desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '') + print(f" Description: {desc}") + + print(f"\nšŸ“Š Total: {len(envs)} environments") + print("\nUse `rl_select_environment(name)` to select an environment for training.") + except Exception as e: + print(f"āŒ Error listing environments: {e}") + print("\nMake sure the RL API server is running:") + print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") + return + + # Check requirements + if not check_requirements(): + sys.exit(1) + + # Set default task if none provided + if not task and not interactive: + print("\nāš ļø No task provided. Use --interactive for interactive mode or provide a task.") + print("\nExamples:") + print(' python rl_cli.py "Train a model on GSM8k math problems"') + print(' python rl_cli.py "Create an RL environment for code generation"') + print(' python rl_cli.py --interactive') + return + + # Get API key + api_key = api_key or os.getenv("OPENROUTER_API_KEY") + if not api_key: + print("āŒ No API key provided. Set OPENROUTER_API_KEY or pass --api-key") + sys.exit(1) + + print(f"\nšŸ¤– Model: {model}") + print(f"šŸ”§ Max iterations: {max_iterations}") + print(f"šŸ“ Toolsets: {', '.join(RL_TOOLSETS)}") + print("=" * 60) + + # Create agent with RL configuration + agent = AIAgent( + base_url=base_url, + api_key=api_key, + model=model, + max_iterations=max_iterations, + enabled_toolsets=RL_TOOLSETS, + save_trajectories=save_trajectories, + verbose_logging=verbose, + quiet_mode=False, + ephemeral_system_prompt=RL_SYSTEM_PROMPT, + ) + + if interactive: + # Interactive mode - multiple conversations + print("\nšŸ”„ Interactive RL Training Mode") + print("Type 'quit' or 'exit' to end the session.") + print("Type 'status' to check active training runs.") + print("-" * 40) + + while True: + try: + user_input = input("\nšŸŽÆ RL Task> ").strip() + + if not user_input: + continue + + if user_input.lower() in ('quit', 'exit', 'q'): + print("\nšŸ‘‹ Goodbye!") + break + + if user_input.lower() == 'status': + # Quick status check + from tools.rl_training_tool import rl_list_runs + import json + result = asyncio.run(rl_list_runs()) + runs = json.loads(result) + if isinstance(runs, list) and runs: + print("\nšŸ“Š Active Runs:") + for run in runs: + print(f" - {run['run_id']}: {run['environment']} ({run['status']})") + else: + print("\nNo active runs.") + continue + + # Run the agent + print("\n" + "=" * 60) + response = agent.run_conversation(user_input) + print("\n" + "=" * 60) + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Interrupted. Goodbye!") + break + except Exception as e: + print(f"\nāŒ Error: {e}") + if verbose: + import traceback + traceback.print_exc() + else: + # Single task mode + print(f"\nšŸ“ Task: {task}") + print("-" * 40) + + try: + response = agent.run_conversation(task) + print("\n" + "=" * 60) + print("āœ… Task completed") + except KeyboardInterrupt: + print("\n\nāš ļø Interrupted by user") + except Exception as e: + print(f"\nāŒ Error: {e}") + if verbose: + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/tools/__init__.py b/tools/__init__.py index 3365dab44..dd8bb4dac 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -95,6 +95,23 @@ from .cronjob_tools import ( REMOVE_CRONJOB_SCHEMA ) +# RL Training tools (Tinker-Atropos) +from .rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + rl_list_runs, + rl_health_check, + check_rl_api_keys, + get_missing_keys, +) + __all__ = [ # Web tools 'web_search_tool', @@ -152,5 +169,19 @@ __all__ = [ 'SCHEDULE_CRONJOB_SCHEMA', 'LIST_CRONJOBS_SCHEMA', 'REMOVE_CRONJOB_SCHEMA', + # RL Training tools + 'rl_list_environments', + 'rl_select_environment', + 'rl_get_current_config', + 'rl_edit_config', + 'rl_start_training', + 'rl_check_status', + 'rl_stop_training', + 'rl_get_results', + 'rl_test_inference', + 'rl_list_runs', + 'rl_health_check', + 'check_rl_api_keys', + 'get_missing_keys', ] diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py new file mode 100644 index 000000000..1b7401c1c --- /dev/null +++ b/tools/rl_training_tool.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python3 +""" +RL Training Tools Module + +This module provides tools for running RL training through Tinker-Atropos. +Communicates with the RL API server (rl_api_server.py) to manage: +- Environment discovery and selection +- Configuration management +- Training run lifecycle +- WandB metrics monitoring +- Inference-only testing + +Required environment variables: +- TINKER_API_KEY: API key for Tinker service +- WANDB_API_KEY: API key for Weights & Biases metrics + +Optional environment variables: +- RL_API_URL: URL of the RL API server (default: http://localhost:8080) +- WANDB_ENTITY: WandB entity/team name +- WANDB_PROJECT: Default WandB project name + +Usage: + from tools.rl_training_tool import ( + rl_list_environments, + rl_select_environment, + rl_get_current_config, + rl_edit_config, + rl_start_training, + rl_check_status, + rl_stop_training, + rl_get_results, + rl_test_inference, + ) +""" + +import json +import os +import time +from typing import Any, Dict, List, Optional + +import aiohttp + +# ============================================================================ +# Configuration +# ============================================================================ + +# Default RL API server URL (can be overridden via environment variable) +RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080") + +# Rate limiting for status checks (30 minutes in seconds) +MIN_STATUS_CHECK_INTERVAL = 30 * 60 +_last_status_check: Dict[str, float] = {} + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +async def _make_request( + method: str, + endpoint: str, + data: Optional[Dict] = None, + timeout: int = 30, +) -> Dict[str, Any]: + """Make an HTTP request to the RL API server.""" + url = f"{RL_API_URL}{endpoint}" + + async with aiohttp.ClientSession() as session: + try: + if method == "GET": + async with session.get(url, timeout=timeout) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + return {"error": f"HTTP {response.status}: {error_text}"} + elif method == "POST": + async with session.post(url, json=data, timeout=timeout) as response: + if response.status == 200: + return await response.json() + else: + error_text = await response.text() + return {"error": f"HTTP {response.status}: {error_text}"} + except aiohttp.ClientConnectorError: + return { + "error": f"Cannot connect to RL API server at {RL_API_URL}. " + "Make sure the server is running: " + "cd tinker-atropos && uvicorn rl_api_server:app --port 8080" + } + except Exception as e: + return {"error": f"Request failed: {str(e)}"} + + +# ============================================================================ +# Environment Discovery Tools +# ============================================================================ + +async def rl_list_environments() -> str: + """ + List all available RL environments. + + Scans tinker-atropos/tinker_atropos/environments/ for Python files + containing classes that inherit from BaseEnv. + + Returns information about each environment including: + - name: Environment identifier + - class_name: Python class name + - file_path: Path to the environment file + - description: Brief description if available + + TIP: To create or modify RL environments: + 1. Use terminal/file tools to inspect existing environments + 2. Study how they load datasets, define verifiers, and structure rewards + 3. Inspect HuggingFace datasets to understand data formats + 4. Copy an existing environment as a template + 5. Test with rl_test_inference before running full training + + Returns: + JSON string with list of environments or error message + """ + result = await _make_request("GET", "/environments") + + if "error" in result: + return json.dumps(result, indent=2) + + # Add helpful tips to the response + response = { + "environments": result, + "count": len(result), + "tips": [ + "Use rl_select_environment(name) to select an environment", + "Read the file_path with file tools to understand how each environment works", + "Look for load_dataset(), score_answer(), get_next_item() methods", + ] + } + + return json.dumps(response, indent=2) + + +async def rl_select_environment(name: str) -> str: + """ + Select an RL environment for training. + + This loads the environment's default configuration into the config state. + After selecting, use rl_get_current_config() to see the configuration + and rl_edit_config() to modify specific fields. + + Args: + name: Name of the environment to select (from rl_list_environments) + + Returns: + JSON string with selection result, file path, and current config + + TIP: Read the returned file_path to understand how the environment works: + - How it loads data (load_dataset calls) + - How it verifies answers (score_answer method) + - What prompts it uses (system_prompt, get_next_item) + """ + result = await _make_request("POST", f"/environments/{name}/select") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Configuration Tools +# ============================================================================ + +async def rl_get_current_config() -> str: + """ + Get the current environment configuration. + + Returns only the fields that are safe to modify. Other fields + (tokenizer_name, rollout_server_url, etc.) are fixed by the system. + + Available fields: + - group_size: Rollouts per prompt (4-16 typical) + - max_token_length: Max generation tokens (2048-16384) + - total_steps: Training steps (50-2000) + - steps_per_eval: Steps between evaluations + - use_wandb: Enable WandB logging + - wandb_name: WandB run name prefix + - max_num_workers: Concurrent workers (-1 = auto) + + Returns: + JSON string with current config fields and their values + """ + result = await _make_request("GET", "/config") + return json.dumps(result, indent=2) + + +async def rl_edit_config(field: str, value: Any) -> str: + """ + Update a configuration field. + + Only exposed fields can be modified. Validates field name and type. + + Args: + field: Name of the field to update (e.g., "group_size", "total_steps") + value: New value for the field + + Valid fields: + - group_size (int): Rollouts per prompt + - max_token_length (int): Max generation tokens + - total_steps (int): Training steps + - steps_per_eval (int): Eval frequency + - use_wandb (bool): Enable logging + - wandb_name (str): Run name prefix + - max_num_workers (int): Workers count + + Returns: + JSON string with updated config or error message + """ + result = await _make_request("POST", "/config", {"field": field, "value": value}) + return json.dumps(result, indent=2) + + +# ============================================================================ +# Training Management Tools +# ============================================================================ + +async def rl_start_training( + wandb_project: str = "rl-training", + lora_rank: int = 32, + learning_rate: float = 4e-5, +) -> str: + """ + Start a new RL training run with the current environment and config. + + Requires an environment to be selected first using rl_select_environment(). + + WARNING: Training runs can take hours to days. Use rl_check_status() to + monitor progress (recommended: check every 30 minutes at most). + + Args: + wandb_project: WandB project name for logging + lora_rank: LoRA rank for training (default: 32) + learning_rate: Learning rate (default: 4e-5) + + Returns: + JSON string with run_id and initial status + + TIP: Before starting training: + 1. Test with rl_test_inference() to verify the environment works + 2. Start with fewer total_steps to validate the setup + 3. Monitor WandB metrics for reward/mean and percent_correct + """ + result = await _make_request("POST", "/runs", { + "wandb_project": wandb_project, + "lora_rank": lora_rank, + "learning_rate": learning_rate, + }) + return json.dumps(result, indent=2) + + +async def rl_check_status(run_id: str) -> str: + """ + Get status and metrics for a training run. + + RATE LIMITED: For long-running training, this function enforces a + minimum 30-minute interval between checks for the same run_id. + + Fetches latest metrics from WandB if available: + - step: Current training step + - state: Run state (running, finished, crashed) + - reward_mean: Average reward across batches + - loss: Training loss + - percent_correct: Training accuracy + - eval_percent_correct: Evaluation accuracy + + Args: + run_id: The run ID returned by rl_start_training() + + Returns: + JSON string with run status and metrics, or rate limit message + """ + global _last_status_check + + # Check rate limiting + now = time.time() + if run_id in _last_status_check: + elapsed = now - _last_status_check[run_id] + if elapsed < MIN_STATUS_CHECK_INTERVAL: + remaining = MIN_STATUS_CHECK_INTERVAL - elapsed + return json.dumps({ + "rate_limited": True, + "run_id": run_id, + "message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.", + "next_check_in_seconds": remaining, + }, indent=2) + + _last_status_check[run_id] = now + result = await _make_request("GET", f"/runs/{run_id}") + return json.dumps(result, indent=2) + + +async def rl_stop_training(run_id: str) -> str: + """ + Stop a running training job. + + Use this if: + - Metrics look bad or training is stagnant + - You want to try different settings + - You need to free up resources + + Args: + run_id: The run ID to stop + + Returns: + JSON string with stop confirmation + """ + result = await _make_request("POST", f"/runs/{run_id}/stop") + return json.dumps(result, indent=2) + + +async def rl_get_results(run_id: str) -> str: + """ + Get final results and metrics for a completed training run. + + Returns: + - Final metrics (reward, loss, accuracy) + - WandB run URL for detailed analysis + - Path to trained weights (tinker:// URL) + + Args: + run_id: The run ID to get results for + + Returns: + JSON string with final results and weights path + """ + result = await _make_request("GET", f"/runs/{run_id}/metrics") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Inference Testing Tools +# ============================================================================ + +async def rl_test_inference( + prompts: List[str], + max_tokens: int = 256, + temperature: float = 1.0, +) -> str: + """ + Test inference + verifier on sample prompts WITHOUT full training. + + Use this to validate environments before committing to long training runs. + Tests: + - Data loading and formatting + - Model inference through Tinker + - Verifier/reward function logic + + NOTE: This still requires the RL API server to be running with + Tinker access for the Sample() method. + + Args: + prompts: List of test prompts to run through the environment + max_tokens: Maximum tokens to generate per prompt + temperature: Sampling temperature + + Returns: + JSON string with responses and verifier scores for each prompt + + TIP: Include prompts with known correct/incorrect answers to verify + the reward function is working correctly. + """ + result = await _make_request("POST", "/test/inference", { + "prompts": prompts, + "max_tokens": max_tokens, + "temperature": temperature, + }) + return json.dumps(result, indent=2) + + +# ============================================================================ +# Utility Tools +# ============================================================================ + +async def rl_list_runs() -> str: + """ + List all training runs (active and completed). + + Returns: + JSON string with list of runs and their status + """ + result = await _make_request("GET", "/runs") + return json.dumps(result, indent=2) + + +# ============================================================================ +# Requirements Check +# ============================================================================ + +def check_rl_api_keys() -> bool: + """ + Check if required API keys are available in environment variables. + + Required: + - TINKER_API_KEY: For Tinker training service + - WANDB_API_KEY: For metrics logging and fetching + + Returns: + bool: True if all required keys are set, False otherwise + """ + tinker_key = os.getenv("TINKER_API_KEY") + wandb_key = os.getenv("WANDB_API_KEY") + + return bool(tinker_key) and bool(wandb_key) + + +def get_missing_keys() -> List[str]: + """ + Get list of missing required API keys. + + Returns: + List of missing key names + """ + missing = [] + if not os.getenv("TINKER_API_KEY"): + missing.append("TINKER_API_KEY") + if not os.getenv("WANDB_API_KEY"): + missing.append("WANDB_API_KEY") + return missing + + +# ============================================================================ +# Debug/Status +# ============================================================================ + +async def rl_health_check() -> str: + """ + Check if the RL API server is running and accessible. + + Returns: + JSON string with server health status + """ + result = await _make_request("GET", "/health") + return json.dumps(result, indent=2) diff --git a/toolsets.py b/toolsets.py index 5d08731ec..e4644251c 100644 --- a/toolsets.py +++ b/toolsets.py @@ -90,6 +90,18 @@ TOOLSETS = { "includes": [] }, + "rl": { + "description": "RL training tools for running reinforcement learning on Tinker-Atropos", + "tools": [ + "rl_list_environments", "rl_select_environment", + "rl_get_current_config", "rl_edit_config", + "rl_start_training", "rl_check_status", + "rl_stop_training", "rl_get_results", + "rl_test_inference", "rl_list_runs" + ], + "includes": [] + }, + # Scenario-specific toolsets "debugging": {