feat: add custom weights, model registry, per-agent models, and reward scoring

Inspired by OpenClaw-RL's multi-model orchestration, this adds four
features for custom model management:

1. Custom model registry (infrastructure/models/registry.py) — SQLite-backed
   registry for GGUF, safetensors, HF checkpoint, and Ollama models with
   role-based lookups (general, reward, teacher, judge).

2. Per-agent model assignment — each swarm persona can use a different model
   instead of sharing the global default. Resolved via registry assignment >
   persona default > global default.

3. Runtime model management API (/api/v1/models) — REST endpoints to register,
   list, assign, enable/disable, and remove custom models without restart.
   Includes a dashboard page at /models.

4. Reward model scoring (PRM-style) — majority-vote quality evaluation of
   agent outputs using a configurable reward model. Scores persist in SQLite
   and feed into the swarm learner.

New config settings: custom_weights_dir, reward_model_enabled,
reward_model_name, reward_model_votes.

54 new tests covering registry CRUD, API endpoints, agent assignments,
role lookups, and reward scoring.

https://claude.ai/code/session_01V4iTozMwcE2gjfnCJdCugC
This commit is contained in:
Claude
2026-02-27 01:08:03 +00:00
parent e4d5ec5ed4
commit 211c54bc8c
13 changed files with 1603 additions and 1 deletions

View File

@@ -68,6 +68,37 @@ providers:
- name: claude-3-sonnet-20240229
context_window: 200000
# ── Custom Models ──────────────────────────────────────────────────────
# Register custom model weights for per-agent assignment.
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
# Models can also be registered at runtime via the /api/v1/models API.
#
# Roles: general (default inference), reward (PRM scoring),
# teacher (distillation), judge (output evaluation)
custom_models: []
# Example entries:
# - name: my-finetuned-llama
# format: gguf
# path: /path/to/model.gguf
# role: general
# context_window: 8192
# description: "Fine-tuned Llama for code tasks"
#
# - name: reward-model
# format: ollama
# path: deepseek-r1:1.5b
# role: reward
# context_window: 32000
# description: "Process reward model for scoring outputs"
# ── Agent Model Assignments ─────────────────────────────────────────────
# Map persona agent IDs to specific models.
# Agents without an assignment use the global default (ollama_model).
agent_model_assignments: {}
# Example:
# persona-forge: my-finetuned-llama
# persona-echo: deepseek-r1:1.5b
# Cost tracking (optional, for budget monitoring)
cost_tracking:
enabled: true

View File

@@ -103,6 +103,17 @@ class Settings(BaseSettings):
work_orders_auto_execute: bool = False # Master switch for auto-execution
work_orders_auto_threshold: str = "low" # Max priority that auto-executes: "low" | "medium" | "high" | "none"
# ── Custom Weights & Models ──────────────────────────────────────
# Directory for custom model weights (GGUF, safetensors, HF checkpoints).
# Models placed here can be registered at runtime and assigned to agents.
custom_weights_dir: str = "data/models"
# Enable the reward model for scoring agent outputs (PRM-style).
reward_model_enabled: bool = False
# Reward model name (must be available via Ollama or a custom weight path).
reward_model_name: str = ""
# Minimum votes for majority-vote reward scoring (odd number recommended).
reward_model_votes: int = 3
# ── Browser Local Models (iPhone / WebGPU) ───────────────────────
# Enable in-browser LLM inference via WebLLM for offline iPhone use.
# When enabled, the mobile dashboard loads a small model directly

View File

@@ -36,6 +36,8 @@ from dashboard.routes.self_coding import router as self_coding_router
from dashboard.routes.self_coding import self_modify_router
from dashboard.routes.hands import router as hands_router
from dashboard.routes.grok import router as grok_router
from dashboard.routes.models import router as models_router
from dashboard.routes.models import api_router as models_api_router
from infrastructure.router.api import router as cascade_router
logging.basicConfig(
@@ -208,6 +210,8 @@ app.include_router(tasks_router)
app.include_router(scripture_router)
app.include_router(hands_router)
app.include_router(grok_router)
app.include_router(models_router)
app.include_router(models_api_router)
app.include_router(cascade_router)

View File

@@ -0,0 +1,272 @@
"""Custom model management routes — register, list, assign, and swap models.
Provides a REST API for managing custom model weights and their assignment
to swarm agents. Inspired by OpenClaw-RL's multi-model orchestration.
"""
import logging
from pathlib import Path
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from config import settings
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/models", tags=["models"])
api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
templates = Jinja2Templates(directory=str(Path(__file__).parent.parent / "templates"))
# ── Pydantic schemas ──────────────────────────────────────────────────────────
class RegisterModelRequest(BaseModel):
"""Request body for model registration."""
name: str
format: str # gguf, safetensors, hf, ollama
path: str
role: str = "general"
context_window: int = 4096
description: str = ""
default_temperature: float = 0.7
max_tokens: int = 2048
class AssignModelRequest(BaseModel):
"""Request body for assigning a model to an agent."""
agent_id: str
model_name: str
class SetActiveRequest(BaseModel):
"""Request body for enabling/disabling a model."""
active: bool
# ── API endpoints ─────────────────────────────────────────────────────────────
@api_router.get("")
async def list_models(role: Optional[str] = None) -> dict[str, Any]:
"""List all registered custom models."""
model_role = ModelRole(role) if role else None
models = model_registry.list_models(role=model_role)
return {
"models": [
{
"name": m.name,
"format": m.format.value,
"path": m.path,
"role": m.role.value,
"context_window": m.context_window,
"description": m.description,
"active": m.active,
"registered_at": m.registered_at,
"default_temperature": m.default_temperature,
"max_tokens": m.max_tokens,
}
for m in models
],
"total": len(models),
"weights_dir": settings.custom_weights_dir,
}
@api_router.post("")
async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
"""Register a new custom model."""
try:
fmt = ModelFormat(request.format)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid format: {request.format}. "
f"Choose from: {[f.value for f in ModelFormat]}",
)
try:
role = ModelRole(request.role)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {request.role}. "
f"Choose from: {[r.value for r in ModelRole]}",
)
# Validate path exists for non-Ollama formats
if fmt != ModelFormat.OLLAMA:
weight_path = Path(request.path)
if not weight_path.exists():
raise HTTPException(
status_code=400,
detail=f"Weight path does not exist: {request.path}",
)
model = CustomModel(
name=request.name,
format=fmt,
path=request.path,
role=role,
context_window=request.context_window,
description=request.description,
default_temperature=request.default_temperature,
max_tokens=request.max_tokens,
)
registered = model_registry.register(model)
return {
"message": f"Model {registered.name} registered",
"model": {
"name": registered.name,
"format": registered.format.value,
"role": registered.role.value,
"path": registered.path,
},
}
@api_router.get("/{model_name}")
async def get_model(model_name: str) -> dict[str, Any]:
"""Get details of a specific model."""
model = model_registry.get(model_name)
if not model:
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
return {
"name": model.name,
"format": model.format.value,
"path": model.path,
"role": model.role.value,
"context_window": model.context_window,
"description": model.description,
"active": model.active,
"registered_at": model.registered_at,
"default_temperature": model.default_temperature,
"max_tokens": model.max_tokens,
}
@api_router.delete("/{model_name}")
async def unregister_model(model_name: str) -> dict[str, str]:
"""Remove a model from the registry."""
if not model_registry.unregister(model_name):
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
return {"message": f"Model {model_name} unregistered"}
@api_router.patch("/{model_name}/active")
async def set_model_active(
model_name: str, request: SetActiveRequest
) -> dict[str, str]:
"""Enable or disable a model."""
if not model_registry.set_active(model_name, request.active):
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
state = "enabled" if request.active else "disabled"
return {"message": f"Model {model_name} {state}"}
# ── Agent assignment endpoints ────────────────────────────────────────────────
@api_router.get("/assignments/all")
async def list_assignments() -> dict[str, Any]:
"""List all agent-to-model assignments."""
assignments = model_registry.get_agent_assignments()
return {
"assignments": [
{"agent_id": aid, "model_name": mname}
for aid, mname in assignments.items()
],
"total": len(assignments),
}
@api_router.post("/assignments")
async def assign_model(request: AssignModelRequest) -> dict[str, str]:
"""Assign a model to a swarm agent."""
if not model_registry.assign_model(request.agent_id, request.model_name):
raise HTTPException(
status_code=404,
detail=f"Model {request.model_name} not found in registry",
)
return {
"message": f"Model {request.model_name} assigned to {request.agent_id}",
}
@api_router.delete("/assignments/{agent_id}")
async def unassign_model(agent_id: str) -> dict[str, str]:
"""Remove model assignment from an agent (reverts to default)."""
if not model_registry.unassign_model(agent_id):
raise HTTPException(
status_code=404,
detail=f"No model assignment for agent {agent_id}",
)
return {"message": f"Model assignment removed for {agent_id}"}
# ── Role-based lookups ────────────────────────────────────────────────────────
@api_router.get("/roles/reward")
async def get_reward_model() -> dict[str, Any]:
"""Get the active reward (PRM) model."""
model = model_registry.get_reward_model()
if not model:
return {"reward_model": None, "reward_enabled": settings.reward_model_enabled}
return {
"reward_model": {
"name": model.name,
"format": model.format.value,
"path": model.path,
},
"reward_enabled": settings.reward_model_enabled,
}
@api_router.get("/roles/teacher")
async def get_teacher_model() -> dict[str, Any]:
"""Get the active teacher model for distillation."""
model = model_registry.get_teacher_model()
if not model:
return {"teacher_model": None}
return {
"teacher_model": {
"name": model.name,
"format": model.format.value,
"path": model.path,
},
}
# ── Dashboard page ────────────────────────────────────────────────────────────
@router.get("", response_class=HTMLResponse)
async def models_page(request: Request):
"""Custom models management dashboard page."""
models = model_registry.list_models()
assignments = model_registry.get_agent_assignments()
reward = model_registry.get_reward_model()
return templates.TemplateResponse(
request,
"models.html",
{
"page_title": "Custom Models",
"models": models,
"assignments": assignments,
"reward_model": reward,
"weights_dir": settings.custom_weights_dir,
"reward_enabled": settings.reward_model_enabled,
},
)

View File

@@ -0,0 +1,119 @@
{% extends "base.html" %}
{% block title %}Custom Models - Timmy Time{% endblock %}
{% block content %}
<div class="mc-panel">
<div class="mc-panel-header">
<h1 class="page-title">Custom Models</h1>
<p class="mc-text-secondary">Manage model weights and agent assignments</p>
</div>
<!-- Stats -->
<div class="mc-stats-row">
<div class="mc-stat-card">
<div class="mc-stat-value">{{ models|length }}</div>
<div class="mc-stat-label">Models</div>
</div>
<div class="mc-stat-card">
<div class="mc-stat-value">{{ assignments|length }}</div>
<div class="mc-stat-label">Assignments</div>
</div>
<div class="mc-stat-card">
<div class="mc-stat-value">{{ "Yes" if reward_model else "No" }}</div>
<div class="mc-stat-label">Reward Model</div>
</div>
</div>
<!-- Register Model Form -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Register Model</h2>
<form hx-post="/api/v1/models" hx-target="#model-result" hx-swap="innerHTML"
style="display: grid; gap: 0.5rem; max-width: 500px;">
<input name="name" placeholder="Model name" required class="mc-input" />
<select name="format" class="mc-input">
<option value="ollama">Ollama</option>
<option value="gguf">GGUF</option>
<option value="safetensors">Safetensors</option>
<option value="hf">HF Checkpoint</option>
</select>
<input name="path" placeholder="Path or Ollama model name" required class="mc-input" />
<select name="role" class="mc-input">
<option value="general">General</option>
<option value="reward">Reward (PRM)</option>
<option value="teacher">Teacher</option>
<option value="judge">Judge</option>
</select>
<input name="context_window" type="number" value="4096" class="mc-input" />
<input name="description" placeholder="Description (optional)" class="mc-input" />
<button type="submit" class="mc-btn mc-btn-primary">Register</button>
</form>
<div id="model-result" style="margin-top: 0.5rem;"></div>
</div>
<!-- Registered Models -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Registered Models</h2>
{% if models %}
<table class="mc-table">
<thead>
<tr>
<th>Name</th>
<th>Format</th>
<th>Role</th>
<th>Context</th>
<th>Active</th>
<th>Actions</th>
</tr>
</thead>
<tbody>
{% for m in models %}
<tr>
<td><strong>{{ m.name }}</strong></td>
<td>{{ m.format.value }}</td>
<td>{{ m.role.value }}</td>
<td>{{ m.context_window }}</td>
<td>{{ "Yes" if m.active else "No" }}</td>
<td>
<button class="mc-btn mc-btn-sm"
hx-delete="/api/v1/models/{{ m.name }}"
hx-confirm="Remove {{ m.name }}?"
hx-target="closest tr"
hx-swap="outerHTML">Remove</button>
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<p class="mc-text-secondary">No custom models registered. Use the form above or the API.</p>
{% endif %}
</div>
<!-- Agent Assignments -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Agent Model Assignments</h2>
{% if assignments %}
<table class="mc-table">
<thead>
<tr><th>Agent</th><th>Model</th></tr>
</thead>
<tbody>
{% for agent_id, model_name in assignments.items() %}
<tr>
<td>{{ agent_id }}</td>
<td>{{ model_name }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<p class="mc-text-secondary">No agent-specific model assignments. All agents use the global default.</p>
{% endif %}
</div>
<div class="mc-section" style="margin-top: 1rem;">
<p class="mc-text-secondary">Weights directory: <code>{{ weights_dir }}</code></p>
</div>
</div>
{% endblock %}

View File

View File

@@ -0,0 +1,268 @@
"""Custom model registry — register, load, and manage model weights.
Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles)
and their assignment to swarm agents. Models can be registered at
runtime via the API or pre-configured via providers.yaml.
Inspired by OpenClaw-RL's multi-model orchestration where distinct
model roles (student, teacher, judge/PRM) run on dedicated resources.
"""
import logging
import sqlite3
import threading
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Optional
from config import settings
logger = logging.getLogger(__name__)
DB_PATH = Path("data/swarm.db")
class ModelFormat(str, Enum):
"""Supported model weight formats."""
GGUF = "gguf" # Ollama-compatible quantised weights
SAFETENSORS = "safetensors" # HuggingFace safetensors
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
OLLAMA = "ollama" # Already loaded in Ollama by name
class ModelRole(str, Enum):
"""Role a model can play in the system (OpenClaw-RL style)."""
GENERAL = "general" # Default agent inference
REWARD = "reward" # Process Reward Model (PRM) scoring
TEACHER = "teacher" # On-policy distillation teacher
JUDGE = "judge" # Output quality evaluation
@dataclass
class CustomModel:
"""A registered custom model."""
name: str
format: ModelFormat
path: str # Absolute path or Ollama model name
role: ModelRole = ModelRole.GENERAL
context_window: int = 4096
description: str = ""
registered_at: str = ""
active: bool = True
# Per-model generation settings
default_temperature: float = 0.7
max_tokens: int = 2048
def __post_init__(self):
if not self.registered_at:
self.registered_at = datetime.now(timezone.utc).isoformat()
def _get_conn() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute(
"""
CREATE TABLE IF NOT EXISTS custom_models (
name TEXT PRIMARY KEY,
format TEXT NOT NULL,
path TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'general',
context_window INTEGER NOT NULL DEFAULT 4096,
description TEXT NOT NULL DEFAULT '',
registered_at TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1,
default_temperature REAL NOT NULL DEFAULT 0.7,
max_tokens INTEGER NOT NULL DEFAULT 2048
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS agent_model_assignments (
agent_id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
assigned_at TEXT NOT NULL,
FOREIGN KEY (model_name) REFERENCES custom_models(name)
)
"""
)
conn.commit()
return conn
class ModelRegistry:
"""Singleton registry for custom models and agent-model assignments."""
def __init__(self) -> None:
self._lock = threading.Lock()
# In-memory cache for fast lookups
self._models: dict[str, CustomModel] = {}
self._agent_assignments: dict[str, str] = {}
self._load_from_db()
def _load_from_db(self) -> None:
"""Bootstrap cache from SQLite."""
try:
conn = _get_conn()
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
self._models[row["name"]] = CustomModel(
name=row["name"],
format=ModelFormat(row["format"]),
path=row["path"],
role=ModelRole(row["role"]),
context_window=row["context_window"],
description=row["description"],
registered_at=row["registered_at"],
active=bool(row["active"]),
default_temperature=row["default_temperature"],
max_tokens=row["max_tokens"],
)
for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall():
self._agent_assignments[row["agent_id"]] = row["model_name"]
conn.close()
except Exception as exc:
logger.warning("Failed to load model registry from DB: %s", exc)
# ── Model CRUD ─────────────────────────────────────────────────────────
def register(self, model: CustomModel) -> CustomModel:
"""Register a new custom model."""
with self._lock:
conn = _get_conn()
conn.execute(
"""
INSERT OR REPLACE INTO custom_models
(name, format, path, role, context_window, description,
registered_at, active, default_temperature, max_tokens)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model.name, model.format.value, model.path,
model.role.value, model.context_window, model.description,
model.registered_at, int(model.active),
model.default_temperature, model.max_tokens,
),
)
conn.commit()
conn.close()
self._models[model.name] = model
logger.info("Registered model: %s (%s)", model.name, model.format.value)
return model
def unregister(self, name: str) -> bool:
"""Remove a model from the registry."""
with self._lock:
if name not in self._models:
return False
conn = _get_conn()
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
conn.execute(
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
)
conn.commit()
conn.close()
del self._models[name]
# Remove any agent assignments using this model
self._agent_assignments = {
k: v for k, v in self._agent_assignments.items() if v != name
}
logger.info("Unregistered model: %s", name)
return True
def get(self, name: str) -> Optional[CustomModel]:
"""Look up a model by name."""
return self._models.get(name)
def list_models(self, role: Optional[ModelRole] = None) -> list[CustomModel]:
"""List all registered models, optionally filtered by role."""
models = list(self._models.values())
if role is not None:
models = [m for m in models if m.role == role]
return models
def set_active(self, name: str, active: bool) -> bool:
"""Enable or disable a model without removing it."""
model = self._models.get(name)
if not model:
return False
with self._lock:
model.active = active
conn = _get_conn()
conn.execute(
"UPDATE custom_models SET active = ? WHERE name = ?",
(int(active), name),
)
conn.commit()
conn.close()
return True
# ── Agent-model assignments ────────────────────────────────────────────
def assign_model(self, agent_id: str, model_name: str) -> bool:
"""Assign a specific model to an agent."""
if model_name not in self._models:
return False
with self._lock:
now = datetime.now(timezone.utc).isoformat()
conn = _get_conn()
conn.execute(
"""
INSERT OR REPLACE INTO agent_model_assignments
(agent_id, model_name, assigned_at)
VALUES (?, ?, ?)
""",
(agent_id, model_name, now),
)
conn.commit()
conn.close()
self._agent_assignments[agent_id] = model_name
logger.info("Assigned model %s to agent %s", model_name, agent_id)
return True
def unassign_model(self, agent_id: str) -> bool:
"""Remove model assignment from an agent (falls back to default)."""
with self._lock:
if agent_id not in self._agent_assignments:
return False
conn = _get_conn()
conn.execute(
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
(agent_id,),
)
conn.commit()
conn.close()
del self._agent_assignments[agent_id]
return True
def get_agent_model(self, agent_id: str) -> Optional[CustomModel]:
"""Get the model assigned to an agent, or None for default."""
model_name = self._agent_assignments.get(agent_id)
if model_name:
return self._models.get(model_name)
return None
def get_agent_assignments(self) -> dict[str, str]:
"""Return all agent-to-model assignments."""
return dict(self._agent_assignments)
# ── Role-based lookups ─────────────────────────────────────────────────
def get_reward_model(self) -> Optional[CustomModel]:
"""Get the active reward/PRM model, if any."""
reward_models = self.list_models(role=ModelRole.REWARD)
active = [m for m in reward_models if m.active]
return active[0] if active else None
def get_teacher_model(self) -> Optional[CustomModel]:
"""Get the active teacher model for distillation."""
teacher_models = self.list_models(role=ModelRole.TEACHER)
active = [m for m in teacher_models if m.active]
return active[0] if active else None
# Module-level singleton
model_registry = ModelRegistry()

View File

@@ -251,3 +251,193 @@ def learned_keywords(agent_id: str) -> list[dict]:
results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails})
results.sort(key=lambda x: x["net"], reverse=True)
return results
# ── Reward model scoring (PRM-style) ─────────────────────────────────────────
import logging as _logging
from config import settings as _settings
_reward_logger = _logging.getLogger(__name__ + ".reward")
@dataclass
class RewardScore:
"""Result from reward-model evaluation."""
score: float # Normalised score in [-1.0, 1.0]
positive_votes: int
negative_votes: int
total_votes: int
model_used: str
def _ensure_reward_table() -> None:
"""Create the reward_scores table if needed."""
conn = _get_conn()
conn.execute(
"""
CREATE TABLE IF NOT EXISTS reward_scores (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
output_text TEXT NOT NULL,
score REAL NOT NULL,
positive INTEGER NOT NULL,
negative INTEGER NOT NULL,
total INTEGER NOT NULL,
model_used TEXT NOT NULL,
scored_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"""
)
conn.commit()
conn.close()
def score_output(
task_id: str,
agent_id: str,
task_description: str,
output_text: str,
) -> Optional[RewardScore]:
"""Score an agent's output using the reward model (majority vote).
Calls the reward model N times (settings.reward_model_votes) with a
quality-evaluation prompt. Each vote is +1 (good) or -1 (bad).
Final score is (positive - negative) / total, in [-1.0, 1.0].
Returns None if the reward model is disabled or unavailable.
"""
if not _settings.reward_model_enabled:
return None
# Resolve model name: explicit setting > registry reward model > skip
model_name = _settings.reward_model_name
if not model_name:
try:
from infrastructure.models.registry import model_registry
reward = model_registry.get_reward_model()
if reward:
model_name = reward.path if reward.format.value == "ollama" else reward.name
except Exception:
pass
if not model_name:
_reward_logger.debug("No reward model configured, skipping scoring")
return None
num_votes = max(1, _settings.reward_model_votes)
positive = 0
negative = 0
prompt = (
f"You are a quality evaluator. Rate the following agent output.\n\n"
f"TASK: {task_description}\n\n"
f"OUTPUT:\n{output_text[:2000]}\n\n"
f"Is this output correct, helpful, and complete? "
f"Reply with exactly one word: GOOD or BAD."
)
try:
import requests as _req
ollama_url = _settings.ollama_url
for _ in range(num_votes):
try:
resp = _req.post(
f"{ollama_url}/api/generate",
json={
"model": model_name,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.3, "num_predict": 10},
},
timeout=30,
)
if resp.status_code == 200:
answer = resp.json().get("response", "").strip().upper()
if "GOOD" in answer:
positive += 1
else:
negative += 1
else:
negative += 1 # Treat errors as negative conservatively
except Exception as vote_exc:
_reward_logger.debug("Vote failed: %s", vote_exc)
negative += 1
except ImportError:
_reward_logger.warning("requests library not available for reward scoring")
return None
total = positive + negative
if total == 0:
return None
score = (positive - negative) / total
result = RewardScore(
score=score,
positive_votes=positive,
negative_votes=negative,
total_votes=total,
model_used=model_name,
)
# Persist to DB
try:
_ensure_reward_table()
conn = _get_conn()
conn.execute(
"""
INSERT INTO reward_scores
(task_id, agent_id, output_text, score, positive, negative, total, model_used)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
task_id, agent_id, output_text[:5000],
score, positive, negative, total, model_name,
),
)
conn.commit()
conn.close()
except Exception as db_exc:
_reward_logger.warning("Failed to persist reward score: %s", db_exc)
_reward_logger.info(
"Scored task %s agent %s: %.2f (%d+/%d- of %d votes)",
task_id, agent_id, score, positive, negative, total,
)
return result
def get_reward_scores(
agent_id: Optional[str] = None, limit: int = 50
) -> list[dict]:
"""Retrieve historical reward scores from the database."""
_ensure_reward_table()
conn = _get_conn()
if agent_id:
rows = conn.execute(
"SELECT * FROM reward_scores WHERE agent_id = ? ORDER BY id DESC LIMIT ?",
(agent_id, limit),
).fetchall()
else:
rows = conn.execute(
"SELECT * FROM reward_scores ORDER BY id DESC LIMIT ?",
(limit,),
).fetchall()
conn.close()
return [
{
"task_id": r["task_id"],
"agent_id": r["agent_id"],
"score": r["score"],
"positive": r["positive"],
"negative": r["negative"],
"total": r["total"],
"model_used": r["model_used"],
"scored_at": r["scored_at"],
}
for r in rows
]

View File

@@ -51,6 +51,16 @@ class PersonaNode(SwarmNode):
self._meta = meta
self._persona_id = persona_id
self._use_learner = use_learner
# Resolve model: registry assignment > persona default > global default
self._model_name: Optional[str] = meta.get("model")
try:
from infrastructure.models.registry import model_registry
assigned = model_registry.get_agent_model(agent_id)
if assigned:
self._model_name = assigned.name
except Exception:
pass # Graceful degradation — use persona/global default
# Initialize tool executor for task execution
self._tool_executor: Optional[ToolExecutor] = None
@@ -213,6 +223,11 @@ class PersonaNode(SwarmNode):
"""Return the task ID currently being executed, if any."""
return self._current_task
@property
def model_name(self) -> Optional[str]:
"""Return the model this agent uses, or None for global default."""
return self._model_name
@property
def tool_capabilities(self) -> list[str]:
"""Return list of available tool names."""

View File

@@ -14,7 +14,7 @@ from __future__ import annotations
from typing import TypedDict
class PersonaMeta(TypedDict):
class PersonaMeta(TypedDict, total=False):
id: str
name: str
role: str
@@ -24,6 +24,11 @@ class PersonaMeta(TypedDict):
bid_base: int # typical bid when task matches persona
bid_jitter: int # ± random jitter added to bid_base
preferred_keywords: list[str]
# Optional: custom model override for this persona.
# When set, this persona uses this model instead of the global default.
# Value is a model name registered in the ModelRegistry, or an Ollama
# model name like "llama3.2" or "deepseek-r1:1.5b".
model: str
PERSONAS: dict[str, PersonaMeta] = {

View File

@@ -0,0 +1,217 @@
"""Tests for the custom model registry."""
import sqlite3
from pathlib import Path
from unittest.mock import patch
import pytest
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
)
@pytest.fixture
def registry(tmp_path):
"""Create a fresh ModelRegistry backed by a temporary database."""
db = tmp_path / "test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
reg = ModelRegistry()
yield reg
@pytest.fixture
def sample_model():
"""A sample CustomModel for testing."""
return CustomModel(
name="test-llama",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
context_window=8192,
description="Test model",
)
@pytest.fixture
def reward_model():
"""A sample reward model."""
return CustomModel(
name="test-reward",
format=ModelFormat.OLLAMA,
path="deepseek-r1:1.5b",
role=ModelRole.REWARD,
context_window=32000,
description="Test reward model",
)
class TestModelCRUD:
"""Test model registration, lookup, and removal."""
def test_register_model(self, registry, sample_model):
registered = registry.register(sample_model)
assert registered.name == "test-llama"
assert registered.format == ModelFormat.OLLAMA
def test_get_model(self, registry, sample_model):
registry.register(sample_model)
found = registry.get("test-llama")
assert found is not None
assert found.name == "test-llama"
assert found.path == "llama3.2"
def test_get_nonexistent_model(self, registry):
assert registry.get("nonexistent") is None
def test_list_models(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
all_models = registry.list_models()
assert len(all_models) == 2
def test_list_models_by_role(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
general = registry.list_models(role=ModelRole.GENERAL)
assert len(general) == 1
assert general[0].name == "test-llama"
rewards = registry.list_models(role=ModelRole.REWARD)
assert len(rewards) == 1
assert rewards[0].name == "test-reward"
def test_unregister_model(self, registry, sample_model):
registry.register(sample_model)
assert registry.unregister("test-llama") is True
assert registry.get("test-llama") is None
def test_unregister_nonexistent(self, registry):
assert registry.unregister("nonexistent") is False
def test_set_active(self, registry, sample_model):
registry.register(sample_model)
assert registry.set_active("test-llama", False) is True
model = registry.get("test-llama")
assert model.active is False
assert registry.set_active("test-llama", True) is True
model = registry.get("test-llama")
assert model.active is True
def test_set_active_nonexistent(self, registry):
assert registry.set_active("nonexistent", True) is False
def test_register_replaces_existing(self, registry, sample_model):
registry.register(sample_model)
updated = CustomModel(
name="test-llama",
format=ModelFormat.GGUF,
path="/new/path.gguf",
role=ModelRole.GENERAL,
description="Updated model",
)
registry.register(updated)
found = registry.get("test-llama")
assert found.format == ModelFormat.GGUF
assert found.path == "/new/path.gguf"
class TestAgentAssignments:
"""Test agent-to-model assignment management."""
def test_assign_model(self, registry, sample_model):
registry.register(sample_model)
assert registry.assign_model("agent-1", "test-llama") is True
model = registry.get_agent_model("agent-1")
assert model is not None
assert model.name == "test-llama"
def test_assign_nonexistent_model(self, registry):
assert registry.assign_model("agent-1", "nonexistent") is False
def test_unassign_model(self, registry, sample_model):
registry.register(sample_model)
registry.assign_model("agent-1", "test-llama")
assert registry.unassign_model("agent-1") is True
assert registry.get_agent_model("agent-1") is None
def test_unassign_nonexistent(self, registry):
assert registry.unassign_model("agent-1") is False
def test_get_agent_model_none(self, registry):
assert registry.get_agent_model("agent-1") is None
def test_get_all_assignments(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
registry.assign_model("agent-1", "test-llama")
registry.assign_model("agent-2", "test-reward")
assignments = registry.get_agent_assignments()
assert len(assignments) == 2
assert assignments["agent-1"] == "test-llama"
assert assignments["agent-2"] == "test-reward"
def test_unregister_removes_assignments(self, registry, sample_model):
registry.register(sample_model)
registry.assign_model("agent-1", "test-llama")
registry.unregister("test-llama")
assert registry.get_agent_model("agent-1") is None
assert len(registry.get_agent_assignments()) == 0
class TestRoleLookups:
"""Test role-based model lookups."""
def test_get_reward_model(self, registry, reward_model):
registry.register(reward_model)
found = registry.get_reward_model()
assert found is not None
assert found.name == "test-reward"
assert found.role == ModelRole.REWARD
def test_get_reward_model_none(self, registry):
assert registry.get_reward_model() is None
def test_get_teacher_model(self, registry):
teacher = CustomModel(
name="teacher-model",
format=ModelFormat.OLLAMA,
path="teacher:latest",
role=ModelRole.TEACHER,
)
registry.register(teacher)
found = registry.get_teacher_model()
assert found is not None
assert found.name == "teacher-model"
def test_get_teacher_model_none(self, registry):
assert registry.get_teacher_model() is None
def test_inactive_reward_model_not_returned(self, registry, reward_model):
registry.register(reward_model)
registry.set_active("test-reward", False)
assert registry.get_reward_model() is None
class TestCustomModelDataclass:
"""Test CustomModel construction."""
def test_default_registered_at(self):
model = CustomModel(
name="test", format=ModelFormat.OLLAMA, path="test"
)
assert model.registered_at != ""
def test_model_roles(self):
assert ModelRole.GENERAL.value == "general"
assert ModelRole.REWARD.value == "reward"
assert ModelRole.TEACHER.value == "teacher"
assert ModelRole.JUDGE.value == "judge"
def test_model_formats(self):
assert ModelFormat.GGUF.value == "gguf"
assert ModelFormat.SAFETENSORS.value == "safetensors"
assert ModelFormat.HF_CHECKPOINT.value == "hf"
assert ModelFormat.OLLAMA.value == "ollama"

View File

@@ -0,0 +1,273 @@
"""Tests for the custom models API routes."""
from unittest.mock import patch, MagicMock
import pytest
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
)
@pytest.fixture
def registry(tmp_path):
"""A fresh ModelRegistry for each test."""
db = tmp_path / "api_test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
reg = ModelRegistry()
yield reg
class TestModelsAPIList:
"""Test listing models via the API."""
def test_list_models_empty(self, client, tmp_path):
db = tmp_path / "api.db"
with patch("infrastructure.models.registry.DB_PATH", db):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.list_models.return_value = []
resp = client.get("/api/v1/models")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert "total" in data
def test_list_models_with_data(self, client):
model = CustomModel(
name="test-m",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.list_models.return_value = [model]
resp = client.get("/api/v1/models")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["models"][0]["name"] == "test-m"
class TestModelsAPIRegister:
"""Test model registration via the API."""
def test_register_ollama_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.register.return_value = CustomModel(
name="my-model",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
resp = client.post(
"/api/v1/models",
json={
"name": "my-model",
"format": "ollama",
"path": "llama3.2",
"role": "general",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["model"]["name"] == "my-model"
def test_register_invalid_format(self, client):
resp = client.post(
"/api/v1/models",
json={
"name": "bad-model",
"format": "invalid_format",
"path": "whatever",
},
)
assert resp.status_code == 400
assert "Invalid format" in resp.json()["detail"]
def test_register_invalid_role(self, client):
resp = client.post(
"/api/v1/models",
json={
"name": "bad-model",
"format": "ollama",
"path": "llama3.2",
"role": "invalid_role",
},
)
assert resp.status_code == 400
assert "Invalid role" in resp.json()["detail"]
class TestModelsAPIDelete:
"""Test model deletion via the API."""
def test_delete_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unregister.return_value = True
resp = client.delete("/api/v1/models/my-model")
assert resp.status_code == 200
def test_delete_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unregister.return_value = False
resp = client.delete("/api/v1/models/nonexistent")
assert resp.status_code == 404
class TestModelsAPIGet:
"""Test getting a specific model."""
def test_get_model(self, client):
model = CustomModel(
name="my-model",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get.return_value = model
resp = client.get("/api/v1/models/my-model")
assert resp.status_code == 200
assert resp.json()["name"] == "my-model"
def test_get_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get.return_value = None
resp = client.get("/api/v1/models/nonexistent")
assert resp.status_code == 404
class TestModelsAPIAssignments:
"""Test agent model assignment endpoints."""
def test_assign_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.assign_model.return_value = True
resp = client.post(
"/api/v1/models/assignments",
json={"agent_id": "agent-1", "model_name": "my-model"},
)
assert resp.status_code == 200
def test_assign_nonexistent_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.assign_model.return_value = False
resp = client.post(
"/api/v1/models/assignments",
json={"agent_id": "agent-1", "model_name": "nonexistent"},
)
assert resp.status_code == 404
def test_unassign_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unassign_model.return_value = True
resp = client.delete("/api/v1/models/assignments/agent-1")
assert resp.status_code == 200
def test_unassign_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unassign_model.return_value = False
resp = client.delete("/api/v1/models/assignments/nonexistent")
assert resp.status_code == 404
def test_list_assignments(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_agent_assignments.return_value = {
"agent-1": "model-a",
"agent-2": "model-b",
}
resp = client.get("/api/v1/models/assignments/all")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
class TestModelsAPIRoles:
"""Test role-based lookup endpoints."""
def test_get_reward_model(self, client):
model = CustomModel(
name="reward-m",
format=ModelFormat.OLLAMA,
path="deepseek-r1:1.5b",
role=ModelRole.REWARD,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = model
resp = client.get("/api/v1/models/roles/reward")
assert resp.status_code == 200
data = resp.json()
assert data["reward_model"]["name"] == "reward-m"
def test_get_reward_model_none(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = None
resp = client.get("/api/v1/models/roles/reward")
assert resp.status_code == 200
assert resp.json()["reward_model"] is None
def test_get_teacher_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_teacher_model.return_value = None
resp = client.get("/api/v1/models/roles/teacher")
assert resp.status_code == 200
assert resp.json()["teacher_model"] is None
class TestModelsAPISetActive:
"""Test enable/disable model endpoint."""
def test_enable_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.set_active.return_value = True
resp = client.patch(
"/api/v1/models/my-model/active",
json={"active": True},
)
assert resp.status_code == 200
def test_disable_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.set_active.return_value = False
resp = client.patch(
"/api/v1/models/nonexistent/active",
json={"active": False},
)
assert resp.status_code == 404

View File

@@ -0,0 +1,197 @@
"""Tests for reward model scoring in the swarm learner."""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from swarm.learner import (
RewardScore,
get_reward_scores,
score_output,
)
@pytest.fixture(autouse=True)
def _isolate_db(tmp_path):
"""Point the learner at a temporary database."""
db = tmp_path / "learner_test.db"
with patch("swarm.learner.DB_PATH", db):
yield
class TestScoreOutput:
"""Test the score_output function."""
def test_returns_none_when_disabled(self):
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = False
result = score_output("task-1", "agent-1", "do X", "done X")
assert result is None
def test_returns_none_when_no_model(self):
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = ""
with patch(
"infrastructure.models.registry.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = None
result = score_output("task-1", "agent-1", "do X", "done X")
assert result is None
def test_positive_scoring(self):
"""All votes return GOOD → score = 1.0."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "GOOD"}
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = "test-model"
mock_s.reward_model_votes = 3
mock_s.ollama_url = "http://localhost:11434"
with patch("requests.post", return_value=mock_response):
result = score_output("task-1", "agent-1", "do X", "done X")
assert result is not None
assert result.score == 1.0
assert result.positive_votes == 3
assert result.negative_votes == 0
assert result.total_votes == 3
assert result.model_used == "test-model"
def test_negative_scoring(self):
"""All votes return BAD → score = -1.0."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "BAD"}
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = "test-model"
mock_s.reward_model_votes = 3
mock_s.ollama_url = "http://localhost:11434"
with patch("requests.post", return_value=mock_response):
result = score_output("task-1", "agent-1", "do X", "bad output")
assert result is not None
assert result.score == -1.0
assert result.negative_votes == 3
def test_mixed_scoring(self):
"""2 GOOD + 1 BAD → score ≈ 0.33."""
responses = []
for text in ["GOOD", "GOOD", "BAD"]:
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {"response": text}
responses.append(resp)
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = "test-model"
mock_s.reward_model_votes = 3
mock_s.ollama_url = "http://localhost:11434"
with patch("requests.post", side_effect=responses):
result = score_output("task-1", "agent-1", "do X", "ok output")
assert result is not None
assert abs(result.score - (1 / 3)) < 0.01
assert result.positive_votes == 2
assert result.negative_votes == 1
def test_uses_registry_reward_model(self):
"""Falls back to registry reward model when setting is empty."""
mock_model = MagicMock()
mock_model.path = "registry-reward-model"
mock_model.format = MagicMock()
mock_model.format.value = "ollama"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "GOOD"}
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = ""
mock_s.reward_model_votes = 1
mock_s.ollama_url = "http://localhost:11434"
with patch(
"infrastructure.models.registry.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = mock_model
with patch("requests.post", return_value=mock_response):
result = score_output("task-1", "agent-1", "do X", "ok")
assert result is not None
assert result.model_used == "registry-reward-model"
class TestGetRewardScores:
"""Test retrieving historical reward scores."""
def test_empty_history(self):
scores = get_reward_scores()
assert scores == []
def test_scores_persisted(self):
"""Scores from score_output are retrievable."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "GOOD"}
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = "test-model"
mock_s.reward_model_votes = 1
mock_s.ollama_url = "http://localhost:11434"
with patch("requests.post", return_value=mock_response):
score_output("task-1", "agent-1", "do X", "done X")
scores = get_reward_scores()
assert len(scores) == 1
assert scores[0]["task_id"] == "task-1"
assert scores[0]["agent_id"] == "agent-1"
assert scores[0]["score"] == 1.0
def test_filter_by_agent(self):
"""Filter scores by agent_id."""
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"response": "GOOD"}
with patch("swarm.learner._settings") as mock_s:
mock_s.reward_model_enabled = True
mock_s.reward_model_name = "test-model"
mock_s.reward_model_votes = 1
mock_s.ollama_url = "http://localhost:11434"
with patch("requests.post", return_value=mock_response):
score_output("task-1", "agent-1", "task A", "output A")
score_output("task-2", "agent-2", "task B", "output B")
agent1_scores = get_reward_scores(agent_id="agent-1")
assert len(agent1_scores) == 1
assert agent1_scores[0]["agent_id"] == "agent-1"
class TestRewardScoreDataclass:
"""Test RewardScore construction."""
def test_create_score(self):
score = RewardScore(
score=0.5,
positive_votes=3,
negative_votes=1,
total_votes=4,
model_used="test-model",
)
assert score.score == 0.5
assert score.total_votes == 4