diff --git a/config/providers.yaml b/config/providers.yaml
index 90630484..f078eacc 100644
--- a/config/providers.yaml
+++ b/config/providers.yaml
@@ -68,6 +68,37 @@ providers:
- name: claude-3-sonnet-20240229
context_window: 200000
+# ── Custom Models ──────────────────────────────────────────────────────
+# Register custom model weights for per-agent assignment.
+# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
+# Models can also be registered at runtime via the /api/v1/models API.
+#
+# Roles: general (default inference), reward (PRM scoring),
+# teacher (distillation), judge (output evaluation)
+custom_models: []
+ # Example entries:
+ # - name: my-finetuned-llama
+ # format: gguf
+ # path: /path/to/model.gguf
+ # role: general
+ # context_window: 8192
+ # description: "Fine-tuned Llama for code tasks"
+ #
+ # - name: reward-model
+ # format: ollama
+ # path: deepseek-r1:1.5b
+ # role: reward
+ # context_window: 32000
+ # description: "Process reward model for scoring outputs"
+
+# ── Agent Model Assignments ─────────────────────────────────────────────
+# Map persona agent IDs to specific models.
+# Agents without an assignment use the global default (ollama_model).
+agent_model_assignments: {}
+ # Example:
+ # persona-forge: my-finetuned-llama
+ # persona-echo: deepseek-r1:1.5b
+
# Cost tracking (optional, for budget monitoring)
cost_tracking:
enabled: true
diff --git a/src/config.py b/src/config.py
index 30d3c644..d46595c0 100644
--- a/src/config.py
+++ b/src/config.py
@@ -103,6 +103,17 @@ class Settings(BaseSettings):
work_orders_auto_execute: bool = False # Master switch for auto-execution
work_orders_auto_threshold: str = "low" # Max priority that auto-executes: "low" | "medium" | "high" | "none"
+ # ── Custom Weights & Models ──────────────────────────────────────
+ # Directory for custom model weights (GGUF, safetensors, HF checkpoints).
+ # Models placed here can be registered at runtime and assigned to agents.
+ custom_weights_dir: str = "data/models"
+ # Enable the reward model for scoring agent outputs (PRM-style).
+ reward_model_enabled: bool = False
+ # Reward model name (must be available via Ollama or a custom weight path).
+ reward_model_name: str = ""
+ # Minimum votes for majority-vote reward scoring (odd number recommended).
+ reward_model_votes: int = 3
+
# ── Browser Local Models (iPhone / WebGPU) ───────────────────────
# Enable in-browser LLM inference via WebLLM for offline iPhone use.
# When enabled, the mobile dashboard loads a small model directly
diff --git a/src/dashboard/app.py b/src/dashboard/app.py
index b0bc7fa5..dbc6f846 100644
--- a/src/dashboard/app.py
+++ b/src/dashboard/app.py
@@ -36,6 +36,8 @@ from dashboard.routes.self_coding import router as self_coding_router
from dashboard.routes.self_coding import self_modify_router
from dashboard.routes.hands import router as hands_router
from dashboard.routes.grok import router as grok_router
+from dashboard.routes.models import router as models_router
+from dashboard.routes.models import api_router as models_api_router
from infrastructure.router.api import router as cascade_router
logging.basicConfig(
@@ -208,6 +210,8 @@ app.include_router(tasks_router)
app.include_router(scripture_router)
app.include_router(hands_router)
app.include_router(grok_router)
+app.include_router(models_router)
+app.include_router(models_api_router)
app.include_router(cascade_router)
diff --git a/src/dashboard/routes/models.py b/src/dashboard/routes/models.py
new file mode 100644
index 00000000..77c566e9
--- /dev/null
+++ b/src/dashboard/routes/models.py
@@ -0,0 +1,272 @@
+"""Custom model management routes — register, list, assign, and swap models.
+
+Provides a REST API for managing custom model weights and their assignment
+to swarm agents. Inspired by OpenClaw-RL's multi-model orchestration.
+"""
+
+import logging
+from pathlib import Path
+from typing import Any, Optional
+
+from fastapi import APIRouter, HTTPException, Request
+from fastapi.responses import HTMLResponse
+from fastapi.templating import Jinja2Templates
+from pydantic import BaseModel
+
+from config import settings
+from infrastructure.models.registry import (
+ CustomModel,
+ ModelFormat,
+ ModelRegistry,
+ ModelRole,
+ model_registry,
+)
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/models", tags=["models"])
+api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
+templates = Jinja2Templates(directory=str(Path(__file__).parent.parent / "templates"))
+
+
+# ── Pydantic schemas ──────────────────────────────────────────────────────────
+
+
+class RegisterModelRequest(BaseModel):
+ """Request body for model registration."""
+ name: str
+ format: str # gguf, safetensors, hf, ollama
+ path: str
+ role: str = "general"
+ context_window: int = 4096
+ description: str = ""
+ default_temperature: float = 0.7
+ max_tokens: int = 2048
+
+
+class AssignModelRequest(BaseModel):
+ """Request body for assigning a model to an agent."""
+ agent_id: str
+ model_name: str
+
+
+class SetActiveRequest(BaseModel):
+ """Request body for enabling/disabling a model."""
+ active: bool
+
+
+# ── API endpoints ─────────────────────────────────────────────────────────────
+
+
+@api_router.get("")
+async def list_models(role: Optional[str] = None) -> dict[str, Any]:
+ """List all registered custom models."""
+ model_role = ModelRole(role) if role else None
+ models = model_registry.list_models(role=model_role)
+ return {
+ "models": [
+ {
+ "name": m.name,
+ "format": m.format.value,
+ "path": m.path,
+ "role": m.role.value,
+ "context_window": m.context_window,
+ "description": m.description,
+ "active": m.active,
+ "registered_at": m.registered_at,
+ "default_temperature": m.default_temperature,
+ "max_tokens": m.max_tokens,
+ }
+ for m in models
+ ],
+ "total": len(models),
+ "weights_dir": settings.custom_weights_dir,
+ }
+
+
+@api_router.post("")
+async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
+ """Register a new custom model."""
+ try:
+ fmt = ModelFormat(request.format)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid format: {request.format}. "
+ f"Choose from: {[f.value for f in ModelFormat]}",
+ )
+ try:
+ role = ModelRole(request.role)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid role: {request.role}. "
+ f"Choose from: {[r.value for r in ModelRole]}",
+ )
+
+ # Validate path exists for non-Ollama formats
+ if fmt != ModelFormat.OLLAMA:
+ weight_path = Path(request.path)
+ if not weight_path.exists():
+ raise HTTPException(
+ status_code=400,
+ detail=f"Weight path does not exist: {request.path}",
+ )
+
+ model = CustomModel(
+ name=request.name,
+ format=fmt,
+ path=request.path,
+ role=role,
+ context_window=request.context_window,
+ description=request.description,
+ default_temperature=request.default_temperature,
+ max_tokens=request.max_tokens,
+ )
+ registered = model_registry.register(model)
+ return {
+ "message": f"Model {registered.name} registered",
+ "model": {
+ "name": registered.name,
+ "format": registered.format.value,
+ "role": registered.role.value,
+ "path": registered.path,
+ },
+ }
+
+
+@api_router.get("/{model_name}")
+async def get_model(model_name: str) -> dict[str, Any]:
+ """Get details of a specific model."""
+ model = model_registry.get(model_name)
+ if not model:
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
+ return {
+ "name": model.name,
+ "format": model.format.value,
+ "path": model.path,
+ "role": model.role.value,
+ "context_window": model.context_window,
+ "description": model.description,
+ "active": model.active,
+ "registered_at": model.registered_at,
+ "default_temperature": model.default_temperature,
+ "max_tokens": model.max_tokens,
+ }
+
+
+@api_router.delete("/{model_name}")
+async def unregister_model(model_name: str) -> dict[str, str]:
+ """Remove a model from the registry."""
+ if not model_registry.unregister(model_name):
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
+ return {"message": f"Model {model_name} unregistered"}
+
+
+@api_router.patch("/{model_name}/active")
+async def set_model_active(
+ model_name: str, request: SetActiveRequest
+) -> dict[str, str]:
+ """Enable or disable a model."""
+ if not model_registry.set_active(model_name, request.active):
+ raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
+ state = "enabled" if request.active else "disabled"
+ return {"message": f"Model {model_name} {state}"}
+
+
+# ── Agent assignment endpoints ────────────────────────────────────────────────
+
+
+@api_router.get("/assignments/all")
+async def list_assignments() -> dict[str, Any]:
+ """List all agent-to-model assignments."""
+ assignments = model_registry.get_agent_assignments()
+ return {
+ "assignments": [
+ {"agent_id": aid, "model_name": mname}
+ for aid, mname in assignments.items()
+ ],
+ "total": len(assignments),
+ }
+
+
+@api_router.post("/assignments")
+async def assign_model(request: AssignModelRequest) -> dict[str, str]:
+ """Assign a model to a swarm agent."""
+ if not model_registry.assign_model(request.agent_id, request.model_name):
+ raise HTTPException(
+ status_code=404,
+ detail=f"Model {request.model_name} not found in registry",
+ )
+ return {
+ "message": f"Model {request.model_name} assigned to {request.agent_id}",
+ }
+
+
+@api_router.delete("/assignments/{agent_id}")
+async def unassign_model(agent_id: str) -> dict[str, str]:
+ """Remove model assignment from an agent (reverts to default)."""
+ if not model_registry.unassign_model(agent_id):
+ raise HTTPException(
+ status_code=404,
+ detail=f"No model assignment for agent {agent_id}",
+ )
+ return {"message": f"Model assignment removed for {agent_id}"}
+
+
+# ── Role-based lookups ────────────────────────────────────────────────────────
+
+
+@api_router.get("/roles/reward")
+async def get_reward_model() -> dict[str, Any]:
+ """Get the active reward (PRM) model."""
+ model = model_registry.get_reward_model()
+ if not model:
+ return {"reward_model": None, "reward_enabled": settings.reward_model_enabled}
+ return {
+ "reward_model": {
+ "name": model.name,
+ "format": model.format.value,
+ "path": model.path,
+ },
+ "reward_enabled": settings.reward_model_enabled,
+ }
+
+
+@api_router.get("/roles/teacher")
+async def get_teacher_model() -> dict[str, Any]:
+ """Get the active teacher model for distillation."""
+ model = model_registry.get_teacher_model()
+ if not model:
+ return {"teacher_model": None}
+ return {
+ "teacher_model": {
+ "name": model.name,
+ "format": model.format.value,
+ "path": model.path,
+ },
+ }
+
+
+# ── Dashboard page ────────────────────────────────────────────────────────────
+
+
+@router.get("", response_class=HTMLResponse)
+async def models_page(request: Request):
+ """Custom models management dashboard page."""
+ models = model_registry.list_models()
+ assignments = model_registry.get_agent_assignments()
+ reward = model_registry.get_reward_model()
+
+ return templates.TemplateResponse(
+ request,
+ "models.html",
+ {
+ "page_title": "Custom Models",
+ "models": models,
+ "assignments": assignments,
+ "reward_model": reward,
+ "weights_dir": settings.custom_weights_dir,
+ "reward_enabled": settings.reward_model_enabled,
+ },
+ )
diff --git a/src/dashboard/templates/models.html b/src/dashboard/templates/models.html
new file mode 100644
index 00000000..47ccea3b
--- /dev/null
+++ b/src/dashboard/templates/models.html
@@ -0,0 +1,119 @@
+{% extends "base.html" %}
+
+{% block title %}Custom Models - Timmy Time{% endblock %}
+
+{% block content %}
+
+
+
+
+
+
+
{{ models|length }}
+
Models
+
+
+
{{ assignments|length }}
+
Assignments
+
+
+
{{ "Yes" if reward_model else "No" }}
+
Reward Model
+
+
+
+
+
+
+
+
+
Registered Models
+ {% if models %}
+
+
+
+ | Name |
+ Format |
+ Role |
+ Context |
+ Active |
+ Actions |
+
+
+
+ {% for m in models %}
+
+ | {{ m.name }} |
+ {{ m.format.value }} |
+ {{ m.role.value }} |
+ {{ m.context_window }} |
+ {{ "Yes" if m.active else "No" }} |
+
+
+ |
+
+ {% endfor %}
+
+
+ {% else %}
+
No custom models registered. Use the form above or the API.
+ {% endif %}
+
+
+
+
+
Agent Model Assignments
+ {% if assignments %}
+
+
+ | Agent | Model |
+
+
+ {% for agent_id, model_name in assignments.items() %}
+
+ | {{ agent_id }} |
+ {{ model_name }} |
+
+ {% endfor %}
+
+
+ {% else %}
+
No agent-specific model assignments. All agents use the global default.
+ {% endif %}
+
+
+
+
Weights directory: {{ weights_dir }}
+
+
+{% endblock %}
diff --git a/src/infrastructure/models/__init__.py b/src/infrastructure/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/infrastructure/models/registry.py b/src/infrastructure/models/registry.py
new file mode 100644
index 00000000..b9a568c0
--- /dev/null
+++ b/src/infrastructure/models/registry.py
@@ -0,0 +1,268 @@
+"""Custom model registry — register, load, and manage model weights.
+
+Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles)
+and their assignment to swarm agents. Models can be registered at
+runtime via the API or pre-configured via providers.yaml.
+
+Inspired by OpenClaw-RL's multi-model orchestration where distinct
+model roles (student, teacher, judge/PRM) run on dedicated resources.
+"""
+
+import logging
+import sqlite3
+import threading
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from enum import Enum
+from pathlib import Path
+from typing import Optional
+
+from config import settings
+
+logger = logging.getLogger(__name__)
+
+DB_PATH = Path("data/swarm.db")
+
+
+class ModelFormat(str, Enum):
+ """Supported model weight formats."""
+ GGUF = "gguf" # Ollama-compatible quantised weights
+ SAFETENSORS = "safetensors" # HuggingFace safetensors
+ HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
+ OLLAMA = "ollama" # Already loaded in Ollama by name
+
+
+class ModelRole(str, Enum):
+ """Role a model can play in the system (OpenClaw-RL style)."""
+ GENERAL = "general" # Default agent inference
+ REWARD = "reward" # Process Reward Model (PRM) scoring
+ TEACHER = "teacher" # On-policy distillation teacher
+ JUDGE = "judge" # Output quality evaluation
+
+
+@dataclass
+class CustomModel:
+ """A registered custom model."""
+ name: str
+ format: ModelFormat
+ path: str # Absolute path or Ollama model name
+ role: ModelRole = ModelRole.GENERAL
+ context_window: int = 4096
+ description: str = ""
+ registered_at: str = ""
+ active: bool = True
+ # Per-model generation settings
+ default_temperature: float = 0.7
+ max_tokens: int = 2048
+
+ def __post_init__(self):
+ if not self.registered_at:
+ self.registered_at = datetime.now(timezone.utc).isoformat()
+
+
+def _get_conn() -> sqlite3.Connection:
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
+ conn = sqlite3.connect(str(DB_PATH))
+ conn.row_factory = sqlite3.Row
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS custom_models (
+ name TEXT PRIMARY KEY,
+ format TEXT NOT NULL,
+ path TEXT NOT NULL,
+ role TEXT NOT NULL DEFAULT 'general',
+ context_window INTEGER NOT NULL DEFAULT 4096,
+ description TEXT NOT NULL DEFAULT '',
+ registered_at TEXT NOT NULL,
+ active INTEGER NOT NULL DEFAULT 1,
+ default_temperature REAL NOT NULL DEFAULT 0.7,
+ max_tokens INTEGER NOT NULL DEFAULT 2048
+ )
+ """
+ )
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS agent_model_assignments (
+ agent_id TEXT PRIMARY KEY,
+ model_name TEXT NOT NULL,
+ assigned_at TEXT NOT NULL,
+ FOREIGN KEY (model_name) REFERENCES custom_models(name)
+ )
+ """
+ )
+ conn.commit()
+ return conn
+
+
+class ModelRegistry:
+ """Singleton registry for custom models and agent-model assignments."""
+
+ def __init__(self) -> None:
+ self._lock = threading.Lock()
+ # In-memory cache for fast lookups
+ self._models: dict[str, CustomModel] = {}
+ self._agent_assignments: dict[str, str] = {}
+ self._load_from_db()
+
+ def _load_from_db(self) -> None:
+ """Bootstrap cache from SQLite."""
+ try:
+ conn = _get_conn()
+ for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
+ self._models[row["name"]] = CustomModel(
+ name=row["name"],
+ format=ModelFormat(row["format"]),
+ path=row["path"],
+ role=ModelRole(row["role"]),
+ context_window=row["context_window"],
+ description=row["description"],
+ registered_at=row["registered_at"],
+ active=bool(row["active"]),
+ default_temperature=row["default_temperature"],
+ max_tokens=row["max_tokens"],
+ )
+ for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall():
+ self._agent_assignments[row["agent_id"]] = row["model_name"]
+ conn.close()
+ except Exception as exc:
+ logger.warning("Failed to load model registry from DB: %s", exc)
+
+ # ── Model CRUD ─────────────────────────────────────────────────────────
+
+ def register(self, model: CustomModel) -> CustomModel:
+ """Register a new custom model."""
+ with self._lock:
+ conn = _get_conn()
+ conn.execute(
+ """
+ INSERT OR REPLACE INTO custom_models
+ (name, format, path, role, context_window, description,
+ registered_at, active, default_temperature, max_tokens)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ model.name, model.format.value, model.path,
+ model.role.value, model.context_window, model.description,
+ model.registered_at, int(model.active),
+ model.default_temperature, model.max_tokens,
+ ),
+ )
+ conn.commit()
+ conn.close()
+ self._models[model.name] = model
+ logger.info("Registered model: %s (%s)", model.name, model.format.value)
+ return model
+
+ def unregister(self, name: str) -> bool:
+ """Remove a model from the registry."""
+ with self._lock:
+ if name not in self._models:
+ return False
+ conn = _get_conn()
+ conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
+ conn.execute(
+ "DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
+ )
+ conn.commit()
+ conn.close()
+ del self._models[name]
+ # Remove any agent assignments using this model
+ self._agent_assignments = {
+ k: v for k, v in self._agent_assignments.items() if v != name
+ }
+ logger.info("Unregistered model: %s", name)
+ return True
+
+ def get(self, name: str) -> Optional[CustomModel]:
+ """Look up a model by name."""
+ return self._models.get(name)
+
+ def list_models(self, role: Optional[ModelRole] = None) -> list[CustomModel]:
+ """List all registered models, optionally filtered by role."""
+ models = list(self._models.values())
+ if role is not None:
+ models = [m for m in models if m.role == role]
+ return models
+
+ def set_active(self, name: str, active: bool) -> bool:
+ """Enable or disable a model without removing it."""
+ model = self._models.get(name)
+ if not model:
+ return False
+ with self._lock:
+ model.active = active
+ conn = _get_conn()
+ conn.execute(
+ "UPDATE custom_models SET active = ? WHERE name = ?",
+ (int(active), name),
+ )
+ conn.commit()
+ conn.close()
+ return True
+
+ # ── Agent-model assignments ────────────────────────────────────────────
+
+ def assign_model(self, agent_id: str, model_name: str) -> bool:
+ """Assign a specific model to an agent."""
+ if model_name not in self._models:
+ return False
+ with self._lock:
+ now = datetime.now(timezone.utc).isoformat()
+ conn = _get_conn()
+ conn.execute(
+ """
+ INSERT OR REPLACE INTO agent_model_assignments
+ (agent_id, model_name, assigned_at)
+ VALUES (?, ?, ?)
+ """,
+ (agent_id, model_name, now),
+ )
+ conn.commit()
+ conn.close()
+ self._agent_assignments[agent_id] = model_name
+ logger.info("Assigned model %s to agent %s", model_name, agent_id)
+ return True
+
+ def unassign_model(self, agent_id: str) -> bool:
+ """Remove model assignment from an agent (falls back to default)."""
+ with self._lock:
+ if agent_id not in self._agent_assignments:
+ return False
+ conn = _get_conn()
+ conn.execute(
+ "DELETE FROM agent_model_assignments WHERE agent_id = ?",
+ (agent_id,),
+ )
+ conn.commit()
+ conn.close()
+ del self._agent_assignments[agent_id]
+ return True
+
+ def get_agent_model(self, agent_id: str) -> Optional[CustomModel]:
+ """Get the model assigned to an agent, or None for default."""
+ model_name = self._agent_assignments.get(agent_id)
+ if model_name:
+ return self._models.get(model_name)
+ return None
+
+ def get_agent_assignments(self) -> dict[str, str]:
+ """Return all agent-to-model assignments."""
+ return dict(self._agent_assignments)
+
+ # ── Role-based lookups ─────────────────────────────────────────────────
+
+ def get_reward_model(self) -> Optional[CustomModel]:
+ """Get the active reward/PRM model, if any."""
+ reward_models = self.list_models(role=ModelRole.REWARD)
+ active = [m for m in reward_models if m.active]
+ return active[0] if active else None
+
+ def get_teacher_model(self) -> Optional[CustomModel]:
+ """Get the active teacher model for distillation."""
+ teacher_models = self.list_models(role=ModelRole.TEACHER)
+ active = [m for m in teacher_models if m.active]
+ return active[0] if active else None
+
+
+# Module-level singleton
+model_registry = ModelRegistry()
diff --git a/src/swarm/learner.py b/src/swarm/learner.py
index 3f82a46e..b8559f50 100644
--- a/src/swarm/learner.py
+++ b/src/swarm/learner.py
@@ -251,3 +251,193 @@ def learned_keywords(agent_id: str) -> list[dict]:
results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails})
results.sort(key=lambda x: x["net"], reverse=True)
return results
+
+
+# ── Reward model scoring (PRM-style) ─────────────────────────────────────────
+
+import logging as _logging
+from config import settings as _settings
+
+_reward_logger = _logging.getLogger(__name__ + ".reward")
+
+
+@dataclass
+class RewardScore:
+ """Result from reward-model evaluation."""
+ score: float # Normalised score in [-1.0, 1.0]
+ positive_votes: int
+ negative_votes: int
+ total_votes: int
+ model_used: str
+
+
+def _ensure_reward_table() -> None:
+ """Create the reward_scores table if needed."""
+ conn = _get_conn()
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS reward_scores (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ task_id TEXT NOT NULL,
+ agent_id TEXT NOT NULL,
+ output_text TEXT NOT NULL,
+ score REAL NOT NULL,
+ positive INTEGER NOT NULL,
+ negative INTEGER NOT NULL,
+ total INTEGER NOT NULL,
+ model_used TEXT NOT NULL,
+ scored_at TEXT NOT NULL DEFAULT (datetime('now'))
+ )
+ """
+ )
+ conn.commit()
+ conn.close()
+
+
+def score_output(
+ task_id: str,
+ agent_id: str,
+ task_description: str,
+ output_text: str,
+) -> Optional[RewardScore]:
+ """Score an agent's output using the reward model (majority vote).
+
+ Calls the reward model N times (settings.reward_model_votes) with a
+ quality-evaluation prompt. Each vote is +1 (good) or -1 (bad).
+ Final score is (positive - negative) / total, in [-1.0, 1.0].
+
+ Returns None if the reward model is disabled or unavailable.
+ """
+ if not _settings.reward_model_enabled:
+ return None
+
+ # Resolve model name: explicit setting > registry reward model > skip
+ model_name = _settings.reward_model_name
+ if not model_name:
+ try:
+ from infrastructure.models.registry import model_registry
+ reward = model_registry.get_reward_model()
+ if reward:
+ model_name = reward.path if reward.format.value == "ollama" else reward.name
+ except Exception:
+ pass
+
+ if not model_name:
+ _reward_logger.debug("No reward model configured, skipping scoring")
+ return None
+
+ num_votes = max(1, _settings.reward_model_votes)
+ positive = 0
+ negative = 0
+
+ prompt = (
+ f"You are a quality evaluator. Rate the following agent output.\n\n"
+ f"TASK: {task_description}\n\n"
+ f"OUTPUT:\n{output_text[:2000]}\n\n"
+ f"Is this output correct, helpful, and complete? "
+ f"Reply with exactly one word: GOOD or BAD."
+ )
+
+ try:
+ import requests as _req
+ ollama_url = _settings.ollama_url
+
+ for _ in range(num_votes):
+ try:
+ resp = _req.post(
+ f"{ollama_url}/api/generate",
+ json={
+ "model": model_name,
+ "prompt": prompt,
+ "stream": False,
+ "options": {"temperature": 0.3, "num_predict": 10},
+ },
+ timeout=30,
+ )
+ if resp.status_code == 200:
+ answer = resp.json().get("response", "").strip().upper()
+ if "GOOD" in answer:
+ positive += 1
+ else:
+ negative += 1
+ else:
+ negative += 1 # Treat errors as negative conservatively
+ except Exception as vote_exc:
+ _reward_logger.debug("Vote failed: %s", vote_exc)
+ negative += 1
+
+ except ImportError:
+ _reward_logger.warning("requests library not available for reward scoring")
+ return None
+
+ total = positive + negative
+ if total == 0:
+ return None
+
+ score = (positive - negative) / total
+
+ result = RewardScore(
+ score=score,
+ positive_votes=positive,
+ negative_votes=negative,
+ total_votes=total,
+ model_used=model_name,
+ )
+
+ # Persist to DB
+ try:
+ _ensure_reward_table()
+ conn = _get_conn()
+ conn.execute(
+ """
+ INSERT INTO reward_scores
+ (task_id, agent_id, output_text, score, positive, negative, total, model_used)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ (
+ task_id, agent_id, output_text[:5000],
+ score, positive, negative, total, model_name,
+ ),
+ )
+ conn.commit()
+ conn.close()
+ except Exception as db_exc:
+ _reward_logger.warning("Failed to persist reward score: %s", db_exc)
+
+ _reward_logger.info(
+ "Scored task %s agent %s: %.2f (%d+/%d- of %d votes)",
+ task_id, agent_id, score, positive, negative, total,
+ )
+ return result
+
+
+def get_reward_scores(
+ agent_id: Optional[str] = None, limit: int = 50
+) -> list[dict]:
+ """Retrieve historical reward scores from the database."""
+ _ensure_reward_table()
+ conn = _get_conn()
+ if agent_id:
+ rows = conn.execute(
+ "SELECT * FROM reward_scores WHERE agent_id = ? ORDER BY id DESC LIMIT ?",
+ (agent_id, limit),
+ ).fetchall()
+ else:
+ rows = conn.execute(
+ "SELECT * FROM reward_scores ORDER BY id DESC LIMIT ?",
+ (limit,),
+ ).fetchall()
+ conn.close()
+ return [
+ {
+ "task_id": r["task_id"],
+ "agent_id": r["agent_id"],
+ "score": r["score"],
+ "positive": r["positive"],
+ "negative": r["negative"],
+ "total": r["total"],
+ "model_used": r["model_used"],
+ "scored_at": r["scored_at"],
+ }
+ for r in rows
+ ]
diff --git a/src/swarm/persona_node.py b/src/swarm/persona_node.py
index 98a36755..1ed7481f 100644
--- a/src/swarm/persona_node.py
+++ b/src/swarm/persona_node.py
@@ -51,6 +51,16 @@ class PersonaNode(SwarmNode):
self._meta = meta
self._persona_id = persona_id
self._use_learner = use_learner
+
+ # Resolve model: registry assignment > persona default > global default
+ self._model_name: Optional[str] = meta.get("model")
+ try:
+ from infrastructure.models.registry import model_registry
+ assigned = model_registry.get_agent_model(agent_id)
+ if assigned:
+ self._model_name = assigned.name
+ except Exception:
+ pass # Graceful degradation — use persona/global default
# Initialize tool executor for task execution
self._tool_executor: Optional[ToolExecutor] = None
@@ -213,6 +223,11 @@ class PersonaNode(SwarmNode):
"""Return the task ID currently being executed, if any."""
return self._current_task
+ @property
+ def model_name(self) -> Optional[str]:
+ """Return the model this agent uses, or None for global default."""
+ return self._model_name
+
@property
def tool_capabilities(self) -> list[str]:
"""Return list of available tool names."""
diff --git a/src/swarm/personas.py b/src/swarm/personas.py
index 6a73548d..3c583ff6 100644
--- a/src/swarm/personas.py
+++ b/src/swarm/personas.py
@@ -14,7 +14,7 @@ from __future__ import annotations
from typing import TypedDict
-class PersonaMeta(TypedDict):
+class PersonaMeta(TypedDict, total=False):
id: str
name: str
role: str
@@ -24,6 +24,11 @@ class PersonaMeta(TypedDict):
bid_base: int # typical bid when task matches persona
bid_jitter: int # ± random jitter added to bid_base
preferred_keywords: list[str]
+ # Optional: custom model override for this persona.
+ # When set, this persona uses this model instead of the global default.
+ # Value is a model name registered in the ModelRegistry, or an Ollama
+ # model name like "llama3.2" or "deepseek-r1:1.5b".
+ model: str
PERSONAS: dict[str, PersonaMeta] = {
diff --git a/tests/infrastructure/test_model_registry.py b/tests/infrastructure/test_model_registry.py
new file mode 100644
index 00000000..d8e70d4d
--- /dev/null
+++ b/tests/infrastructure/test_model_registry.py
@@ -0,0 +1,217 @@
+"""Tests for the custom model registry."""
+
+import sqlite3
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+
+from infrastructure.models.registry import (
+ CustomModel,
+ ModelFormat,
+ ModelRegistry,
+ ModelRole,
+)
+
+
+@pytest.fixture
+def registry(tmp_path):
+ """Create a fresh ModelRegistry backed by a temporary database."""
+ db = tmp_path / "test.db"
+ with patch("infrastructure.models.registry.DB_PATH", db):
+ reg = ModelRegistry()
+ yield reg
+
+
+@pytest.fixture
+def sample_model():
+ """A sample CustomModel for testing."""
+ return CustomModel(
+ name="test-llama",
+ format=ModelFormat.OLLAMA,
+ path="llama3.2",
+ role=ModelRole.GENERAL,
+ context_window=8192,
+ description="Test model",
+ )
+
+
+@pytest.fixture
+def reward_model():
+ """A sample reward model."""
+ return CustomModel(
+ name="test-reward",
+ format=ModelFormat.OLLAMA,
+ path="deepseek-r1:1.5b",
+ role=ModelRole.REWARD,
+ context_window=32000,
+ description="Test reward model",
+ )
+
+
+class TestModelCRUD:
+ """Test model registration, lookup, and removal."""
+
+ def test_register_model(self, registry, sample_model):
+ registered = registry.register(sample_model)
+ assert registered.name == "test-llama"
+ assert registered.format == ModelFormat.OLLAMA
+
+ def test_get_model(self, registry, sample_model):
+ registry.register(sample_model)
+ found = registry.get("test-llama")
+ assert found is not None
+ assert found.name == "test-llama"
+ assert found.path == "llama3.2"
+
+ def test_get_nonexistent_model(self, registry):
+ assert registry.get("nonexistent") is None
+
+ def test_list_models(self, registry, sample_model, reward_model):
+ registry.register(sample_model)
+ registry.register(reward_model)
+ all_models = registry.list_models()
+ assert len(all_models) == 2
+
+ def test_list_models_by_role(self, registry, sample_model, reward_model):
+ registry.register(sample_model)
+ registry.register(reward_model)
+ general = registry.list_models(role=ModelRole.GENERAL)
+ assert len(general) == 1
+ assert general[0].name == "test-llama"
+ rewards = registry.list_models(role=ModelRole.REWARD)
+ assert len(rewards) == 1
+ assert rewards[0].name == "test-reward"
+
+ def test_unregister_model(self, registry, sample_model):
+ registry.register(sample_model)
+ assert registry.unregister("test-llama") is True
+ assert registry.get("test-llama") is None
+
+ def test_unregister_nonexistent(self, registry):
+ assert registry.unregister("nonexistent") is False
+
+ def test_set_active(self, registry, sample_model):
+ registry.register(sample_model)
+ assert registry.set_active("test-llama", False) is True
+ model = registry.get("test-llama")
+ assert model.active is False
+ assert registry.set_active("test-llama", True) is True
+ model = registry.get("test-llama")
+ assert model.active is True
+
+ def test_set_active_nonexistent(self, registry):
+ assert registry.set_active("nonexistent", True) is False
+
+ def test_register_replaces_existing(self, registry, sample_model):
+ registry.register(sample_model)
+ updated = CustomModel(
+ name="test-llama",
+ format=ModelFormat.GGUF,
+ path="/new/path.gguf",
+ role=ModelRole.GENERAL,
+ description="Updated model",
+ )
+ registry.register(updated)
+ found = registry.get("test-llama")
+ assert found.format == ModelFormat.GGUF
+ assert found.path == "/new/path.gguf"
+
+
+class TestAgentAssignments:
+ """Test agent-to-model assignment management."""
+
+ def test_assign_model(self, registry, sample_model):
+ registry.register(sample_model)
+ assert registry.assign_model("agent-1", "test-llama") is True
+ model = registry.get_agent_model("agent-1")
+ assert model is not None
+ assert model.name == "test-llama"
+
+ def test_assign_nonexistent_model(self, registry):
+ assert registry.assign_model("agent-1", "nonexistent") is False
+
+ def test_unassign_model(self, registry, sample_model):
+ registry.register(sample_model)
+ registry.assign_model("agent-1", "test-llama")
+ assert registry.unassign_model("agent-1") is True
+ assert registry.get_agent_model("agent-1") is None
+
+ def test_unassign_nonexistent(self, registry):
+ assert registry.unassign_model("agent-1") is False
+
+ def test_get_agent_model_none(self, registry):
+ assert registry.get_agent_model("agent-1") is None
+
+ def test_get_all_assignments(self, registry, sample_model, reward_model):
+ registry.register(sample_model)
+ registry.register(reward_model)
+ registry.assign_model("agent-1", "test-llama")
+ registry.assign_model("agent-2", "test-reward")
+ assignments = registry.get_agent_assignments()
+ assert len(assignments) == 2
+ assert assignments["agent-1"] == "test-llama"
+ assert assignments["agent-2"] == "test-reward"
+
+ def test_unregister_removes_assignments(self, registry, sample_model):
+ registry.register(sample_model)
+ registry.assign_model("agent-1", "test-llama")
+ registry.unregister("test-llama")
+ assert registry.get_agent_model("agent-1") is None
+ assert len(registry.get_agent_assignments()) == 0
+
+
+class TestRoleLookups:
+ """Test role-based model lookups."""
+
+ def test_get_reward_model(self, registry, reward_model):
+ registry.register(reward_model)
+ found = registry.get_reward_model()
+ assert found is not None
+ assert found.name == "test-reward"
+ assert found.role == ModelRole.REWARD
+
+ def test_get_reward_model_none(self, registry):
+ assert registry.get_reward_model() is None
+
+ def test_get_teacher_model(self, registry):
+ teacher = CustomModel(
+ name="teacher-model",
+ format=ModelFormat.OLLAMA,
+ path="teacher:latest",
+ role=ModelRole.TEACHER,
+ )
+ registry.register(teacher)
+ found = registry.get_teacher_model()
+ assert found is not None
+ assert found.name == "teacher-model"
+
+ def test_get_teacher_model_none(self, registry):
+ assert registry.get_teacher_model() is None
+
+ def test_inactive_reward_model_not_returned(self, registry, reward_model):
+ registry.register(reward_model)
+ registry.set_active("test-reward", False)
+ assert registry.get_reward_model() is None
+
+
+class TestCustomModelDataclass:
+ """Test CustomModel construction."""
+
+ def test_default_registered_at(self):
+ model = CustomModel(
+ name="test", format=ModelFormat.OLLAMA, path="test"
+ )
+ assert model.registered_at != ""
+
+ def test_model_roles(self):
+ assert ModelRole.GENERAL.value == "general"
+ assert ModelRole.REWARD.value == "reward"
+ assert ModelRole.TEACHER.value == "teacher"
+ assert ModelRole.JUDGE.value == "judge"
+
+ def test_model_formats(self):
+ assert ModelFormat.GGUF.value == "gguf"
+ assert ModelFormat.SAFETENSORS.value == "safetensors"
+ assert ModelFormat.HF_CHECKPOINT.value == "hf"
+ assert ModelFormat.OLLAMA.value == "ollama"
diff --git a/tests/infrastructure/test_models_api.py b/tests/infrastructure/test_models_api.py
new file mode 100644
index 00000000..212c513b
--- /dev/null
+++ b/tests/infrastructure/test_models_api.py
@@ -0,0 +1,273 @@
+"""Tests for the custom models API routes."""
+
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+from infrastructure.models.registry import (
+ CustomModel,
+ ModelFormat,
+ ModelRegistry,
+ ModelRole,
+)
+
+
+@pytest.fixture
+def registry(tmp_path):
+ """A fresh ModelRegistry for each test."""
+ db = tmp_path / "api_test.db"
+ with patch("infrastructure.models.registry.DB_PATH", db):
+ reg = ModelRegistry()
+ yield reg
+
+
+class TestModelsAPIList:
+ """Test listing models via the API."""
+
+ def test_list_models_empty(self, client, tmp_path):
+ db = tmp_path / "api.db"
+ with patch("infrastructure.models.registry.DB_PATH", db):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.list_models.return_value = []
+ resp = client.get("/api/v1/models")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert "models" in data
+ assert "total" in data
+
+ def test_list_models_with_data(self, client):
+ model = CustomModel(
+ name="test-m",
+ format=ModelFormat.OLLAMA,
+ path="llama3.2",
+ role=ModelRole.GENERAL,
+ )
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.list_models.return_value = [model]
+ resp = client.get("/api/v1/models")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["total"] == 1
+ assert data["models"][0]["name"] == "test-m"
+
+
+class TestModelsAPIRegister:
+ """Test model registration via the API."""
+
+ def test_register_ollama_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.register.return_value = CustomModel(
+ name="my-model",
+ format=ModelFormat.OLLAMA,
+ path="llama3.2",
+ role=ModelRole.GENERAL,
+ )
+ resp = client.post(
+ "/api/v1/models",
+ json={
+ "name": "my-model",
+ "format": "ollama",
+ "path": "llama3.2",
+ "role": "general",
+ },
+ )
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["model"]["name"] == "my-model"
+
+ def test_register_invalid_format(self, client):
+ resp = client.post(
+ "/api/v1/models",
+ json={
+ "name": "bad-model",
+ "format": "invalid_format",
+ "path": "whatever",
+ },
+ )
+ assert resp.status_code == 400
+ assert "Invalid format" in resp.json()["detail"]
+
+ def test_register_invalid_role(self, client):
+ resp = client.post(
+ "/api/v1/models",
+ json={
+ "name": "bad-model",
+ "format": "ollama",
+ "path": "llama3.2",
+ "role": "invalid_role",
+ },
+ )
+ assert resp.status_code == 400
+ assert "Invalid role" in resp.json()["detail"]
+
+
+class TestModelsAPIDelete:
+ """Test model deletion via the API."""
+
+ def test_delete_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.unregister.return_value = True
+ resp = client.delete("/api/v1/models/my-model")
+ assert resp.status_code == 200
+
+ def test_delete_nonexistent(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.unregister.return_value = False
+ resp = client.delete("/api/v1/models/nonexistent")
+ assert resp.status_code == 404
+
+
+class TestModelsAPIGet:
+ """Test getting a specific model."""
+
+ def test_get_model(self, client):
+ model = CustomModel(
+ name="my-model",
+ format=ModelFormat.OLLAMA,
+ path="llama3.2",
+ role=ModelRole.GENERAL,
+ )
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get.return_value = model
+ resp = client.get("/api/v1/models/my-model")
+ assert resp.status_code == 200
+ assert resp.json()["name"] == "my-model"
+
+ def test_get_nonexistent(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get.return_value = None
+ resp = client.get("/api/v1/models/nonexistent")
+ assert resp.status_code == 404
+
+
+class TestModelsAPIAssignments:
+ """Test agent model assignment endpoints."""
+
+ def test_assign_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.assign_model.return_value = True
+ resp = client.post(
+ "/api/v1/models/assignments",
+ json={"agent_id": "agent-1", "model_name": "my-model"},
+ )
+ assert resp.status_code == 200
+
+ def test_assign_nonexistent_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.assign_model.return_value = False
+ resp = client.post(
+ "/api/v1/models/assignments",
+ json={"agent_id": "agent-1", "model_name": "nonexistent"},
+ )
+ assert resp.status_code == 404
+
+ def test_unassign_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.unassign_model.return_value = True
+ resp = client.delete("/api/v1/models/assignments/agent-1")
+ assert resp.status_code == 200
+
+ def test_unassign_nonexistent(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.unassign_model.return_value = False
+ resp = client.delete("/api/v1/models/assignments/nonexistent")
+ assert resp.status_code == 404
+
+ def test_list_assignments(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get_agent_assignments.return_value = {
+ "agent-1": "model-a",
+ "agent-2": "model-b",
+ }
+ resp = client.get("/api/v1/models/assignments/all")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["total"] == 2
+
+
+class TestModelsAPIRoles:
+ """Test role-based lookup endpoints."""
+
+ def test_get_reward_model(self, client):
+ model = CustomModel(
+ name="reward-m",
+ format=ModelFormat.OLLAMA,
+ path="deepseek-r1:1.5b",
+ role=ModelRole.REWARD,
+ )
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get_reward_model.return_value = model
+ resp = client.get("/api/v1/models/roles/reward")
+ assert resp.status_code == 200
+ data = resp.json()
+ assert data["reward_model"]["name"] == "reward-m"
+
+ def test_get_reward_model_none(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get_reward_model.return_value = None
+ resp = client.get("/api/v1/models/roles/reward")
+ assert resp.status_code == 200
+ assert resp.json()["reward_model"] is None
+
+ def test_get_teacher_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.get_teacher_model.return_value = None
+ resp = client.get("/api/v1/models/roles/teacher")
+ assert resp.status_code == 200
+ assert resp.json()["teacher_model"] is None
+
+
+class TestModelsAPISetActive:
+ """Test enable/disable model endpoint."""
+
+ def test_enable_model(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.set_active.return_value = True
+ resp = client.patch(
+ "/api/v1/models/my-model/active",
+ json={"active": True},
+ )
+ assert resp.status_code == 200
+
+ def test_disable_nonexistent(self, client):
+ with patch(
+ "dashboard.routes.models.model_registry"
+ ) as mock_reg:
+ mock_reg.set_active.return_value = False
+ resp = client.patch(
+ "/api/v1/models/nonexistent/active",
+ json={"active": False},
+ )
+ assert resp.status_code == 404
diff --git a/tests/swarm/test_reward_scoring.py b/tests/swarm/test_reward_scoring.py
new file mode 100644
index 00000000..34939ef4
--- /dev/null
+++ b/tests/swarm/test_reward_scoring.py
@@ -0,0 +1,197 @@
+"""Tests for reward model scoring in the swarm learner."""
+
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from swarm.learner import (
+ RewardScore,
+ get_reward_scores,
+ score_output,
+)
+
+
+@pytest.fixture(autouse=True)
+def _isolate_db(tmp_path):
+ """Point the learner at a temporary database."""
+ db = tmp_path / "learner_test.db"
+ with patch("swarm.learner.DB_PATH", db):
+ yield
+
+
+class TestScoreOutput:
+ """Test the score_output function."""
+
+ def test_returns_none_when_disabled(self):
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = False
+ result = score_output("task-1", "agent-1", "do X", "done X")
+ assert result is None
+
+ def test_returns_none_when_no_model(self):
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = ""
+ with patch(
+ "infrastructure.models.registry.model_registry"
+ ) as mock_reg:
+ mock_reg.get_reward_model.return_value = None
+ result = score_output("task-1", "agent-1", "do X", "done X")
+ assert result is None
+
+ def test_positive_scoring(self):
+ """All votes return GOOD → score = 1.0."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"response": "GOOD"}
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = "test-model"
+ mock_s.reward_model_votes = 3
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch("requests.post", return_value=mock_response):
+ result = score_output("task-1", "agent-1", "do X", "done X")
+
+ assert result is not None
+ assert result.score == 1.0
+ assert result.positive_votes == 3
+ assert result.negative_votes == 0
+ assert result.total_votes == 3
+ assert result.model_used == "test-model"
+
+ def test_negative_scoring(self):
+ """All votes return BAD → score = -1.0."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"response": "BAD"}
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = "test-model"
+ mock_s.reward_model_votes = 3
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch("requests.post", return_value=mock_response):
+ result = score_output("task-1", "agent-1", "do X", "bad output")
+
+ assert result is not None
+ assert result.score == -1.0
+ assert result.negative_votes == 3
+
+ def test_mixed_scoring(self):
+ """2 GOOD + 1 BAD → score ≈ 0.33."""
+ responses = []
+ for text in ["GOOD", "GOOD", "BAD"]:
+ resp = MagicMock()
+ resp.status_code = 200
+ resp.json.return_value = {"response": text}
+ responses.append(resp)
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = "test-model"
+ mock_s.reward_model_votes = 3
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch("requests.post", side_effect=responses):
+ result = score_output("task-1", "agent-1", "do X", "ok output")
+
+ assert result is not None
+ assert abs(result.score - (1 / 3)) < 0.01
+ assert result.positive_votes == 2
+ assert result.negative_votes == 1
+
+ def test_uses_registry_reward_model(self):
+ """Falls back to registry reward model when setting is empty."""
+ mock_model = MagicMock()
+ mock_model.path = "registry-reward-model"
+ mock_model.format = MagicMock()
+ mock_model.format.value = "ollama"
+
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"response": "GOOD"}
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = ""
+ mock_s.reward_model_votes = 1
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch(
+ "infrastructure.models.registry.model_registry"
+ ) as mock_reg:
+ mock_reg.get_reward_model.return_value = mock_model
+
+ with patch("requests.post", return_value=mock_response):
+ result = score_output("task-1", "agent-1", "do X", "ok")
+
+ assert result is not None
+ assert result.model_used == "registry-reward-model"
+
+
+class TestGetRewardScores:
+ """Test retrieving historical reward scores."""
+
+ def test_empty_history(self):
+ scores = get_reward_scores()
+ assert scores == []
+
+ def test_scores_persisted(self):
+ """Scores from score_output are retrievable."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"response": "GOOD"}
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = "test-model"
+ mock_s.reward_model_votes = 1
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch("requests.post", return_value=mock_response):
+ score_output("task-1", "agent-1", "do X", "done X")
+
+ scores = get_reward_scores()
+ assert len(scores) == 1
+ assert scores[0]["task_id"] == "task-1"
+ assert scores[0]["agent_id"] == "agent-1"
+ assert scores[0]["score"] == 1.0
+
+ def test_filter_by_agent(self):
+ """Filter scores by agent_id."""
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"response": "GOOD"}
+
+ with patch("swarm.learner._settings") as mock_s:
+ mock_s.reward_model_enabled = True
+ mock_s.reward_model_name = "test-model"
+ mock_s.reward_model_votes = 1
+ mock_s.ollama_url = "http://localhost:11434"
+
+ with patch("requests.post", return_value=mock_response):
+ score_output("task-1", "agent-1", "task A", "output A")
+ score_output("task-2", "agent-2", "task B", "output B")
+
+ agent1_scores = get_reward_scores(agent_id="agent-1")
+ assert len(agent1_scores) == 1
+ assert agent1_scores[0]["agent_id"] == "agent-1"
+
+
+class TestRewardScoreDataclass:
+ """Test RewardScore construction."""
+
+ def test_create_score(self):
+ score = RewardScore(
+ score=0.5,
+ positive_votes=3,
+ negative_votes=1,
+ total_votes=4,
+ model_used="test-model",
+ )
+ assert score.score == 0.5
+ assert score.total_votes == 4