Add RL training configuration and tools
- Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training. - Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process. - Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB. - Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks. - Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management.
This commit is contained in:
18
.env.example
18
.env.example
@@ -165,3 +165,21 @@ IMAGE_TOOLS_DEBUG=false
|
||||
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
||||
56
README.md
56
README.md
@@ -74,6 +74,7 @@ You need at least one LLM provider:
|
||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||
|
||||
---
|
||||
@@ -270,6 +271,61 @@ When enabled, you'll see messages like:
|
||||
|
||||
See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
||||
|
||||
### 🤖 RL Training (Tinker + Atropos)
|
||||
|
||||
Train language models with reinforcement learning using the Tinker API and Atropos framework.
|
||||
|
||||
#### Requirements
|
||||
|
||||
1. **API Keys:** Add to `~/.hermes/.env`:
|
||||
```bash
|
||||
TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys
|
||||
WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize
|
||||
```
|
||||
|
||||
2. **Install tinker-atropos:** (in a separate directory)
|
||||
```bash
|
||||
cd ~/tinker-atropos
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
3. **Start the RL API server:**
|
||||
```bash
|
||||
rl-server # Runs on port 8080 by default
|
||||
```
|
||||
|
||||
#### Using RL Tools
|
||||
|
||||
The agent can now use RL training tools:
|
||||
|
||||
```
|
||||
You: Start training on GSM8k with group_size=16
|
||||
|
||||
Agent: I'll set up an RL training run on the GSM8k environment...
|
||||
[Uses rl_list_environments, rl_select_environment, rl_edit_config, rl_start_training]
|
||||
```
|
||||
|
||||
#### Available RL Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `rl_list_environments` | List available RL environments |
|
||||
| `rl_select_environment` | Select an environment for training |
|
||||
| `rl_get_current_config` | View all configurable options |
|
||||
| `rl_edit_config` | Change a configuration value |
|
||||
| `rl_start_training` | Start a training run |
|
||||
| `rl_check_status` | Check training progress |
|
||||
| `rl_stop_training` | Stop a running training |
|
||||
| `rl_get_results` | Fetch WandB metrics |
|
||||
|
||||
#### Dedicated RL CLI
|
||||
|
||||
For extended RL workflows with longer timeouts:
|
||||
|
||||
```bash
|
||||
python rl_cli.py --model "anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### ⏰ Scheduled Tasks (Cron)
|
||||
|
||||
Schedule tasks to run automatically:
|
||||
|
||||
@@ -151,6 +151,20 @@ OPTIONAL_ENV_VARS = {
|
||||
"tools": ["image_generate"],
|
||||
"password": True,
|
||||
},
|
||||
"TINKER_API_KEY": {
|
||||
"description": "Tinker API key for RL training",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
||||
"password": True,
|
||||
},
|
||||
"WANDB_API_KEY": {
|
||||
"description": "Weights & Biases API key for experiment tracking",
|
||||
"prompt": "WandB API key",
|
||||
"url": "https://wandb.ai/authorize",
|
||||
"tools": ["rl_get_results", "rl_check_status"],
|
||||
"password": True,
|
||||
},
|
||||
"OPENAI_BASE_URL": {
|
||||
"description": "Custom OpenAI-compatible API endpoint URL",
|
||||
"prompt": "API base URL (e.g., https://api.example.com/v1)",
|
||||
|
||||
@@ -186,6 +186,14 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
else:
|
||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
elif get_env_value('TINKER_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
||||
else:
|
||||
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
||||
|
||||
# Terminal (always available if system deps met)
|
||||
tool_status.append(("Terminal/Commands", True, None))
|
||||
|
||||
@@ -932,6 +940,47 @@ def run_setup_wizard(args):
|
||||
if api_key:
|
||||
save_env_value("FAL_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Tinker + WandB - RL Training
|
||||
print_info("─" * 50)
|
||||
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||
print_info(" Enables: rl_start_training, rl_check_status, rl_get_results tools")
|
||||
print_info(" Use case: Run reinforcement learning training via Tinker API")
|
||||
tinker_configured = get_env_value('TINKER_API_KEY')
|
||||
wandb_configured = get_env_value('WANDB_API_KEY')
|
||||
|
||||
if tinker_configured and wandb_configured:
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update RL training credentials?", False):
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
if tinker_configured:
|
||||
print_warning(" Status: Tinker configured, WandB missing")
|
||||
elif wandb_configured:
|
||||
print_warning(" Status: WandB configured, Tinker missing")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
|
||||
if prompt_yes_no(" Set up RL Training?", False):
|
||||
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
if api_key and wandb_key:
|
||||
print_success(" Configured ✓")
|
||||
else:
|
||||
print_warning(" Partially configured (both keys required)")
|
||||
|
||||
# =========================================================================
|
||||
# Save config and show summary
|
||||
|
||||
@@ -74,6 +74,8 @@ def show_status(args):
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
}
|
||||
|
||||
for name, env_var in keys.items():
|
||||
|
||||
@@ -554,13 +554,13 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_edit_config",
|
||||
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
|
||||
"description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "Name of the field to update"
|
||||
"description": "Name of the field to update (get available fields from rl_get_current_config)"
|
||||
},
|
||||
"value": {
|
||||
"description": "New value for the field"
|
||||
@@ -574,26 +574,10 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_start_training",
|
||||
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
|
||||
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"wandb_project": {
|
||||
"type": "string",
|
||||
"description": "WandB project name for logging",
|
||||
"default": "rl-training"
|
||||
},
|
||||
"lora_rank": {
|
||||
"type": "integer",
|
||||
"description": "LoRA rank for training",
|
||||
"default": 32
|
||||
},
|
||||
"learning_rate": {
|
||||
"type": "number",
|
||||
"description": "Learning rate",
|
||||
"default": 4e-5
|
||||
}
|
||||
},
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -1324,13 +1308,7 @@ def handle_rl_function_call(
|
||||
)
|
||||
|
||||
elif function_name == "rl_start_training":
|
||||
return loop.run_until_complete(
|
||||
rl_start_training(
|
||||
wandb_project=function_args.get("wandb_project", "rl-training"),
|
||||
lora_rank=function_args.get("lora_rank", 32),
|
||||
learning_rate=function_args.get("learning_rate", 4e-5)
|
||||
)
|
||||
)
|
||||
return loop.run_until_complete(rl_start_training())
|
||||
|
||||
elif function_name == "rl_check_status":
|
||||
return loop.run_until_complete(
|
||||
|
||||
@@ -168,20 +168,22 @@ async def rl_get_current_config() -> str:
|
||||
"""
|
||||
Get the current environment configuration.
|
||||
|
||||
Returns only the fields that are safe to modify. Other fields
|
||||
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
|
||||
Returns all configurable fields for the selected environment.
|
||||
Each environment may have different configuration options.
|
||||
|
||||
Available fields:
|
||||
- group_size: Rollouts per prompt (4-16 typical)
|
||||
- max_token_length: Max generation tokens (2048-16384)
|
||||
- total_steps: Training steps (50-2000)
|
||||
- steps_per_eval: Steps between evaluations
|
||||
- use_wandb: Enable WandB logging
|
||||
Fields are divided into:
|
||||
- configurable_fields: Can be changed with rl_edit_config()
|
||||
- locked_fields: Infrastructure settings that cannot be changed
|
||||
|
||||
Common configurable fields include:
|
||||
- group_size: Rollouts per prompt
|
||||
- batch_size: Training batch size
|
||||
- wandb_name: WandB run name prefix
|
||||
- max_num_workers: Concurrent workers (-1 = auto)
|
||||
- system_prompt: Model instructions
|
||||
- And any environment-specific options
|
||||
|
||||
Returns:
|
||||
JSON string with current config fields and their values
|
||||
JSON string with configurable and locked fields
|
||||
"""
|
||||
result = await _make_request("GET", "/config")
|
||||
return json.dumps(result, indent=2)
|
||||
@@ -191,21 +193,15 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
||||
"""
|
||||
Update a configuration field.
|
||||
|
||||
Only exposed fields can be modified. Validates field name and type.
|
||||
Use rl_get_current_config() first to see available fields for the
|
||||
selected environment. Each environment has different options.
|
||||
|
||||
Locked fields (infrastructure settings) cannot be changed.
|
||||
|
||||
Args:
|
||||
field: Name of the field to update (e.g., "group_size", "total_steps")
|
||||
field: Name of the field to update (from rl_get_current_config)
|
||||
value: New value for the field
|
||||
|
||||
Valid fields:
|
||||
- group_size (int): Rollouts per prompt
|
||||
- max_token_length (int): Max generation tokens
|
||||
- total_steps (int): Training steps
|
||||
- steps_per_eval (int): Eval frequency
|
||||
- use_wandb (bool): Enable logging
|
||||
- wandb_name (str): Run name prefix
|
||||
- max_num_workers (int): Workers count
|
||||
|
||||
Returns:
|
||||
JSON string with updated config or error message
|
||||
"""
|
||||
@@ -217,37 +213,28 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
||||
# Training Management Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_start_training(
|
||||
wandb_project: str = "rl-training",
|
||||
lora_rank: int = 32,
|
||||
learning_rate: float = 4e-5,
|
||||
) -> str:
|
||||
async def rl_start_training() -> str:
|
||||
"""
|
||||
Start a new RL training run with the current environment and config.
|
||||
|
||||
Requires an environment to be selected first using rl_select_environment().
|
||||
Use rl_edit_config() to set group_size, batch_size, wandb_project before starting.
|
||||
|
||||
WARNING: Training runs can take hours to days. Use rl_check_status() to
|
||||
monitor progress (recommended: check every 30 minutes at most).
|
||||
Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.)
|
||||
and cannot be changed.
|
||||
|
||||
Args:
|
||||
wandb_project: WandB project name for logging
|
||||
lora_rank: LoRA rank for training (default: 32)
|
||||
learning_rate: Learning rate (default: 4e-5)
|
||||
WARNING: Training runs take hours. Use rl_check_status() to monitor
|
||||
progress (recommended: check every 30 minutes at most).
|
||||
|
||||
Returns:
|
||||
JSON string with run_id and initial status
|
||||
|
||||
TIP: Before starting training:
|
||||
1. Test with rl_test_inference() to verify the environment works
|
||||
2. Start with fewer total_steps to validate the setup
|
||||
2. Configure group_size and batch_size appropriately
|
||||
3. Monitor WandB metrics for reward/mean and percent_correct
|
||||
"""
|
||||
result = await _make_request("POST", "/runs", {
|
||||
"wandb_project": wandb_project,
|
||||
"lora_rank": lora_rank,
|
||||
"learning_rate": learning_rate,
|
||||
})
|
||||
result = await _make_request("POST", "/runs", {})
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user