forked from Rockachopa/Timmy-time-dashboard
268 lines
10 KiB
Python
268 lines
10 KiB
Python
"""Custom model registry — register, load, and manage model weights.
|
|
|
|
Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles)
|
|
and their assignment to swarm agents. Models can be registered at
|
|
runtime via the API or pre-configured via providers.yaml.
|
|
|
|
Inspired by OpenClaw-RL's multi-model orchestration where distinct
|
|
model roles (student, teacher, judge/PRM) run on dedicated resources.
|
|
"""
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
from collections.abc import Generator
|
|
from contextlib import closing, contextmanager
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from enum import StrEnum
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DB_PATH = Path("data/swarm.db")
|
|
|
|
|
|
class ModelFormat(StrEnum):
|
|
"""Supported model weight formats."""
|
|
|
|
GGUF = "gguf" # Ollama-compatible quantised weights
|
|
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
|
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
|
OLLAMA = "ollama" # Already loaded in Ollama by name
|
|
|
|
|
|
class ModelRole(StrEnum):
|
|
"""Role a model can play in the system (OpenClaw-RL style)."""
|
|
|
|
GENERAL = "general" # Default agent inference
|
|
REWARD = "reward" # Process Reward Model (PRM) scoring
|
|
TEACHER = "teacher" # On-policy distillation teacher
|
|
JUDGE = "judge" # Output quality evaluation
|
|
|
|
|
|
@dataclass
|
|
class CustomModel:
|
|
"""A registered custom model."""
|
|
|
|
name: str
|
|
format: ModelFormat
|
|
path: str # Absolute path or Ollama model name
|
|
role: ModelRole = ModelRole.GENERAL
|
|
context_window: int = 4096
|
|
description: str = ""
|
|
registered_at: str = ""
|
|
active: bool = True
|
|
# Per-model generation settings
|
|
default_temperature: float = 0.7
|
|
max_tokens: int = 2048
|
|
|
|
def __post_init__(self):
|
|
if not self.registered_at:
|
|
self.registered_at = datetime.now(UTC).isoformat()
|
|
|
|
|
|
@contextmanager
|
|
def _get_conn() -> Generator[sqlite3.Connection, None, None]:
|
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
with closing(sqlite3.connect(str(DB_PATH))) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute("PRAGMA busy_timeout=5000")
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS custom_models (
|
|
name TEXT PRIMARY KEY,
|
|
format TEXT NOT NULL,
|
|
path TEXT NOT NULL,
|
|
role TEXT NOT NULL DEFAULT 'general',
|
|
context_window INTEGER NOT NULL DEFAULT 4096,
|
|
description TEXT NOT NULL DEFAULT '',
|
|
registered_at TEXT NOT NULL,
|
|
active INTEGER NOT NULL DEFAULT 1,
|
|
default_temperature REAL NOT NULL DEFAULT 0.7,
|
|
max_tokens INTEGER NOT NULL DEFAULT 2048
|
|
)
|
|
""")
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS agent_model_assignments (
|
|
agent_id TEXT PRIMARY KEY,
|
|
model_name TEXT NOT NULL,
|
|
assigned_at TEXT NOT NULL,
|
|
FOREIGN KEY (model_name) REFERENCES custom_models(name)
|
|
)
|
|
""")
|
|
conn.commit()
|
|
yield conn
|
|
|
|
|
|
class ModelRegistry:
|
|
"""Singleton registry for custom models and agent-model assignments."""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = threading.Lock()
|
|
# In-memory cache for fast lookups
|
|
self._models: dict[str, CustomModel] = {}
|
|
self._agent_assignments: dict[str, str] = {}
|
|
self._load_from_db()
|
|
|
|
def _load_from_db(self) -> None:
|
|
"""Bootstrap cache from SQLite."""
|
|
try:
|
|
with _get_conn() as conn:
|
|
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
|
|
self._models[row["name"]] = CustomModel(
|
|
name=row["name"],
|
|
format=ModelFormat(row["format"]),
|
|
path=row["path"],
|
|
role=ModelRole(row["role"]),
|
|
context_window=row["context_window"],
|
|
description=row["description"],
|
|
registered_at=row["registered_at"],
|
|
active=bool(row["active"]),
|
|
default_temperature=row["default_temperature"],
|
|
max_tokens=row["max_tokens"],
|
|
)
|
|
for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall():
|
|
self._agent_assignments[row["agent_id"]] = row["model_name"]
|
|
except Exception as exc:
|
|
logger.warning("Failed to load model registry from DB: %s", exc)
|
|
|
|
# ── Model CRUD ─────────────────────────────────────────────────────────
|
|
|
|
def register(self, model: CustomModel) -> CustomModel:
|
|
"""Register a new custom model."""
|
|
with self._lock:
|
|
with _get_conn() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO custom_models
|
|
(name, format, path, role, context_window, description,
|
|
registered_at, active, default_temperature, max_tokens)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
model.name,
|
|
model.format.value,
|
|
model.path,
|
|
model.role.value,
|
|
model.context_window,
|
|
model.description,
|
|
model.registered_at,
|
|
int(model.active),
|
|
model.default_temperature,
|
|
model.max_tokens,
|
|
),
|
|
)
|
|
conn.commit()
|
|
self._models[model.name] = model
|
|
logger.info("Registered model: %s (%s)", model.name, model.format.value)
|
|
return model
|
|
|
|
def unregister(self, name: str) -> bool:
|
|
"""Remove a model from the registry."""
|
|
with self._lock:
|
|
if name not in self._models:
|
|
return False
|
|
with _get_conn() as conn:
|
|
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
|
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
|
|
conn.commit()
|
|
del self._models[name]
|
|
# Remove any agent assignments using this model
|
|
self._agent_assignments = {
|
|
k: v for k, v in self._agent_assignments.items() if v != name
|
|
}
|
|
logger.info("Unregistered model: %s", name)
|
|
return True
|
|
|
|
def get(self, name: str) -> CustomModel | None:
|
|
"""Look up a model by name."""
|
|
return self._models.get(name)
|
|
|
|
def list_models(self, role: ModelRole | None = None) -> list[CustomModel]:
|
|
"""List all registered models, optionally filtered by role."""
|
|
models = list(self._models.values())
|
|
if role is not None:
|
|
models = [m for m in models if m.role == role]
|
|
return models
|
|
|
|
def set_active(self, name: str, active: bool) -> bool:
|
|
"""Enable or disable a model without removing it."""
|
|
model = self._models.get(name)
|
|
if not model:
|
|
return False
|
|
with self._lock:
|
|
model.active = active
|
|
with _get_conn() as conn:
|
|
conn.execute(
|
|
"UPDATE custom_models SET active = ? WHERE name = ?",
|
|
(int(active), name),
|
|
)
|
|
conn.commit()
|
|
return True
|
|
|
|
# ── Agent-model assignments ────────────────────────────────────────────
|
|
|
|
def assign_model(self, agent_id: str, model_name: str) -> bool:
|
|
"""Assign a specific model to an agent."""
|
|
if model_name not in self._models:
|
|
return False
|
|
with self._lock:
|
|
now = datetime.now(UTC).isoformat()
|
|
with _get_conn() as conn:
|
|
conn.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO agent_model_assignments
|
|
(agent_id, model_name, assigned_at)
|
|
VALUES (?, ?, ?)
|
|
""",
|
|
(agent_id, model_name, now),
|
|
)
|
|
conn.commit()
|
|
self._agent_assignments[agent_id] = model_name
|
|
logger.info("Assigned model %s to agent %s", model_name, agent_id)
|
|
return True
|
|
|
|
def unassign_model(self, agent_id: str) -> bool:
|
|
"""Remove model assignment from an agent (falls back to default)."""
|
|
with self._lock:
|
|
if agent_id not in self._agent_assignments:
|
|
return False
|
|
with _get_conn() as conn:
|
|
conn.execute(
|
|
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
|
|
(agent_id,),
|
|
)
|
|
conn.commit()
|
|
del self._agent_assignments[agent_id]
|
|
return True
|
|
|
|
def get_agent_model(self, agent_id: str) -> CustomModel | None:
|
|
"""Get the model assigned to an agent, or None for default."""
|
|
model_name = self._agent_assignments.get(agent_id)
|
|
if model_name:
|
|
return self._models.get(model_name)
|
|
return None
|
|
|
|
def get_agent_assignments(self) -> dict[str, str]:
|
|
"""Return all agent-to-model assignments."""
|
|
return dict(self._agent_assignments)
|
|
|
|
# ── Role-based lookups ─────────────────────────────────────────────────
|
|
|
|
def get_reward_model(self) -> CustomModel | None:
|
|
"""Get the active reward/PRM model, if any."""
|
|
reward_models = self.list_models(role=ModelRole.REWARD)
|
|
active = [m for m in reward_models if m.active]
|
|
return active[0] if active else None
|
|
|
|
def get_teacher_model(self) -> CustomModel | None:
|
|
"""Get the active teacher model for distillation."""
|
|
teacher_models = self.list_models(role=ModelRole.TEACHER)
|
|
active = [m for m in teacher_models if m.active]
|
|
return active[0] if active else None
|
|
|
|
|
|
# Module-level singleton
|
|
model_registry = ModelRegistry()
|