1
0
This repository has been archived on 2026-03-24. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
Timmy-time-dashboard/tests/infrastructure/test_models_api.py
Claude 211c54bc8c feat: add custom weights, model registry, per-agent models, and reward scoring
Inspired by OpenClaw-RL's multi-model orchestration, this adds four
features for custom model management:

1. Custom model registry (infrastructure/models/registry.py) — SQLite-backed
   registry for GGUF, safetensors, HF checkpoint, and Ollama models with
   role-based lookups (general, reward, teacher, judge).

2. Per-agent model assignment — each swarm persona can use a different model
   instead of sharing the global default. Resolved via registry assignment >
   persona default > global default.

3. Runtime model management API (/api/v1/models) — REST endpoints to register,
   list, assign, enable/disable, and remove custom models without restart.
   Includes a dashboard page at /models.

4. Reward model scoring (PRM-style) — majority-vote quality evaluation of
   agent outputs using a configurable reward model. Scores persist in SQLite
   and feed into the swarm learner.

New config settings: custom_weights_dir, reward_model_enabled,
reward_model_name, reward_model_votes.

54 new tests covering registry CRUD, API endpoints, agent assignments,
role lookups, and reward scoring.

https://claude.ai/code/session_01V4iTozMwcE2gjfnCJdCugC
2026-02-27 01:27:53 +00:00

274 lines
8.6 KiB
Python

"""Tests for the custom models API routes."""
from unittest.mock import patch, MagicMock
import pytest
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
)
@pytest.fixture
def registry(tmp_path):
"""A fresh ModelRegistry for each test."""
db = tmp_path / "api_test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
reg = ModelRegistry()
yield reg
class TestModelsAPIList:
"""Test listing models via the API."""
def test_list_models_empty(self, client, tmp_path):
db = tmp_path / "api.db"
with patch("infrastructure.models.registry.DB_PATH", db):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.list_models.return_value = []
resp = client.get("/api/v1/models")
assert resp.status_code == 200
data = resp.json()
assert "models" in data
assert "total" in data
def test_list_models_with_data(self, client):
model = CustomModel(
name="test-m",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.list_models.return_value = [model]
resp = client.get("/api/v1/models")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["models"][0]["name"] == "test-m"
class TestModelsAPIRegister:
"""Test model registration via the API."""
def test_register_ollama_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.register.return_value = CustomModel(
name="my-model",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
resp = client.post(
"/api/v1/models",
json={
"name": "my-model",
"format": "ollama",
"path": "llama3.2",
"role": "general",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["model"]["name"] == "my-model"
def test_register_invalid_format(self, client):
resp = client.post(
"/api/v1/models",
json={
"name": "bad-model",
"format": "invalid_format",
"path": "whatever",
},
)
assert resp.status_code == 400
assert "Invalid format" in resp.json()["detail"]
def test_register_invalid_role(self, client):
resp = client.post(
"/api/v1/models",
json={
"name": "bad-model",
"format": "ollama",
"path": "llama3.2",
"role": "invalid_role",
},
)
assert resp.status_code == 400
assert "Invalid role" in resp.json()["detail"]
class TestModelsAPIDelete:
"""Test model deletion via the API."""
def test_delete_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unregister.return_value = True
resp = client.delete("/api/v1/models/my-model")
assert resp.status_code == 200
def test_delete_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unregister.return_value = False
resp = client.delete("/api/v1/models/nonexistent")
assert resp.status_code == 404
class TestModelsAPIGet:
"""Test getting a specific model."""
def test_get_model(self, client):
model = CustomModel(
name="my-model",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get.return_value = model
resp = client.get("/api/v1/models/my-model")
assert resp.status_code == 200
assert resp.json()["name"] == "my-model"
def test_get_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get.return_value = None
resp = client.get("/api/v1/models/nonexistent")
assert resp.status_code == 404
class TestModelsAPIAssignments:
"""Test agent model assignment endpoints."""
def test_assign_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.assign_model.return_value = True
resp = client.post(
"/api/v1/models/assignments",
json={"agent_id": "agent-1", "model_name": "my-model"},
)
assert resp.status_code == 200
def test_assign_nonexistent_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.assign_model.return_value = False
resp = client.post(
"/api/v1/models/assignments",
json={"agent_id": "agent-1", "model_name": "nonexistent"},
)
assert resp.status_code == 404
def test_unassign_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unassign_model.return_value = True
resp = client.delete("/api/v1/models/assignments/agent-1")
assert resp.status_code == 200
def test_unassign_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.unassign_model.return_value = False
resp = client.delete("/api/v1/models/assignments/nonexistent")
assert resp.status_code == 404
def test_list_assignments(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_agent_assignments.return_value = {
"agent-1": "model-a",
"agent-2": "model-b",
}
resp = client.get("/api/v1/models/assignments/all")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 2
class TestModelsAPIRoles:
"""Test role-based lookup endpoints."""
def test_get_reward_model(self, client):
model = CustomModel(
name="reward-m",
format=ModelFormat.OLLAMA,
path="deepseek-r1:1.5b",
role=ModelRole.REWARD,
)
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = model
resp = client.get("/api/v1/models/roles/reward")
assert resp.status_code == 200
data = resp.json()
assert data["reward_model"]["name"] == "reward-m"
def test_get_reward_model_none(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_reward_model.return_value = None
resp = client.get("/api/v1/models/roles/reward")
assert resp.status_code == 200
assert resp.json()["reward_model"] is None
def test_get_teacher_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.get_teacher_model.return_value = None
resp = client.get("/api/v1/models/roles/teacher")
assert resp.status_code == 200
assert resp.json()["teacher_model"] is None
class TestModelsAPISetActive:
"""Test enable/disable model endpoint."""
def test_enable_model(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.set_active.return_value = True
resp = client.patch(
"/api/v1/models/my-model/active",
json={"active": True},
)
assert resp.status_code == 200
def test_disable_nonexistent(self, client):
with patch(
"dashboard.routes.models.model_registry"
) as mock_reg:
mock_reg.set_active.return_value = False
resp = client.patch(
"/api/v1/models/nonexistent/active",
json={"active": False},
)
assert resp.status_code == 404