Files
hermes-agent/tools/rl_training_tool.py
teknium1 f6574978de Add RL training configuration and tools
- Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training.
- Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process.
- Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB.
- Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks.
- Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management.
2026-02-04 09:36:51 -08:00

424 lines
14 KiB
Python

#!/usr/bin/env python3
"""
RL Training Tools Module
This module provides tools for running RL training through Tinker-Atropos.
Communicates with the RL API server (rl_api_server.py) to manage:
- Environment discovery and selection
- Configuration management
- Training run lifecycle
- WandB metrics monitoring
- Inference-only testing
Required environment variables:
- TINKER_API_KEY: API key for Tinker service
- WANDB_API_KEY: API key for Weights & Biases metrics
Optional environment variables:
- RL_API_URL: URL of the RL API server (default: http://localhost:8080)
- WANDB_ENTITY: WandB entity/team name
- WANDB_PROJECT: Default WandB project name
Usage:
from tools.rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_test_inference,
)
"""
import json
import os
import time
from typing import Any, Dict, List, Optional
import aiohttp
# ============================================================================
# Configuration
# ============================================================================
# Default RL API server URL (can be overridden via environment variable)
RL_API_URL = os.getenv("RL_API_URL", "http://localhost:8080")
# Rate limiting for status checks (30 minutes in seconds)
MIN_STATUS_CHECK_INTERVAL = 30 * 60
_last_status_check: Dict[str, float] = {}
# ============================================================================
# Helper Functions
# ============================================================================
async def _make_request(
method: str,
endpoint: str,
data: Optional[Dict] = None,
timeout: int = 30,
) -> Dict[str, Any]:
"""Make an HTTP request to the RL API server."""
url = f"{RL_API_URL}{endpoint}"
async with aiohttp.ClientSession() as session:
try:
if method == "GET":
async with session.get(url, timeout=timeout) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
return {"error": f"HTTP {response.status}: {error_text}"}
elif method == "POST":
async with session.post(url, json=data, timeout=timeout) as response:
if response.status == 200:
return await response.json()
else:
error_text = await response.text()
return {"error": f"HTTP {response.status}: {error_text}"}
except aiohttp.ClientConnectorError:
return {
"error": f"Cannot connect to RL API server at {RL_API_URL}. "
"Make sure the server is running: "
"cd tinker-atropos && uvicorn rl_api_server:app --port 8080"
}
except Exception as e:
return {"error": f"Request failed: {str(e)}"}
# ============================================================================
# Environment Discovery Tools
# ============================================================================
async def rl_list_environments() -> str:
"""
List all available RL environments.
Scans tinker-atropos/tinker_atropos/environments/ for Python files
containing classes that inherit from BaseEnv.
Returns information about each environment including:
- name: Environment identifier
- class_name: Python class name
- file_path: Path to the environment file
- description: Brief description if available
TIP: To create or modify RL environments:
1. Use terminal/file tools to inspect existing environments
2. Study how they load datasets, define verifiers, and structure rewards
3. Inspect HuggingFace datasets to understand data formats
4. Copy an existing environment as a template
5. Test with rl_test_inference before running full training
Returns:
JSON string with list of environments or error message
"""
result = await _make_request("GET", "/environments")
if "error" in result:
return json.dumps(result, indent=2)
# Add helpful tips to the response
response = {
"environments": result,
"count": len(result),
"tips": [
"Use rl_select_environment(name) to select an environment",
"Read the file_path with file tools to understand how each environment works",
"Look for load_dataset(), score_answer(), get_next_item() methods",
]
}
return json.dumps(response, indent=2)
async def rl_select_environment(name: str) -> str:
"""
Select an RL environment for training.
This loads the environment's default configuration into the config state.
After selecting, use rl_get_current_config() to see the configuration
and rl_edit_config() to modify specific fields.
Args:
name: Name of the environment to select (from rl_list_environments)
Returns:
JSON string with selection result, file path, and current config
TIP: Read the returned file_path to understand how the environment works:
- How it loads data (load_dataset calls)
- How it verifies answers (score_answer method)
- What prompts it uses (system_prompt, get_next_item)
"""
result = await _make_request("POST", f"/environments/{name}/select")
return json.dumps(result, indent=2)
# ============================================================================
# Configuration Tools
# ============================================================================
async def rl_get_current_config() -> str:
"""
Get the current environment configuration.
Returns all configurable fields for the selected environment.
Each environment may have different configuration options.
Fields are divided into:
- configurable_fields: Can be changed with rl_edit_config()
- locked_fields: Infrastructure settings that cannot be changed
Common configurable fields include:
- group_size: Rollouts per prompt
- batch_size: Training batch size
- wandb_name: WandB run name prefix
- system_prompt: Model instructions
- And any environment-specific options
Returns:
JSON string with configurable and locked fields
"""
result = await _make_request("GET", "/config")
return json.dumps(result, indent=2)
async def rl_edit_config(field: str, value: Any) -> str:
"""
Update a configuration field.
Use rl_get_current_config() first to see available fields for the
selected environment. Each environment has different options.
Locked fields (infrastructure settings) cannot be changed.
Args:
field: Name of the field to update (from rl_get_current_config)
value: New value for the field
Returns:
JSON string with updated config or error message
"""
result = await _make_request("POST", "/config", {"field": field, "value": value})
return json.dumps(result, indent=2)
# ============================================================================
# Training Management Tools
# ============================================================================
async def rl_start_training() -> str:
"""
Start a new RL training run with the current environment and config.
Requires an environment to be selected first using rl_select_environment().
Use rl_edit_config() to set group_size, batch_size, wandb_project before starting.
Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.)
and cannot be changed.
WARNING: Training runs take hours. Use rl_check_status() to monitor
progress (recommended: check every 30 minutes at most).
Returns:
JSON string with run_id and initial status
TIP: Before starting training:
1. Test with rl_test_inference() to verify the environment works
2. Configure group_size and batch_size appropriately
3. Monitor WandB metrics for reward/mean and percent_correct
"""
result = await _make_request("POST", "/runs", {})
return json.dumps(result, indent=2)
async def rl_check_status(run_id: str) -> str:
"""
Get status and metrics for a training run.
RATE LIMITED: For long-running training, this function enforces a
minimum 30-minute interval between checks for the same run_id.
Fetches latest metrics from WandB if available:
- step: Current training step
- state: Run state (running, finished, crashed)
- reward_mean: Average reward across batches
- loss: Training loss
- percent_correct: Training accuracy
- eval_percent_correct: Evaluation accuracy
Args:
run_id: The run ID returned by rl_start_training()
Returns:
JSON string with run status and metrics, or rate limit message
"""
global _last_status_check
# Check rate limiting
now = time.time()
if run_id in _last_status_check:
elapsed = now - _last_status_check[run_id]
if elapsed < MIN_STATUS_CHECK_INTERVAL:
remaining = MIN_STATUS_CHECK_INTERVAL - elapsed
return json.dumps({
"rate_limited": True,
"run_id": run_id,
"message": f"Rate limited. Next check available in {remaining/60:.0f} minutes.",
"next_check_in_seconds": remaining,
}, indent=2)
_last_status_check[run_id] = now
result = await _make_request("GET", f"/runs/{run_id}")
return json.dumps(result, indent=2)
async def rl_stop_training(run_id: str) -> str:
"""
Stop a running training job.
Use this if:
- Metrics look bad or training is stagnant
- You want to try different settings
- You need to free up resources
Args:
run_id: The run ID to stop
Returns:
JSON string with stop confirmation
"""
result = await _make_request("POST", f"/runs/{run_id}/stop")
return json.dumps(result, indent=2)
async def rl_get_results(run_id: str) -> str:
"""
Get final results and metrics for a completed training run.
Returns:
- Final metrics (reward, loss, accuracy)
- WandB run URL for detailed analysis
- Path to trained weights (tinker:// URL)
Args:
run_id: The run ID to get results for
Returns:
JSON string with final results and weights path
"""
result = await _make_request("GET", f"/runs/{run_id}/metrics")
return json.dumps(result, indent=2)
# ============================================================================
# Inference Testing Tools
# ============================================================================
async def rl_test_inference(
prompts: List[str],
max_tokens: int = 256,
temperature: float = 1.0,
) -> str:
"""
Test inference + verifier on sample prompts WITHOUT full training.
Use this to validate environments before committing to long training runs.
Tests:
- Data loading and formatting
- Model inference through Tinker
- Verifier/reward function logic
NOTE: This still requires the RL API server to be running with
Tinker access for the Sample() method.
Args:
prompts: List of test prompts to run through the environment
max_tokens: Maximum tokens to generate per prompt
temperature: Sampling temperature
Returns:
JSON string with responses and verifier scores for each prompt
TIP: Include prompts with known correct/incorrect answers to verify
the reward function is working correctly.
"""
result = await _make_request("POST", "/test/inference", {
"prompts": prompts,
"max_tokens": max_tokens,
"temperature": temperature,
})
return json.dumps(result, indent=2)
# ============================================================================
# Utility Tools
# ============================================================================
async def rl_list_runs() -> str:
"""
List all training runs (active and completed).
Returns:
JSON string with list of runs and their status
"""
result = await _make_request("GET", "/runs")
return json.dumps(result, indent=2)
# ============================================================================
# Requirements Check
# ============================================================================
def check_rl_api_keys() -> bool:
"""
Check if required API keys are available in environment variables.
Required:
- TINKER_API_KEY: For Tinker training service
- WANDB_API_KEY: For metrics logging and fetching
Returns:
bool: True if all required keys are set, False otherwise
"""
tinker_key = os.getenv("TINKER_API_KEY")
wandb_key = os.getenv("WANDB_API_KEY")
return bool(tinker_key) and bool(wandb_key)
def get_missing_keys() -> List[str]:
"""
Get list of missing required API keys.
Returns:
List of missing key names
"""
missing = []
if not os.getenv("TINKER_API_KEY"):
missing.append("TINKER_API_KEY")
if not os.getenv("WANDB_API_KEY"):
missing.append("WANDB_API_KEY")
return missing
# ============================================================================
# Debug/Status
# ============================================================================
async def rl_health_check() -> str:
"""
Check if the RL API server is running and accessible.
Returns:
JSON string with server health status
"""
result = await _make_request("GET", "/health")
return json.dumps(result, indent=2)