[kimi] test: add comprehensive tests for multimodal.py (#658) (#694)
Some checks failed
Tests / lint (push) Has been cancelled
Tests / test (push) Has been cancelled

This commit was merged in pull request #694.
This commit is contained in:
2026-03-21 04:00:53 +00:00
parent 6f0863b587
commit cc30bdb391

View File

@@ -0,0 +1,509 @@
"""Tests for infrastructure.models.multimodal — multi-modal model management."""
import json
from unittest.mock import MagicMock, patch
from infrastructure.models.multimodal import (
DEFAULT_FALLBACK_CHAINS,
KNOWN_MODEL_CAPABILITIES,
ModelCapability,
ModelInfo,
MultiModalManager,
get_model_for_capability,
model_supports_tools,
model_supports_vision,
pull_model_with_fallback,
)
# ---------------------------------------------------------------------------
# ModelCapability enum
# ---------------------------------------------------------------------------
class TestModelCapability:
def test_members_exist(self):
assert ModelCapability.TEXT
assert ModelCapability.VISION
assert ModelCapability.AUDIO
assert ModelCapability.TOOLS
assert ModelCapability.JSON
assert ModelCapability.STREAMING
def test_all_members_unique(self):
values = [m.value for m in ModelCapability]
assert len(values) == len(set(values))
# ---------------------------------------------------------------------------
# ModelInfo dataclass
# ---------------------------------------------------------------------------
class TestModelInfo:
def test_defaults(self):
info = ModelInfo(name="test-model")
assert info.name == "test-model"
assert info.capabilities == set()
assert info.is_available is False
assert info.is_pulled is False
assert info.size_mb is None
assert info.description == ""
def test_supports_true(self):
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT, ModelCapability.VISION})
assert info.supports(ModelCapability.TEXT) is True
assert info.supports(ModelCapability.VISION) is True
def test_supports_false(self):
info = ModelInfo(name="m", capabilities={ModelCapability.TEXT})
assert info.supports(ModelCapability.VISION) is False
# ---------------------------------------------------------------------------
# Known model capabilities lookup table
# ---------------------------------------------------------------------------
class TestKnownModelCapabilities:
def test_vision_models_have_vision(self):
vision_names = [
"llama3.2-vision",
"llava",
"moondream",
"qwen2.5-vl",
]
for name in vision_names:
assert ModelCapability.VISION in KNOWN_MODEL_CAPABILITIES[name], name
def test_text_models_lack_vision(self):
text_only = ["deepseek-r1", "gemma2", "phi3"]
for name in text_only:
assert ModelCapability.VISION not in KNOWN_MODEL_CAPABILITIES[name], name
def test_all_models_have_text(self):
for name, caps in KNOWN_MODEL_CAPABILITIES.items():
assert ModelCapability.TEXT in caps, f"{name} should have TEXT"
# ---------------------------------------------------------------------------
# Default fallback chains
# ---------------------------------------------------------------------------
class TestDefaultFallbackChains:
def test_vision_chain_non_empty(self):
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.VISION]) > 0
def test_tools_chain_non_empty(self):
assert len(DEFAULT_FALLBACK_CHAINS[ModelCapability.TOOLS]) > 0
def test_audio_chain_empty(self):
assert DEFAULT_FALLBACK_CHAINS[ModelCapability.AUDIO] == []
# ---------------------------------------------------------------------------
# Helpers to build a manager without hitting the network
# ---------------------------------------------------------------------------
def _fake_ollama_tags(*model_names: str) -> bytes:
"""Build a JSON response mimicking Ollama /api/tags."""
models = []
for name in model_names:
models.append({"name": name, "size": 4 * 1024 * 1024 * 1024, "details": {"family": "test"}})
return json.dumps({"models": models}).encode()
def _make_manager(model_names: list[str] | None = None) -> MultiModalManager:
"""Create a MultiModalManager with mocked Ollama responses."""
if model_names is None:
# No models available — Ollama unreachable
with patch("urllib.request.urlopen", side_effect=ConnectionError("no ollama")):
return MultiModalManager(ollama_url="http://localhost:11434")
resp = MagicMock()
resp.__enter__ = MagicMock(return_value=resp)
resp.__exit__ = MagicMock(return_value=False)
resp.read.return_value = _fake_ollama_tags(*model_names)
resp.status = 200
with patch("urllib.request.urlopen", return_value=resp):
return MultiModalManager(ollama_url="http://localhost:11434")
# ---------------------------------------------------------------------------
# MultiModalManager — init & refresh
# ---------------------------------------------------------------------------
class TestMultiModalManagerInit:
def test_init_no_ollama(self):
mgr = _make_manager(None)
assert mgr.list_available_models() == []
def test_init_with_models(self):
mgr = _make_manager(["llama3.1:8b", "llava:7b"])
names = {m.name for m in mgr.list_available_models()}
assert names == {"llama3.1:8b", "llava:7b"}
def test_refresh_updates_models(self):
mgr = _make_manager([])
assert mgr.list_available_models() == []
resp = MagicMock()
resp.__enter__ = MagicMock(return_value=resp)
resp.__exit__ = MagicMock(return_value=False)
resp.read.return_value = _fake_ollama_tags("gemma2:9b")
resp.status = 200
with patch("urllib.request.urlopen", return_value=resp):
mgr.refresh()
names = {m.name for m in mgr.list_available_models()}
assert "gemma2:9b" in names
# ---------------------------------------------------------------------------
# _detect_capabilities
# ---------------------------------------------------------------------------
class TestDetectCapabilities:
def test_exact_match(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("llava:7b")
assert ModelCapability.VISION in caps
def test_base_name_match(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("llava:99b")
# "llava:99b" not in table, but "llava" is
assert ModelCapability.VISION in caps
def test_unknown_model_defaults_to_text(self):
mgr = _make_manager(None)
caps = mgr._detect_capabilities("totally-unknown-model:1b")
assert caps == {ModelCapability.TEXT, ModelCapability.STREAMING}
# ---------------------------------------------------------------------------
# get_model_capabilities / model_supports
# ---------------------------------------------------------------------------
class TestGetModelCapabilities:
def test_available_model(self):
mgr = _make_manager(["llava:7b"])
caps = mgr.get_model_capabilities("llava:7b")
assert ModelCapability.VISION in caps
def test_unavailable_model_uses_detection(self):
mgr = _make_manager([])
caps = mgr.get_model_capabilities("llava:7b")
assert ModelCapability.VISION in caps
class TestModelSupports:
def test_supports_true(self):
mgr = _make_manager(["llava:7b"])
assert mgr.model_supports("llava:7b", ModelCapability.VISION) is True
def test_supports_false(self):
mgr = _make_manager(["deepseek-r1:7b"])
assert mgr.model_supports("deepseek-r1:7b", ModelCapability.VISION) is False
# ---------------------------------------------------------------------------
# get_models_with_capability
# ---------------------------------------------------------------------------
class TestGetModelsWithCapability:
def test_returns_vision_models(self):
mgr = _make_manager(["llava:7b", "deepseek-r1:7b"])
vision = mgr.get_models_with_capability(ModelCapability.VISION)
names = {m.name for m in vision}
assert "llava:7b" in names
assert "deepseek-r1:7b" not in names
def test_empty_when_none_available(self):
mgr = _make_manager(["deepseek-r1:7b"])
vision = mgr.get_models_with_capability(ModelCapability.VISION)
assert vision == []
# ---------------------------------------------------------------------------
# get_best_model_for
# ---------------------------------------------------------------------------
class TestGetBestModelFor:
def test_preferred_model_with_capability(self):
mgr = _make_manager(["llava:7b", "llama3.1:8b"])
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:7b")
assert result == "llava:7b"
def test_preferred_model_without_capability_uses_fallback(self):
mgr = _make_manager(["deepseek-r1:7b", "llava:7b"])
# preferred doesn't have VISION, fallback chain has llava:7b
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="deepseek-r1:7b")
assert result == "llava:7b"
def test_fallback_chain_order(self):
# First in chain: llama3.2:3b
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result == "llama3.2:3b"
def test_any_capable_model_when_no_fallback(self):
mgr = _make_manager(["moondream:1.8b"])
mgr._fallback_chains[ModelCapability.VISION] = [] # clear chain
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result == "moondream:1.8b"
def test_none_when_no_capable_model(self):
mgr = _make_manager(["deepseek-r1:7b"])
result = mgr.get_best_model_for(ModelCapability.VISION)
assert result is None
def test_preferred_model_not_available_skipped(self):
mgr = _make_manager(["llava:7b"])
# preferred_model "llava:13b" is not in available_models
result = mgr.get_best_model_for(ModelCapability.VISION, preferred_model="llava:13b")
assert result == "llava:7b"
# ---------------------------------------------------------------------------
# pull_model_with_fallback (manager method)
# ---------------------------------------------------------------------------
class TestPullModelWithFallback:
def test_already_available(self):
mgr = _make_manager(["llama3.1:8b"])
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fallback is False
def test_pull_succeeds(self):
mgr = _make_manager([])
pull_resp = MagicMock()
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
pull_resp.__exit__ = MagicMock(return_value=False)
pull_resp.status = 200
# After pull, refresh returns the model
refresh_resp = MagicMock()
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
refresh_resp.__exit__ = MagicMock(return_value=False)
refresh_resp.read.return_value = _fake_ollama_tags("llama3.1:8b")
refresh_resp.status = 200
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
model, is_fallback = mgr.pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fallback is False
def test_pull_fails_uses_capability_fallback(self):
mgr = _make_manager(["llava:7b"])
with patch("urllib.request.urlopen", side_effect=ConnectionError("fail")):
model, is_fallback = mgr.pull_model_with_fallback(
"nonexistent-vision:1b",
capability=ModelCapability.VISION,
)
assert model == "llava:7b"
assert is_fallback is True
def test_pull_fails_uses_default_model(self):
mgr = _make_manager([settings_ollama_model := "llama3.1:8b"])
with (
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
patch("infrastructure.models.multimodal.settings") as mock_settings,
):
mock_settings.ollama_model = settings_ollama_model
mock_settings.ollama_url = "http://localhost:11434"
model, is_fallback = mgr.pull_model_with_fallback("missing-model:99b")
assert model == "llama3.1:8b"
assert is_fallback is True
def test_auto_pull_false_skips_pull(self):
mgr = _make_manager([])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "default"
model, is_fallback = mgr.pull_model_with_fallback("missing:1b", auto_pull=False)
# Falls through to absolute last resort
assert model == "missing:1b"
assert is_fallback is False
def test_absolute_last_resort(self):
mgr = _make_manager([])
with (
patch("urllib.request.urlopen", side_effect=ConnectionError("fail")),
patch("infrastructure.models.multimodal.settings") as mock_settings,
):
mock_settings.ollama_model = "not-available"
model, is_fallback = mgr.pull_model_with_fallback("primary:1b")
assert model == "primary:1b"
assert is_fallback is False
# ---------------------------------------------------------------------------
# _pull_model
# ---------------------------------------------------------------------------
class TestPullModel:
def test_pull_success(self):
mgr = _make_manager([])
pull_resp = MagicMock()
pull_resp.__enter__ = MagicMock(return_value=pull_resp)
pull_resp.__exit__ = MagicMock(return_value=False)
pull_resp.status = 200
refresh_resp = MagicMock()
refresh_resp.__enter__ = MagicMock(return_value=refresh_resp)
refresh_resp.__exit__ = MagicMock(return_value=False)
refresh_resp.read.return_value = _fake_ollama_tags("new-model:1b")
refresh_resp.status = 200
with patch("urllib.request.urlopen", side_effect=[pull_resp, refresh_resp]):
assert mgr._pull_model("new-model:1b") is True
def test_pull_network_error(self):
mgr = _make_manager([])
with patch("urllib.request.urlopen", side_effect=ConnectionError("offline")):
assert mgr._pull_model("any-model:1b") is False
# ---------------------------------------------------------------------------
# configure_fallback_chain / get_fallback_chain
# ---------------------------------------------------------------------------
class TestFallbackChainConfig:
def test_configure_and_get(self):
mgr = _make_manager(None)
mgr.configure_fallback_chain(ModelCapability.VISION, ["model-a", "model-b"])
assert mgr.get_fallback_chain(ModelCapability.VISION) == ["model-a", "model-b"]
def test_get_returns_copy(self):
mgr = _make_manager(None)
chain = mgr.get_fallback_chain(ModelCapability.VISION)
chain.append("mutated")
assert "mutated" not in mgr.get_fallback_chain(ModelCapability.VISION)
def test_get_empty_for_unknown(self):
mgr = _make_manager(None)
# AUDIO has an empty chain by default
assert mgr.get_fallback_chain(ModelCapability.AUDIO) == []
# ---------------------------------------------------------------------------
# get_model_for_content
# ---------------------------------------------------------------------------
class TestGetModelForContent:
def test_image_content(self):
mgr = _make_manager(["llava:7b"])
model, is_fb = mgr.get_model_for_content("image")
assert model == "llava:7b"
def test_vision_content(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("vision")
assert model == "llava:7b"
def test_multimodal_content(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("multimodal")
assert model == "llava:7b"
def test_audio_content(self):
mgr = _make_manager(["llama3.1:8b"])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "llama3.1:8b"
mock_settings.ollama_url = "http://localhost:11434"
model, _ = mgr.get_model_for_content("audio")
assert model == "llama3.1:8b"
def test_text_content(self):
mgr = _make_manager(["llama3.1:8b"])
with patch("infrastructure.models.multimodal.settings") as mock_settings:
mock_settings.ollama_model = "llama3.1:8b"
mock_settings.ollama_url = "http://localhost:11434"
model, _ = mgr.get_model_for_content("text")
assert model == "llama3.1:8b"
def test_preferred_model_respected(self):
mgr = _make_manager(["llama3.2:3b", "llava:7b"])
model, _ = mgr.get_model_for_content("image", preferred_model="llama3.2:3b")
assert model == "llama3.2:3b"
def test_case_insensitive(self):
mgr = _make_manager(["llava:7b"])
model, _ = mgr.get_model_for_content("IMAGE")
assert model == "llava:7b"
# ---------------------------------------------------------------------------
# Module-level convenience functions
# ---------------------------------------------------------------------------
class TestConvenienceFunctions:
def _patch_manager(self, mgr):
return patch(
"infrastructure.models.multimodal._multimodal_manager",
mgr,
)
def test_get_model_for_capability(self):
mgr = _make_manager(["llava:7b"])
with self._patch_manager(mgr):
result = get_model_for_capability(ModelCapability.VISION)
assert result == "llava:7b"
def test_pull_model_with_fallback_convenience(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
model, is_fb = pull_model_with_fallback("llama3.1:8b")
assert model == "llama3.1:8b"
assert is_fb is False
def test_model_supports_vision_true(self):
mgr = _make_manager(["llava:7b"])
with self._patch_manager(mgr):
assert model_supports_vision("llava:7b") is True
def test_model_supports_vision_false(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
assert model_supports_vision("llama3.1:8b") is False
def test_model_supports_tools_true(self):
mgr = _make_manager(["llama3.1:8b"])
with self._patch_manager(mgr):
assert model_supports_tools("llama3.1:8b") is True
def test_model_supports_tools_false(self):
mgr = _make_manager(["deepseek-r1:7b"])
with self._patch_manager(mgr):
assert model_supports_tools("deepseek-r1:7b") is False
# ---------------------------------------------------------------------------
# ModelInfo in available_models — size_mb and description populated
# ---------------------------------------------------------------------------
class TestModelInfoPopulation:
def test_size_and_description(self):
mgr = _make_manager(["llama3.1:8b"])
info = mgr._available_models["llama3.1:8b"]
assert info.is_available is True
assert info.is_pulled is True
assert info.size_mb == 4 * 1024 # 4 GiB in MiB
assert info.description == "test"