initial RL training tools and loop
This commit is contained in:
360
model_tools.py
360
model_tools.py
@@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_test_inference,
|
||||
rl_list_runs,
|
||||
rl_health_check,
|
||||
check_rl_api_keys,
|
||||
)
|
||||
# Cronjob management tools (CLI-only)
|
||||
from tools.cronjob_tools import (
|
||||
schedule_cronjob,
|
||||
@@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = {
|
||||
"setup_url": None,
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training (Tinker-Atropos)",
|
||||
"env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"],
|
||||
"check_fn": check_rl_api_keys,
|
||||
"setup_url": "https://wandb.ai/authorize",
|
||||
"tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]:
|
||||
]]
|
||||
|
||||
|
||||
def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for RL training tools in OpenAI's expected format.
|
||||
|
||||
These tools enable running RL training through Tinker-Atropos.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of RL tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_environments",
|
||||
"description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_select_environment",
|
||||
"description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the environment to select (from rl_list_environments)"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_current_config",
|
||||
"description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_edit_config",
|
||||
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "Name of the field to update"
|
||||
},
|
||||
"value": {
|
||||
"description": "New value for the field"
|
||||
}
|
||||
},
|
||||
"required": ["field", "value"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_start_training",
|
||||
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"wandb_project": {
|
||||
"type": "string",
|
||||
"description": "WandB project name for logging",
|
||||
"default": "rl-training"
|
||||
},
|
||||
"lora_rank": {
|
||||
"type": "integer",
|
||||
"description": "LoRA rank for training",
|
||||
"default": 32
|
||||
},
|
||||
"learning_rate": {
|
||||
"type": "number",
|
||||
"description": "Learning rate",
|
||||
"default": 4e-5
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_check_status",
|
||||
"description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID from rl_start_training()"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_stop_training",
|
||||
"description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to stop"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_results",
|
||||
"description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to get results for"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_test_inference",
|
||||
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of test prompts to run through the environment"
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens to generate per prompt",
|
||||
"default": 256
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Sampling temperature",
|
||||
"default": 1.0
|
||||
}
|
||||
},
|
||||
"required": ["prompts"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_runs",
|
||||
"description": "List all training runs (active and completed) with their status.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
@@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]:
|
||||
"schedule_cronjob", "list_cronjobs", "remove_cronjob"
|
||||
])
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
tool_names.extend([
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
@@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
||||
# Cronjob management tools
|
||||
"schedule_cronjob": "cronjob_tools",
|
||||
"list_cronjobs": "cronjob_tools",
|
||||
"remove_cronjob": "cronjob_tools"
|
||||
"remove_cronjob": "cronjob_tools",
|
||||
# RL Training tools
|
||||
"rl_list_environments": "rl_tools",
|
||||
"rl_select_environment": "rl_tools",
|
||||
"rl_get_current_config": "rl_tools",
|
||||
"rl_edit_config": "rl_tools",
|
||||
"rl_start_training": "rl_tools",
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_test_inference": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
}
|
||||
|
||||
return toolset_mapping.get(tool_name, "unknown")
|
||||
@@ -635,6 +877,11 @@ def get_tool_definitions(
|
||||
for tool in get_cronjob_tool_definitions_formatted():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
for tool in get_rl_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
@@ -663,7 +910,14 @@ def get_tool_definitions(
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
@@ -708,7 +962,14 @@ def get_tool_definitions(
|
||||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
@@ -1018,6 +1279,89 @@ def handle_cronjob_function_call(
|
||||
return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_rl_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for RL training tools.
|
||||
|
||||
These tools communicate with the RL API server to manage training runs.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the RL function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
# Run async functions in event loop
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if function_name == "rl_list_environments":
|
||||
return loop.run_until_complete(rl_list_environments())
|
||||
|
||||
elif function_name == "rl_select_environment":
|
||||
return loop.run_until_complete(
|
||||
rl_select_environment(name=function_args.get("name", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_current_config":
|
||||
return loop.run_until_complete(rl_get_current_config())
|
||||
|
||||
elif function_name == "rl_edit_config":
|
||||
return loop.run_until_complete(
|
||||
rl_edit_config(
|
||||
field=function_args.get("field", ""),
|
||||
value=function_args.get("value")
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_start_training":
|
||||
return loop.run_until_complete(
|
||||
rl_start_training(
|
||||
wandb_project=function_args.get("wandb_project", "rl-training"),
|
||||
lora_rank=function_args.get("lora_rank", 32),
|
||||
learning_rate=function_args.get("learning_rate", 4e-5)
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_check_status":
|
||||
return loop.run_until_complete(
|
||||
rl_check_status(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_stop_training":
|
||||
return loop.run_until_complete(
|
||||
rl_stop_training(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_results":
|
||||
return loop.run_until_complete(
|
||||
rl_get_results(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_test_inference":
|
||||
return loop.run_until_complete(
|
||||
rl_test_inference(
|
||||
prompts=function_args.get("prompts", []),
|
||||
max_tokens=function_args.get("max_tokens", 256),
|
||||
temperature=function_args.get("temperature", 1.0)
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_list_runs":
|
||||
return loop.run_until_complete(rl_list_runs())
|
||||
|
||||
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
@@ -1081,6 +1425,16 @@ def handle_function_call(
|
||||
elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]:
|
||||
return handle_cronjob_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route RL training tools
|
||||
elif function_name in [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]:
|
||||
return handle_rl_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
363
rl_cli.py
Normal file
363
rl_cli.py
Normal file
@@ -0,0 +1,363 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RL Training CLI Runner
|
||||
|
||||
Dedicated CLI runner for RL training workflows with:
|
||||
- Extended timeouts for long-running training
|
||||
- RL-focused system prompts
|
||||
- Full toolset including RL training tools
|
||||
- Special handling for 30-minute check intervals
|
||||
|
||||
Usage:
|
||||
python rl_cli.py "Train a model on GSM8k for math reasoning"
|
||||
python rl_cli.py --interactive
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
Environment Variables:
|
||||
TINKER_API_KEY: API key for Tinker service (required)
|
||||
WANDB_API_KEY: API key for WandB metrics (required)
|
||||
RL_API_URL: URL of RL API server (default: http://localhost:8080)
|
||||
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
|
||||
env_path = Path(__file__).parent / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
print(f"✅ Loaded environment variables from {env_path}")
|
||||
|
||||
# Import agent and tools
|
||||
from run_agent import AIAgent
|
||||
from model_tools import get_tool_definitions, check_toolset_requirements
|
||||
from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RL-Specific Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Extended timeouts for long-running RL operations
|
||||
RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows
|
||||
|
||||
# RL-focused system prompt
|
||||
RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models.
|
||||
|
||||
## Your Capabilities
|
||||
|
||||
You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos:
|
||||
|
||||
1. **DISCOVER**: Use `rl_list_environments` to see available RL environments
|
||||
2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards)
|
||||
3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format
|
||||
4. **CREATE**: Copy existing environments as templates, modify for your needs
|
||||
5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training
|
||||
6. **TEST**: Always use `rl_test_inference` before full training to validate your setup
|
||||
7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor
|
||||
8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance
|
||||
|
||||
## Environment Files
|
||||
|
||||
Environment files are located in: `tinker-atropos/tinker_atropos/environments/`
|
||||
|
||||
Study existing environments to learn patterns. Look for:
|
||||
- `load_dataset()` calls - how data is loaded
|
||||
- `score_answer()` / `score()` - verification logic
|
||||
- `get_next_item()` - prompt formatting
|
||||
- `system_prompt` - instruction format
|
||||
- `config_init()` - default configuration
|
||||
|
||||
## Creating New Environments
|
||||
|
||||
To create a new environment:
|
||||
1. Read an existing environment file (e.g., gsm8k_tinker.py)
|
||||
2. Use terminal to explore the target dataset format
|
||||
3. Copy the environment file as a template
|
||||
4. Modify the dataset loading, prompt formatting, and verifier logic
|
||||
5. Test with `rl_test_inference` before training
|
||||
|
||||
## Important Guidelines
|
||||
|
||||
- **Always test before training**: Training runs take hours - verify everything works first
|
||||
- **Monitor metrics**: Check WandB for reward/mean and percent_correct
|
||||
- **Status check intervals**: Wait at least 30 minutes between status checks
|
||||
- **Early stopping**: Stop training early if metrics look bad or stagnant
|
||||
- **Iterate quickly**: Start with small total_steps to validate, then scale up
|
||||
|
||||
## Available Toolsets
|
||||
|
||||
You have access to:
|
||||
- **RL tools**: Environment discovery, config management, training, testing
|
||||
- **Terminal**: Run commands, inspect files, explore datasets
|
||||
- **Web**: Search for information, documentation, papers
|
||||
- **File tools**: Read and modify code files
|
||||
|
||||
When asked to train a model, follow this workflow:
|
||||
1. List available environments
|
||||
2. Select and configure the appropriate environment
|
||||
3. Test with sample prompts
|
||||
4. Start training with conservative settings
|
||||
5. Monitor progress and adjust as needed
|
||||
"""
|
||||
|
||||
# Toolsets to enable for RL workflows
|
||||
RL_TOOLSETS = ["base", "terminal", "web", "rl"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def check_requirements():
|
||||
"""Check that all required environment variables and services are available."""
|
||||
errors = []
|
||||
|
||||
# Check API keys
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
errors.append("OPENROUTER_API_KEY not set - required for agent")
|
||||
|
||||
missing_rl_keys = get_missing_keys()
|
||||
if missing_rl_keys:
|
||||
errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}")
|
||||
|
||||
if errors:
|
||||
print("❌ Missing requirements:")
|
||||
for error in errors:
|
||||
print(f" - {error}")
|
||||
print("\nPlease set these environment variables in your .env file or shell.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def check_rl_server():
|
||||
"""Check if the RL API server is running."""
|
||||
try:
|
||||
result = await rl_health_check()
|
||||
import json
|
||||
data = json.loads(result)
|
||||
if "error" in data:
|
||||
return False, data["error"]
|
||||
return True, data
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def list_environments_sync():
|
||||
"""List available environments (synchronous wrapper)."""
|
||||
from tools.rl_training_tool import rl_list_environments
|
||||
import json
|
||||
|
||||
async def _list():
|
||||
result = await rl_list_environments()
|
||||
return json.loads(result)
|
||||
|
||||
return asyncio.run(_list())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main CLI
|
||||
# ============================================================================
|
||||
|
||||
def main(
|
||||
task: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_iterations: int = RL_MAX_ITERATIONS,
|
||||
interactive: bool = False,
|
||||
list_environments: bool = False,
|
||||
check_server: bool = False,
|
||||
verbose: bool = False,
|
||||
save_trajectories: bool = True,
|
||||
):
|
||||
"""
|
||||
RL Training CLI - Dedicated runner for RL training workflows.
|
||||
|
||||
Args:
|
||||
task: The training task/goal (e.g., "Train a model on GSM8k for math")
|
||||
model: Model to use for the agent (default: claude-sonnet-4)
|
||||
api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided)
|
||||
base_url: API base URL (default: OpenRouter)
|
||||
max_iterations: Maximum agent iterations (default: 200 for long workflows)
|
||||
interactive: Run in interactive mode (multiple conversations)
|
||||
list_environments: Just list available RL environments and exit
|
||||
check_server: Check if RL API server is running and exit
|
||||
verbose: Enable verbose logging
|
||||
save_trajectories: Save conversation trajectories (default: True for RL)
|
||||
|
||||
Examples:
|
||||
# Train on a specific environment
|
||||
python rl_cli.py "Train a model on GSM8k math problems"
|
||||
|
||||
# Interactive mode
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Check server status
|
||||
python rl_cli.py --check-server
|
||||
"""
|
||||
print("🎯 RL Training Agent")
|
||||
print("=" * 60)
|
||||
|
||||
# Handle server check
|
||||
if check_server:
|
||||
print("\n🔍 Checking RL API server...")
|
||||
ok, result = asyncio.run(check_rl_server())
|
||||
if ok:
|
||||
print("✅ RL API server is running")
|
||||
print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}")
|
||||
print(f" Current environment: {result.get('current_environment', 'none')}")
|
||||
print(f" Active runs: {result.get('active_runs', 0)}")
|
||||
else:
|
||||
print(f"❌ RL API server not accessible: {result}")
|
||||
print("\nTo start the server:")
|
||||
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
|
||||
return
|
||||
|
||||
# Handle environment listing
|
||||
if list_environments:
|
||||
print("\n📋 Available RL Environments:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
data = list_environments_sync()
|
||||
if "error" in data:
|
||||
print(f"❌ Error: {data['error']}")
|
||||
return
|
||||
|
||||
envs = data.get("environments", [])
|
||||
if not envs:
|
||||
print("No environments found.")
|
||||
print("\nMake sure the RL API server is running:")
|
||||
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
|
||||
return
|
||||
|
||||
for env in envs:
|
||||
print(f"\n 📦 {env['name']}")
|
||||
print(f" Class: {env['class_name']}")
|
||||
print(f" Path: {env['file_path']}")
|
||||
if env.get('description'):
|
||||
desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '')
|
||||
print(f" Description: {desc}")
|
||||
|
||||
print(f"\n📊 Total: {len(envs)} environments")
|
||||
print("\nUse `rl_select_environment(name)` to select an environment for training.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error listing environments: {e}")
|
||||
print("\nMake sure the RL API server is running:")
|
||||
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080")
|
||||
return
|
||||
|
||||
# Check requirements
|
||||
if not check_requirements():
|
||||
sys.exit(1)
|
||||
|
||||
# Set default task if none provided
|
||||
if not task and not interactive:
|
||||
print("\n⚠️ No task provided. Use --interactive for interactive mode or provide a task.")
|
||||
print("\nExamples:")
|
||||
print(' python rl_cli.py "Train a model on GSM8k math problems"')
|
||||
print(' python rl_cli.py "Create an RL environment for code generation"')
|
||||
print(' python rl_cli.py --interactive')
|
||||
return
|
||||
|
||||
# Get API key
|
||||
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ No API key provided. Set OPENROUTER_API_KEY or pass --api-key")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n🤖 Model: {model}")
|
||||
print(f"🔧 Max iterations: {max_iterations}")
|
||||
print(f"📁 Toolsets: {', '.join(RL_TOOLSETS)}")
|
||||
print("=" * 60)
|
||||
|
||||
# Create agent with RL configuration
|
||||
agent = AIAgent(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
enabled_toolsets=RL_TOOLSETS,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose,
|
||||
quiet_mode=False,
|
||||
ephemeral_system_prompt=RL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
if interactive:
|
||||
# Interactive mode - multiple conversations
|
||||
print("\n🔄 Interactive RL Training Mode")
|
||||
print("Type 'quit' or 'exit' to end the session.")
|
||||
print("Type 'status' to check active training runs.")
|
||||
print("-" * 40)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🎯 RL Task> ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ('quit', 'exit', 'q'):
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'status':
|
||||
# Quick status check
|
||||
from tools.rl_training_tool import rl_list_runs
|
||||
import json
|
||||
result = asyncio.run(rl_list_runs())
|
||||
runs = json.loads(result)
|
||||
if isinstance(runs, list) and runs:
|
||||
print("\n📊 Active Runs:")
|
||||
for run in runs:
|
||||
print(f" - {run['run_id']}: {run['environment']} ({run['status']})")
|
||||
else:
|
||||
print("\nNo active runs.")
|
||||
continue
|
||||
|
||||
# Run the agent
|
||||
print("\n" + "=" * 60)
|
||||
response = agent.run_conversation(user_input)
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
# Single task mode
|
||||
print(f"\n📝 Task: {task}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
response = agent.run_conversation(task)
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Task completed")
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
@@ -95,6 +95,23 @@ from .cronjob_tools import (
|
||||
REMOVE_CRONJOB_SCHEMA
|
||||
)
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_test_inference,
|
||||
rl_list_runs,
|
||||
rl_health_check,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
@@ -152,5 +169,19 @@ __all__ = [
|
||||
'SCHEDULE_CRONJOB_SCHEMA',
|
||||
'LIST_CRONJOBS_SCHEMA',
|
||||
'REMOVE_CRONJOB_SCHEMA',
|
||||
# RL Training tools
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_test_inference',
|
||||
'rl_list_runs',
|
||||
'rl_health_check',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
]
|
||||
|
||||
|
||||
436
tools/rl_training_tool.py
Normal file
436
tools/rl_training_tool.py
Normal file
@@ -0,0 +1,436 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RL Training Tools Module
|
||||
|
||||
This module provides tools for running RL training through Tinker-Atropos.
|
||||
Communicates with the RL API server (rl_api_server.py) to manage:
|
||||
- Environment discovery and selection
|
||||
- Configuration management
|
||||
- Training run lifecycle
|
||||
- WandB metrics monitoring
|
||||
- Inference-only testing
|
||||
|
||||
Required environment variables:
|
||||
- TINKER_API_KEY: API key for Tinker service
|
||||
- WANDB_API_KEY: API key for Weights & Biases metrics
|
||||
|
||||
Optional environment variables:
|
||||
- RL_API_URL: URL of the RL API server (default: http://localhost:8080)
|
||||
- WANDB_ENTITY: WandB entity/team name
|
||||
- WANDB_PROJECT: Default WandB project name
|
||||
|
||||
Usage:
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_test_inference,
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Default RL API server URL (can be overridden via environment variable)
|
||||
RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080")
|
||||
|
||||
# Rate limiting for status checks (30 minutes in seconds)
|
||||
MIN_STATUS_CHECK_INTERVAL = 30 * 60
|
||||
_last_status_check: Dict[str, float] = {}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
async def _make_request(
|
||||
method: str,
|
||||
endpoint: str,
|
||||
data: Optional[Dict] = None,
|
||||
timeout: int = 30,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to the RL API server."""
|
||||
url = f"{RL_API_URL}{endpoint}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
if method == "GET":
|
||||
async with session.get(url, timeout=timeout) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
return {"error": f"HTTP {response.status}: {error_text}"}
|
||||
elif method == "POST":
|
||||
async with session.post(url, json=data, timeout=timeout) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
return {"error": f"HTTP {response.status}: {error_text}"}
|
||||
except aiohttp.ClientConnectorError:
|
||||
return {
|
||||
"error": f"Cannot connect to RL API server at {RL_API_URL}. "
|
||||
"Make sure the server is running: "
|
||||
"cd tinker-atropos && uvicorn rl_api_server:app --port 8080"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"Request failed: {str(e)}"}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Environment Discovery Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_list_environments() -> str:
|
||||
"""
|
||||
List all available RL environments.
|
||||
|
||||
Scans tinker-atropos/tinker_atropos/environments/ for Python files
|
||||
containing classes that inherit from BaseEnv.
|
||||
|
||||
Returns information about each environment including:
|
||||
- name: Environment identifier
|
||||
- class_name: Python class name
|
||||
- file_path: Path to the environment file
|
||||
- description: Brief description if available
|
||||
|
||||
TIP: To create or modify RL environments:
|
||||
1. Use terminal/file tools to inspect existing environments
|
||||
2. Study how they load datasets, define verifiers, and structure rewards
|
||||
3. Inspect HuggingFace datasets to understand data formats
|
||||
4. Copy an existing environment as a template
|
||||
5. Test with rl_test_inference before running full training
|
||||
|
||||
Returns:
|
||||
JSON string with list of environments or error message
|
||||
"""
|
||||
result = await _make_request("GET", "/environments")
|
||||
|
||||
if "error" in result:
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
# Add helpful tips to the response
|
||||
response = {
|
||||
"environments": result,
|
||||
"count": len(result),
|
||||
"tips": [
|
||||
"Use rl_select_environment(name) to select an environment",
|
||||
"Read the file_path with file tools to understand how each environment works",
|
||||
"Look for load_dataset(), score_answer(), get_next_item() methods",
|
||||
]
|
||||
}
|
||||
|
||||
return json.dumps(response, indent=2)
|
||||
|
||||
|
||||
async def rl_select_environment(name: str) -> str:
|
||||
"""
|
||||
Select an RL environment for training.
|
||||
|
||||
This loads the environment's default configuration into the config state.
|
||||
After selecting, use rl_get_current_config() to see the configuration
|
||||
and rl_edit_config() to modify specific fields.
|
||||
|
||||
Args:
|
||||
name: Name of the environment to select (from rl_list_environments)
|
||||
|
||||
Returns:
|
||||
JSON string with selection result, file path, and current config
|
||||
|
||||
TIP: Read the returned file_path to understand how the environment works:
|
||||
- How it loads data (load_dataset calls)
|
||||
- How it verifies answers (score_answer method)
|
||||
- What prompts it uses (system_prompt, get_next_item)
|
||||
"""
|
||||
result = await _make_request("POST", f"/environments/{name}/select")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_get_current_config() -> str:
|
||||
"""
|
||||
Get the current environment configuration.
|
||||
|
||||
Returns only the fields that are safe to modify. Other fields
|
||||
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
|
||||
|
||||
Available fields:
|
||||
- group_size: Rollouts per prompt (4-16 typical)
|
||||
- max_token_length: Max generation tokens (2048-16384)
|
||||
- total_steps: Training steps (50-2000)
|
||||
- steps_per_eval: Steps between evaluations
|
||||
- use_wandb: Enable WandB logging
|
||||
- wandb_name: WandB run name prefix
|
||||
- max_num_workers: Concurrent workers (-1 = auto)
|
||||
|
||||
Returns:
|
||||
JSON string with current config fields and their values
|
||||
"""
|
||||
result = await _make_request("GET", "/config")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
async def rl_edit_config(field: str, value: Any) -> str:
|
||||
"""
|
||||
Update a configuration field.
|
||||
|
||||
Only exposed fields can be modified. Validates field name and type.
|
||||
|
||||
Args:
|
||||
field: Name of the field to update (e.g., "group_size", "total_steps")
|
||||
value: New value for the field
|
||||
|
||||
Valid fields:
|
||||
- group_size (int): Rollouts per prompt
|
||||
- max_token_length (int): Max generation tokens
|
||||
- total_steps (int): Training steps
|
||||
- steps_per_eval (int): Eval frequency
|
||||
- use_wandb (bool): Enable logging
|
||||
- wandb_name (str): Run name prefix
|
||||
- max_num_workers (int): Workers count
|
||||
|
||||
Returns:
|
||||
JSON string with updated config or error message
|
||||
"""
|
||||
result = await _make_request("POST", "/config", {"field": field, "value": value})
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Training Management Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_start_training(
|
||||
wandb_project: str = "rl-training",
|
||||
lora_rank: int = 32,
|
||||
learning_rate: float = 4e-5,
|
||||
) -> str:
|
||||
"""
|
||||
Start a new RL training run with the current environment and config.
|
||||
|
||||
Requires an environment to be selected first using rl_select_environment().
|
||||
|
||||
WARNING: Training runs can take hours to days. Use rl_check_status() to
|
||||
monitor progress (recommended: check every 30 minutes at most).
|
||||
|
||||
Args:
|
||||
wandb_project: WandB project name for logging
|
||||
lora_rank: LoRA rank for training (default: 32)
|
||||
learning_rate: Learning rate (default: 4e-5)
|
||||
|
||||
Returns:
|
||||
JSON string with run_id and initial status
|
||||
|
||||
TIP: Before starting training:
|
||||
1. Test with rl_test_inference() to verify the environment works
|
||||
2. Start with fewer total_steps to validate the setup
|
||||
3. Monitor WandB metrics for reward/mean and percent_correct
|
||||
"""
|
||||
result = await _make_request("POST", "/runs", {
|
||||
"wandb_project": wandb_project,
|
||||
"lora_rank": lora_rank,
|
||||
"learning_rate": learning_rate,
|
||||
})
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
async def rl_check_status(run_id: str) -> str:
|
||||
"""
|
||||
Get status and metrics for a training run.
|
||||
|
||||
RATE LIMITED: For long-running training, this function enforces a
|
||||
minimum 30-minute interval between checks for the same run_id.
|
||||
|
||||
Fetches latest metrics from WandB if available:
|
||||
- step: Current training step
|
||||
- state: Run state (running, finished, crashed)
|
||||
- reward_mean: Average reward across batches
|
||||
- loss: Training loss
|
||||
- percent_correct: Training accuracy
|
||||
- eval_percent_correct: Evaluation accuracy
|
||||
|
||||
Args:
|
||||
run_id: The run ID returned by rl_start_training()
|
||||
|
||||
Returns:
|
||||
JSON string with run status and metrics, or rate limit message
|
||||
"""
|
||||
global _last_status_check
|
||||
|
||||
# Check rate limiting
|
||||
now = time.time()
|
||||
if run_id in _last_status_check:
|
||||
elapsed = now - _last_status_check[run_id]
|
||||
if elapsed < MIN_STATUS_CHECK_INTERVAL:
|
||||
remaining = MIN_STATUS_CHECK_INTERVAL - elapsed
|
||||
return json.dumps({
|
||||
"rate_limited": True,
|
||||
"run_id": run_id,
|
||||
"message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.",
|
||||
"next_check_in_seconds": remaining,
|
||||
}, indent=2)
|
||||
|
||||
_last_status_check[run_id] = now
|
||||
result = await _make_request("GET", f"/runs/{run_id}")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
async def rl_stop_training(run_id: str) -> str:
|
||||
"""
|
||||
Stop a running training job.
|
||||
|
||||
Use this if:
|
||||
- Metrics look bad or training is stagnant
|
||||
- You want to try different settings
|
||||
- You need to free up resources
|
||||
|
||||
Args:
|
||||
run_id: The run ID to stop
|
||||
|
||||
Returns:
|
||||
JSON string with stop confirmation
|
||||
"""
|
||||
result = await _make_request("POST", f"/runs/{run_id}/stop")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
async def rl_get_results(run_id: str) -> str:
|
||||
"""
|
||||
Get final results and metrics for a completed training run.
|
||||
|
||||
Returns:
|
||||
- Final metrics (reward, loss, accuracy)
|
||||
- WandB run URL for detailed analysis
|
||||
- Path to trained weights (tinker:// URL)
|
||||
|
||||
Args:
|
||||
run_id: The run ID to get results for
|
||||
|
||||
Returns:
|
||||
JSON string with final results and weights path
|
||||
"""
|
||||
result = await _make_request("GET", f"/runs/{run_id}/metrics")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Inference Testing Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_test_inference(
|
||||
prompts: List[str],
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 1.0,
|
||||
) -> str:
|
||||
"""
|
||||
Test inference + verifier on sample prompts WITHOUT full training.
|
||||
|
||||
Use this to validate environments before committing to long training runs.
|
||||
Tests:
|
||||
- Data loading and formatting
|
||||
- Model inference through Tinker
|
||||
- Verifier/reward function logic
|
||||
|
||||
NOTE: This still requires the RL API server to be running with
|
||||
Tinker access for the Sample() method.
|
||||
|
||||
Args:
|
||||
prompts: List of test prompts to run through the environment
|
||||
max_tokens: Maximum tokens to generate per prompt
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
JSON string with responses and verifier scores for each prompt
|
||||
|
||||
TIP: Include prompts with known correct/incorrect answers to verify
|
||||
the reward function is working correctly.
|
||||
"""
|
||||
result = await _make_request("POST", "/test/inference", {
|
||||
"prompts": prompts,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
})
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Utility Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_list_runs() -> str:
|
||||
"""
|
||||
List all training runs (active and completed).
|
||||
|
||||
Returns:
|
||||
JSON string with list of runs and their status
|
||||
"""
|
||||
result = await _make_request("GET", "/runs")
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Requirements Check
|
||||
# ============================================================================
|
||||
|
||||
def check_rl_api_keys() -> bool:
|
||||
"""
|
||||
Check if required API keys are available in environment variables.
|
||||
|
||||
Required:
|
||||
- TINKER_API_KEY: For Tinker training service
|
||||
- WANDB_API_KEY: For metrics logging and fetching
|
||||
|
||||
Returns:
|
||||
bool: True if all required keys are set, False otherwise
|
||||
"""
|
||||
tinker_key = os.getenv("TINKER_API_KEY")
|
||||
wandb_key = os.getenv("WANDB_API_KEY")
|
||||
|
||||
return bool(tinker_key) and bool(wandb_key)
|
||||
|
||||
|
||||
def get_missing_keys() -> List[str]:
|
||||
"""
|
||||
Get list of missing required API keys.
|
||||
|
||||
Returns:
|
||||
List of missing key names
|
||||
"""
|
||||
missing = []
|
||||
if not os.getenv("TINKER_API_KEY"):
|
||||
missing.append("TINKER_API_KEY")
|
||||
if not os.getenv("WANDB_API_KEY"):
|
||||
missing.append("WANDB_API_KEY")
|
||||
return missing
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Debug/Status
|
||||
# ============================================================================
|
||||
|
||||
async def rl_health_check() -> str:
|
||||
"""
|
||||
Check if the RL API server is running and accessible.
|
||||
|
||||
Returns:
|
||||
JSON string with server health status
|
||||
"""
|
||||
result = await _make_request("GET", "/health")
|
||||
return json.dumps(result, indent=2)
|
||||
12
toolsets.py
12
toolsets.py
@@ -90,6 +90,18 @@ TOOLSETS = {
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"rl": {
|
||||
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
|
||||
"tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
# Scenario-specific toolsets
|
||||
|
||||
"debugging": {
|
||||
|
||||
Reference in New Issue
Block a user