Add tinker-atropos submodule and update RL training tools

- Added the tinker-atropos submodule for enhanced RL training capabilities.
- Updated model_tools.py to reorder RL function definitions and improve descriptions.
- Modified rl_cli.py to include checks for the tinker-atropos setup and provide user guidance.
- Adjusted toolsets.py and __init__.py to reflect changes in RL function availability.
- Enhanced rl_training_tool.py to manage training processes directly without a separate API server.
This commit is contained in:
teknium1
2026-02-04 10:36:01 -08:00
parent f6574978de
commit 12bbca95ec
7 changed files with 1059 additions and 256 deletions

3
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "mini-swe-agent"] [submodule "mini-swe-agent"]
path = mini-swe-agent path = mini-swe-agent
url = https://github.com/SWE-agent/mini-swe-agent url = https://github.com/SWE-agent/mini-swe-agent
[submodule "tinker-atropos"]
path = tinker-atropos
url = https://github.com/nousresearch/tinker-atropos

View File

@@ -49,9 +49,8 @@ from tools.rl_training_tool import (
rl_check_status, rl_check_status,
rl_stop_training, rl_stop_training,
rl_get_results, rl_get_results,
rl_test_inference,
rl_list_runs, rl_list_runs,
rl_health_check, rl_test_inference,
check_rl_api_keys, check_rl_api_keys,
) )
# Cronjob management tools (CLI-only) # Cronjob management tools (CLI-only)
@@ -153,7 +152,7 @@ TOOLSET_REQUIREMENTS = {
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs", "rl_list_runs", "rl_test_inference",
], ],
}, },
} }
@@ -574,7 +573,7 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
"type": "function", "type": "function",
"function": { "function": {
"name": "rl_start_training", "name": "rl_start_training",
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!", "description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {}, "properties": {},
@@ -636,39 +635,39 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "rl_test_inference", "name": "rl_list_runs",
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.", "description": "List all training runs (active and completed) with their status.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {},
"prompts": { "required": []
"type": "array",
"items": {"type": "string"},
"description": "List of test prompts to run through the environment"
},
"max_tokens": {
"type": "integer",
"description": "Maximum tokens to generate per prompt",
"default": 256
},
"temperature": {
"type": "number",
"description": "Sampling temperature",
"default": 1.0
}
},
"required": ["prompts"]
} }
} }
}, },
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "rl_list_runs", "name": "rl_test_inference",
"description": "List all training runs (active and completed) with their status.", "description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps × 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": {}, "properties": {
"num_steps": {
"type": "integer",
"description": "Number of steps to run (default: 3, recommended max for testing)",
"default": 3
},
"group_size": {
"type": "integer",
"description": "Completions per step (default: 16, like training)",
"default": 16
},
"models": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, zhipu-ai/glm-4-flash, minimax/minimax-m1"
}
},
"required": [] "required": []
} }
} }
@@ -731,7 +730,7 @@ def get_all_tool_names() -> List[str]:
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs" "rl_list_runs"
]) ])
return tool_names return tool_names
@@ -782,7 +781,6 @@ def get_toolset_for_tool(tool_name: str) -> str:
"rl_check_status": "rl_tools", "rl_check_status": "rl_tools",
"rl_stop_training": "rl_tools", "rl_stop_training": "rl_tools",
"rl_get_results": "rl_tools", "rl_get_results": "rl_tools",
"rl_test_inference": "rl_tools",
"rl_list_runs": "rl_tools", "rl_list_runs": "rl_tools",
} }
@@ -900,7 +898,7 @@ def get_tool_definitions(
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs" "rl_list_runs"
] ]
} }
legacy_tools = legacy_map.get(toolset_name, []) legacy_tools = legacy_map.get(toolset_name, [])
@@ -952,7 +950,7 @@ def get_tool_definitions(
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs" "rl_list_runs"
] ]
} }
legacy_tools = legacy_map.get(toolset_name, []) legacy_tools = legacy_map.get(toolset_name, [])
@@ -1325,18 +1323,18 @@ def handle_rl_function_call(
rl_get_results(run_id=function_args.get("run_id", "")) rl_get_results(run_id=function_args.get("run_id", ""))
) )
elif function_name == "rl_list_runs":
return loop.run_until_complete(rl_list_runs())
elif function_name == "rl_test_inference": elif function_name == "rl_test_inference":
return loop.run_until_complete( return loop.run_until_complete(
rl_test_inference( rl_test_inference(
prompts=function_args.get("prompts", []), num_steps=function_args.get("num_steps", 3),
max_tokens=function_args.get("max_tokens", 256), group_size=function_args.get("group_size", 16),
temperature=function_args.get("temperature", 1.0) models=function_args.get("models"),
) )
) )
elif function_name == "rl_list_runs":
return loop.run_until_complete(rl_list_runs())
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False) return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
@@ -1409,7 +1407,7 @@ def handle_function_call(
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs" "rl_list_runs"
]: ]:
return handle_rl_function_call(function_name, function_args) return handle_rl_function_call(function_name, function_args)

View File

@@ -16,7 +16,6 @@ Usage:
Environment Variables: Environment Variables:
TINKER_API_KEY: API key for Tinker service (required) TINKER_API_KEY: API key for Tinker service (required)
WANDB_API_KEY: API key for WandB metrics (required) WANDB_API_KEY: API key for WandB metrics (required)
RL_API_URL: URL of RL API server (default: http://localhost:8080)
OPENROUTER_API_KEY: API key for OpenRouter (required for agent) OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
""" """
@@ -38,7 +37,7 @@ if env_path.exists():
# Import agent and tools # Import agent and tools
from run_agent import AIAgent from run_agent import AIAgent
from model_tools import get_tool_definitions, check_toolset_requirements from model_tools import get_tool_definitions, check_toolset_requirements
from tools.rl_training_tool import check_rl_api_keys, get_missing_keys, rl_health_check from tools.rl_training_tool import check_rl_api_keys, get_missing_keys
# ============================================================================ # ============================================================================
@@ -138,17 +137,21 @@ def check_requirements():
return True return True
async def check_rl_server(): def check_tinker_atropos():
"""Check if the RL API server is running.""" """Check if tinker-atropos submodule is properly set up."""
try: tinker_path = Path(__file__).parent / "tinker-atropos"
result = await rl_health_check()
import json if not tinker_path.exists():
data = json.loads(result) return False, "tinker-atropos submodule not found. Run: git submodule update --init"
if "error" in data:
return False, data["error"] envs_path = tinker_path / "tinker_atropos" / "environments"
return True, data if not envs_path.exists():
except Exception as e: return False, f"environments directory not found at {envs_path}"
return False, str(e)
env_files = list(envs_path.glob("*.py"))
env_files = [f for f in env_files if not f.name.startswith("_")]
return True, {"path": str(tinker_path), "environments_count": len(env_files)}
def list_environments_sync(): def list_environments_sync():
@@ -210,19 +213,27 @@ def main(
print("🎯 RL Training Agent") print("🎯 RL Training Agent")
print("=" * 60) print("=" * 60)
# Handle server check # Handle setup check
if check_server: if check_server:
print("\n🔍 Checking RL API server...") print("\n🔍 Checking tinker-atropos setup...")
ok, result = asyncio.run(check_rl_server()) ok, result = check_tinker_atropos()
if ok: if ok:
print("RL API server is running") print("tinker-atropos submodule found")
print(f" Environments discovered: {result.get('environments_discovered', 'unknown')}") print(f" Path: {result.get('path')}")
print(f" Current environment: {result.get('current_environment', 'none')}") print(f" Environments found: {result.get('environments_count', 0)}")
print(f" Active runs: {result.get('active_runs', 0)}")
# Also check API keys
missing = get_missing_keys()
if missing:
print(f"\n⚠️ Missing API keys: {', '.join(missing)}")
print(" Add them to ~/.hermes/.env")
else:
print("✅ API keys configured")
else: else:
print(f"RL API server not accessible: {result}") print(f"tinker-atropos not set up: {result}")
print("\nTo start the server:") print("\nTo set up:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") print(" git submodule update --init")
print(" pip install -e ./tinker-atropos")
return return
# Handle environment listing # Handle environment listing
@@ -238,8 +249,8 @@ def main(
envs = data.get("environments", []) envs = data.get("environments", [])
if not envs: if not envs:
print("No environments found.") print("No environments found.")
print("\nMake sure the RL API server is running:") print("\nMake sure tinker-atropos is set up:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") print(" git submodule update --init")
return return
for env in envs: for env in envs:
@@ -254,8 +265,9 @@ def main(
print("\nUse `rl_select_environment(name)` to select an environment for training.") print("\nUse `rl_select_environment(name)` to select an environment for training.")
except Exception as e: except Exception as e:
print(f"❌ Error listing environments: {e}") print(f"❌ Error listing environments: {e}")
print("\nMake sure the RL API server is running:") print("\nMake sure tinker-atropos is set up:")
print(" cd tinker-atropos && uvicorn rl_api_server:app --port 8080") print(" git submodule update --init")
print(" pip install -e ./tinker-atropos")
return return
# Check requirements # Check requirements

1
tinker-atropos Submodule

Submodule tinker-atropos added at 65f084ee80

View File

@@ -105,9 +105,8 @@ from .rl_training_tool import (
rl_check_status, rl_check_status,
rl_stop_training, rl_stop_training,
rl_get_results, rl_get_results,
rl_test_inference,
rl_list_runs, rl_list_runs,
rl_health_check, rl_test_inference,
check_rl_api_keys, check_rl_api_keys,
get_missing_keys, get_missing_keys,
) )
@@ -178,9 +177,8 @@ __all__ = [
'rl_check_status', 'rl_check_status',
'rl_stop_training', 'rl_stop_training',
'rl_get_results', 'rl_get_results',
'rl_test_inference',
'rl_list_runs', 'rl_list_runs',
'rl_health_check', 'rl_test_inference',
'check_rl_api_keys', 'check_rl_api_keys',
'get_missing_keys', 'get_missing_keys',
] ]

File diff suppressed because it is too large Load Diff

View File

@@ -97,7 +97,7 @@ TOOLSETS = {
"rl_get_current_config", "rl_edit_config", "rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status", "rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results", "rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs" "rl_list_runs", "rl_test_inference"
], ],
"includes": [] "includes": []
}, },