forked from Rockachopa/Timmy-time-dashboard
Co-authored-by: Kimi Agent <kimi@timmy.local> Co-committed-by: Kimi Agent <kimi@timmy.local>
163 lines
5.6 KiB
Python
163 lines
5.6 KiB
Python
"""Unit tests for timmy.memory.embeddings — embedding, similarity, and keyword overlap."""
|
|
|
|
import math
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import timmy.memory.embeddings as emb
|
|
from timmy.memory.embeddings import (
|
|
_keyword_overlap,
|
|
_simple_hash_embedding,
|
|
cosine_similarity,
|
|
embed_text,
|
|
)
|
|
|
|
# ── _simple_hash_embedding ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestSimpleHashEmbedding:
|
|
def test_returns_128_dim_vector(self):
|
|
vec = _simple_hash_embedding("hello world")
|
|
assert len(vec) == 128
|
|
|
|
def test_normalized(self):
|
|
vec = _simple_hash_embedding("some text for embedding")
|
|
mag = math.sqrt(sum(x * x for x in vec))
|
|
assert mag == pytest.approx(1.0, abs=1e-6)
|
|
|
|
def test_deterministic(self):
|
|
a = _simple_hash_embedding("same input")
|
|
b = _simple_hash_embedding("same input")
|
|
assert a == b
|
|
|
|
def test_different_texts_differ(self):
|
|
a = _simple_hash_embedding("hello world")
|
|
b = _simple_hash_embedding("goodbye moon")
|
|
assert a != b
|
|
|
|
def test_empty_string(self):
|
|
vec = _simple_hash_embedding("")
|
|
assert len(vec) == 128
|
|
# All zeros normalised stays zero (mag fallback to 1.0)
|
|
assert all(x == 0.0 for x in vec)
|
|
|
|
def test_long_text_truncates_at_50_words(self):
|
|
"""Words beyond 50 should not change the result."""
|
|
short = " ".join(f"word{i}" for i in range(50))
|
|
long = short + " extra1 extra2 extra3"
|
|
assert _simple_hash_embedding(short) == _simple_hash_embedding(long)
|
|
|
|
|
|
# ── cosine_similarity ────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestCosineSimilarity:
|
|
def test_identical_vectors(self):
|
|
v = [1.0, 2.0, 3.0]
|
|
assert cosine_similarity(v, v) == pytest.approx(1.0)
|
|
|
|
def test_orthogonal_vectors(self):
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
assert cosine_similarity(a, b) == pytest.approx(0.0)
|
|
|
|
def test_opposite_vectors(self):
|
|
a = [1.0, 0.0]
|
|
b = [-1.0, 0.0]
|
|
assert cosine_similarity(a, b) == pytest.approx(-1.0)
|
|
|
|
def test_zero_vector_returns_zero(self):
|
|
assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0
|
|
assert cosine_similarity([1.0, 2.0], [0.0, 0.0]) == 0.0
|
|
|
|
def test_both_zero_vectors(self):
|
|
assert cosine_similarity([0.0], [0.0]) == 0.0
|
|
|
|
|
|
# ── _keyword_overlap ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestKeywordOverlap:
|
|
def test_full_overlap(self):
|
|
assert _keyword_overlap("hello world", "hello world") == pytest.approx(1.0)
|
|
|
|
def test_partial_overlap(self):
|
|
assert _keyword_overlap("hello world", "hello moon") == pytest.approx(0.5)
|
|
|
|
def test_no_overlap(self):
|
|
assert _keyword_overlap("hello", "goodbye") == pytest.approx(0.0)
|
|
|
|
def test_empty_query(self):
|
|
assert _keyword_overlap("", "anything") == 0.0
|
|
|
|
def test_case_insensitive(self):
|
|
assert _keyword_overlap("Hello World", "hello world") == pytest.approx(1.0)
|
|
|
|
|
|
# ── embed_text ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestEmbedText:
|
|
def setup_method(self):
|
|
self._saved_model = emb.EMBEDDING_MODEL
|
|
emb.EMBEDDING_MODEL = None
|
|
|
|
def teardown_method(self):
|
|
emb.EMBEDDING_MODEL = self._saved_model
|
|
|
|
def test_uses_fallback_when_model_disabled(self):
|
|
emb.EMBEDDING_MODEL = False
|
|
vec = embed_text("test")
|
|
assert len(vec) == 128 # hash fallback dimension
|
|
|
|
def test_uses_model_when_available(self):
|
|
import numpy as np
|
|
|
|
mock_model = MagicMock()
|
|
mock_model.encode.return_value = np.array([0.1, 0.2, 0.3])
|
|
emb.EMBEDDING_MODEL = mock_model
|
|
|
|
result = embed_text("test")
|
|
assert result == pytest.approx([0.1, 0.2, 0.3])
|
|
mock_model.encode.assert_called_once_with("test")
|
|
|
|
|
|
# ── _get_embedding_model ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestGetEmbeddingModel:
|
|
def setup_method(self):
|
|
self._saved_model = emb.EMBEDDING_MODEL
|
|
emb.EMBEDDING_MODEL = None
|
|
|
|
def teardown_method(self):
|
|
emb.EMBEDDING_MODEL = self._saved_model
|
|
|
|
def test_skip_embeddings_setting(self):
|
|
mock_settings = MagicMock()
|
|
mock_settings.timmy_skip_embeddings = True
|
|
with patch.dict("sys.modules", {"config": MagicMock(settings=mock_settings)}):
|
|
emb.EMBEDDING_MODEL = None
|
|
result = emb._get_embedding_model()
|
|
assert result is False
|
|
|
|
def test_fallback_when_transformers_missing(self):
|
|
mock_settings = MagicMock()
|
|
mock_settings.timmy_skip_embeddings = False
|
|
with patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"config": MagicMock(settings=mock_settings),
|
|
"sentence_transformers": None,
|
|
},
|
|
):
|
|
emb.EMBEDDING_MODEL = None
|
|
result = emb._get_embedding_model()
|
|
assert result is False
|
|
|
|
def test_returns_cached_model(self):
|
|
sentinel = object()
|
|
emb.EMBEDDING_MODEL = sentinel
|
|
assert emb._get_embedding_model() is sentinel
|