forked from Rockachopa/Timmy-time-dashboard
510 lines
19 KiB
Python
510 lines
19 KiB
Python
"""Tests for infrastructure.models.multimodal — multi-modal model management."""
|
|
|
|
import json
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from infrastructure.models.multimodal import (
|
|
DEFAULT_FALLBACK_CHAINS,
|
|
KNOWN_MODEL_CAPABILITIES,
|
|
ModelCapability,
|
|
ModelInfo,
|
|
MultiModalManager,
|
|
get_model_for_capability,
|
|
model_supports_tools,
|
|
model_supports_vision,
|
|
pull_model_with_fallback,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ModelCapability enum
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestModelCapability:
|
|
def test_members_exist(self):
|
|
assert ModelCapability.TEXT
|
|
assert ModelCapability.VISION
|
|
assert ModelCapability.AUDIO
|
|
assert ModelCapability.TOOLS
|
|
assert ModelCapability.JSON
|
|
assert ModelCapability.STREAMING
|
|
|
|
def test_all_members_unique(self):
|
|
values = [m.value for m in ModelCapability]
|
|
assert len(values) == len(set(values))
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ModelInfo dataclass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestModelInfo:
|
|
def test_defaults(self):
|
|
info = ModelInfo(name="test-model")
|
|
assert info.name == "test-model"
|
|
assert info.capabilities == set()
|
|
assert info.is_available is False
|
|
assert info.is_pulled is False
|
|
assert info.size_mb is None
|
|
assert info.description == ""
|
|
|
|
def test_supports_true(self):
|
|
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT, ModelCapability.VISION})
|
|
assert info.supports(ModelCapability.TEXT) is True
|
|
assert info.supports(ModelCapability.VISION) is True
|
|
|
|
def test_supports_false(self):
|
|
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT})
|
|
assert info.supports(ModelCapability.VISION) is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Known model capabilities lookup table
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestKnownModelCapabilities:
|
|
def test_vision_models_have_vision(self):
|
|
vision_names = [
|
|
"llama3.2-vision",
|
|
"llava",
|
|
"moondream",
|
|
"qwen2.5-vl",
|
|
]
|
|
for name in vision_names:
|
|
assert ModelCapability.VISION in KNOWN_MODEL_CAPABILITIES[name], name
|
|
|
|
def test_text_models_lack_vision(self):
|
|
text_only = ["deepseek-r1", "gemma2", "phi3"]
|
|
for name in text_only:
|
|
assert ModelCapability.VISION not in KNOWN_MODEL_CAPABILITIES[name], name
|
|
|
|
def test_all_models_have_text(self):
|
|
for name, caps in KNOWN_MODEL_CAPABILITIES.items():
|
|
assert ModelCapability.TEXT in caps, f"{name} should have TEXT"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Default fallback chains
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDefaultFallbackChains:
|
|
def test_vision_chain_non_empty(self):
|
|
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.VISION]) > 0
|
|
|
|
def test_tools_chain_non_empty(self):
|
|
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.TOOLS]) > 0
|
|
|
|
def test_audio_chain_empty(self):
|
|
assert DEFAULT_FALLBACK_CHAINS[ModelCapability.AUDIO] == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers to build a manager without hitting the network
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _fake_ollama_tags(*model_names: str) -> bytes:
|
|
"""Build a JSON response mimicking Ollama /api/tags."""
|
|
models = []
|
|
for name in model_names:
|
|
models.append({"name": name, "size": 4 * 1024 * 1024 * 1024, "details": {"family": "test"}})
|
|
return json.dumps({"models": models}).encode()
|
|
|
|
|
|
def _make_manager(model_names: list[str] | None = None) -> MultiModalManager:
|
|
"""Create a MultiModalManager with mocked Ollama responses."""
|
|
if model_names is None:
|
|
# No models available — Ollama unreachable
|
|
with patch("urllib.request.urlopen", side_effect=ConnectionError("no ollama")):
|
|
return MultiModalManager(ollama_url="http://localhost:11434")
|
|
|
|
resp = MagicMock()
|
|
resp.__enter__ = MagicMock(return_value=resp)
|
|
resp.__exit__ = MagicMock(return_value=False)
|
|
resp.read.return_value = _fake_ollama_tags(*model_names)
|
|
resp.status = 200
|
|
|
|
with patch("urllib.request.urlopen", return_value=resp):
|
|
return MultiModalManager(ollama_url="http://localhost:11434")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MultiModalManager — init & refresh
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMultiModalManagerInit:
|
|
def test_init_no_ollama(self):
|
|
mgr = _make_manager(None)
|
|
assert mgr.list_available_models() == []
|
|
|
|
def test_init_with_models(self):
|
|
mgr = _make_manager(["llama3.1:8b", "llava:7b"])
|
|
names = {m.name for m in mgr.list_available_models()}
|
|
assert names == {"llama3.1:8b", "llava:7b"}
|
|
|
|
def test_refresh_updates_models(self):
|
|
mgr = _make_manager([])
|
|
assert mgr.list_available_models() == []
|
|
|
|
resp = MagicMock()
|
|
resp.__enter__ = MagicMock(return_value=resp)
|
|
resp.__exit__ = MagicMock(return_value=False)
|
|
resp.read.return_value = _fake_ollama_tags("gemma2:9b")
|
|
resp.status = 200
|
|
|
|
with patch("urllib.request.urlopen", return_value=resp):
|
|
mgr.refresh()
|
|
|
|
names = {m.name for m in mgr.list_available_models()}
|
|
assert "gemma2:9b" in names
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _detect_capabilities
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDetectCapabilities:
|
|
def test_exact_match(self):
|
|
mgr = _make_manager(None)
|
|
caps = mgr._detect_capabilities("llava:7b")
|
|
assert ModelCapability.VISION in caps
|
|
|
|
def test_base_name_match(self):
|
|
mgr = _make_manager(None)
|
|
caps = mgr._detect_capabilities("llava:99b")
|
|
# "llava:99b" not in table, but "llava" is
|
|
assert ModelCapability.VISION in caps
|
|
|
|
def test_unknown_model_defaults_to_text(self):
|
|
mgr = _make_manager(None)
|
|
caps = mgr._detect_capabilities("totally-unknown-model:1b")
|
|
assert caps == {ModelCapability.TEXT, ModelCapability.STREAMING}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_model_capabilities / model_supports
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetModelCapabilities:
|
|
def test_available_model(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
caps = mgr.get_model_capabilities("llava:7b")
|
|
assert ModelCapability.VISION in caps
|
|
|
|
def test_unavailable_model_uses_detection(self):
|
|
mgr = _make_manager([])
|
|
caps = mgr.get_model_capabilities("llava:7b")
|
|
assert ModelCapability.VISION in caps
|
|
|
|
|
|
class TestModelSupports:
|
|
def test_supports_true(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
assert mgr.model_supports("llava:7b", ModelCapability.VISION) is True
|
|
|
|
def test_supports_false(self):
|
|
mgr = _make_manager(["deepseek-r1:7b"])
|
|
assert mgr.model_supports("deepseek-r1:7b", ModelCapability.VISION) is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_models_with_capability
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetModelsWithCapability:
|
|
def test_returns_vision_models(self):
|
|
mgr = _make_manager(["llava:7b", "deepseek-r1:7b"])
|
|
vision = mgr.get_models_with_capability(ModelCapability.VISION)
|
|
names = {m.name for m in vision}
|
|
assert "llava:7b" in names
|
|
assert "deepseek-r1:7b" not in names
|
|
|
|
def test_empty_when_none_available(self):
|
|
mgr = _make_manager(["deepseek-r1:7b"])
|
|
vision = mgr.get_models_with_capability(ModelCapability.VISION)
|
|
assert vision == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_best_model_for
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetBestModelFor:
|
|
def test_preferred_model_with_capability(self):
|
|
mgr = _make_manager(["llava:7b", "llama3.1:8b"])
|
|
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:7b")
|
|
assert result == "llava:7b"
|
|
|
|
def test_preferred_model_without_capability_uses_fallback(self):
|
|
mgr = _make_manager(["deepseek-r1:7b", "llava:7b"])
|
|
# preferred doesn't have VISION, fallback chain has llava:7b
|
|
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="deepseek-r1:7b")
|
|
assert result == "llava:7b"
|
|
|
|
def test_fallback_chain_order(self):
|
|
# First in chain: llama3.2:3b
|
|
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
|
|
result = mgr.get_best_model_for(ModelCapability.VISION)
|
|
assert result == "llama3.2:3b"
|
|
|
|
def test_any_capable_model_when_no_fallback(self):
|
|
mgr = _make_manager(["moondream:1.8b"])
|
|
mgr._fallback_chains[ModelCapability.VISION] = [] # clear chain
|
|
result = mgr.get_best_model_for(ModelCapability.VISION)
|
|
assert result == "moondream:1.8b"
|
|
|
|
def test_none_when_no_capable_model(self):
|
|
mgr = _make_manager(["deepseek-r1:7b"])
|
|
result = mgr.get_best_model_for(ModelCapability.VISION)
|
|
assert result is None
|
|
|
|
def test_preferred_model_not_available_skipped(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
# preferred_model "llava:13b" is not in available_models
|
|
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:13b")
|
|
assert result == "llava:7b"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# pull_model_with_fallback (manager method)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPullModelWithFallback:
|
|
def test_already_available(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
|
|
assert model == "llama3.1:8b"
|
|
assert is_fallback is False
|
|
|
|
def test_pull_succeeds(self):
|
|
mgr = _make_manager([])
|
|
|
|
pull_resp = MagicMock()
|
|
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
|
|
pull_resp.__exit__ = MagicMock(return_value=False)
|
|
pull_resp.status = 200
|
|
|
|
# After pull, refresh returns the model
|
|
refresh_resp = MagicMock()
|
|
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
|
|
refresh_resp.__exit__ = MagicMock(return_value=False)
|
|
refresh_resp.read.return_value = _fake_ollama_tags("llama3.1:8b")
|
|
refresh_resp.status = 200
|
|
|
|
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
|
|
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
|
|
assert model == "llama3.1:8b"
|
|
assert is_fallback is False
|
|
|
|
def test_pull_fails_uses_capability_fallback(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
with patch("urllib.request.urlopen", side_effect=ConnectionError("fail")):
|
|
model, is_fallback = mgr.pull_model_with_fallback(
|
|
"nonexistent-vision:1b",
|
|
capability=ModelCapability.VISION,
|
|
)
|
|
assert model == "llava:7b"
|
|
assert is_fallback is True
|
|
|
|
def test_pull_fails_uses_default_model(self):
|
|
mgr = _make_manager([settings_ollama_model := "llama3.1:8b"])
|
|
with (
|
|
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
|
|
patch("infrastructure.models.multimodal.settings") as mock_settings,
|
|
):
|
|
mock_settings.ollama_model = settings_ollama_model
|
|
mock_settings.ollama_url = "http://localhost:11434"
|
|
model, is_fallback = mgr.pull_model_with_fallback("missing-model:99b")
|
|
assert model == "llama3.1:8b"
|
|
assert is_fallback is True
|
|
|
|
def test_auto_pull_false_skips_pull(self):
|
|
mgr = _make_manager([])
|
|
with patch("infrastructure.models.multimodal.settings") as mock_settings:
|
|
mock_settings.ollama_model = "default"
|
|
model, is_fallback = mgr.pull_model_with_fallback("missing:1b", auto_pull=False)
|
|
# Falls through to absolute last resort
|
|
assert model == "missing:1b"
|
|
assert is_fallback is False
|
|
|
|
def test_absolute_last_resort(self):
|
|
mgr = _make_manager([])
|
|
with (
|
|
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
|
|
patch("infrastructure.models.multimodal.settings") as mock_settings,
|
|
):
|
|
mock_settings.ollama_model = "not-available"
|
|
model, is_fallback = mgr.pull_model_with_fallback("primary:1b")
|
|
assert model == "primary:1b"
|
|
assert is_fallback is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _pull_model
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPullModel:
|
|
def test_pull_success(self):
|
|
mgr = _make_manager([])
|
|
|
|
pull_resp = MagicMock()
|
|
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
|
|
pull_resp.__exit__ = MagicMock(return_value=False)
|
|
pull_resp.status = 200
|
|
|
|
refresh_resp = MagicMock()
|
|
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
|
|
refresh_resp.__exit__ = MagicMock(return_value=False)
|
|
refresh_resp.read.return_value = _fake_ollama_tags("new-model:1b")
|
|
refresh_resp.status = 200
|
|
|
|
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
|
|
assert mgr._pull_model("new-model:1b") is True
|
|
|
|
def test_pull_network_error(self):
|
|
mgr = _make_manager([])
|
|
with patch("urllib.request.urlopen", side_effect=ConnectionError("offline")):
|
|
assert mgr._pull_model("any-model:1b") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# configure_fallback_chain / get_fallback_chain
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFallbackChainConfig:
|
|
def test_configure_and_get(self):
|
|
mgr = _make_manager(None)
|
|
mgr.configure_fallback_chain(ModelCapability.VISION, ["model-a", "model-b"])
|
|
assert mgr.get_fallback_chain(ModelCapability.VISION) == ["model-a", "model-b"]
|
|
|
|
def test_get_returns_copy(self):
|
|
mgr = _make_manager(None)
|
|
chain = mgr.get_fallback_chain(ModelCapability.VISION)
|
|
chain.append("mutated")
|
|
assert "mutated" not in mgr.get_fallback_chain(ModelCapability.VISION)
|
|
|
|
def test_get_empty_for_unknown(self):
|
|
mgr = _make_manager(None)
|
|
# AUDIO has an empty chain by default
|
|
assert mgr.get_fallback_chain(ModelCapability.AUDIO) == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_model_for_content
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetModelForContent:
|
|
def test_image_content(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
model, is_fb = mgr.get_model_for_content("image")
|
|
assert model == "llava:7b"
|
|
|
|
def test_vision_content(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
model, _ = mgr.get_model_for_content("vision")
|
|
assert model == "llava:7b"
|
|
|
|
def test_multimodal_content(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
model, _ = mgr.get_model_for_content("multimodal")
|
|
assert model == "llava:7b"
|
|
|
|
def test_audio_content(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
with patch("infrastructure.models.multimodal.settings") as mock_settings:
|
|
mock_settings.ollama_model = "llama3.1:8b"
|
|
mock_settings.ollama_url = "http://localhost:11434"
|
|
model, _ = mgr.get_model_for_content("audio")
|
|
assert model == "llama3.1:8b"
|
|
|
|
def test_text_content(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
with patch("infrastructure.models.multimodal.settings") as mock_settings:
|
|
mock_settings.ollama_model = "llama3.1:8b"
|
|
mock_settings.ollama_url = "http://localhost:11434"
|
|
model, _ = mgr.get_model_for_content("text")
|
|
assert model == "llama3.1:8b"
|
|
|
|
def test_preferred_model_respected(self):
|
|
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
|
|
model, _ = mgr.get_model_for_content("image", preferred_model="llama3.2:3b")
|
|
assert model == "llama3.2:3b"
|
|
|
|
def test_case_insensitive(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
model, _ = mgr.get_model_for_content("IMAGE")
|
|
assert model == "llava:7b"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level convenience functions
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConvenienceFunctions:
|
|
def _patch_manager(self, mgr):
|
|
return patch(
|
|
"infrastructure.models.multimodal._multimodal_manager",
|
|
mgr,
|
|
)
|
|
|
|
def test_get_model_for_capability(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
with self._patch_manager(mgr):
|
|
result = get_model_for_capability(ModelCapability.VISION)
|
|
assert result == "llava:7b"
|
|
|
|
def test_pull_model_with_fallback_convenience(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
with self._patch_manager(mgr):
|
|
model, is_fb = pull_model_with_fallback("llama3.1:8b")
|
|
assert model == "llama3.1:8b"
|
|
assert is_fb is False
|
|
|
|
def test_model_supports_vision_true(self):
|
|
mgr = _make_manager(["llava:7b"])
|
|
with self._patch_manager(mgr):
|
|
assert model_supports_vision("llava:7b") is True
|
|
|
|
def test_model_supports_vision_false(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
with self._patch_manager(mgr):
|
|
assert model_supports_vision("llama3.1:8b") is False
|
|
|
|
def test_model_supports_tools_true(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
with self._patch_manager(mgr):
|
|
assert model_supports_tools("llama3.1:8b") is True
|
|
|
|
def test_model_supports_tools_false(self):
|
|
mgr = _make_manager(["deepseek-r1:7b"])
|
|
with self._patch_manager(mgr):
|
|
assert model_supports_tools("deepseek-r1:7b") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ModelInfo in available_models — size_mb and description populated
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestModelInfoPopulation:
|
|
def test_size_and_description(self):
|
|
mgr = _make_manager(["llama3.1:8b"])
|
|
info = mgr._available_models["llama3.1:8b"]
|
|
assert info.is_available is True
|
|
assert info.is_pulled is True
|
|
assert info.size_mb == 4 * 1024 # 4 GiB in MiB
|
|
assert info.description == "test"
|