From c4167a9466e453b9eefbbb480a15f0dee3efc011 Mon Sep 17 00:00:00 2001 From: kimi Date: Fri, 20 Mar 2026 23:59:25 -0400 Subject: [PATCH] test: add comprehensive tests for multimodal.py (#658) Cover ModelCapability enum, ModelInfo dataclass, MultiModalManager init/refresh, capability detection (exact/base-name/unknown), fallback chain configuration, model pulling with fallback, content-type routing, and module-level convenience functions. Fixes #658 Co-Authored-By: Claude Opus 4.6 --- tests/infrastructure/test_multimodal.py | 509 ++++++++++++++++++++++++ 1 file changed, 509 insertions(+) create mode 100644 tests/infrastructure/test_multimodal.py diff --git a/tests/infrastructure/test_multimodal.py b/tests/infrastructure/test_multimodal.py new file mode 100644 index 00000000..ac8bfd7c --- /dev/null +++ b/tests/infrastructure/test_multimodal.py @@ -0,0 +1,509 @@ +"""Tests for infrastructure.models.multimodal — multi-modal model management.""" + +import json +from unittest.mock import MagicMock, patch + +from infrastructure.models.multimodal import ( + DEFAULT_FALLBACK_CHAINS, + KNOWN_MODEL_CAPABILITIES, + ModelCapability, + ModelInfo, + MultiModalManager, + get_model_for_capability, + model_supports_tools, + model_supports_vision, + pull_model_with_fallback, +) + +# --------------------------------------------------------------------------- +# ModelCapability enum +# --------------------------------------------------------------------------- + + +class TestModelCapability: + def test_members_exist(self): + assert ModelCapability.TEXT + assert ModelCapability.VISION + assert ModelCapability.AUDIO + assert ModelCapability.TOOLS + assert ModelCapability.JSON + assert ModelCapability.STREAMING + + def test_all_members_unique(self): + values = [m.value for m in ModelCapability] + assert len(values) == len(set(values)) + + +# --------------------------------------------------------------------------- +# ModelInfo dataclass +# --------------------------------------------------------------------------- + + +class TestModelInfo: + def test_defaults(self): + info = ModelInfo(name="test-model") + assert info.name == "test-model" + assert info.capabilities == set() + assert info.is_available is False + assert info.is_pulled is False + assert info.size_mb is None + assert info.description == "" + + def test_supports_true(self): + info = ModelInfo(name="m", capabilities={ModelCapability.TEXT, ModelCapability.VISION}) + assert info.supports(ModelCapability.TEXT) is True + assert info.supports(ModelCapability.VISION) is True + + def test_supports_false(self): + info = ModelInfo(name="m", capabilities={ModelCapability.TEXT}) + assert info.supports(ModelCapability.VISION) is False + + +# --------------------------------------------------------------------------- +# Known model capabilities lookup table +# --------------------------------------------------------------------------- + + +class TestKnownModelCapabilities: + def test_vision_models_have_vision(self): + vision_names = [ + "llama3.2-vision", + "llava", + "moondream", + "qwen2.5-vl", + ] + for name in vision_names: + assert ModelCapability.VISION in KNOWN_MODEL_CAPABILITIES[name], name + + def test_text_models_lack_vision(self): + text_only = ["deepseek-r1", "gemma2", "phi3"] + for name in text_only: + assert ModelCapability.VISION not in KNOWN_MODEL_CAPABILITIES[name], name + + def test_all_models_have_text(self): + for name, caps in KNOWN_MODEL_CAPABILITIES.items(): + assert ModelCapability.TEXT in caps, f"{name} should have TEXT" + + +# --------------------------------------------------------------------------- +# Default fallback chains +# --------------------------------------------------------------------------- + + +class TestDefaultFallbackChains: + def test_vision_chain_non_empty(self): + assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.VISION]) > 0 + + def test_tools_chain_non_empty(self): + assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.TOOLS]) > 0 + + def test_audio_chain_empty(self): + assert DEFAULT_FALLBACK_CHAINS[ModelCapability.AUDIO] == [] + + +# --------------------------------------------------------------------------- +# Helpers to build a manager without hitting the network +# --------------------------------------------------------------------------- + + +def _fake_ollama_tags(*model_names: str) -> bytes: + """Build a JSON response mimicking Ollama /api/tags.""" + models = [] + for name in model_names: + models.append({"name": name, "size": 4 * 1024 * 1024 * 1024, "details": {"family": "test"}}) + return json.dumps({"models": models}).encode() + + +def _make_manager(model_names: list[str] | None = None) -> MultiModalManager: + """Create a MultiModalManager with mocked Ollama responses.""" + if model_names is None: + # No models available — Ollama unreachable + with patch("urllib.request.urlopen", side_effect=ConnectionError("no ollama")): + return MultiModalManager(ollama_url="http://localhost:11434") + + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.read.return_value = _fake_ollama_tags(*model_names) + resp.status = 200 + + with patch("urllib.request.urlopen", return_value=resp): + return MultiModalManager(ollama_url="http://localhost:11434") + + +# --------------------------------------------------------------------------- +# MultiModalManager — init & refresh +# --------------------------------------------------------------------------- + + +class TestMultiModalManagerInit: + def test_init_no_ollama(self): + mgr = _make_manager(None) + assert mgr.list_available_models() == [] + + def test_init_with_models(self): + mgr = _make_manager(["llama3.1:8b", "llava:7b"]) + names = {m.name for m in mgr.list_available_models()} + assert names == {"llama3.1:8b", "llava:7b"} + + def test_refresh_updates_models(self): + mgr = _make_manager([]) + assert mgr.list_available_models() == [] + + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.read.return_value = _fake_ollama_tags("gemma2:9b") + resp.status = 200 + + with patch("urllib.request.urlopen", return_value=resp): + mgr.refresh() + + names = {m.name for m in mgr.list_available_models()} + assert "gemma2:9b" in names + + +# --------------------------------------------------------------------------- +# _detect_capabilities +# --------------------------------------------------------------------------- + + +class TestDetectCapabilities: + def test_exact_match(self): + mgr = _make_manager(None) + caps = mgr._detect_capabilities("llava:7b") + assert ModelCapability.VISION in caps + + def test_base_name_match(self): + mgr = _make_manager(None) + caps = mgr._detect_capabilities("llava:99b") + # "llava:99b" not in table, but "llava" is + assert ModelCapability.VISION in caps + + def test_unknown_model_defaults_to_text(self): + mgr = _make_manager(None) + caps = mgr._detect_capabilities("totally-unknown-model:1b") + assert caps == {ModelCapability.TEXT, ModelCapability.STREAMING} + + +# --------------------------------------------------------------------------- +# get_model_capabilities / model_supports +# --------------------------------------------------------------------------- + + +class TestGetModelCapabilities: + def test_available_model(self): + mgr = _make_manager(["llava:7b"]) + caps = mgr.get_model_capabilities("llava:7b") + assert ModelCapability.VISION in caps + + def test_unavailable_model_uses_detection(self): + mgr = _make_manager([]) + caps = mgr.get_model_capabilities("llava:7b") + assert ModelCapability.VISION in caps + + +class TestModelSupports: + def test_supports_true(self): + mgr = _make_manager(["llava:7b"]) + assert mgr.model_supports("llava:7b", ModelCapability.VISION) is True + + def test_supports_false(self): + mgr = _make_manager(["deepseek-r1:7b"]) + assert mgr.model_supports("deepseek-r1:7b", ModelCapability.VISION) is False + + +# --------------------------------------------------------------------------- +# get_models_with_capability +# --------------------------------------------------------------------------- + + +class TestGetModelsWithCapability: + def test_returns_vision_models(self): + mgr = _make_manager(["llava:7b", "deepseek-r1:7b"]) + vision = mgr.get_models_with_capability(ModelCapability.VISION) + names = {m.name for m in vision} + assert "llava:7b" in names + assert "deepseek-r1:7b" not in names + + def test_empty_when_none_available(self): + mgr = _make_manager(["deepseek-r1:7b"]) + vision = mgr.get_models_with_capability(ModelCapability.VISION) + assert vision == [] + + +# --------------------------------------------------------------------------- +# get_best_model_for +# --------------------------------------------------------------------------- + + +class TestGetBestModelFor: + def test_preferred_model_with_capability(self): + mgr = _make_manager(["llava:7b", "llama3.1:8b"]) + result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:7b") + assert result == "llava:7b" + + def test_preferred_model_without_capability_uses_fallback(self): + mgr = _make_manager(["deepseek-r1:7b", "llava:7b"]) + # preferred doesn't have VISION, fallback chain has llava:7b + result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="deepseek-r1:7b") + assert result == "llava:7b" + + def test_fallback_chain_order(self): + # First in chain: llama3.2:3b + mgr = _make_manager(["llama3.2:3b", "llava:7b"]) + result = mgr.get_best_model_for(ModelCapability.VISION) + assert result == "llama3.2:3b" + + def test_any_capable_model_when_no_fallback(self): + mgr = _make_manager(["moondream:1.8b"]) + mgr._fallback_chains[ModelCapability.VISION] = [] # clear chain + result = mgr.get_best_model_for(ModelCapability.VISION) + assert result == "moondream:1.8b" + + def test_none_when_no_capable_model(self): + mgr = _make_manager(["deepseek-r1:7b"]) + result = mgr.get_best_model_for(ModelCapability.VISION) + assert result is None + + def test_preferred_model_not_available_skipped(self): + mgr = _make_manager(["llava:7b"]) + # preferred_model "llava:13b" is not in available_models + result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:13b") + assert result == "llava:7b" + + +# --------------------------------------------------------------------------- +# pull_model_with_fallback (manager method) +# --------------------------------------------------------------------------- + + +class TestPullModelWithFallback: + def test_already_available(self): + mgr = _make_manager(["llama3.1:8b"]) + model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b") + assert model == "llama3.1:8b" + assert is_fallback is False + + def test_pull_succeeds(self): + mgr = _make_manager([]) + + pull_resp = MagicMock() + pull_resp.__enter__ = MagicMock(return_value=pull_resp) + pull_resp.__exit__ = MagicMock(return_value=False) + pull_resp.status = 200 + + # After pull, refresh returns the model + refresh_resp = MagicMock() + refresh_resp.__enter__ = MagicMock(return_value=refresh_resp) + refresh_resp.__exit__ = MagicMock(return_value=False) + refresh_resp.read.return_value = _fake_ollama_tags("llama3.1:8b") + refresh_resp.status = 200 + + with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]): + model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b") + assert model == "llama3.1:8b" + assert is_fallback is False + + def test_pull_fails_uses_capability_fallback(self): + mgr = _make_manager(["llava:7b"]) + with patch("urllib.request.urlopen", side_effect=ConnectionError("fail")): + model, is_fallback = mgr.pull_model_with_fallback( + "nonexistent-vision:1b", + capability=ModelCapability.VISION, + ) + assert model == "llava:7b" + assert is_fallback is True + + def test_pull_fails_uses_default_model(self): + mgr = _make_manager([settings_ollama_model := "llama3.1:8b"]) + with ( + patch("urllib.request.urlopen", side_effect=ConnectionError("fail")), + patch("infrastructure.models.multimodal.settings") as mock_settings, + ): + mock_settings.ollama_model = settings_ollama_model + mock_settings.ollama_url = "http://localhost:11434" + model, is_fallback = mgr.pull_model_with_fallback("missing-model:99b") + assert model == "llama3.1:8b" + assert is_fallback is True + + def test_auto_pull_false_skips_pull(self): + mgr = _make_manager([]) + with patch("infrastructure.models.multimodal.settings") as mock_settings: + mock_settings.ollama_model = "default" + model, is_fallback = mgr.pull_model_with_fallback("missing:1b", auto_pull=False) + # Falls through to absolute last resort + assert model == "missing:1b" + assert is_fallback is False + + def test_absolute_last_resort(self): + mgr = _make_manager([]) + with ( + patch("urllib.request.urlopen", side_effect=ConnectionError("fail")), + patch("infrastructure.models.multimodal.settings") as mock_settings, + ): + mock_settings.ollama_model = "not-available" + model, is_fallback = mgr.pull_model_with_fallback("primary:1b") + assert model == "primary:1b" + assert is_fallback is False + + +# --------------------------------------------------------------------------- +# _pull_model +# --------------------------------------------------------------------------- + + +class TestPullModel: + def test_pull_success(self): + mgr = _make_manager([]) + + pull_resp = MagicMock() + pull_resp.__enter__ = MagicMock(return_value=pull_resp) + pull_resp.__exit__ = MagicMock(return_value=False) + pull_resp.status = 200 + + refresh_resp = MagicMock() + refresh_resp.__enter__ = MagicMock(return_value=refresh_resp) + refresh_resp.__exit__ = MagicMock(return_value=False) + refresh_resp.read.return_value = _fake_ollama_tags("new-model:1b") + refresh_resp.status = 200 + + with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]): + assert mgr._pull_model("new-model:1b") is True + + def test_pull_network_error(self): + mgr = _make_manager([]) + with patch("urllib.request.urlopen", side_effect=ConnectionError("offline")): + assert mgr._pull_model("any-model:1b") is False + + +# --------------------------------------------------------------------------- +# configure_fallback_chain / get_fallback_chain +# --------------------------------------------------------------------------- + + +class TestFallbackChainConfig: + def test_configure_and_get(self): + mgr = _make_manager(None) + mgr.configure_fallback_chain(ModelCapability.VISION, ["model-a", "model-b"]) + assert mgr.get_fallback_chain(ModelCapability.VISION) == ["model-a", "model-b"] + + def test_get_returns_copy(self): + mgr = _make_manager(None) + chain = mgr.get_fallback_chain(ModelCapability.VISION) + chain.append("mutated") + assert "mutated" not in mgr.get_fallback_chain(ModelCapability.VISION) + + def test_get_empty_for_unknown(self): + mgr = _make_manager(None) + # AUDIO has an empty chain by default + assert mgr.get_fallback_chain(ModelCapability.AUDIO) == [] + + +# --------------------------------------------------------------------------- +# get_model_for_content +# --------------------------------------------------------------------------- + + +class TestGetModelForContent: + def test_image_content(self): + mgr = _make_manager(["llava:7b"]) + model, is_fb = mgr.get_model_for_content("image") + assert model == "llava:7b" + + def test_vision_content(self): + mgr = _make_manager(["llava:7b"]) + model, _ = mgr.get_model_for_content("vision") + assert model == "llava:7b" + + def test_multimodal_content(self): + mgr = _make_manager(["llava:7b"]) + model, _ = mgr.get_model_for_content("multimodal") + assert model == "llava:7b" + + def test_audio_content(self): + mgr = _make_manager(["llama3.1:8b"]) + with patch("infrastructure.models.multimodal.settings") as mock_settings: + mock_settings.ollama_model = "llama3.1:8b" + mock_settings.ollama_url = "http://localhost:11434" + model, _ = mgr.get_model_for_content("audio") + assert model == "llama3.1:8b" + + def test_text_content(self): + mgr = _make_manager(["llama3.1:8b"]) + with patch("infrastructure.models.multimodal.settings") as mock_settings: + mock_settings.ollama_model = "llama3.1:8b" + mock_settings.ollama_url = "http://localhost:11434" + model, _ = mgr.get_model_for_content("text") + assert model == "llama3.1:8b" + + def test_preferred_model_respected(self): + mgr = _make_manager(["llama3.2:3b", "llava:7b"]) + model, _ = mgr.get_model_for_content("image", preferred_model="llama3.2:3b") + assert model == "llama3.2:3b" + + def test_case_insensitive(self): + mgr = _make_manager(["llava:7b"]) + model, _ = mgr.get_model_for_content("IMAGE") + assert model == "llava:7b" + + +# --------------------------------------------------------------------------- +# Module-level convenience functions +# --------------------------------------------------------------------------- + + +class TestConvenienceFunctions: + def _patch_manager(self, mgr): + return patch( + "infrastructure.models.multimodal._multimodal_manager", + mgr, + ) + + def test_get_model_for_capability(self): + mgr = _make_manager(["llava:7b"]) + with self._patch_manager(mgr): + result = get_model_for_capability(ModelCapability.VISION) + assert result == "llava:7b" + + def test_pull_model_with_fallback_convenience(self): + mgr = _make_manager(["llama3.1:8b"]) + with self._patch_manager(mgr): + model, is_fb = pull_model_with_fallback("llama3.1:8b") + assert model == "llama3.1:8b" + assert is_fb is False + + def test_model_supports_vision_true(self): + mgr = _make_manager(["llava:7b"]) + with self._patch_manager(mgr): + assert model_supports_vision("llava:7b") is True + + def test_model_supports_vision_false(self): + mgr = _make_manager(["llama3.1:8b"]) + with self._patch_manager(mgr): + assert model_supports_vision("llama3.1:8b") is False + + def test_model_supports_tools_true(self): + mgr = _make_manager(["llama3.1:8b"]) + with self._patch_manager(mgr): + assert model_supports_tools("llama3.1:8b") is True + + def test_model_supports_tools_false(self): + mgr = _make_manager(["deepseek-r1:7b"]) + with self._patch_manager(mgr): + assert model_supports_tools("deepseek-r1:7b") is False + + +# --------------------------------------------------------------------------- +# ModelInfo in available_models — size_mb and description populated +# --------------------------------------------------------------------------- + + +class TestModelInfoPopulation: + def test_size_and_description(self): + mgr = _make_manager(["llama3.1:8b"]) + info = mgr._available_models["llama3.1:8b"] + assert info.is_available is True + assert info.is_pulled is True + assert info.size_mb == 4 * 1024 # 4 GiB in MiB + assert info.description == "test" -- 2.43.0