1
0

feat: add custom weights, model registry, per-agent models, and reward scoring

Inspired by OpenClaw-RL's multi-model orchestration, this adds four
features for custom model management:

1. Custom model registry (infrastructure/models/registry.py) — SQLite-backed
   registry for GGUF, safetensors, HF checkpoint, and Ollama models with
   role-based lookups (general, reward, teacher, judge).

2. Per-agent model assignment — each swarm persona can use a different model
   instead of sharing the global default. Resolved via registry assignment >
   persona default > global default.

3. Runtime model management API (/api/v1/models) — REST endpoints to register,
   list, assign, enable/disable, and remove custom models without restart.
   Includes a dashboard page at /models.

4. Reward model scoring (PRM-style) — majority-vote quality evaluation of
   agent outputs using a configurable reward model. Scores persist in SQLite
   and feed into the swarm learner.

New config settings: custom_weights_dir, reward_model_enabled,
reward_model_name, reward_model_votes.

54 new tests covering registry CRUD, API endpoints, agent assignments,
role lookups, and reward scoring.

https://claude.ai/code/session_01V4iTozMwcE2gjfnCJdCugC
This commit is contained in:
Claude
2026-02-27 01:08:03 +00:00
parent e4d5ec5ed4
commit 211c54bc8c
13 changed files with 1603 additions and 1 deletions

View File

@@ -103,6 +103,17 @@ class Settings(BaseSettings):
work_orders_auto_execute: bool = False # Master switch for auto-execution
work_orders_auto_threshold: str = "low" # Max priority that auto-executes: "low" | "medium" | "high" | "none"
# ── Custom Weights & Models ──────────────────────────────────────
# Directory for custom model weights (GGUF, safetensors, HF checkpoints).
# Models placed here can be registered at runtime and assigned to agents.
custom_weights_dir: str = "data/models"
# Enable the reward model for scoring agent outputs (PRM-style).
reward_model_enabled: bool = False
# Reward model name (must be available via Ollama or a custom weight path).
reward_model_name: str = ""
# Minimum votes for majority-vote reward scoring (odd number recommended).
reward_model_votes: int = 3
# ── Browser Local Models (iPhone / WebGPU) ───────────────────────
# Enable in-browser LLM inference via WebLLM for offline iPhone use.
# When enabled, the mobile dashboard loads a small model directly

View File

@@ -36,6 +36,8 @@ from dashboard.routes.self_coding import router as self_coding_router
from dashboard.routes.self_coding import self_modify_router
from dashboard.routes.hands import router as hands_router
from dashboard.routes.grok import router as grok_router
from dashboard.routes.models import router as models_router
from dashboard.routes.models import api_router as models_api_router
from infrastructure.router.api import router as cascade_router
logging.basicConfig(
@@ -208,6 +210,8 @@ app.include_router(tasks_router)
app.include_router(scripture_router)
app.include_router(hands_router)
app.include_router(grok_router)
app.include_router(models_router)
app.include_router(models_api_router)
app.include_router(cascade_router)

View File

@@ -0,0 +1,272 @@
"""Custom model management routes — register, list, assign, and swap models.
Provides a REST API for managing custom model weights and their assignment
to swarm agents. Inspired by OpenClaw-RL's multi-model orchestration.
"""
import logging
from pathlib import Path
from typing import Any, Optional
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from config import settings
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/models", tags=["models"])
api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
templates = Jinja2Templates(directory=str(Path(__file__).parent.parent / "templates"))
# ── Pydantic schemas ──────────────────────────────────────────────────────────
class RegisterModelRequest(BaseModel):
"""Request body for model registration."""
name: str
format: str # gguf, safetensors, hf, ollama
path: str
role: str = "general"
context_window: int = 4096
description: str = ""
default_temperature: float = 0.7
max_tokens: int = 2048
class AssignModelRequest(BaseModel):
"""Request body for assigning a model to an agent."""
agent_id: str
model_name: str
class SetActiveRequest(BaseModel):
"""Request body for enabling/disabling a model."""
active: bool
# ── API endpoints ─────────────────────────────────────────────────────────────
@api_router.get("")
async def list_models(role: Optional[str] = None) -> dict[str, Any]:
"""List all registered custom models."""
model_role = ModelRole(role) if role else None
models = model_registry.list_models(role=model_role)
return {
"models": [
{
"name": m.name,
"format": m.format.value,
"path": m.path,
"role": m.role.value,
"context_window": m.context_window,
"description": m.description,
"active": m.active,
"registered_at": m.registered_at,
"default_temperature": m.default_temperature,
"max_tokens": m.max_tokens,
}
for m in models
],
"total": len(models),
"weights_dir": settings.custom_weights_dir,
}
@api_router.post("")
async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
"""Register a new custom model."""
try:
fmt = ModelFormat(request.format)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid format: {request.format}. "
f"Choose from: {[f.value for f in ModelFormat]}",
)
try:
role = ModelRole(request.role)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid role: {request.role}. "
f"Choose from: {[r.value for r in ModelRole]}",
)
# Validate path exists for non-Ollama formats
if fmt != ModelFormat.OLLAMA:
weight_path = Path(request.path)
if not weight_path.exists():
raise HTTPException(
status_code=400,
detail=f"Weight path does not exist: {request.path}",
)
model = CustomModel(
name=request.name,
format=fmt,
path=request.path,
role=role,
context_window=request.context_window,
description=request.description,
default_temperature=request.default_temperature,
max_tokens=request.max_tokens,
)
registered = model_registry.register(model)
return {
"message": f"Model {registered.name} registered",
"model": {
"name": registered.name,
"format": registered.format.value,
"role": registered.role.value,
"path": registered.path,
},
}
@api_router.get("/{model_name}")
async def get_model(model_name: str) -> dict[str, Any]:
"""Get details of a specific model."""
model = model_registry.get(model_name)
if not model:
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
return {
"name": model.name,
"format": model.format.value,
"path": model.path,
"role": model.role.value,
"context_window": model.context_window,
"description": model.description,
"active": model.active,
"registered_at": model.registered_at,
"default_temperature": model.default_temperature,
"max_tokens": model.max_tokens,
}
@api_router.delete("/{model_name}")
async def unregister_model(model_name: str) -> dict[str, str]:
"""Remove a model from the registry."""
if not model_registry.unregister(model_name):
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
return {"message": f"Model {model_name} unregistered"}
@api_router.patch("/{model_name}/active")
async def set_model_active(
model_name: str, request: SetActiveRequest
) -> dict[str, str]:
"""Enable or disable a model."""
if not model_registry.set_active(model_name, request.active):
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
state = "enabled" if request.active else "disabled"
return {"message": f"Model {model_name} {state}"}
# ── Agent assignment endpoints ────────────────────────────────────────────────
@api_router.get("/assignments/all")
async def list_assignments() -> dict[str, Any]:
"""List all agent-to-model assignments."""
assignments = model_registry.get_agent_assignments()
return {
"assignments": [
{"agent_id": aid, "model_name": mname}
for aid, mname in assignments.items()
],
"total": len(assignments),
}
@api_router.post("/assignments")
async def assign_model(request: AssignModelRequest) -> dict[str, str]:
"""Assign a model to a swarm agent."""
if not model_registry.assign_model(request.agent_id, request.model_name):
raise HTTPException(
status_code=404,
detail=f"Model {request.model_name} not found in registry",
)
return {
"message": f"Model {request.model_name} assigned to {request.agent_id}",
}
@api_router.delete("/assignments/{agent_id}")
async def unassign_model(agent_id: str) -> dict[str, str]:
"""Remove model assignment from an agent (reverts to default)."""
if not model_registry.unassign_model(agent_id):
raise HTTPException(
status_code=404,
detail=f"No model assignment for agent {agent_id}",
)
return {"message": f"Model assignment removed for {agent_id}"}
# ── Role-based lookups ────────────────────────────────────────────────────────
@api_router.get("/roles/reward")
async def get_reward_model() -> dict[str, Any]:
"""Get the active reward (PRM) model."""
model = model_registry.get_reward_model()
if not model:
return {"reward_model": None, "reward_enabled": settings.reward_model_enabled}
return {
"reward_model": {
"name": model.name,
"format": model.format.value,
"path": model.path,
},
"reward_enabled": settings.reward_model_enabled,
}
@api_router.get("/roles/teacher")
async def get_teacher_model() -> dict[str, Any]:
"""Get the active teacher model for distillation."""
model = model_registry.get_teacher_model()
if not model:
return {"teacher_model": None}
return {
"teacher_model": {
"name": model.name,
"format": model.format.value,
"path": model.path,
},
}
# ── Dashboard page ────────────────────────────────────────────────────────────
@router.get("", response_class=HTMLResponse)
async def models_page(request: Request):
"""Custom models management dashboard page."""
models = model_registry.list_models()
assignments = model_registry.get_agent_assignments()
reward = model_registry.get_reward_model()
return templates.TemplateResponse(
request,
"models.html",
{
"page_title": "Custom Models",
"models": models,
"assignments": assignments,
"reward_model": reward,
"weights_dir": settings.custom_weights_dir,
"reward_enabled": settings.reward_model_enabled,
},
)

View File

@@ -0,0 +1,119 @@
{% extends "base.html" %}
{% block title %}Custom Models - Timmy Time{% endblock %}
{% block content %}
<div class="mc-panel">
<div class="mc-panel-header">
<h1 class="page-title">Custom Models</h1>
<p class="mc-text-secondary">Manage model weights and agent assignments</p>
</div>
<!-- Stats -->
<div class="mc-stats-row">
<div class="mc-stat-card">
<div class="mc-stat-value">{{ models|length }}</div>
<div class="mc-stat-label">Models</div>
</div>
<div class="mc-stat-card">
<div class="mc-stat-value">{{ assignments|length }}</div>
<div class="mc-stat-label">Assignments</div>
</div>
<div class="mc-stat-card">
<div class="mc-stat-value">{{ "Yes" if reward_model else "No" }}</div>
<div class="mc-stat-label">Reward Model</div>
</div>
</div>
<!-- Register Model Form -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Register Model</h2>
<form hx-post="/api/v1/models" hx-target="#model-result" hx-swap="innerHTML"
style="display: grid; gap: 0.5rem; max-width: 500px;">
<input name="name" placeholder="Model name" required class="mc-input" />
<select name="format" class="mc-input">
<option value="ollama">Ollama</option>
<option value="gguf">GGUF</option>
<option value="safetensors">Safetensors</option>
<option value="hf">HF Checkpoint</option>
</select>
<input name="path" placeholder="Path or Ollama model name" required class="mc-input" />
<select name="role" class="mc-input">
<option value="general">General</option>
<option value="reward">Reward (PRM)</option>
<option value="teacher">Teacher</option>
<option value="judge">Judge</option>
</select>
<input name="context_window" type="number" value="4096" class="mc-input" />
<input name="description" placeholder="Description (optional)" class="mc-input" />
<button type="submit" class="mc-btn mc-btn-primary">Register</button>
</form>
<div id="model-result" style="margin-top: 0.5rem;"></div>
</div>
<!-- Registered Models -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Registered Models</h2>
{% if models %}
<table class="mc-table">
<thead>
<tr>
<th>Name</th>
<th>Format</th>
<th>Role</th>
<th>Context</th>
<th>Active</th>
<th>Actions</th>
</tr>
</thead>
<tbody>
{% for m in models %}
<tr>
<td><strong>{{ m.name }}</strong></td>
<td>{{ m.format.value }}</td>
<td>{{ m.role.value }}</td>
<td>{{ m.context_window }}</td>
<td>{{ "Yes" if m.active else "No" }}</td>
<td>
<button class="mc-btn mc-btn-sm"
hx-delete="/api/v1/models/{{ m.name }}"
hx-confirm="Remove {{ m.name }}?"
hx-target="closest tr"
hx-swap="outerHTML">Remove</button>
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<p class="mc-text-secondary">No custom models registered. Use the form above or the API.</p>
{% endif %}
</div>
<!-- Agent Assignments -->
<div class="mc-section" style="margin-top: 1.5rem;">
<h2>Agent Model Assignments</h2>
{% if assignments %}
<table class="mc-table">
<thead>
<tr><th>Agent</th><th>Model</th></tr>
</thead>
<tbody>
{% for agent_id, model_name in assignments.items() %}
<tr>
<td>{{ agent_id }}</td>
<td>{{ model_name }}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<p class="mc-text-secondary">No agent-specific model assignments. All agents use the global default.</p>
{% endif %}
</div>
<div class="mc-section" style="margin-top: 1rem;">
<p class="mc-text-secondary">Weights directory: <code>{{ weights_dir }}</code></p>
</div>
</div>
{% endblock %}

View File

View File

@@ -0,0 +1,268 @@
"""Custom model registry — register, load, and manage model weights.
Tracks custom models (GGUF files, HF checkpoints, Ollama modelfiles)
and their assignment to swarm agents. Models can be registered at
runtime via the API or pre-configured via providers.yaml.
Inspired by OpenClaw-RL's multi-model orchestration where distinct
model roles (student, teacher, judge/PRM) run on dedicated resources.
"""
import logging
import sqlite3
import threading
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Optional
from config import settings
logger = logging.getLogger(__name__)
DB_PATH = Path("data/swarm.db")
class ModelFormat(str, Enum):
"""Supported model weight formats."""
GGUF = "gguf" # Ollama-compatible quantised weights
SAFETENSORS = "safetensors" # HuggingFace safetensors
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
OLLAMA = "ollama" # Already loaded in Ollama by name
class ModelRole(str, Enum):
"""Role a model can play in the system (OpenClaw-RL style)."""
GENERAL = "general" # Default agent inference
REWARD = "reward" # Process Reward Model (PRM) scoring
TEACHER = "teacher" # On-policy distillation teacher
JUDGE = "judge" # Output quality evaluation
@dataclass
class CustomModel:
"""A registered custom model."""
name: str
format: ModelFormat
path: str # Absolute path or Ollama model name
role: ModelRole = ModelRole.GENERAL
context_window: int = 4096
description: str = ""
registered_at: str = ""
active: bool = True
# Per-model generation settings
default_temperature: float = 0.7
max_tokens: int = 2048
def __post_init__(self):
if not self.registered_at:
self.registered_at = datetime.now(timezone.utc).isoformat()
def _get_conn() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row
conn.execute(
"""
CREATE TABLE IF NOT EXISTS custom_models (
name TEXT PRIMARY KEY,
format TEXT NOT NULL,
path TEXT NOT NULL,
role TEXT NOT NULL DEFAULT 'general',
context_window INTEGER NOT NULL DEFAULT 4096,
description TEXT NOT NULL DEFAULT '',
registered_at TEXT NOT NULL,
active INTEGER NOT NULL DEFAULT 1,
default_temperature REAL NOT NULL DEFAULT 0.7,
max_tokens INTEGER NOT NULL DEFAULT 2048
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS agent_model_assignments (
agent_id TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
assigned_at TEXT NOT NULL,
FOREIGN KEY (model_name) REFERENCES custom_models(name)
)
"""
)
conn.commit()
return conn
class ModelRegistry:
"""Singleton registry for custom models and agent-model assignments."""
def __init__(self) -> None:
self._lock = threading.Lock()
# In-memory cache for fast lookups
self._models: dict[str, CustomModel] = {}
self._agent_assignments: dict[str, str] = {}
self._load_from_db()
def _load_from_db(self) -> None:
"""Bootstrap cache from SQLite."""
try:
conn = _get_conn()
for row in conn.execute("SELECT * FROM custom_models WHERE active = 1").fetchall():
self._models[row["name"]] = CustomModel(
name=row["name"],
format=ModelFormat(row["format"]),
path=row["path"],
role=ModelRole(row["role"]),
context_window=row["context_window"],
description=row["description"],
registered_at=row["registered_at"],
active=bool(row["active"]),
default_temperature=row["default_temperature"],
max_tokens=row["max_tokens"],
)
for row in conn.execute("SELECT * FROM agent_model_assignments").fetchall():
self._agent_assignments[row["agent_id"]] = row["model_name"]
conn.close()
except Exception as exc:
logger.warning("Failed to load model registry from DB: %s", exc)
# ── Model CRUD ─────────────────────────────────────────────────────────
def register(self, model: CustomModel) -> CustomModel:
"""Register a new custom model."""
with self._lock:
conn = _get_conn()
conn.execute(
"""
INSERT OR REPLACE INTO custom_models
(name, format, path, role, context_window, description,
registered_at, active, default_temperature, max_tokens)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model.name, model.format.value, model.path,
model.role.value, model.context_window, model.description,
model.registered_at, int(model.active),
model.default_temperature, model.max_tokens,
),
)
conn.commit()
conn.close()
self._models[model.name] = model
logger.info("Registered model: %s (%s)", model.name, model.format.value)
return model
def unregister(self, name: str) -> bool:
"""Remove a model from the registry."""
with self._lock:
if name not in self._models:
return False
conn = _get_conn()
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
conn.execute(
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
)
conn.commit()
conn.close()
del self._models[name]
# Remove any agent assignments using this model
self._agent_assignments = {
k: v for k, v in self._agent_assignments.items() if v != name
}
logger.info("Unregistered model: %s", name)
return True
def get(self, name: str) -> Optional[CustomModel]:
"""Look up a model by name."""
return self._models.get(name)
def list_models(self, role: Optional[ModelRole] = None) -> list[CustomModel]:
"""List all registered models, optionally filtered by role."""
models = list(self._models.values())
if role is not None:
models = [m for m in models if m.role == role]
return models
def set_active(self, name: str, active: bool) -> bool:
"""Enable or disable a model without removing it."""
model = self._models.get(name)
if not model:
return False
with self._lock:
model.active = active
conn = _get_conn()
conn.execute(
"UPDATE custom_models SET active = ? WHERE name = ?",
(int(active), name),
)
conn.commit()
conn.close()
return True
# ── Agent-model assignments ────────────────────────────────────────────
def assign_model(self, agent_id: str, model_name: str) -> bool:
"""Assign a specific model to an agent."""
if model_name not in self._models:
return False
with self._lock:
now = datetime.now(timezone.utc).isoformat()
conn = _get_conn()
conn.execute(
"""
INSERT OR REPLACE INTO agent_model_assignments
(agent_id, model_name, assigned_at)
VALUES (?, ?, ?)
""",
(agent_id, model_name, now),
)
conn.commit()
conn.close()
self._agent_assignments[agent_id] = model_name
logger.info("Assigned model %s to agent %s", model_name, agent_id)
return True
def unassign_model(self, agent_id: str) -> bool:
"""Remove model assignment from an agent (falls back to default)."""
with self._lock:
if agent_id not in self._agent_assignments:
return False
conn = _get_conn()
conn.execute(
"DELETE FROM agent_model_assignments WHERE agent_id = ?",
(agent_id,),
)
conn.commit()
conn.close()
del self._agent_assignments[agent_id]
return True
def get_agent_model(self, agent_id: str) -> Optional[CustomModel]:
"""Get the model assigned to an agent, or None for default."""
model_name = self._agent_assignments.get(agent_id)
if model_name:
return self._models.get(model_name)
return None
def get_agent_assignments(self) -> dict[str, str]:
"""Return all agent-to-model assignments."""
return dict(self._agent_assignments)
# ── Role-based lookups ─────────────────────────────────────────────────
def get_reward_model(self) -> Optional[CustomModel]:
"""Get the active reward/PRM model, if any."""
reward_models = self.list_models(role=ModelRole.REWARD)
active = [m for m in reward_models if m.active]
return active[0] if active else None
def get_teacher_model(self) -> Optional[CustomModel]:
"""Get the active teacher model for distillation."""
teacher_models = self.list_models(role=ModelRole.TEACHER)
active = [m for m in teacher_models if m.active]
return active[0] if active else None
# Module-level singleton
model_registry = ModelRegistry()

View File

@@ -251,3 +251,193 @@ def learned_keywords(agent_id: str) -> list[dict]:
results.append({"keyword": kw, "wins": wins, "failures": fails, "net": wins - fails})
results.sort(key=lambda x: x["net"], reverse=True)
return results
# ── Reward model scoring (PRM-style) ─────────────────────────────────────────
import logging as _logging
from config import settings as _settings
_reward_logger = _logging.getLogger(__name__ + ".reward")
@dataclass
class RewardScore:
"""Result from reward-model evaluation."""
score: float # Normalised score in [-1.0, 1.0]
positive_votes: int
negative_votes: int
total_votes: int
model_used: str
def _ensure_reward_table() -> None:
"""Create the reward_scores table if needed."""
conn = _get_conn()
conn.execute(
"""
CREATE TABLE IF NOT EXISTS reward_scores (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
output_text TEXT NOT NULL,
score REAL NOT NULL,
positive INTEGER NOT NULL,
negative INTEGER NOT NULL,
total INTEGER NOT NULL,
model_used TEXT NOT NULL,
scored_at TEXT NOT NULL DEFAULT (datetime('now'))
)
"""
)
conn.commit()
conn.close()
def score_output(
task_id: str,
agent_id: str,
task_description: str,
output_text: str,
) -> Optional[RewardScore]:
"""Score an agent's output using the reward model (majority vote).
Calls the reward model N times (settings.reward_model_votes) with a
quality-evaluation prompt. Each vote is +1 (good) or -1 (bad).
Final score is (positive - negative) / total, in [-1.0, 1.0].
Returns None if the reward model is disabled or unavailable.
"""
if not _settings.reward_model_enabled:
return None
# Resolve model name: explicit setting > registry reward model > skip
model_name = _settings.reward_model_name
if not model_name:
try:
from infrastructure.models.registry import model_registry
reward = model_registry.get_reward_model()
if reward:
model_name = reward.path if reward.format.value == "ollama" else reward.name
except Exception:
pass
if not model_name:
_reward_logger.debug("No reward model configured, skipping scoring")
return None
num_votes = max(1, _settings.reward_model_votes)
positive = 0
negative = 0
prompt = (
f"You are a quality evaluator. Rate the following agent output.\n\n"
f"TASK: {task_description}\n\n"
f"OUTPUT:\n{output_text[:2000]}\n\n"
f"Is this output correct, helpful, and complete? "
f"Reply with exactly one word: GOOD or BAD."
)
try:
import requests as _req
ollama_url = _settings.ollama_url
for _ in range(num_votes):
try:
resp = _req.post(
f"{ollama_url}/api/generate",
json={
"model": model_name,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.3, "num_predict": 10},
},
timeout=30,
)
if resp.status_code == 200:
answer = resp.json().get("response", "").strip().upper()
if "GOOD" in answer:
positive += 1
else:
negative += 1
else:
negative += 1 # Treat errors as negative conservatively
except Exception as vote_exc:
_reward_logger.debug("Vote failed: %s", vote_exc)
negative += 1
except ImportError:
_reward_logger.warning("requests library not available for reward scoring")
return None
total = positive + negative
if total == 0:
return None
score = (positive - negative) / total
result = RewardScore(
score=score,
positive_votes=positive,
negative_votes=negative,
total_votes=total,
model_used=model_name,
)
# Persist to DB
try:
_ensure_reward_table()
conn = _get_conn()
conn.execute(
"""
INSERT INTO reward_scores
(task_id, agent_id, output_text, score, positive, negative, total, model_used)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
task_id, agent_id, output_text[:5000],
score, positive, negative, total, model_name,
),
)
conn.commit()
conn.close()
except Exception as db_exc:
_reward_logger.warning("Failed to persist reward score: %s", db_exc)
_reward_logger.info(
"Scored task %s agent %s: %.2f (%d+/%d- of %d votes)",
task_id, agent_id, score, positive, negative, total,
)
return result
def get_reward_scores(
agent_id: Optional[str] = None, limit: int = 50
) -> list[dict]:
"""Retrieve historical reward scores from the database."""
_ensure_reward_table()
conn = _get_conn()
if agent_id:
rows = conn.execute(
"SELECT * FROM reward_scores WHERE agent_id = ? ORDER BY id DESC LIMIT ?",
(agent_id, limit),
).fetchall()
else:
rows = conn.execute(
"SELECT * FROM reward_scores ORDER BY id DESC LIMIT ?",
(limit,),
).fetchall()
conn.close()
return [
{
"task_id": r["task_id"],
"agent_id": r["agent_id"],
"score": r["score"],
"positive": r["positive"],
"negative": r["negative"],
"total": r["total"],
"model_used": r["model_used"],
"scored_at": r["scored_at"],
}
for r in rows
]

View File

@@ -51,6 +51,16 @@ class PersonaNode(SwarmNode):
self._meta = meta
self._persona_id = persona_id
self._use_learner = use_learner
# Resolve model: registry assignment > persona default > global default
self._model_name: Optional[str] = meta.get("model")
try:
from infrastructure.models.registry import model_registry
assigned = model_registry.get_agent_model(agent_id)
if assigned:
self._model_name = assigned.name
except Exception:
pass # Graceful degradation — use persona/global default
# Initialize tool executor for task execution
self._tool_executor: Optional[ToolExecutor] = None
@@ -213,6 +223,11 @@ class PersonaNode(SwarmNode):
"""Return the task ID currently being executed, if any."""
return self._current_task
@property
def model_name(self) -> Optional[str]:
"""Return the model this agent uses, or None for global default."""
return self._model_name
@property
def tool_capabilities(self) -> list[str]:
"""Return list of available tool names."""

View File

@@ -14,7 +14,7 @@ from __future__ import annotations
from typing import TypedDict
class PersonaMeta(TypedDict):
class PersonaMeta(TypedDict, total=False):
id: str
name: str
role: str
@@ -24,6 +24,11 @@ class PersonaMeta(TypedDict):
bid_base: int # typical bid when task matches persona
bid_jitter: int # ± random jitter added to bid_base
preferred_keywords: list[str]
# Optional: custom model override for this persona.
# When set, this persona uses this model instead of the global default.
# Value is a model name registered in the ModelRegistry, or an Ollama
# model name like "llama3.2" or "deepseek-r1:1.5b".
model: str
PERSONAS: dict[str, PersonaMeta] = {