diff --git a/tests/unit/test_memory_embeddings.py b/tests/unit/test_memory_embeddings.py index 6d3e6bb5..7436ebc4 100644 --- a/tests/unit/test_memory_embeddings.py +++ b/tests/unit/test_memory_embeddings.py @@ -99,16 +99,9 @@ class TestKeywordOverlap: class TestEmbedText: - def setup_method(self): - self._saved_model = emb.EMBEDDING_MODEL - emb.EMBEDDING_MODEL = None - - def teardown_method(self): - emb.EMBEDDING_MODEL = self._saved_model - def test_uses_fallback_when_model_disabled(self): - emb.EMBEDDING_MODEL = False - vec = embed_text("test") + with patch.object(emb, "_get_embedding_model", return_value=False): + vec = embed_text("test") assert len(vec) == 128 # hash fallback dimension def test_uses_model_when_available(self): @@ -116,9 +109,9 @@ class TestEmbedText: mock_encoding.tolist.return_value = [0.1, 0.2, 0.3] mock_model = MagicMock() mock_model.encode.return_value = mock_encoding - emb.EMBEDDING_MODEL = mock_model - result = embed_text("test") + with patch.object(emb, "_get_embedding_model", return_value=mock_model): + result = embed_text("test") assert result == pytest.approx([0.1, 0.2, 0.3]) mock_model.encode.assert_called_once_with("test")