Files
timmy-home/timmy-local/cache/agent_cache.py

657 lines
23 KiB
Python

#!/usr/bin/env python3
"""
Multi-Tier Caching Layer for Local Timmy
Issue #103 — Cache Everywhere
Provides:
- Tier 1: KV Cache (prompt prefix caching)
- Tier 2: Semantic Response Cache (full LLM responses)
- Tier 3: Tool Result Cache (stable tool outputs)
- Tier 4: Embedding Cache (RAG embeddings)
- Tier 5: Template Cache (pre-compiled prompts)
- Tier 6: HTTP Response Cache (API responses)
"""
import sqlite3
import hashlib
import json
import time
import threading
from typing import Optional, Any, Dict, List, Callable
from dataclasses import dataclass, asdict
from pathlib import Path
import pickle
import functools
@dataclass
class CacheStats:
"""Statistics for cache monitoring."""
hits: int = 0
misses: int = 0
evictions: int = 0
hit_rate: float = 0.0
def record_hit(self):
self.hits += 1
self._update_rate()
def record_miss(self):
self.misses += 1
self._update_rate()
def record_eviction(self):
self.evictions += 1
def _update_rate(self):
total = self.hits + self.misses
if total > 0:
self.hit_rate = self.hits / total
class LRUCache:
"""In-memory LRU cache for hot path."""
def __init__(self, max_size: int = 1000):
self.max_size = max_size
self.cache: Dict[str, Any] = {}
self.access_order: List[str] = []
self.lock = threading.RLock()
def get(self, key: str) -> Optional[Any]:
with self.lock:
if key in self.cache:
# Move to front (most recent)
self.access_order.remove(key)
self.access_order.append(key)
return self.cache[key]
return None
def put(self, key: str, value: Any):
with self.lock:
if key in self.cache:
self.access_order.remove(key)
elif len(self.cache) >= self.max_size:
# Evict oldest
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[key] = value
self.access_order.append(key)
def invalidate(self, key: str):
with self.lock:
if key in self.cache:
self.access_order.remove(key)
del self.cache[key]
def clear(self):
with self.lock:
self.cache.clear()
self.access_order.clear()
class ResponseCache:
"""Tier 2: Semantic Response Cache — full LLM responses."""
def __init__(self, db_path: str = "~/.timmy/cache/responses.db"):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.stats = CacheStats()
self.lru = LRUCache(max_size=100)
self._init_db()
def _init_db(self):
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS responses (
prompt_hash TEXT PRIMARY KEY,
response TEXT NOT NULL,
created_at REAL NOT NULL,
ttl INTEGER NOT NULL,
access_count INTEGER DEFAULT 0,
last_accessed REAL
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_accessed ON responses(last_accessed)
""")
def _hash_prompt(self, prompt: str) -> str:
"""Hash prompt after normalizing (removing timestamps, etc)."""
# Normalize: lowercase, strip extra whitespace
normalized = " ".join(prompt.lower().split())
return hashlib.sha256(normalized.encode()).hexdigest()[:32]
def get(self, prompt: str, ttl: int = 3600) -> Optional[str]:
"""Get cached response if available and not expired."""
prompt_hash = self._hash_prompt(prompt)
# Check LRU first
cached = self.lru.get(prompt_hash)
if cached:
self.stats.record_hit()
return cached
# Check disk cache
with sqlite3.connect(self.db_path) as conn:
row = conn.execute(
"SELECT response, created_at, ttl FROM responses WHERE prompt_hash = ?",
(prompt_hash,)
).fetchone()
if row:
response, created_at, stored_ttl = row
# Use minimum of requested and stored TTL
effective_ttl = min(ttl, stored_ttl)
if time.time() - created_at < effective_ttl:
# Cache hit
self.stats.record_hit()
# Update access stats
conn.execute(
"UPDATE responses SET access_count = access_count + 1, last_accessed = ? WHERE prompt_hash = ?",
(time.time(), prompt_hash)
)
# Add to LRU
self.lru.put(prompt_hash, response)
return response
else:
# Expired
conn.execute("DELETE FROM responses WHERE prompt_hash = ?", (prompt_hash,))
self.stats.record_eviction()
self.stats.record_miss()
return None
def put(self, prompt: str, response: str, ttl: int = 3600):
"""Cache a response with TTL."""
prompt_hash = self._hash_prompt(prompt)
# Add to LRU
self.lru.put(prompt_hash, response)
# Add to disk cache
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""INSERT OR REPLACE INTO responses
(prompt_hash, response, created_at, ttl, last_accessed)
VALUES (?, ?, ?, ?, ?)""",
(prompt_hash, response, time.time(), ttl, time.time())
)
def invalidate_pattern(self, pattern: str):
"""Invalidate all cached responses matching pattern."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM responses WHERE response LIKE ?", (f"%{pattern}%",))
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with sqlite3.connect(self.db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM responses").fetchone()[0]
total_accesses = conn.execute("SELECT SUM(access_count) FROM responses").fetchone()[0] or 0
return {
"tier": "response_cache",
"memory_entries": len(self.lru.cache),
"disk_entries": count,
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": f"{self.stats.hit_rate:.1%}",
"total_accesses": total_accesses
}
class ToolCache:
"""Tier 3: Tool Result Cache — stable tool outputs."""
# TTL configuration per tool type (seconds)
TOOL_TTL = {
"system_info": 60,
"disk_usage": 120,
"git_status": 30,
"git_log": 300,
"health_check": 60,
"gitea_list_issues": 120,
"file_read": 30,
"process_list": 30,
"service_status": 60,
}
# Tools that invalidate cache on write operations
INVALIDATORS = {
"git_commit": ["git_status", "git_log"],
"git_pull": ["git_status", "git_log"],
"file_write": ["file_read"],
"gitea_create_issue": ["gitea_list_issues"],
"gitea_comment": ["gitea_list_issues"],
}
def __init__(self, db_path: str = "~/.timmy/cache/tool_cache.db"):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.stats = CacheStats()
self.lru = LRUCache(max_size=500)
self._init_db()
def _init_db(self):
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS tool_results (
tool_hash TEXT PRIMARY KEY,
tool_name TEXT NOT NULL,
params_hash TEXT NOT NULL,
result TEXT NOT NULL,
created_at REAL NOT NULL,
ttl INTEGER NOT NULL
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_tool_name ON tool_results(tool_name)
""")
def _hash_call(self, tool_name: str, params: Dict) -> str:
"""Hash tool name and params for cache key."""
param_str = json.dumps(params, sort_keys=True)
combined = f"{tool_name}:{param_str}"
return hashlib.sha256(combined.encode()).hexdigest()[:32]
def get(self, tool_name: str, params: Dict) -> Optional[Any]:
"""Get cached tool result if available."""
if tool_name not in self.TOOL_TTL:
return None # Not cacheable
tool_hash = self._hash_call(tool_name, params)
# Check LRU
cached = self.lru.get(tool_hash)
if cached:
self.stats.record_hit()
return pickle.loads(cached)
# Check disk
with sqlite3.connect(self.db_path) as conn:
row = conn.execute(
"SELECT result, created_at, ttl FROM tool_results WHERE tool_hash = ?",
(tool_hash,)
).fetchone()
if row:
result, created_at, ttl = row
if time.time() - created_at < ttl:
self.stats.record_hit()
self.lru.put(tool_hash, result)
return pickle.loads(result)
else:
conn.execute("DELETE FROM tool_results WHERE tool_hash = ?", (tool_hash,))
self.stats.record_eviction()
self.stats.record_miss()
return None
def put(self, tool_name: str, params: Dict, result: Any):
"""Cache a tool result."""
if tool_name not in self.TOOL_TTL:
return # Not cacheable
ttl = self.TOOL_TTL[tool_name]
tool_hash = self._hash_call(tool_name, params)
params_hash = hashlib.sha256(json.dumps(params, sort_keys=True).encode()).hexdigest()[:16]
# Add to LRU
pickled = pickle.dumps(result)
self.lru.put(tool_hash, pickled)
# Add to disk
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""INSERT OR REPLACE INTO tool_results
(tool_hash, tool_name, params_hash, result, created_at, ttl)
VALUES (?, ?, ?, ?, ?, ?)""",
(tool_hash, tool_name, params_hash, pickled, time.time(), ttl)
)
def invalidate(self, tool_name: str):
"""Invalidate all cached results for a tool."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM tool_results WHERE tool_name = ?", (tool_name,))
# Clear matching LRU entries
# (simplified: clear all since LRU doesn't track tool names)
self.lru.clear()
def handle_invalidation(self, tool_name: str):
"""Handle cache invalidation after a write operation."""
if tool_name in self.INVALIDATORS:
for dependent in self.INVALIDATORS[tool_name]:
self.invalidate(dependent)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with sqlite3.connect(self.db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM tool_results").fetchone()[0]
by_tool = conn.execute(
"SELECT tool_name, COUNT(*) FROM tool_results GROUP BY tool_name"
).fetchall()
return {
"tier": "tool_cache",
"memory_entries": len(self.lru.cache),
"disk_entries": count,
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": f"{self.stats.hit_rate:.1%}",
"by_tool": dict(by_tool)
}
class EmbeddingCache:
"""Tier 4: Embedding Cache — for RAG pipeline (#93)."""
def __init__(self, db_path: str = "~/.timmy/cache/embeddings.db"):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.stats = CacheStats()
self._init_db()
def _init_db(self):
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
file_path TEXT PRIMARY KEY,
mtime REAL NOT NULL,
embedding BLOB NOT NULL,
model_name TEXT NOT NULL,
created_at REAL NOT NULL
)
""")
def get(self, file_path: str, mtime: float, model_name: str) -> Optional[List[float]]:
"""Get embedding if file hasn't changed and model matches."""
with sqlite3.connect(self.db_path) as conn:
row = conn.execute(
"SELECT embedding, mtime, model_name FROM embeddings WHERE file_path = ?",
(file_path,)
).fetchone()
if row:
embedding_blob, stored_mtime, stored_model = row
if stored_mtime == mtime and stored_model == model_name:
self.stats.record_hit()
return pickle.loads(embedding_blob)
self.stats.record_miss()
return None
def put(self, file_path: str, mtime: float, embedding: List[float], model_name: str):
"""Store embedding with file metadata."""
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""INSERT OR REPLACE INTO embeddings
(file_path, mtime, embedding, model_name, created_at)
VALUES (?, ?, ?, ?, ?)""",
(file_path, mtime, pickle.dumps(embedding), model_name, time.time())
)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with sqlite3.connect(self.db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM embeddings").fetchone()[0]
models = conn.execute(
"SELECT model_name, COUNT(*) FROM embeddings GROUP BY model_name"
).fetchall()
return {
"tier": "embedding_cache",
"entries": count,
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": f"{self.stats.hit_rate:.1%}",
"by_model": dict(models)
}
class TemplateCache:
"""Tier 5: Template Cache — pre-compiled prompts."""
def __init__(self):
self.templates: Dict[str, str] = {}
self.tokenized: Dict[str, Any] = {} # For tokenizer outputs
self.stats = CacheStats()
def load_template(self, name: str, path: str) -> str:
"""Load and cache a template file."""
if name not in self.templates:
with open(path, 'r') as f:
self.templates[name] = f.read()
self.stats.record_miss()
else:
self.stats.record_hit()
return self.templates[name]
def get(self, name: str) -> Optional[str]:
"""Get cached template."""
if name in self.templates:
self.stats.record_hit()
return self.templates[name]
self.stats.record_miss()
return None
def cache_tokenized(self, name: str, tokens: Any):
"""Cache tokenized version of template."""
self.tokenized[name] = tokens
def get_tokenized(self, name: str) -> Optional[Any]:
"""Get cached tokenized template."""
return self.tokenized.get(name)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
return {
"tier": "template_cache",
"templates_cached": len(self.templates),
"tokenized_cached": len(self.tokenized),
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": f"{self.stats.hit_rate:.1%}"
}
class HTTPCache:
"""Tier 6: HTTP Response Cache — for API calls."""
def __init__(self, db_path: str = "~/.timmy/cache/http_cache.db"):
self.db_path = Path(db_path).expanduser()
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.stats = CacheStats()
self.lru = LRUCache(max_size=200)
self._init_db()
def _init_db(self):
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS http_responses (
url_hash TEXT PRIMARY KEY,
url TEXT NOT NULL,
response TEXT NOT NULL,
etag TEXT,
last_modified TEXT,
created_at REAL NOT NULL,
ttl INTEGER NOT NULL
)
""")
def _hash_url(self, url: str) -> str:
return hashlib.sha256(url.encode()).hexdigest()[:32]
def get(self, url: str, ttl: int = 300) -> Optional[Dict]:
"""Get cached HTTP response."""
url_hash = self._hash_url(url)
# Check LRU
cached = self.lru.get(url_hash)
if cached:
self.stats.record_hit()
return cached
# Check disk
with sqlite3.connect(self.db_path) as conn:
row = conn.execute(
"SELECT response, etag, last_modified, created_at, ttl FROM http_responses WHERE url_hash = ?",
(url_hash,)
).fetchone()
if row:
response, etag, last_modified, created_at, stored_ttl = row
effective_ttl = min(ttl, stored_ttl)
if time.time() - created_at < effective_ttl:
self.stats.record_hit()
result = {
"response": response,
"etag": etag,
"last_modified": last_modified
}
self.lru.put(url_hash, result)
return result
else:
conn.execute("DELETE FROM http_responses WHERE url_hash = ?", (url_hash,))
self.stats.record_eviction()
self.stats.record_miss()
return None
def put(self, url: str, response: str, etag: Optional[str] = None,
last_modified: Optional[str] = None, ttl: int = 300):
"""Cache HTTP response."""
url_hash = self._hash_url(url)
result = {
"response": response,
"etag": etag,
"last_modified": last_modified
}
self.lru.put(url_hash, result)
with sqlite3.connect(self.db_path) as conn:
conn.execute(
"""INSERT OR REPLACE INTO http_responses
(url_hash, url, response, etag, last_modified, created_at, ttl)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(url_hash, url, response, etag, last_modified, time.time(), ttl)
)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
with sqlite3.connect(self.db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM http_responses").fetchone()[0]
return {
"tier": "http_cache",
"memory_entries": len(self.lru.cache),
"disk_entries": count,
"hits": self.stats.hits,
"misses": self.stats.misses,
"hit_rate": f"{self.stats.hit_rate:.1%}"
}
class CacheManager:
"""Central manager for all cache tiers."""
def __init__(self, base_path: str = "~/.timmy/cache"):
self.base_path = Path(base_path).expanduser()
self.base_path.mkdir(parents=True, exist_ok=True)
# Initialize all tiers
self.response = ResponseCache(self.base_path / "responses.db")
self.tool = ToolCache(self.base_path / "tool_cache.db")
self.embedding = EmbeddingCache(self.base_path / "embeddings.db")
self.template = TemplateCache()
self.http = HTTPCache(self.base_path / "http_cache.db")
# KV cache handled by llama-server (external)
def get_all_stats(self) -> Dict[str, Dict]:
"""Get statistics for all cache tiers."""
return {
"response_cache": self.response.get_stats(),
"tool_cache": self.tool.get_stats(),
"embedding_cache": self.embedding.get_stats(),
"template_cache": self.template.get_stats(),
"http_cache": self.http.get_stats(),
}
def clear_all(self):
"""Clear all caches."""
self.response.lru.clear()
self.tool.lru.clear()
self.http.lru.clear()
self.template.templates.clear()
self.template.tokenized.clear()
# Clear databases
for db_file in self.base_path.glob("*.db"):
with sqlite3.connect(db_file) as conn:
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
for (table,) in tables:
conn.execute(f"DELETE FROM {table}")
def cached_tool(self, ttl: Optional[int] = None):
"""Decorator for caching tool results."""
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args, **kwargs):
tool_name = func.__name__
params = {"args": args, "kwargs": kwargs}
# Try cache
cached = self.tool.get(tool_name, params)
if cached is not None:
return cached
# Execute and cache
result = func(*args, **kwargs)
self.tool.put(tool_name, params, result)
return result
return wrapper
return decorator
# Singleton instance
cache_manager = CacheManager()
if __name__ == "__main__":
# Test the cache
print("Testing Timmy Cache Layer...")
print()
# Test response cache
print("1. Response Cache:")
cache_manager.response.put("What is 2+2?", "4", ttl=60)
cached = cache_manager.response.get("What is 2+2?")
print(f" Cached: {cached}")
print(f" Stats: {cache_manager.response.get_stats()}")
print()
# Test tool cache
print("2. Tool Cache:")
cache_manager.tool.put("system_info", {}, {"cpu": "ARM64", "ram": "8GB"})
cached = cache_manager.tool.get("system_info", {})
print(f" Cached: {cached}")
print(f" Stats: {cache_manager.tool.get_stats()}")
print()
# Test all stats
print("3. All Cache Stats:")
stats = cache_manager.get_all_stats()
for tier, tier_stats in stats.items():
print(f" {tier}: {tier_stats}")
print()
print("✅ Cache layer operational")