diff --git a/config/providers.yaml b/config/providers.yaml index 90630484..f078eacc 100644 --- a/config/providers.yaml +++ b/config/providers.yaml @@ -68,6 +68,37 @@ providers: - name: claude-3-sonnet-20240229 context_window: 200000 +# ── Custom Models ────────────────────────────────────────────────────── +# Register custom model weights for per-agent assignment. +# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs. +# Models can also be registered at runtime via the /api/v1/models API. +# +# Roles: general (default inference), reward (PRM scoring), +# teacher (distillation), judge (output evaluation) +custom_models: [] + # Example entries: + # - name: my-finetuned-llama + # format: gguf + # path: /path/to/model.gguf + # role: general + # context_window: 8192 + # description: "Fine-tuned Llama for code tasks" + # + # - name: reward-model + # format: ollama + # path: deepseek-r1:1.5b + # role: reward + # context_window: 32000 + # description: "Process reward model for scoring outputs" + +# ── Agent Model Assignments ───────────────────────────────────────────── +# Map persona agent IDs to specific models. +# Agents without an assignment use the global default (ollama_model). +agent_model_assignments: {} + # Example: + # persona-forge: my-finetuned-llama + # persona-echo: deepseek-r1:1.5b + # Cost tracking (optional, for budget monitoring) cost_tracking: enabled: true diff --git a/src/config.py b/src/config.py index 30d3c644..d46595c0 100644 --- a/src/config.py +++ b/src/config.py @@ -103,6 +103,17 @@ class Settings(BaseSettings): work_orders_auto_execute: bool = False # Master switch for auto-execution work_orders_auto_threshold: str = "low" # Max priority that auto-executes: "low" | "medium" | "high" | "none" + # ── Custom Weights & Models ────────────────────────────────────── + # Directory for custom model weights (GGUF, safetensors, HF checkpoints). + # Models placed here can be registered at runtime and assigned to agents. + custom_weights_dir: str = "data/models" + # Enable the reward model for scoring agent outputs (PRM-style). + reward_model_enabled: bool = False + # Reward model name (must be available via Ollama or a custom weight path). + reward_model_name: str = "" + # Minimum votes for majority-vote reward scoring (odd number recommended). + reward_model_votes: int = 3 + # ── Browser Local Models (iPhone / WebGPU) ─────────────────────── # Enable in-browser LLM inference via WebLLM for offline iPhone use. # When enabled, the mobile dashboard loads a small model directly diff --git a/src/dashboard/app.py b/src/dashboard/app.py index b0bc7fa5..dbc6f846 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -36,6 +36,8 @@ from dashboard.routes.self_coding import router as self_coding_router from dashboard.routes.self_coding import self_modify_router from dashboard.routes.hands import router as hands_router from dashboard.routes.grok import router as grok_router +from dashboard.routes.models import router as models_router +from dashboard.routes.models import api_router as models_api_router from infrastructure.router.api import router as cascade_router logging.basicConfig( @@ -208,6 +210,8 @@ app.include_router(tasks_router) app.include_router(scripture_router) app.include_router(hands_router) app.include_router(grok_router) +app.include_router(models_router) +app.include_router(models_api_router) app.include_router(cascade_router) diff --git a/src/dashboard/routes/models.py b/src/dashboard/routes/models.py new file mode 100644 index 00000000..77c566e9 --- /dev/null +++ b/src/dashboard/routes/models.py @@ -0,0 +1,272 @@ +"""Custom model management routes — register, list, assign, and swap models. + +Provides a REST API for managing custom model weights and their assignment +to swarm agents. Inspired by OpenClaw-RL's multi-model orchestration. +""" + +import logging +from pathlib import Path +from typing import Any, Optional + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel + +from config import settings +from infrastructure.models.registry import ( + CustomModel, + ModelFormat, + ModelRegistry, + ModelRole, + model_registry, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/models", tags=["models"]) +api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"]) +templates = Jinja2Templates(directory=str(Path(__file__).parent.parent / "templates")) + + +# ── Pydantic schemas ────────────────────────────────────────────────────────── + + +class RegisterModelRequest(BaseModel): + """Request body for model registration.""" + name: str + format: str # gguf, safetensors, hf, ollama + path: str + role: str = "general" + context_window: int = 4096 + description: str = "" + default_temperature: float = 0.7 + max_tokens: int = 2048 + + +class AssignModelRequest(BaseModel): + """Request body for assigning a model to an agent.""" + agent_id: str + model_name: str + + +class SetActiveRequest(BaseModel): + """Request body for enabling/disabling a model.""" + active: bool + + +# ── API endpoints ───────────────────────────────────────────────────────────── + + +@api_router.get("") +async def list_models(role: Optional[str] = None) -> dict[str, Any]: + """List all registered custom models.""" + model_role = ModelRole(role) if role else None + models = model_registry.list_models(role=model_role) + return { + "models": [ + { + "name": m.name, + "format": m.format.value, + "path": m.path, + "role": m.role.value, + "context_window": m.context_window, + "description": m.description, + "active": m.active, + "registered_at": m.registered_at, + "default_temperature": m.default_temperature, + "max_tokens": m.max_tokens, + } + for m in models + ], + "total": len(models), + "weights_dir": settings.custom_weights_dir, + } + + +@api_router.post("") +async def register_model(request: RegisterModelRequest) -> dict[str, Any]: + """Register a new custom model.""" + try: + fmt = ModelFormat(request.format) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid format: {request.format}. " + f"Choose from: {[f.value for f in ModelFormat]}", + ) + try: + role = ModelRole(request.role) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid role: {request.role}. " + f"Choose from: {[r.value for r in ModelRole]}", + ) + + # Validate path exists for non-Ollama formats + if fmt != ModelFormat.OLLAMA: + weight_path = Path(request.path) + if not weight_path.exists(): + raise HTTPException( + status_code=400, + detail=f"Weight path does not exist: {request.path}", + ) + + model = CustomModel( + name=request.name, + format=fmt, + path=request.path, + role=role, + context_window=request.context_window, + description=request.description, + default_temperature=request.default_temperature, + max_tokens=request.max_tokens, + ) + registered = model_registry.register(model) + return { + "message": f"Model {registered.name} registered", + "model": { + "name": registered.name, + "format": registered.format.value, + "role": registered.role.value, + "path": registered.path, + }, + } + + +@api_router.get("/{model_name}") +async def get_model(model_name: str) -> dict[str, Any]: + """Get details of a specific model.""" + model = model_registry.get(model_name) + if not model: + raise HTTPException(status_code=404, detail=f"Model {model_name} not found") + return { + "name": model.name, + "format": model.format.value, + "path": model.path, + "role": model.role.value, + "context_window": model.context_window, + "description": model.description, + "active": model.active, + "registered_at": model.registered_at, + "default_temperature": model.default_temperature, + "max_tokens": model.max_tokens, + } + + +@api_router.delete("/{model_name}") +async def unregister_model(model_name: str) -> dict[str, str]: + """Remove a model from the registry.""" + if not model_registry.unregister(model_name): + raise HTTPException(status_code=404, detail=f"Model {model_name} not found") + return {"message": f"Model {model_name} unregistered"} + + +@api_router.patch("/{model_name}/active") +async def set_model_active( + model_name: str, request: SetActiveRequest +) -> dict[str, str]: + """Enable or disable a model.""" + if not model_registry.set_active(model_name, request.active): + raise HTTPException(status_code=404, detail=f"Model {model_name} not found") + state = "enabled" if request.active else "disabled" + return {"message": f"Model {model_name} {state}"} + + +# ── Agent assignment endpoints ──────────────────────────────────────────────── + + +@api_router.get("/assignments/all") +async def list_assignments() -> dict[str, Any]: + """List all agent-to-model assignments.""" + assignments = model_registry.get_agent_assignments() + return { + "assignments": [ + {"agent_id": aid, "model_name": mname} + for aid, mname in assignments.items() + ], + "total": len(assignments), + } + + +@api_router.post("/assignments") +async def assign_model(request: AssignModelRequest) -> dict[str, str]: + """Assign a model to a swarm agent.""" + if not model_registry.assign_model(request.agent_id, request.model_name): + raise HTTPException( + status_code=404, + detail=f"Model {request.model_name} not found in registry", + ) + return { + "message": f"Model {request.model_name} assigned to {request.agent_id}", + } + + +@api_router.delete("/assignments/{agent_id}") +async def unassign_model(agent_id: str) -> dict[str, str]: + """Remove model assignment from an agent (reverts to default).""" + if not model_registry.unassign_model(agent_id): + raise HTTPException( + status_code=404, + detail=f"No model assignment for agent {agent_id}", + ) + return {"message": f"Model assignment removed for {agent_id}"} + + +# ── Role-based lookups ──────────────────────────────────────────────────────── + + +@api_router.get("/roles/reward") +async def get_reward_model() -> dict[str, Any]: + """Get the active reward (PRM) model.""" + model = model_registry.get_reward_model() + if not model: + return {"reward_model": None, "reward_enabled": settings.reward_model_enabled} + return { + "reward_model": { + "name": model.name, + "format": model.format.value, + "path": model.path, + }, + "reward_enabled": settings.reward_model_enabled, + } + + +@api_router.get("/roles/teacher") +async def get_teacher_model() -> dict[str, Any]: + """Get the active teacher model for distillation.""" + model = model_registry.get_teacher_model() + if not model: + return {"teacher_model": None} + return { + "teacher_model": { + "name": model.name, + "format": model.format.value, + "path": model.path, + }, + } + + +# ── Dashboard page ──────────────────────────────────────────────────────────── + + +@router.get("", response_class=HTMLResponse) +async def models_page(request: Request): + """Custom models management dashboard page.""" + models = model_registry.list_models() + assignments = model_registry.get_agent_assignments() + reward = model_registry.get_reward_model() + + return templates.TemplateResponse( + request, + "models.html", + { + "page_title": "Custom Models", + "models": models, + "assignments": assignments, + "reward_model": reward, + "weights_dir": settings.custom_weights_dir, + "reward_enabled": settings.reward_model_enabled, + }, + ) diff --git a/src/dashboard/templates/models.html b/src/dashboard/templates/models.html new file mode 100644 index 00000000..47ccea3b --- /dev/null +++ b/src/dashboard/templates/models.html @@ -0,0 +1,119 @@ +{% extends "base.html" %} + +{% block title %}Custom Models - Timmy Time{% endblock %} + +{% block content %} +
+
+

Custom Models

+

Manage model weights and agent assignments

+
+ + +
+
+
{{ models|length }}
+
Models
+
+
+
{{ assignments|length }}
+
Assignments
+
+
+
{{ "Yes" if reward_model else "No" }}
+
Reward Model
+
+
+ + +
+

Register Model

+
+ + + + + + + +
+
+
+ + +
+

Registered Models

+ {% if models %} + + + + + + + + + + + + + {% for m in models %} + + + + + + + + + {% endfor %} + +
NameFormatRoleContextActiveActions
{{ m.name }}{{ m.format.value }}{{ m.role.value }}{{ m.context_window }}{{ "Yes" if m.active else "No" }} + +
+ {% else %} +

No custom models registered. Use the form above or the API.

+ {% endif %} +
+ + +
+

Agent Model Assignments

+ {% if assignments %} + + + + + + {% for agent_id, model_name in assignments.items() %} + + + + + {% endfor %} + +
AgentModel
{{ agent_id }}{{ model_name }}
+ {% else %} +

No agent-specific model assignments. All agents use the global default.

+ {% endif %} +
+ +
+

Weights directory: {{ weights_dir }}

+
+
+{% endblock %} diff --git a/src/infrastructure/models/__init__.py b/src/infrastructure/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/infrastructure/models/registry.py b/src/infrastructure/models/registry.py new file mode 100644 index 00000000..b9a568c0 --- /dev/null +++ b/src/infrastructure/models/registry.py @@ -0,0 +1,268 @@ +"""Custom model registry — register, load, and manage model weights. + +Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles) +and their assignment to swarm agents. Models can be registered at +runtime via the API or pre-configured via providers.yaml. + +Inspired by OpenClaw-RL's multi-model orchestration where distinct +model roles (student, teacher, judge/PRM) run on dedicated resources. +""" + +import logging +import sqlite3 +import threading +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Optional + +from config import settings + +logger = logging.getLogger(__name__) + +DB_PATH = Path("data/swarm.db") + + +class ModelFormat(str, Enum): + """Supported model weight formats.""" + GGUF = "gguf" # Ollama-compatible quantised weights + SAFETENSORS = "safetensors" # HuggingFace safetensors + HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory + OLLAMA = "ollama" # Already loaded in Ollama by name + + +class ModelRole(str, Enum): + """Role a model can play in the system (OpenClaw-RL style).""" + GENERAL = "general" # Default agent inference + REWARD = "reward" # Process Reward Model (PRM) scoring + TEACHER = "teacher" # On-policy distillation teacher + JUDGE = "judge" # Output quality evaluation + + +@dataclass +class CustomModel: + """A registered custom model.""" + name: str + format: ModelFormat + path: str # Absolute path or Ollama model name + role: ModelRole = ModelRole.GENERAL + context_window: int = 4096 + description: str = "" + registered_at: str = "" + active: bool = True + # Per-model generation settings + default_temperature: float = 0.7 + max_tokens: int = 2048 + + def __post_init__(self): + if not self.registered_at: + self.registered_at = datetime.now(timezone.utc).isoformat() + + +def _get_conn() -> sqlite3.Connection: + DB_PATH.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(DB_PATH)) + conn.row_factory = sqlite3.Row + conn.execute( + """ + CREATE TABLE IF NOT EXISTS custom_models ( + name TEXT PRIMARY KEY, + format TEXT NOT NULL, + path TEXT NOT NULL, + role TEXT NOT NULL DEFAULT 'general', + context_window INTEGER NOT NULL DEFAULT 4096, + description TEXT NOT NULL DEFAULT '', + registered_at TEXT NOT NULL, + active INTEGER NOT NULL DEFAULT 1, + default_temperature REAL NOT NULL DEFAULT 0.7, + max_tokens INTEGER NOT NULL DEFAULT 2048 + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS agent_model_assignments ( + agent_id TEXT PRIMARY KEY, + model_name TEXT NOT NULL, + assigned_at TEXT NOT NULL, + FOREIGN KEY (model_name) REFERENCES custom_models(name) + ) + """ + ) + conn.commit() + return conn + + +class ModelRegistry: + """Singleton registry for custom models and agent-model assignments.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + # In-memory cache for fast lookups + self._models: dict[str, CustomModel] = {} + self._agent_assignments: dict[str, str] = {} + self._load_from_db() + + def _load_from_db(self) -> None: + """Bootstrap cache from SQLite.""" + try: + conn = _get_conn() + for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall(): + self._models[row["name"]] = CustomModel( + name=row["name"], + format=ModelFormat(row["format"]), + path=row["path"], + role=ModelRole(row["role"]), + context_window=row["context_window"], + description=row["description"], + registered_at=row["registered_at"], + active=bool(row["active"]), + default_temperature=row["default_temperature"], + max_tokens=row["max_tokens"], + ) + for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall(): + self._agent_assignments[row["agent_id"]] = row["model_name"] + conn.close() + except Exception as exc: + logger.warning("Failed to load model registry from DB: %s", exc) + + # ── Model CRUD ───────────────────────────────────────────────────────── + + def register(self, model: CustomModel) -> CustomModel: + """Register a new custom model.""" + with self._lock: + conn = _get_conn() + conn.execute( + """ + INSERT OR REPLACE INTO custom_models + (name, format, path, role, context_window, description, + registered_at, active, default_temperature, max_tokens) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + model.name, model.format.value, model.path, + model.role.value, model.context_window, model.description, + model.registered_at, int(model.active), + model.default_temperature, model.max_tokens, + ), + ) + conn.commit() + conn.close() + self._models[model.name] = model + logger.info("Registered model: %s (%s)", model.name, model.format.value) + return model + + def unregister(self, name: str) -> bool: + """Remove a model from the registry.""" + with self._lock: + if name not in self._models: + return False + conn = _get_conn() + conn.execute("DELETE FROM custom_models WHERE name = ?", (name,)) + conn.execute( + "DELETE FROM agent_model_assignments WHERE model_name = ?", (name,) + ) + conn.commit() + conn.close() + del self._models[name] + # Remove any agent assignments using this model + self._agent_assignments = { + k: v for k, v in self._agent_assignments.items() if v != name + } + logger.info("Unregistered model: %s", name) + return True + + def get(self, name: str) -> Optional[CustomModel]: + """Look up a model by name.""" + return self._models.get(name) + + def list_models(self, role: Optional[ModelRole] = None) -> list[CustomModel]: + """List all registered models, optionally filtered by role.""" + models = list(self._models.values()) + if role is not None: + models = [m for m in models if m.role == role] + return models + + def set_active(self, name: str, active: bool) -> bool: + """Enable or disable a model without removing it.""" + model = self._models.get(name) + if not model: + return False + with self._lock: + model.active = active + conn = _get_conn() + conn.execute( + "UPDATE custom_models SET active = ? WHERE name = ?", + (int(active), name), + ) + conn.commit() + conn.close() + return True + + # ── Agent-model assignments ──────────────────────────────────────────── + + def assign_model(self, agent_id: str, model_name: str) -> bool: + """Assign a specific model to an agent.""" + if model_name not in self._models: + return False + with self._lock: + now = datetime.now(timezone.utc).isoformat() + conn = _get_conn() + conn.execute( + """ + INSERT OR REPLACE INTO agent_model_assignments + (agent_id, model_name, assigned_at) + VALUES (?, ?, ?) + """, + (agent_id, model_name, now), + ) + conn.commit() + conn.close() + self._agent_assignments[agent_id] = model_name + logger.info("Assigned model %s to agent %s", model_name, agent_id) + return True + + def unassign_model(self, agent_id: str) -> bool: + """Remove model assignment from an agent (falls back to default).""" + with self._lock: + if agent_id not in self._agent_assignments: + return False + conn = _get_conn() + conn.execute( + "DELETE FROM agent_model_assignments WHERE agent_id = ?", + (agent_id,), + ) + conn.commit() + conn.close() + del self._agent_assignments[agent_id] + return True + + def get_agent_model(self, agent_id: str) -> Optional[CustomModel]: + """Get the model assigned to an agent, or None for default.""" + model_name = self._agent_assignments.get(agent_id) + if model_name: + return self._models.get(model_name) + return None + + def get_agent_assignments(self) -> dict[str, str]: + """Return all agent-to-model assignments.""" + return dict(self._agent_assignments) + + # ── Role-based lookups ───────────────────────────────────────────────── + + def get_reward_model(self) -> Optional[CustomModel]: + """Get the active reward/PRM model, if any.""" + reward_models = self.list_models(role=ModelRole.REWARD) + active = [m for m in reward_models if m.active] + return active[0] if active else None + + def get_teacher_model(self) -> Optional[CustomModel]: + """Get the active teacher model for distillation.""" + teacher_models = self.list_models(role=ModelRole.TEACHER) + active = [m for m in teacher_models if m.active] + return active[0] if active else None + + +# Module-level singleton +model_registry = ModelRegistry() diff --git a/src/swarm/learner.py b/src/swarm/learner.py index 3f82a46e..b8559f50 100644 --- a/src/swarm/learner.py +++ b/src/swarm/learner.py @@ -251,3 +251,193 @@ def learned_keywords(agent_id: str) -> list[dict]: results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails}) results.sort(key=lambda x: x["net"], reverse=True) return results + + +# ── Reward model scoring (PRM-style) ───────────────────────────────────────── + +import logging as _logging +from config import settings as _settings + +_reward_logger = _logging.getLogger(__name__ + ".reward") + + +@dataclass +class RewardScore: + """Result from reward-model evaluation.""" + score: float # Normalised score in [-1.0, 1.0] + positive_votes: int + negative_votes: int + total_votes: int + model_used: str + + +def _ensure_reward_table() -> None: + """Create the reward_scores table if needed.""" + conn = _get_conn() + conn.execute( + """ + CREATE TABLE IF NOT EXISTS reward_scores ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + output_text TEXT NOT NULL, + score REAL NOT NULL, + positive INTEGER NOT NULL, + negative INTEGER NOT NULL, + total INTEGER NOT NULL, + model_used TEXT NOT NULL, + scored_at TEXT NOT NULL DEFAULT (datetime('now')) + ) + """ + ) + conn.commit() + conn.close() + + +def score_output( + task_id: str, + agent_id: str, + task_description: str, + output_text: str, +) -> Optional[RewardScore]: + """Score an agent's output using the reward model (majority vote). + + Calls the reward model N times (settings.reward_model_votes) with a + quality-evaluation prompt. Each vote is +1 (good) or -1 (bad). + Final score is (positive - negative) / total, in [-1.0, 1.0]. + + Returns None if the reward model is disabled or unavailable. + """ + if not _settings.reward_model_enabled: + return None + + # Resolve model name: explicit setting > registry reward model > skip + model_name = _settings.reward_model_name + if not model_name: + try: + from infrastructure.models.registry import model_registry + reward = model_registry.get_reward_model() + if reward: + model_name = reward.path if reward.format.value == "ollama" else reward.name + except Exception: + pass + + if not model_name: + _reward_logger.debug("No reward model configured, skipping scoring") + return None + + num_votes = max(1, _settings.reward_model_votes) + positive = 0 + negative = 0 + + prompt = ( + f"You are a quality evaluator. Rate the following agent output.\n\n" + f"TASK: {task_description}\n\n" + f"OUTPUT:\n{output_text[:2000]}\n\n" + f"Is this output correct, helpful, and complete? " + f"Reply with exactly one word: GOOD or BAD." + ) + + try: + import requests as _req + ollama_url = _settings.ollama_url + + for _ in range(num_votes): + try: + resp = _req.post( + f"{ollama_url}/api/generate", + json={ + "model": model_name, + "prompt": prompt, + "stream": False, + "options": {"temperature": 0.3, "num_predict": 10}, + }, + timeout=30, + ) + if resp.status_code == 200: + answer = resp.json().get("response", "").strip().upper() + if "GOOD" in answer: + positive += 1 + else: + negative += 1 + else: + negative += 1 # Treat errors as negative conservatively + except Exception as vote_exc: + _reward_logger.debug("Vote failed: %s", vote_exc) + negative += 1 + + except ImportError: + _reward_logger.warning("requests library not available for reward scoring") + return None + + total = positive + negative + if total == 0: + return None + + score = (positive - negative) / total + + result = RewardScore( + score=score, + positive_votes=positive, + negative_votes=negative, + total_votes=total, + model_used=model_name, + ) + + # Persist to DB + try: + _ensure_reward_table() + conn = _get_conn() + conn.execute( + """ + INSERT INTO reward_scores + (task_id, agent_id, output_text, score, positive, negative, total, model_used) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + task_id, agent_id, output_text[:5000], + score, positive, negative, total, model_name, + ), + ) + conn.commit() + conn.close() + except Exception as db_exc: + _reward_logger.warning("Failed to persist reward score: %s", db_exc) + + _reward_logger.info( + "Scored task %s agent %s: %.2f (%d+/%d- of %d votes)", + task_id, agent_id, score, positive, negative, total, + ) + return result + + +def get_reward_scores( + agent_id: Optional[str] = None, limit: int = 50 +) -> list[dict]: + """Retrieve historical reward scores from the database.""" + _ensure_reward_table() + conn = _get_conn() + if agent_id: + rows = conn.execute( + "SELECT * FROM reward_scores WHERE agent_id = ? ORDER BY id DESC LIMIT ?", + (agent_id, limit), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM reward_scores ORDER BY id DESC LIMIT ?", + (limit,), + ).fetchall() + conn.close() + return [ + { + "task_id": r["task_id"], + "agent_id": r["agent_id"], + "score": r["score"], + "positive": r["positive"], + "negative": r["negative"], + "total": r["total"], + "model_used": r["model_used"], + "scored_at": r["scored_at"], + } + for r in rows + ] diff --git a/src/swarm/persona_node.py b/src/swarm/persona_node.py index 98a36755..1ed7481f 100644 --- a/src/swarm/persona_node.py +++ b/src/swarm/persona_node.py @@ -51,6 +51,16 @@ class PersonaNode(SwarmNode): self._meta = meta self._persona_id = persona_id self._use_learner = use_learner + + # Resolve model: registry assignment > persona default > global default + self._model_name: Optional[str] = meta.get("model") + try: + from infrastructure.models.registry import model_registry + assigned = model_registry.get_agent_model(agent_id) + if assigned: + self._model_name = assigned.name + except Exception: + pass # Graceful degradation — use persona/global default # Initialize tool executor for task execution self._tool_executor: Optional[ToolExecutor] = None @@ -213,6 +223,11 @@ class PersonaNode(SwarmNode): """Return the task ID currently being executed, if any.""" return self._current_task + @property + def model_name(self) -> Optional[str]: + """Return the model this agent uses, or None for global default.""" + return self._model_name + @property def tool_capabilities(self) -> list[str]: """Return list of available tool names.""" diff --git a/src/swarm/personas.py b/src/swarm/personas.py index 6a73548d..3c583ff6 100644 --- a/src/swarm/personas.py +++ b/src/swarm/personas.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import TypedDict -class PersonaMeta(TypedDict): +class PersonaMeta(TypedDict, total=False): id: str name: str role: str @@ -24,6 +24,11 @@ class PersonaMeta(TypedDict): bid_base: int # typical bid when task matches persona bid_jitter: int # ± random jitter added to bid_base preferred_keywords: list[str] + # Optional: custom model override for this persona. + # When set, this persona uses this model instead of the global default. + # Value is a model name registered in the ModelRegistry, or an Ollama + # model name like "llama3.2" or "deepseek-r1:1.5b". + model: str PERSONAS: dict[str, PersonaMeta] = { diff --git a/tests/infrastructure/test_model_registry.py b/tests/infrastructure/test_model_registry.py new file mode 100644 index 00000000..d8e70d4d --- /dev/null +++ b/tests/infrastructure/test_model_registry.py @@ -0,0 +1,217 @@ +"""Tests for the custom model registry.""" + +import sqlite3 +from pathlib import Path +from unittest.mock import patch + +import pytest + +from infrastructure.models.registry import ( + CustomModel, + ModelFormat, + ModelRegistry, + ModelRole, +) + + +@pytest.fixture +def registry(tmp_path): + """Create a fresh ModelRegistry backed by a temporary database.""" + db = tmp_path / "test.db" + with patch("infrastructure.models.registry.DB_PATH", db): + reg = ModelRegistry() + yield reg + + +@pytest.fixture +def sample_model(): + """A sample CustomModel for testing.""" + return CustomModel( + name="test-llama", + format=ModelFormat.OLLAMA, + path="llama3.2", + role=ModelRole.GENERAL, + context_window=8192, + description="Test model", + ) + + +@pytest.fixture +def reward_model(): + """A sample reward model.""" + return CustomModel( + name="test-reward", + format=ModelFormat.OLLAMA, + path="deepseek-r1:1.5b", + role=ModelRole.REWARD, + context_window=32000, + description="Test reward model", + ) + + +class TestModelCRUD: + """Test model registration, lookup, and removal.""" + + def test_register_model(self, registry, sample_model): + registered = registry.register(sample_model) + assert registered.name == "test-llama" + assert registered.format == ModelFormat.OLLAMA + + def test_get_model(self, registry, sample_model): + registry.register(sample_model) + found = registry.get("test-llama") + assert found is not None + assert found.name == "test-llama" + assert found.path == "llama3.2" + + def test_get_nonexistent_model(self, registry): + assert registry.get("nonexistent") is None + + def test_list_models(self, registry, sample_model, reward_model): + registry.register(sample_model) + registry.register(reward_model) + all_models = registry.list_models() + assert len(all_models) == 2 + + def test_list_models_by_role(self, registry, sample_model, reward_model): + registry.register(sample_model) + registry.register(reward_model) + general = registry.list_models(role=ModelRole.GENERAL) + assert len(general) == 1 + assert general[0].name == "test-llama" + rewards = registry.list_models(role=ModelRole.REWARD) + assert len(rewards) == 1 + assert rewards[0].name == "test-reward" + + def test_unregister_model(self, registry, sample_model): + registry.register(sample_model) + assert registry.unregister("test-llama") is True + assert registry.get("test-llama") is None + + def test_unregister_nonexistent(self, registry): + assert registry.unregister("nonexistent") is False + + def test_set_active(self, registry, sample_model): + registry.register(sample_model) + assert registry.set_active("test-llama", False) is True + model = registry.get("test-llama") + assert model.active is False + assert registry.set_active("test-llama", True) is True + model = registry.get("test-llama") + assert model.active is True + + def test_set_active_nonexistent(self, registry): + assert registry.set_active("nonexistent", True) is False + + def test_register_replaces_existing(self, registry, sample_model): + registry.register(sample_model) + updated = CustomModel( + name="test-llama", + format=ModelFormat.GGUF, + path="/new/path.gguf", + role=ModelRole.GENERAL, + description="Updated model", + ) + registry.register(updated) + found = registry.get("test-llama") + assert found.format == ModelFormat.GGUF + assert found.path == "/new/path.gguf" + + +class TestAgentAssignments: + """Test agent-to-model assignment management.""" + + def test_assign_model(self, registry, sample_model): + registry.register(sample_model) + assert registry.assign_model("agent-1", "test-llama") is True + model = registry.get_agent_model("agent-1") + assert model is not None + assert model.name == "test-llama" + + def test_assign_nonexistent_model(self, registry): + assert registry.assign_model("agent-1", "nonexistent") is False + + def test_unassign_model(self, registry, sample_model): + registry.register(sample_model) + registry.assign_model("agent-1", "test-llama") + assert registry.unassign_model("agent-1") is True + assert registry.get_agent_model("agent-1") is None + + def test_unassign_nonexistent(self, registry): + assert registry.unassign_model("agent-1") is False + + def test_get_agent_model_none(self, registry): + assert registry.get_agent_model("agent-1") is None + + def test_get_all_assignments(self, registry, sample_model, reward_model): + registry.register(sample_model) + registry.register(reward_model) + registry.assign_model("agent-1", "test-llama") + registry.assign_model("agent-2", "test-reward") + assignments = registry.get_agent_assignments() + assert len(assignments) == 2 + assert assignments["agent-1"] == "test-llama" + assert assignments["agent-2"] == "test-reward" + + def test_unregister_removes_assignments(self, registry, sample_model): + registry.register(sample_model) + registry.assign_model("agent-1", "test-llama") + registry.unregister("test-llama") + assert registry.get_agent_model("agent-1") is None + assert len(registry.get_agent_assignments()) == 0 + + +class TestRoleLookups: + """Test role-based model lookups.""" + + def test_get_reward_model(self, registry, reward_model): + registry.register(reward_model) + found = registry.get_reward_model() + assert found is not None + assert found.name == "test-reward" + assert found.role == ModelRole.REWARD + + def test_get_reward_model_none(self, registry): + assert registry.get_reward_model() is None + + def test_get_teacher_model(self, registry): + teacher = CustomModel( + name="teacher-model", + format=ModelFormat.OLLAMA, + path="teacher:latest", + role=ModelRole.TEACHER, + ) + registry.register(teacher) + found = registry.get_teacher_model() + assert found is not None + assert found.name == "teacher-model" + + def test_get_teacher_model_none(self, registry): + assert registry.get_teacher_model() is None + + def test_inactive_reward_model_not_returned(self, registry, reward_model): + registry.register(reward_model) + registry.set_active("test-reward", False) + assert registry.get_reward_model() is None + + +class TestCustomModelDataclass: + """Test CustomModel construction.""" + + def test_default_registered_at(self): + model = CustomModel( + name="test", format=ModelFormat.OLLAMA, path="test" + ) + assert model.registered_at != "" + + def test_model_roles(self): + assert ModelRole.GENERAL.value == "general" + assert ModelRole.REWARD.value == "reward" + assert ModelRole.TEACHER.value == "teacher" + assert ModelRole.JUDGE.value == "judge" + + def test_model_formats(self): + assert ModelFormat.GGUF.value == "gguf" + assert ModelFormat.SAFETENSORS.value == "safetensors" + assert ModelFormat.HF_CHECKPOINT.value == "hf" + assert ModelFormat.OLLAMA.value == "ollama" diff --git a/tests/infrastructure/test_models_api.py b/tests/infrastructure/test_models_api.py new file mode 100644 index 00000000..212c513b --- /dev/null +++ b/tests/infrastructure/test_models_api.py @@ -0,0 +1,273 @@ +"""Tests for the custom models API routes.""" + +from unittest.mock import patch, MagicMock + +import pytest + +from infrastructure.models.registry import ( + CustomModel, + ModelFormat, + ModelRegistry, + ModelRole, +) + + +@pytest.fixture +def registry(tmp_path): + """A fresh ModelRegistry for each test.""" + db = tmp_path / "api_test.db" + with patch("infrastructure.models.registry.DB_PATH", db): + reg = ModelRegistry() + yield reg + + +class TestModelsAPIList: + """Test listing models via the API.""" + + def test_list_models_empty(self, client, tmp_path): + db = tmp_path / "api.db" + with patch("infrastructure.models.registry.DB_PATH", db): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.list_models.return_value = [] + resp = client.get("/api/v1/models") + assert resp.status_code == 200 + data = resp.json() + assert "models" in data + assert "total" in data + + def test_list_models_with_data(self, client): + model = CustomModel( + name="test-m", + format=ModelFormat.OLLAMA, + path="llama3.2", + role=ModelRole.GENERAL, + ) + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.list_models.return_value = [model] + resp = client.get("/api/v1/models") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert data["models"][0]["name"] == "test-m" + + +class TestModelsAPIRegister: + """Test model registration via the API.""" + + def test_register_ollama_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.register.return_value = CustomModel( + name="my-model", + format=ModelFormat.OLLAMA, + path="llama3.2", + role=ModelRole.GENERAL, + ) + resp = client.post( + "/api/v1/models", + json={ + "name": "my-model", + "format": "ollama", + "path": "llama3.2", + "role": "general", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["model"]["name"] == "my-model" + + def test_register_invalid_format(self, client): + resp = client.post( + "/api/v1/models", + json={ + "name": "bad-model", + "format": "invalid_format", + "path": "whatever", + }, + ) + assert resp.status_code == 400 + assert "Invalid format" in resp.json()["detail"] + + def test_register_invalid_role(self, client): + resp = client.post( + "/api/v1/models", + json={ + "name": "bad-model", + "format": "ollama", + "path": "llama3.2", + "role": "invalid_role", + }, + ) + assert resp.status_code == 400 + assert "Invalid role" in resp.json()["detail"] + + +class TestModelsAPIDelete: + """Test model deletion via the API.""" + + def test_delete_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.unregister.return_value = True + resp = client.delete("/api/v1/models/my-model") + assert resp.status_code == 200 + + def test_delete_nonexistent(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.unregister.return_value = False + resp = client.delete("/api/v1/models/nonexistent") + assert resp.status_code == 404 + + +class TestModelsAPIGet: + """Test getting a specific model.""" + + def test_get_model(self, client): + model = CustomModel( + name="my-model", + format=ModelFormat.OLLAMA, + path="llama3.2", + role=ModelRole.GENERAL, + ) + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get.return_value = model + resp = client.get("/api/v1/models/my-model") + assert resp.status_code == 200 + assert resp.json()["name"] == "my-model" + + def test_get_nonexistent(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get.return_value = None + resp = client.get("/api/v1/models/nonexistent") + assert resp.status_code == 404 + + +class TestModelsAPIAssignments: + """Test agent model assignment endpoints.""" + + def test_assign_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.assign_model.return_value = True + resp = client.post( + "/api/v1/models/assignments", + json={"agent_id": "agent-1", "model_name": "my-model"}, + ) + assert resp.status_code == 200 + + def test_assign_nonexistent_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.assign_model.return_value = False + resp = client.post( + "/api/v1/models/assignments", + json={"agent_id": "agent-1", "model_name": "nonexistent"}, + ) + assert resp.status_code == 404 + + def test_unassign_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.unassign_model.return_value = True + resp = client.delete("/api/v1/models/assignments/agent-1") + assert resp.status_code == 200 + + def test_unassign_nonexistent(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.unassign_model.return_value = False + resp = client.delete("/api/v1/models/assignments/nonexistent") + assert resp.status_code == 404 + + def test_list_assignments(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get_agent_assignments.return_value = { + "agent-1": "model-a", + "agent-2": "model-b", + } + resp = client.get("/api/v1/models/assignments/all") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 2 + + +class TestModelsAPIRoles: + """Test role-based lookup endpoints.""" + + def test_get_reward_model(self, client): + model = CustomModel( + name="reward-m", + format=ModelFormat.OLLAMA, + path="deepseek-r1:1.5b", + role=ModelRole.REWARD, + ) + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get_reward_model.return_value = model + resp = client.get("/api/v1/models/roles/reward") + assert resp.status_code == 200 + data = resp.json() + assert data["reward_model"]["name"] == "reward-m" + + def test_get_reward_model_none(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get_reward_model.return_value = None + resp = client.get("/api/v1/models/roles/reward") + assert resp.status_code == 200 + assert resp.json()["reward_model"] is None + + def test_get_teacher_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.get_teacher_model.return_value = None + resp = client.get("/api/v1/models/roles/teacher") + assert resp.status_code == 200 + assert resp.json()["teacher_model"] is None + + +class TestModelsAPISetActive: + """Test enable/disable model endpoint.""" + + def test_enable_model(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.set_active.return_value = True + resp = client.patch( + "/api/v1/models/my-model/active", + json={"active": True}, + ) + assert resp.status_code == 200 + + def test_disable_nonexistent(self, client): + with patch( + "dashboard.routes.models.model_registry" + ) as mock_reg: + mock_reg.set_active.return_value = False + resp = client.patch( + "/api/v1/models/nonexistent/active", + json={"active": False}, + ) + assert resp.status_code == 404 diff --git a/tests/swarm/test_reward_scoring.py b/tests/swarm/test_reward_scoring.py new file mode 100644 index 00000000..34939ef4 --- /dev/null +++ b/tests/swarm/test_reward_scoring.py @@ -0,0 +1,197 @@ +"""Tests for reward model scoring in the swarm learner.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from swarm.learner import ( + RewardScore, + get_reward_scores, + score_output, +) + + +@pytest.fixture(autouse=True) +def _isolate_db(tmp_path): + """Point the learner at a temporary database.""" + db = tmp_path / "learner_test.db" + with patch("swarm.learner.DB_PATH", db): + yield + + +class TestScoreOutput: + """Test the score_output function.""" + + def test_returns_none_when_disabled(self): + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = False + result = score_output("task-1", "agent-1", "do X", "done X") + assert result is None + + def test_returns_none_when_no_model(self): + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "" + with patch( + "infrastructure.models.registry.model_registry" + ) as mock_reg: + mock_reg.get_reward_model.return_value = None + result = score_output("task-1", "agent-1", "do X", "done X") + assert result is None + + def test_positive_scoring(self): + """All votes return GOOD → score = 1.0.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"response": "GOOD"} + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "test-model" + mock_s.reward_model_votes = 3 + mock_s.ollama_url = "http://localhost:11434" + + with patch("requests.post", return_value=mock_response): + result = score_output("task-1", "agent-1", "do X", "done X") + + assert result is not None + assert result.score == 1.0 + assert result.positive_votes == 3 + assert result.negative_votes == 0 + assert result.total_votes == 3 + assert result.model_used == "test-model" + + def test_negative_scoring(self): + """All votes return BAD → score = -1.0.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"response": "BAD"} + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "test-model" + mock_s.reward_model_votes = 3 + mock_s.ollama_url = "http://localhost:11434" + + with patch("requests.post", return_value=mock_response): + result = score_output("task-1", "agent-1", "do X", "bad output") + + assert result is not None + assert result.score == -1.0 + assert result.negative_votes == 3 + + def test_mixed_scoring(self): + """2 GOOD + 1 BAD → score ≈ 0.33.""" + responses = [] + for text in ["GOOD", "GOOD", "BAD"]: + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = {"response": text} + responses.append(resp) + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "test-model" + mock_s.reward_model_votes = 3 + mock_s.ollama_url = "http://localhost:11434" + + with patch("requests.post", side_effect=responses): + result = score_output("task-1", "agent-1", "do X", "ok output") + + assert result is not None + assert abs(result.score - (1 / 3)) < 0.01 + assert result.positive_votes == 2 + assert result.negative_votes == 1 + + def test_uses_registry_reward_model(self): + """Falls back to registry reward model when setting is empty.""" + mock_model = MagicMock() + mock_model.path = "registry-reward-model" + mock_model.format = MagicMock() + mock_model.format.value = "ollama" + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"response": "GOOD"} + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "" + mock_s.reward_model_votes = 1 + mock_s.ollama_url = "http://localhost:11434" + + with patch( + "infrastructure.models.registry.model_registry" + ) as mock_reg: + mock_reg.get_reward_model.return_value = mock_model + + with patch("requests.post", return_value=mock_response): + result = score_output("task-1", "agent-1", "do X", "ok") + + assert result is not None + assert result.model_used == "registry-reward-model" + + +class TestGetRewardScores: + """Test retrieving historical reward scores.""" + + def test_empty_history(self): + scores = get_reward_scores() + assert scores == [] + + def test_scores_persisted(self): + """Scores from score_output are retrievable.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"response": "GOOD"} + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "test-model" + mock_s.reward_model_votes = 1 + mock_s.ollama_url = "http://localhost:11434" + + with patch("requests.post", return_value=mock_response): + score_output("task-1", "agent-1", "do X", "done X") + + scores = get_reward_scores() + assert len(scores) == 1 + assert scores[0]["task_id"] == "task-1" + assert scores[0]["agent_id"] == "agent-1" + assert scores[0]["score"] == 1.0 + + def test_filter_by_agent(self): + """Filter scores by agent_id.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"response": "GOOD"} + + with patch("swarm.learner._settings") as mock_s: + mock_s.reward_model_enabled = True + mock_s.reward_model_name = "test-model" + mock_s.reward_model_votes = 1 + mock_s.ollama_url = "http://localhost:11434" + + with patch("requests.post", return_value=mock_response): + score_output("task-1", "agent-1", "task A", "output A") + score_output("task-2", "agent-2", "task B", "output B") + + agent1_scores = get_reward_scores(agent_id="agent-1") + assert len(agent1_scores) == 1 + assert agent1_scores[0]["agent_id"] == "agent-1" + + +class TestRewardScoreDataclass: + """Test RewardScore construction.""" + + def test_create_score(self): + score = RewardScore( + score=0.5, + positive_votes=3, + negative_votes=1, + total_votes=4, + model_used="test-model", + ) + assert score.score == 0.5 + assert score.total_votes == 4