This repository has been archived on 2026-03-24. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
Timmy-time-dashboard/tests/infrastructure/test_model_registry.py

231 lines
8.0 KiB
Python

"""Tests for the custom model registry."""
from unittest.mock import patch
import pytest
from infrastructure.models.registry import CustomModel, ModelFormat, ModelRegistry, ModelRole
@pytest.fixture
def registry(tmp_path):
"""Create a fresh ModelRegistry backed by a temporary database."""
db = tmp_path / "test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
reg = ModelRegistry()
yield reg
@pytest.fixture
def sample_model():
"""A sample CustomModel for testing."""
return CustomModel(
name="test-llama",
format=ModelFormat.OLLAMA,
path="llama3.2",
role=ModelRole.GENERAL,
context_window=8192,
description="Test model",
)
@pytest.fixture
def reward_model():
"""A sample reward model."""
return CustomModel(
name="test-reward",
format=ModelFormat.OLLAMA,
path="deepseek-r1:1.5b",
role=ModelRole.REWARD,
context_window=32000,
description="Test reward model",
)
class TestModelCRUD:
"""Test model registration, lookup, and removal."""
def test_register_model(self, registry, sample_model):
registered = registry.register(sample_model)
assert registered.name == "test-llama"
assert registered.format == ModelFormat.OLLAMA
def test_get_model(self, registry, sample_model):
registry.register(sample_model)
found = registry.get("test-llama")
assert found is not None
assert found.name == "test-llama"
assert found.path == "llama3.2"
def test_get_nonexistent_model(self, registry):
assert registry.get("nonexistent") is None
def test_list_models(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
all_models = registry.list_models()
assert len(all_models) == 2
def test_list_models_by_role(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
general = registry.list_models(role=ModelRole.GENERAL)
assert len(general) == 1
assert general[0].name == "test-llama"
rewards = registry.list_models(role=ModelRole.REWARD)
assert len(rewards) == 1
assert rewards[0].name == "test-reward"
def test_unregister_model(self, registry, sample_model):
registry.register(sample_model)
assert registry.unregister("test-llama") is True
assert registry.get("test-llama") is None
def test_unregister_nonexistent(self, registry):
assert registry.unregister("nonexistent") is False
def test_set_active(self, registry, sample_model):
registry.register(sample_model)
assert registry.set_active("test-llama", False) is True
model = registry.get("test-llama")
assert model.active is False
assert registry.set_active("test-llama", True) is True
model = registry.get("test-llama")
assert model.active is True
def test_set_active_nonexistent(self, registry):
assert registry.set_active("nonexistent", True) is False
def test_register_replaces_existing(self, registry, sample_model):
registry.register(sample_model)
updated = CustomModel(
name="test-llama",
format=ModelFormat.GGUF,
path="/new/path.gguf",
role=ModelRole.GENERAL,
description="Updated model",
)
registry.register(updated)
found = registry.get("test-llama")
assert found.format == ModelFormat.GGUF
assert found.path == "/new/path.gguf"
class TestAgentAssignments:
"""Test agent-to-model assignment management."""
def test_assign_model(self, registry, sample_model):
registry.register(sample_model)
assert registry.assign_model("agent-1", "test-llama") is True
model = registry.get_agent_model("agent-1")
assert model is not None
assert model.name == "test-llama"
def test_assign_nonexistent_model(self, registry):
assert registry.assign_model("agent-1", "nonexistent") is False
def test_unassign_model(self, registry, sample_model):
registry.register(sample_model)
registry.assign_model("agent-1", "test-llama")
assert registry.unassign_model("agent-1") is True
assert registry.get_agent_model("agent-1") is None
def test_unassign_nonexistent(self, registry):
assert registry.unassign_model("agent-1") is False
def test_get_agent_model_none(self, registry):
assert registry.get_agent_model("agent-1") is None
def test_get_all_assignments(self, registry, sample_model, reward_model):
registry.register(sample_model)
registry.register(reward_model)
registry.assign_model("agent-1", "test-llama")
registry.assign_model("agent-2", "test-reward")
assignments = registry.get_agent_assignments()
assert len(assignments) == 2
assert assignments["agent-1"] == "test-llama"
assert assignments["agent-2"] == "test-reward"
def test_unregister_removes_assignments(self, registry, sample_model):
registry.register(sample_model)
registry.assign_model("agent-1", "test-llama")
registry.unregister("test-llama")
assert registry.get_agent_model("agent-1") is None
assert len(registry.get_agent_assignments()) == 0
class TestRoleLookups:
"""Test role-based model lookups."""
def test_get_reward_model(self, registry, reward_model):
registry.register(reward_model)
found = registry.get_reward_model()
assert found is not None
assert found.name == "test-reward"
assert found.role == ModelRole.REWARD
def test_get_reward_model_none(self, registry):
assert registry.get_reward_model() is None
def test_get_teacher_model(self, registry):
teacher = CustomModel(
name="teacher-model",
format=ModelFormat.OLLAMA,
path="teacher:latest",
role=ModelRole.TEACHER,
)
registry.register(teacher)
found = registry.get_teacher_model()
assert found is not None
assert found.name == "teacher-model"
def test_get_teacher_model_none(self, registry):
assert registry.get_teacher_model() is None
def test_inactive_reward_model_not_returned(self, registry, reward_model):
registry.register(reward_model)
registry.set_active("test-reward", False)
assert registry.get_reward_model() is None
class TestCustomModelDataclass:
"""Test CustomModel construction."""
def test_default_registered_at(self):
model = CustomModel(name="test", format=ModelFormat.OLLAMA, path="test")
assert model.registered_at != ""
def test_model_roles(self):
assert ModelRole.GENERAL.value == "general"
assert ModelRole.REWARD.value == "reward"
assert ModelRole.TEACHER.value == "teacher"
assert ModelRole.JUDGE.value == "judge"
def test_model_formats(self):
assert ModelFormat.GGUF.value == "gguf"
assert ModelFormat.SAFETENSORS.value == "safetensors"
assert ModelFormat.HF_CHECKPOINT.value == "hf"
assert ModelFormat.OLLAMA.value == "ollama"
class TestWALMode:
"""Verify WAL mode is enabled for the model registry database."""
def test_registry_db_uses_wal(self, tmp_path):
db = tmp_path / "wal_test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
from infrastructure.models.registry import _get_conn
with _get_conn() as conn:
mode = conn.execute("PRAGMA journal_mode").fetchone()[0]
assert mode == "wal"
def test_registry_db_busy_timeout(self, tmp_path):
db = tmp_path / "wal_test.db"
with patch("infrastructure.models.registry.DB_PATH", db):
from infrastructure.models.registry import _get_conn
with _get_conn() as conn:
timeout = conn.execute("PRAGMA busy_timeout").fetchone()[0]
assert timeout == 5000