139 lines
5.4 KiB
Python
139 lines
5.4 KiB
Python
"""Embedding functions for Timmy's memory system.
|
|
|
|
Provides text-to-vector embedding using sentence-transformers (preferred)
|
|
with a deterministic hash-based fallback when the ML library is unavailable.
|
|
|
|
Also includes vector similarity utilities (cosine similarity, keyword overlap).
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import math
|
|
import json
|
|
import httpx # Import httpx for Ollama API calls
|
|
|
|
from config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Embedding model - small, fast, local
|
|
EMBEDDING_MODEL = None
|
|
EMBEDDING_DIM = 384 # MiniLM dimension, will be overridden if Ollama model has different dim
|
|
|
|
class OllamaEmbedder:
|
|
"""Mimics SentenceTransformer interface for Ollama."""
|
|
def __init__(self, model_name: str, ollama_url: str):
|
|
self.model_name = model_name
|
|
self.ollama_url = ollama_url
|
|
self.dimension = 0 # Will be updated after first call
|
|
|
|
def encode(self, sentences: str | list[str], convert_to_numpy: bool = False, normalize_embeddings: bool = True) -> list[list[float]] | list[float]:
|
|
"""Generate embeddings using Ollama."""
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
|
|
all_embeddings = []
|
|
for sentence in sentences:
|
|
try:
|
|
response = httpx.post(
|
|
f"{self.ollama_url}/api/embeddings",
|
|
json={"model": self.model_name, "prompt": sentence},
|
|
timeout=settings.mcp_bridge_timeout,
|
|
)
|
|
response.raise_for_status()
|
|
embedding = response.json()["embedding"]
|
|
if not self.dimension:
|
|
self.dimension = len(embedding) # Set dimension on first successful call
|
|
global EMBEDDING_DIM
|
|
EMBEDDING_DIM = self.dimension # Update global EMBEDDING_DIM
|
|
all_embeddings.append(embedding)
|
|
except httpx.RequestError as exc:
|
|
logger.error("Ollama embeddings request failed: %s", exc)
|
|
# Fallback to simple hash embedding on Ollama error
|
|
return _simple_hash_embedding(sentence)
|
|
except json.JSONDecodeError as exc:
|
|
logger.error("Failed to decode Ollama embeddings response: %s", exc)
|
|
return _simple_hash_embedding(sentence)
|
|
|
|
if len(all_embeddings) == 1 and isinstance(sentences, str):
|
|
return all_embeddings[0]
|
|
return all_embeddings
|
|
|
|
def _get_embedding_model():
|
|
"""Lazy-load embedding model, preferring Ollama if configured."""
|
|
global EMBEDDING_MODEL
|
|
global EMBEDDING_DIM
|
|
if EMBEDDING_MODEL is None:
|
|
if settings.timmy_skip_embeddings:
|
|
EMBEDDING_MODEL = False
|
|
return EMBEDDING_MODEL
|
|
|
|
if settings.timmy_embedding_backend == "ollama":
|
|
logger.info("MemorySystem: Using Ollama for embeddings with model %s", settings.ollama_embedding_model)
|
|
EMBEDDING_MODEL = OllamaEmbedder(settings.ollama_embedding_model, settings.normalized_ollama_url)
|
|
# We don't know the dimension until after the first call, so keep it default for now.
|
|
# It will be updated dynamically in OllamaEmbedder.encode
|
|
return EMBEDDING_MODEL
|
|
else:
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
|
EMBEDDING_DIM = 384 # Reset to MiniLM dimension
|
|
logger.info("MemorySystem: Loaded local embedding model (all-MiniLM-L6-v2)")
|
|
except ImportError:
|
|
logger.warning("MemorySystem: sentence-transformers not installed, using fallback")
|
|
EMBEDDING_MODEL = False # Use fallback
|
|
return EMBEDDING_MODEL
|
|
|
|
|
|
def _simple_hash_embedding(text: str) -> list[float]:
|
|
"""Fallback: Simple hash-based embedding when transformers unavailable."""
|
|
words = text.lower().split()
|
|
vec = [0.0] * 128
|
|
for i, word in enumerate(words[:50]): # First 50 words
|
|
h = hashlib.md5(word.encode()).hexdigest()
|
|
for j in range(8):
|
|
idx = (i * 8 + j) % 128
|
|
vec[idx] += int(h[j * 2 : j * 2 + 2], 16) / 255.0
|
|
# Normalize
|
|
mag = math.sqrt(sum(x * x for x in vec)) or 1.0
|
|
return [x / mag for x in vec]
|
|
|
|
|
|
def embed_text(text: str) -> list[float]:
|
|
"""Generate embedding for text."""
|
|
model = _get_embedding_model()
|
|
if model and model is not False:
|
|
embedding = model.encode(text)
|
|
# Ensure it's a list of floats, not numpy array
|
|
if hasattr(embedding, 'tolist'):
|
|
return embedding.tolist()
|
|
return embedding
|
|
return _simple_hash_embedding(text)
|
|
|
|
|
|
|
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
dot = sum(x * y for x, y in zip(a, b, strict=False))
|
|
mag_a = math.sqrt(sum(x * x for x in a))
|
|
mag_b = math.sqrt(sum(x * x for x in b))
|
|
if mag_a == 0 or mag_b == 0:
|
|
return 0.0
|
|
return dot / (mag_a * mag_b)
|
|
|
|
|
|
# Alias for backward compatibility
|
|
_cosine_similarity = cosine_similarity
|
|
|
|
|
|
def _keyword_overlap(query: str, content: str) -> float:
|
|
"""Simple keyword overlap score as fallback."""
|
|
query_words = set(query.lower().split())
|
|
content_words = set(content.lower().split())
|
|
if not query_words:
|
|
return 0.0
|
|
overlap = len(query_words & content_words)
|
|
return overlap / len(query_words)
|