From ab0363f70007a072a6377a5fbf0198994d2ed924 Mon Sep 17 00:00:00 2001 From: kimi Date: Thu, 19 Mar 2026 11:11:49 -0400 Subject: [PATCH] fix: add unit tests for memory/embeddings.py Tests cover _simple_hash_embedding, cosine_similarity, _keyword_overlap, embed_text, and _get_embedding_model with proper mocking of the global model state. Fixes #431 Co-Authored-By: Claude Opus 4.6 --- tests/unit/test_embeddings.py | 154 ++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 tests/unit/test_embeddings.py diff --git a/tests/unit/test_embeddings.py b/tests/unit/test_embeddings.py new file mode 100644 index 00000000..87c840ea --- /dev/null +++ b/tests/unit/test_embeddings.py @@ -0,0 +1,154 @@ +"""Unit tests for timmy.memory.embeddings — embedding, similarity, and keyword overlap.""" + +import math +from unittest.mock import MagicMock, patch + +import pytest + +import timmy.memory.embeddings as emb +from timmy.memory.embeddings import ( + _keyword_overlap, + _simple_hash_embedding, + cosine_similarity, + embed_text, +) + +# ── _simple_hash_embedding ─────────────────────────────────────────────────── + + +class TestSimpleHashEmbedding: + """Tests for the deterministic hash-based fallback embedding.""" + + def test_returns_128_floats(self): + vec = _simple_hash_embedding("hello world") + assert len(vec) == 128 + assert all(isinstance(x, float) for x in vec) + + def test_deterministic(self): + """Same input always produces the same vector.""" + assert _simple_hash_embedding("test") == _simple_hash_embedding("test") + + def test_normalized(self): + """Output vector has unit magnitude.""" + vec = _simple_hash_embedding("some text for testing") + mag = math.sqrt(sum(x * x for x in vec)) + assert mag == pytest.approx(1.0, abs=1e-6) + + def test_empty_string(self): + """Empty string doesn't crash — returns a zero-ish vector.""" + vec = _simple_hash_embedding("") + assert len(vec) == 128 + + def test_different_inputs_differ(self): + a = _simple_hash_embedding("alpha") + b = _simple_hash_embedding("beta") + assert a != b + + +# ── cosine_similarity ──────────────────────────────────────────────────────── + + +class TestCosineSimilarity: + """Tests for cosine similarity calculation.""" + + def test_identical_vectors(self): + v = [1.0, 2.0, 3.0] + assert cosine_similarity(v, v) == pytest.approx(1.0) + + def test_orthogonal_vectors(self): + a = [1.0, 0.0] + b = [0.0, 1.0] + assert cosine_similarity(a, b) == pytest.approx(0.0) + + def test_opposite_vectors(self): + a = [1.0, 0.0] + b = [-1.0, 0.0] + assert cosine_similarity(a, b) == pytest.approx(-1.0) + + def test_zero_vector_returns_zero(self): + assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0 + assert cosine_similarity([1.0, 2.0], [0.0, 0.0]) == 0.0 + + def test_different_length_uses_zip(self): + """zip(strict=False) truncates to shortest — verify no crash.""" + result = cosine_similarity([1.0, 0.0], [1.0, 0.0, 9.9]) + # mag_b includes the extra element, so result < 1.0 + assert isinstance(result, float) + + +# ── _keyword_overlap ───────────────────────────────────────────────────────── + + +class TestKeywordOverlap: + """Tests for keyword overlap scoring.""" + + def test_full_overlap(self): + assert _keyword_overlap("hello world", "hello world extra") == pytest.approx(1.0) + + def test_partial_overlap(self): + assert _keyword_overlap("hello world", "hello there") == pytest.approx(0.5) + + def test_no_overlap(self): + assert _keyword_overlap("alpha", "beta") == pytest.approx(0.0) + + def test_empty_query(self): + assert _keyword_overlap("", "some content") == 0.0 + + def test_case_insensitive(self): + assert _keyword_overlap("Hello", "hello world") == pytest.approx(1.0) + + +# ── embed_text ─────────────────────────────────────────────────────────────── + + +class TestEmbedText: + """Tests for the main embed_text entry point.""" + + def test_uses_fallback_when_model_false(self): + """When _get_embedding_model returns False, use hash fallback.""" + with patch.object(emb, "EMBEDDING_MODEL", False): + with patch.object(emb, "_get_embedding_model", return_value=False): + vec = embed_text("hello") + assert len(vec) == 128 + + def test_uses_model_when_available(self): + """When a real model is loaded, call model.encode().""" + mock_model = MagicMock() + mock_model.encode.return_value = MagicMock(tolist=MagicMock(return_value=[0.1, 0.2])) + with patch.object(emb, "_get_embedding_model", return_value=mock_model): + vec = embed_text("hello") + assert vec == [0.1, 0.2] + mock_model.encode.assert_called_once_with("hello") + + +# ── _get_embedding_model ───────────────────────────────────────────────────── + + +class TestGetEmbeddingModel: + """Tests for lazy model loading.""" + + def setup_method(self): + """Reset global state before each test.""" + emb.EMBEDDING_MODEL = None + + def test_skip_embeddings_setting(self): + """When settings.timmy_skip_embeddings is True, model is set to False.""" + mock_settings = MagicMock() + mock_settings.timmy_skip_embeddings = True + with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}): + emb.EMBEDDING_MODEL = None + result = emb._get_embedding_model() + assert result is False + + def test_sentence_transformers_import_error(self): + """When sentence-transformers is missing, falls back to False.""" + mock_settings = MagicMock() + mock_settings.timmy_skip_embeddings = False + with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}): + with patch.dict("sys.modules", {"sentence_transformers": None}): + emb.EMBEDDING_MODEL = None + result = emb._get_embedding_model() + assert result is False + + def teardown_method(self): + emb.EMBEDDING_MODEL = None