"""Tests for infrastructure.models.multimodal — multi-modal model management.""" import json from unittest.mock import MagicMock, patch from infrastructure.models.multimodal import ( DEFAULT_FALLBACK_CHAINS, KNOWN_MODEL_CAPABILITIES, ModelCapability, ModelInfo, MultiModalManager, get_model_for_capability, model_supports_tools, model_supports_vision, pull_model_with_fallback, ) # --------------------------------------------------------------------------- # ModelCapability enum # --------------------------------------------------------------------------- class TestModelCapability: def test_members_exist(self): assert ModelCapability.TEXT assert ModelCapability.VISION assert ModelCapability.AUDIO assert ModelCapability.TOOLS assert ModelCapability.JSON assert ModelCapability.STREAMING def test_all_members_unique(self): values = [m.value for m in ModelCapability] assert len(values) == len(set(values)) # --------------------------------------------------------------------------- # ModelInfo dataclass # --------------------------------------------------------------------------- class TestModelInfo: def test_defaults(self): info = ModelInfo(name="test-model") assert info.name == "test-model" assert info.capabilities == set() assert info.is_available is False assert info.is_pulled is False assert info.size_mb is None assert info.description == "" def test_supports_true(self): info = ModelInfo(name="m", capabilities={ModelCapability.TEXT, ModelCapability.VISION}) assert info.supports(ModelCapability.TEXT) is True assert info.supports(ModelCapability.VISION) is True def test_supports_false(self): info = ModelInfo(name="m", capabilities={ModelCapability.TEXT}) assert info.supports(ModelCapability.VISION) is False # --------------------------------------------------------------------------- # Known model capabilities lookup table # --------------------------------------------------------------------------- class TestKnownModelCapabilities: def test_vision_models_have_vision(self): vision_names = [ "llama3.2-vision", "llava", "moondream", "qwen2.5-vl", ] for name in vision_names: assert ModelCapability.VISION in KNOWN_MODEL_CAPABILITIES[name], name def test_text_models_lack_vision(self): text_only = ["deepseek-r1", "gemma2", "phi3"] for name in text_only: assert ModelCapability.VISION not in KNOWN_MODEL_CAPABILITIES[name], name def test_all_models_have_text(self): for name, caps in KNOWN_MODEL_CAPABILITIES.items(): assert ModelCapability.TEXT in caps, f"{name} should have TEXT" # --------------------------------------------------------------------------- # Default fallback chains # --------------------------------------------------------------------------- class TestDefaultFallbackChains: def test_vision_chain_non_empty(self): assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.VISION]) > 0 def test_tools_chain_non_empty(self): assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.TOOLS]) > 0 def test_audio_chain_empty(self): assert DEFAULT_FALLBACK_CHAINS[ModelCapability.AUDIO] == [] # --------------------------------------------------------------------------- # Helpers to build a manager without hitting the network # --------------------------------------------------------------------------- def _fake_ollama_tags(*model_names: str) -> bytes: """Build a JSON response mimicking Ollama /api/tags.""" models = [] for name in model_names: models.append({"name": name, "size": 4 * 1024 * 1024 * 1024, "details": {"family": "test"}}) return json.dumps({"models": models}).encode() def _make_manager(model_names: list[str] | None = None) -> MultiModalManager: """Create a MultiModalManager with mocked Ollama responses.""" if model_names is None: # No models available — Ollama unreachable with patch("urllib.request.urlopen", side_effect=ConnectionError("no ollama")): return MultiModalManager(ollama_url="http://localhost:11434") resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) resp.read.return_value = _fake_ollama_tags(*model_names) resp.status = 200 with patch("urllib.request.urlopen", return_value=resp): return MultiModalManager(ollama_url="http://localhost:11434") # --------------------------------------------------------------------------- # MultiModalManager — init & refresh # --------------------------------------------------------------------------- class TestMultiModalManagerInit: def test_init_no_ollama(self): mgr = _make_manager(None) assert mgr.list_available_models() == [] def test_init_with_models(self): mgr = _make_manager(["llama3.1:8b", "llava:7b"]) names = {m.name for m in mgr.list_available_models()} assert names == {"llama3.1:8b", "llava:7b"} def test_refresh_updates_models(self): mgr = _make_manager([]) assert mgr.list_available_models() == [] resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) resp.read.return_value = _fake_ollama_tags("gemma2:9b") resp.status = 200 with patch("urllib.request.urlopen", return_value=resp): mgr.refresh() names = {m.name for m in mgr.list_available_models()} assert "gemma2:9b" in names # --------------------------------------------------------------------------- # _detect_capabilities # --------------------------------------------------------------------------- class TestDetectCapabilities: def test_exact_match(self): mgr = _make_manager(None) caps = mgr._detect_capabilities("llava:7b") assert ModelCapability.VISION in caps def test_base_name_match(self): mgr = _make_manager(None) caps = mgr._detect_capabilities("llava:99b") # "llava:99b" not in table, but "llava" is assert ModelCapability.VISION in caps def test_unknown_model_defaults_to_text(self): mgr = _make_manager(None) caps = mgr._detect_capabilities("totally-unknown-model:1b") assert caps == {ModelCapability.TEXT, ModelCapability.STREAMING} # --------------------------------------------------------------------------- # get_model_capabilities / model_supports # --------------------------------------------------------------------------- class TestGetModelCapabilities: def test_available_model(self): mgr = _make_manager(["llava:7b"]) caps = mgr.get_model_capabilities("llava:7b") assert ModelCapability.VISION in caps def test_unavailable_model_uses_detection(self): mgr = _make_manager([]) caps = mgr.get_model_capabilities("llava:7b") assert ModelCapability.VISION in caps class TestModelSupports: def test_supports_true(self): mgr = _make_manager(["llava:7b"]) assert mgr.model_supports("llava:7b", ModelCapability.VISION) is True def test_supports_false(self): mgr = _make_manager(["deepseek-r1:7b"]) assert mgr.model_supports("deepseek-r1:7b", ModelCapability.VISION) is False # --------------------------------------------------------------------------- # get_models_with_capability # --------------------------------------------------------------------------- class TestGetModelsWithCapability: def test_returns_vision_models(self): mgr = _make_manager(["llava:7b", "deepseek-r1:7b"]) vision = mgr.get_models_with_capability(ModelCapability.VISION) names = {m.name for m in vision} assert "llava:7b" in names assert "deepseek-r1:7b" not in names def test_empty_when_none_available(self): mgr = _make_manager(["deepseek-r1:7b"]) vision = mgr.get_models_with_capability(ModelCapability.VISION) assert vision == [] # --------------------------------------------------------------------------- # get_best_model_for # --------------------------------------------------------------------------- class TestGetBestModelFor: def test_preferred_model_with_capability(self): mgr = _make_manager(["llava:7b", "llama3.1:8b"]) result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:7b") assert result == "llava:7b" def test_preferred_model_without_capability_uses_fallback(self): mgr = _make_manager(["deepseek-r1:7b", "llava:7b"]) # preferred doesn't have VISION, fallback chain has llava:7b result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="deepseek-r1:7b") assert result == "llava:7b" def test_fallback_chain_order(self): # First in chain: llama3.2:3b mgr = _make_manager(["llama3.2:3b", "llava:7b"]) result = mgr.get_best_model_for(ModelCapability.VISION) assert result == "llama3.2:3b" def test_any_capable_model_when_no_fallback(self): mgr = _make_manager(["moondream:1.8b"]) mgr._fallback_chains[ModelCapability.VISION] = [] # clear chain result = mgr.get_best_model_for(ModelCapability.VISION) assert result == "moondream:1.8b" def test_none_when_no_capable_model(self): mgr = _make_manager(["deepseek-r1:7b"]) result = mgr.get_best_model_for(ModelCapability.VISION) assert result is None def test_preferred_model_not_available_skipped(self): mgr = _make_manager(["llava:7b"]) # preferred_model "llava:13b" is not in available_models result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:13b") assert result == "llava:7b" # --------------------------------------------------------------------------- # pull_model_with_fallback (manager method) # --------------------------------------------------------------------------- class TestPullModelWithFallback: def test_already_available(self): mgr = _make_manager(["llama3.1:8b"]) model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b") assert model == "llama3.1:8b" assert is_fallback is False def test_pull_succeeds(self): mgr = _make_manager([]) pull_resp = MagicMock() pull_resp.__enter__ = MagicMock(return_value=pull_resp) pull_resp.__exit__ = MagicMock(return_value=False) pull_resp.status = 200 # After pull, refresh returns the model refresh_resp = MagicMock() refresh_resp.__enter__ = MagicMock(return_value=refresh_resp) refresh_resp.__exit__ = MagicMock(return_value=False) refresh_resp.read.return_value = _fake_ollama_tags("llama3.1:8b") refresh_resp.status = 200 with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]): model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b") assert model == "llama3.1:8b" assert is_fallback is False def test_pull_fails_uses_capability_fallback(self): mgr = _make_manager(["llava:7b"]) with patch("urllib.request.urlopen", side_effect=ConnectionError("fail")): model, is_fallback = mgr.pull_model_with_fallback( "nonexistent-vision:1b", capability=ModelCapability.VISION, ) assert model == "llava:7b" assert is_fallback is True def test_pull_fails_uses_default_model(self): mgr = _make_manager([settings_ollama_model := "llama3.1:8b"]) with ( patch("urllib.request.urlopen", side_effect=ConnectionError("fail")), patch("infrastructure.models.multimodal.settings") as mock_settings, ): mock_settings.ollama_model = settings_ollama_model mock_settings.ollama_url = "http://localhost:11434" model, is_fallback = mgr.pull_model_with_fallback("missing-model:99b") assert model == "llama3.1:8b" assert is_fallback is True def test_auto_pull_false_skips_pull(self): mgr = _make_manager([]) with patch("infrastructure.models.multimodal.settings") as mock_settings: mock_settings.ollama_model = "default" model, is_fallback = mgr.pull_model_with_fallback("missing:1b", auto_pull=False) # Falls through to absolute last resort assert model == "missing:1b" assert is_fallback is False def test_absolute_last_resort(self): mgr = _make_manager([]) with ( patch("urllib.request.urlopen", side_effect=ConnectionError("fail")), patch("infrastructure.models.multimodal.settings") as mock_settings, ): mock_settings.ollama_model = "not-available" model, is_fallback = mgr.pull_model_with_fallback("primary:1b") assert model == "primary:1b" assert is_fallback is False # --------------------------------------------------------------------------- # _pull_model # --------------------------------------------------------------------------- class TestPullModel: def test_pull_success(self): mgr = _make_manager([]) pull_resp = MagicMock() pull_resp.__enter__ = MagicMock(return_value=pull_resp) pull_resp.__exit__ = MagicMock(return_value=False) pull_resp.status = 200 refresh_resp = MagicMock() refresh_resp.__enter__ = MagicMock(return_value=refresh_resp) refresh_resp.__exit__ = MagicMock(return_value=False) refresh_resp.read.return_value = _fake_ollama_tags("new-model:1b") refresh_resp.status = 200 with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]): assert mgr._pull_model("new-model:1b") is True def test_pull_network_error(self): mgr = _make_manager([]) with patch("urllib.request.urlopen", side_effect=ConnectionError("offline")): assert mgr._pull_model("any-model:1b") is False # --------------------------------------------------------------------------- # configure_fallback_chain / get_fallback_chain # --------------------------------------------------------------------------- class TestFallbackChainConfig: def test_configure_and_get(self): mgr = _make_manager(None) mgr.configure_fallback_chain(ModelCapability.VISION, ["model-a", "model-b"]) assert mgr.get_fallback_chain(ModelCapability.VISION) == ["model-a", "model-b"] def test_get_returns_copy(self): mgr = _make_manager(None) chain = mgr.get_fallback_chain(ModelCapability.VISION) chain.append("mutated") assert "mutated" not in mgr.get_fallback_chain(ModelCapability.VISION) def test_get_empty_for_unknown(self): mgr = _make_manager(None) # AUDIO has an empty chain by default assert mgr.get_fallback_chain(ModelCapability.AUDIO) == [] # --------------------------------------------------------------------------- # get_model_for_content # --------------------------------------------------------------------------- class TestGetModelForContent: def test_image_content(self): mgr = _make_manager(["llava:7b"]) model, is_fb = mgr.get_model_for_content("image") assert model == "llava:7b" def test_vision_content(self): mgr = _make_manager(["llava:7b"]) model, _ = mgr.get_model_for_content("vision") assert model == "llava:7b" def test_multimodal_content(self): mgr = _make_manager(["llava:7b"]) model, _ = mgr.get_model_for_content("multimodal") assert model == "llava:7b" def test_audio_content(self): mgr = _make_manager(["llama3.1:8b"]) with patch("infrastructure.models.multimodal.settings") as mock_settings: mock_settings.ollama_model = "llama3.1:8b" mock_settings.ollama_url = "http://localhost:11434" model, _ = mgr.get_model_for_content("audio") assert model == "llama3.1:8b" def test_text_content(self): mgr = _make_manager(["llama3.1:8b"]) with patch("infrastructure.models.multimodal.settings") as mock_settings: mock_settings.ollama_model = "llama3.1:8b" mock_settings.ollama_url = "http://localhost:11434" model, _ = mgr.get_model_for_content("text") assert model == "llama3.1:8b" def test_preferred_model_respected(self): mgr = _make_manager(["llama3.2:3b", "llava:7b"]) model, _ = mgr.get_model_for_content("image", preferred_model="llama3.2:3b") assert model == "llama3.2:3b" def test_case_insensitive(self): mgr = _make_manager(["llava:7b"]) model, _ = mgr.get_model_for_content("IMAGE") assert model == "llava:7b" # --------------------------------------------------------------------------- # Module-level convenience functions # --------------------------------------------------------------------------- class TestConvenienceFunctions: def _patch_manager(self, mgr): return patch( "infrastructure.models.multimodal._multimodal_manager", mgr, ) def test_get_model_for_capability(self): mgr = _make_manager(["llava:7b"]) with self._patch_manager(mgr): result = get_model_for_capability(ModelCapability.VISION) assert result == "llava:7b" def test_pull_model_with_fallback_convenience(self): mgr = _make_manager(["llama3.1:8b"]) with self._patch_manager(mgr): model, is_fb = pull_model_with_fallback("llama3.1:8b") assert model == "llama3.1:8b" assert is_fb is False def test_model_supports_vision_true(self): mgr = _make_manager(["llava:7b"]) with self._patch_manager(mgr): assert model_supports_vision("llava:7b") is True def test_model_supports_vision_false(self): mgr = _make_manager(["llama3.1:8b"]) with self._patch_manager(mgr): assert model_supports_vision("llama3.1:8b") is False def test_model_supports_tools_true(self): mgr = _make_manager(["llama3.1:8b"]) with self._patch_manager(mgr): assert model_supports_tools("llama3.1:8b") is True def test_model_supports_tools_false(self): mgr = _make_manager(["deepseek-r1:7b"]) with self._patch_manager(mgr): assert model_supports_tools("deepseek-r1:7b") is False # --------------------------------------------------------------------------- # ModelInfo in available_models — size_mb and description populated # --------------------------------------------------------------------------- class TestModelInfoPopulation: def test_size_and_description(self): mgr = _make_manager(["llama3.1:8b"]) info = mgr._available_models["llama3.1:8b"] assert info.is_available is True assert info.is_pulled is True assert info.size_mb == 4 * 1024 # 4 GiB in MiB assert info.description == "test"