Merge: custom providers instant activation + model persistence
This commit is contained in:
@@ -777,11 +777,14 @@ def cmd_model(args):
|
||||
# Generate a stable key from the name
|
||||
key = "custom:" + name.lower().replace(" ", "-")
|
||||
short_url = base_url.replace("https://", "").replace("http://", "").rstrip("/")
|
||||
providers.append((key, f"{name} ({short_url})"))
|
||||
saved_model = entry.get("model", "")
|
||||
model_hint = f" — {saved_model}" if saved_model else ""
|
||||
providers.append((key, f"{name} ({short_url}){model_hint}"))
|
||||
_custom_provider_map[key] = {
|
||||
"name": name,
|
||||
"base_url": base_url,
|
||||
"api_key": entry.get("api_key", ""),
|
||||
"model": saved_model,
|
||||
}
|
||||
|
||||
# Always add the manual custom endpoint option last
|
||||
@@ -1099,14 +1102,15 @@ def _model_flow_custom(config):
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
# Auto-save to custom_providers so it appears in the menu next time
|
||||
_save_custom_provider(effective_url, effective_key)
|
||||
_save_custom_provider(effective_url, effective_key, model_name or "")
|
||||
|
||||
|
||||
def _save_custom_provider(base_url, api_key=""):
|
||||
def _save_custom_provider(base_url, api_key="", model=""):
|
||||
"""Save a custom endpoint to custom_providers in config.yaml.
|
||||
|
||||
Deduplicates by base_url — if the URL already exists, silently skips.
|
||||
Auto-generates a name from the URL hostname.
|
||||
Deduplicates by base_url — if the URL already exists, updates the
|
||||
model name but doesn't add a duplicate entry.
|
||||
Auto-generates a display name from the URL hostname.
|
||||
"""
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
@@ -1115,10 +1119,14 @@ def _save_custom_provider(base_url, api_key=""):
|
||||
if not isinstance(providers, list):
|
||||
providers = []
|
||||
|
||||
# Check if this URL is already saved
|
||||
# Check if this URL is already saved — update model if so
|
||||
for entry in providers:
|
||||
if isinstance(entry, dict) and entry.get("base_url", "").rstrip("/") == base_url.rstrip("/"):
|
||||
return # already saved
|
||||
if model and entry.get("model") != model:
|
||||
entry["model"] = model
|
||||
cfg["custom_providers"] = providers
|
||||
save_config(cfg)
|
||||
return # already saved, updated model if needed
|
||||
|
||||
# Auto-generate a name from the URL
|
||||
import re
|
||||
@@ -1138,6 +1146,8 @@ def _save_custom_provider(base_url, api_key=""):
|
||||
entry = {"name": name, "base_url": base_url}
|
||||
if api_key:
|
||||
entry["api_key"] = api_key
|
||||
if model:
|
||||
entry["model"] = model
|
||||
|
||||
providers.append(entry)
|
||||
cfg["custom_providers"] = providers
|
||||
@@ -1201,7 +1211,11 @@ def _remove_custom_provider(config):
|
||||
|
||||
|
||||
def _model_flow_named_custom(config, provider_info):
|
||||
"""Handle a named custom provider from config.yaml custom_providers list."""
|
||||
"""Handle a named custom provider from config.yaml custom_providers list.
|
||||
|
||||
If the entry has a saved model name, activates it immediately.
|
||||
Otherwise probes the endpoint's /models API to let the user pick one.
|
||||
"""
|
||||
from hermes_cli.auth import _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import save_env_value, load_config, save_config
|
||||
from hermes_cli.models import fetch_api_models
|
||||
@@ -1209,18 +1223,36 @@ def _model_flow_named_custom(config, provider_info):
|
||||
name = provider_info["name"]
|
||||
base_url = provider_info["base_url"]
|
||||
api_key = provider_info.get("api_key", "")
|
||||
saved_model = provider_info.get("model", "")
|
||||
|
||||
# If a model is saved, just activate immediately — no probing needed
|
||||
if saved_model:
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
_save_model_choice(saved_model)
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = "custom"
|
||||
model["base_url"] = base_url
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"✅ Switched to: {saved_model}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
return
|
||||
|
||||
# No saved model — probe endpoint and let user pick
|
||||
print(f" Provider: {name}")
|
||||
print(f" URL: {base_url}")
|
||||
print()
|
||||
|
||||
# Try to fetch available models from the endpoint
|
||||
print("Fetching available models...")
|
||||
print("No model saved for this provider. Fetching available models...")
|
||||
models = fetch_api_models(api_key, base_url, timeout=8.0)
|
||||
|
||||
if models:
|
||||
print(f"Found {len(models)} model(s):\n")
|
||||
# Show model selection menu
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu_items = [f" {m}" for m in models] + [" Cancel"]
|
||||
@@ -1238,7 +1270,6 @@ def _model_flow_named_custom(config, provider_info):
|
||||
return
|
||||
model_name = models[idx]
|
||||
except (ImportError, NotImplementedError):
|
||||
# Fallback: numbered list
|
||||
for i, m in enumerate(models, 1):
|
||||
print(f" {i}. {m}")
|
||||
print(f" {len(models) + 1}. Cancel")
|
||||
@@ -1267,13 +1298,12 @@ def _model_flow_named_custom(config, provider_info):
|
||||
print("No model specified. Cancelled.")
|
||||
return
|
||||
|
||||
# Save the endpoint + model
|
||||
# Activate and save the model to the custom_providers entry
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
_save_model_choice(model_name)
|
||||
|
||||
# Update config
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
@@ -1282,6 +1312,9 @@ def _model_flow_named_custom(config, provider_info):
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
# Save model name to the custom_providers entry for next time
|
||||
_save_custom_provider(base_url, api_key, model_name)
|
||||
|
||||
print(f"\n✅ Model set to: {model_name}")
|
||||
print(f" Provider: {name} ({base_url})")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user