fix: add unit tests for memory/embeddings.py
Fixes #431 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
"""Unit tests for timmy.memory.embeddings module."""
|
||||
"""Unit tests for timmy.memory.embeddings — embedding, similarity, and overlap."""
|
||||
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import timmy.memory.embeddings as emb
|
||||
from timmy.memory.embeddings import (
|
||||
_keyword_overlap,
|
||||
_simple_hash_embedding,
|
||||
@@ -12,82 +13,42 @@ from timmy.memory.embeddings import (
|
||||
embed_text,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _simple_hash_embedding
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── _simple_hash_embedding ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSimpleHashEmbedding:
|
||||
def test_returns_128_dim_vector(self):
|
||||
def test_returns_list_of_floats(self):
|
||||
vec = _simple_hash_embedding("hello world")
|
||||
assert isinstance(vec, list)
|
||||
assert all(isinstance(v, float) for v in vec)
|
||||
|
||||
def test_dimension_is_128(self):
|
||||
vec = _simple_hash_embedding("test sentence")
|
||||
assert len(vec) == 128
|
||||
|
||||
def test_deterministic(self):
|
||||
assert _simple_hash_embedding("test") == _simple_hash_embedding("test")
|
||||
|
||||
def test_normalized(self):
|
||||
vec = _simple_hash_embedding("the quick brown fox")
|
||||
def test_normalized_unit_vector(self):
|
||||
vec = _simple_hash_embedding("some text here")
|
||||
mag = math.sqrt(sum(x * x for x in vec))
|
||||
assert mag == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
def test_deterministic(self):
|
||||
a = _simple_hash_embedding("deterministic check")
|
||||
b = _simple_hash_embedding("deterministic check")
|
||||
assert a == b
|
||||
|
||||
def test_different_texts_differ(self):
|
||||
a = _simple_hash_embedding("hello")
|
||||
b = _simple_hash_embedding("goodbye")
|
||||
a = _simple_hash_embedding("alpha")
|
||||
b = _simple_hash_embedding("beta")
|
||||
assert a != b
|
||||
|
||||
def test_empty_string(self):
|
||||
vec = _simple_hash_embedding("")
|
||||
assert len(vec) == 128
|
||||
# All zeros → magnitude 0 → division by 1.0 → still all zeros
|
||||
assert all(x == 0.0 for x in vec)
|
||||
|
||||
def test_unicode(self):
|
||||
vec = _simple_hash_embedding("café résumé naïve")
|
||||
assert len(vec) == 128
|
||||
mag = math.sqrt(sum(x * x for x in vec))
|
||||
assert mag == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
def test_long_text_truncates_to_50_words(self):
|
||||
words = [f"word{i}" for i in range(100)]
|
||||
vec_100 = _simple_hash_embedding(" ".join(words))
|
||||
vec_50 = _simple_hash_embedding(" ".join(words[:50]))
|
||||
assert vec_100 == vec_50
|
||||
# All zeros normalized → magnitude 0 guarded by `or 1.0`
|
||||
assert all(v == 0.0 for v in vec)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# embed_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEmbedText:
|
||||
@patch("timmy.memory.embeddings._get_embedding_model")
|
||||
def test_uses_fallback_when_model_false(self, mock_get):
|
||||
mock_get.return_value = False
|
||||
vec = embed_text("hello")
|
||||
assert len(vec) == 128 # hash fallback dimension
|
||||
|
||||
@patch("timmy.memory.embeddings._get_embedding_model")
|
||||
def test_uses_model_when_available(self, mock_get):
|
||||
import numpy as np
|
||||
|
||||
fake_model = MagicMock()
|
||||
fake_model.encode.return_value = np.array([0.1, 0.2, 0.3])
|
||||
mock_get.return_value = fake_model
|
||||
|
||||
result = embed_text("test")
|
||||
fake_model.encode.assert_called_once_with("test")
|
||||
assert result == pytest.approx([0.1, 0.2, 0.3])
|
||||
|
||||
@patch("timmy.memory.embeddings._get_embedding_model")
|
||||
def test_uses_fallback_when_model_none(self, mock_get):
|
||||
mock_get.return_value = None
|
||||
vec = embed_text("hello")
|
||||
assert len(vec) == 128
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cosine_similarity
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── cosine_similarity ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
@@ -108,18 +69,12 @@ class TestCosineSimilarity:
|
||||
def test_zero_vector_returns_zero(self):
|
||||
assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0
|
||||
assert cosine_similarity([1.0, 2.0], [0.0, 0.0]) == 0.0
|
||||
assert cosine_similarity([0.0, 0.0], [0.0, 0.0]) == 0.0
|
||||
|
||||
def test_different_length_vectors(self):
|
||||
# zip(strict=False) truncates to shorter
|
||||
a = [1.0, 0.0, 0.0]
|
||||
b = [1.0, 0.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(1.0)
|
||||
def test_both_zero_returns_zero(self):
|
||||
assert cosine_similarity([0.0], [0.0]) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _keyword_overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── _keyword_overlap ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKeywordOverlap:
|
||||
@@ -130,51 +85,69 @@ class TestKeywordOverlap:
|
||||
assert _keyword_overlap("hello world", "hello there") == pytest.approx(0.5)
|
||||
|
||||
def test_no_overlap(self):
|
||||
assert _keyword_overlap("hello", "world") == pytest.approx(0.0)
|
||||
assert _keyword_overlap("foo bar", "baz qux") == pytest.approx(0.0)
|
||||
|
||||
def test_empty_query(self):
|
||||
def test_empty_query_returns_zero(self):
|
||||
assert _keyword_overlap("", "some content") == 0.0
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _keyword_overlap("hello", "") == 0.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _keyword_overlap("Hello World", "hello world") == pytest.approx(1.0)
|
||||
|
||||
def test_superset_content(self):
|
||||
assert _keyword_overlap("cat", "the cat sat on the mat") == pytest.approx(1.0)
|
||||
assert _keyword_overlap("Hello WORLD", "hello world") == pytest.approx(1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_embedding_model
|
||||
# ---------------------------------------------------------------------------
|
||||
# ── _get_embedding_model ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetEmbeddingModel:
|
||||
def test_skip_embeddings_setting(self):
|
||||
"""When timmy_skip_embeddings is True, model should be False."""
|
||||
import timmy.memory.embeddings as mod
|
||||
def setup_method(self):
|
||||
# Reset global state before each test
|
||||
emb.EMBEDDING_MODEL = None
|
||||
|
||||
original = mod.EMBEDDING_MODEL
|
||||
try:
|
||||
mod.EMBEDDING_MODEL = None # reset lazy cache
|
||||
with patch("timmy.memory.embeddings.settings", create=True) as mock_settings:
|
||||
mock_settings.timmy_skip_embeddings = True
|
||||
# Patch the import inside _get_embedding_model
|
||||
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
||||
result = mod._get_embedding_model()
|
||||
assert result is False
|
||||
finally:
|
||||
mod.EMBEDDING_MODEL = original
|
||||
def teardown_method(self):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
|
||||
def test_skip_embeddings_setting(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.timmy_skip_embeddings = True
|
||||
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
result = emb._get_embedding_model()
|
||||
assert result is False
|
||||
|
||||
def test_fallback_when_sentence_transformers_missing(self):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.timmy_skip_embeddings = False
|
||||
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
||||
with patch.dict("sys.modules", {"sentence_transformers": None}):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
result = emb._get_embedding_model()
|
||||
assert result is False
|
||||
|
||||
def test_caches_model(self):
|
||||
"""Subsequent calls return the cached value."""
|
||||
import timmy.memory.embeddings as mod
|
||||
sentinel = MagicMock()
|
||||
emb.EMBEDDING_MODEL = sentinel
|
||||
assert emb._get_embedding_model() is sentinel
|
||||
|
||||
original = mod.EMBEDDING_MODEL
|
||||
try:
|
||||
mod.EMBEDDING_MODEL = "cached_sentinel"
|
||||
result = mod._get_embedding_model()
|
||||
assert result == "cached_sentinel"
|
||||
finally:
|
||||
mod.EMBEDDING_MODEL = original
|
||||
|
||||
# ── embed_text ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmbedText:
|
||||
def test_uses_fallback_when_model_false(self):
|
||||
with patch.object(emb, "_get_embedding_model", return_value=False):
|
||||
vec = embed_text("test")
|
||||
assert len(vec) == 128
|
||||
|
||||
def test_uses_model_when_available(self):
|
||||
import numpy as np
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
|
||||
with patch.object(emb, "_get_embedding_model", return_value=mock_model):
|
||||
vec = embed_text("test")
|
||||
assert vec == [0.1, 0.2, 0.3]
|
||||
mock_model.encode.assert_called_once_with("test")
|
||||
|
||||
def test_uses_fallback_when_model_none(self):
|
||||
with patch.object(emb, "_get_embedding_model", return_value=None):
|
||||
vec = embed_text("test")
|
||||
assert len(vec) == 128
|
||||
|
||||
Reference in New Issue
Block a user