forked from Rockachopa/Timmy-time-dashboard
feat: distributed brain architecture with rqlite and local embeddings (#118)
- Add new brain module with rqlite-based distributed memory and task queue - Implement BrainClient for memory operations (store, recall, search) - Implement DistributedWorker for continuous task processing - Add local embeddings via sentence-transformers (all-MiniLM-L6-v2) - No OpenAI dependency, runs 100% local on CPU - 384-dim embeddings, 80MB model download - Deprecate persona system (swarm/personas.py, persona_node.py) - Deprecate hands system (hands/__init__.py, routes) - Update marketplace, tools, hands routes for brain integration - Add sentence-transformers and numpy to dependencies - All changes backward compatible with deprecation warnings Co-authored-by: Alexander Payne <apayne@MM.local>
This commit is contained in:
committed by
GitHub
parent
d9bb26b9c5
commit
f7c574e0b2
14
src/brain/__init__.py
Normal file
14
src/brain/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Distributed Brain — Rqlite-based memory and task queue.
|
||||
|
||||
A distributed SQLite (rqlite) cluster that runs across all Tailscale devices.
|
||||
Provides:
|
||||
- Semantic memory with local embeddings
|
||||
- Distributed task queue with work stealing
|
||||
- Automatic replication and failover
|
||||
"""
|
||||
|
||||
from brain.client import BrainClient
|
||||
from brain.worker import DistributedWorker
|
||||
from brain.embeddings import LocalEmbedder
|
||||
|
||||
__all__ = ["BrainClient", "DistributedWorker", "LocalEmbedder"]
|
||||
440
src/brain/client.py
Normal file
440
src/brain/client.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""Brain client — interface to distributed rqlite memory.
|
||||
|
||||
All devices connect to the local rqlite node, which replicates to peers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_RQLITE_URL = "http://localhost:4001"
|
||||
|
||||
|
||||
class BrainClient:
|
||||
"""Client for distributed brain (rqlite).
|
||||
|
||||
Connects to local rqlite instance, which handles replication.
|
||||
All writes go to leader, reads can come from local node.
|
||||
"""
|
||||
|
||||
def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None):
|
||||
self.rqlite_url = rqlite_url or os.environ.get("RQLITE_URL", DEFAULT_RQLITE_URL)
|
||||
self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}"
|
||||
self.source = self._detect_source()
|
||||
self._client = httpx.AsyncClient(timeout=30)
|
||||
|
||||
def _detect_source(self) -> str:
|
||||
"""Detect what component is using the brain."""
|
||||
# Could be 'timmy', 'zeroclaw', 'worker', etc.
|
||||
# For now, infer from context or env
|
||||
return os.environ.get("BRAIN_SOURCE", "timmy")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Memory Operations
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
content: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
source: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Store a memory with embedding.
|
||||
|
||||
Args:
|
||||
content: Text content to remember
|
||||
tags: Optional list of tags (e.g., ['shell', 'result'])
|
||||
source: Source identifier (defaults to self.source)
|
||||
metadata: Additional JSON-serializable metadata
|
||||
|
||||
Returns:
|
||||
Dict with 'id' and 'status'
|
||||
"""
|
||||
from brain.embeddings import get_embedder
|
||||
|
||||
embedder = get_embedder()
|
||||
embedding_bytes = embedder.encode_single(content)
|
||||
|
||||
query = """
|
||||
INSERT INTO memories (content, embedding, source, tags, metadata, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
params = [
|
||||
content,
|
||||
embedding_bytes,
|
||||
source or self.source,
|
||||
json.dumps(tags or []),
|
||||
json.dumps(metadata or {}),
|
||||
datetime.utcnow().isoformat()
|
||||
]
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
# Extract inserted ID
|
||||
last_id = None
|
||||
if "results" in result and result["results"]:
|
||||
last_id = result["results"][0].get("last_insert_id")
|
||||
|
||||
logger.debug(f"Stored memory {last_id}: {content[:50]}...")
|
||||
return {"id": last_id, "status": "stored"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store memory: {e}")
|
||||
raise
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
sources: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""Semantic search for memories.
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
limit: Max results to return
|
||||
sources: Filter by source(s) (e.g., ['timmy', 'user'])
|
||||
|
||||
Returns:
|
||||
List of memory content strings
|
||||
"""
|
||||
from brain.embeddings import get_embedder
|
||||
|
||||
embedder = get_embedder()
|
||||
query_emb = embedder.encode_single(query)
|
||||
|
||||
# rqlite with sqlite-vec extension for vector search
|
||||
sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?"
|
||||
params = [query_emb]
|
||||
|
||||
if sources:
|
||||
placeholders = ",".join(["?"] * len(sources))
|
||||
sql += f" AND source IN ({placeholders})"
|
||||
params.extend(sources)
|
||||
|
||||
sql += " ORDER BY distance LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, params]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
results = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
results.append({
|
||||
"content": row[0],
|
||||
"source": row[1],
|
||||
"metadata": json.loads(row[2]) if row[2] else {},
|
||||
"distance": row[3]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search memories: {e}")
|
||||
# Graceful fallback - return empty list
|
||||
return []
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
hours: int = 24,
|
||||
limit: int = 20,
|
||||
sources: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get recent memories by time.
|
||||
|
||||
Args:
|
||||
hours: Look back this many hours
|
||||
limit: Max results
|
||||
sources: Optional source filter
|
||||
|
||||
Returns:
|
||||
List of memory dicts
|
||||
"""
|
||||
sql = """
|
||||
SELECT id, content, source, tags, metadata, created_at
|
||||
FROM memories
|
||||
WHERE created_at > datetime('now', ?)
|
||||
"""
|
||||
params = [f"-{hours} hours"]
|
||||
|
||||
if sources:
|
||||
placeholders = ",".join(["?"] * len(sources))
|
||||
sql += f" AND source IN ({placeholders})"
|
||||
params.extend(sources)
|
||||
|
||||
sql += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, params]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
memories = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
memories.append({
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"source": row[2],
|
||||
"tags": json.loads(row[3]) if row[3] else [],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5]
|
||||
})
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent memories: {e}")
|
||||
return []
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Get formatted context for system prompt.
|
||||
|
||||
Combines recent memories + relevant memories.
|
||||
|
||||
Args:
|
||||
query: Current user query to find relevant context
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
recent = await self.get_recent(hours=24, limit=10)
|
||||
relevant = await self.recall(query, limit=5)
|
||||
|
||||
lines = ["Recent activity:"]
|
||||
for m in recent[:5]:
|
||||
lines.append(f"- {m['content'][:100]}")
|
||||
|
||||
lines.append("\nRelevant memories:")
|
||||
for r in relevant[:5]:
|
||||
lines.append(f"- {r['content'][:100]}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Task Queue Operations
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
content: str,
|
||||
task_type: str = "general",
|
||||
priority: int = 0,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Submit a task to the distributed queue.
|
||||
|
||||
Args:
|
||||
content: Task description/prompt
|
||||
task_type: Type of task (shell, creative, code, research, general)
|
||||
priority: Higher = processed first
|
||||
metadata: Additional task data
|
||||
|
||||
Returns:
|
||||
Dict with task 'id'
|
||||
"""
|
||||
query = """
|
||||
INSERT INTO tasks (content, task_type, priority, status, metadata, created_at)
|
||||
VALUES (?, ?, ?, 'pending', ?, ?)
|
||||
"""
|
||||
params = [
|
||||
content,
|
||||
task_type,
|
||||
priority,
|
||||
json.dumps(metadata or {}),
|
||||
datetime.utcnow().isoformat()
|
||||
]
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
last_id = None
|
||||
if "results" in result and result["results"]:
|
||||
last_id = result["results"][0].get("last_insert_id")
|
||||
|
||||
logger.info(f"Submitted task {last_id}: {content[:50]}...")
|
||||
return {"id": last_id, "status": "queued"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit task: {e}")
|
||||
raise
|
||||
|
||||
async def claim_task(
|
||||
self,
|
||||
capabilities: List[str],
|
||||
node_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Atomically claim next available task.
|
||||
|
||||
Uses UPDATE ... RETURNING pattern for atomic claim.
|
||||
|
||||
Args:
|
||||
capabilities: List of capabilities this node has
|
||||
node_id: Identifier for claiming node
|
||||
|
||||
Returns:
|
||||
Task dict or None if no tasks available
|
||||
"""
|
||||
claimer = node_id or self.node_id
|
||||
|
||||
# Try to claim a matching task atomically
|
||||
# This works because rqlite uses Raft consensus - only one node wins
|
||||
placeholders = ",".join(["?"] * len(capabilities))
|
||||
|
||||
query = f"""
|
||||
UPDATE tasks
|
||||
SET status = 'claimed',
|
||||
claimed_by = ?,
|
||||
claimed_at = ?
|
||||
WHERE id = (
|
||||
SELECT id FROM tasks
|
||||
WHERE status = 'pending'
|
||||
AND (task_type IN ({placeholders}) OR task_type = 'general')
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
AND status = 'pending'
|
||||
RETURNING id, content, task_type, priority, metadata
|
||||
"""
|
||||
params = [claimer, datetime.utcnow().isoformat()] + capabilities
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
if "results" in result and result["results"]:
|
||||
rows = result["results"][0].get("rows", [])
|
||||
if rows:
|
||||
row = rows[0]
|
||||
return {
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"type": row[2],
|
||||
"priority": row[3],
|
||||
"metadata": json.loads(row[4]) if row[4] else {}
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to claim task: {e}")
|
||||
return None
|
||||
|
||||
async def complete_task(
|
||||
self,
|
||||
task_id: int,
|
||||
success: bool,
|
||||
result: Optional[str] = None,
|
||||
error: Optional[str] = None
|
||||
) -> None:
|
||||
"""Mark task as completed or failed.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
success: True if task succeeded
|
||||
result: Task result/output
|
||||
error: Error message if failed
|
||||
"""
|
||||
status = "done" if success else "failed"
|
||||
|
||||
query = """
|
||||
UPDATE tasks
|
||||
SET status = ?,
|
||||
result = ?,
|
||||
error = ?,
|
||||
completed_at = ?
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = [status, result, error, datetime.utcnow().isoformat(), task_id]
|
||||
|
||||
try:
|
||||
await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
logger.debug(f"Task {task_id} marked {status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete task {task_id}: {e}")
|
||||
|
||||
async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get list of pending tasks (for dashboard/monitoring).
|
||||
|
||||
Args:
|
||||
limit: Max tasks to return
|
||||
|
||||
Returns:
|
||||
List of pending task dicts
|
||||
"""
|
||||
sql = """
|
||||
SELECT id, content, task_type, priority, metadata, created_at
|
||||
FROM tasks
|
||||
WHERE status = 'pending'
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, [limit]]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
tasks = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
tasks.append({
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"type": row[2],
|
||||
"priority": row[3],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5]
|
||||
})
|
||||
|
||||
return tasks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pending tasks: {e}")
|
||||
return []
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client."""
|
||||
await self._client.aclose()
|
||||
86
src/brain/embeddings.py
Normal file
86
src/brain/embeddings.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Local embeddings using sentence-transformers.
|
||||
|
||||
No OpenAI dependency. Runs 100% locally on CPU.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model cache
|
||||
_model = None
|
||||
_model_name = "all-MiniLM-L6-v2"
|
||||
_dimensions = 384
|
||||
|
||||
|
||||
class LocalEmbedder:
|
||||
"""Local sentence transformer for embeddings.
|
||||
|
||||
Uses all-MiniLM-L6-v2 (80MB download, runs on CPU).
|
||||
384-dimensional embeddings, good enough for semantic search.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = _model_name):
|
||||
self.model_name = model_name
|
||||
self._model = None
|
||||
self._dimensions = _dimensions
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the model."""
|
||||
global _model
|
||||
if _model is not None:
|
||||
self._model = _model
|
||||
return
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
_model = SentenceTransformer(self.model_name)
|
||||
self._model = _model
|
||||
logger.info(f"Embedding model loaded ({self._dimensions} dims)")
|
||||
except ImportError:
|
||||
logger.error("sentence-transformers not installed. Run: pip install sentence-transformers")
|
||||
raise
|
||||
|
||||
def encode(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""Encode text to embedding vector(s).
|
||||
|
||||
Args:
|
||||
text: String or list of strings to encode
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (dims,) for single string or (n, dims) for list
|
||||
"""
|
||||
if self._model is None:
|
||||
self._load_model()
|
||||
|
||||
# Normalize embeddings for cosine similarity
|
||||
return self._model.encode(text, normalize_embeddings=True)
|
||||
|
||||
def encode_single(self, text: str) -> bytes:
|
||||
"""Encode single text to bytes for SQLite storage.
|
||||
|
||||
Returns:
|
||||
Float32 bytes
|
||||
"""
|
||||
embedding = self.encode(text)
|
||||
if len(embedding.shape) > 1:
|
||||
embedding = embedding[0]
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
def similarity(self, a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors.
|
||||
|
||||
Vectors should already be normalized from encode().
|
||||
"""
|
||||
return float(np.dot(a, b))
|
||||
|
||||
|
||||
def get_embedder() -> LocalEmbedder:
|
||||
"""Get singleton embedder instance."""
|
||||
return LocalEmbedder()
|
||||
94
src/brain/schema.py
Normal file
94
src/brain/schema.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Database schema for distributed brain.
|
||||
|
||||
SQL to initialize rqlite with memories and tasks tables.
|
||||
"""
|
||||
|
||||
# Schema version for migrations
|
||||
SCHEMA_VERSION = 1
|
||||
|
||||
INIT_SQL = """
|
||||
-- Enable SQLite extensions
|
||||
.load vector0
|
||||
.load vec0
|
||||
|
||||
-- Memories table with vector search
|
||||
CREATE TABLE IF NOT EXISTS memories (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
embedding BLOB, -- 384-dim float32 array (normalized)
|
||||
source TEXT, -- 'timmy', 'zeroclaw', 'worker', 'user'
|
||||
tags TEXT, -- JSON array
|
||||
metadata TEXT, -- JSON object
|
||||
created_at TEXT -- ISO8601
|
||||
);
|
||||
|
||||
-- Tasks table (distributed queue)
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
task_type TEXT DEFAULT 'general', -- shell, creative, code, research, general
|
||||
priority INTEGER DEFAULT 0, -- Higher = process first
|
||||
status TEXT DEFAULT 'pending', -- pending, claimed, done, failed
|
||||
claimed_by TEXT, -- Node ID
|
||||
claimed_at TEXT,
|
||||
result TEXT,
|
||||
error TEXT,
|
||||
metadata TEXT, -- JSON
|
||||
created_at TEXT,
|
||||
completed_at TEXT
|
||||
);
|
||||
|
||||
-- Node registry (who's online)
|
||||
CREATE TABLE IF NOT EXISTS nodes (
|
||||
node_id TEXT PRIMARY KEY,
|
||||
capabilities TEXT, -- JSON array
|
||||
last_seen TEXT, -- ISO8601
|
||||
load_average REAL
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_source ON memories(source);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_created ON memories(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status_priority ON tasks(status, priority DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_claimed ON tasks(claimed_by, status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_type ON tasks(task_type);
|
||||
|
||||
-- Virtual table for vector search (if using sqlite-vec)
|
||||
-- Note: This requires sqlite-vec extension loaded
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_memories USING vec0(
|
||||
embedding float[384]
|
||||
);
|
||||
|
||||
-- Schema version tracking
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at TEXT
|
||||
);
|
||||
|
||||
INSERT OR REPLACE INTO schema_version (version, applied_at)
|
||||
VALUES (1, datetime('now'));
|
||||
"""
|
||||
|
||||
MIGRATIONS = {
|
||||
# Future migrations go here
|
||||
# 2: "ALTER TABLE ...",
|
||||
}
|
||||
|
||||
|
||||
def get_init_sql() -> str:
|
||||
"""Get SQL to initialize fresh database."""
|
||||
return INIT_SQL
|
||||
|
||||
|
||||
def get_migration_sql(from_version: int, to_version: int) -> str:
|
||||
"""Get SQL to migrate between versions."""
|
||||
if to_version <= from_version:
|
||||
return ""
|
||||
|
||||
sql_parts = []
|
||||
for v in range(from_version + 1, to_version + 1):
|
||||
if v in MIGRATIONS:
|
||||
sql_parts.append(MIGRATIONS[v])
|
||||
sql_parts.append(f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');")
|
||||
|
||||
return "\n".join(sql_parts)
|
||||
366
src/brain/worker.py
Normal file
366
src/brain/worker.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""Distributed Worker — continuously processes tasks from the brain queue.
|
||||
|
||||
Each device runs a worker that claims and executes tasks based on capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from brain.client import BrainClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DistributedWorker:
|
||||
"""Continuous task processor for the distributed brain.
|
||||
|
||||
Runs on every device, claims tasks matching its capabilities,
|
||||
executes them immediately, stores results.
|
||||
"""
|
||||
|
||||
def __init__(self, brain_client: Optional[BrainClient] = None):
|
||||
self.brain = brain_client or BrainClient()
|
||||
self.node_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||
self.capabilities = self._detect_capabilities()
|
||||
self.running = False
|
||||
self._handlers: Dict[str, Callable] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _detect_capabilities(self) -> List[str]:
|
||||
"""Detect what this node can do."""
|
||||
caps = ["general", "shell", "file_ops", "git"]
|
||||
|
||||
# Check for GPU
|
||||
if self._has_gpu():
|
||||
caps.append("gpu")
|
||||
caps.append("creative")
|
||||
caps.append("image_gen")
|
||||
caps.append("video_gen")
|
||||
|
||||
# Check for internet
|
||||
if self._has_internet():
|
||||
caps.append("web")
|
||||
caps.append("research")
|
||||
|
||||
# Check memory
|
||||
mem_gb = self._get_memory_gb()
|
||||
if mem_gb > 16:
|
||||
caps.append("large_model")
|
||||
if mem_gb > 32:
|
||||
caps.append("huge_model")
|
||||
|
||||
# Check for specific tools
|
||||
if self._has_command("ollama"):
|
||||
caps.append("ollama")
|
||||
if self._has_command("docker"):
|
||||
caps.append("docker")
|
||||
if self._has_command("cargo"):
|
||||
caps.append("rust")
|
||||
|
||||
logger.info(f"Worker capabilities: {caps}")
|
||||
return caps
|
||||
|
||||
def _has_gpu(self) -> bool:
|
||||
"""Check for NVIDIA or AMD GPU."""
|
||||
try:
|
||||
# Check for nvidia-smi
|
||||
result = subprocess.run(
|
||||
["nvidia-smi"], capture_output=True, timeout=5
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check for ROCm
|
||||
if os.path.exists("/opt/rocm"):
|
||||
return True
|
||||
|
||||
# Check for Apple Silicon Metal
|
||||
if os.uname().sysname == "Darwin":
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["system_profiler", "SPDisplaysDataType"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
if "Metal" in result.stdout:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _has_internet(self) -> bool:
|
||||
"""Check if we have internet connectivity."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["curl", "-s", "--max-time", "3", "https://1.1.1.1"],
|
||||
capture_output=True, timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def _get_memory_gb(self) -> float:
|
||||
"""Get total system memory in GB."""
|
||||
try:
|
||||
if os.uname().sysname == "Darwin":
|
||||
result = subprocess.run(
|
||||
["sysctl", "-n", "hw.memsize"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
bytes_mem = int(result.stdout.strip())
|
||||
return bytes_mem / (1024**3)
|
||||
else:
|
||||
with open("/proc/meminfo") as f:
|
||||
for line in f:
|
||||
if line.startswith("MemTotal:"):
|
||||
kb = int(line.split()[1])
|
||||
return kb / (1024**2)
|
||||
except:
|
||||
pass
|
||||
return 8.0 # Assume 8GB if we can't detect
|
||||
|
||||
def _has_command(self, cmd: str) -> bool:
|
||||
"""Check if command exists."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["which", cmd], capture_output=True, timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def _register_default_handlers(self):
|
||||
"""Register built-in task handlers."""
|
||||
self._handlers = {
|
||||
"shell": self._handle_shell,
|
||||
"creative": self._handle_creative,
|
||||
"code": self._handle_code,
|
||||
"research": self._handle_research,
|
||||
"general": self._handle_general,
|
||||
}
|
||||
|
||||
def register_handler(self, task_type: str, handler: Callable[[str], Any]):
|
||||
"""Register a custom task handler.
|
||||
|
||||
Args:
|
||||
task_type: Type of task this handler handles
|
||||
handler: Async function that takes task content and returns result
|
||||
"""
|
||||
self._handlers[task_type] = handler
|
||||
if task_type not in self.capabilities:
|
||||
self.capabilities.append(task_type)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Task Handlers
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _handle_shell(self, command: str) -> str:
|
||||
"""Execute shell command via ZeroClaw or direct subprocess."""
|
||||
# Try ZeroClaw first if available
|
||||
if self._has_command("zeroclaw"):
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
f"zeroclaw exec --json '{command}'",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
# Store result in brain
|
||||
await self.brain.remember(
|
||||
content=f"Shell: {command}\nOutput: {stdout.decode()}",
|
||||
tags=["shell", "result"],
|
||||
source=self.node_id,
|
||||
metadata={"command": command, "exit_code": proc.returncode}
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise Exception(f"Command failed: {stderr.decode()}")
|
||||
return stdout.decode()
|
||||
|
||||
# Fallback to direct subprocess (less safe)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise Exception(f"Command failed: {stderr.decode()}")
|
||||
return stdout.decode()
|
||||
|
||||
async def _handle_creative(self, prompt: str) -> str:
|
||||
"""Generate creative media (requires GPU)."""
|
||||
if "gpu" not in self.capabilities:
|
||||
raise Exception("GPU not available on this node")
|
||||
|
||||
# This would call creative tools (Stable Diffusion, etc.)
|
||||
# For now, placeholder
|
||||
logger.info(f"Creative task: {prompt[:50]}...")
|
||||
|
||||
# Store result
|
||||
result = f"Creative output for: {prompt}"
|
||||
await self.brain.remember(
|
||||
content=result,
|
||||
tags=["creative", "generated"],
|
||||
source=self.node_id,
|
||||
metadata={"prompt": prompt}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _handle_code(self, description: str) -> str:
|
||||
"""Code generation and modification."""
|
||||
# Would use LLM to generate code
|
||||
# For now, placeholder
|
||||
logger.info(f"Code task: {description[:50]}...")
|
||||
return f"Code generated for: {description}"
|
||||
|
||||
async def _handle_research(self, query: str) -> str:
|
||||
"""Web research."""
|
||||
if "web" not in self.capabilities:
|
||||
raise Exception("Internet not available on this node")
|
||||
|
||||
# Would use browser automation or search
|
||||
logger.info(f"Research task: {query[:50]}...")
|
||||
return f"Research results for: {query}"
|
||||
|
||||
async def _handle_general(self, prompt: str) -> str:
|
||||
"""General LLM task via local Ollama."""
|
||||
if "ollama" not in self.capabilities:
|
||||
raise Exception("Ollama not available on this node")
|
||||
|
||||
# Call Ollama
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"curl", "-s", "http://localhost:11434/api/generate",
|
||||
"-d", json.dumps({
|
||||
"model": "llama3.1:8b-instruct",
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}),
|
||||
stdout=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
|
||||
response = json.loads(stdout.decode())
|
||||
result = response.get("response", "No response")
|
||||
|
||||
# Store in brain
|
||||
await self.brain.remember(
|
||||
content=f"Task: {prompt}\nResult: {result}",
|
||||
tags=["llm", "result"],
|
||||
source=self.node_id,
|
||||
metadata={"model": "llama3.1:8b-instruct"}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"LLM failed: {e}")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Main Loop
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute a claimed task."""
|
||||
task_type = task.get("type", "general")
|
||||
content = task.get("content", "")
|
||||
task_id = task.get("id")
|
||||
|
||||
handler = self._handlers.get(task_type, self._handlers["general"])
|
||||
|
||||
try:
|
||||
logger.info(f"Executing task {task_id}: {task_type}")
|
||||
result = await handler(content)
|
||||
|
||||
await self.brain.complete_task(task_id, success=True, result=result)
|
||||
logger.info(f"Task {task_id} completed")
|
||||
return {"success": True, "result": result}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Task {task_id} failed: {error_msg}")
|
||||
await self.brain.complete_task(task_id, success=False, error=error_msg)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
async def run_once(self) -> bool:
|
||||
"""Process one task if available.
|
||||
|
||||
Returns:
|
||||
True if a task was processed, False if no tasks available
|
||||
"""
|
||||
task = await self.brain.claim_task(self.capabilities, self.node_id)
|
||||
|
||||
if task:
|
||||
await self.execute_task(task)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""Main loop — continuously process tasks."""
|
||||
logger.info(f"Worker {self.node_id} started")
|
||||
logger.info(f"Capabilities: {self.capabilities}")
|
||||
|
||||
self.running = True
|
||||
consecutive_empty = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
had_work = await self.run_once()
|
||||
|
||||
if had_work:
|
||||
# Immediately check for more work
|
||||
consecutive_empty = 0
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
# No work available - adaptive sleep
|
||||
consecutive_empty += 1
|
||||
# Sleep 0.5s, but up to 2s if consistently empty
|
||||
sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the worker loop."""
|
||||
self.running = False
|
||||
logger.info("Worker stopping...")
|
||||
|
||||
|
||||
async def main():
|
||||
"""CLI entry point for worker."""
|
||||
import sys
|
||||
|
||||
# Allow capability overrides from CLI
|
||||
if len(sys.argv) > 1:
|
||||
caps = sys.argv[1].split(",")
|
||||
worker = DistributedWorker()
|
||||
worker.capabilities = caps
|
||||
logger.info(f"Overriding capabilities: {caps}")
|
||||
else:
|
||||
worker = DistributedWorker()
|
||||
|
||||
try:
|
||||
await worker.run()
|
||||
except KeyboardInterrupt:
|
||||
worker.stop()
|
||||
print("\nWorker stopped.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user