feat: add custom weights, model registry, per-agent models, and reward scoring
Inspired by OpenClaw-RL's multi-model orchestration, this adds four features for custom model management: 1. Custom model registry (infrastructure/models/registry.py) — SQLite-backed registry for GGUF, safetensors, HF checkpoint, and Ollama models with role-based lookups (general, reward, teacher, judge). 2. Per-agent model assignment — each swarm persona can use a different model instead of sharing the global default. Resolved via registry assignment > persona default > global default. 3. Runtime model management API (/api/v1/models) — REST endpoints to register, list, assign, enable/disable, and remove custom models without restart. Includes a dashboard page at /models. 4. Reward model scoring (PRM-style) — majority-vote quality evaluation of agent outputs using a configurable reward model. Scores persist in SQLite and feed into the swarm learner. New config settings: custom_weights_dir, reward_model_enabled, reward_model_name, reward_model_votes. 54 new tests covering registry CRUD, API endpoints, agent assignments, role lookups, and reward scoring. https://claude.ai/code/session_01V4iTozMwcE2gjfnCJdCugC
This commit is contained in:
@@ -68,6 +68,37 @@ providers:
|
||||
- name: claude-3-sonnet-20240229
|
||||
context_window: 200000
|
||||
|
||||
# ── Custom Models ──────────────────────────────────────────────────────
|
||||
# Register custom model weights for per-agent assignment.
|
||||
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
|
||||
# Models can also be registered at runtime via the /api/v1/models API.
|
||||
#
|
||||
# Roles: general (default inference), reward (PRM scoring),
|
||||
# teacher (distillation), judge (output evaluation)
|
||||
custom_models: []
|
||||
# Example entries:
|
||||
# - name: my-finetuned-llama
|
||||
# format: gguf
|
||||
# path: /path/to/model.gguf
|
||||
# role: general
|
||||
# context_window: 8192
|
||||
# description: "Fine-tuned Llama for code tasks"
|
||||
#
|
||||
# - name: reward-model
|
||||
# format: ollama
|
||||
# path: deepseek-r1:1.5b
|
||||
# role: reward
|
||||
# context_window: 32000
|
||||
# description: "Process reward model for scoring outputs"
|
||||
|
||||
# ── Agent Model Assignments ─────────────────────────────────────────────
|
||||
# Map persona agent IDs to specific models.
|
||||
# Agents without an assignment use the global default (ollama_model).
|
||||
agent_model_assignments: {}
|
||||
# Example:
|
||||
# persona-forge: my-finetuned-llama
|
||||
# persona-echo: deepseek-r1:1.5b
|
||||
|
||||
# Cost tracking (optional, for budget monitoring)
|
||||
cost_tracking:
|
||||
enabled: true
|
||||
|
||||
@@ -103,6 +103,17 @@ class Settings(BaseSettings):
|
||||
work_orders_auto_execute: bool = False # Master switch for auto-execution
|
||||
work_orders_auto_threshold: str = "low" # Max priority that auto-executes: "low" | "medium" | "high" | "none"
|
||||
|
||||
# ── Custom Weights & Models ──────────────────────────────────────
|
||||
# Directory for custom model weights (GGUF, safetensors, HF checkpoints).
|
||||
# Models placed here can be registered at runtime and assigned to agents.
|
||||
custom_weights_dir: str = "data/models"
|
||||
# Enable the reward model for scoring agent outputs (PRM-style).
|
||||
reward_model_enabled: bool = False
|
||||
# Reward model name (must be available via Ollama or a custom weight path).
|
||||
reward_model_name: str = ""
|
||||
# Minimum votes for majority-vote reward scoring (odd number recommended).
|
||||
reward_model_votes: int = 3
|
||||
|
||||
# ── Browser Local Models (iPhone / WebGPU) ───────────────────────
|
||||
# Enable in-browser LLM inference via WebLLM for offline iPhone use.
|
||||
# When enabled, the mobile dashboard loads a small model directly
|
||||
|
||||
@@ -36,6 +36,8 @@ from dashboard.routes.self_coding import router as self_coding_router
|
||||
from dashboard.routes.self_coding import self_modify_router
|
||||
from dashboard.routes.hands import router as hands_router
|
||||
from dashboard.routes.grok import router as grok_router
|
||||
from dashboard.routes.models import router as models_router
|
||||
from dashboard.routes.models import api_router as models_api_router
|
||||
from infrastructure.router.api import router as cascade_router
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -208,6 +210,8 @@ app.include_router(tasks_router)
|
||||
app.include_router(scripture_router)
|
||||
app.include_router(hands_router)
|
||||
app.include_router(grok_router)
|
||||
app.include_router(models_router)
|
||||
app.include_router(models_api_router)
|
||||
app.include_router(cascade_router)
|
||||
|
||||
|
||||
|
||||
272
src/dashboard/routes/models.py
Normal file
272
src/dashboard/routes/models.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Custom model management routes — register, list, assign, and swap models.
|
||||
|
||||
Provides a REST API for managing custom model weights and their assignment
|
||||
to swarm agents. Inspired by OpenClaw-RL's multi-model orchestration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import settings
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
|
||||
templates = Jinja2Templates(directory=str(Path(__file__).parent.parent / "templates"))
|
||||
|
||||
|
||||
# ── Pydantic schemas ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RegisterModelRequest(BaseModel):
|
||||
"""Request body for model registration."""
|
||||
name: str
|
||||
format: str # gguf, safetensors, hf, ollama
|
||||
path: str
|
||||
role: str = "general"
|
||||
context_window: int = 4096
|
||||
description: str = ""
|
||||
default_temperature: float = 0.7
|
||||
max_tokens: int = 2048
|
||||
|
||||
|
||||
class AssignModelRequest(BaseModel):
|
||||
"""Request body for assigning a model to an agent."""
|
||||
agent_id: str
|
||||
model_name: str
|
||||
|
||||
|
||||
class SetActiveRequest(BaseModel):
|
||||
"""Request body for enabling/disabling a model."""
|
||||
active: bool
|
||||
|
||||
|
||||
# ── API endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@api_router.get("")
|
||||
async def list_models(role: Optional[str] = None) -> dict[str, Any]:
|
||||
"""List all registered custom models."""
|
||||
model_role = ModelRole(role) if role else None
|
||||
models = model_registry.list_models(role=model_role)
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": m.name,
|
||||
"format": m.format.value,
|
||||
"path": m.path,
|
||||
"role": m.role.value,
|
||||
"context_window": m.context_window,
|
||||
"description": m.description,
|
||||
"active": m.active,
|
||||
"registered_at": m.registered_at,
|
||||
"default_temperature": m.default_temperature,
|
||||
"max_tokens": m.max_tokens,
|
||||
}
|
||||
for m in models
|
||||
],
|
||||
"total": len(models),
|
||||
"weights_dir": settings.custom_weights_dir,
|
||||
}
|
||||
|
||||
|
||||
@api_router.post("")
|
||||
async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
|
||||
"""Register a new custom model."""
|
||||
try:
|
||||
fmt = ModelFormat(request.format)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid format: {request.format}. "
|
||||
f"Choose from: {[f.value for f in ModelFormat]}",
|
||||
)
|
||||
try:
|
||||
role = ModelRole(request.role)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid role: {request.role}. "
|
||||
f"Choose from: {[r.value for r in ModelRole]}",
|
||||
)
|
||||
|
||||
# Validate path exists for non-Ollama formats
|
||||
if fmt != ModelFormat.OLLAMA:
|
||||
weight_path = Path(request.path)
|
||||
if not weight_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Weight path does not exist: {request.path}",
|
||||
)
|
||||
|
||||
model = CustomModel(
|
||||
name=request.name,
|
||||
format=fmt,
|
||||
path=request.path,
|
||||
role=role,
|
||||
context_window=request.context_window,
|
||||
description=request.description,
|
||||
default_temperature=request.default_temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
)
|
||||
registered = model_registry.register(model)
|
||||
return {
|
||||
"message": f"Model {registered.name} registered",
|
||||
"model": {
|
||||
"name": registered.name,
|
||||
"format": registered.format.value,
|
||||
"role": registered.role.value,
|
||||
"path": registered.path,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@api_router.get("/{model_name}")
|
||||
async def get_model(model_name: str) -> dict[str, Any]:
|
||||
"""Get details of a specific model."""
|
||||
model = model_registry.get(model_name)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
return {
|
||||
"name": model.name,
|
||||
"format": model.format.value,
|
||||
"path": model.path,
|
||||
"role": model.role.value,
|
||||
"context_window": model.context_window,
|
||||
"description": model.description,
|
||||
"active": model.active,
|
||||
"registered_at": model.registered_at,
|
||||
"default_temperature": model.default_temperature,
|
||||
"max_tokens": model.max_tokens,
|
||||
}
|
||||
|
||||
|
||||
@api_router.delete("/{model_name}")
|
||||
async def unregister_model(model_name: str) -> dict[str, str]:
|
||||
"""Remove a model from the registry."""
|
||||
if not model_registry.unregister(model_name):
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
return {"message": f"Model {model_name} unregistered"}
|
||||
|
||||
|
||||
@api_router.patch("/{model_name}/active")
|
||||
async def set_model_active(
|
||||
model_name: str, request: SetActiveRequest
|
||||
) -> dict[str, str]:
|
||||
"""Enable or disable a model."""
|
||||
if not model_registry.set_active(model_name, request.active):
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
state = "enabled" if request.active else "disabled"
|
||||
return {"message": f"Model {model_name} {state}"}
|
||||
|
||||
|
||||
# ── Agent assignment endpoints ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@api_router.get("/assignments/all")
|
||||
async def list_assignments() -> dict[str, Any]:
|
||||
"""List all agent-to-model assignments."""
|
||||
assignments = model_registry.get_agent_assignments()
|
||||
return {
|
||||
"assignments": [
|
||||
{"agent_id": aid, "model_name": mname}
|
||||
for aid, mname in assignments.items()
|
||||
],
|
||||
"total": len(assignments),
|
||||
}
|
||||
|
||||
|
||||
@api_router.post("/assignments")
|
||||
async def assign_model(request: AssignModelRequest) -> dict[str, str]:
|
||||
"""Assign a model to a swarm agent."""
|
||||
if not model_registry.assign_model(request.agent_id, request.model_name):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Model {request.model_name} not found in registry",
|
||||
)
|
||||
return {
|
||||
"message": f"Model {request.model_name} assigned to {request.agent_id}",
|
||||
}
|
||||
|
||||
|
||||
@api_router.delete("/assignments/{agent_id}")
|
||||
async def unassign_model(agent_id: str) -> dict[str, str]:
|
||||
"""Remove model assignment from an agent (reverts to default)."""
|
||||
if not model_registry.unassign_model(agent_id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No model assignment for agent {agent_id}",
|
||||
)
|
||||
return {"message": f"Model assignment removed for {agent_id}"}
|
||||
|
||||
|
||||
# ── Role-based lookups ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@api_router.get("/roles/reward")
|
||||
async def get_reward_model() -> dict[str, Any]:
|
||||
"""Get the active reward (PRM) model."""
|
||||
model = model_registry.get_reward_model()
|
||||
if not model:
|
||||
return {"reward_model": None, "reward_enabled": settings.reward_model_enabled}
|
||||
return {
|
||||
"reward_model": {
|
||||
"name": model.name,
|
||||
"format": model.format.value,
|
||||
"path": model.path,
|
||||
},
|
||||
"reward_enabled": settings.reward_model_enabled,
|
||||
}
|
||||
|
||||
|
||||
@api_router.get("/roles/teacher")
|
||||
async def get_teacher_model() -> dict[str, Any]:
|
||||
"""Get the active teacher model for distillation."""
|
||||
model = model_registry.get_teacher_model()
|
||||
if not model:
|
||||
return {"teacher_model": None}
|
||||
return {
|
||||
"teacher_model": {
|
||||
"name": model.name,
|
||||
"format": model.format.value,
|
||||
"path": model.path,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ── Dashboard page ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", response_class=HTMLResponse)
|
||||
async def models_page(request: Request):
|
||||
"""Custom models management dashboard page."""
|
||||
models = model_registry.list_models()
|
||||
assignments = model_registry.get_agent_assignments()
|
||||
reward = model_registry.get_reward_model()
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"models.html",
|
||||
{
|
||||
"page_title": "Custom Models",
|
||||
"models": models,
|
||||
"assignments": assignments,
|
||||
"reward_model": reward,
|
||||
"weights_dir": settings.custom_weights_dir,
|
||||
"reward_enabled": settings.reward_model_enabled,
|
||||
},
|
||||
)
|
||||
119
src/dashboard/templates/models.html
Normal file
119
src/dashboard/templates/models.html
Normal file
@@ -0,0 +1,119 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Custom Models - Timmy Time{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="mc-panel">
|
||||
<div class="mc-panel-header">
|
||||
<h1 class="page-title">Custom Models</h1>
|
||||
<p class="mc-text-secondary">Manage model weights and agent assignments</p>
|
||||
</div>
|
||||
|
||||
<!-- Stats -->
|
||||
<div class="mc-stats-row">
|
||||
<div class="mc-stat-card">
|
||||
<div class="mc-stat-value">{{ models|length }}</div>
|
||||
<div class="mc-stat-label">Models</div>
|
||||
</div>
|
||||
<div class="mc-stat-card">
|
||||
<div class="mc-stat-value">{{ assignments|length }}</div>
|
||||
<div class="mc-stat-label">Assignments</div>
|
||||
</div>
|
||||
<div class="mc-stat-card">
|
||||
<div class="mc-stat-value">{{ "Yes" if reward_model else "No" }}</div>
|
||||
<div class="mc-stat-label">Reward Model</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Register Model Form -->
|
||||
<div class="mc-section" style="margin-top: 1.5rem;">
|
||||
<h2>Register Model</h2>
|
||||
<form hx-post="/api/v1/models" hx-target="#model-result" hx-swap="innerHTML"
|
||||
style="display: grid; gap: 0.5rem; max-width: 500px;">
|
||||
<input name="name" placeholder="Model name" required class="mc-input" />
|
||||
<select name="format" class="mc-input">
|
||||
<option value="ollama">Ollama</option>
|
||||
<option value="gguf">GGUF</option>
|
||||
<option value="safetensors">Safetensors</option>
|
||||
<option value="hf">HF Checkpoint</option>
|
||||
</select>
|
||||
<input name="path" placeholder="Path or Ollama model name" required class="mc-input" />
|
||||
<select name="role" class="mc-input">
|
||||
<option value="general">General</option>
|
||||
<option value="reward">Reward (PRM)</option>
|
||||
<option value="teacher">Teacher</option>
|
||||
<option value="judge">Judge</option>
|
||||
</select>
|
||||
<input name="context_window" type="number" value="4096" class="mc-input" />
|
||||
<input name="description" placeholder="Description (optional)" class="mc-input" />
|
||||
<button type="submit" class="mc-btn mc-btn-primary">Register</button>
|
||||
</form>
|
||||
<div id="model-result" style="margin-top: 0.5rem;"></div>
|
||||
</div>
|
||||
|
||||
<!-- Registered Models -->
|
||||
<div class="mc-section" style="margin-top: 1.5rem;">
|
||||
<h2>Registered Models</h2>
|
||||
{% if models %}
|
||||
<table class="mc-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Format</th>
|
||||
<th>Role</th>
|
||||
<th>Context</th>
|
||||
<th>Active</th>
|
||||
<th>Actions</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for m in models %}
|
||||
<tr>
|
||||
<td><strong>{{ m.name }}</strong></td>
|
||||
<td>{{ m.format.value }}</td>
|
||||
<td>{{ m.role.value }}</td>
|
||||
<td>{{ m.context_window }}</td>
|
||||
<td>{{ "Yes" if m.active else "No" }}</td>
|
||||
<td>
|
||||
<button class="mc-btn mc-btn-sm"
|
||||
hx-delete="/api/v1/models/{{ m.name }}"
|
||||
hx-confirm="Remove {{ m.name }}?"
|
||||
hx-target="closest tr"
|
||||
hx-swap="outerHTML">Remove</button>
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p class="mc-text-secondary">No custom models registered. Use the form above or the API.</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<!-- Agent Assignments -->
|
||||
<div class="mc-section" style="margin-top: 1.5rem;">
|
||||
<h2>Agent Model Assignments</h2>
|
||||
{% if assignments %}
|
||||
<table class="mc-table">
|
||||
<thead>
|
||||
<tr><th>Agent</th><th>Model</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for agent_id, model_name in assignments.items() %}
|
||||
<tr>
|
||||
<td>{{ agent_id }}</td>
|
||||
<td>{{ model_name }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p class="mc-text-secondary">No agent-specific model assignments. All agents use the global default.</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
|
||||
<div class="mc-section" style="margin-top: 1rem;">
|
||||
<p class="mc-text-secondary">Weights directory: <code>{{ weights_dir }}</code></p>
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %}
|
||||
0
src/infrastructure/models/__init__.py
Normal file
0
src/infrastructure/models/__init__.py
Normal file
268
src/infrastructure/models/registry.py
Normal file
268
src/infrastructure/models/registry.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Custom model registry — register, load, and manage model weights.
|
||||
|
||||
Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles)
|
||||
and their assignment to swarm agents. Models can be registered at
|
||||
runtime via the API or pre-configured via providers.yaml.
|
||||
|
||||
Inspired by OpenClaw-RL's multi-model orchestration where distinct
|
||||
model roles (student, teacher, judge/PRM) run on dedicated resources.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_PATH = Path("data/swarm.db")
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Supported model weight formats."""
|
||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||
|
||||
|
||||
class ModelRole(str, Enum):
|
||||
"""Role a model can play in the system (OpenClaw-RL style)."""
|
||||
GENERAL = "general" # Default agent inference
|
||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||
TEACHER = "teacher" # On-policy distillation teacher
|
||||
JUDGE = "judge" # Output quality evaluation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomModel:
|
||||
"""A registered custom model."""
|
||||
name: str
|
||||
format: ModelFormat
|
||||
path: str # Absolute path or Ollama model name
|
||||
role: ModelRole = ModelRole.GENERAL
|
||||
context_window: int = 4096
|
||||
description: str = ""
|
||||
registered_at: str = ""
|
||||
active: bool = True
|
||||
# Per-model generation settings
|
||||
default_temperature: float = 0.7
|
||||
max_tokens: int = 2048
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.registered_at:
|
||||
self.registered_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _get_conn() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS custom_models (
|
||||
name TEXT PRIMARY KEY,
|
||||
format TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
role TEXT NOT NULL DEFAULT 'general',
|
||||
context_window INTEGER NOT NULL DEFAULT 4096,
|
||||
description TEXT NOT NULL DEFAULT '',
|
||||
registered_at TEXT NOT NULL,
|
||||
active INTEGER NOT NULL DEFAULT 1,
|
||||
default_temperature REAL NOT NULL DEFAULT 0.7,
|
||||
max_tokens INTEGER NOT NULL DEFAULT 2048
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS agent_model_assignments (
|
||||
agent_id TEXT PRIMARY KEY,
|
||||
model_name TEXT NOT NULL,
|
||||
assigned_at TEXT NOT NULL,
|
||||
FOREIGN KEY (model_name) REFERENCES custom_models(name)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Singleton registry for custom models and agent-model assignments."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
# In-memory cache for fast lookups
|
||||
self._models: dict[str, CustomModel] = {}
|
||||
self._agent_assignments: dict[str, str] = {}
|
||||
self._load_from_db()
|
||||
|
||||
def _load_from_db(self) -> None:
|
||||
"""Bootstrap cache from SQLite."""
|
||||
try:
|
||||
conn = _get_conn()
|
||||
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
|
||||
self._models[row["name"]] = CustomModel(
|
||||
name=row["name"],
|
||||
format=ModelFormat(row["format"]),
|
||||
path=row["path"],
|
||||
role=ModelRole(row["role"]),
|
||||
context_window=row["context_window"],
|
||||
description=row["description"],
|
||||
registered_at=row["registered_at"],
|
||||
active=bool(row["active"]),
|
||||
default_temperature=row["default_temperature"],
|
||||
max_tokens=row["max_tokens"],
|
||||
)
|
||||
for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall():
|
||||
self._agent_assignments[row["agent_id"]] = row["model_name"]
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load model registry from DB: %s", exc)
|
||||
|
||||
# ── Model CRUD ─────────────────────────────────────────────────────────
|
||||
|
||||
def register(self, model: CustomModel) -> CustomModel:
|
||||
"""Register a new custom model."""
|
||||
with self._lock:
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO custom_models
|
||||
(name, format, path, role, context_window, description,
|
||||
registered_at, active, default_temperature, max_tokens)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
model.name, model.format.value, model.path,
|
||||
model.role.value, model.context_window, model.description,
|
||||
model.registered_at, int(model.active),
|
||||
model.default_temperature, model.max_tokens,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._models[model.name] = model
|
||||
logger.info("Registered model: %s (%s)", model.name, model.format.value)
|
||||
return model
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove a model from the registry."""
|
||||
with self._lock:
|
||||
if name not in self._models:
|
||||
return False
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
||||
conn.execute(
|
||||
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
del self._models[name]
|
||||
# Remove any agent assignments using this model
|
||||
self._agent_assignments = {
|
||||
k: v for k, v in self._agent_assignments.items() if v != name
|
||||
}
|
||||
logger.info("Unregistered model: %s", name)
|
||||
return True
|
||||
|
||||
def get(self, name: str) -> Optional[CustomModel]:
|
||||
"""Look up a model by name."""
|
||||
return self._models.get(name)
|
||||
|
||||
def list_models(self, role: Optional[ModelRole] = None) -> list[CustomModel]:
|
||||
"""List all registered models, optionally filtered by role."""
|
||||
models = list(self._models.values())
|
||||
if role is not None:
|
||||
models = [m for m in models if m.role == role]
|
||||
return models
|
||||
|
||||
def set_active(self, name: str, active: bool) -> bool:
|
||||
"""Enable or disable a model without removing it."""
|
||||
model = self._models.get(name)
|
||||
if not model:
|
||||
return False
|
||||
with self._lock:
|
||||
model.active = active
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"UPDATE custom_models SET active = ? WHERE name = ?",
|
||||
(int(active), name),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
# ── Agent-model assignments ────────────────────────────────────────────
|
||||
|
||||
def assign_model(self, agent_id: str, model_name: str) -> bool:
|
||||
"""Assign a specific model to an agent."""
|
||||
if model_name not in self._models:
|
||||
return False
|
||||
with self._lock:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO agent_model_assignments
|
||||
(agent_id, model_name, assigned_at)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(agent_id, model_name, now),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._agent_assignments[agent_id] = model_name
|
||||
logger.info("Assigned model %s to agent %s", model_name, agent_id)
|
||||
return True
|
||||
|
||||
def unassign_model(self, agent_id: str) -> bool:
|
||||
"""Remove model assignment from an agent (falls back to default)."""
|
||||
with self._lock:
|
||||
if agent_id not in self._agent_assignments:
|
||||
return False
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
|
||||
(agent_id,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
del self._agent_assignments[agent_id]
|
||||
return True
|
||||
|
||||
def get_agent_model(self, agent_id: str) -> Optional[CustomModel]:
|
||||
"""Get the model assigned to an agent, or None for default."""
|
||||
model_name = self._agent_assignments.get(agent_id)
|
||||
if model_name:
|
||||
return self._models.get(model_name)
|
||||
return None
|
||||
|
||||
def get_agent_assignments(self) -> dict[str, str]:
|
||||
"""Return all agent-to-model assignments."""
|
||||
return dict(self._agent_assignments)
|
||||
|
||||
# ── Role-based lookups ─────────────────────────────────────────────────
|
||||
|
||||
def get_reward_model(self) -> Optional[CustomModel]:
|
||||
"""Get the active reward/PRM model, if any."""
|
||||
reward_models = self.list_models(role=ModelRole.REWARD)
|
||||
active = [m for m in reward_models if m.active]
|
||||
return active[0] if active else None
|
||||
|
||||
def get_teacher_model(self) -> Optional[CustomModel]:
|
||||
"""Get the active teacher model for distillation."""
|
||||
teacher_models = self.list_models(role=ModelRole.TEACHER)
|
||||
active = [m for m in teacher_models if m.active]
|
||||
return active[0] if active else None
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
model_registry = ModelRegistry()
|
||||
@@ -251,3 +251,193 @@ def learned_keywords(agent_id: str) -> list[dict]:
|
||||
results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails})
|
||||
results.sort(key=lambda x: x["net"], reverse=True)
|
||||
return results
|
||||
|
||||
|
||||
# ── Reward model scoring (PRM-style) ─────────────────────────────────────────
|
||||
|
||||
import logging as _logging
|
||||
from config import settings as _settings
|
||||
|
||||
_reward_logger = _logging.getLogger(__name__ + ".reward")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardScore:
|
||||
"""Result from reward-model evaluation."""
|
||||
score: float # Normalised score in [-1.0, 1.0]
|
||||
positive_votes: int
|
||||
negative_votes: int
|
||||
total_votes: int
|
||||
model_used: str
|
||||
|
||||
|
||||
def _ensure_reward_table() -> None:
|
||||
"""Create the reward_scores table if needed."""
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS reward_scores (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
output_text TEXT NOT NULL,
|
||||
score REAL NOT NULL,
|
||||
positive INTEGER NOT NULL,
|
||||
negative INTEGER NOT NULL,
|
||||
total INTEGER NOT NULL,
|
||||
model_used TEXT NOT NULL,
|
||||
scored_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def score_output(
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
task_description: str,
|
||||
output_text: str,
|
||||
) -> Optional[RewardScore]:
|
||||
"""Score an agent's output using the reward model (majority vote).
|
||||
|
||||
Calls the reward model N times (settings.reward_model_votes) with a
|
||||
quality-evaluation prompt. Each vote is +1 (good) or -1 (bad).
|
||||
Final score is (positive - negative) / total, in [-1.0, 1.0].
|
||||
|
||||
Returns None if the reward model is disabled or unavailable.
|
||||
"""
|
||||
if not _settings.reward_model_enabled:
|
||||
return None
|
||||
|
||||
# Resolve model name: explicit setting > registry reward model > skip
|
||||
model_name = _settings.reward_model_name
|
||||
if not model_name:
|
||||
try:
|
||||
from infrastructure.models.registry import model_registry
|
||||
reward = model_registry.get_reward_model()
|
||||
if reward:
|
||||
model_name = reward.path if reward.format.value == "ollama" else reward.name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not model_name:
|
||||
_reward_logger.debug("No reward model configured, skipping scoring")
|
||||
return None
|
||||
|
||||
num_votes = max(1, _settings.reward_model_votes)
|
||||
positive = 0
|
||||
negative = 0
|
||||
|
||||
prompt = (
|
||||
f"You are a quality evaluator. Rate the following agent output.\n\n"
|
||||
f"TASK: {task_description}\n\n"
|
||||
f"OUTPUT:\n{output_text[:2000]}\n\n"
|
||||
f"Is this output correct, helpful, and complete? "
|
||||
f"Reply with exactly one word: GOOD or BAD."
|
||||
)
|
||||
|
||||
try:
|
||||
import requests as _req
|
||||
ollama_url = _settings.ollama_url
|
||||
|
||||
for _ in range(num_votes):
|
||||
try:
|
||||
resp = _req.post(
|
||||
f"{ollama_url}/api/generate",
|
||||
json={
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.3, "num_predict": 10},
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
answer = resp.json().get("response", "").strip().upper()
|
||||
if "GOOD" in answer:
|
||||
positive += 1
|
||||
else:
|
||||
negative += 1
|
||||
else:
|
||||
negative += 1 # Treat errors as negative conservatively
|
||||
except Exception as vote_exc:
|
||||
_reward_logger.debug("Vote failed: %s", vote_exc)
|
||||
negative += 1
|
||||
|
||||
except ImportError:
|
||||
_reward_logger.warning("requests library not available for reward scoring")
|
||||
return None
|
||||
|
||||
total = positive + negative
|
||||
if total == 0:
|
||||
return None
|
||||
|
||||
score = (positive - negative) / total
|
||||
|
||||
result = RewardScore(
|
||||
score=score,
|
||||
positive_votes=positive,
|
||||
negative_votes=negative,
|
||||
total_votes=total,
|
||||
model_used=model_name,
|
||||
)
|
||||
|
||||
# Persist to DB
|
||||
try:
|
||||
_ensure_reward_table()
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO reward_scores
|
||||
(task_id, agent_id, output_text, score, positive, negative, total, model_used)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
task_id, agent_id, output_text[:5000],
|
||||
score, positive, negative, total, model_name,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as db_exc:
|
||||
_reward_logger.warning("Failed to persist reward score: %s", db_exc)
|
||||
|
||||
_reward_logger.info(
|
||||
"Scored task %s agent %s: %.2f (%d+/%d- of %d votes)",
|
||||
task_id, agent_id, score, positive, negative, total,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def get_reward_scores(
|
||||
agent_id: Optional[str] = None, limit: int = 50
|
||||
) -> list[dict]:
|
||||
"""Retrieve historical reward scores from the database."""
|
||||
_ensure_reward_table()
|
||||
conn = _get_conn()
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM reward_scores WHERE agent_id = ? ORDER BY id DESC LIMIT ?",
|
||||
(agent_id, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM reward_scores ORDER BY id DESC LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return [
|
||||
{
|
||||
"task_id": r["task_id"],
|
||||
"agent_id": r["agent_id"],
|
||||
"score": r["score"],
|
||||
"positive": r["positive"],
|
||||
"negative": r["negative"],
|
||||
"total": r["total"],
|
||||
"model_used": r["model_used"],
|
||||
"scored_at": r["scored_at"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@@ -51,6 +51,16 @@ class PersonaNode(SwarmNode):
|
||||
self._meta = meta
|
||||
self._persona_id = persona_id
|
||||
self._use_learner = use_learner
|
||||
|
||||
# Resolve model: registry assignment > persona default > global default
|
||||
self._model_name: Optional[str] = meta.get("model")
|
||||
try:
|
||||
from infrastructure.models.registry import model_registry
|
||||
assigned = model_registry.get_agent_model(agent_id)
|
||||
if assigned:
|
||||
self._model_name = assigned.name
|
||||
except Exception:
|
||||
pass # Graceful degradation — use persona/global default
|
||||
|
||||
# Initialize tool executor for task execution
|
||||
self._tool_executor: Optional[ToolExecutor] = None
|
||||
@@ -213,6 +223,11 @@ class PersonaNode(SwarmNode):
|
||||
"""Return the task ID currently being executed, if any."""
|
||||
return self._current_task
|
||||
|
||||
@property
|
||||
def model_name(self) -> Optional[str]:
|
||||
"""Return the model this agent uses, or None for global default."""
|
||||
return self._model_name
|
||||
|
||||
@property
|
||||
def tool_capabilities(self) -> list[str]:
|
||||
"""Return list of available tool names."""
|
||||
|
||||
@@ -14,7 +14,7 @@ from __future__ import annotations
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class PersonaMeta(TypedDict):
|
||||
class PersonaMeta(TypedDict, total=False):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
@@ -24,6 +24,11 @@ class PersonaMeta(TypedDict):
|
||||
bid_base: int # typical bid when task matches persona
|
||||
bid_jitter: int # ± random jitter added to bid_base
|
||||
preferred_keywords: list[str]
|
||||
# Optional: custom model override for this persona.
|
||||
# When set, this persona uses this model instead of the global default.
|
||||
# Value is a model name registered in the ModelRegistry, or an Ollama
|
||||
# model name like "llama3.2" or "deepseek-r1:1.5b".
|
||||
model: str
|
||||
|
||||
|
||||
PERSONAS: dict[str, PersonaMeta] = {
|
||||
|
||||
217
tests/infrastructure/test_model_registry.py
Normal file
217
tests/infrastructure/test_model_registry.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Tests for the custom model registry."""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(tmp_path):
|
||||
"""Create a fresh ModelRegistry backed by a temporary database."""
|
||||
db = tmp_path / "test.db"
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
reg = ModelRegistry()
|
||||
yield reg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
"""A sample CustomModel for testing."""
|
||||
return CustomModel(
|
||||
name="test-llama",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
context_window=8192,
|
||||
description="Test model",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reward_model():
|
||||
"""A sample reward model."""
|
||||
return CustomModel(
|
||||
name="test-reward",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="deepseek-r1:1.5b",
|
||||
role=ModelRole.REWARD,
|
||||
context_window=32000,
|
||||
description="Test reward model",
|
||||
)
|
||||
|
||||
|
||||
class TestModelCRUD:
|
||||
"""Test model registration, lookup, and removal."""
|
||||
|
||||
def test_register_model(self, registry, sample_model):
|
||||
registered = registry.register(sample_model)
|
||||
assert registered.name == "test-llama"
|
||||
assert registered.format == ModelFormat.OLLAMA
|
||||
|
||||
def test_get_model(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
found = registry.get("test-llama")
|
||||
assert found is not None
|
||||
assert found.name == "test-llama"
|
||||
assert found.path == "llama3.2"
|
||||
|
||||
def test_get_nonexistent_model(self, registry):
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_list_models(self, registry, sample_model, reward_model):
|
||||
registry.register(sample_model)
|
||||
registry.register(reward_model)
|
||||
all_models = registry.list_models()
|
||||
assert len(all_models) == 2
|
||||
|
||||
def test_list_models_by_role(self, registry, sample_model, reward_model):
|
||||
registry.register(sample_model)
|
||||
registry.register(reward_model)
|
||||
general = registry.list_models(role=ModelRole.GENERAL)
|
||||
assert len(general) == 1
|
||||
assert general[0].name == "test-llama"
|
||||
rewards = registry.list_models(role=ModelRole.REWARD)
|
||||
assert len(rewards) == 1
|
||||
assert rewards[0].name == "test-reward"
|
||||
|
||||
def test_unregister_model(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
assert registry.unregister("test-llama") is True
|
||||
assert registry.get("test-llama") is None
|
||||
|
||||
def test_unregister_nonexistent(self, registry):
|
||||
assert registry.unregister("nonexistent") is False
|
||||
|
||||
def test_set_active(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
assert registry.set_active("test-llama", False) is True
|
||||
model = registry.get("test-llama")
|
||||
assert model.active is False
|
||||
assert registry.set_active("test-llama", True) is True
|
||||
model = registry.get("test-llama")
|
||||
assert model.active is True
|
||||
|
||||
def test_set_active_nonexistent(self, registry):
|
||||
assert registry.set_active("nonexistent", True) is False
|
||||
|
||||
def test_register_replaces_existing(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
updated = CustomModel(
|
||||
name="test-llama",
|
||||
format=ModelFormat.GGUF,
|
||||
path="/new/path.gguf",
|
||||
role=ModelRole.GENERAL,
|
||||
description="Updated model",
|
||||
)
|
||||
registry.register(updated)
|
||||
found = registry.get("test-llama")
|
||||
assert found.format == ModelFormat.GGUF
|
||||
assert found.path == "/new/path.gguf"
|
||||
|
||||
|
||||
class TestAgentAssignments:
|
||||
"""Test agent-to-model assignment management."""
|
||||
|
||||
def test_assign_model(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
assert registry.assign_model("agent-1", "test-llama") is True
|
||||
model = registry.get_agent_model("agent-1")
|
||||
assert model is not None
|
||||
assert model.name == "test-llama"
|
||||
|
||||
def test_assign_nonexistent_model(self, registry):
|
||||
assert registry.assign_model("agent-1", "nonexistent") is False
|
||||
|
||||
def test_unassign_model(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
registry.assign_model("agent-1", "test-llama")
|
||||
assert registry.unassign_model("agent-1") is True
|
||||
assert registry.get_agent_model("agent-1") is None
|
||||
|
||||
def test_unassign_nonexistent(self, registry):
|
||||
assert registry.unassign_model("agent-1") is False
|
||||
|
||||
def test_get_agent_model_none(self, registry):
|
||||
assert registry.get_agent_model("agent-1") is None
|
||||
|
||||
def test_get_all_assignments(self, registry, sample_model, reward_model):
|
||||
registry.register(sample_model)
|
||||
registry.register(reward_model)
|
||||
registry.assign_model("agent-1", "test-llama")
|
||||
registry.assign_model("agent-2", "test-reward")
|
||||
assignments = registry.get_agent_assignments()
|
||||
assert len(assignments) == 2
|
||||
assert assignments["agent-1"] == "test-llama"
|
||||
assert assignments["agent-2"] == "test-reward"
|
||||
|
||||
def test_unregister_removes_assignments(self, registry, sample_model):
|
||||
registry.register(sample_model)
|
||||
registry.assign_model("agent-1", "test-llama")
|
||||
registry.unregister("test-llama")
|
||||
assert registry.get_agent_model("agent-1") is None
|
||||
assert len(registry.get_agent_assignments()) == 0
|
||||
|
||||
|
||||
class TestRoleLookups:
|
||||
"""Test role-based model lookups."""
|
||||
|
||||
def test_get_reward_model(self, registry, reward_model):
|
||||
registry.register(reward_model)
|
||||
found = registry.get_reward_model()
|
||||
assert found is not None
|
||||
assert found.name == "test-reward"
|
||||
assert found.role == ModelRole.REWARD
|
||||
|
||||
def test_get_reward_model_none(self, registry):
|
||||
assert registry.get_reward_model() is None
|
||||
|
||||
def test_get_teacher_model(self, registry):
|
||||
teacher = CustomModel(
|
||||
name="teacher-model",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="teacher:latest",
|
||||
role=ModelRole.TEACHER,
|
||||
)
|
||||
registry.register(teacher)
|
||||
found = registry.get_teacher_model()
|
||||
assert found is not None
|
||||
assert found.name == "teacher-model"
|
||||
|
||||
def test_get_teacher_model_none(self, registry):
|
||||
assert registry.get_teacher_model() is None
|
||||
|
||||
def test_inactive_reward_model_not_returned(self, registry, reward_model):
|
||||
registry.register(reward_model)
|
||||
registry.set_active("test-reward", False)
|
||||
assert registry.get_reward_model() is None
|
||||
|
||||
|
||||
class TestCustomModelDataclass:
|
||||
"""Test CustomModel construction."""
|
||||
|
||||
def test_default_registered_at(self):
|
||||
model = CustomModel(
|
||||
name="test", format=ModelFormat.OLLAMA, path="test"
|
||||
)
|
||||
assert model.registered_at != ""
|
||||
|
||||
def test_model_roles(self):
|
||||
assert ModelRole.GENERAL.value == "general"
|
||||
assert ModelRole.REWARD.value == "reward"
|
||||
assert ModelRole.TEACHER.value == "teacher"
|
||||
assert ModelRole.JUDGE.value == "judge"
|
||||
|
||||
def test_model_formats(self):
|
||||
assert ModelFormat.GGUF.value == "gguf"
|
||||
assert ModelFormat.SAFETENSORS.value == "safetensors"
|
||||
assert ModelFormat.HF_CHECKPOINT.value == "hf"
|
||||
assert ModelFormat.OLLAMA.value == "ollama"
|
||||
273
tests/infrastructure/test_models_api.py
Normal file
273
tests/infrastructure/test_models_api.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""Tests for the custom models API routes."""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry(tmp_path):
|
||||
"""A fresh ModelRegistry for each test."""
|
||||
db = tmp_path / "api_test.db"
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
reg = ModelRegistry()
|
||||
yield reg
|
||||
|
||||
|
||||
class TestModelsAPIList:
|
||||
"""Test listing models via the API."""
|
||||
|
||||
def test_list_models_empty(self, client, tmp_path):
|
||||
db = tmp_path / "api.db"
|
||||
with patch("infrastructure.models.registry.DB_PATH", db):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.list_models.return_value = []
|
||||
resp = client.get("/api/v1/models")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "models" in data
|
||||
assert "total" in data
|
||||
|
||||
def test_list_models_with_data(self, client):
|
||||
model = CustomModel(
|
||||
name="test-m",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.list_models.return_value = [model]
|
||||
resp = client.get("/api/v1/models")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 1
|
||||
assert data["models"][0]["name"] == "test-m"
|
||||
|
||||
|
||||
class TestModelsAPIRegister:
|
||||
"""Test model registration via the API."""
|
||||
|
||||
def test_register_ollama_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.register.return_value = CustomModel(
|
||||
name="my-model",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
)
|
||||
resp = client.post(
|
||||
"/api/v1/models",
|
||||
json={
|
||||
"name": "my-model",
|
||||
"format": "ollama",
|
||||
"path": "llama3.2",
|
||||
"role": "general",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["model"]["name"] == "my-model"
|
||||
|
||||
def test_register_invalid_format(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/models",
|
||||
json={
|
||||
"name": "bad-model",
|
||||
"format": "invalid_format",
|
||||
"path": "whatever",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "Invalid format" in resp.json()["detail"]
|
||||
|
||||
def test_register_invalid_role(self, client):
|
||||
resp = client.post(
|
||||
"/api/v1/models",
|
||||
json={
|
||||
"name": "bad-model",
|
||||
"format": "ollama",
|
||||
"path": "llama3.2",
|
||||
"role": "invalid_role",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "Invalid role" in resp.json()["detail"]
|
||||
|
||||
|
||||
class TestModelsAPIDelete:
|
||||
"""Test model deletion via the API."""
|
||||
|
||||
def test_delete_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.unregister.return_value = True
|
||||
resp = client.delete("/api/v1/models/my-model")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_delete_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.unregister.return_value = False
|
||||
resp = client.delete("/api/v1/models/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestModelsAPIGet:
|
||||
"""Test getting a specific model."""
|
||||
|
||||
def test_get_model(self, client):
|
||||
model = CustomModel(
|
||||
name="my-model",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="llama3.2",
|
||||
role=ModelRole.GENERAL,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get.return_value = model
|
||||
resp = client.get("/api/v1/models/my-model")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "my-model"
|
||||
|
||||
def test_get_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get.return_value = None
|
||||
resp = client.get("/api/v1/models/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestModelsAPIAssignments:
|
||||
"""Test agent model assignment endpoints."""
|
||||
|
||||
def test_assign_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.assign_model.return_value = True
|
||||
resp = client.post(
|
||||
"/api/v1/models/assignments",
|
||||
json={"agent_id": "agent-1", "model_name": "my-model"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_assign_nonexistent_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.assign_model.return_value = False
|
||||
resp = client.post(
|
||||
"/api/v1/models/assignments",
|
||||
json={"agent_id": "agent-1", "model_name": "nonexistent"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_unassign_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.unassign_model.return_value = True
|
||||
resp = client.delete("/api/v1/models/assignments/agent-1")
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_unassign_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.unassign_model.return_value = False
|
||||
resp = client.delete("/api/v1/models/assignments/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_list_assignments(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_agent_assignments.return_value = {
|
||||
"agent-1": "model-a",
|
||||
"agent-2": "model-b",
|
||||
}
|
||||
resp = client.get("/api/v1/models/assignments/all")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 2
|
||||
|
||||
|
||||
class TestModelsAPIRoles:
|
||||
"""Test role-based lookup endpoints."""
|
||||
|
||||
def test_get_reward_model(self, client):
|
||||
model = CustomModel(
|
||||
name="reward-m",
|
||||
format=ModelFormat.OLLAMA,
|
||||
path="deepseek-r1:1.5b",
|
||||
role=ModelRole.REWARD,
|
||||
)
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = model
|
||||
resp = client.get("/api/v1/models/roles/reward")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["reward_model"]["name"] == "reward-m"
|
||||
|
||||
def test_get_reward_model_none(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = None
|
||||
resp = client.get("/api/v1/models/roles/reward")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["reward_model"] is None
|
||||
|
||||
def test_get_teacher_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_teacher_model.return_value = None
|
||||
resp = client.get("/api/v1/models/roles/teacher")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["teacher_model"] is None
|
||||
|
||||
|
||||
class TestModelsAPISetActive:
|
||||
"""Test enable/disable model endpoint."""
|
||||
|
||||
def test_enable_model(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.set_active.return_value = True
|
||||
resp = client.patch(
|
||||
"/api/v1/models/my-model/active",
|
||||
json={"active": True},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_disable_nonexistent(self, client):
|
||||
with patch(
|
||||
"dashboard.routes.models.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.set_active.return_value = False
|
||||
resp = client.patch(
|
||||
"/api/v1/models/nonexistent/active",
|
||||
json={"active": False},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
197
tests/swarm/test_reward_scoring.py
Normal file
197
tests/swarm/test_reward_scoring.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Tests for reward model scoring in the swarm learner."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm.learner import (
|
||||
RewardScore,
|
||||
get_reward_scores,
|
||||
score_output,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_db(tmp_path):
|
||||
"""Point the learner at a temporary database."""
|
||||
db = tmp_path / "learner_test.db"
|
||||
with patch("swarm.learner.DB_PATH", db):
|
||||
yield
|
||||
|
||||
|
||||
class TestScoreOutput:
|
||||
"""Test the score_output function."""
|
||||
|
||||
def test_returns_none_when_disabled(self):
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = False
|
||||
result = score_output("task-1", "agent-1", "do X", "done X")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_no_model(self):
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = ""
|
||||
with patch(
|
||||
"infrastructure.models.registry.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = None
|
||||
result = score_output("task-1", "agent-1", "do X", "done X")
|
||||
assert result is None
|
||||
|
||||
def test_positive_scoring(self):
|
||||
"""All votes return GOOD → score = 1.0."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"response": "GOOD"}
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = "test-model"
|
||||
mock_s.reward_model_votes = 3
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
result = score_output("task-1", "agent-1", "do X", "done X")
|
||||
|
||||
assert result is not None
|
||||
assert result.score == 1.0
|
||||
assert result.positive_votes == 3
|
||||
assert result.negative_votes == 0
|
||||
assert result.total_votes == 3
|
||||
assert result.model_used == "test-model"
|
||||
|
||||
def test_negative_scoring(self):
|
||||
"""All votes return BAD → score = -1.0."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"response": "BAD"}
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = "test-model"
|
||||
mock_s.reward_model_votes = 3
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
result = score_output("task-1", "agent-1", "do X", "bad output")
|
||||
|
||||
assert result is not None
|
||||
assert result.score == -1.0
|
||||
assert result.negative_votes == 3
|
||||
|
||||
def test_mixed_scoring(self):
|
||||
"""2 GOOD + 1 BAD → score ≈ 0.33."""
|
||||
responses = []
|
||||
for text in ["GOOD", "GOOD", "BAD"]:
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"response": text}
|
||||
responses.append(resp)
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = "test-model"
|
||||
mock_s.reward_model_votes = 3
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch("requests.post", side_effect=responses):
|
||||
result = score_output("task-1", "agent-1", "do X", "ok output")
|
||||
|
||||
assert result is not None
|
||||
assert abs(result.score - (1 / 3)) < 0.01
|
||||
assert result.positive_votes == 2
|
||||
assert result.negative_votes == 1
|
||||
|
||||
def test_uses_registry_reward_model(self):
|
||||
"""Falls back to registry reward model when setting is empty."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.path = "registry-reward-model"
|
||||
mock_model.format = MagicMock()
|
||||
mock_model.format.value = "ollama"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"response": "GOOD"}
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = ""
|
||||
mock_s.reward_model_votes = 1
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch(
|
||||
"infrastructure.models.registry.model_registry"
|
||||
) as mock_reg:
|
||||
mock_reg.get_reward_model.return_value = mock_model
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
result = score_output("task-1", "agent-1", "do X", "ok")
|
||||
|
||||
assert result is not None
|
||||
assert result.model_used == "registry-reward-model"
|
||||
|
||||
|
||||
class TestGetRewardScores:
|
||||
"""Test retrieving historical reward scores."""
|
||||
|
||||
def test_empty_history(self):
|
||||
scores = get_reward_scores()
|
||||
assert scores == []
|
||||
|
||||
def test_scores_persisted(self):
|
||||
"""Scores from score_output are retrievable."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"response": "GOOD"}
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = "test-model"
|
||||
mock_s.reward_model_votes = 1
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
score_output("task-1", "agent-1", "do X", "done X")
|
||||
|
||||
scores = get_reward_scores()
|
||||
assert len(scores) == 1
|
||||
assert scores[0]["task_id"] == "task-1"
|
||||
assert scores[0]["agent_id"] == "agent-1"
|
||||
assert scores[0]["score"] == 1.0
|
||||
|
||||
def test_filter_by_agent(self):
|
||||
"""Filter scores by agent_id."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"response": "GOOD"}
|
||||
|
||||
with patch("swarm.learner._settings") as mock_s:
|
||||
mock_s.reward_model_enabled = True
|
||||
mock_s.reward_model_name = "test-model"
|
||||
mock_s.reward_model_votes = 1
|
||||
mock_s.ollama_url = "http://localhost:11434"
|
||||
|
||||
with patch("requests.post", return_value=mock_response):
|
||||
score_output("task-1", "agent-1", "task A", "output A")
|
||||
score_output("task-2", "agent-2", "task B", "output B")
|
||||
|
||||
agent1_scores = get_reward_scores(agent_id="agent-1")
|
||||
assert len(agent1_scores) == 1
|
||||
assert agent1_scores[0]["agent_id"] == "agent-1"
|
||||
|
||||
|
||||
class TestRewardScoreDataclass:
|
||||
"""Test RewardScore construction."""
|
||||
|
||||
def test_create_score(self):
|
||||
score = RewardScore(
|
||||
score=0.5,
|
||||
positive_votes=3,
|
||||
negative_votes=1,
|
||||
total_votes=4,
|
||||
model_used="test-model",
|
||||
)
|
||||
assert score.score == 0.5
|
||||
assert score.total_votes == 4
|
||||
Reference in New Issue
Block a user