Add tinker-atropos submodule and update RL training tools
- Added the tinker-atropos submodule for enhanced RL training capabilities. - Updated model_tools.py to reorder RL function definitions and improve descriptions. - Modified rl_cli.py to include checks for the tinker-atropos setup and provide user guidance. - Adjusted toolsets.py and __init__.py to reflect changes in RL function availability. - Enhanced rl_training_tool.py to manage training processes directly without a separate API server.
This commit is contained in:
@@ -49,9 +49,8 @@ from tools.rl_training_tool import (
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_test_inference,
|
||||
rl_list_runs,
|
||||
rl_health_check,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
)
|
||||
# Cronjob management tools (CLI-only)
|
||||
@@ -153,7 +152,7 @@ TOOLSET_REQUIREMENTS = {
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs",
|
||||
"rl_list_runs", "rl_test_inference",
|
||||
],
|
||||
},
|
||||
}
|
||||
@@ -574,7 +573,7 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_start_training",
|
||||
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!",
|
||||
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
@@ -636,39 +635,39 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_test_inference",
|
||||
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.",
|
||||
"name": "rl_list_runs",
|
||||
"description": "List all training runs (active and completed) with their status.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of test prompts to run through the environment"
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens to generate per prompt",
|
||||
"default": 256
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Sampling temperature",
|
||||
"default": 1.0
|
||||
}
|
||||
},
|
||||
"required": ["prompts"]
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_runs",
|
||||
"description": "List all training runs (active and completed) with their status.",
|
||||
"name": "rl_test_inference",
|
||||
"description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps × 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"properties": {
|
||||
"num_steps": {
|
||||
"type": "integer",
|
||||
"description": "Number of steps to run (default: 3, recommended max for testing)",
|
||||
"default": 3
|
||||
},
|
||||
"group_size": {
|
||||
"type": "integer",
|
||||
"description": "Completions per step (default: 16, like training)",
|
||||
"default": 16
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, zhipu-ai/glm-4-flash, minimax/minimax-m1"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -731,7 +730,7 @@ def get_all_tool_names() -> List[str]:
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
"rl_list_runs"
|
||||
])
|
||||
|
||||
return tool_names
|
||||
@@ -782,7 +781,6 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_test_inference": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
}
|
||||
|
||||
@@ -900,7 +898,7 @@ def get_tool_definitions(
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
"rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
@@ -952,7 +950,7 @@ def get_tool_definitions(
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
"rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
@@ -1325,18 +1323,18 @@ def handle_rl_function_call(
|
||||
rl_get_results(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_list_runs":
|
||||
return loop.run_until_complete(rl_list_runs())
|
||||
|
||||
elif function_name == "rl_test_inference":
|
||||
return loop.run_until_complete(
|
||||
rl_test_inference(
|
||||
prompts=function_args.get("prompts", []),
|
||||
max_tokens=function_args.get("max_tokens", 256),
|
||||
temperature=function_args.get("temperature", 1.0)
|
||||
num_steps=function_args.get("num_steps", 3),
|
||||
group_size=function_args.get("group_size", 16),
|
||||
models=function_args.get("models"),
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_list_runs":
|
||||
return loop.run_until_complete(rl_list_runs())
|
||||
|
||||
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -1409,7 +1407,7 @@ def handle_function_call(
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
"rl_list_runs"
|
||||
]:
|
||||
return handle_rl_function_call(function_name, function_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user