Compare commits
1 Commits
gemini/iss
...
kimi/issue
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab0363f700 |
154
tests/unit/test_embeddings.py
Normal file
154
tests/unit/test_embeddings.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Unit tests for timmy.memory.embeddings — embedding, similarity, and keyword overlap."""
|
||||
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import timmy.memory.embeddings as emb
|
||||
from timmy.memory.embeddings import (
|
||||
_keyword_overlap,
|
||||
_simple_hash_embedding,
|
||||
cosine_similarity,
|
||||
embed_text,
|
||||
)
|
||||
|
||||
# ── _simple_hash_embedding ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSimpleHashEmbedding:
|
||||
"""Tests for the deterministic hash-based fallback embedding."""
|
||||
|
||||
def test_returns_128_floats(self):
|
||||
vec = _simple_hash_embedding("hello world")
|
||||
assert len(vec) == 128
|
||||
assert all(isinstance(x, float) for x in vec)
|
||||
|
||||
def test_deterministic(self):
|
||||
"""Same input always produces the same vector."""
|
||||
assert _simple_hash_embedding("test") == _simple_hash_embedding("test")
|
||||
|
||||
def test_normalized(self):
|
||||
"""Output vector has unit magnitude."""
|
||||
vec = _simple_hash_embedding("some text for testing")
|
||||
mag = math.sqrt(sum(x * x for x in vec))
|
||||
assert mag == pytest.approx(1.0, abs=1e-6)
|
||||
|
||||
def test_empty_string(self):
|
||||
"""Empty string doesn't crash — returns a zero-ish vector."""
|
||||
vec = _simple_hash_embedding("")
|
||||
assert len(vec) == 128
|
||||
|
||||
def test_different_inputs_differ(self):
|
||||
a = _simple_hash_embedding("alpha")
|
||||
b = _simple_hash_embedding("beta")
|
||||
assert a != b
|
||||
|
||||
|
||||
# ── cosine_similarity ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
"""Tests for cosine similarity calculation."""
|
||||
|
||||
def test_identical_vectors(self):
|
||||
v = [1.0, 2.0, 3.0]
|
||||
assert cosine_similarity(v, v) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [0.0, 1.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(0.0)
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [-1.0, 0.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||
|
||||
def test_zero_vector_returns_zero(self):
|
||||
assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0
|
||||
assert cosine_similarity([1.0, 2.0], [0.0, 0.0]) == 0.0
|
||||
|
||||
def test_different_length_uses_zip(self):
|
||||
"""zip(strict=False) truncates to shortest — verify no crash."""
|
||||
result = cosine_similarity([1.0, 0.0], [1.0, 0.0, 9.9])
|
||||
# mag_b includes the extra element, so result < 1.0
|
||||
assert isinstance(result, float)
|
||||
|
||||
|
||||
# ── _keyword_overlap ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestKeywordOverlap:
|
||||
"""Tests for keyword overlap scoring."""
|
||||
|
||||
def test_full_overlap(self):
|
||||
assert _keyword_overlap("hello world", "hello world extra") == pytest.approx(1.0)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
assert _keyword_overlap("hello world", "hello there") == pytest.approx(0.5)
|
||||
|
||||
def test_no_overlap(self):
|
||||
assert _keyword_overlap("alpha", "beta") == pytest.approx(0.0)
|
||||
|
||||
def test_empty_query(self):
|
||||
assert _keyword_overlap("", "some content") == 0.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _keyword_overlap("Hello", "hello world") == pytest.approx(1.0)
|
||||
|
||||
|
||||
# ── embed_text ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmbedText:
|
||||
"""Tests for the main embed_text entry point."""
|
||||
|
||||
def test_uses_fallback_when_model_false(self):
|
||||
"""When _get_embedding_model returns False, use hash fallback."""
|
||||
with patch.object(emb, "EMBEDDING_MODEL", False):
|
||||
with patch.object(emb, "_get_embedding_model", return_value=False):
|
||||
vec = embed_text("hello")
|
||||
assert len(vec) == 128
|
||||
|
||||
def test_uses_model_when_available(self):
|
||||
"""When a real model is loaded, call model.encode()."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode.return_value = MagicMock(tolist=MagicMock(return_value=[0.1, 0.2]))
|
||||
with patch.object(emb, "_get_embedding_model", return_value=mock_model):
|
||||
vec = embed_text("hello")
|
||||
assert vec == [0.1, 0.2]
|
||||
mock_model.encode.assert_called_once_with("hello")
|
||||
|
||||
|
||||
# ── _get_embedding_model ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetEmbeddingModel:
|
||||
"""Tests for lazy model loading."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset global state before each test."""
|
||||
emb.EMBEDDING_MODEL = None
|
||||
|
||||
def test_skip_embeddings_setting(self):
|
||||
"""When settings.timmy_skip_embeddings is True, model is set to False."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.timmy_skip_embeddings = True
|
||||
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
result = emb._get_embedding_model()
|
||||
assert result is False
|
||||
|
||||
def test_sentence_transformers_import_error(self):
|
||||
"""When sentence-transformers is missing, falls back to False."""
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.timmy_skip_embeddings = False
|
||||
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
||||
with patch.dict("sys.modules", {"sentence_transformers": None}):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
result = emb._get_embedding_model()
|
||||
assert result is False
|
||||
|
||||
def teardown_method(self):
|
||||
emb.EMBEDDING_MODEL = None
|
||||
Reference in New Issue
Block a user