1
0

feat: code quality audit + autoresearch integration + infra hardening (#150)

This commit is contained in:
Alexander Whitestone
2026-03-08 12:50:44 -04:00
committed by GitHub
parent fd0ede0d51
commit ae3bb1cc21
186 changed files with 5129 additions and 3289 deletions

View File

@@ -71,3 +71,23 @@
# Requires: pip install ".[discord]" # Requires: pip install ".[discord]"
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots) # Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
# DISCORD_TOKEN= # DISCORD_TOKEN=
# ── Autoresearch — autonomous ML experiment loops ────────────────────────────
# Enable autonomous experiment loops (Karpathy autoresearch pattern).
# AUTORESEARCH_ENABLED=false
# AUTORESEARCH_WORKSPACE=data/experiments
# AUTORESEARCH_TIME_BUDGET=300
# AUTORESEARCH_MAX_ITERATIONS=100
# AUTORESEARCH_METRIC=val_bpb
# ── Docker Production ────────────────────────────────────────────────────────
# When deploying with docker-compose.prod.yml:
# - Containers run as non-root user "timmy" (defined in Dockerfile)
# - No source bind mounts — code is baked into the image
# - Set TIMMY_ENV=production to enforce security checks
# - All secrets below MUST be set before production deployment
#
# Taskosaur secrets (change from dev defaults):
# TASKOSAUR_JWT_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
# TASKOSAUR_JWT_REFRESH_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
# TASKOSAUR_ENCRYPTION_KEY=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">

View File

@@ -7,8 +7,30 @@ on:
branches: ["**"] branches: ["**"]
jobs: jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install linters
run: pip install black==23.12.1 isort==5.13.2 bandit==1.7.5
- name: Check formatting (black)
run: black --check --line-length 100 src/ tests/
- name: Check import order (isort)
run: isort --check --profile black --line-length 100 src/ tests/
- name: Security scan (bandit)
run: bandit -r src/ -ll -s B101,B104,B307,B310,B324,B601,B608 -q
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: lint
# Required for publish-unit-test-result-action to post check runs and PR comments # Required for publish-unit-test-result-action to post check runs and PR comments
permissions: permissions:
@@ -22,7 +44,15 @@ jobs:
- uses: actions/setup-python@v5 - uses: actions/setup-python@v5
with: with:
python-version: "3.11" python-version: "3.11"
cache: "pip"
- name: Cache Poetry virtualenv
uses: actions/cache@v4
with:
path: |
~/.cache/pypoetry
~/.cache/pip
key: poetry-${{ hashFiles('poetry.lock') }}
restore-keys: poetry-
- name: Install dependencies - name: Install dependencies
run: | run: |
@@ -60,3 +90,11 @@ jobs:
name: coverage-report name: coverage-report
path: reports/coverage.xml path: reports/coverage.xml
retention-days: 14 retention-days: 14
docker-build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build Docker image
run: DOCKER_BUILDKIT=1 docker build -t timmy-time:ci .

View File

@@ -51,12 +51,12 @@ repos:
exclude: ^tests/ exclude: ^tests/
stages: [manual] stages: [manual]
# Full test suite with 30-second wall-clock limit. # Unit tests only with 30-second wall-clock limit.
# Current baseline: ~18s. If tests get slow, this blocks the commit. # Runs only fast unit tests on commit; full suite runs in CI.
- repo: local - repo: local
hooks: hooks:
- id: pytest-fast - id: pytest-fast
name: pytest (30s limit) name: pytest unit (30s limit)
entry: timeout 30 poetry run pytest entry: timeout 30 poetry run pytest
language: system language: system
types: [python] types: [python]
@@ -68,4 +68,8 @@ repos:
- -q - -q
- --tb=short - --tb=short
- --timeout=10 - --timeout=10
- -m
- unit
- -p
- no:xdist
verbose: true verbose: true

View File

@@ -56,7 +56,7 @@ make test-cov # With coverage (term-missing + XML)
- **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest - **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest
- **FastAPI testing:** Use the `client` fixture - **FastAPI testing:** Use the `client` fixture
- **Async:** `asyncio_mode = "auto"` — async tests detected automatically - **Async:** `asyncio_mode = "auto"` — async tests detected automatically
- **Coverage threshold:** 60% (`fail_under` in `pyproject.toml`) - **Coverage threshold:** 73% (`fail_under` in `pyproject.toml`)
--- ---

View File

@@ -11,7 +11,7 @@
# timmy-time:latest \ # timmy-time:latest \
# python -m swarm.agent_runner --agent-id w1 --name Worker-1 # python -m swarm.agent_runner --agent-id w1 --name Worker-1
# ── Stage 1: Builder — export deps via Poetry, install via pip ────────────── # ── Stage 1: Builder — install deps via Poetry ──────────────────────────────
FROM python:3.12-slim AS builder FROM python:3.12-slim AS builder
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
@@ -20,18 +20,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
WORKDIR /build WORKDIR /build
# Install Poetry + export plugin (only needed for export, not in runtime) # Install Poetry (only needed to resolve deps, not in runtime)
RUN pip install --no-cache-dir poetry poetry-plugin-export RUN pip install --no-cache-dir poetry
# Copy dependency files only (layer caching) # Copy dependency files only (layer caching)
COPY pyproject.toml poetry.lock ./ COPY pyproject.toml poetry.lock ./
# Export pinned requirements and install with pip cache mount # Install deps directly from lock file (no virtualenv, no export plugin needed)
RUN poetry export --extras swarm --extras telegram --extras discord --without-hashes \ RUN poetry config virtualenvs.create false && \
-f requirements.txt -o requirements.txt poetry install --only main --extras telegram --extras discord --no-interaction
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir -r requirements.txt
# ── Stage 2: Runtime ─────────────────────────────────────────────────────── # ── Stage 2: Runtime ───────────────────────────────────────────────────────
FROM python:3.12-slim AS base FROM python:3.12-slim AS base

View File

@@ -210,6 +210,11 @@ docker-up:
mkdir -p data mkdir -p data
docker compose up -d dashboard docker compose up -d dashboard
docker-prod:
mkdir -p data
DOCKER_BUILDKIT=1 docker build -t timmy-time:latest .
docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d dashboard
docker-down: docker-down:
docker compose down docker compose down

56
docker-compose.prod.yml Normal file
View File

@@ -0,0 +1,56 @@
# ── Production Compose Overlay ─────────────────────────────────────────────────
#
# Usage:
# make docker-prod # build + start with prod settings
# docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d
#
# Differences from dev:
# - Runs as non-root user (timmy) from Dockerfile
# - No bind mounts — uses image-baked source only
# - Named volumes only (no host path dependencies)
# - Read-only root filesystem with tmpfs for /tmp
# - Resource limits enforced
# - Secrets passed via environment variables (set in .env)
#
# Security note: Set all secrets in .env before deploying.
# Required: L402_HMAC_SECRET, L402_MACAROON_SECRET
# Recommended: TASKOSAUR_JWT_SECRET, TASKOSAUR_ENCRYPTION_KEY
services:
dashboard:
# Remove dev-only root user override — use Dockerfile's USER timmy
user: ""
read_only: true
tmpfs:
- /tmp:size=100M
volumes:
# Override: named volume only, no host bind mounts
- timmy-data:/app/data
# Remove ./src and ./static bind mounts (use baked-in image files)
environment:
DEBUG: "false"
TIMMY_ENV: "production"
deploy:
resources:
limits:
cpus: "2.0"
memory: 2G
celery-worker:
user: ""
read_only: true
tmpfs:
- /tmp:size=100M
volumes:
- timmy-data:/app/data
deploy:
resources:
limits:
cpus: "1.0"
memory: 1G
# Override timmy-data to use a simple named volume (no host bind)
volumes:
timmy-data:
driver: local

View File

@@ -97,6 +97,12 @@ markers = [
"skip_ci: Skip in CI environment (local development only)", "skip_ci: Skip in CI environment (local development only)",
] ]
[tool.isort]
profile = "black"
line_length = 100
src_paths = ["src", "tests"]
known_first_party = ["brain", "config", "dashboard", "infrastructure", "integrations", "spark", "swarm", "timmy", "timmy_serve"]
[tool.coverage.run] [tool.coverage.run]
source = ["src"] source = ["src"]
omit = [ omit = [

View File

@@ -11,9 +11,9 @@ upgrade to distributed rqlite over Tailscale — same API, replicated.
""" """
from brain.client import BrainClient from brain.client import BrainClient
from brain.worker import DistributedWorker
from brain.embeddings import LocalEmbedder from brain.embeddings import LocalEmbedder
from brain.memory import UnifiedMemory, get_memory from brain.memory import UnifiedMemory, get_memory
from brain.worker import DistributedWorker
__all__ = [ __all__ = [
"BrainClient", "BrainClient",

View File

@@ -21,52 +21,54 @@ DEFAULT_RQLITE_URL = "http://localhost:4001"
class BrainClient: class BrainClient:
"""Client for distributed brain (rqlite). """Client for distributed brain (rqlite).
Connects to local rqlite instance, which handles replication. Connects to local rqlite instance, which handles replication.
All writes go to leader, reads can come from local node. All writes go to leader, reads can come from local node.
""" """
def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None): def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None):
from config import settings from config import settings
self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL
self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}" self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}"
self.source = self._detect_source() self.source = self._detect_source()
self._client = httpx.AsyncClient(timeout=30) self._client = httpx.AsyncClient(timeout=30)
def _detect_source(self) -> str: def _detect_source(self) -> str:
"""Detect what component is using the brain.""" """Detect what component is using the brain."""
# Could be 'timmy', 'zeroclaw', 'worker', etc. # Could be 'timmy', 'zeroclaw', 'worker', etc.
# For now, infer from context or env # For now, infer from context or env
from config import settings from config import settings
return settings.brain_source return settings.brain_source
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# Memory Operations # Memory Operations
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
async def remember( async def remember(
self, self,
content: str, content: str,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
source: Optional[str] = None, source: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Store a memory with embedding. """Store a memory with embedding.
Args: Args:
content: Text content to remember content: Text content to remember
tags: Optional list of tags (e.g., ['shell', 'result']) tags: Optional list of tags (e.g., ['shell', 'result'])
source: Source identifier (defaults to self.source) source: Source identifier (defaults to self.source)
metadata: Additional JSON-serializable metadata metadata: Additional JSON-serializable metadata
Returns: Returns:
Dict with 'id' and 'status' Dict with 'id' and 'status'
""" """
from brain.embeddings import get_embedder from brain.embeddings import get_embedder
embedder = get_embedder() embedder = get_embedder()
embedding_bytes = embedder.encode_single(content) embedding_bytes = embedder.encode_single(content)
query = """ query = """
INSERT INTO memories (content, embedding, source, tags, metadata, created_at) INSERT INTO memories (content, embedding, source, tags, metadata, created_at)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
@@ -77,100 +79,90 @@ class BrainClient:
source or self.source, source or self.source,
json.dumps(tags or []), json.dumps(tags or []),
json.dumps(metadata or {}), json.dumps(metadata or {}),
datetime.utcnow().isoformat() datetime.utcnow().isoformat(),
] ]
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
f"{self.rqlite_url}/db/execute",
json=[query, params]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
# Extract inserted ID # Extract inserted ID
last_id = None last_id = None
if "results" in result and result["results"]: if "results" in result and result["results"]:
last_id = result["results"][0].get("last_insert_id") last_id = result["results"][0].get("last_insert_id")
logger.debug(f"Stored memory {last_id}: {content[:50]}...") logger.debug(f"Stored memory {last_id}: {content[:50]}...")
return {"id": last_id, "status": "stored"} return {"id": last_id, "status": "stored"}
except Exception as e: except Exception as e:
logger.error(f"Failed to store memory: {e}") logger.error(f"Failed to store memory: {e}")
raise raise
async def recall( async def recall(
self, self, query: str, limit: int = 5, sources: Optional[List[str]] = None
query: str,
limit: int = 5,
sources: Optional[List[str]] = None
) -> List[str]: ) -> List[str]:
"""Semantic search for memories. """Semantic search for memories.
Args: Args:
query: Search query text query: Search query text
limit: Max results to return limit: Max results to return
sources: Filter by source(s) (e.g., ['timmy', 'user']) sources: Filter by source(s) (e.g., ['timmy', 'user'])
Returns: Returns:
List of memory content strings List of memory content strings
""" """
from brain.embeddings import get_embedder from brain.embeddings import get_embedder
embedder = get_embedder() embedder = get_embedder()
query_emb = embedder.encode_single(query) query_emb = embedder.encode_single(query)
# rqlite with sqlite-vec extension for vector search # rqlite with sqlite-vec extension for vector search
sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?" sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?"
params = [query_emb] params = [query_emb]
if sources: if sources:
placeholders = ",".join(["?"] * len(sources)) placeholders = ",".join(["?"] * len(sources))
sql += f" AND source IN ({placeholders})" sql += f" AND source IN ({placeholders})"
params.extend(sources) params.extend(sources)
sql += " ORDER BY distance LIMIT ?" sql += " ORDER BY distance LIMIT ?"
params.append(limit) params.append(limit)
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
f"{self.rqlite_url}/db/query",
json=[sql, params]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
results = [] results = []
if "results" in result and result["results"]: if "results" in result and result["results"]:
for row in result["results"][0].get("rows", []): for row in result["results"][0].get("rows", []):
results.append({ results.append(
"content": row[0], {
"source": row[1], "content": row[0],
"metadata": json.loads(row[2]) if row[2] else {}, "source": row[1],
"distance": row[3] "metadata": json.loads(row[2]) if row[2] else {},
}) "distance": row[3],
}
)
return results return results
except Exception as e: except Exception as e:
logger.error(f"Failed to search memories: {e}") logger.error(f"Failed to search memories: {e}")
# Graceful fallback - return empty list # Graceful fallback - return empty list
return [] return []
async def get_recent( async def get_recent(
self, self, hours: int = 24, limit: int = 20, sources: Optional[List[str]] = None
hours: int = 24,
limit: int = 20,
sources: Optional[List[str]] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Get recent memories by time. """Get recent memories by time.
Args: Args:
hours: Look back this many hours hours: Look back this many hours
limit: Max results limit: Max results
sources: Optional source filter sources: Optional source filter
Returns: Returns:
List of memory dicts List of memory dicts
""" """
@@ -180,84 +172,83 @@ class BrainClient:
WHERE created_at > datetime('now', ?) WHERE created_at > datetime('now', ?)
""" """
params = [f"-{hours} hours"] params = [f"-{hours} hours"]
if sources: if sources:
placeholders = ",".join(["?"] * len(sources)) placeholders = ",".join(["?"] * len(sources))
sql += f" AND source IN ({placeholders})" sql += f" AND source IN ({placeholders})"
params.extend(sources) params.extend(sources)
sql += " ORDER BY created_at DESC LIMIT ?" sql += " ORDER BY created_at DESC LIMIT ?"
params.append(limit) params.append(limit)
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
f"{self.rqlite_url}/db/query",
json=[sql, params]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
memories = [] memories = []
if "results" in result and result["results"]: if "results" in result and result["results"]:
for row in result["results"][0].get("rows", []): for row in result["results"][0].get("rows", []):
memories.append({ memories.append(
"id": row[0], {
"content": row[1], "id": row[0],
"source": row[2], "content": row[1],
"tags": json.loads(row[3]) if row[3] else [], "source": row[2],
"metadata": json.loads(row[4]) if row[4] else {}, "tags": json.loads(row[3]) if row[3] else [],
"created_at": row[5] "metadata": json.loads(row[4]) if row[4] else {},
}) "created_at": row[5],
}
)
return memories return memories
except Exception as e: except Exception as e:
logger.error(f"Failed to get recent memories: {e}") logger.error(f"Failed to get recent memories: {e}")
return [] return []
async def get_context(self, query: str) -> str: async def get_context(self, query: str) -> str:
"""Get formatted context for system prompt. """Get formatted context for system prompt.
Combines recent memories + relevant memories. Combines recent memories + relevant memories.
Args: Args:
query: Current user query to find relevant context query: Current user query to find relevant context
Returns: Returns:
Formatted context string for prompt injection Formatted context string for prompt injection
""" """
recent = await self.get_recent(hours=24, limit=10) recent = await self.get_recent(hours=24, limit=10)
relevant = await self.recall(query, limit=5) relevant = await self.recall(query, limit=5)
lines = ["Recent activity:"] lines = ["Recent activity:"]
for m in recent[:5]: for m in recent[:5]:
lines.append(f"- {m['content'][:100]}") lines.append(f"- {m['content'][:100]}")
lines.append("\nRelevant memories:") lines.append("\nRelevant memories:")
for r in relevant[:5]: for r in relevant[:5]:
lines.append(f"- {r['content'][:100]}") lines.append(f"- {r['content'][:100]}")
return "\n".join(lines) return "\n".join(lines)
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# Task Queue Operations # Task Queue Operations
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
async def submit_task( async def submit_task(
self, self,
content: str, content: str,
task_type: str = "general", task_type: str = "general",
priority: int = 0, priority: int = 0,
metadata: Optional[Dict[str, Any]] = None metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Submit a task to the distributed queue. """Submit a task to the distributed queue.
Args: Args:
content: Task description/prompt content: Task description/prompt
task_type: Type of task (shell, creative, code, research, general) task_type: Type of task (shell, creative, code, research, general)
priority: Higher = processed first priority: Higher = processed first
metadata: Additional task data metadata: Additional task data
Returns: Returns:
Dict with task 'id' Dict with task 'id'
""" """
@@ -270,50 +261,45 @@ class BrainClient:
task_type, task_type,
priority, priority,
json.dumps(metadata or {}), json.dumps(metadata or {}),
datetime.utcnow().isoformat() datetime.utcnow().isoformat(),
] ]
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
f"{self.rqlite_url}/db/execute",
json=[query, params]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
last_id = None last_id = None
if "results" in result and result["results"]: if "results" in result and result["results"]:
last_id = result["results"][0].get("last_insert_id") last_id = result["results"][0].get("last_insert_id")
logger.info(f"Submitted task {last_id}: {content[:50]}...") logger.info(f"Submitted task {last_id}: {content[:50]}...")
return {"id": last_id, "status": "queued"} return {"id": last_id, "status": "queued"}
except Exception as e: except Exception as e:
logger.error(f"Failed to submit task: {e}") logger.error(f"Failed to submit task: {e}")
raise raise
async def claim_task( async def claim_task(
self, self, capabilities: List[str], node_id: Optional[str] = None
capabilities: List[str],
node_id: Optional[str] = None
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""Atomically claim next available task. """Atomically claim next available task.
Uses UPDATE ... RETURNING pattern for atomic claim. Uses UPDATE ... RETURNING pattern for atomic claim.
Args: Args:
capabilities: List of capabilities this node has capabilities: List of capabilities this node has
node_id: Identifier for claiming node node_id: Identifier for claiming node
Returns: Returns:
Task dict or None if no tasks available Task dict or None if no tasks available
""" """
claimer = node_id or self.node_id claimer = node_id or self.node_id
# Try to claim a matching task atomically # Try to claim a matching task atomically
# This works because rqlite uses Raft consensus - only one node wins # This works because rqlite uses Raft consensus - only one node wins
placeholders = ",".join(["?"] * len(capabilities)) placeholders = ",".join(["?"] * len(capabilities))
query = f""" query = f"""
UPDATE tasks UPDATE tasks
SET status = 'claimed', SET status = 'claimed',
@@ -330,15 +316,12 @@ class BrainClient:
RETURNING id, content, task_type, priority, metadata RETURNING id, content, task_type, priority, metadata
""" """
params = [claimer, datetime.utcnow().isoformat()] + capabilities params = [claimer, datetime.utcnow().isoformat()] + capabilities
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
f"{self.rqlite_url}/db/execute",
json=[query, params]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
if "results" in result and result["results"]: if "results" in result and result["results"]:
rows = result["results"][0].get("rows", []) rows = result["results"][0].get("rows", [])
if rows: if rows:
@@ -348,24 +331,20 @@ class BrainClient:
"content": row[1], "content": row[1],
"type": row[2], "type": row[2],
"priority": row[3], "priority": row[3],
"metadata": json.loads(row[4]) if row[4] else {} "metadata": json.loads(row[4]) if row[4] else {},
} }
return None return None
except Exception as e: except Exception as e:
logger.error(f"Failed to claim task: {e}") logger.error(f"Failed to claim task: {e}")
return None return None
async def complete_task( async def complete_task(
self, self, task_id: int, success: bool, result: Optional[str] = None, error: Optional[str] = None
task_id: int,
success: bool,
result: Optional[str] = None,
error: Optional[str] = None
) -> None: ) -> None:
"""Mark task as completed or failed. """Mark task as completed or failed.
Args: Args:
task_id: Task ID task_id: Task ID
success: True if task succeeded success: True if task succeeded
@@ -373,7 +352,7 @@ class BrainClient:
error: Error message if failed error: Error message if failed
""" """
status = "done" if success else "failed" status = "done" if success else "failed"
query = """ query = """
UPDATE tasks UPDATE tasks
SET status = ?, SET status = ?,
@@ -383,23 +362,20 @@ class BrainClient:
WHERE id = ? WHERE id = ?
""" """
params = [status, result, error, datetime.utcnow().isoformat(), task_id] params = [status, result, error, datetime.utcnow().isoformat(), task_id]
try: try:
await self._client.post( await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
f"{self.rqlite_url}/db/execute",
json=[query, params]
)
logger.debug(f"Task {task_id} marked {status}") logger.debug(f"Task {task_id} marked {status}")
except Exception as e: except Exception as e:
logger.error(f"Failed to complete task {task_id}: {e}") logger.error(f"Failed to complete task {task_id}: {e}")
async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]: async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get list of pending tasks (for dashboard/monitoring). """Get list of pending tasks (for dashboard/monitoring).
Args: Args:
limit: Max tasks to return limit: Max tasks to return
Returns: Returns:
List of pending task dicts List of pending task dicts
""" """
@@ -410,33 +386,32 @@ class BrainClient:
ORDER BY priority DESC, created_at ASC ORDER BY priority DESC, created_at ASC
LIMIT ? LIMIT ?
""" """
try: try:
resp = await self._client.post( resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, [limit]])
f"{self.rqlite_url}/db/query",
json=[sql, [limit]]
)
resp.raise_for_status() resp.raise_for_status()
result = resp.json() result = resp.json()
tasks = [] tasks = []
if "results" in result and result["results"]: if "results" in result and result["results"]:
for row in result["results"][0].get("rows", []): for row in result["results"][0].get("rows", []):
tasks.append({ tasks.append(
"id": row[0], {
"content": row[1], "id": row[0],
"type": row[2], "content": row[1],
"priority": row[3], "type": row[2],
"metadata": json.loads(row[4]) if row[4] else {}, "priority": row[3],
"created_at": row[5] "metadata": json.loads(row[4]) if row[4] else {},
}) "created_at": row[5],
}
)
return tasks return tasks
except Exception as e: except Exception as e:
logger.error(f"Failed to get pending tasks: {e}") logger.error(f"Failed to get pending tasks: {e}")
return [] return []
async def close(self): async def close(self):
"""Close HTTP client.""" """Close HTTP client."""
await self._client.aclose() await self._client.aclose()

View File

@@ -18,48 +18,51 @@ _dimensions = 384
class LocalEmbedder: class LocalEmbedder:
"""Local sentence transformer for embeddings. """Local sentence transformer for embeddings.
Uses all-MiniLM-L6-v2 (80MB download, runs on CPU). Uses all-MiniLM-L6-v2 (80MB download, runs on CPU).
384-dimensional embeddings, good enough for semantic search. 384-dimensional embeddings, good enough for semantic search.
""" """
def __init__(self, model_name: str = _model_name): def __init__(self, model_name: str = _model_name):
self.model_name = model_name self.model_name = model_name
self._model = None self._model = None
self._dimensions = _dimensions self._dimensions = _dimensions
def _load_model(self): def _load_model(self):
"""Lazy load the model.""" """Lazy load the model."""
global _model global _model
if _model is not None: if _model is not None:
self._model = _model self._model = _model
return return
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
logger.info(f"Loading embedding model: {self.model_name}") logger.info(f"Loading embedding model: {self.model_name}")
_model = SentenceTransformer(self.model_name) _model = SentenceTransformer(self.model_name)
self._model = _model self._model = _model
logger.info(f"Embedding model loaded ({self._dimensions} dims)") logger.info(f"Embedding model loaded ({self._dimensions} dims)")
except ImportError: except ImportError:
logger.error("sentence-transformers not installed. Run: pip install sentence-transformers") logger.error(
"sentence-transformers not installed. Run: pip install sentence-transformers"
)
raise raise
def encode(self, text: Union[str, List[str]]): def encode(self, text: Union[str, List[str]]):
"""Encode text to embedding vector(s). """Encode text to embedding vector(s).
Args: Args:
text: String or list of strings to encode text: String or list of strings to encode
Returns: Returns:
Numpy array of shape (dims,) for single string or (n, dims) for list Numpy array of shape (dims,) for single string or (n, dims) for list
""" """
if self._model is None: if self._model is None:
self._load_model() self._load_model()
# Normalize embeddings for cosine similarity # Normalize embeddings for cosine similarity
return self._model.encode(text, normalize_embeddings=True) return self._model.encode(text, normalize_embeddings=True)
def encode_single(self, text: str) -> bytes: def encode_single(self, text: str) -> bytes:
"""Encode single text to bytes for SQLite storage. """Encode single text to bytes for SQLite storage.
@@ -67,17 +70,19 @@ class LocalEmbedder:
Float32 bytes Float32 bytes
""" """
import numpy as np import numpy as np
embedding = self.encode(text) embedding = self.encode(text)
if len(embedding.shape) > 1: if len(embedding.shape) > 1:
embedding = embedding[0] embedding = embedding[0]
return embedding.astype(np.float32).tobytes() return embedding.astype(np.float32).tobytes()
def similarity(self, a, b) -> float: def similarity(self, a, b) -> float:
"""Compute cosine similarity between two vectors. """Compute cosine similarity between two vectors.
Vectors should already be normalized from encode(). Vectors should already be normalized from encode().
""" """
import numpy as np import numpy as np
return float(np.dot(a, b)) return float(np.dot(a, b))

View File

@@ -48,6 +48,7 @@ _SCHEMA_VERSION = 1
def _get_db_path() -> Path: def _get_db_path() -> Path:
"""Get the brain database path from env or default.""" """Get the brain database path from env or default."""
from config import settings from config import settings
if settings.brain_db_path: if settings.brain_db_path:
return Path(settings.brain_db_path) return Path(settings.brain_db_path)
return _DEFAULT_DB_PATH return _DEFAULT_DB_PATH
@@ -75,6 +76,7 @@ class UnifiedMemory:
# Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite # Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite
if use_rqlite is None: if use_rqlite is None:
from config import settings as _settings from config import settings as _settings
use_rqlite = bool(_settings.rqlite_url) use_rqlite = bool(_settings.rqlite_url)
self._use_rqlite = use_rqlite self._use_rqlite = use_rqlite
@@ -107,10 +109,12 @@ class UnifiedMemory:
"""Lazy-load the embedding model.""" """Lazy-load the embedding model."""
if self._embedder is None: if self._embedder is None:
from config import settings as _settings from config import settings as _settings
if _settings.timmy_skip_embeddings: if _settings.timmy_skip_embeddings:
return None return None
try: try:
from brain.embeddings import LocalEmbedder from brain.embeddings import LocalEmbedder
self._embedder = LocalEmbedder() self._embedder = LocalEmbedder()
except ImportError: except ImportError:
logger.warning("sentence-transformers not available — semantic search disabled") logger.warning("sentence-transformers not available — semantic search disabled")
@@ -125,6 +129,7 @@ class UnifiedMemory:
"""Lazy-load the rqlite BrainClient.""" """Lazy-load the rqlite BrainClient."""
if self._rqlite_client is None: if self._rqlite_client is None:
from brain.client import BrainClient from brain.client import BrainClient
self._rqlite_client = BrainClient() self._rqlite_client = BrainClient()
return self._rqlite_client return self._rqlite_client
@@ -292,15 +297,17 @@ class UnifiedMemory:
results = [] results = []
for score, row in scored[:limit]: for score, row in scored[:limit]:
results.append({ results.append(
"id": row["id"], {
"content": row["content"], "id": row["id"],
"source": row["source"], "content": row["content"],
"tags": json.loads(row["tags"]) if row["tags"] else [], "source": row["source"],
"metadata": json.loads(row["metadata"]) if row["metadata"] else {}, "tags": json.loads(row["tags"]) if row["tags"] else [],
"score": score, "metadata": json.loads(row["metadata"]) if row["metadata"] else {},
"created_at": row["created_at"], "score": score,
}) "created_at": row["created_at"],
}
)
return results return results
finally: finally:

View File

@@ -84,11 +84,13 @@ def get_migration_sql(from_version: int, to_version: int) -> str:
"""Get SQL to migrate between versions.""" """Get SQL to migrate between versions."""
if to_version <= from_version: if to_version <= from_version:
return "" return ""
sql_parts = [] sql_parts = []
for v in range(from_version + 1, to_version + 1): for v in range(from_version + 1, to_version + 1):
if v in MIGRATIONS: if v in MIGRATIONS:
sql_parts.append(MIGRATIONS[v]) sql_parts.append(MIGRATIONS[v])
sql_parts.append(f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');") sql_parts.append(
f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');"
)
return "\n".join(sql_parts) return "\n".join(sql_parts)

View File

@@ -21,11 +21,11 @@ logger = logging.getLogger(__name__)
class DistributedWorker: class DistributedWorker:
"""Continuous task processor for the distributed brain. """Continuous task processor for the distributed brain.
Runs on every device, claims tasks matching its capabilities, Runs on every device, claims tasks matching its capabilities,
executes them immediately, stores results. executes them immediately, stores results.
""" """
def __init__(self, brain_client: Optional[BrainClient] = None): def __init__(self, brain_client: Optional[BrainClient] = None):
self.brain = brain_client or BrainClient() self.brain = brain_client or BrainClient()
self.node_id = f"{socket.gethostname()}-{os.getpid()}" self.node_id = f"{socket.gethostname()}-{os.getpid()}"
@@ -33,30 +33,30 @@ class DistributedWorker:
self.running = False self.running = False
self._handlers: Dict[str, Callable] = {} self._handlers: Dict[str, Callable] = {}
self._register_default_handlers() self._register_default_handlers()
def _detect_capabilities(self) -> List[str]: def _detect_capabilities(self) -> List[str]:
"""Detect what this node can do.""" """Detect what this node can do."""
caps = ["general", "shell", "file_ops", "git"] caps = ["general", "shell", "file_ops", "git"]
# Check for GPU # Check for GPU
if self._has_gpu(): if self._has_gpu():
caps.append("gpu") caps.append("gpu")
caps.append("creative") caps.append("creative")
caps.append("image_gen") caps.append("image_gen")
caps.append("video_gen") caps.append("video_gen")
# Check for internet # Check for internet
if self._has_internet(): if self._has_internet():
caps.append("web") caps.append("web")
caps.append("research") caps.append("research")
# Check memory # Check memory
mem_gb = self._get_memory_gb() mem_gb = self._get_memory_gb()
if mem_gb > 16: if mem_gb > 16:
caps.append("large_model") caps.append("large_model")
if mem_gb > 32: if mem_gb > 32:
caps.append("huge_model") caps.append("huge_model")
# Check for specific tools # Check for specific tools
if self._has_command("ollama"): if self._has_command("ollama"):
caps.append("ollama") caps.append("ollama")
@@ -64,17 +64,15 @@ class DistributedWorker:
caps.append("docker") caps.append("docker")
if self._has_command("cargo"): if self._has_command("cargo"):
caps.append("rust") caps.append("rust")
logger.info(f"Worker capabilities: {caps}") logger.info(f"Worker capabilities: {caps}")
return caps return caps
def _has_gpu(self) -> bool: def _has_gpu(self) -> bool:
"""Check for NVIDIA or AMD GPU.""" """Check for NVIDIA or AMD GPU."""
try: try:
# Check for nvidia-smi # Check for nvidia-smi
result = subprocess.run( result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=5)
["nvidia-smi"], capture_output=True, timeout=5
)
if result.returncode == 0: if result.returncode == 0:
return True return True
except (OSError, subprocess.SubprocessError): except (OSError, subprocess.SubprocessError):
@@ -83,13 +81,15 @@ class DistributedWorker:
# Check for ROCm # Check for ROCm
if os.path.exists("/opt/rocm"): if os.path.exists("/opt/rocm"):
return True return True
# Check for Apple Silicon Metal # Check for Apple Silicon Metal
if os.uname().sysname == "Darwin": if os.uname().sysname == "Darwin":
try: try:
result = subprocess.run( result = subprocess.run(
["system_profiler", "SPDisplaysDataType"], ["system_profiler", "SPDisplaysDataType"],
capture_output=True, text=True, timeout=5 capture_output=True,
text=True,
timeout=5,
) )
if "Metal" in result.stdout: if "Metal" in result.stdout:
return True return True
@@ -102,8 +102,7 @@ class DistributedWorker:
"""Check if we have internet connectivity.""" """Check if we have internet connectivity."""
try: try:
result = subprocess.run( result = subprocess.run(
["curl", "-s", "--max-time", "3", "https://1.1.1.1"], ["curl", "-s", "--max-time", "3", "https://1.1.1.1"], capture_output=True, timeout=5
capture_output=True, timeout=5
) )
return result.returncode == 0 return result.returncode == 0
except (OSError, subprocess.SubprocessError): except (OSError, subprocess.SubprocessError):
@@ -114,8 +113,7 @@ class DistributedWorker:
try: try:
if os.uname().sysname == "Darwin": if os.uname().sysname == "Darwin":
result = subprocess.run( result = subprocess.run(
["sysctl", "-n", "hw.memsize"], ["sysctl", "-n", "hw.memsize"], capture_output=True, text=True
capture_output=True, text=True
) )
bytes_mem = int(result.stdout.strip()) bytes_mem = int(result.stdout.strip())
return bytes_mem / (1024**3) return bytes_mem / (1024**3)
@@ -128,13 +126,11 @@ class DistributedWorker:
except (OSError, ValueError): except (OSError, ValueError):
pass pass
return 8.0 # Assume 8GB if we can't detect return 8.0 # Assume 8GB if we can't detect
def _has_command(self, cmd: str) -> bool: def _has_command(self, cmd: str) -> bool:
"""Check if command exists.""" """Check if command exists."""
try: try:
result = subprocess.run( result = subprocess.run(["which", cmd], capture_output=True, timeout=5)
["which", cmd], capture_output=True, timeout=5
)
return result.returncode == 0 return result.returncode == 0
except (OSError, subprocess.SubprocessError): except (OSError, subprocess.SubprocessError):
return False return False
@@ -148,10 +144,10 @@ class DistributedWorker:
"research": self._handle_research, "research": self._handle_research,
"general": self._handle_general, "general": self._handle_general,
} }
def register_handler(self, task_type: str, handler: Callable[[str], Any]): def register_handler(self, task_type: str, handler: Callable[[str], Any]):
"""Register a custom task handler. """Register a custom task handler.
Args: Args:
task_type: Type of task this handler handles task_type: Type of task this handler handles
handler: Async function that takes task content and returns result handler: Async function that takes task content and returns result
@@ -159,11 +155,11 @@ class DistributedWorker:
self._handlers[task_type] = handler self._handlers[task_type] = handler
if task_type not in self.capabilities: if task_type not in self.capabilities:
self.capabilities.append(task_type) self.capabilities.append(task_type)
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# Task Handlers # Task Handlers
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
async def _handle_shell(self, command: str) -> str: async def _handle_shell(self, command: str) -> str:
"""Execute shell command via ZeroClaw or direct subprocess.""" """Execute shell command via ZeroClaw or direct subprocess."""
# Try ZeroClaw first if available # Try ZeroClaw first if available
@@ -171,156 +167,153 @@ class DistributedWorker:
proc = await asyncio.create_subprocess_shell( proc = await asyncio.create_subprocess_shell(
f"zeroclaw exec --json '{command}'", f"zeroclaw exec --json '{command}'",
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
stdout, stderr = await proc.communicate() stdout, stderr = await proc.communicate()
# Store result in brain # Store result in brain
await self.brain.remember( await self.brain.remember(
content=f"Shell: {command}\nOutput: {stdout.decode()}", content=f"Shell: {command}\nOutput: {stdout.decode()}",
tags=["shell", "result"], tags=["shell", "result"],
source=self.node_id, source=self.node_id,
metadata={"command": command, "exit_code": proc.returncode} metadata={"command": command, "exit_code": proc.returncode},
) )
if proc.returncode != 0: if proc.returncode != 0:
raise Exception(f"Command failed: {stderr.decode()}") raise Exception(f"Command failed: {stderr.decode()}")
return stdout.decode() return stdout.decode()
# Fallback to direct subprocess (less safe) # Fallback to direct subprocess (less safe)
proc = await asyncio.create_subprocess_shell( proc = await asyncio.create_subprocess_shell(
command, command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
) )
stdout, stderr = await proc.communicate() stdout, stderr = await proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise Exception(f"Command failed: {stderr.decode()}") raise Exception(f"Command failed: {stderr.decode()}")
return stdout.decode() return stdout.decode()
async def _handle_creative(self, prompt: str) -> str: async def _handle_creative(self, prompt: str) -> str:
"""Generate creative media (requires GPU).""" """Generate creative media (requires GPU)."""
if "gpu" not in self.capabilities: if "gpu" not in self.capabilities:
raise Exception("GPU not available on this node") raise Exception("GPU not available on this node")
# This would call creative tools (Stable Diffusion, etc.) # This would call creative tools (Stable Diffusion, etc.)
# For now, placeholder # For now, placeholder
logger.info(f"Creative task: {prompt[:50]}...") logger.info(f"Creative task: {prompt[:50]}...")
# Store result # Store result
result = f"Creative output for: {prompt}" result = f"Creative output for: {prompt}"
await self.brain.remember( await self.brain.remember(
content=result, content=result,
tags=["creative", "generated"], tags=["creative", "generated"],
source=self.node_id, source=self.node_id,
metadata={"prompt": prompt} metadata={"prompt": prompt},
) )
return result return result
async def _handle_code(self, description: str) -> str: async def _handle_code(self, description: str) -> str:
"""Code generation and modification.""" """Code generation and modification."""
# Would use LLM to generate code # Would use LLM to generate code
# For now, placeholder # For now, placeholder
logger.info(f"Code task: {description[:50]}...") logger.info(f"Code task: {description[:50]}...")
return f"Code generated for: {description}" return f"Code generated for: {description}"
async def _handle_research(self, query: str) -> str: async def _handle_research(self, query: str) -> str:
"""Web research.""" """Web research."""
if "web" not in self.capabilities: if "web" not in self.capabilities:
raise Exception("Internet not available on this node") raise Exception("Internet not available on this node")
# Would use browser automation or search # Would use browser automation or search
logger.info(f"Research task: {query[:50]}...") logger.info(f"Research task: {query[:50]}...")
return f"Research results for: {query}" return f"Research results for: {query}"
async def _handle_general(self, prompt: str) -> str: async def _handle_general(self, prompt: str) -> str:
"""General LLM task via local Ollama.""" """General LLM task via local Ollama."""
if "ollama" not in self.capabilities: if "ollama" not in self.capabilities:
raise Exception("Ollama not available on this node") raise Exception("Ollama not available on this node")
# Call Ollama # Call Ollama
try: try:
proc = await asyncio.create_subprocess_exec( proc = await asyncio.create_subprocess_exec(
"curl", "-s", "http://localhost:11434/api/generate", "curl",
"-d", json.dumps({ "-s",
"model": "llama3.1:8b-instruct", "http://localhost:11434/api/generate",
"prompt": prompt, "-d",
"stream": False json.dumps({"model": "llama3.1:8b-instruct", "prompt": prompt, "stream": False}),
}), stdout=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE
) )
stdout, _ = await proc.communicate() stdout, _ = await proc.communicate()
response = json.loads(stdout.decode()) response = json.loads(stdout.decode())
result = response.get("response", "No response") result = response.get("response", "No response")
# Store in brain # Store in brain
await self.brain.remember( await self.brain.remember(
content=f"Task: {prompt}\nResult: {result}", content=f"Task: {prompt}\nResult: {result}",
tags=["llm", "result"], tags=["llm", "result"],
source=self.node_id, source=self.node_id,
metadata={"model": "llama3.1:8b-instruct"} metadata={"model": "llama3.1:8b-instruct"},
) )
return result return result
except Exception as e: except Exception as e:
raise Exception(f"LLM failed: {e}") raise Exception(f"LLM failed: {e}")
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# Main Loop # Main Loop
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]: async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
"""Execute a claimed task.""" """Execute a claimed task."""
task_type = task.get("type", "general") task_type = task.get("type", "general")
content = task.get("content", "") content = task.get("content", "")
task_id = task.get("id") task_id = task.get("id")
handler = self._handlers.get(task_type, self._handlers["general"]) handler = self._handlers.get(task_type, self._handlers["general"])
try: try:
logger.info(f"Executing task {task_id}: {task_type}") logger.info(f"Executing task {task_id}: {task_type}")
result = await handler(content) result = await handler(content)
await self.brain.complete_task(task_id, success=True, result=result) await self.brain.complete_task(task_id, success=True, result=result)
logger.info(f"Task {task_id} completed") logger.info(f"Task {task_id} completed")
return {"success": True, "result": result} return {"success": True, "result": result}
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
logger.error(f"Task {task_id} failed: {error_msg}") logger.error(f"Task {task_id} failed: {error_msg}")
await self.brain.complete_task(task_id, success=False, error=error_msg) await self.brain.complete_task(task_id, success=False, error=error_msg)
return {"success": False, "error": error_msg} return {"success": False, "error": error_msg}
async def run_once(self) -> bool: async def run_once(self) -> bool:
"""Process one task if available. """Process one task if available.
Returns: Returns:
True if a task was processed, False if no tasks available True if a task was processed, False if no tasks available
""" """
task = await self.brain.claim_task(self.capabilities, self.node_id) task = await self.brain.claim_task(self.capabilities, self.node_id)
if task: if task:
await self.execute_task(task) await self.execute_task(task)
return True return True
return False return False
async def run(self): async def run(self):
"""Main loop — continuously process tasks.""" """Main loop — continuously process tasks."""
logger.info(f"Worker {self.node_id} started") logger.info(f"Worker {self.node_id} started")
logger.info(f"Capabilities: {self.capabilities}") logger.info(f"Capabilities: {self.capabilities}")
self.running = True self.running = True
consecutive_empty = 0 consecutive_empty = 0
while self.running: while self.running:
try: try:
had_work = await self.run_once() had_work = await self.run_once()
if had_work: if had_work:
# Immediately check for more work # Immediately check for more work
consecutive_empty = 0 consecutive_empty = 0
@@ -331,11 +324,11 @@ class DistributedWorker:
# Sleep 0.5s, but up to 2s if consistently empty # Sleep 0.5s, but up to 2s if consistently empty
sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0) sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0)
await asyncio.sleep(sleep_time) await asyncio.sleep(sleep_time)
except Exception as e: except Exception as e:
logger.error(f"Worker error: {e}") logger.error(f"Worker error: {e}")
await asyncio.sleep(1) await asyncio.sleep(1)
def stop(self): def stop(self):
"""Stop the worker loop.""" """Stop the worker loop."""
self.running = False self.running = False
@@ -345,7 +338,7 @@ class DistributedWorker:
async def main(): async def main():
"""CLI entry point for worker.""" """CLI entry point for worker."""
import sys import sys
# Allow capability overrides from CLI # Allow capability overrides from CLI
if len(sys.argv) > 1: if len(sys.argv) > 1:
caps = sys.argv[1].split(",") caps = sys.argv[1].split(",")
@@ -354,12 +347,12 @@ async def main():
logger.info(f"Overriding capabilities: {caps}") logger.info(f"Overriding capabilities: {caps}")
else: else:
worker = DistributedWorker() worker = DistributedWorker()
try: try:
await worker.run() await worker.run()
except KeyboardInterrupt: except KeyboardInterrupt:
worker.stop() worker.stop()
print("\nWorker stopped.") logger.info("Worker stopped.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -213,6 +213,15 @@ class Settings(BaseSettings):
# Timeout in seconds for OpenFang hand execution (some hands are slow). # Timeout in seconds for OpenFang hand execution (some hands are slow).
openfang_timeout: int = 120 openfang_timeout: int = 120
# ── Autoresearch — autonomous ML experiment loops ──────────────────
# Integrates Karpathy's autoresearch pattern: agents modify training
# code, run time-boxed experiments, evaluate metrics, and iterate.
autoresearch_enabled: bool = False
autoresearch_workspace: str = "data/experiments"
autoresearch_time_budget: int = 300 # seconds per experiment run
autoresearch_max_iterations: int = 100
autoresearch_metric: str = "val_bpb" # metric to optimise (lower = better)
# ── Local Hands (Shell + Git) ────────────────────────────────────── # ── Local Hands (Shell + Git) ──────────────────────────────────────
# Enable local shell/git execution hands. # Enable local shell/git execution hands.
hands_shell_enabled: bool = True hands_shell_enabled: bool = True

View File

@@ -18,36 +18,38 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from config import settings from config import settings
from dashboard.routes.agents import router as agents_router
from dashboard.routes.health import router as health_router
from dashboard.routes.marketplace import router as marketplace_router
from dashboard.routes.voice import router as voice_router
from dashboard.routes.mobile import router as mobile_router
from dashboard.routes.briefing import router as briefing_router
from dashboard.routes.telegram import router as telegram_router
from dashboard.routes.tools import router as tools_router
from dashboard.routes.spark import router as spark_router
from dashboard.routes.discord import router as discord_router
from dashboard.routes.memory import router as memory_router
from dashboard.routes.router import router as router_status_router
from dashboard.routes.grok import router as grok_router
from dashboard.routes.models import router as models_router
from dashboard.routes.models import api_router as models_api_router
from dashboard.routes.chat_api import router as chat_api_router
from dashboard.routes.thinking import router as thinking_router
from dashboard.routes.calm import router as calm_router
from dashboard.routes.swarm import router as swarm_router
from dashboard.routes.tasks import router as tasks_router
from dashboard.routes.work_orders import router as work_orders_router
from dashboard.routes.system import router as system_router
from dashboard.routes.paperclip import router as paperclip_router
from infrastructure.router.api import router as cascade_router
# Import dedicated middleware # Import dedicated middleware
from dashboard.middleware.csrf import CSRFMiddleware from dashboard.middleware.csrf import CSRFMiddleware
from dashboard.middleware.request_logging import RequestLoggingMiddleware from dashboard.middleware.request_logging import RequestLoggingMiddleware
from dashboard.middleware.security_headers import SecurityHeadersMiddleware from dashboard.middleware.security_headers import SecurityHeadersMiddleware
from dashboard.routes.agents import router as agents_router
from dashboard.routes.briefing import router as briefing_router
from dashboard.routes.calm import router as calm_router
from dashboard.routes.chat_api import router as chat_api_router
from dashboard.routes.discord import router as discord_router
from dashboard.routes.experiments import router as experiments_router
from dashboard.routes.grok import router as grok_router
from dashboard.routes.health import router as health_router
from dashboard.routes.marketplace import router as marketplace_router
from dashboard.routes.memory import router as memory_router
from dashboard.routes.mobile import router as mobile_router
from dashboard.routes.models import api_router as models_api_router
from dashboard.routes.models import router as models_router
from dashboard.routes.paperclip import router as paperclip_router
from dashboard.routes.router import router as router_status_router
from dashboard.routes.spark import router as spark_router
from dashboard.routes.swarm import router as swarm_router
from dashboard.routes.system import router as system_router
from dashboard.routes.tasks import router as tasks_router
from dashboard.routes.telegram import router as telegram_router
from dashboard.routes.thinking import router as thinking_router
from dashboard.routes.tools import router as tools_router
from dashboard.routes.voice import router as voice_router
from dashboard.routes.work_orders import router as work_orders_router
from infrastructure.router.api import router as cascade_router
def _configure_logging() -> None: def _configure_logging() -> None:
@@ -100,8 +102,8 @@ _BRIEFING_INTERVAL_HOURS = 6
async def _briefing_scheduler() -> None: async def _briefing_scheduler() -> None:
"""Background task: regenerate Timmy's briefing every 6 hours.""" """Background task: regenerate Timmy's briefing every 6 hours."""
from timmy.briefing import engine as briefing_engine
from infrastructure.notifications.push import notify_briefing_ready from infrastructure.notifications.push import notify_briefing_ready
from timmy.briefing import engine as briefing_engine
await asyncio.sleep(2) await asyncio.sleep(2)
@@ -121,9 +123,9 @@ async def _briefing_scheduler() -> None:
async def _start_chat_integrations_background() -> None: async def _start_chat_integrations_background() -> None:
"""Background task: start chat integrations without blocking startup.""" """Background task: start chat integrations without blocking startup."""
from integrations.telegram_bot.bot import telegram_bot
from integrations.chat_bridge.vendors.discord import discord_bot
from integrations.chat_bridge.registry import platform_registry from integrations.chat_bridge.registry import platform_registry
from integrations.chat_bridge.vendors.discord import discord_bot
from integrations.telegram_bot.bot import telegram_bot
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
@@ -164,9 +166,9 @@ async def _discord_token_watcher() -> None:
if discord_bot.state.name == "CONNECTED": if discord_bot.state.name == "CONNECTED":
return # Already running — stop watching return # Already running — stop watching
# 1. Check live environment variable (intentionally uses os.environ, # 1. Check settings (pydantic-settings reads env on instantiation;
# not settings, because this polls for runtime hot-reload changes) # hot-reload is handled by re-reading .env below)
token = os.environ.get("DISCORD_TOKEN", "") token = settings.discord_token
# 2. Re-read .env file for hot-reload # 2. Re-read .env file for hot-reload
if not token: if not token:
@@ -203,6 +205,7 @@ async def lifespan(app: FastAPI):
# Initialize Spark Intelligence engine # Initialize Spark Intelligence engine
from spark.engine import spark_engine from spark.engine import spark_engine
if spark_engine.enabled: if spark_engine.enabled:
logger.info("Spark Intelligence active — event capture enabled") logger.info("Spark Intelligence active — event capture enabled")
@@ -210,12 +213,17 @@ async def lifespan(app: FastAPI):
if settings.memory_prune_days > 0: if settings.memory_prune_days > 0:
try: try:
from timmy.memory.vector_store import prune_memories from timmy.memory.vector_store import prune_memories
pruned = prune_memories( pruned = prune_memories(
older_than_days=settings.memory_prune_days, older_than_days=settings.memory_prune_days,
keep_facts=settings.memory_prune_keep_facts, keep_facts=settings.memory_prune_keep_facts,
) )
if pruned: if pruned:
logger.info("Memory auto-prune: removed %d entries older than %d days", pruned, settings.memory_prune_days) logger.info(
"Memory auto-prune: removed %d entries older than %d days",
pruned,
settings.memory_prune_days,
)
except Exception as exc: except Exception as exc:
logger.debug("Memory auto-prune skipped: %s", exc) logger.debug("Memory auto-prune skipped: %s", exc)
@@ -229,7 +237,8 @@ async def lifespan(app: FastAPI):
if total_mb > settings.memory_vault_max_mb: if total_mb > settings.memory_vault_max_mb:
logger.warning( logger.warning(
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes", "Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
total_mb, settings.memory_vault_max_mb, total_mb,
settings.memory_vault_max_mb,
) )
except Exception as exc: except Exception as exc:
logger.debug("Vault size check skipped: %s", exc) logger.debug("Vault size check skipped: %s", exc)
@@ -284,10 +293,7 @@ def _get_cors_origins() -> list[str]:
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"]) app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
# 2. Security Headers # 2. Security Headers
app.add_middleware( app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
SecurityHeadersMiddleware,
production=not settings.debug
)
# 3. CSRF Protection # 3. CSRF Protection
app.add_middleware(CSRFMiddleware) app.add_middleware(CSRFMiddleware)
@@ -314,7 +320,6 @@ if static_dir.exists():
# Shared templates instance # Shared templates instance
from dashboard.templating import templates # noqa: E402 from dashboard.templating import templates # noqa: E402
# Include routers # Include routers
app.include_router(health_router) app.include_router(health_router)
app.include_router(agents_router) app.include_router(agents_router)
@@ -339,6 +344,7 @@ app.include_router(tasks_router)
app.include_router(work_orders_router) app.include_router(work_orders_router)
app.include_router(system_router) app.include_router(system_router)
app.include_router(paperclip_router) app.include_router(paperclip_router)
app.include_router(experiments_router)
app.include_router(cascade_router) app.include_router(cascade_router)

View File

@@ -1,8 +1,8 @@
"""Dashboard middleware package.""" """Dashboard middleware package."""
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
from .security_headers import SecurityHeadersMiddleware
from .request_logging import RequestLoggingMiddleware from .request_logging import RequestLoggingMiddleware
from .security_headers import SecurityHeadersMiddleware
__all__ = [ __all__ = [
"CSRFMiddleware", "CSRFMiddleware",

View File

@@ -4,16 +4,15 @@ Provides CSRF token generation, validation, and middleware integration
to protect state-changing endpoints from cross-site request attacks. to protect state-changing endpoints from cross-site request attacks.
""" """
import secrets
import hmac
import hashlib import hashlib
from typing import Callable, Optional import hmac
import secrets
from functools import wraps from functools import wraps
from typing import Callable, Optional
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response, JSONResponse from starlette.responses import JSONResponse, Response
# Module-level set to track exempt routes # Module-level set to track exempt routes
_exempt_routes: set[str] = set() _exempt_routes: set[str] = set()
@@ -21,26 +20,27 @@ _exempt_routes: set[str] = set()
def csrf_exempt(endpoint: Callable) -> Callable: def csrf_exempt(endpoint: Callable) -> Callable:
"""Decorator to mark an endpoint as exempt from CSRF validation. """Decorator to mark an endpoint as exempt from CSRF validation.
Usage: Usage:
@app.post("/webhook") @app.post("/webhook")
@csrf_exempt @csrf_exempt
def webhook_endpoint(): def webhook_endpoint():
... ...
""" """
@wraps(endpoint) @wraps(endpoint)
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
return await endpoint(*args, **kwargs) return await endpoint(*args, **kwargs)
@wraps(endpoint) @wraps(endpoint)
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
return endpoint(*args, **kwargs) return endpoint(*args, **kwargs)
# Mark the original function as exempt # Mark the original function as exempt
endpoint._csrf_exempt = True # type: ignore endpoint._csrf_exempt = True # type: ignore
# Also mark the wrapper # Also mark the wrapper
if hasattr(endpoint, '__code__') and endpoint.__code__.co_flags & 0x80: if hasattr(endpoint, "__code__") and endpoint.__code__.co_flags & 0x80:
async_wrapper._csrf_exempt = True # type: ignore async_wrapper._csrf_exempt = True # type: ignore
return async_wrapper return async_wrapper
else: else:
@@ -50,12 +50,12 @@ def csrf_exempt(endpoint: Callable) -> Callable:
def is_csrf_exempt(endpoint: Callable) -> bool: def is_csrf_exempt(endpoint: Callable) -> bool:
"""Check if an endpoint is marked as CSRF exempt.""" """Check if an endpoint is marked as CSRF exempt."""
return getattr(endpoint, '_csrf_exempt', False) return getattr(endpoint, "_csrf_exempt", False)
def generate_csrf_token() -> str: def generate_csrf_token() -> str:
"""Generate a cryptographically secure CSRF token. """Generate a cryptographically secure CSRF token.
Returns: Returns:
A secure random token string. A secure random token string.
""" """
@@ -64,77 +64,78 @@ def generate_csrf_token() -> str:
def validate_csrf_token(token: str, expected_token: str) -> bool: def validate_csrf_token(token: str, expected_token: str) -> bool:
"""Validate a CSRF token against the expected token. """Validate a CSRF token against the expected token.
Uses constant-time comparison to prevent timing attacks. Uses constant-time comparison to prevent timing attacks.
Args: Args:
token: The token provided by the client. token: The token provided by the client.
expected_token: The expected token (from cookie/session). expected_token: The expected token (from cookie/session).
Returns: Returns:
True if the token is valid, False otherwise. True if the token is valid, False otherwise.
""" """
if not token or not expected_token: if not token or not expected_token:
return False return False
return hmac.compare_digest(token, expected_token) return hmac.compare_digest(token, expected_token)
class CSRFMiddleware(BaseHTTPMiddleware): class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware to enforce CSRF protection on state-changing requests. """Middleware to enforce CSRF protection on state-changing requests.
Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens. Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens.
State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token. State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token.
The token is expected to be: The token is expected to be:
- In the X-CSRF-Token header, or - In the X-CSRF-Token header, or
- In the request body as 'csrf_token', or - In the request body as 'csrf_token', or
- Matching the token in the csrf_token cookie - Matching the token in the csrf_token cookie
Usage: Usage:
app.add_middleware(CSRFMiddleware, secret="your-secret-key") app.add_middleware(CSRFMiddleware, secret="your-secret-key")
Attributes: Attributes:
secret: Secret key for token signing (optional, for future use). secret: Secret key for token signing (optional, for future use).
cookie_name: Name of the CSRF cookie. cookie_name: Name of the CSRF cookie.
header_name: Name of the CSRF header. header_name: Name of the CSRF header.
safe_methods: HTTP methods that don't require CSRF tokens. safe_methods: HTTP methods that don't require CSRF tokens.
""" """
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"} SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
def __init__( def __init__(
self, self,
app, app,
secret: Optional[str] = None, secret: Optional[str] = None,
cookie_name: str = "csrf_token", cookie_name: str = "csrf_token",
header_name: str = "X-CSRF-Token", header_name: str = "X-CSRF-Token",
form_field: str = "csrf_token" form_field: str = "csrf_token",
): ):
super().__init__(app) super().__init__(app)
self.secret = secret self.secret = secret
self.cookie_name = cookie_name self.cookie_name = cookie_name
self.header_name = header_name self.header_name = header_name
self.form_field = form_field self.form_field = form_field
async def dispatch(self, request: Request, call_next) -> Response: async def dispatch(self, request: Request, call_next) -> Response:
"""Process the request and enforce CSRF protection. """Process the request and enforce CSRF protection.
For safe methods: Set a CSRF token cookie if not present. For safe methods: Set a CSRF token cookie if not present.
For unsafe methods: Validate the CSRF token. For unsafe methods: Validate the CSRF token.
""" """
# Bypass CSRF if explicitly disabled (e.g. in tests) # Bypass CSRF if explicitly disabled (e.g. in tests)
from config import settings from config import settings
if settings.timmy_disable_csrf: if settings.timmy_disable_csrf:
return await call_next(request) return await call_next(request)
# Get existing CSRF token from cookie # Get existing CSRF token from cookie
csrf_cookie = request.cookies.get(self.cookie_name) csrf_cookie = request.cookies.get(self.cookie_name)
# For safe methods, just ensure a token exists # For safe methods, just ensure a token exists
if request.method in self.SAFE_METHODS: if request.method in self.SAFE_METHODS:
response = await call_next(request) response = await call_next(request)
# Set CSRF token cookie if not present # Set CSRF token cookie if not present
if not csrf_cookie: if not csrf_cookie:
new_token = generate_csrf_token() new_token = generate_csrf_token()
@@ -144,15 +145,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
httponly=False, # Must be readable by JavaScript httponly=False, # Must be readable by JavaScript
secure=settings.csrf_cookie_secure, secure=settings.csrf_cookie_secure,
samesite="Lax", samesite="Lax",
max_age=86400 # 24 hours max_age=86400, # 24 hours
) )
return response return response
# For unsafe methods, check if route is exempt first # For unsafe methods, check if route is exempt first
# Note: We need to let the request proceed and check at response time # Note: We need to let the request proceed and check at response time
# since FastAPI routes are resolved after middleware # since FastAPI routes are resolved after middleware
# Try to validate token early # Try to validate token early
if not await self._validate_request(request, csrf_cookie): if not await self._validate_request(request, csrf_cookie):
# Check if this might be an exempt route by checking path patterns # Check if this might be an exempt route by checking path patterns
@@ -164,33 +165,34 @@ class CSRFMiddleware(BaseHTTPMiddleware):
content={ content={
"error": "CSRF validation failed", "error": "CSRF validation failed",
"code": "CSRF_INVALID", "code": "CSRF_INVALID",
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field." "message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field.",
} },
) )
return await call_next(request) return await call_next(request)
def _is_likely_exempt(self, path: str) -> bool: def _is_likely_exempt(self, path: str) -> bool:
"""Check if a path is likely to be CSRF exempt. """Check if a path is likely to be CSRF exempt.
Common patterns like webhooks, API endpoints, etc. Common patterns like webhooks, API endpoints, etc.
Uses path normalization and exact/prefix matching to prevent bypasses. Uses path normalization and exact/prefix matching to prevent bypasses.
Args: Args:
path: The request path. path: The request path.
Returns: Returns:
True if the path is likely exempt. True if the path is likely exempt.
""" """
# 1. Normalize path to prevent /webhook/../ bypasses # 1. Normalize path to prevent /webhook/../ bypasses
# Use posixpath for consistent behavior on all platforms # Use posixpath for consistent behavior on all platforms
import posixpath import posixpath
normalized_path = posixpath.normpath(path) normalized_path = posixpath.normpath(path)
# Ensure it starts with / for comparison # Ensure it starts with / for comparison
if not normalized_path.startswith("/"): if not normalized_path.startswith("/"):
normalized_path = "/" + normalized_path normalized_path = "/" + normalized_path
# Add back trailing slash if it was present in original path # Add back trailing slash if it was present in original path
# to ensure prefix matching behaves as expected # to ensure prefix matching behaves as expected
if path.endswith("/") and not normalized_path.endswith("/"): if path.endswith("/") and not normalized_path.endswith("/"):
@@ -200,15 +202,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
# Patterns ending with / are prefix-matched # Patterns ending with / are prefix-matched
# Patterns NOT ending with / are exact-matched # Patterns NOT ending with / are exact-matched
exempt_patterns = [ exempt_patterns = [
"/webhook/", # Prefix match (e.g., /webhook/stripe) "/webhook/", # Prefix match (e.g., /webhook/stripe)
"/webhook", # Exact match "/webhook", # Exact match
"/api/v1/", # Prefix match "/api/v1/", # Prefix match
"/lightning/webhook/", # Prefix match "/lightning/webhook/", # Prefix match
"/lightning/webhook", # Exact match "/lightning/webhook", # Exact match
"/_internal/", # Prefix match "/_internal/", # Prefix match
"/_internal", # Exact match "/_internal", # Exact match
] ]
for pattern in exempt_patterns: for pattern in exempt_patterns:
if pattern.endswith("/"): if pattern.endswith("/"):
if normalized_path.startswith(pattern): if normalized_path.startswith(pattern):
@@ -216,20 +218,20 @@ class CSRFMiddleware(BaseHTTPMiddleware):
else: else:
if normalized_path == pattern: if normalized_path == pattern:
return True return True
return False return False
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool: async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
"""Validate the CSRF token in the request. """Validate the CSRF token in the request.
Checks for token in: Checks for token in:
1. X-CSRF-Token header 1. X-CSRF-Token header
2. csrf_token form field 2. csrf_token form field
Args: Args:
request: The incoming request. request: The incoming request.
csrf_cookie: The expected token from the cookie. csrf_cookie: The expected token from the cookie.
Returns: Returns:
True if the token is valid, False otherwise. True if the token is valid, False otherwise.
""" """
@@ -241,11 +243,14 @@ class CSRFMiddleware(BaseHTTPMiddleware):
header_token = request.headers.get(self.header_name) header_token = request.headers.get(self.header_name)
if header_token and validate_csrf_token(header_token, csrf_cookie): if header_token and validate_csrf_token(header_token, csrf_cookie):
return True return True
# If no header token, try form data (for non-JSON POSTs) # If no header token, try form data (for non-JSON POSTs)
# Check Content-Type to avoid hanging on non-form requests # Check Content-Type to avoid hanging on non-form requests
content_type = request.headers.get("Content-Type", "") content_type = request.headers.get("Content-Type", "")
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type: if (
"application/x-www-form-urlencoded" in content_type
or "multipart/form-data" in content_type
):
try: try:
form_data = await request.form() form_data = await request.form()
form_token = form_data.get(self.form_field) form_token = form_data.get(self.form_field)
@@ -254,5 +259,5 @@ class CSRFMiddleware(BaseHTTPMiddleware):
except Exception: except Exception:
# Error parsing form data, treat as invalid # Error parsing form data, treat as invalid
pass pass
return False return False

View File

@@ -4,22 +4,21 @@ Logs HTTP requests with timing, status codes, and client information
for monitoring and debugging purposes. for monitoring and debugging purposes.
""" """
import logging
import time import time
import uuid import uuid
import logging from typing import List, Optional
from typing import Optional, List
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
logger = logging.getLogger("timmy.requests") logger = logging.getLogger("timmy.requests")
class RequestLoggingMiddleware(BaseHTTPMiddleware): class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware to log all HTTP requests. """Middleware to log all HTTP requests.
Logs the following information for each request: Logs the following information for each request:
- HTTP method and path - HTTP method and path
- Response status code - Response status code
@@ -27,60 +26,55 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
- Client IP address - Client IP address
- User-Agent header - User-Agent header
- Correlation ID for tracing - Correlation ID for tracing
Usage: Usage:
app.add_middleware(RequestLoggingMiddleware) app.add_middleware(RequestLoggingMiddleware)
# Skip certain paths: # Skip certain paths:
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"]) app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
Attributes: Attributes:
skip_paths: List of URL paths to skip logging. skip_paths: List of URL paths to skip logging.
log_level: Logging level for successful requests. log_level: Logging level for successful requests.
""" """
def __init__( def __init__(self, app, skip_paths: Optional[List[str]] = None, log_level: int = logging.INFO):
self,
app,
skip_paths: Optional[List[str]] = None,
log_level: int = logging.INFO
):
super().__init__(app) super().__init__(app)
self.skip_paths = set(skip_paths or []) self.skip_paths = set(skip_paths or [])
self.log_level = log_level self.log_level = log_level
async def dispatch(self, request: Request, call_next) -> Response: async def dispatch(self, request: Request, call_next) -> Response:
"""Log the request and response details. """Log the request and response details.
Args: Args:
request: The incoming request. request: The incoming request.
call_next: Callable to get the response from downstream. call_next: Callable to get the response from downstream.
Returns: Returns:
The response from downstream. The response from downstream.
""" """
# Check if we should skip logging this path # Check if we should skip logging this path
if request.url.path in self.skip_paths: if request.url.path in self.skip_paths:
return await call_next(request) return await call_next(request)
# Generate correlation ID # Generate correlation ID
correlation_id = str(uuid.uuid4())[:8] correlation_id = str(uuid.uuid4())[:8]
request.state.correlation_id = correlation_id request.state.correlation_id = correlation_id
# Record start time # Record start time
start_time = time.time() start_time = time.time()
# Get client info # Get client info
client_ip = self._get_client_ip(request) client_ip = self._get_client_ip(request)
user_agent = request.headers.get("user-agent", "-") user_agent = request.headers.get("user-agent", "-")
try: try:
# Process the request # Process the request
response = await call_next(request) response = await call_next(request)
# Calculate duration # Calculate duration
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
# Log the request # Log the request
self._log_request( self._log_request(
method=request.method, method=request.method,
@@ -89,14 +83,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
duration_ms=duration_ms, duration_ms=duration_ms,
client_ip=client_ip, client_ip=client_ip,
user_agent=user_agent, user_agent=user_agent,
correlation_id=correlation_id correlation_id=correlation_id,
) )
# Add correlation ID to response headers # Add correlation ID to response headers
response.headers["X-Correlation-ID"] = correlation_id response.headers["X-Correlation-ID"] = correlation_id
return response return response
except Exception as exc: except Exception as exc:
# Calculate duration even for failed requests # Calculate duration even for failed requests
duration_ms = (time.time() - start_time) * 1000 duration_ms = (time.time() - start_time) * 1000
@@ -110,6 +104,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# Auto-escalate: create bug report task from unhandled exception # Auto-escalate: create bug report task from unhandled exception
try: try:
from infrastructure.error_capture import capture_error from infrastructure.error_capture import capture_error
capture_error( capture_error(
exc, exc,
source="http", source="http",
@@ -126,16 +121,16 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
# Re-raise the exception # Re-raise the exception
raise raise
def _get_client_ip(self, request: Request) -> str: def _get_client_ip(self, request: Request) -> str:
"""Extract the client IP address from the request. """Extract the client IP address from the request.
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests, Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
falls back to the direct client IP. falls back to the direct client IP.
Args: Args:
request: The incoming request. request: The incoming request.
Returns: Returns:
Client IP address string. Client IP address string.
""" """
@@ -144,17 +139,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
if forwarded_for: if forwarded_for:
# X-Forwarded-For can contain multiple IPs, take the first one # X-Forwarded-For can contain multiple IPs, take the first one
return forwarded_for.split(",")[0].strip() return forwarded_for.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip") real_ip = request.headers.get("x-real-ip")
if real_ip: if real_ip:
return real_ip return real_ip
# Fall back to direct connection # Fall back to direct connection
if request.client: if request.client:
return request.client.host return request.client.host
return "-" return "-"
def _log_request( def _log_request(
self, self,
method: str, method: str,
@@ -163,10 +158,10 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
duration_ms: float, duration_ms: float,
client_ip: str, client_ip: str,
user_agent: str, user_agent: str,
correlation_id: str correlation_id: str,
) -> None: ) -> None:
"""Format and log the request details. """Format and log the request details.
Args: Args:
method: HTTP method. method: HTTP method.
path: Request path. path: Request path.
@@ -182,14 +177,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
level = logging.ERROR level = logging.ERROR
elif status_code >= 400: elif status_code >= 400:
level = logging.WARNING level = logging.WARNING
message = ( message = (
f"[{correlation_id}] {method} {path} - {status_code} " f"[{correlation_id}] {method} {path} - {status_code} "
f"- {duration_ms:.2f}ms - {client_ip}" f"- {duration_ms:.2f}ms - {client_ip}"
) )
# Add user agent for non-health requests # Add user agent for non-health requests
if path not in self.skip_paths: if path not in self.skip_paths:
message += f" - {user_agent[:50]}" message += f" - {user_agent[:50]}"
logger.log(level, message) logger.log(level, message)

View File

@@ -4,6 +4,8 @@ Adds common security headers to all HTTP responses to improve
application security posture against various attacks. application security posture against various attacks.
""" """
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
@@ -11,7 +13,7 @@ from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware): class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers to all responses. """Middleware to add security headers to all responses.
Adds the following headers: Adds the following headers:
- X-Content-Type-Options: Prevents MIME type sniffing - X-Content-Type-Options: Prevents MIME type sniffing
- X-Frame-Options: Prevents clickjacking - X-Frame-Options: Prevents clickjacking
@@ -20,41 +22,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
- Permissions-Policy: Restricts feature access - Permissions-Policy: Restricts feature access
- Content-Security-Policy: Mitigates XSS and data injection - Content-Security-Policy: Mitigates XSS and data injection
- Strict-Transport-Security: Enforces HTTPS (production only) - Strict-Transport-Security: Enforces HTTPS (production only)
Usage: Usage:
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
# Or with production settings: # Or with production settings:
app.add_middleware(SecurityHeadersMiddleware, production=True) app.add_middleware(SecurityHeadersMiddleware, production=True)
Attributes: Attributes:
production: If True, adds HSTS header for HTTPS enforcement. production: If True, adds HSTS header for HTTPS enforcement.
csp_report_only: If True, sends CSP in report-only mode. csp_report_only: If True, sends CSP in report-only mode.
""" """
def __init__( def __init__(
self, self,
app, app,
production: bool = False, production: bool = False,
csp_report_only: bool = False, csp_report_only: bool = False,
custom_csp: str = None custom_csp: Optional[str] = None,
): ):
super().__init__(app) super().__init__(app)
self.production = production self.production = production
self.csp_report_only = csp_report_only self.csp_report_only = csp_report_only
# Build CSP directive # Build CSP directive
self.csp_directive = custom_csp or self._build_csp() self.csp_directive = custom_csp or self._build_csp()
def _build_csp(self) -> str: def _build_csp(self) -> str:
"""Build the Content-Security-Policy directive. """Build the Content-Security-Policy directive.
Creates a restrictive default policy that allows: Creates a restrictive default policy that allows:
- Same-origin resources by default - Same-origin resources by default
- Inline scripts/styles (needed for HTMX/Bootstrap) - Inline scripts/styles (needed for HTMX/Bootstrap)
- Data URIs for images - Data URIs for images
- WebSocket connections - WebSocket connections
Returns: Returns:
CSP directive string. CSP directive string.
""" """
@@ -73,25 +75,25 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"form-action 'self'", "form-action 'self'",
] ]
return "; ".join(directives) return "; ".join(directives)
def _add_security_headers(self, response: Response) -> None: def _add_security_headers(self, response: Response) -> None:
"""Add security headers to a response. """Add security headers to a response.
Args: Args:
response: The response to add headers to. response: The response to add headers to.
""" """
# Prevent MIME type sniffing # Prevent MIME type sniffing
response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Content-Type-Options"] = "nosniff"
# Prevent clickjacking # Prevent clickjacking
response.headers["X-Frame-Options"] = "SAMEORIGIN" response.headers["X-Frame-Options"] = "SAMEORIGIN"
# Enable XSS protection (legacy browsers) # Enable XSS protection (legacy browsers)
response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["X-XSS-Protection"] = "1; mode=block"
# Control referrer information # Control referrer information
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Restrict browser features # Restrict browser features
response.headers["Permissions-Policy"] = ( response.headers["Permissions-Policy"] = (
"camera=(), " "camera=(), "
@@ -103,38 +105,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"gyroscope=(), " "gyroscope=(), "
"accelerometer=()" "accelerometer=()"
) )
# Content Security Policy # Content Security Policy
csp_header = "Content-Security-Policy-Report-Only" if self.csp_report_only else "Content-Security-Policy" csp_header = (
"Content-Security-Policy-Report-Only"
if self.csp_report_only
else "Content-Security-Policy"
)
response.headers[csp_header] = self.csp_directive response.headers[csp_header] = self.csp_directive
# HTTPS enforcement (production only) # HTTPS enforcement (production only)
if self.production: if self.production:
response.headers["Strict-Transport-Security"] = ( response.headers[
"max-age=31536000; includeSubDomains; preload" "Strict-Transport-Security"
) ] = "max-age=31536000; includeSubDomains; preload"
async def dispatch(self, request: Request, call_next) -> Response: async def dispatch(self, request: Request, call_next) -> Response:
"""Add security headers to the response. """Add security headers to the response.
Args: Args:
request: The incoming request. request: The incoming request.
call_next: Callable to get the response from downstream. call_next: Callable to get the response from downstream.
Returns: Returns:
Response with security headers added. Response with security headers added.
""" """
try: try:
response = await call_next(request) response = await call_next(request)
self._add_security_headers(response)
return response
except Exception: except Exception:
# Create a response for the error with security headers import logging
from starlette.responses import PlainTextResponse
response = PlainTextResponse( logging.getLogger(__name__).debug(
content="Internal Server Error", "Upstream error in security headers middleware", exc_info=True
status_code=500
) )
self._add_security_headers(response) from starlette.responses import PlainTextResponse
# Return the error response with headers (don't re-raise)
return response response = PlainTextResponse("Internal Server Error", status_code=500)
self._add_security_headers(response)
return response

View File

@@ -1,24 +1,27 @@
from datetime import date, datetime
from datetime import datetime, date
from enum import Enum as PyEnum from enum import Enum as PyEnum
from sqlalchemy import (
Column, Integer, String, DateTime, Boolean, Enum as SQLEnum, from sqlalchemy import JSON, Boolean, Column, Date, DateTime
Date, ForeignKey, Index, JSON from sqlalchemy import Enum as SQLEnum
) from sqlalchemy import ForeignKey, Index, Integer, String
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .database import Base # Assuming a shared Base in models/database.py from .database import Base # Assuming a shared Base in models/database.py
class TaskState(str, PyEnum): class TaskState(str, PyEnum):
LATER = "LATER" LATER = "LATER"
NEXT = "NEXT" NEXT = "NEXT"
NOW = "NOW" NOW = "NOW"
DONE = "DONE" DONE = "DONE"
DEFERRED = "DEFERRED" # Task pushed to tomorrow DEFERRED = "DEFERRED" # Task pushed to tomorrow
class TaskCertainty(str, PyEnum): class TaskCertainty(str, PyEnum):
FUZZY = "FUZZY" # An intention without a time FUZZY = "FUZZY" # An intention without a time
SOFT = "SOFT" # A flexible task with a time SOFT = "SOFT" # A flexible task with a time
HARD = "HARD" # A fixed meeting/appointment HARD = "HARD" # A fixed meeting/appointment
class Task(Base): class Task(Base):
__tablename__ = "tasks" __tablename__ = "tasks"
@@ -29,7 +32,7 @@ class Task(Base):
state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True) state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True)
certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False) certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False)
is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day
sort_order = Column(Integer, default=0, nullable=False) sort_order = Column(Integer, default=0, nullable=False)
@@ -42,7 +45,8 @@ class Task(Base):
created_at = Column(DateTime, default=datetime.utcnow, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
__table_args__ = (Index('ix_task_state_order', 'state', 'sort_order'),) __table_args__ = (Index("ix_task_state_order", "state", "sort_order"),)
class JournalEntry(Base): class JournalEntry(Base):
__tablename__ = "journal_entries" __tablename__ = "journal_entries"

View File

@@ -1,17 +1,16 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import Session, sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db" SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db"
engine = create_engine( engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
def create_tables(): def create_tables():
"""Create all tables defined by models that have imported Base.""" """Create all tables defined by models that have imported Base."""
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)

View File

@@ -5,9 +5,9 @@ from datetime import datetime
from fastapi import APIRouter, Form, Request from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from timmy.session import chat as agent_chat
from dashboard.store import message_log from dashboard.store import message_log
from dashboard.templating import templates from dashboard.templating import templates
from timmy.session import chat as agent_chat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,9 +38,7 @@ async def list_agents():
@router.get("/default/panel", response_class=HTMLResponse) @router.get("/default/panel", response_class=HTMLResponse)
async def agent_panel(request: Request): async def agent_panel(request: Request):
"""Chat panel — for HTMX main-panel swaps.""" """Chat panel — for HTMX main-panel swaps."""
return templates.TemplateResponse( return templates.TemplateResponse(request, "partials/agent_panel_chat.html", {"agent": None})
request, "partials/agent_panel_chat.html", {"agent": None}
)
@router.get("/default/history", response_class=HTMLResponse) @router.get("/default/history", response_class=HTMLResponse)
@@ -77,7 +75,9 @@ async def chat_agent(request: Request, message: str = Form(...)):
message_log.append(role="user", content=message, timestamp=timestamp, source="browser") message_log.append(role="user", content=message, timestamp=timestamp, source="browser")
if response_text is not None: if response_text is not None:
message_log.append(role="agent", content=response_text, timestamp=timestamp, source="browser") message_log.append(
role="agent", content=response_text, timestamp=timestamp, source="browser"
)
elif error_text: elif error_text:
message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser") message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser")

View File

@@ -12,9 +12,10 @@ from datetime import datetime, timezone
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from timmy.briefing import Briefing, engine as briefing_engine
from timmy import approvals as approval_store
from dashboard.templating import templates from dashboard.templating import templates
from timmy import approvals as approval_store
from timmy.briefing import Briefing
from timmy.briefing import engine as briefing_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,3 @@
import logging import logging
from datetime import date, datetime from datetime import date, datetime
from typing import List, Optional from typing import List, Optional
@@ -8,7 +7,7 @@ from fastapi.responses import HTMLResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState
from dashboard.models.database import SessionLocal, engine, get_db, create_tables from dashboard.models.database import SessionLocal, create_tables, engine, get_db
from dashboard.templating import templates from dashboard.templating import templates
# Ensure CALM tables exist (safe to call multiple times) # Ensure CALM tables exist (safe to call multiple times)
@@ -23,11 +22,19 @@ router = APIRouter(tags=["calm"])
def get_now_task(db: Session) -> Optional[Task]: def get_now_task(db: Session) -> Optional[Task]:
return db.query(Task).filter(Task.state == TaskState.NOW).first() return db.query(Task).filter(Task.state == TaskState.NOW).first()
def get_next_task(db: Session) -> Optional[Task]: def get_next_task(db: Session) -> Optional[Task]:
return db.query(Task).filter(Task.state == TaskState.NEXT).first() return db.query(Task).filter(Task.state == TaskState.NEXT).first()
def get_later_tasks(db: Session) -> List[Task]: def get_later_tasks(db: Session) -> List[Task]:
return db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() return (
db.query(Task)
.filter(Task.state == TaskState.LATER)
.order_by(Task.is_mit.desc(), Task.sort_order)
.all()
)
def promote_tasks(db: Session): def promote_tasks(db: Session):
# Ensure only one NOW task exists. If multiple, demote extras to NEXT. # Ensure only one NOW task exists. If multiple, demote extras to NEXT.
@@ -38,7 +45,7 @@ def promote_tasks(db: Session):
for task_to_demote in now_tasks[1:]: for task_to_demote in now_tasks[1:]:
task_to_demote.state = TaskState.NEXT task_to_demote.state = TaskState.NEXT
db.add(task_to_demote) db.add(task_to_demote)
db.flush() # Make changes visible db.flush() # Make changes visible
# If no NOW task, promote NEXT to NOW # If no NOW task, promote NEXT to NOW
current_now = db.query(Task).filter(Task.state == TaskState.NOW).first() current_now = db.query(Task).filter(Task.state == TaskState.NOW).first()
@@ -47,12 +54,17 @@ def promote_tasks(db: Session):
if next_task: if next_task:
next_task.state = TaskState.NOW next_task.state = TaskState.NOW
db.add(next_task) db.add(next_task)
db.flush() # Make changes visible db.flush() # Make changes visible
# If no NEXT task, promote highest priority LATER to NEXT # If no NEXT task, promote highest priority LATER to NEXT
current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first() current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first()
if not current_next: if not current_next:
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() later_tasks = (
db.query(Task)
.filter(Task.state == TaskState.LATER)
.order_by(Task.is_mit.desc(), Task.sort_order)
.all()
)
if later_tasks: if later_tasks:
later_tasks[0].state = TaskState.NEXT later_tasks[0].state = TaskState.NEXT
db.add(later_tasks[0]) db.add(later_tasks[0])
@@ -60,14 +72,17 @@ def promote_tasks(db: Session):
db.commit() db.commit()
# Endpoints # Endpoints
@router.get("/calm", response_class=HTMLResponse) @router.get("/calm", response_class=HTMLResponse)
async def get_calm_view(request: Request, db: Session = Depends(get_db)): async def get_calm_view(request: Request, db: Session = Depends(get_db)):
now_task = get_now_task(db) now_task = get_now_task(db)
next_task = get_next_task(db) next_task = get_next_task(db)
later_tasks_count = len(get_later_tasks(db)) later_tasks_count = len(get_later_tasks(db))
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": now_task, return templates.TemplateResponse(
request,
"calm/calm_view.html",
{
"now_task": now_task,
"next_task": next_task, "next_task": next_task,
"later_tasks_count": later_tasks_count, "later_tasks_count": later_tasks_count,
}, },
@@ -101,7 +116,7 @@ async def post_morning_ritual(
task = Task( task = Task(
title=mit_title, title=mit_title,
is_mit=True, is_mit=True,
state=TaskState.LATER, # Initially LATER, will be promoted state=TaskState.LATER, # Initially LATER, will be promoted
certainty=TaskCertainty.SOFT, certainty=TaskCertainty.SOFT,
) )
db.add(task) db.add(task)
@@ -113,7 +128,7 @@ async def post_morning_ritual(
db.add(journal_entry) db.add(journal_entry)
# Create other tasks # Create other tasks
for task_title in other_tasks.split('\n'): for task_title in other_tasks.split("\n"):
task_title = task_title.strip() task_title = task_title.strip()
if task_title: if task_title:
task = Task( task = Task(
@@ -128,20 +143,29 @@ async def post_morning_ritual(
# Set initial NOW/NEXT states # Set initial NOW/NEXT states
# Set initial NOW/NEXT states after all tasks are created # Set initial NOW/NEXT states after all tasks are created
if not get_now_task(db) and not get_next_task(db): if not get_now_task(db) and not get_next_task(db):
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all() later_tasks = (
db.query(Task)
.filter(Task.state == TaskState.LATER)
.order_by(Task.is_mit.desc(), Task.sort_order)
.all()
)
if later_tasks: if later_tasks:
# Set the highest priority LATER task to NOW # Set the highest priority LATER task to NOW
later_tasks[0].state = TaskState.NOW later_tasks[0].state = TaskState.NOW
db.add(later_tasks[0]) db.add(later_tasks[0])
db.flush() # Flush to make the change visible for the next query db.flush() # Flush to make the change visible for the next query
# Set the next highest priority LATER task to NEXT # Set the next highest priority LATER task to NEXT
if len(later_tasks) > 1: if len(later_tasks) > 1:
later_tasks[1].state = TaskState.NEXT later_tasks[1].state = TaskState.NEXT
db.add(later_tasks[1]) db.add(later_tasks[1])
db.commit() # Commit changes after initial NOW/NEXT setup db.commit() # Commit changes after initial NOW/NEXT setup
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": get_now_task(db), return templates.TemplateResponse(
request,
"calm/calm_view.html",
{
"now_task": get_now_task(db),
"next_task": get_next_task(db), "next_task": get_next_task(db),
"later_tasks_count": len(get_later_tasks(db)), "later_tasks_count": len(get_later_tasks(db)),
}, },
@@ -154,7 +178,8 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db
if not journal_entry: if not journal_entry:
raise HTTPException(status_code=404, detail="No journal entry for today") raise HTTPException(status_code=404, detail="No journal entry for today")
return templates.TemplateResponse( return templates.TemplateResponse(
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}) "calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}
)
@router.post("/calm/ritual/evening", response_class=HTMLResponse) @router.post("/calm/ritual/evening", response_class=HTMLResponse)
@@ -175,9 +200,13 @@ async def post_evening_ritual(
db.add(journal_entry) db.add(journal_entry)
# Archive any remaining active tasks # Archive any remaining active tasks
active_tasks = db.query(Task).filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER])).all() active_tasks = (
db.query(Task)
.filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER]))
.all()
)
for task in active_tasks: for task in active_tasks:
task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic
task.deferred_at = datetime.utcnow() task.deferred_at = datetime.utcnow()
db.add(task) db.add(task)
@@ -221,7 +250,7 @@ async def start_task(
): ):
current_now_task = get_now_task(db) current_now_task = get_now_task(db)
if current_now_task and current_now_task.id != task_id: if current_now_task and current_now_task.id != task_id:
current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT
db.add(current_now_task) db.add(current_now_task)
task = db.query(Task).filter(Task.id == task_id).first() task = db.query(Task).filter(Task.id == task_id).first()
@@ -322,7 +351,7 @@ async def reorder_tasks(
): ):
# Reorder LATER tasks # Reorder LATER tasks
if later_task_ids: if later_task_ids:
ids_in_order = [int(x.strip()) for x in later_task_ids.split(',') if x.strip()] ids_in_order = [int(x.strip()) for x in later_task_ids.split(",") if x.strip()]
for index, task_id in enumerate(ids_in_order): for index, task_id in enumerate(ids_in_order):
task = db.query(Task).filter(Task.id == task_id).first() task = db.query(Task).filter(Task.id == task_id).first()
if task and task.state == TaskState.LATER: if task and task.state == TaskState.LATER:
@@ -332,16 +361,18 @@ async def reorder_tasks(
# Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly) # Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly)
if next_task_id: if next_task_id:
task = db.query(Task).filter(Task.id == next_task_id).first() task = db.query(Task).filter(Task.id == next_task_id).first()
if task and task.state == TaskState.LATER: # Only if it was a LATER task being promoted manually if (
task and task.state == TaskState.LATER
): # Only if it was a LATER task being promoted manually
# Demote current NEXT to LATER # Demote current NEXT to LATER
current_next = get_next_task(db) current_next = get_next_task(db)
if current_next: if current_next:
current_next.state = TaskState.LATER current_next.state = TaskState.LATER
current_next.sort_order = len(get_later_tasks(db)) # Add to end of later current_next.sort_order = len(get_later_tasks(db)) # Add to end of later
db.add(current_next) db.add(current_next)
task.state = TaskState.NEXT task.state = TaskState.NEXT
task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency
db.add(task) db.add(task)
db.commit() db.commit()

View File

@@ -27,12 +27,13 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["chat-api"]) router = APIRouter(prefix="/api", tags=["chat-api"])
_UPLOAD_DIR = os.path.join("data", "chat-uploads") _UPLOAD_DIR = str(Path(settings.repo_root) / "data" / "chat-uploads")
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB _MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
# ── POST /api/chat ──────────────────────────────────────────────────────────── # ── POST /api/chat ────────────────────────────────────────────────────────────
@router.post("/chat") @router.post("/chat")
async def api_chat(request: Request): async def api_chat(request: Request):
"""Accept a JSON chat payload and return the agent's reply. """Accept a JSON chat payload and return the agent's reply.
@@ -65,7 +66,8 @@ async def api_chat(request: Request):
# Handle multimodal content arrays — extract text parts # Handle multimodal content arrays — extract text parts
if isinstance(content, list): if isinstance(content, list):
text_parts = [ text_parts = [
p.get("text", "") for p in content p.get("text", "")
for p in content
if isinstance(p, dict) and p.get("type") == "text" if isinstance(p, dict) and p.get("type") == "text"
] ]
last_user_msg = " ".join(text_parts).strip() last_user_msg = " ".join(text_parts).strip()
@@ -109,6 +111,7 @@ async def api_chat(request: Request):
# ── POST /api/upload ────────────────────────────────────────────────────────── # ── POST /api/upload ──────────────────────────────────────────────────────────
@router.post("/upload") @router.post("/upload")
async def api_upload(file: UploadFile = File(...)): async def api_upload(file: UploadFile = File(...)):
"""Accept a file upload and return its URL. """Accept a file upload and return its URL.
@@ -147,6 +150,7 @@ async def api_upload(file: UploadFile = File(...)):
# ── GET /api/chat/history ──────────────────────────────────────────────────── # ── GET /api/chat/history ────────────────────────────────────────────────────
@router.get("/chat/history") @router.get("/chat/history")
async def api_chat_history(): async def api_chat_history():
"""Return the in-memory chat history as JSON.""" """Return the in-memory chat history as JSON."""
@@ -165,6 +169,7 @@ async def api_chat_history():
# ── DELETE /api/chat/history ────────────────────────────────────────────────── # ── DELETE /api/chat/history ──────────────────────────────────────────────────
@router.delete("/chat/history") @router.delete("/chat/history")
async def api_clear_history(): async def api_clear_history():
"""Clear the in-memory chat history.""" """Clear the in-memory chat history."""

View File

@@ -7,9 +7,10 @@ Endpoints:
GET /discord/oauth-url — get the bot's OAuth2 authorization URL GET /discord/oauth-url — get the bot's OAuth2 authorization URL
""" """
from typing import Optional
from fastapi import APIRouter, File, Form, UploadFile from fastapi import APIRouter, File, Form, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional
router = APIRouter(prefix="/discord", tags=["discord"]) router = APIRouter(prefix="/discord", tags=["discord"])

View File

@@ -0,0 +1,77 @@
"""Experiment dashboard routes — autoresearch experiment monitoring.
Provides endpoints for viewing, starting, and monitoring autonomous
ML experiment loops powered by Karpathy's autoresearch pattern.
"""
import logging
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from config import settings
from dashboard.templating import templates
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/experiments", tags=["experiments"])
def _workspace() -> Path:
return Path(settings.repo_root) / settings.autoresearch_workspace
@router.get("", response_class=HTMLResponse)
async def experiments_page(request: Request):
"""Experiment dashboard — lists past runs and allows starting new ones."""
from timmy.autoresearch import get_experiment_history
history = []
try:
history = get_experiment_history(_workspace())
except Exception:
logger.debug("Failed to load experiment history", exc_info=True)
return templates.TemplateResponse(
request,
"experiments.html",
{
"page_title": "Experiments — Autoresearch",
"enabled": settings.autoresearch_enabled,
"history": history[:50],
"metric_name": settings.autoresearch_metric,
"time_budget": settings.autoresearch_time_budget,
"max_iterations": settings.autoresearch_max_iterations,
},
)
@router.post("/start", response_class=JSONResponse)
async def start_experiment(request: Request):
"""Kick off an experiment loop in the background."""
if not settings.autoresearch_enabled:
raise HTTPException(
status_code=403,
detail="Autoresearch is disabled. Set AUTORESEARCH_ENABLED=true.",
)
from timmy.autoresearch import prepare_experiment
workspace = _workspace()
status = prepare_experiment(workspace)
return {"status": "started", "workspace": str(workspace), "prepare": status}
@router.get("/{run_id}", response_class=JSONResponse)
async def experiment_detail(run_id: str):
"""Get details for a specific experiment run."""
from timmy.autoresearch import get_experiment_history
history = get_experiment_history(_workspace())
for entry in history:
if entry.get("run_id") == run_id:
return entry
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")

View File

@@ -43,6 +43,7 @@ async def grok_status(request: Request):
stats = None stats = None
try: try:
from timmy.backends import get_grok_backend from timmy.backends import get_grok_backend
backend = get_grok_backend() backend = get_grok_backend()
stats = { stats = {
"total_requests": backend.stats.total_requests, "total_requests": backend.stats.total_requests,
@@ -52,12 +53,16 @@ async def grok_status(request: Request):
"errors": backend.stats.errors, "errors": backend.stats.errors,
} }
except Exception: except Exception:
pass logger.debug("Failed to load Grok stats", exc_info=True)
return templates.TemplateResponse(request, "grok_status.html", { return templates.TemplateResponse(
"status": status, request,
"stats": stats, "grok_status.html",
}) {
"status": status,
"stats": stats,
},
)
@router.post("/toggle") @router.post("/toggle")
@@ -90,7 +95,7 @@ async def toggle_grok_mode(request: Request):
success=True, success=True,
) )
except Exception: except Exception:
pass logger.debug("Failed to log Grok toggle to Spark", exc_info=True)
return HTMLResponse( return HTMLResponse(
_render_toggle_card(_grok_mode_active), _render_toggle_card(_grok_mode_active),
@@ -104,10 +109,13 @@ def _run_grok_query(message: str) -> dict:
Returns: Returns:
{"response": str | None, "error": str | None} {"response": str | None, "error": str | None}
""" """
from timmy.backends import grok_available, get_grok_backend from timmy.backends import get_grok_backend, grok_available
if not grok_available(): if not grok_available():
return {"response": None, "error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY."} return {
"response": None,
"error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY.",
}
backend = get_grok_backend() backend = get_grok_backend()
@@ -115,12 +123,13 @@ def _run_grok_query(message: str) -> dict:
if not settings.grok_free: if not settings.grok_free:
try: try:
from lightning.factory import get_backend as get_ln_backend from lightning.factory import get_backend as get_ln_backend
ln = get_ln_backend() ln = get_ln_backend()
sats = min(settings.grok_max_sats_per_query, 100) sats = min(settings.grok_max_sats_per_query, 100)
ln.create_invoice(sats, f"Grok: {message[:50]}") ln.create_invoice(sats, f"Grok: {message[:50]}")
invoice_note = f" | {sats} sats" invoice_note = f" | {sats} sats"
except Exception: except Exception:
pass logger.debug("Lightning invoice creation failed", exc_info=True)
try: try:
result = backend.run(message) result = backend.run(message)
@@ -132,9 +141,10 @@ def _run_grok_query(message: str) -> dict:
@router.post("/chat", response_class=HTMLResponse) @router.post("/chat", response_class=HTMLResponse)
async def grok_chat(request: Request, message: str = Form(...)): async def grok_chat(request: Request, message: str = Form(...)):
"""Send a message directly to Grok and return HTMX chat partial.""" """Send a message directly to Grok and return HTMX chat partial."""
from dashboard.store import message_log
from datetime import datetime from datetime import datetime
from dashboard.store import message_log
timestamp = datetime.now().strftime("%H:%M:%S") timestamp = datetime.now().strftime("%H:%M:%S")
result = _run_grok_query(message) result = _run_grok_query(message)
@@ -142,9 +152,13 @@ async def grok_chat(request: Request, message: str = Form(...)):
message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser") message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser")
if result["response"]: if result["response"]:
message_log.append(role="agent", content=result["response"], timestamp=timestamp, source="browser") message_log.append(
role="agent", content=result["response"], timestamp=timestamp, source="browser"
)
else: else:
message_log.append(role="error", content=result["error"], timestamp=timestamp, source="browser") message_log.append(
role="error", content=result["error"], timestamp=timestamp, source="browser"
)
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
@@ -185,6 +199,7 @@ async def grok_stats():
def _render_toggle_card(active: bool) -> str: def _render_toggle_card(active: bool) -> str:
"""Render the Grok Mode toggle card HTML.""" """Render the Grok Mode toggle card HTML."""
import html import html
color = "#00ff88" if active else "#666" color = "#00ff88" if active else "#666"
state = "ACTIVE" if active else "STANDBY" state = "ACTIVE" if active else "STANDBY"
glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none" glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none"

View File

@@ -22,6 +22,7 @@ router = APIRouter(tags=["health"])
class DependencyStatus(BaseModel): class DependencyStatus(BaseModel):
"""Status of a single dependency.""" """Status of a single dependency."""
name: str name: str
status: str # "healthy", "degraded", "unavailable" status: str # "healthy", "degraded", "unavailable"
sovereignty_score: int # 0-10 sovereignty_score: int # 0-10
@@ -30,6 +31,7 @@ class DependencyStatus(BaseModel):
class SovereigntyReport(BaseModel): class SovereigntyReport(BaseModel):
"""Full sovereignty audit report.""" """Full sovereignty audit report."""
overall_score: float overall_score: float
dependencies: list[DependencyStatus] dependencies: list[DependencyStatus]
timestamp: str timestamp: str
@@ -38,6 +40,7 @@ class SovereigntyReport(BaseModel):
class HealthStatus(BaseModel): class HealthStatus(BaseModel):
"""System health status.""" """System health status."""
status: str status: str
timestamp: str timestamp: str
version: str version: str
@@ -52,6 +55,7 @@ def _check_ollama_sync() -> DependencyStatus:
"""Synchronous Ollama check — run via asyncio.to_thread().""" """Synchronous Ollama check — run via asyncio.to_thread()."""
try: try:
import urllib.request import urllib.request
url = settings.ollama_url.replace("localhost", "127.0.0.1") url = settings.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request( req = urllib.request.Request(
f"{url}/api/tags", f"{url}/api/tags",
@@ -67,7 +71,7 @@ def _check_ollama_sync() -> DependencyStatus:
details={"url": settings.ollama_url, "model": settings.ollama_model}, details={"url": settings.ollama_url, "model": settings.ollama_model},
) )
except Exception: except Exception:
pass logger.debug("Ollama health check failed", exc_info=True)
return DependencyStatus( return DependencyStatus(
name="Ollama AI", name="Ollama AI",
@@ -142,7 +146,7 @@ def _calculate_overall_score(deps: list[DependencyStatus]) -> float:
def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]: def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
"""Generate recommendations based on dependency status.""" """Generate recommendations based on dependency status."""
recommendations = [] recommendations = []
for dep in deps: for dep in deps:
if dep.status == "unavailable": if dep.status == "unavailable":
recommendations.append(f"{dep.name} is unavailable - check configuration") recommendations.append(f"{dep.name} is unavailable - check configuration")
@@ -151,25 +155,25 @@ def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
recommendations.append( recommendations.append(
"Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND" "Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND"
) )
if not recommendations: if not recommendations:
recommendations.append("System operating optimally - all dependencies healthy") recommendations.append("System operating optimally - all dependencies healthy")
return recommendations return recommendations
@router.get("/health") @router.get("/health")
async def health_check(): async def health_check():
"""Basic health check endpoint. """Basic health check endpoint.
Returns legacy format for backward compatibility with existing tests, Returns legacy format for backward compatibility with existing tests,
plus extended information for the Mission Control dashboard. plus extended information for the Mission Control dashboard.
""" """
uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds() uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds()
# Legacy format for test compatibility # Legacy format for test compatibility
ollama_ok = await check_ollama() ollama_ok = await check_ollama()
agent_status = "idle" if ollama_ok else "offline" agent_status = "idle" if ollama_ok else "offline"
return { return {
@@ -193,12 +197,13 @@ async def health_check():
async def health_status_panel(request: Request): async def health_status_panel(request: Request):
"""Simple HTML health status panel.""" """Simple HTML health status panel."""
ollama_ok = await check_ollama() ollama_ok = await check_ollama()
status_text = "UP" if ollama_ok else "DOWN" status_text = "UP" if ollama_ok else "DOWN"
status_color = "#10b981" if ollama_ok else "#ef4444" status_color = "#10b981" if ollama_ok else "#ef4444"
import html import html
model = html.escape(settings.ollama_model) # Include model for test compatibility model = html.escape(settings.ollama_model) # Include model for test compatibility
html_content = f""" html_content = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html>
@@ -217,7 +222,7 @@ async def health_status_panel(request: Request):
@router.get("/health/sovereignty", response_model=SovereigntyReport) @router.get("/health/sovereignty", response_model=SovereigntyReport)
async def sovereignty_check(): async def sovereignty_check():
"""Comprehensive sovereignty audit report. """Comprehensive sovereignty audit report.
Returns the status of all external dependencies with sovereignty scores. Returns the status of all external dependencies with sovereignty scores.
Use this to verify the system is operating in a sovereign manner. Use this to verify the system is operating in a sovereign manner.
""" """
@@ -226,10 +231,10 @@ async def sovereignty_check():
_check_lightning(), _check_lightning(),
_check_sqlite(), _check_sqlite(),
] ]
overall = _calculate_overall_score(dependencies) overall = _calculate_overall_score(dependencies)
recommendations = _generate_recommendations(dependencies) recommendations = _generate_recommendations(dependencies)
return SovereigntyReport( return SovereigntyReport(
overall_score=overall, overall_score=overall,
dependencies=dependencies, dependencies=dependencies,

View File

@@ -19,8 +19,7 @@ AGENT_CATALOG = [
"name": "Orchestrator", "name": "Orchestrator",
"role": "Local AI", "role": "Local AI",
"description": ( "description": (
"Primary AI agent. Coordinates tasks, manages memory. " "Primary AI agent. Coordinates tasks, manages memory. " "Uses distributed brain."
"Uses distributed brain."
), ),
"capabilities": "chat,reasoning,coordination,memory", "capabilities": "chat,reasoning,coordination,memory",
"rate_sats": 0, "rate_sats": 0,
@@ -37,11 +36,11 @@ async def api_list_agents():
pending_tasks = len(await brain.get_pending_tasks(limit=1000)) pending_tasks = len(await brain.get_pending_tasks(limit=1000))
except Exception: except Exception:
pending_tasks = 0 pending_tasks = 0
catalog = [dict(AGENT_CATALOG[0])] catalog = [dict(AGENT_CATALOG[0])]
catalog[0]["pending_tasks"] = pending_tasks catalog[0]["pending_tasks"] = pending_tasks
catalog[0]["status"] = "active" catalog[0]["status"] = "active"
# Include 'total' for backward compatibility with tests # Include 'total' for backward compatibility with tests
return {"agents": catalog, "total": len(catalog)} return {"agents": catalog, "total": len(catalog)}
@@ -82,7 +81,7 @@ async def marketplace_ui(request: Request):
"page_title": "Agent Marketplace", "page_title": "Agent Marketplace",
"active_count": active, "active_count": active,
"planned_count": 0, "planned_count": 0,
} },
) )

View File

@@ -5,17 +5,17 @@ from typing import Optional
from fastapi import APIRouter, Form, HTTPException, Request from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from dashboard.templating import templates
from timmy.memory.vector_store import ( from timmy.memory.vector_store import (
store_memory, delete_memory,
search_memories,
get_memory_stats, get_memory_stats,
recall_personal_facts, recall_personal_facts,
recall_personal_facts_with_ids, recall_personal_facts_with_ids,
search_memories,
store_memory,
store_personal_fact, store_personal_fact,
update_personal_fact, update_personal_fact,
delete_memory,
) )
from dashboard.templating import templates
router = APIRouter(prefix="/memory", tags=["memory"]) router = APIRouter(prefix="/memory", tags=["memory"])
@@ -36,10 +36,10 @@ async def memory_page(
agent_id=agent_id, agent_id=agent_id,
limit=20, limit=20,
) )
stats = get_memory_stats() stats = get_memory_stats()
facts = recall_personal_facts_with_ids()[:10] facts = recall_personal_facts_with_ids()[:10]
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"memory.html", "memory.html",
@@ -67,7 +67,7 @@ async def memory_search(
context_type=context_type, context_type=context_type,
limit=20, limit=20,
) )
# Return partial for HTMX # Return partial for HTMX
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,

View File

@@ -13,6 +13,7 @@ from fastapi.responses import HTMLResponse
from pydantic import BaseModel from pydantic import BaseModel
from config import settings from config import settings
from dashboard.templating import templates
from infrastructure.models.registry import ( from infrastructure.models.registry import (
CustomModel, CustomModel,
ModelFormat, ModelFormat,
@@ -20,7 +21,6 @@ from infrastructure.models.registry import (
ModelRole, ModelRole,
model_registry, model_registry,
) )
from dashboard.templating import templates
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,6 +33,7 @@ api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
class RegisterModelRequest(BaseModel): class RegisterModelRequest(BaseModel):
"""Request body for model registration.""" """Request body for model registration."""
name: str name: str
format: str # gguf, safetensors, hf, ollama format: str # gguf, safetensors, hf, ollama
path: str path: str
@@ -45,12 +46,14 @@ class RegisterModelRequest(BaseModel):
class AssignModelRequest(BaseModel): class AssignModelRequest(BaseModel):
"""Request body for assigning a model to an agent.""" """Request body for assigning a model to an agent."""
agent_id: str agent_id: str
model_name: str model_name: str
class SetActiveRequest(BaseModel): class SetActiveRequest(BaseModel):
"""Request body for enabling/disabling a model.""" """Request body for enabling/disabling a model."""
active: bool active: bool
@@ -92,15 +95,14 @@ async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid format: {request.format}. " detail=f"Invalid format: {request.format}. "
f"Choose from: {[f.value for f in ModelFormat]}", f"Choose from: {[f.value for f in ModelFormat]}",
) )
try: try:
role = ModelRole(request.role) role = ModelRole(request.role)
except ValueError: except ValueError:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Invalid role: {request.role}. " detail=f"Invalid role: {request.role}. " f"Choose from: {[r.value for r in ModelRole]}",
f"Choose from: {[r.value for r in ModelRole]}",
) )
# Validate path exists for non-Ollama formats # Validate path exists for non-Ollama formats
@@ -163,9 +165,7 @@ async def unregister_model(model_name: str) -> dict[str, str]:
@api_router.patch("/{model_name}/active") @api_router.patch("/{model_name}/active")
async def set_model_active( async def set_model_active(model_name: str, request: SetActiveRequest) -> dict[str, str]:
model_name: str, request: SetActiveRequest
) -> dict[str, str]:
"""Enable or disable a model.""" """Enable or disable a model."""
if not model_registry.set_active(model_name, request.active): if not model_registry.set_active(model_name, request.active):
raise HTTPException(status_code=404, detail=f"Model {model_name} not found") raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
@@ -182,8 +182,7 @@ async def list_assignments() -> dict[str, Any]:
assignments = model_registry.get_agent_assignments() assignments = model_registry.get_agent_assignments()
return { return {
"assignments": [ "assignments": [
{"agent_id": aid, "model_name": mname} {"agent_id": aid, "model_name": mname} for aid, mname in assignments.items()
for aid, mname in assignments.items()
], ],
"total": len(assignments), "total": len(assignments),
} }

View File

@@ -3,8 +3,8 @@
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from timmy.cascade_adapter import get_cascade_adapter
from dashboard.templating import templates from dashboard.templating import templates
from timmy.cascade_adapter import get_cascade_adapter
router = APIRouter(prefix="/router", tags=["router"]) router = APIRouter(prefix="/router", tags=["router"])
@@ -13,19 +13,19 @@ router = APIRouter(prefix="/router", tags=["router"])
async def router_status_page(request: Request): async def router_status_page(request: Request):
"""Cascade Router status dashboard.""" """Cascade Router status dashboard."""
adapter = get_cascade_adapter() adapter = get_cascade_adapter()
providers = adapter.get_provider_status() providers = adapter.get_provider_status()
preferred = adapter.get_preferred_provider() preferred = adapter.get_preferred_provider()
# Calculate overall stats # Calculate overall stats
total_requests = sum(p["metrics"]["total"] for p in providers) total_requests = sum(p["metrics"]["total"] for p in providers)
total_success = sum(p["metrics"]["success"] for p in providers) total_success = sum(p["metrics"]["success"] for p in providers)
total_failed = sum(p["metrics"]["failed"] for p in providers) total_failed = sum(p["metrics"]["failed"] for p in providers)
avg_latency = 0.0 avg_latency = 0.0
if providers: if providers:
avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers) avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers)
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"router_status.html", "router_status.html",

View File

@@ -13,8 +13,8 @@ import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from spark.engine import spark_engine
from dashboard.templating import templates from dashboard.templating import templates
from spark.engine import spark_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -86,23 +86,26 @@ async def spark_ui(request: Request):
async def spark_status_json(): async def spark_status_json():
"""Return Spark Intelligence status as JSON.""" """Return Spark Intelligence status as JSON."""
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
status = spark_engine.status() status = spark_engine.status()
advisories = spark_engine.get_advisories() advisories = spark_engine.get_advisories()
return JSONResponse({ return JSONResponse(
"status": status, {
"advisories": [ "status": status,
{ "advisories": [
"category": a.category, {
"priority": a.priority, "category": a.category,
"title": a.title, "priority": a.priority,
"detail": a.detail, "title": a.title,
"suggested_action": a.suggested_action, "detail": a.detail,
"subject": a.subject, "suggested_action": a.suggested_action,
"evidence_count": a.evidence_count, "subject": a.subject,
} "evidence_count": a.evidence_count,
for a in advisories }
], for a in advisories
}) ],
}
)
@router.get("/timeline", response_class=HTMLResponse) @router.get("/timeline", response_class=HTMLResponse)

View File

@@ -7,9 +7,9 @@ from typing import Optional
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from spark.engine import spark_engine
from dashboard.templating import templates from dashboard.templating import templates
from infrastructure.ws_manager.handler import ws_manager from infrastructure.ws_manager.handler import ws_manager
from spark.engine import spark_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,7 +25,7 @@ async def swarm_events(
): ):
"""Event log page.""" """Event log page."""
events = spark_engine.get_timeline(limit=100) events = spark_engine.get_timeline(limit=100)
# Filter if requested # Filter if requested
if task_id: if task_id:
events = [e for e in events if e.task_id == task_id] events = [e for e in events if e.task_id == task_id]
@@ -33,7 +33,7 @@ async def swarm_events(
events = [e for e in events if e.agent_id == agent_id] events = [e for e in events if e.agent_id == agent_id]
if event_type: if event_type:
events = [e for e in events if e.event_type == event_type] events = [e for e in events if e.event_type == event_type]
# Prepare summary and event types for template # Prepare summary and event types for template
summary = {} summary = {}
event_types = set() event_types = set()
@@ -41,7 +41,7 @@ async def swarm_events(
etype = e.event_type etype = e.event_type
event_types.add(etype) event_types.add(etype)
summary[etype] = summary.get(etype, 0) + 1 summary[etype] = summary.get(etype, 0) + 1
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"events.html", "events.html",
@@ -78,14 +78,16 @@ async def swarm_ws(websocket: WebSocket):
await ws_manager.connect(websocket) await ws_manager.connect(websocket)
try: try:
# Send initial state so frontend can clear loading placeholders # Send initial state so frontend can clear loading placeholders
await websocket.send_json({ await websocket.send_json(
"type": "initial_state", {
"data": { "type": "initial_state",
"agents": {"total": 0, "active": 0, "list": []}, "data": {
"tasks": {"active": 0}, "agents": {"total": 0, "active": 0, "list": []},
"auctions": {"list": []}, "tasks": {"active": 0},
}, "auctions": {"list": []},
}) },
}
)
while True: while True:
await websocket.receive_text() await websocket.receive_text()
except WebSocketDisconnect: except WebSocketDisconnect:

View File

@@ -25,26 +25,42 @@ async def lightning_ledger(request: Request):
"pending_incoming_sats": 0, "pending_incoming_sats": 0,
"pending_outgoing_sats": 0, "pending_outgoing_sats": 0,
} }
# Mock transactions # Mock transactions
from collections import namedtuple from collections import namedtuple
from enum import Enum from enum import Enum
class TxType(Enum): class TxType(Enum):
incoming = "incoming" incoming = "incoming"
outgoing = "outgoing" outgoing = "outgoing"
class TxStatus(Enum): class TxStatus(Enum):
completed = "completed" completed = "completed"
pending = "pending" pending = "pending"
Tx = namedtuple("Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"]) Tx = namedtuple(
"Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"]
)
transactions = [ transactions = [
Tx(TxType.outgoing, TxStatus.completed, 50, "hash1", "Model inference", "2026-03-04 10:00:00"), Tx(
Tx(TxType.incoming, TxStatus.completed, 1000, "hash2", "Manual deposit", "2026-03-03 15:00:00"), TxType.outgoing,
TxStatus.completed,
50,
"hash1",
"Model inference",
"2026-03-04 10:00:00",
),
Tx(
TxType.incoming,
TxStatus.completed,
1000,
"hash2",
"Manual deposit",
"2026-03-03 15:00:00",
),
] ]
return templates.TemplateResponse( return templates.TemplateResponse(
request, request,
"ledger.html", "ledger.html",
@@ -84,9 +100,16 @@ async def mission_control(request: Request):
@router.get("/bugs", response_class=HTMLResponse) @router.get("/bugs", response_class=HTMLResponse)
async def bugs_page(request: Request): async def bugs_page(request: Request):
return templates.TemplateResponse(request, "bugs.html", { return templates.TemplateResponse(
"bugs": [], "total": 0, "stats": {}, "filter_status": None, request,
}) "bugs.html",
{
"bugs": [],
"total": 0,
"stats": {},
"filter_status": None,
},
)
@router.get("/self-coding", response_class=HTMLResponse) @router.get("/self-coding", response_class=HTMLResponse)
@@ -109,14 +132,17 @@ async def api_notifications():
"""Return recent system events for the notification dropdown.""" """Return recent system events for the notification dropdown."""
try: try:
from spark.engine import spark_engine from spark.engine import spark_engine
events = spark_engine.get_timeline(limit=20) events = spark_engine.get_timeline(limit=20)
return JSONResponse([ return JSONResponse(
{ [
"event_type": e.event_type, {
"title": getattr(e, "description", e.event_type), "event_type": e.event_type,
"timestamp": str(getattr(e, "timestamp", "")), "title": getattr(e, "description", e.event_type),
} "timestamp": str(getattr(e, "timestamp", "")),
for e in events }
]) for e in events
]
)
except Exception: except Exception:
return JSONResponse([]) return JSONResponse([])

View File

@@ -7,9 +7,10 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from fastapi import APIRouter, HTTPException, Request, Form from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from config import settings
from dashboard.templating import templates from dashboard.templating import templates
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -20,11 +21,17 @@ router = APIRouter(tags=["tasks"])
# Database helpers # Database helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
DB_PATH = Path("data/tasks.db") DB_PATH = Path(settings.repo_root) / "data" / "tasks.db"
VALID_STATUSES = { VALID_STATUSES = {
"pending_approval", "approved", "running", "paused", "pending_approval",
"completed", "vetoed", "failed", "backlogged", "approved",
"running",
"paused",
"completed",
"vetoed",
"failed",
"backlogged",
} }
VALID_PRIORITIES = {"low", "normal", "high", "urgent"} VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
@@ -33,7 +40,8 @@ def _get_db() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True) DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH)) conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS tasks ( CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -46,7 +54,8 @@ def _get_db() -> sqlite3.Connection:
created_at TEXT DEFAULT (datetime('now')), created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT completed_at TEXT
) )
""") """
)
conn.commit() conn.commit()
return conn return conn
@@ -91,37 +100,52 @@ class _TaskView:
# Page routes # Page routes
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/tasks", response_class=HTMLResponse) @router.get("/tasks", response_class=HTMLResponse)
async def tasks_page(request: Request): async def tasks_page(request: Request):
"""Render the main task queue page with 3-column layout.""" """Render the main task queue page with 3-column layout."""
db = _get_db() db = _get_db()
try: try:
pending = [_TaskView(_row_to_dict(r)) for r in db.execute( pending = [
"SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC" _TaskView(_row_to_dict(r))
).fetchall()] for r in db.execute(
active = [_TaskView(_row_to_dict(r)) for r in db.execute( "SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC"
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC" ).fetchall()
).fetchall()] ]
completed = [_TaskView(_row_to_dict(r)) for r in db.execute( active = [
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50" _TaskView(_row_to_dict(r))
).fetchall()] for r in db.execute(
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
).fetchall()
]
completed = [
_TaskView(_row_to_dict(r))
for r in db.execute(
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
).fetchall()
]
finally: finally:
db.close() db.close()
return templates.TemplateResponse(request, "tasks.html", { return templates.TemplateResponse(
"pending_count": len(pending), request,
"pending": pending, "tasks.html",
"active": active, {
"completed": completed, "pending_count": len(pending),
"agents": [], # no agent roster wired yet "pending": pending,
"pre_assign": "", "active": active,
}) "completed": completed,
"agents": [], # no agent roster wired yet
"pre_assign": "",
},
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# HTMX partials (polled by the template) # HTMX partials (polled by the template)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/tasks/pending", response_class=HTMLResponse) @router.get("/tasks/pending", response_class=HTMLResponse)
async def tasks_pending(request: Request): async def tasks_pending(request: Request):
db = _get_db() db = _get_db()
@@ -134,9 +158,11 @@ async def tasks_pending(request: Request):
tasks = [_TaskView(_row_to_dict(r)) for r in rows] tasks = [_TaskView(_row_to_dict(r)) for r in rows]
parts = [] parts = []
for task in tasks: for task in tasks:
parts.append(templates.TemplateResponse( parts.append(
request, "partials/task_card.html", {"task": task} templates.TemplateResponse(
).body.decode()) request, "partials/task_card.html", {"task": task}
).body.decode()
)
if not parts: if not parts:
return HTMLResponse('<div class="empty-column">No pending tasks</div>') return HTMLResponse('<div class="empty-column">No pending tasks</div>')
return HTMLResponse("".join(parts)) return HTMLResponse("".join(parts))
@@ -154,9 +180,11 @@ async def tasks_active(request: Request):
tasks = [_TaskView(_row_to_dict(r)) for r in rows] tasks = [_TaskView(_row_to_dict(r)) for r in rows]
parts = [] parts = []
for task in tasks: for task in tasks:
parts.append(templates.TemplateResponse( parts.append(
request, "partials/task_card.html", {"task": task} templates.TemplateResponse(
).body.decode()) request, "partials/task_card.html", {"task": task}
).body.decode()
)
if not parts: if not parts:
return HTMLResponse('<div class="empty-column">No active tasks</div>') return HTMLResponse('<div class="empty-column">No active tasks</div>')
return HTMLResponse("".join(parts)) return HTMLResponse("".join(parts))
@@ -174,9 +202,11 @@ async def tasks_completed(request: Request):
tasks = [_TaskView(_row_to_dict(r)) for r in rows] tasks = [_TaskView(_row_to_dict(r)) for r in rows]
parts = [] parts = []
for task in tasks: for task in tasks:
parts.append(templates.TemplateResponse( parts.append(
request, "partials/task_card.html", {"task": task} templates.TemplateResponse(
).body.decode()) request, "partials/task_card.html", {"task": task}
).body.decode()
)
if not parts: if not parts:
return HTMLResponse('<div class="empty-column">No completed tasks yet</div>') return HTMLResponse('<div class="empty-column">No completed tasks yet</div>')
return HTMLResponse("".join(parts)) return HTMLResponse("".join(parts))
@@ -186,6 +216,7 @@ async def tasks_completed(request: Request):
# Form-based create (used by the modal in tasks.html) # Form-based create (used by the modal in tasks.html)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/tasks/create", response_class=HTMLResponse) @router.post("/tasks/create", response_class=HTMLResponse)
async def create_task_form( async def create_task_form(
request: Request, request: Request,
@@ -218,6 +249,7 @@ async def create_task_form(
# Task action endpoints (approve, veto, modify, pause, cancel, retry) # Task action endpoints (approve, veto, modify, pause, cancel, retry)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/tasks/{task_id}/approve", response_class=HTMLResponse) @router.post("/tasks/{task_id}/approve", response_class=HTMLResponse)
async def approve_task(request: Request, task_id: str): async def approve_task(request: Request, task_id: str):
return await _set_status(request, task_id, "approved") return await _set_status(request, task_id, "approved")
@@ -268,7 +300,9 @@ async def modify_task(
async def _set_status(request: Request, task_id: str, new_status: str): async def _set_status(request: Request, task_id: str, new_status: str):
"""Helper to update status and return refreshed task card.""" """Helper to update status and return refreshed task card."""
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
)
db = _get_db() db = _get_db()
try: try:
db.execute( db.execute(
@@ -289,6 +323,7 @@ async def _set_status(request: Request, task_id: str, new_status: str):
# JSON API (for programmatic access / Timmy's tool calls) # JSON API (for programmatic access / Timmy's tool calls)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/api/tasks", response_class=JSONResponse, status_code=201) @router.post("/api/tasks", response_class=JSONResponse, status_code=201)
async def api_create_task(request: Request): async def api_create_task(request: Request):
"""Create a task via JSON API.""" """Create a task via JSON API."""
@@ -345,7 +380,9 @@ async def api_update_status(task_id: str, request: Request):
if not new_status or new_status not in VALID_STATUSES: if not new_status or new_status not in VALID_STATUSES:
raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}") raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}")
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
)
db = _get_db() db = _get_db()
try: try:
db.execute( db.execute(
@@ -379,6 +416,7 @@ async def api_delete_task(task_id: str):
# Queue status (polled by the chat panel every 10 seconds) # Queue status (polled by the chat panel every 10 seconds)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/api/queue/status", response_class=JSONResponse) @router.get("/api/queue/status", response_class=JSONResponse)
async def queue_status(assigned_to: str = "default"): async def queue_status(assigned_to: str = "default"):
"""Return queue status for the chat panel's agent status indicator.""" """Return queue status for the chat panel's agent status indicator."""
@@ -396,14 +434,18 @@ async def queue_status(assigned_to: str = "default"):
db.close() db.close()
if running: if running:
return JSONResponse({ return JSONResponse(
"is_working": True, {
"current_task": {"id": running["id"], "title": running["title"]}, "is_working": True,
"tasks_ahead": 0, "current_task": {"id": running["id"], "title": running["title"]},
}) "tasks_ahead": 0,
}
)
return JSONResponse({ return JSONResponse(
"is_working": False, {
"current_task": None, "is_working": False,
"tasks_ahead": ahead["cnt"] if ahead else 0, "current_task": None,
}) "tasks_ahead": ahead["cnt"] if ahead else 0,
}
)

View File

@@ -10,8 +10,8 @@ import logging
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from timmy.thinking import thinking_engine
from dashboard.templating import templates from dashboard.templating import templates
from timmy.thinking import thinking_engine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -8,8 +8,8 @@ from collections import namedtuple
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from timmy.tools import get_all_available_tools
from dashboard.templating import templates from dashboard.templating import templates
from timmy.tools import get_all_available_tools
router = APIRouter(tags=["tools"]) router = APIRouter(tags=["tools"])
@@ -29,9 +29,7 @@ def _build_agent_tools():
for name, fn in available.items() for name, fn in available.items()
] ]
return [ return [_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))]
_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))
]
@router.get("/tools", response_class=HTMLResponse) @router.get("/tools", response_class=HTMLResponse)

View File

@@ -10,9 +10,9 @@ import logging
from fastapi import APIRouter, Form, Request from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse
from dashboard.templating import templates
from integrations.voice.nlu import detect_intent, extract_command from integrations.voice.nlu import detect_intent, extract_command
from timmy.agent import create_timmy from timmy.agent import create_timmy
from dashboard.templating import templates
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -38,6 +38,7 @@ async def tts_status():
"""Check TTS engine availability.""" """Check TTS engine availability."""
try: try:
from timmy_serve.voice_tts import voice_tts from timmy_serve.voice_tts import voice_tts
return { return {
"available": voice_tts.available, "available": voice_tts.available,
"voices": voice_tts.get_voices() if voice_tts.available else [], "voices": voice_tts.get_voices() if voice_tts.available else [],
@@ -51,6 +52,7 @@ async def tts_speak(text: str = Form(...)):
"""Speak text aloud via TTS.""" """Speak text aloud via TTS."""
try: try:
from timmy_serve.voice_tts import voice_tts from timmy_serve.voice_tts import voice_tts
if not voice_tts.available: if not voice_tts.available:
return {"spoken": False, "reason": "TTS engine not available"} return {"spoken": False, "reason": "TTS engine not available"}
voice_tts.speak(text) voice_tts.speak(text)
@@ -86,6 +88,7 @@ async def voice_command(text: str = Form(...)):
# ── Enhanced voice pipeline ────────────────────────────────────────────── # ── Enhanced voice pipeline ──────────────────────────────────────────────
@router.post("/enhanced/process") @router.post("/enhanced/process")
async def process_voice_input( async def process_voice_input(
text: str = Form(...), text: str = Form(...),
@@ -133,6 +136,7 @@ async def process_voice_input(
if speak_response and response_text: if speak_response and response_text:
try: try:
from timmy_serve.voice_tts import voice_tts from timmy_serve.voice_tts import voice_tts
if voice_tts.available: if voice_tts.available:
voice_tts.speak(response_text) voice_tts.speak(response_text)
except Exception: except Exception:

View File

@@ -6,7 +6,7 @@ import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, HTTPException, Request, Form from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse, JSONResponse
from dashboard.templating import templates from dashboard.templating import templates
@@ -26,7 +26,8 @@ def _get_db() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True) DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH)) conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS work_orders ( CREATE TABLE IF NOT EXISTS work_orders (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -41,7 +42,8 @@ def _get_db() -> sqlite3.Connection:
created_at TEXT DEFAULT (datetime('now')), created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT completed_at TEXT
) )
""") """
)
conn.commit() conn.commit()
return conn return conn
@@ -71,7 +73,9 @@ class _WOView:
self.submitter = row.get("submitter", "dashboard") self.submitter = row.get("submitter", "dashboard")
self.status = _EnumLike(row.get("status", "submitted")) self.status = _EnumLike(row.get("status", "submitted"))
raw_files = row.get("related_files", "") raw_files = row.get("related_files", "")
self.related_files = [f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else [] self.related_files = (
[f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else []
)
self.result = row.get("result", "") self.result = row.get("result", "")
self.rejection_reason = row.get("rejection_reason", "") self.rejection_reason = row.get("rejection_reason", "")
self.created_at = row.get("created_at", "") self.created_at = row.get("created_at", "")
@@ -98,6 +102,7 @@ def _query_wos(db, statuses):
# Page route # Page route
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/work-orders/queue", response_class=HTMLResponse) @router.get("/work-orders/queue", response_class=HTMLResponse)
async def work_orders_page(request: Request): async def work_orders_page(request: Request):
db = _get_db() db = _get_db()
@@ -109,21 +114,26 @@ async def work_orders_page(request: Request):
finally: finally:
db.close() db.close()
return templates.TemplateResponse(request, "work_orders.html", { return templates.TemplateResponse(
"pending_count": len(pending), request,
"pending": pending, "work_orders.html",
"active": active, {
"completed": completed, "pending_count": len(pending),
"rejected": rejected, "pending": pending,
"priorities": PRIORITIES, "active": active,
"categories": CATEGORIES, "completed": completed,
}) "rejected": rejected,
"priorities": PRIORITIES,
"categories": CATEGORIES,
},
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Form submit # Form submit
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post("/work-orders/submit", response_class=HTMLResponse) @router.post("/work-orders/submit", response_class=HTMLResponse)
async def submit_work_order( async def submit_work_order(
request: Request, request: Request,
@@ -159,6 +169,7 @@ async def submit_work_order(
# HTMX partials # HTMX partials
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get("/work-orders/queue/pending", response_class=HTMLResponse) @router.get("/work-orders/queue/pending", response_class=HTMLResponse)
async def pending_partial(request: Request): async def pending_partial(request: Request):
db = _get_db() db = _get_db()
@@ -174,7 +185,9 @@ async def pending_partial(request: Request):
parts = [] parts = []
for wo in wos: for wo in wos:
parts.append( parts.append(
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode() templates.TemplateResponse(
request, "partials/work_order_card.html", {"wo": wo}
).body.decode()
) )
return HTMLResponse("".join(parts)) return HTMLResponse("".join(parts))
@@ -194,7 +207,9 @@ async def active_partial(request: Request):
parts = [] parts = []
for wo in wos: for wo in wos:
parts.append( parts.append(
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode() templates.TemplateResponse(
request, "partials/work_order_card.html", {"wo": wo}
).body.decode()
) )
return HTMLResponse("".join(parts)) return HTMLResponse("".join(parts))
@@ -203,8 +218,11 @@ async def active_partial(request: Request):
# Action endpoints # Action endpoints
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _update_status(request: Request, wo_id: str, new_status: str, **extra): async def _update_status(request: Request, wo_id: str, new_status: str, **extra):
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None completed_at = (
datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
)
db = _get_db() db = _get_db()
try: try:
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"] sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]

View File

@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class Message: class Message:
role: str # "user" | "agent" | "error" role: str # "user" | "agent" | "error"
content: str content: str
timestamp: str timestamp: str
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system" source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
@@ -16,7 +16,9 @@ class MessageLog:
self._entries: list[Message] = [] self._entries: list[Message] = []
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None: def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
self._entries.append(Message(role=role, content=content, timestamp=timestamp, source=source)) self._entries.append(
Message(role=role, content=content, timestamp=timestamp, source=source)
)
def all(self) -> list[Message]: def all(self) -> list[Message]:
return list(self._entries) return list(self._entries)

View File

@@ -0,0 +1,90 @@
{% extends "base.html" %}
{% block title %}{{ page_title }}{% endblock %}
{% block extra_styles %}
<style>
.experiments-container { max-width: 1000px; margin: 0 auto; }
.exp-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 20px; }
.exp-title { font-size: 1.3rem; font-weight: 700; color: var(--text-bright); }
.exp-subtitle { font-size: 0.8rem; color: var(--text-dim); margin-top: 2px; }
.exp-config { display: flex; gap: 16px; font-size: 0.8rem; color: var(--text-dim); }
.exp-config span { background: var(--glass-bg); border: 1px solid var(--border); padding: 4px 10px; border-radius: 6px; }
.exp-table { width: 100%; border-collapse: collapse; font-size: 0.85rem; }
.exp-table th { text-align: left; padding: 8px 12px; color: var(--text-dim); border-bottom: 1px solid var(--border); font-weight: 600; }
.exp-table td { padding: 8px 12px; border-bottom: 1px solid var(--border); color: var(--text); }
.exp-table tr:hover { background: var(--glass-bg); }
.metric-good { color: var(--success); }
.metric-bad { color: var(--danger); }
.btn-start { background: var(--accent); color: #fff; border: none; padding: 8px 18px; border-radius: 6px; cursor: pointer; font-size: 0.85rem; }
.btn-start:hover { opacity: 0.9; }
.btn-start:disabled { opacity: 0.4; cursor: not-allowed; }
.disabled-note { font-size: 0.8rem; color: var(--text-dim); margin-top: 8px; }
.empty-state { text-align: center; padding: 40px; color: var(--text-dim); }
</style>
{% endblock %}
{% block content %}
<div class="experiments-container">
<div class="exp-header">
<div>
<div class="exp-title">Autoresearch Experiments</div>
<div class="exp-subtitle">Autonomous ML experiment loops — modify code, train, evaluate, iterate</div>
</div>
<div>
{% if enabled %}
<button class="btn-start"
hx-post="/experiments/start"
hx-target="#experiment-status"
hx-swap="innerHTML">
Start Experiment
</button>
{% else %}
<button class="btn-start" disabled>Disabled</button>
<div class="disabled-note">Set AUTORESEARCH_ENABLED=true to enable</div>
{% endif %}
</div>
</div>
<div class="exp-config">
<span>Metric: {{ metric_name }}</span>
<span>Budget: {{ time_budget }}s</span>
<span>Max iters: {{ max_iterations }}</span>
</div>
<div id="experiment-status" style="margin: 12px 0;"></div>
{% if history %}
<table class="exp-table">
<thead>
<tr>
<th>#</th>
<th>{{ metric_name }}</th>
<th>Duration</th>
<th>Status</th>
</tr>
</thead>
<tbody>
{% for run in history %}
<tr>
<td>{{ loop.index }}</td>
<td>
{% if run.metric is not none %}
{{ "%.4f"|format(run.metric) }}
{% else %}
{% endif %}
</td>
<td>{{ run.get("duration_s", "—") }}s</td>
<td>{% if run.get("success") %}OK{% else %}{{ run.get("error", "failed") }}{% endif %}</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
<div class="empty-state">
No experiments yet. Start one to begin autonomous training.
</div>
{% endif %}
</div>
{% endblock %}

View File

@@ -119,9 +119,7 @@ def capture_error(
return None return None
# Format the stack trace # Format the stack trace
tb_str = "".join( tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
traceback.format_exception(type(exc), exc, exc.__traceback__)
)
# Extract file/line from traceback # Extract file/line from traceback
tb_obj = exc.__traceback__ tb_obj = exc.__traceback__

View File

@@ -19,38 +19,39 @@ logger = logging.getLogger(__name__)
class EventBroadcaster: class EventBroadcaster:
"""Broadcasts events to WebSocket clients. """Broadcasts events to WebSocket clients.
Usage: Usage:
from infrastructure.events.broadcaster import event_broadcaster from infrastructure.events.broadcaster import event_broadcaster
event_broadcaster.broadcast(event) event_broadcaster.broadcast(event)
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._ws_manager: Optional = None self._ws_manager: Optional = None
def _get_ws_manager(self): def _get_ws_manager(self):
"""Lazy import to avoid circular deps.""" """Lazy import to avoid circular deps."""
if self._ws_manager is None: if self._ws_manager is None:
try: try:
from infrastructure.ws_manager.handler import ws_manager from infrastructure.ws_manager.handler import ws_manager
self._ws_manager = ws_manager self._ws_manager = ws_manager
except Exception as exc: except Exception as exc:
logger.debug("WebSocket manager not available: %s", exc) logger.debug("WebSocket manager not available: %s", exc)
return self._ws_manager return self._ws_manager
async def broadcast(self, event: EventLogEntry) -> int: async def broadcast(self, event: EventLogEntry) -> int:
"""Broadcast an event to all connected WebSocket clients. """Broadcast an event to all connected WebSocket clients.
Args: Args:
event: The event to broadcast event: The event to broadcast
Returns: Returns:
Number of clients notified Number of clients notified
""" """
ws_manager = self._get_ws_manager() ws_manager = self._get_ws_manager()
if not ws_manager: if not ws_manager:
return 0 return 0
# Build message payload # Build message payload
payload = { payload = {
"type": "event", "type": "event",
@@ -62,9 +63,9 @@ class EventBroadcaster:
"agent_id": event.agent_id, "agent_id": event.agent_id,
"timestamp": event.timestamp, "timestamp": event.timestamp,
"data": event.data, "data": event.data,
} },
} }
try: try:
# Broadcast to all connected clients # Broadcast to all connected clients
count = await ws_manager.broadcast_json(payload) count = await ws_manager.broadcast_json(payload)
@@ -73,10 +74,10 @@ class EventBroadcaster:
except Exception as exc: except Exception as exc:
logger.error("Failed to broadcast event: %s", exc) logger.error("Failed to broadcast event: %s", exc)
return 0 return 0
def broadcast_sync(self, event: EventLogEntry) -> None: def broadcast_sync(self, event: EventLogEntry) -> None:
"""Synchronous wrapper for broadcast. """Synchronous wrapper for broadcast.
Use this from synchronous code - it schedules the async broadcast Use this from synchronous code - it schedules the async broadcast
in the event loop if one is running. in the event loop if one is running.
""" """
@@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str:
def format_event_for_display(event: EventLogEntry) -> dict: def format_event_for_display(event: EventLogEntry) -> dict:
"""Format event for display in activity feed. """Format event for display in activity feed.
Returns dict with display-friendly fields. Returns dict with display-friendly fields.
""" """
data = event.data or {} data = event.data or {}
# Build description based on event type # Build description based on event type
description = "" description = ""
if event.event_type.value == "task.created": if event.event_type.value == "task.created":
@@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict:
val = str(data[key]) val = str(data[key])
description = val[:60] + "..." if len(val) > 60 else val description = val[:60] + "..." if len(val) > 60 else val
break break
return { return {
"id": event.id, "id": event.id,
"icon": get_event_icon(event.event_type.value), "icon": get_event_icon(event.event_type.value),

View File

@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class Event: class Event:
"""A typed event in the system.""" """A typed event in the system."""
type: str # e.g., "agent.task.assigned", "tool.execution.completed" type: str # e.g., "agent.task.assigned", "tool.execution.completed"
source: str # Agent or component that emitted the event source: str # Agent or component that emitted the event
data: dict = field(default_factory=dict) data: dict = field(default_factory=dict)
@@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
class EventBus: class EventBus:
"""Async event bus for publish/subscribe pattern. """Async event bus for publish/subscribe pattern.
Usage: Usage:
bus = EventBus() bus = EventBus()
# Subscribe to events # Subscribe to events
@bus.subscribe("agent.task.*") @bus.subscribe("agent.task.*")
async def handle_task(event: Event): async def handle_task(event: Event):
print(f"Task event: {event.data}") print(f"Task event: {event.data}")
# Publish events # Publish events
await bus.publish(Event( await bus.publish(Event(
type="agent.task.assigned", type="agent.task.assigned",
@@ -45,88 +46,89 @@ class EventBus:
data={"task_id": "123", "agent": "forge"} data={"task_id": "123", "agent": "forge"}
)) ))
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._subscribers: dict[str, list[EventHandler]] = {} self._subscribers: dict[str, list[EventHandler]] = {}
self._history: list[Event] = [] self._history: list[Event] = []
self._max_history = 1000 self._max_history = 1000
logger.info("EventBus initialized") logger.info("EventBus initialized")
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]: def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
"""Decorator to subscribe to events matching a pattern. """Decorator to subscribe to events matching a pattern.
Patterns support wildcards: Patterns support wildcards:
- "agent.task.assigned" — exact match - "agent.task.assigned" — exact match
- "agent.task.*" — any task event - "agent.task.*" — any task event
- "agent.*" — any agent event - "agent.*" — any agent event
- "*" — all events - "*" — all events
""" """
def decorator(handler: EventHandler) -> EventHandler: def decorator(handler: EventHandler) -> EventHandler:
if event_pattern not in self._subscribers: if event_pattern not in self._subscribers:
self._subscribers[event_pattern] = [] self._subscribers[event_pattern] = []
self._subscribers[event_pattern].append(handler) self._subscribers[event_pattern].append(handler)
logger.debug("Subscribed handler to '%s'", event_pattern) logger.debug("Subscribed handler to '%s'", event_pattern)
return handler return handler
return decorator return decorator
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool: def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
"""Remove a handler from a subscription.""" """Remove a handler from a subscription."""
if event_pattern not in self._subscribers: if event_pattern not in self._subscribers:
return False return False
if handler in self._subscribers[event_pattern]: if handler in self._subscribers[event_pattern]:
self._subscribers[event_pattern].remove(handler) self._subscribers[event_pattern].remove(handler)
logger.debug("Unsubscribed handler from '%s'", event_pattern) logger.debug("Unsubscribed handler from '%s'", event_pattern)
return True return True
return False return False
async def publish(self, event: Event) -> int: async def publish(self, event: Event) -> int:
"""Publish an event to all matching subscribers. """Publish an event to all matching subscribers.
Returns: Returns:
Number of handlers invoked Number of handlers invoked
""" """
# Store in history # Store in history
self._history.append(event) self._history.append(event)
if len(self._history) > self._max_history: if len(self._history) > self._max_history:
self._history = self._history[-self._max_history:] self._history = self._history[-self._max_history :]
# Find matching handlers # Find matching handlers
handlers: list[EventHandler] = [] handlers: list[EventHandler] = []
for pattern, pattern_handlers in self._subscribers.items(): for pattern, pattern_handlers in self._subscribers.items():
if self._match_pattern(event.type, pattern): if self._match_pattern(event.type, pattern):
handlers.extend(pattern_handlers) handlers.extend(pattern_handlers)
# Invoke handlers concurrently # Invoke handlers concurrently
if handlers: if handlers:
await asyncio.gather( await asyncio.gather(
*[self._invoke_handler(h, event) for h in handlers], *[self._invoke_handler(h, event) for h in handlers], return_exceptions=True
return_exceptions=True
) )
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers)) logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
return len(handlers) return len(handlers)
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None: async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
"""Invoke a handler with error handling.""" """Invoke a handler with error handling."""
try: try:
await handler(event) await handler(event)
except Exception as exc: except Exception as exc:
logger.error("Event handler failed for '%s': %s", event.type, exc) logger.error("Event handler failed for '%s': %s", event.type, exc)
def _match_pattern(self, event_type: str, pattern: str) -> bool: def _match_pattern(self, event_type: str, pattern: str) -> bool:
"""Check if event type matches a wildcard pattern.""" """Check if event type matches a wildcard pattern."""
if pattern == "*": if pattern == "*":
return True return True
if pattern.endswith(".*"): if pattern.endswith(".*"):
prefix = pattern[:-2] prefix = pattern[:-2]
return event_type.startswith(prefix + ".") return event_type.startswith(prefix + ".")
return event_type == pattern return event_type == pattern
def get_history( def get_history(
self, self,
event_type: str | None = None, event_type: str | None = None,
@@ -135,15 +137,15 @@ class EventBus:
) -> list[Event]: ) -> list[Event]:
"""Get recent event history with optional filtering.""" """Get recent event history with optional filtering."""
events = self._history events = self._history
if event_type: if event_type:
events = [e for e in events if e.type == event_type] events = [e for e in events if e.type == event_type]
if source: if source:
events = [e for e in events if e.source == source] events = [e for e in events if e.source == source]
return events[-limit:] return events[-limit:]
def clear_history(self) -> None: def clear_history(self) -> None:
"""Clear event history.""" """Clear event history."""
self._history.clear() self._history.clear()
@@ -156,11 +158,13 @@ event_bus = EventBus()
# Convenience functions # Convenience functions
async def emit(event_type: str, source: str, data: dict) -> int: async def emit(event_type: str, source: str, data: dict) -> int:
"""Quick emit an event.""" """Quick emit an event."""
return await event_bus.publish(Event( return await event_bus.publish(
type=event_type, Event(
source=source, type=event_type,
data=data, source=source,
)) data=data,
)
)
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]: def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:

View File

@@ -11,7 +11,7 @@ Usage:
result = await git_hand.run("status") result = await git_hand.run("status")
""" """
from infrastructure.hands.shell import shell_hand
from infrastructure.hands.git import git_hand from infrastructure.hands.git import git_hand
from infrastructure.hands.shell import shell_hand
__all__ = ["shell_hand", "git_hand"] __all__ = ["shell_hand", "git_hand"]

View File

@@ -25,16 +25,18 @@ from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Operations that require explicit confirmation before execution # Operations that require explicit confirmation before execution
DESTRUCTIVE_OPS = frozenset({ DESTRUCTIVE_OPS = frozenset(
"push --force", {
"push -f", "push --force",
"reset --hard", "push -f",
"clean -fd", "reset --hard",
"clean -f", "clean -fd",
"branch -D", "clean -f",
"checkout -- .", "branch -D",
"restore .", "checkout -- .",
}) "restore .",
}
)
@dataclass @dataclass
@@ -190,7 +192,9 @@ class GitHand:
flag = "-b" if create else "" flag = "-b" if create else ""
return await self.run(f"checkout {flag} {branch}".strip()) return await self.run(f"checkout {flag} {branch}".strip())
async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult: async def push(
self, remote: str = "origin", branch: str = "", force: bool = False
) -> GitResult:
"""Push to remote. Force-push requires explicit opt-in.""" """Push to remote. Force-push requires explicit opt-in."""
args = f"push -u {remote} {branch}".strip() args = f"push -u {remote} {branch}".strip()
if force: if force:

View File

@@ -26,15 +26,17 @@ from config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Commands that are always blocked regardless of allow-list # Commands that are always blocked regardless of allow-list
_BLOCKED_COMMANDS = frozenset({ _BLOCKED_COMMANDS = frozenset(
"rm -rf /", {
"rm -rf /*", "rm -rf /",
"mkfs", "rm -rf /*",
"dd if=/dev/zero", "mkfs",
":(){ :|:& };:", # fork bomb "dd if=/dev/zero",
"> /dev/sda", ":(){ :|:& };:", # fork bomb
"chmod -R 777 /", "> /dev/sda",
}) "chmod -R 777 /",
}
)
# Default allow-list: safe build/dev commands # Default allow-list: safe build/dev commands
DEFAULT_ALLOWED_PREFIXES = ( DEFAULT_ALLOWED_PREFIXES = (
@@ -199,9 +201,7 @@ class ShellHand:
proc.kill() proc.kill()
await proc.wait() await proc.wait()
latency = (time.time() - start) * 1000 latency = (time.time() - start) * 1000
logger.warning( logger.warning("Shell command timed out after %ds: %s", effective_timeout, command)
"Shell command timed out after %ds: %s", effective_timeout, command
)
return ShellResult( return ShellResult(
command=command, command=command,
success=False, success=False,

View File

@@ -11,15 +11,17 @@ the tool registry.
import logging import logging
from typing import Any from typing import Any
from infrastructure.hands.shell import shell_hand
from infrastructure.hands.git import git_hand from infrastructure.hands.git import git_hand
from infrastructure.hands.shell import shell_hand
try: try:
from mcp.schemas.base import create_tool_schema from mcp.schemas.base import create_tool_schema
except ImportError: except ImportError:
def create_tool_schema(**kwargs): def create_tool_schema(**kwargs):
return kwargs return kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ── Tool schemas ───────────────────────────────────────────────────────────── # ── Tool schemas ─────────────────────────────────────────────────────────────
@@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = {
# ── Handlers ───────────────────────────────────────────────────────────────── # ── Handlers ─────────────────────────────────────────────────────────────────
async def _handle_shell(**kwargs: Any) -> str: async def _handle_shell(**kwargs: Any) -> str:
"""Handler for the shell MCP tool.""" """Handler for the shell MCP tool."""
command = kwargs.get("command", "") command = kwargs.get("command", "")

View File

@@ -1,12 +1,5 @@
"""Infrastructure models package.""" """Infrastructure models package."""
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
from infrastructure.models.multimodal import ( from infrastructure.models.multimodal import (
ModelCapability, ModelCapability,
ModelInfo, ModelInfo,
@@ -17,6 +10,13 @@ from infrastructure.models.multimodal import (
model_supports_vision, model_supports_vision,
pull_model_with_fallback, pull_model_with_fallback,
) )
from infrastructure.models.registry import (
CustomModel,
ModelFormat,
ModelRegistry,
ModelRole,
model_registry,
)
__all__ = [ __all__ = [
# Registry # Registry

View File

@@ -21,39 +21,130 @@ logger = logging.getLogger(__name__)
class ModelCapability(Enum): class ModelCapability(Enum):
"""Capabilities a model can have.""" """Capabilities a model can have."""
TEXT = auto() # Standard text completion
VISION = auto() # Image understanding TEXT = auto() # Standard text completion
AUDIO = auto() # Audio/speech processing VISION = auto() # Image understanding
TOOLS = auto() # Function calling / tool use AUDIO = auto() # Audio/speech processing
JSON = auto() # Structured output / JSON mode TOOLS = auto() # Function calling / tool use
STREAMING = auto() # Streaming responses JSON = auto() # Structured output / JSON mode
STREAMING = auto() # Streaming responses
# Known model capabilities (local Ollama models) # Known model capabilities (local Ollama models)
# These are used when we can't query the model directly # These are used when we can't query the model directly
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = { KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
# Llama 3.x series # Llama 3.x series
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "llama3.1": {
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TEXT,
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TOOLS,
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.JSON,
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.STREAMING,
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, },
"llama3.1:8b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:8b-instruct": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:70b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.1:405b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"llama3.2": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, "llama3.2:3b": {
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, ModelCapability.TEXT,
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2-vision": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"llama3.2-vision:11b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
# Qwen series # Qwen series
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "qwen2.5": {
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TEXT,
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TOOLS,
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.JSON,
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.STREAMING,
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, },
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, "qwen2.5:7b": {
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION}, ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:14b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:32b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5:72b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"qwen2.5-vl": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"qwen2.5-vl:3b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
"qwen2.5-vl:7b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
ModelCapability.VISION,
},
# DeepSeek series # DeepSeek series
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
@@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "deepseek-v3": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Gemma series # Gemma series
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
# Mistral series # Mistral series
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "mistral": {
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TEXT,
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TOOLS,
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.JSON,
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.STREAMING,
},
"mistral:7b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-nemo": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-small": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"mistral-large": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Vision-specific models # Vision-specific models
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
@@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING}, "moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
# Phi series # Phi series
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING}, "phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "phi4": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Command R # Command R
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "command-r": {
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TEXT,
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"command-r:35b": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"command-r-plus": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
# Granite (IBM) # Granite (IBM)
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, "granite3-dense": {
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING}, ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
"granite3-moe": {
ModelCapability.TEXT,
ModelCapability.TOOLS,
ModelCapability.JSON,
ModelCapability.STREAMING,
},
} }
@@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
# These are tried in order when the primary model doesn't support a capability # These are tried in order when the primary model doesn't support a capability
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = { DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
ModelCapability.VISION: [ ModelCapability.VISION: [
"llama3.2:3b", # Fast vision model "llama3.2:3b", # Fast vision model
"llava:7b", # Classic vision model "llava:7b", # Classic vision model
"qwen2.5-vl:3b", # Qwen vision "qwen2.5-vl:3b", # Qwen vision
"moondream:1.8b", # Tiny vision model (last resort) "moondream:1.8b", # Tiny vision model (last resort)
], ],
ModelCapability.TOOLS: [ ModelCapability.TOOLS: [
"llama3.1:8b-instruct", # Best tool use "llama3.1:8b-instruct", # Best tool use
"llama3.2:3b", # Smaller but capable "llama3.2:3b", # Smaller but capable
"qwen2.5:7b", # Reliable fallback "qwen2.5:7b", # Reliable fallback
], ],
ModelCapability.AUDIO: [ ModelCapability.AUDIO: [
# Audio models are less common in Ollama # Audio models are less common in Ollama
@@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
@dataclass @dataclass
class ModelInfo: class ModelInfo:
"""Information about a model's capabilities and availability.""" """Information about a model's capabilities and availability."""
name: str name: str
capabilities: set[ModelCapability] = field(default_factory=set) capabilities: set[ModelCapability] = field(default_factory=set)
is_available: bool = False is_available: bool = False
is_pulled: bool = False is_pulled: bool = False
size_mb: Optional[int] = None size_mb: Optional[int] = None
description: str = "" description: str = ""
def supports(self, capability: ModelCapability) -> bool: def supports(self, capability: ModelCapability) -> bool:
"""Check if model supports a specific capability.""" """Check if model supports a specific capability."""
return capability in self.capabilities return capability in self.capabilities
@@ -142,26 +288,26 @@ class ModelInfo:
class MultiModalManager: class MultiModalManager:
"""Manages multi-modal model capabilities and fallback chains. """Manages multi-modal model capabilities and fallback chains.
This class: This class:
1. Detects what capabilities each model has 1. Detects what capabilities each model has
2. Maintains fallback chains for different capabilities 2. Maintains fallback chains for different capabilities
3. Pulls models on-demand with automatic fallback 3. Pulls models on-demand with automatic fallback
4. Routes requests to appropriate models based on content type 4. Routes requests to appropriate models based on content type
""" """
def __init__(self, ollama_url: Optional[str] = None) -> None: def __init__(self, ollama_url: Optional[str] = None) -> None:
self.ollama_url = ollama_url or settings.ollama_url self.ollama_url = ollama_url or settings.ollama_url
self._available_models: dict[str, ModelInfo] = {} self._available_models: dict[str, ModelInfo] = {}
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS) self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
self._refresh_available_models() self._refresh_available_models()
def _refresh_available_models(self) -> None: def _refresh_available_models(self) -> None:
"""Query Ollama for available models.""" """Query Ollama for available models."""
try: try:
import urllib.request
import json import json
import urllib.request
url = self.ollama_url.replace("localhost", "127.0.0.1") url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request( req = urllib.request.Request(
f"{url}/api/tags", f"{url}/api/tags",
@@ -170,7 +316,7 @@ class MultiModalManager:
) )
with urllib.request.urlopen(req, timeout=5) as response: with urllib.request.urlopen(req, timeout=5) as response:
data = json.loads(response.read().decode()) data = json.loads(response.read().decode())
for model_data in data.get("models", []): for model_data in data.get("models", []):
name = model_data.get("name", "") name = model_data.get("name", "")
self._available_models[name] = ModelInfo( self._available_models[name] = ModelInfo(
@@ -181,58 +327,53 @@ class MultiModalManager:
size_mb=model_data.get("size", 0) // (1024 * 1024), size_mb=model_data.get("size", 0) // (1024 * 1024),
description=model_data.get("details", {}).get("family", ""), description=model_data.get("details", {}).get("family", ""),
) )
logger.info("Found %d models in Ollama", len(self._available_models)) logger.info("Found %d models in Ollama", len(self._available_models))
except Exception as exc: except Exception as exc:
logger.warning("Could not refresh available models: %s", exc) logger.warning("Could not refresh available models: %s", exc)
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]: def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Detect capabilities for a model based on known data.""" """Detect capabilities for a model based on known data."""
# Normalize model name (strip tags for lookup) # Normalize model name (strip tags for lookup)
base_name = model_name.split(":")[0] base_name = model_name.split(":")[0]
# Try exact match first # Try exact match first
if model_name in KNOWN_MODEL_CAPABILITIES: if model_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[model_name]) return set(KNOWN_MODEL_CAPABILITIES[model_name])
# Try base name match # Try base name match
if base_name in KNOWN_MODEL_CAPABILITIES: if base_name in KNOWN_MODEL_CAPABILITIES:
return set(KNOWN_MODEL_CAPABILITIES[base_name]) return set(KNOWN_MODEL_CAPABILITIES[base_name])
# Default to text-only for unknown models # Default to text-only for unknown models
logger.debug("Unknown model %s, defaulting to TEXT only", model_name) logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
return {ModelCapability.TEXT, ModelCapability.STREAMING} return {ModelCapability.TEXT, ModelCapability.STREAMING}
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]: def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
"""Get capabilities for a specific model.""" """Get capabilities for a specific model."""
if model_name in self._available_models: if model_name in self._available_models:
return self._available_models[model_name].capabilities return self._available_models[model_name].capabilities
return self._detect_capabilities(model_name) return self._detect_capabilities(model_name)
def model_supports(self, model_name: str, capability: ModelCapability) -> bool: def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
"""Check if a model supports a specific capability.""" """Check if a model supports a specific capability."""
capabilities = self.get_model_capabilities(model_name) capabilities = self.get_model_capabilities(model_name)
return capability in capabilities return capability in capabilities
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]: def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
"""Get all available models that support a capability.""" """Get all available models that support a capability."""
return [ return [info for info in self._available_models.values() if capability in info.capabilities]
info for info in self._available_models.values()
if capability in info.capabilities
]
def get_best_model_for( def get_best_model_for(
self, self, capability: ModelCapability, preferred_model: Optional[str] = None
capability: ModelCapability,
preferred_model: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
"""Get the best available model for a specific capability. """Get the best available model for a specific capability.
Args: Args:
capability: The required capability capability: The required capability
preferred_model: Preferred model to use if available and capable preferred_model: Preferred model to use if available and capable
Returns: Returns:
Model name or None if no suitable model found Model name or None if no suitable model found
""" """
@@ -243,25 +384,26 @@ class MultiModalManager:
return preferred_model return preferred_model
logger.debug( logger.debug(
"Preferred model %s doesn't support %s, checking fallbacks", "Preferred model %s doesn't support %s, checking fallbacks",
preferred_model, capability.name preferred_model,
capability.name,
) )
# Check fallback chain for this capability # Check fallback chain for this capability
fallback_chain = self._fallback_chains.get(capability, []) fallback_chain = self._fallback_chains.get(capability, [])
for model_name in fallback_chain: for model_name in fallback_chain:
if model_name in self._available_models: if model_name in self._available_models:
logger.debug("Using fallback model %s for %s", model_name, capability.name) logger.debug("Using fallback model %s for %s", model_name, capability.name)
return model_name return model_name
# Find any available model with this capability # Find any available model with this capability
capable_models = self.get_models_with_capability(capability) capable_models = self.get_models_with_capability(capability)
if capable_models: if capable_models:
# Sort by size (prefer smaller/faster models as fallback) # Sort by size (prefer smaller/faster models as fallback)
capable_models.sort(key=lambda m: m.size_mb or float('inf')) capable_models.sort(key=lambda m: m.size_mb or float("inf"))
return capable_models[0].name return capable_models[0].name
return None return None
def pull_model_with_fallback( def pull_model_with_fallback(
self, self,
primary_model: str, primary_model: str,
@@ -269,58 +411,58 @@ class MultiModalManager:
auto_pull: bool = True, auto_pull: bool = True,
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Pull a model with automatic fallback if unavailable. """Pull a model with automatic fallback if unavailable.
Args: Args:
primary_model: The desired model to use primary_model: The desired model to use
capability: Required capability (for finding fallback) capability: Required capability (for finding fallback)
auto_pull: Whether to attempt pulling missing models auto_pull: Whether to attempt pulling missing models
Returns: Returns:
Tuple of (model_name, is_fallback) Tuple of (model_name, is_fallback)
""" """
# Check if primary model is already available # Check if primary model is already available
if primary_model in self._available_models: if primary_model in self._available_models:
return primary_model, False return primary_model, False
# Try to pull the primary model # Try to pull the primary model
if auto_pull: if auto_pull:
if self._pull_model(primary_model): if self._pull_model(primary_model):
return primary_model, False return primary_model, False
# Need to find a fallback # Need to find a fallback
if capability: if capability:
fallback = self.get_best_model_for(capability, primary_model) fallback = self.get_best_model_for(capability, primary_model)
if fallback: if fallback:
logger.info( logger.info(
"Primary model %s unavailable, using fallback %s", "Primary model %s unavailable, using fallback %s", primary_model, fallback
primary_model, fallback
) )
return fallback, True return fallback, True
# Last resort: use the configured default model # Last resort: use the configured default model
default_model = settings.ollama_model default_model = settings.ollama_model
if default_model in self._available_models: if default_model in self._available_models:
logger.warning( logger.warning(
"Falling back to default model %s (primary: %s unavailable)", "Falling back to default model %s (primary: %s unavailable)",
default_model, primary_model default_model,
primary_model,
) )
return default_model, True return default_model, True
# Absolute last resort # Absolute last resort
return primary_model, False return primary_model, False
def _pull_model(self, model_name: str) -> bool: def _pull_model(self, model_name: str) -> bool:
"""Attempt to pull a model from Ollama. """Attempt to pull a model from Ollama.
Returns: Returns:
True if successful or model already exists True if successful or model already exists
""" """
try: try:
import urllib.request
import json import json
import urllib.request
logger.info("Pulling model: %s", model_name) logger.info("Pulling model: %s", model_name)
url = self.ollama_url.replace("localhost", "127.0.0.1") url = self.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request( req = urllib.request.Request(
f"{url}/api/pull", f"{url}/api/pull",
@@ -328,7 +470,7 @@ class MultiModalManager:
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
data=json.dumps({"name": model_name, "stream": False}).encode(), data=json.dumps({"name": model_name, "stream": False}).encode(),
) )
with urllib.request.urlopen(req, timeout=300) as response: with urllib.request.urlopen(req, timeout=300) as response:
if response.status == 200: if response.status == 200:
logger.info("Successfully pulled model: %s", model_name) logger.info("Successfully pulled model: %s", model_name)
@@ -338,55 +480,51 @@ class MultiModalManager:
else: else:
logger.error("Failed to pull %s: HTTP %s", model_name, response.status) logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
return False return False
except Exception as exc: except Exception as exc:
logger.error("Error pulling model %s: %s", model_name, exc) logger.error("Error pulling model %s: %s", model_name, exc)
return False return False
def configure_fallback_chain( def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None:
self,
capability: ModelCapability,
models: list[str]
) -> None:
"""Configure a custom fallback chain for a capability.""" """Configure a custom fallback chain for a capability."""
self._fallback_chains[capability] = models self._fallback_chains[capability] = models
logger.info("Configured fallback chain for %s: %s", capability.name, models) logger.info("Configured fallback chain for %s: %s", capability.name, models)
def get_fallback_chain(self, capability: ModelCapability) -> list[str]: def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
"""Get the fallback chain for a capability.""" """Get the fallback chain for a capability."""
return list(self._fallback_chains.get(capability, [])) return list(self._fallback_chains.get(capability, []))
def list_available_models(self) -> list[ModelInfo]: def list_available_models(self) -> list[ModelInfo]:
"""List all available models with their capabilities.""" """List all available models with their capabilities."""
return list(self._available_models.values()) return list(self._available_models.values())
def refresh(self) -> None: def refresh(self) -> None:
"""Refresh the list of available models.""" """Refresh the list of available models."""
self._refresh_available_models() self._refresh_available_models()
def get_model_for_content( def get_model_for_content(
self, self,
content_type: str, # "text", "image", "audio", "multimodal" content_type: str, # "text", "image", "audio", "multimodal"
preferred_model: Optional[str] = None, preferred_model: Optional[str] = None,
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Get appropriate model based on content type. """Get appropriate model based on content type.
Args: Args:
content_type: Type of content (text, image, audio, multimodal) content_type: Type of content (text, image, audio, multimodal)
preferred_model: User's preferred model preferred_model: User's preferred model
Returns: Returns:
Tuple of (model_name, is_fallback) Tuple of (model_name, is_fallback)
""" """
content_type = content_type.lower() content_type = content_type.lower()
if content_type in ("image", "vision", "multimodal"): if content_type in ("image", "vision", "multimodal"):
# For vision content, we need a vision-capable model # For vision content, we need a vision-capable model
return self.pull_model_with_fallback( return self.pull_model_with_fallback(
preferred_model or "llava:7b", preferred_model or "llava:7b",
capability=ModelCapability.VISION, capability=ModelCapability.VISION,
) )
elif content_type == "audio": elif content_type == "audio":
# Audio support is limited in Ollama # Audio support is limited in Ollama
# Would need specific audio models # Would need specific audio models
@@ -395,7 +533,7 @@ class MultiModalManager:
preferred_model or settings.ollama_model, preferred_model or settings.ollama_model,
capability=ModelCapability.TEXT, capability=ModelCapability.TEXT,
) )
else: else:
# Standard text content # Standard text content
return self.pull_model_with_fallback( return self.pull_model_with_fallback(
@@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager:
def get_model_for_capability( def get_model_for_capability(
capability: ModelCapability, capability: ModelCapability, preferred_model: Optional[str] = None
preferred_model: Optional[str] = None
) -> Optional[str]: ) -> Optional[str]:
"""Convenience function to get best model for a capability.""" """Convenience function to get best model for a capability."""
return get_multimodal_manager().get_best_model_for(capability, preferred_model) return get_multimodal_manager().get_best_model_for(capability, preferred_model)
@@ -430,9 +567,7 @@ def pull_model_with_fallback(
auto_pull: bool = True, auto_pull: bool = True,
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Convenience function to pull model with fallback.""" """Convenience function to pull model with fallback."""
return get_multimodal_manager().pull_model_with_fallback( return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull)
primary_model, capability, auto_pull
)
def model_supports_vision(model_name: str) -> bool: def model_supports_vision(model_name: str) -> bool:

View File

@@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db")
class ModelFormat(str, Enum): class ModelFormat(str, Enum):
"""Supported model weight formats.""" """Supported model weight formats."""
GGUF = "gguf" # Ollama-compatible quantised weights
SAFETENSORS = "safetensors" # HuggingFace safetensors GGUF = "gguf" # Ollama-compatible quantised weights
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory SAFETENSORS = "safetensors" # HuggingFace safetensors
OLLAMA = "ollama" # Already loaded in Ollama by name HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
OLLAMA = "ollama" # Already loaded in Ollama by name
class ModelRole(str, Enum): class ModelRole(str, Enum):
"""Role a model can play in the system (OpenClaw-RL style).""" """Role a model can play in the system (OpenClaw-RL style)."""
GENERAL = "general" # Default agent inference
REWARD = "reward" # Process Reward Model (PRM) scoring GENERAL = "general" # Default agent inference
TEACHER = "teacher" # On-policy distillation teacher REWARD = "reward" # Process Reward Model (PRM) scoring
JUDGE = "judge" # Output quality evaluation TEACHER = "teacher" # On-policy distillation teacher
JUDGE = "judge" # Output quality evaluation
@dataclass @dataclass
class CustomModel: class CustomModel:
"""A registered custom model.""" """A registered custom model."""
name: str name: str
format: ModelFormat format: ModelFormat
path: str # Absolute path or Ollama model name path: str # Absolute path or Ollama model name
role: ModelRole = ModelRole.GENERAL role: ModelRole = ModelRole.GENERAL
context_window: int = 4096 context_window: int = 4096
description: str = "" description: str = ""
@@ -141,10 +144,16 @@ class ModelRegistry:
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", """,
( (
model.name, model.format.value, model.path, model.name,
model.role.value, model.context_window, model.description, model.format.value,
model.registered_at, int(model.active), model.path,
model.default_temperature, model.max_tokens, model.role.value,
model.context_window,
model.description,
model.registered_at,
int(model.active),
model.default_temperature,
model.max_tokens,
), ),
) )
conn.commit() conn.commit()
@@ -160,9 +169,7 @@ class ModelRegistry:
return False return False
conn = _get_conn() conn = _get_conn()
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,)) conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
conn.execute( conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
)
conn.commit() conn.commit()
conn.close() conn.close()
del self._models[name] del self._models[name]

View File

@@ -9,8 +9,8 @@ No cloud push services — everything stays local.
""" """
import logging import logging
import subprocess
import platform import platform
import subprocess
from collections import deque from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -25,9 +25,7 @@ class Notification:
title: str title: str
message: str message: str
category: str # swarm | task | agent | system | payment category: str # swarm | task | agent | system | payment
timestamp: str = field( timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
read: bool = False read: bool = False
@@ -74,9 +72,11 @@ class PushNotifier:
def _native_notify(self, title: str, message: str) -> None: def _native_notify(self, title: str, message: str) -> None:
"""Send a native macOS notification via osascript.""" """Send a native macOS notification via osascript."""
try: try:
safe_message = message.replace("\\", "\\\\").replace('"', '\\"')
safe_title = title.replace("\\", "\\\\").replace('"', '\\"')
script = ( script = (
f'display notification "{message}" ' f'display notification "{safe_message}" '
f'with title "Agent Dashboard" subtitle "{title}"' f'with title "Agent Dashboard" subtitle "{safe_title}"'
) )
subprocess.Popen( subprocess.Popen(
["osascript", "-e", script], ["osascript", "-e", script],
@@ -114,7 +114,7 @@ class PushNotifier:
def clear(self) -> None: def clear(self) -> None:
self._notifications.clear() self._notifications.clear()
def add_listener(self, callback) -> None: def add_listener(self, callback: "Callable[[Notification], None]") -> None:
"""Register a callback for real-time notification delivery.""" """Register a callback for real-time notification delivery."""
self._listeners.append(callback) self._listeners.append(callback)
@@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None:
logger.info("Briefing ready but no pending approvals — skipping native notification") logger.info("Briefing ready but no pending approvals — skipping native notification")
return return
message = ( message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval."
f"Your morning briefing is ready. "
f"{n_approvals} item(s) await your approval."
)
notifier.notify( notifier.notify(
title="Morning Briefing Ready", title="Morning Briefing Ready",
message=message, message=message,

View File

@@ -156,33 +156,23 @@ class OpenFangClient:
async def browse(self, url: str, instruction: str = "") -> HandResult: async def browse(self, url: str, instruction: str = "") -> HandResult:
"""Web automation via OpenFang's Browser hand.""" """Web automation via OpenFang's Browser hand."""
return await self.execute_hand( return await self.execute_hand("browser", {"url": url, "instruction": instruction})
"browser", {"url": url, "instruction": instruction}
)
async def collect(self, target: str, depth: str = "shallow") -> HandResult: async def collect(self, target: str, depth: str = "shallow") -> HandResult:
"""OSINT collection via OpenFang's Collector hand.""" """OSINT collection via OpenFang's Collector hand."""
return await self.execute_hand( return await self.execute_hand("collector", {"target": target, "depth": depth})
"collector", {"target": target, "depth": depth}
)
async def predict(self, question: str, horizon: str = "1w") -> HandResult: async def predict(self, question: str, horizon: str = "1w") -> HandResult:
"""Superforecasting via OpenFang's Predictor hand.""" """Superforecasting via OpenFang's Predictor hand."""
return await self.execute_hand( return await self.execute_hand("predictor", {"question": question, "horizon": horizon})
"predictor", {"question": question, "horizon": horizon}
)
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult: async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
"""Prospect discovery via OpenFang's Lead hand.""" """Prospect discovery via OpenFang's Lead hand."""
return await self.execute_hand( return await self.execute_hand("lead", {"icp": icp, "max_results": max_results})
"lead", {"icp": icp, "max_results": max_results}
)
async def research(self, topic: str, depth: str = "standard") -> HandResult: async def research(self, topic: str, depth: str = "standard") -> HandResult:
"""Deep research via OpenFang's Researcher hand.""" """Deep research via OpenFang's Researcher hand."""
return await self.execute_hand( return await self.execute_hand("researcher", {"topic": topic, "depth": depth})
"researcher", {"topic": topic, "depth": depth}
)
# ── Inventory ──────────────────────────────────────────────────────────── # ── Inventory ────────────────────────────────────────────────────────────

View File

@@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client
try: try:
from mcp.schemas.base import create_tool_schema from mcp.schemas.base import create_tool_schema
except ImportError: except ImportError:
def create_tool_schema(**kwargs): def create_tool_schema(**kwargs):
return kwargs return kwargs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ── Tool schemas ───────────────────────────────────────────────────────────── # ── Tool schemas ─────────────────────────────────────────────────────────────

View File

@@ -1,7 +1,7 @@
"""Cascade LLM Router — Automatic failover between providers.""" """Cascade LLM Router — Automatic failover between providers."""
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
from .api import router from .api import router
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
__all__ = [ __all__ = [
"CascadeRouter", "CascadeRouter",

View File

@@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"])
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
"""Request body for completions.""" """Request body for completions."""
messages: list[dict[str, str]] messages: list[dict[str, str]]
model: str | None = None model: str | None = None
temperature: float = 0.7 temperature: float = 0.7
@@ -23,6 +24,7 @@ class CompletionRequest(BaseModel):
class CompletionResponse(BaseModel): class CompletionResponse(BaseModel):
"""Response from completion endpoint.""" """Response from completion endpoint."""
content: str content: str
provider: str provider: str
model: str model: str
@@ -31,6 +33,7 @@ class CompletionResponse(BaseModel):
class ProviderControl(BaseModel): class ProviderControl(BaseModel):
"""Control a provider's status.""" """Control a provider's status."""
action: str # "enable", "disable", "reset_circuit" action: str # "enable", "disable", "reset_circuit"
@@ -45,7 +48,7 @@ async def complete(
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)], cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Complete a conversation with automatic failover. """Complete a conversation with automatic failover.
Routes through providers in priority order until one succeeds. Routes through providers in priority order until one succeeds.
""" """
try: try:
@@ -108,30 +111,32 @@ async def control_provider(
if p.name == provider_name: if p.name == provider_name:
provider = p provider = p
break break
if not provider: if not provider:
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found") raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
if control.action == "enable": if control.action == "enable":
provider.enabled = True provider.enabled = True
provider.status = provider.status.__class__.HEALTHY provider.status = provider.status.__class__.HEALTHY
return {"message": f"Provider {provider_name} enabled"} return {"message": f"Provider {provider_name} enabled"}
elif control.action == "disable": elif control.action == "disable":
provider.enabled = False provider.enabled = False
from .cascade import ProviderStatus from .cascade import ProviderStatus
provider.status = ProviderStatus.DISABLED provider.status = ProviderStatus.DISABLED
return {"message": f"Provider {provider_name} disabled"} return {"message": f"Provider {provider_name} disabled"}
elif control.action == "reset_circuit": elif control.action == "reset_circuit":
from .cascade import CircuitState, ProviderStatus from .cascade import CircuitState, ProviderStatus
provider.circuit_state = CircuitState.CLOSED provider.circuit_state = CircuitState.CLOSED
provider.circuit_opened_at = None provider.circuit_opened_at = None
provider.half_open_calls = 0 provider.half_open_calls = 0
provider.metrics.consecutive_failures = 0 provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY provider.status = ProviderStatus.HEALTHY
return {"message": f"Circuit breaker reset for {provider_name}"} return {"message": f"Circuit breaker reset for {provider_name}"}
else: else:
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}") raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
@@ -142,28 +147,35 @@ async def run_health_check(
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Run health checks on all providers.""" """Run health checks on all providers."""
results = [] results = []
for provider in cascade.providers: for provider in cascade.providers:
# Quick ping to check availability # Quick ping to check availability
is_healthy = cascade._check_provider_available(provider) is_healthy = cascade._check_provider_available(provider)
from .cascade import ProviderStatus from .cascade import ProviderStatus
if is_healthy: if is_healthy:
if provider.status == ProviderStatus.UNHEALTHY: if provider.status == ProviderStatus.UNHEALTHY:
# Reset circuit if it was open but now healthy # Reset circuit if it was open but now healthy
provider.circuit_state = provider.circuit_state.__class__.CLOSED provider.circuit_state = provider.circuit_state.__class__.CLOSED
provider.circuit_opened_at = None provider.circuit_opened_at = None
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED provider.status = (
ProviderStatus.HEALTHY
if provider.metrics.error_rate < 0.1
else ProviderStatus.DEGRADED
)
else: else:
provider.status = ProviderStatus.UNHEALTHY provider.status = ProviderStatus.UNHEALTHY
results.append({ results.append(
"name": provider.name, {
"type": provider.type, "name": provider.name,
"healthy": is_healthy, "type": provider.type,
"status": provider.status.value, "healthy": is_healthy,
}) "status": provider.status.value,
}
)
return { return {
"checked_at": asyncio.get_event_loop().time(), "checked_at": asyncio.get_event_loop().time(),
"providers": results, "providers": results,
@@ -177,7 +189,7 @@ async def get_config(
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Get router configuration (without secrets).""" """Get router configuration (without secrets)."""
cfg = cascade.config cfg = cascade.config
return { return {
"timeout_seconds": cfg.timeout_seconds, "timeout_seconds": cfg.timeout_seconds,
"max_retries_per_provider": cfg.max_retries_per_provider, "max_retries_per_provider": cfg.max_retries_per_provider,

View File

@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
class ProviderStatus(Enum): class ProviderStatus(Enum):
"""Health status of a provider.""" """Health status of a provider."""
HEALTHY = "healthy" HEALTHY = "healthy"
DEGRADED = "degraded" # Working but slow or occasional errors DEGRADED = "degraded" # Working but slow or occasional errors
UNHEALTHY = "unhealthy" # Circuit breaker open UNHEALTHY = "unhealthy" # Circuit breaker open
@@ -41,22 +42,25 @@ class ProviderStatus(Enum):
class CircuitState(Enum): class CircuitState(Enum):
"""Circuit breaker state.""" """Circuit breaker state."""
CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests CLOSED = "closed" # Normal operation
OPEN = "open" # Failing, rejecting requests
HALF_OPEN = "half_open" # Testing if recovered HALF_OPEN = "half_open" # Testing if recovered
class ContentType(Enum): class ContentType(Enum):
"""Type of content in the request.""" """Type of content in the request."""
TEXT = "text" TEXT = "text"
VISION = "vision" # Contains images VISION = "vision" # Contains images
AUDIO = "audio" # Contains audio AUDIO = "audio" # Contains audio
MULTIMODAL = "multimodal" # Multiple content types MULTIMODAL = "multimodal" # Multiple content types
@dataclass @dataclass
class ProviderMetrics: class ProviderMetrics:
"""Metrics for a single provider.""" """Metrics for a single provider."""
total_requests: int = 0 total_requests: int = 0
successful_requests: int = 0 successful_requests: int = 0
failed_requests: int = 0 failed_requests: int = 0
@@ -64,13 +68,13 @@ class ProviderMetrics:
last_request_time: Optional[str] = None last_request_time: Optional[str] = None
last_error_time: Optional[str] = None last_error_time: Optional[str] = None
consecutive_failures: int = 0 consecutive_failures: int = 0
@property @property
def avg_latency_ms(self) -> float: def avg_latency_ms(self) -> float:
if self.total_requests == 0: if self.total_requests == 0:
return 0.0 return 0.0
return self.total_latency_ms / self.total_requests return self.total_latency_ms / self.total_requests
@property @property
def error_rate(self) -> float: def error_rate(self) -> float:
if self.total_requests == 0: if self.total_requests == 0:
@@ -81,6 +85,7 @@ class ProviderMetrics:
@dataclass @dataclass
class ModelCapability: class ModelCapability:
"""Capabilities a model supports.""" """Capabilities a model supports."""
name: str name: str
supports_vision: bool = False supports_vision: bool = False
supports_audio: bool = False supports_audio: bool = False
@@ -93,6 +98,7 @@ class ModelCapability:
@dataclass @dataclass
class Provider: class Provider:
"""LLM provider configuration and state.""" """LLM provider configuration and state."""
name: str name: str
type: str # ollama, openai, anthropic, airllm type: str # ollama, openai, anthropic, airllm
enabled: bool enabled: bool
@@ -101,14 +107,14 @@ class Provider:
api_key: Optional[str] = None api_key: Optional[str] = None
base_url: Optional[str] = None base_url: Optional[str] = None
models: list[dict] = field(default_factory=list) models: list[dict] = field(default_factory=list)
# Runtime state # Runtime state
status: ProviderStatus = ProviderStatus.HEALTHY status: ProviderStatus = ProviderStatus.HEALTHY
metrics: ProviderMetrics = field(default_factory=ProviderMetrics) metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
circuit_state: CircuitState = CircuitState.CLOSED circuit_state: CircuitState = CircuitState.CLOSED
circuit_opened_at: Optional[float] = None circuit_opened_at: Optional[float] = None
half_open_calls: int = 0 half_open_calls: int = 0
def get_default_model(self) -> Optional[str]: def get_default_model(self) -> Optional[str]:
"""Get the default model for this provider.""" """Get the default model for this provider."""
for model in self.models: for model in self.models:
@@ -117,7 +123,7 @@ class Provider:
if self.models: if self.models:
return self.models[0]["name"] return self.models[0]["name"]
return None return None
def get_model_with_capability(self, capability: str) -> Optional[str]: def get_model_with_capability(self, capability: str) -> Optional[str]:
"""Get a model that supports the given capability.""" """Get a model that supports the given capability."""
for model in self.models: for model in self.models:
@@ -126,7 +132,7 @@ class Provider:
return model["name"] return model["name"]
# Fall back to default # Fall back to default
return self.get_default_model() return self.get_default_model()
def model_has_capability(self, model_name: str, capability: str) -> bool: def model_has_capability(self, model_name: str, capability: str) -> bool:
"""Check if a specific model has a capability.""" """Check if a specific model has a capability."""
for model in self.models: for model in self.models:
@@ -139,6 +145,7 @@ class Provider:
@dataclass @dataclass
class RouterConfig: class RouterConfig:
"""Cascade router configuration.""" """Cascade router configuration."""
timeout_seconds: int = 30 timeout_seconds: int = 30
max_retries_per_provider: int = 2 max_retries_per_provider: int = 2
retry_delay_seconds: int = 1 retry_delay_seconds: int = 1
@@ -154,22 +161,22 @@ class RouterConfig:
class CascadeRouter: class CascadeRouter:
"""Routes LLM requests with automatic failover. """Routes LLM requests with automatic failover.
Now with multi-modal support: Now with multi-modal support:
- Automatically detects content type (text, vision, audio) - Automatically detects content type (text, vision, audio)
- Selects appropriate models based on capabilities - Selects appropriate models based on capabilities
- Falls back through capability-specific model chains - Falls back through capability-specific model chains
- Supports image URLs and base64 encoding - Supports image URLs and base64 encoding
Usage: Usage:
router = CascadeRouter() router = CascadeRouter()
# Text request # Text request
response = await router.complete( response = await router.complete(
messages=[{"role": "user", "content": "Hello"}], messages=[{"role": "user", "content": "Hello"}],
model="llama3.2" model="llama3.2"
) )
# Vision request (automatically detects and selects vision model) # Vision request (automatically detects and selects vision model)
response = await router.complete( response = await router.complete(
messages=[{ messages=[{
@@ -179,68 +186,75 @@ class CascadeRouter:
}], }],
model="llava:7b" model="llava:7b"
) )
# Check metrics # Check metrics
metrics = router.get_metrics() metrics = router.get_metrics()
""" """
def __init__(self, config_path: Optional[Path] = None) -> None: def __init__(self, config_path: Optional[Path] = None) -> None:
self.config_path = config_path or Path("config/providers.yaml") self.config_path = config_path or Path("config/providers.yaml")
self.providers: list[Provider] = [] self.providers: list[Provider] = []
self.config: RouterConfig = RouterConfig() self.config: RouterConfig = RouterConfig()
self._load_config() self._load_config()
# Initialize multi-modal manager if available # Initialize multi-modal manager if available
self._mm_manager: Optional[Any] = None self._mm_manager: Optional[Any] = None
try: try:
from infrastructure.models.multimodal import get_multimodal_manager from infrastructure.models.multimodal import get_multimodal_manager
self._mm_manager = get_multimodal_manager() self._mm_manager = get_multimodal_manager()
except Exception as exc: except Exception as exc:
logger.debug("Multi-modal manager not available: %s", exc) logger.debug("Multi-modal manager not available: %s", exc)
logger.info("CascadeRouter initialized with %d providers", len(self.providers)) logger.info("CascadeRouter initialized with %d providers", len(self.providers))
def _load_config(self) -> None: def _load_config(self) -> None:
"""Load configuration from YAML.""" """Load configuration from YAML."""
if not self.config_path.exists(): if not self.config_path.exists():
logger.warning("Config not found: %s, using defaults", self.config_path) logger.warning("Config not found: %s, using defaults", self.config_path)
return return
try: try:
if yaml is None: if yaml is None:
raise RuntimeError("PyYAML not installed") raise RuntimeError("PyYAML not installed")
content = self.config_path.read_text() content = self.config_path.read_text()
# Expand environment variables # Expand environment variables
content = self._expand_env_vars(content) content = self._expand_env_vars(content)
data = yaml.safe_load(content) data = yaml.safe_load(content)
# Load cascade settings # Load cascade settings
cascade = data.get("cascade", {}) cascade = data.get("cascade", {})
# Load fallback chains # Load fallback chains
fallback_chains = data.get("fallback_chains", {}) fallback_chains = data.get("fallback_chains", {})
# Load multi-modal settings # Load multi-modal settings
multimodal = data.get("multimodal", {}) multimodal = data.get("multimodal", {})
self.config = RouterConfig( self.config = RouterConfig(
timeout_seconds=cascade.get("timeout_seconds", 30), timeout_seconds=cascade.get("timeout_seconds", 30),
max_retries_per_provider=cascade.get("max_retries_per_provider", 2), max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
retry_delay_seconds=cascade.get("retry_delay_seconds", 1), retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5), circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60), "failure_threshold", 5
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2), ),
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
"recovery_timeout", 60
),
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
"half_open_max_calls", 2
),
auto_pull_models=multimodal.get("auto_pull", True), auto_pull_models=multimodal.get("auto_pull", True),
fallback_chains=fallback_chains, fallback_chains=fallback_chains,
) )
# Load providers # Load providers
for p_data in data.get("providers", []): for p_data in data.get("providers", []):
# Skip disabled providers # Skip disabled providers
if not p_data.get("enabled", False): if not p_data.get("enabled", False):
continue continue
provider = Provider( provider = Provider(
name=p_data["name"], name=p_data["name"],
type=p_data["type"], type=p_data["type"],
@@ -251,30 +265,34 @@ class CascadeRouter:
base_url=p_data.get("base_url"), base_url=p_data.get("base_url"),
models=p_data.get("models", []), models=p_data.get("models", []),
) )
# Check if provider is actually available # Check if provider is actually available
if self._check_provider_available(provider): if self._check_provider_available(provider):
self.providers.append(provider) self.providers.append(provider)
else: else:
logger.warning("Provider %s not available, skipping", provider.name) logger.warning("Provider %s not available, skipping", provider.name)
# Sort by priority # Sort by priority
self.providers.sort(key=lambda p: p.priority) self.providers.sort(key=lambda p: p.priority)
except Exception as exc: except Exception as exc:
logger.error("Failed to load config: %s", exc) logger.error("Failed to load config: %s", exc)
def _expand_env_vars(self, content: str) -> str: def _expand_env_vars(self, content: str) -> str:
"""Expand ${VAR} syntax in YAML content.""" """Expand ${VAR} syntax in YAML content.
Uses os.environ directly (not settings) because this is a generic
YAML config loader that must expand arbitrary variable references.
"""
import os import os
import re import re
def replace_var(match): def replace_var(match: "re.Match[str]") -> str:
var_name = match.group(1) var_name = match.group(1)
return os.environ.get(var_name, match.group(0)) return os.environ.get(var_name, match.group(0))
return re.sub(r"\$\{(\w+)\}", replace_var, content) return re.sub(r"\$\{(\w+)\}", replace_var, content)
def _check_provider_available(self, provider: Provider) -> bool: def _check_provider_available(self, provider: Provider) -> bool:
"""Check if a provider is actually available.""" """Check if a provider is actually available."""
if provider.type == "ollama": if provider.type == "ollama":
@@ -288,48 +306,49 @@ class CascadeRouter:
return response.status_code == 200 return response.status_code == 200
except Exception: except Exception:
return False return False
elif provider.type == "airllm": elif provider.type == "airllm":
# Check if airllm is installed # Check if airllm is installed
try: try:
import airllm import airllm
return True return True
except ImportError: except ImportError:
return False return False
elif provider.type in ("openai", "anthropic", "grok"): elif provider.type in ("openai", "anthropic", "grok"):
# Check if API key is set # Check if API key is set
return provider.api_key is not None and provider.api_key != "" return provider.api_key is not None and provider.api_key != ""
return True return True
def _detect_content_type(self, messages: list[dict]) -> ContentType: def _detect_content_type(self, messages: list[dict]) -> ContentType:
"""Detect the type of content in the messages. """Detect the type of content in the messages.
Checks for images, audio, etc. in the message content. Checks for images, audio, etc. in the message content.
""" """
has_image = False has_image = False
has_audio = False has_audio = False
for msg in messages: for msg in messages:
content = msg.get("content", "") content = msg.get("content", "")
# Check for image URLs/paths # Check for image URLs/paths
if msg.get("images"): if msg.get("images"):
has_image = True has_image = True
# Check for image URLs in content # Check for image URLs in content
if isinstance(content, str): if isinstance(content, str):
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp') image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
if any(ext in content.lower() for ext in image_extensions): if any(ext in content.lower() for ext in image_extensions):
has_image = True has_image = True
if content.startswith("data:image/"): if content.startswith("data:image/"):
has_image = True has_image = True
# Check for audio # Check for audio
if msg.get("audio"): if msg.get("audio"):
has_audio = True has_audio = True
# Check for multimodal content structure # Check for multimodal content structure
if isinstance(content, list): if isinstance(content, list):
for item in content: for item in content:
@@ -338,7 +357,7 @@ class CascadeRouter:
has_image = True has_image = True
elif item.get("type") == "audio": elif item.get("type") == "audio":
has_audio = True has_audio = True
if has_image and has_audio: if has_image and has_audio:
return ContentType.MULTIMODAL return ContentType.MULTIMODAL
elif has_image: elif has_image:
@@ -346,12 +365,9 @@ class CascadeRouter:
elif has_audio: elif has_audio:
return ContentType.AUDIO return ContentType.AUDIO
return ContentType.TEXT return ContentType.TEXT
def _get_fallback_model( def _get_fallback_model(
self, self, provider: Provider, original_model: str, content_type: ContentType
provider: Provider,
original_model: str,
content_type: ContentType
) -> Optional[str]: ) -> Optional[str]:
"""Get a fallback model for the given content type.""" """Get a fallback model for the given content type."""
# Map content type to capability # Map content type to capability
@@ -360,24 +376,24 @@ class CascadeRouter:
ContentType.AUDIO: "audio", ContentType.AUDIO: "audio",
ContentType.MULTIMODAL: "vision", # Vision models often do both ContentType.MULTIMODAL: "vision", # Vision models often do both
} }
capability = capability_map.get(content_type) capability = capability_map.get(content_type)
if not capability: if not capability:
return None return None
# Check provider's models for capability # Check provider's models for capability
fallback_model = provider.get_model_with_capability(capability) fallback_model = provider.get_model_with_capability(capability)
if fallback_model and fallback_model != original_model: if fallback_model and fallback_model != original_model:
return fallback_model return fallback_model
# Use fallback chains from config # Use fallback chains from config
fallback_chain = self.config.fallback_chains.get(capability, []) fallback_chain = self.config.fallback_chains.get(capability, [])
for model_name in fallback_chain: for model_name in fallback_chain:
if provider.model_has_capability(model_name, capability): if provider.model_has_capability(model_name, capability):
return model_name return model_name
return None return None
async def complete( async def complete(
self, self,
messages: list[dict], messages: list[dict],
@@ -386,21 +402,21 @@ class CascadeRouter:
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
) -> dict: ) -> dict:
"""Complete a chat conversation with automatic failover. """Complete a chat conversation with automatic failover.
Multi-modal support: Multi-modal support:
- Automatically detects if messages contain images - Automatically detects if messages contain images
- Falls back to vision-capable models when needed - Falls back to vision-capable models when needed
- Supports image URLs, paths, and base64 encoding - Supports image URLs, paths, and base64 encoding
Args: Args:
messages: List of message dicts with role and content messages: List of message dicts with role and content
model: Preferred model (tries this first, then provider defaults) model: Preferred model (tries this first, then provider defaults)
temperature: Sampling temperature temperature: Sampling temperature
max_tokens: Maximum tokens to generate max_tokens: Maximum tokens to generate
Returns: Returns:
Dict with content, provider_used, and metrics Dict with content, provider_used, and metrics
Raises: Raises:
RuntimeError: If all providers fail RuntimeError: If all providers fail
""" """
@@ -408,15 +424,15 @@ class CascadeRouter:
content_type = self._detect_content_type(messages) content_type = self._detect_content_type(messages)
if content_type != ContentType.TEXT: if content_type != ContentType.TEXT:
logger.debug("Detected %s content, selecting appropriate model", content_type.value) logger.debug("Detected %s content, selecting appropriate model", content_type.value)
errors = [] errors = []
for provider in self.providers: for provider in self.providers:
# Skip disabled providers # Skip disabled providers
if not provider.enabled: if not provider.enabled:
logger.debug("Skipping %s (disabled)", provider.name) logger.debug("Skipping %s (disabled)", provider.name)
continue continue
# Skip unhealthy providers (circuit breaker) # Skip unhealthy providers (circuit breaker)
if provider.status == ProviderStatus.UNHEALTHY: if provider.status == ProviderStatus.UNHEALTHY:
# Check if circuit breaker can close # Check if circuit breaker can close
@@ -427,16 +443,16 @@ class CascadeRouter:
else: else:
logger.debug("Skipping %s (circuit open)", provider.name) logger.debug("Skipping %s (circuit open)", provider.name)
continue continue
# Determine which model to use # Determine which model to use
selected_model = model or provider.get_default_model() selected_model = model or provider.get_default_model()
is_fallback_model = False is_fallback_model = False
# For non-text content, check if model supports it # For non-text content, check if model supports it
if content_type != ContentType.TEXT and selected_model: if content_type != ContentType.TEXT and selected_model:
if provider.type == "ollama" and self._mm_manager: if provider.type == "ollama" and self._mm_manager:
from infrastructure.models.multimodal import ModelCapability from infrastructure.models.multimodal import ModelCapability
# Check if selected model supports the required capability # Check if selected model supports the required capability
if content_type == ContentType.VISION: if content_type == ContentType.VISION:
supports = self._mm_manager.model_supports( supports = self._mm_manager.model_supports(
@@ -450,16 +466,17 @@ class CascadeRouter:
if fallback: if fallback:
logger.info( logger.info(
"Model %s doesn't support vision, falling back to %s", "Model %s doesn't support vision, falling back to %s",
selected_model, fallback selected_model,
fallback,
) )
selected_model = fallback selected_model = fallback
is_fallback_model = True is_fallback_model = True
else: else:
logger.warning( logger.warning(
"No vision-capable model found on %s, trying anyway", "No vision-capable model found on %s, trying anyway",
provider.name provider.name,
) )
# Try this provider # Try this provider
for attempt in range(self.config.max_retries_per_provider): for attempt in range(self.config.max_retries_per_provider):
try: try:
@@ -471,34 +488,35 @@ class CascadeRouter:
max_tokens=max_tokens, max_tokens=max_tokens,
content_type=content_type, content_type=content_type,
) )
# Success! Update metrics and return # Success! Update metrics and return
self._record_success(provider, result.get("latency_ms", 0)) self._record_success(provider, result.get("latency_ms", 0))
return { return {
"content": result["content"], "content": result["content"],
"provider": provider.name, "provider": provider.name,
"model": result.get("model", selected_model or provider.get_default_model()), "model": result.get(
"model", selected_model or provider.get_default_model()
),
"latency_ms": result.get("latency_ms", 0), "latency_ms": result.get("latency_ms", 0),
"is_fallback_model": is_fallback_model, "is_fallback_model": is_fallback_model,
} }
except Exception as exc: except Exception as exc:
error_msg = str(exc) error_msg = str(exc)
logger.warning( logger.warning(
"Provider %s attempt %d failed: %s", "Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
provider.name, attempt + 1, error_msg
) )
errors.append(f"{provider.name}: {error_msg}") errors.append(f"{provider.name}: {error_msg}")
if attempt < self.config.max_retries_per_provider - 1: if attempt < self.config.max_retries_per_provider - 1:
await asyncio.sleep(self.config.retry_delay_seconds) await asyncio.sleep(self.config.retry_delay_seconds)
# All retries failed for this provider # All retries failed for this provider
self._record_failure(provider) self._record_failure(provider)
# All providers failed # All providers failed
raise RuntimeError(f"All providers failed: {'; '.join(errors)}") raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
async def _try_provider( async def _try_provider(
self, self,
provider: Provider, provider: Provider,
@@ -510,7 +528,7 @@ class CascadeRouter:
) -> dict: ) -> dict:
"""Try a single provider request.""" """Try a single provider request."""
start_time = time.time() start_time = time.time()
if provider.type == "ollama": if provider.type == "ollama":
result = await self._call_ollama( result = await self._call_ollama(
provider=provider, provider=provider,
@@ -545,12 +563,12 @@ class CascadeRouter:
) )
else: else:
raise ValueError(f"Unknown provider type: {provider.type}") raise ValueError(f"Unknown provider type: {provider.type}")
latency_ms = (time.time() - start_time) * 1000 latency_ms = (time.time() - start_time) * 1000
result["latency_ms"] = latency_ms result["latency_ms"] = latency_ms
return result return result
async def _call_ollama( async def _call_ollama(
self, self,
provider: Provider, provider: Provider,
@@ -561,12 +579,12 @@ class CascadeRouter:
) -> dict: ) -> dict:
"""Call Ollama API with multi-modal support.""" """Call Ollama API with multi-modal support."""
import aiohttp import aiohttp
url = f"{provider.url}/api/chat" url = f"{provider.url}/api/chat"
# Transform messages for Ollama format (including images) # Transform messages for Ollama format (including images)
transformed_messages = self._transform_messages_for_ollama(messages) transformed_messages = self._transform_messages_for_ollama(messages)
payload = { payload = {
"model": model, "model": model,
"messages": transformed_messages, "messages": transformed_messages,
@@ -575,31 +593,31 @@ class CascadeRouter:
"temperature": temperature, "temperature": temperature,
}, },
} }
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=payload) as response: async with session.post(url, json=payload) as response:
if response.status != 200: if response.status != 200:
text = await response.text() text = await response.text()
raise RuntimeError(f"Ollama error {response.status}: {text}") raise RuntimeError(f"Ollama error {response.status}: {text}")
data = await response.json() data = await response.json()
return { return {
"content": data["message"]["content"], "content": data["message"]["content"],
"model": model, "model": model,
} }
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
"""Transform messages to Ollama format, handling images.""" """Transform messages to Ollama format, handling images."""
transformed = [] transformed = []
for msg in messages: for msg in messages:
new_msg = { new_msg = {
"role": msg.get("role", "user"), "role": msg.get("role", "user"),
"content": msg.get("content", ""), "content": msg.get("content", ""),
} }
# Handle images # Handle images
images = msg.get("images", []) images = msg.get("images", [])
if images: if images:
@@ -620,11 +638,11 @@ class CascadeRouter:
new_msg["images"].append(img_data) new_msg["images"].append(img_data)
except Exception as exc: except Exception as exc:
logger.error("Failed to read image %s: %s", img, exc) logger.error("Failed to read image %s: %s", img, exc)
transformed.append(new_msg) transformed.append(new_msg)
return transformed return transformed
async def _call_openai( async def _call_openai(
self, self,
provider: Provider, provider: Provider,
@@ -635,13 +653,13 @@ class CascadeRouter:
) -> dict: ) -> dict:
"""Call OpenAI API.""" """Call OpenAI API."""
import openai import openai
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
api_key=provider.api_key, api_key=provider.api_key,
base_url=provider.base_url, base_url=provider.base_url,
timeout=self.config.timeout_seconds, timeout=self.config.timeout_seconds,
) )
kwargs = { kwargs = {
"model": model, "model": model,
"messages": messages, "messages": messages,
@@ -649,14 +667,14 @@ class CascadeRouter:
} }
if max_tokens: if max_tokens:
kwargs["max_tokens"] = max_tokens kwargs["max_tokens"] = max_tokens
response = await client.chat.completions.create(**kwargs) response = await client.chat.completions.create(**kwargs)
return { return {
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"model": response.model, "model": response.model,
} }
async def _call_anthropic( async def _call_anthropic(
self, self,
provider: Provider, provider: Provider,
@@ -667,12 +685,12 @@ class CascadeRouter:
) -> dict: ) -> dict:
"""Call Anthropic API.""" """Call Anthropic API."""
import anthropic import anthropic
client = anthropic.AsyncAnthropic( client = anthropic.AsyncAnthropic(
api_key=provider.api_key, api_key=provider.api_key,
timeout=self.config.timeout_seconds, timeout=self.config.timeout_seconds,
) )
# Convert messages to Anthropic format # Convert messages to Anthropic format
system_msg = None system_msg = None
conversation = [] conversation = []
@@ -680,11 +698,13 @@ class CascadeRouter:
if msg["role"] == "system": if msg["role"] == "system":
system_msg = msg["content"] system_msg = msg["content"]
else: else:
conversation.append({ conversation.append(
"role": msg["role"], {
"content": msg["content"], "role": msg["role"],
}) "content": msg["content"],
}
)
kwargs = { kwargs = {
"model": model, "model": model,
"messages": conversation, "messages": conversation,
@@ -693,9 +713,9 @@ class CascadeRouter:
} }
if system_msg: if system_msg:
kwargs["system"] = system_msg kwargs["system"] = system_msg
response = await client.messages.create(**kwargs) response = await client.messages.create(**kwargs)
return { return {
"content": response.content[0].text, "content": response.content[0].text,
"model": response.model, "model": response.model,
@@ -733,7 +753,7 @@ class CascadeRouter:
"content": response.choices[0].message.content, "content": response.choices[0].message.content,
"model": response.model, "model": response.model,
} }
def _record_success(self, provider: Provider, latency_ms: float) -> None: def _record_success(self, provider: Provider, latency_ms: float) -> None:
"""Record a successful request.""" """Record a successful request."""
provider.metrics.total_requests += 1 provider.metrics.total_requests += 1
@@ -741,50 +761,50 @@ class CascadeRouter:
provider.metrics.total_latency_ms += latency_ms provider.metrics.total_latency_ms += latency_ms
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat() provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures = 0 provider.metrics.consecutive_failures = 0
# Close circuit breaker if half-open # Close circuit breaker if half-open
if provider.circuit_state == CircuitState.HALF_OPEN: if provider.circuit_state == CircuitState.HALF_OPEN:
provider.half_open_calls += 1 provider.half_open_calls += 1
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls: if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
self._close_circuit(provider) self._close_circuit(provider)
# Update status based on error rate # Update status based on error rate
if provider.metrics.error_rate < 0.1: if provider.metrics.error_rate < 0.1:
provider.status = ProviderStatus.HEALTHY provider.status = ProviderStatus.HEALTHY
elif provider.metrics.error_rate < 0.3: elif provider.metrics.error_rate < 0.3:
provider.status = ProviderStatus.DEGRADED provider.status = ProviderStatus.DEGRADED
def _record_failure(self, provider: Provider) -> None: def _record_failure(self, provider: Provider) -> None:
"""Record a failed request.""" """Record a failed request."""
provider.metrics.total_requests += 1 provider.metrics.total_requests += 1
provider.metrics.failed_requests += 1 provider.metrics.failed_requests += 1
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat() provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
provider.metrics.consecutive_failures += 1 provider.metrics.consecutive_failures += 1
# Check if we should open circuit breaker # Check if we should open circuit breaker
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold: if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
self._open_circuit(provider) self._open_circuit(provider)
# Update status # Update status
if provider.metrics.error_rate > 0.3: if provider.metrics.error_rate > 0.3:
provider.status = ProviderStatus.DEGRADED provider.status = ProviderStatus.DEGRADED
if provider.metrics.error_rate > 0.5: if provider.metrics.error_rate > 0.5:
provider.status = ProviderStatus.UNHEALTHY provider.status = ProviderStatus.UNHEALTHY
def _open_circuit(self, provider: Provider) -> None: def _open_circuit(self, provider: Provider) -> None:
"""Open the circuit breaker for a provider.""" """Open the circuit breaker for a provider."""
provider.circuit_state = CircuitState.OPEN provider.circuit_state = CircuitState.OPEN
provider.circuit_opened_at = time.time() provider.circuit_opened_at = time.time()
provider.status = ProviderStatus.UNHEALTHY provider.status = ProviderStatus.UNHEALTHY
logger.warning("Circuit breaker OPEN for %s", provider.name) logger.warning("Circuit breaker OPEN for %s", provider.name)
def _can_close_circuit(self, provider: Provider) -> bool: def _can_close_circuit(self, provider: Provider) -> bool:
"""Check if circuit breaker can transition to half-open.""" """Check if circuit breaker can transition to half-open."""
if provider.circuit_opened_at is None: if provider.circuit_opened_at is None:
return False return False
elapsed = time.time() - provider.circuit_opened_at elapsed = time.time() - provider.circuit_opened_at
return elapsed >= self.config.circuit_breaker_recovery_timeout return elapsed >= self.config.circuit_breaker_recovery_timeout
def _close_circuit(self, provider: Provider) -> None: def _close_circuit(self, provider: Provider) -> None:
"""Close the circuit breaker (provider healthy again).""" """Close the circuit breaker (provider healthy again)."""
provider.circuit_state = CircuitState.CLOSED provider.circuit_state = CircuitState.CLOSED
@@ -793,7 +813,7 @@ class CascadeRouter:
provider.metrics.consecutive_failures = 0 provider.metrics.consecutive_failures = 0
provider.status = ProviderStatus.HEALTHY provider.status = ProviderStatus.HEALTHY
logger.info("Circuit breaker CLOSED for %s", provider.name) logger.info("Circuit breaker CLOSED for %s", provider.name)
def get_metrics(self) -> dict: def get_metrics(self) -> dict:
"""Get metrics for all providers.""" """Get metrics for all providers."""
return { return {
@@ -814,16 +834,20 @@ class CascadeRouter:
for p in self.providers for p in self.providers
] ]
} }
def get_status(self) -> dict: def get_status(self) -> dict:
"""Get current router status.""" """Get current router status."""
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY) healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
return { return {
"total_providers": len(self.providers), "total_providers": len(self.providers),
"healthy_providers": healthy, "healthy_providers": healthy,
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED), "degraded_providers": sum(
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY), 1 for p in self.providers if p.status == ProviderStatus.DEGRADED
),
"unhealthy_providers": sum(
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
),
"providers": [ "providers": [
{ {
"name": p.name, "name": p.name,
@@ -835,7 +859,7 @@ class CascadeRouter:
for p in self.providers for p in self.providers
], ],
} }
async def generate_with_image( async def generate_with_image(
self, self,
prompt: str, prompt: str,
@@ -844,21 +868,23 @@ class CascadeRouter:
temperature: float = 0.7, temperature: float = 0.7,
) -> dict: ) -> dict:
"""Convenience method for vision requests. """Convenience method for vision requests.
Args: Args:
prompt: Text prompt about the image prompt: Text prompt about the image
image_path: Path to image file image_path: Path to image file
model: Vision-capable model (auto-selected if not provided) model: Vision-capable model (auto-selected if not provided)
temperature: Sampling temperature temperature: Sampling temperature
Returns: Returns:
Response dict with content and metadata Response dict with content and metadata
""" """
messages = [{ messages = [
"role": "user", {
"content": prompt, "role": "user",
"images": [image_path], "content": prompt,
}] "images": [image_path],
}
]
return await self.complete( return await self.complete(
messages=messages, messages=messages,
model=model, model=model,

View File

@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class WSEvent: class WSEvent:
"""A WebSocket event to broadcast to connected clients.""" """A WebSocket event to broadcast to connected clients."""
event: str event: str
data: dict data: dict
timestamp: str timestamp: str
@@ -93,28 +94,42 @@ class WebSocketManager:
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name}) await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
async def broadcast_task_posted(self, task_id: str, description: str) -> None: async def broadcast_task_posted(self, task_id: str, description: str) -> None:
await self.broadcast("task_posted", { await self.broadcast(
"task_id": task_id, "description": description, "task_posted",
}) {
"task_id": task_id,
"description": description,
},
)
async def broadcast_bid_submitted( async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None:
self, task_id: str, agent_id: str, bid_sats: int await self.broadcast(
) -> None: "bid_submitted",
await self.broadcast("bid_submitted", { {
"task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats, "task_id": task_id,
}) "agent_id": agent_id,
"bid_sats": bid_sats,
},
)
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None: async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
await self.broadcast("task_assigned", { await self.broadcast(
"task_id": task_id, "agent_id": agent_id, "task_assigned",
}) {
"task_id": task_id,
"agent_id": agent_id,
},
)
async def broadcast_task_completed( async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None:
self, task_id: str, agent_id: str, result: str await self.broadcast(
) -> None: "task_completed",
await self.broadcast("task_completed", { {
"task_id": task_id, "agent_id": agent_id, "result": result[:200], "task_id": task_id,
}) "agent_id": agent_id,
"result": result[:200],
},
)
@property @property
def connection_count(self) -> int: def connection_count(self) -> int:
@@ -122,28 +137,28 @@ class WebSocketManager:
async def broadcast_json(self, data: dict) -> int: async def broadcast_json(self, data: dict) -> int:
"""Broadcast raw JSON data to all connected clients. """Broadcast raw JSON data to all connected clients.
Args: Args:
data: Dictionary to send as JSON data: Dictionary to send as JSON
Returns: Returns:
Number of clients notified Number of clients notified
""" """
message = json.dumps(data) message = json.dumps(data)
disconnected = [] disconnected = []
count = 0 count = 0
for ws in self._connections: for ws in self._connections:
try: try:
await ws.send_text(message) await ws.send_text(message)
count += 1 count += 1
except Exception: except Exception:
disconnected.append(ws) disconnected.append(ws)
# Clean up dead connections # Clean up dead connections
for ws in disconnected: for ws in disconnected:
self.disconnect(ws) self.disconnect(ws)
return count return count
@property @property

View File

@@ -21,6 +21,7 @@ from typing import Any, Optional
class PlatformState(Enum): class PlatformState(Enum):
"""Lifecycle state of a chat platform connection.""" """Lifecycle state of a chat platform connection."""
DISCONNECTED = auto() DISCONNECTED = auto()
CONNECTING = auto() CONNECTING = auto()
CONNECTED = auto() CONNECTED = auto()
@@ -30,13 +31,12 @@ class PlatformState(Enum):
@dataclass @dataclass
class ChatMessage: class ChatMessage:
"""Vendor-agnostic representation of a chat message.""" """Vendor-agnostic representation of a chat message."""
content: str content: str
author: str author: str
channel_id: str channel_id: str
platform: str platform: str
timestamp: str = field( timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
message_id: Optional[str] = None message_id: Optional[str] = None
thread_id: Optional[str] = None thread_id: Optional[str] = None
attachments: list[str] = field(default_factory=list) attachments: list[str] = field(default_factory=list)
@@ -46,13 +46,12 @@ class ChatMessage:
@dataclass @dataclass
class ChatThread: class ChatThread:
"""Vendor-agnostic representation of a conversation thread.""" """Vendor-agnostic representation of a conversation thread."""
thread_id: str thread_id: str
title: str title: str
channel_id: str channel_id: str
platform: str platform: str
created_at: str = field( created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
archived: bool = False archived: bool = False
message_count: int = 0 message_count: int = 0
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
@@ -61,6 +60,7 @@ class ChatThread:
@dataclass @dataclass
class InviteInfo: class InviteInfo:
"""Parsed invite extracted from an image or text.""" """Parsed invite extracted from an image or text."""
url: str url: str
code: str code: str
platform: str platform: str
@@ -71,6 +71,7 @@ class InviteInfo:
@dataclass @dataclass
class PlatformStatus: class PlatformStatus:
"""Current status of a chat platform connection.""" """Current status of a chat platform connection."""
platform: str platform: str
state: PlatformState state: PlatformState
token_set: bool token_set: bool

View File

@@ -115,7 +115,9 @@ class InviteParser:
"""Strategy 2: Use Ollama vision model for local OCR.""" """Strategy 2: Use Ollama vision model for local OCR."""
try: try:
import base64 import base64
import httpx import httpx
from config import settings from config import settings
except ImportError: except ImportError:
logger.debug("httpx not available for Ollama vision.") logger.debug("httpx not available for Ollama vision.")

View File

@@ -90,10 +90,7 @@ class DiscordVendor(ChatPlatform):
try: try:
import discord import discord
except ImportError: except ImportError:
logger.error( logger.error("discord.py is not installed. " 'Run: pip install ".[discord]"')
"discord.py is not installed. "
'Run: pip install ".[discord]"'
)
return False return False
try: try:
@@ -267,6 +264,7 @@ class DiscordVendor(ChatPlatform):
try: try:
from config import settings from config import settings
return settings.discord_token or None return settings.discord_token or None
except Exception: except Exception:
return None return None
@@ -363,9 +361,7 @@ class DiscordVendor(ChatPlatform):
# Show typing indicator while the agent processes # Show typing indicator while the agent processes
async with target.typing(): async with target.typing():
run = await asyncio.wait_for( run = await asyncio.wait_for(
asyncio.to_thread( asyncio.to_thread(agent.run, content, stream=False, session_id=session_id),
agent.run, content, stream=False, session_id=session_id
),
timeout=300, timeout=300,
) )
response = run.content if hasattr(run, "content") else str(run) response = run.content if hasattr(run, "content") else str(run)
@@ -374,7 +370,9 @@ class DiscordVendor(ChatPlatform):
response = "Sorry, that took too long. Please try a simpler request." response = "Sorry, that took too long. Please try a simpler request."
except Exception as exc: except Exception as exc:
logger.error("Discord: agent.run() failed: %s", exc) logger.error("Discord: agent.run() failed: %s", exc)
response = "I'm having trouble reaching my language model right now. Please try again shortly." response = (
"I'm having trouble reaching my language model right now. Please try again shortly."
)
# Strip hallucinated tool-call JSON and chain-of-thought narration # Strip hallucinated tool-call JSON and chain-of-thought narration
from timmy.session import _clean_response from timmy.session import _clean_response
@@ -408,6 +406,7 @@ class DiscordVendor(ChatPlatform):
# Create a thread from this message # Create a thread from this message
from config import settings from config import settings
thread_name = f"{settings.agent_name} | {message.author.display_name}" thread_name = f"{settings.agent_name} | {message.author.display_name}"
thread = await message.create_thread( thread = await message.create_thread(
name=thread_name[:100], name=thread_name[:100],

View File

@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# ── Inbound: Paperclip → Timmy ────────────────────────────────────────────── # ── Inbound: Paperclip → Timmy ──────────────────────────────────────────────

View File

@@ -20,7 +20,8 @@ import logging
from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable
from config import settings from config import settings
from integrations.paperclip.bridge import PaperclipBridge, bridge as default_bridge from integrations.paperclip.bridge import PaperclipBridge
from integrations.paperclip.bridge import bridge as default_bridge
from integrations.paperclip.models import PaperclipIssue from integrations.paperclip.models import PaperclipIssue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -30,9 +31,8 @@ logger = logging.getLogger(__name__)
class Orchestrator(Protocol): class Orchestrator(Protocol):
"""Anything with an ``execute_task`` matching Timmy's orchestrator.""" """Anything with an ``execute_task`` matching Timmy's orchestrator."""
async def execute_task( async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
self, task_id: str, description: str, context: dict ...
) -> Any: ...
def _wrap_orchestrator(orch: Orchestrator) -> Callable: def _wrap_orchestrator(orch: Orchestrator) -> Callable:
@@ -125,7 +125,9 @@ class TaskRunner:
# Mark the issue as done # Mark the issue as done
return await self.bridge.close_issue(issue.id, comment=None) return await self.bridge.close_issue(issue.id, comment=None)
async def create_follow_up(self, original: PaperclipIssue, result: str) -> Optional[PaperclipIssue]: async def create_follow_up(
self, original: PaperclipIssue, result: str
) -> Optional[PaperclipIssue]:
"""Create a recursive follow-up task for Timmy. """Create a recursive follow-up task for Timmy.
Timmy muses about task automation and writes a follow-up issue Timmy muses about task automation and writes a follow-up issue

View File

@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class ShortcutAction: class ShortcutAction:
"""Describes a Siri Shortcut action for the setup guide.""" """Describes a Siri Shortcut action for the setup guide."""
name: str name: str
endpoint: str endpoint: str
method: str method: str

View File

@@ -54,6 +54,7 @@ class TelegramBot:
return from_file return from_file
try: try:
from config import settings from config import settings
return settings.telegram_token or None return settings.telegram_token or None
except Exception: except Exception:
return None return None
@@ -94,10 +95,7 @@ class TelegramBot:
filters, filters,
) )
except ImportError: except ImportError:
logger.error( logger.error("python-telegram-bot is not installed. " 'Run: pip install ".[telegram]"')
"python-telegram-bot is not installed. "
'Run: pip install ".[telegram]"'
)
return False return False
try: try:
@@ -149,6 +147,7 @@ class TelegramBot:
user_text = update.message.text user_text = update.message.text
try: try:
from timmy.agent import create_timmy from timmy.agent import create_timmy
agent = create_timmy() agent = create_timmy()
run = await asyncio.to_thread(agent.run, user_text, stream=False) run = await asyncio.to_thread(agent.run, user_text, stream=False)
response = run.content if hasattr(run, "content") else str(run) response = run.content if hasattr(run, "content") else str(run)

View File

@@ -15,8 +15,8 @@ Intents:
- unknown: Unrecognized intent - unknown: Unrecognized intent
""" """
import re
import logging import logging
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@@ -35,47 +35,68 @@ class Intent:
_PATTERNS: list[tuple[str, re.Pattern, float]] = [ _PATTERNS: list[tuple[str, re.Pattern, float]] = [
# Status queries # Status queries
("status", re.compile( (
r"\b(status|health|how are you|are you (running|online|alive)|check)\b", "status",
re.IGNORECASE, re.compile(
), 0.9), r"\b(status|health|how are you|are you (running|online|alive)|check)\b",
re.IGNORECASE,
),
0.9,
),
# Swarm commands # Swarm commands
("swarm", re.compile( (
r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b", "swarm",
re.IGNORECASE, re.compile(
), 0.85), r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b",
re.IGNORECASE,
),
0.85,
),
# Task commands # Task commands
("task", re.compile( (
r"\b(task|assign|create task|new task|post task|bid)\b", "task",
re.IGNORECASE, re.compile(
), 0.85), r"\b(task|assign|create task|new task|post task|bid)\b",
re.IGNORECASE,
),
0.85,
),
# Help # Help
("help", re.compile( (
r"\b(help|commands?|what can you do|capabilities)\b", "help",
re.IGNORECASE, re.compile(
), 0.9), r"\b(help|commands?|what can you do|capabilities)\b",
re.IGNORECASE,
),
0.9,
),
# Voice settings # Voice settings
("voice", re.compile( (
r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b", "voice",
re.IGNORECASE, re.compile(
), 0.85), r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b",
re.IGNORECASE,
),
0.85,
),
# Code modification / self-modify # Code modification / self-modify
("code", re.compile( (
r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b" "code",
r"|\bself[- ]?modify\b" re.compile(
r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b", r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b"
re.IGNORECASE, r"|\bself[- ]?modify\b"
), 0.9), r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b",
re.IGNORECASE,
),
0.9,
),
] ]
# Keywords for entity extraction # Keywords for entity extraction
_ENTITY_PATTERNS = { _ENTITY_PATTERNS = {
"agent_name": re.compile(r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE), "agent_name": re.compile(
r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE
),
"task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE), "task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE),
"number": re.compile(r"\b(\d+)\b"), "number": re.compile(r"\b(\d+)\b"),
"target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE), "target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE),

View File

@@ -17,8 +17,8 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Optional
from spark import memory as spark_memory
from spark import eidos as spark_eidos from spark import eidos as spark_eidos
from spark import memory as spark_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,10 +29,11 @@ _MIN_EVENTS = 3
@dataclass @dataclass
class Advisory: class Advisory:
"""A single ranked recommendation.""" """A single ranked recommendation."""
category: str # agent_performance, bid_optimization, etc.
priority: float # 0.01.0 (higher = more urgent) category: str # agent_performance, bid_optimization, etc.
title: str # Short headline priority: float # 0.01.0 (higher = more urgent)
detail: str # Longer explanation title: str # Short headline
detail: str # Longer explanation
suggested_action: str # What to do about it suggested_action: str # What to do about it
subject: Optional[str] = None # agent_id or None for system-level subject: Optional[str] = None # agent_id or None for system-level
evidence_count: int = 0 # Number of supporting events evidence_count: int = 0 # Number of supporting events
@@ -47,15 +48,17 @@ def generate_advisories() -> list[Advisory]:
event_count = spark_memory.count_events() event_count = spark_memory.count_events()
if event_count < _MIN_EVENTS: if event_count < _MIN_EVENTS:
advisories.append(Advisory( advisories.append(
category="system_health", Advisory(
priority=0.3, category="system_health",
title="Insufficient data", priority=0.3,
detail=f"Only {event_count} events captured. " title="Insufficient data",
f"Spark needs at least {_MIN_EVENTS} events to generate insights.", detail=f"Only {event_count} events captured. "
suggested_action="Run more swarm tasks to build intelligence.", f"Spark needs at least {_MIN_EVENTS} events to generate insights.",
evidence_count=event_count, suggested_action="Run more swarm tasks to build intelligence.",
)) evidence_count=event_count,
)
)
return advisories return advisories
advisories.extend(_check_failure_patterns()) advisories.extend(_check_failure_patterns())
@@ -82,18 +85,20 @@ def _check_failure_patterns() -> list[Advisory]:
for aid, count in agent_failures.items(): for aid, count in agent_failures.items():
if count >= 2: if count >= 2:
results.append(Advisory( results.append(
category="failure_prevention", Advisory(
priority=min(1.0, 0.5 + count * 0.15), category="failure_prevention",
title=f"Agent {aid[:8]} has {count} failures", priority=min(1.0, 0.5 + count * 0.15),
detail=f"Agent {aid[:8]}... has failed {count} recent tasks. " title=f"Agent {aid[:8]} has {count} failures",
f"This pattern may indicate a capability mismatch or " detail=f"Agent {aid[:8]}... has failed {count} recent tasks. "
f"configuration issue.", f"This pattern may indicate a capability mismatch or "
suggested_action=f"Review task types assigned to {aid[:8]}... " f"configuration issue.",
f"and consider adjusting routing preferences.", suggested_action=f"Review task types assigned to {aid[:8]}... "
subject=aid, f"and consider adjusting routing preferences.",
evidence_count=count, subject=aid,
)) evidence_count=count,
)
)
return results return results
@@ -128,27 +133,31 @@ def _check_agent_performance() -> list[Advisory]:
rate = wins / total rate = wins / total
if rate >= 0.8 and total >= 3: if rate >= 0.8 and total >= 3:
results.append(Advisory( results.append(
category="agent_performance", Advisory(
priority=0.6, category="agent_performance",
title=f"Agent {aid[:8]} excels ({rate:.0%} success)", priority=0.6,
detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks " title=f"Agent {aid[:8]} excels ({rate:.0%} success)",
f"successfully. Consider routing more tasks to this agent.", detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks "
suggested_action="Increase task routing weight for this agent.", f"successfully. Consider routing more tasks to this agent.",
subject=aid, suggested_action="Increase task routing weight for this agent.",
evidence_count=total, subject=aid,
)) evidence_count=total,
)
)
elif rate <= 0.3 and total >= 3: elif rate <= 0.3 and total >= 3:
results.append(Advisory( results.append(
category="agent_performance", Advisory(
priority=0.75, category="agent_performance",
title=f"Agent {aid[:8]} struggling ({rate:.0%} success)", priority=0.75,
detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. " title=f"Agent {aid[:8]} struggling ({rate:.0%} success)",
f"May need different task types or capability updates.", detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. "
suggested_action="Review this agent's capabilities and assigned task types.", f"May need different task types or capability updates.",
subject=aid, suggested_action="Review this agent's capabilities and assigned task types.",
evidence_count=total, subject=aid,
)) evidence_count=total,
)
)
return results return results
@@ -181,27 +190,31 @@ def _check_bid_patterns() -> list[Advisory]:
spread = max_bid - min_bid spread = max_bid - min_bid
if spread > avg_bid * 1.5: if spread > avg_bid * 1.5:
results.append(Advisory( results.append(
category="bid_optimization", Advisory(
priority=0.5, category="bid_optimization",
title=f"Wide bid spread ({min_bid}{max_bid} sats)", priority=0.5,
detail=f"Bids range from {min_bid} to {max_bid} sats " title=f"Wide bid spread ({min_bid}{max_bid} sats)",
f"(avg {avg_bid:.0f}). Large spread may indicate " detail=f"Bids range from {min_bid} to {max_bid} sats "
f"inefficient auction dynamics.", f"(avg {avg_bid:.0f}). Large spread may indicate "
suggested_action="Review agent bid strategies for consistency.", f"inefficient auction dynamics.",
evidence_count=len(bid_amounts), suggested_action="Review agent bid strategies for consistency.",
)) evidence_count=len(bid_amounts),
)
)
if avg_bid > 70: if avg_bid > 70:
results.append(Advisory( results.append(
category="bid_optimization", Advisory(
priority=0.45, category="bid_optimization",
title=f"High average bid ({avg_bid:.0f} sats)", priority=0.45,
detail=f"The swarm average bid is {avg_bid:.0f} sats across " title=f"High average bid ({avg_bid:.0f} sats)",
f"{len(bid_amounts)} bids. This may be above optimal.", detail=f"The swarm average bid is {avg_bid:.0f} sats across "
suggested_action="Consider adjusting base bid rates for persona agents.", f"{len(bid_amounts)} bids. This may be above optimal.",
evidence_count=len(bid_amounts), suggested_action="Consider adjusting base bid rates for persona agents.",
)) evidence_count=len(bid_amounts),
)
)
return results return results
@@ -216,27 +229,31 @@ def _check_prediction_accuracy() -> list[Advisory]:
avg = stats["avg_accuracy"] avg = stats["avg_accuracy"]
if avg < 0.4: if avg < 0.4:
results.append(Advisory( results.append(
category="system_health", Advisory(
priority=0.65, category="system_health",
title=f"Low prediction accuracy ({avg:.0%})", priority=0.65,
detail=f"EIDOS predictions have averaged {avg:.0%} accuracy " title=f"Low prediction accuracy ({avg:.0%})",
f"over {stats['evaluated']} evaluations. The learning " detail=f"EIDOS predictions have averaged {avg:.0%} accuracy "
f"model needs more data or the swarm behaviour is changing.", f"over {stats['evaluated']} evaluations. The learning "
suggested_action="Continue running tasks; accuracy should improve " f"model needs more data or the swarm behaviour is changing.",
"as the model accumulates more training data.", suggested_action="Continue running tasks; accuracy should improve "
evidence_count=stats["evaluated"], "as the model accumulates more training data.",
)) evidence_count=stats["evaluated"],
)
)
elif avg >= 0.75: elif avg >= 0.75:
results.append(Advisory( results.append(
category="system_health", Advisory(
priority=0.3, category="system_health",
title=f"Strong prediction accuracy ({avg:.0%})", priority=0.3,
detail=f"EIDOS predictions are performing well at {avg:.0%} " title=f"Strong prediction accuracy ({avg:.0%})",
f"average accuracy over {stats['evaluated']} evaluations.", detail=f"EIDOS predictions are performing well at {avg:.0%} "
suggested_action="No action needed. Spark intelligence is learning effectively.", f"average accuracy over {stats['evaluated']} evaluations.",
evidence_count=stats["evaluated"], suggested_action="No action needed. Spark intelligence is learning effectively.",
)) evidence_count=stats["evaluated"],
)
)
return results return results
@@ -247,14 +264,16 @@ def _check_system_activity() -> list[Advisory]:
recent = spark_memory.get_events(limit=5) recent = spark_memory.get_events(limit=5)
if not recent: if not recent:
results.append(Advisory( results.append(
category="system_health", Advisory(
priority=0.4, category="system_health",
title="No swarm activity detected", priority=0.4,
detail="Spark has not captured any events. " title="No swarm activity detected",
"The swarm may be idle or Spark event capture is not active.", detail="Spark has not captured any events. "
suggested_action="Post a task to the swarm to activate the pipeline.", "The swarm may be idle or Spark event capture is not active.",
)) suggested_action="Post a task to the swarm to activate the pipeline.",
)
)
return results return results
# Check event type distribution # Check event type distribution
@@ -265,14 +284,16 @@ def _check_system_activity() -> list[Advisory]:
if "task_completed" not in type_counts and "task_failed" not in type_counts: if "task_completed" not in type_counts and "task_failed" not in type_counts:
if type_counts.get("task_posted", 0) > 3: if type_counts.get("task_posted", 0) > 3:
results.append(Advisory( results.append(
category="system_health", Advisory(
priority=0.6, category="system_health",
title="Tasks posted but none completing", priority=0.6,
detail=f"{type_counts.get('task_posted', 0)} tasks posted " title="Tasks posted but none completing",
f"but no completions or failures recorded.", detail=f"{type_counts.get('task_posted', 0)} tasks posted "
suggested_action="Check agent availability and auction configuration.", f"but no completions or failures recorded.",
evidence_count=type_counts.get("task_posted", 0), suggested_action="Check agent availability and auction configuration.",
)) evidence_count=type_counts.get("task_posted", 0),
)
)
return results return results

View File

@@ -29,12 +29,13 @@ DB_PATH = Path("data/spark.db")
@dataclass @dataclass
class Prediction: class Prediction:
"""A prediction made by the EIDOS loop.""" """A prediction made by the EIDOS loop."""
id: str id: str
task_id: str task_id: str
prediction_type: str # outcome, best_agent, bid_range prediction_type: str # outcome, best_agent, bid_range
predicted_value: str # JSON-encoded prediction predicted_value: str # JSON-encoded prediction
actual_value: Optional[str] # JSON-encoded actual (filled on evaluation) actual_value: Optional[str] # JSON-encoded actual (filled on evaluation)
accuracy: Optional[float] # 0.01.0 (filled on evaluation) accuracy: Optional[float] # 0.01.0 (filled on evaluation)
created_at: str created_at: str
evaluated_at: Optional[str] evaluated_at: Optional[str]
@@ -57,18 +58,15 @@ def _get_conn() -> sqlite3.Connection:
) )
""" """
) )
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
"CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)" conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
)
conn.commit() conn.commit()
return conn return conn
# ── Prediction phase ──────────────────────────────────────────────────────── # ── Prediction phase ────────────────────────────────────────────────────────
def predict_task_outcome( def predict_task_outcome(
task_id: str, task_id: str,
task_description: str, task_description: str,
@@ -104,12 +102,8 @@ def predict_task_outcome(
if best_agent: if best_agent:
prediction["likely_winner"] = best_agent prediction["likely_winner"] = best_agent
prediction["success_probability"] = round( prediction["success_probability"] = round(min(1.0, 0.5 + best_rate * 0.4), 2)
min(1.0, 0.5 + best_rate * 0.4), 2 prediction["reasoning"] = f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
)
prediction["reasoning"] = (
f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
)
# Adjust bid range from history # Adjust bid range from history
all_bids = [] all_bids = []
@@ -144,6 +138,7 @@ def predict_task_outcome(
# ── Evaluation phase ──────────────────────────────────────────────────────── # ── Evaluation phase ────────────────────────────────────────────────────────
def evaluate_prediction( def evaluate_prediction(
task_id: str, task_id: str,
actual_winner: Optional[str], actual_winner: Optional[str],
@@ -242,6 +237,7 @@ def _compute_accuracy(predicted: dict, actual: dict) -> float:
# ── Query helpers ────────────────────────────────────────────────────────── # ── Query helpers ──────────────────────────────────────────────────────────
def get_predictions( def get_predictions(
task_id: Optional[str] = None, task_id: Optional[str] = None,
evaluated_only: bool = False, evaluated_only: bool = False,

View File

@@ -76,7 +76,10 @@ class SparkEngine:
return event_id return event_id
def on_bid_submitted( def on_bid_submitted(
self, task_id: str, agent_id: str, bid_sats: int, self,
task_id: str,
agent_id: str,
bid_sats: int,
) -> Optional[str]: ) -> Optional[str]:
"""Capture a bid event.""" """Capture a bid event."""
if not self._enabled: if not self._enabled:
@@ -90,12 +93,13 @@ class SparkEngine:
data=json.dumps({"bid_sats": bid_sats}), data=json.dumps({"bid_sats": bid_sats}),
) )
logger.debug("Spark: captured bid %s%s (%d sats)", logger.debug("Spark: captured bid %s%s (%d sats)", agent_id[:8], task_id[:8], bid_sats)
agent_id[:8], task_id[:8], bid_sats)
return event_id return event_id
def on_task_assigned( def on_task_assigned(
self, task_id: str, agent_id: str, self,
task_id: str,
agent_id: str,
) -> Optional[str]: ) -> Optional[str]:
"""Capture a task-assigned event.""" """Capture a task-assigned event."""
if not self._enabled: if not self._enabled:
@@ -108,8 +112,7 @@ class SparkEngine:
task_id=task_id, task_id=task_id,
) )
logger.debug("Spark: captured assignment %s%s", logger.debug("Spark: captured assignment %s%s", task_id[:8], agent_id[:8])
task_id[:8], agent_id[:8])
return event_id return event_id
def on_task_completed( def on_task_completed(
@@ -128,10 +131,12 @@ class SparkEngine:
description=f"Task completed by {agent_id[:8]}", description=f"Task completed by {agent_id[:8]}",
agent_id=agent_id, agent_id=agent_id,
task_id=task_id, task_id=task_id,
data=json.dumps({ data=json.dumps(
"result_length": len(result), {
"winning_bid": winning_bid, "result_length": len(result),
}), "winning_bid": winning_bid,
}
),
) )
# Evaluate EIDOS prediction # Evaluate EIDOS prediction
@@ -154,8 +159,7 @@ class SparkEngine:
# Consolidate memory if enough events for this agent # Consolidate memory if enough events for this agent
self._maybe_consolidate(agent_id) self._maybe_consolidate(agent_id)
logger.debug("Spark: captured completion %s by %s", logger.debug("Spark: captured completion %s by %s", task_id[:8], agent_id[:8])
task_id[:8], agent_id[:8])
return event_id return event_id
def on_task_failed( def on_task_failed(
@@ -186,8 +190,7 @@ class SparkEngine:
# Failures always worth consolidating # Failures always worth consolidating
self._maybe_consolidate(agent_id) self._maybe_consolidate(agent_id)
logger.debug("Spark: captured failure %s by %s", logger.debug("Spark: captured failure %s by %s", task_id[:8], agent_id[:8])
task_id[:8], agent_id[:8])
return event_id return event_id
def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]: def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]:
@@ -288,7 +291,7 @@ class SparkEngine:
memory_type="pattern", memory_type="pattern",
subject=agent_id, subject=agent_id,
content=f"Agent {agent_id[:8]} has a strong track record: " content=f"Agent {agent_id[:8]} has a strong track record: "
f"{len(completions)}/{total} tasks completed successfully.", f"{len(completions)}/{total} tasks completed successfully.",
confidence=min(0.95, 0.6 + total * 0.05), confidence=min(0.95, 0.6 + total * 0.05),
source_events=total, source_events=total,
) )
@@ -297,7 +300,7 @@ class SparkEngine:
memory_type="anomaly", memory_type="anomaly",
subject=agent_id, subject=agent_id,
content=f"Agent {agent_id[:8]} is struggling: only " content=f"Agent {agent_id[:8]} is struggling: only "
f"{len(completions)}/{total} tasks completed.", f"{len(completions)}/{total} tasks completed.",
confidence=min(0.95, 0.6 + total * 0.05), confidence=min(0.95, 0.6 + total * 0.05),
source_events=total, source_events=total,
) )
@@ -347,6 +350,7 @@ class SparkEngine:
def _create_engine() -> SparkEngine: def _create_engine() -> SparkEngine:
try: try:
from config import settings from config import settings
return SparkEngine(enabled=settings.spark_enabled) return SparkEngine(enabled=settings.spark_enabled)
except Exception: except Exception:
return SparkEngine(enabled=True) return SparkEngine(enabled=True)

View File

@@ -28,25 +28,27 @@ IMPORTANCE_HIGH = 0.8
@dataclass @dataclass
class SparkEvent: class SparkEvent:
"""A single captured swarm event.""" """A single captured swarm event."""
id: str id: str
event_type: str # task_posted, bid, assignment, completion, failure event_type: str # task_posted, bid, assignment, completion, failure
agent_id: Optional[str] agent_id: Optional[str]
task_id: Optional[str] task_id: Optional[str]
description: str description: str
data: str # JSON payload data: str # JSON payload
importance: float # 0.01.0 importance: float # 0.01.0
created_at: str created_at: str
@dataclass @dataclass
class SparkMemory: class SparkMemory:
"""A consolidated memory distilled from event patterns.""" """A consolidated memory distilled from event patterns."""
id: str id: str
memory_type: str # pattern, insight, anomaly memory_type: str # pattern, insight, anomaly
subject: str # agent_id or "system" subject: str # agent_id or "system"
content: str # Human-readable insight content: str # Human-readable insight
confidence: float # 0.01.0 confidence: float # 0.01.0
source_events: int # How many events contributed source_events: int # How many events contributed
created_at: str created_at: str
expires_at: Optional[str] expires_at: Optional[str]
@@ -83,24 +85,17 @@ def _get_conn() -> sqlite3.Connection:
) )
""" """
) )
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
"CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)" conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
) conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
"CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)"
)
conn.commit() conn.commit()
return conn return conn
# ── Importance scoring ────────────────────────────────────────────────────── # ── Importance scoring ──────────────────────────────────────────────────────
def score_importance(event_type: str, data: dict) -> float: def score_importance(event_type: str, data: dict) -> float:
"""Compute importance score for an event (0.01.0). """Compute importance score for an event (0.01.0).
@@ -132,6 +127,7 @@ def score_importance(event_type: str, data: dict) -> float:
# ── Event recording ───────────────────────────────────────────────────────── # ── Event recording ─────────────────────────────────────────────────────────
def record_event( def record_event(
event_type: str, event_type: str,
description: str, description: str,
@@ -142,6 +138,7 @@ def record_event(
) -> str: ) -> str:
"""Record a swarm event. Returns the event id.""" """Record a swarm event. Returns the event id."""
import json import json
event_id = str(uuid.uuid4()) event_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
@@ -224,6 +221,7 @@ def count_events(event_type: Optional[str] = None) -> int:
# ── Memory consolidation ─────────────────────────────────────────────────── # ── Memory consolidation ───────────────────────────────────────────────────
def store_memory( def store_memory(
memory_type: str, memory_type: str,
subject: str, subject: str,

View File

@@ -73,7 +73,8 @@ def _ensure_db() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True) DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH)) conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS events ( CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
event_type TEXT NOT NULL, event_type TEXT NOT NULL,
@@ -83,7 +84,8 @@ def _ensure_db() -> sqlite3.Connection:
data TEXT DEFAULT '{}', data TEXT DEFAULT '{}',
timestamp TEXT NOT NULL timestamp TEXT NOT NULL
) )
""") """
)
conn.commit() conn.commit()
return conn return conn
@@ -119,8 +121,15 @@ def log_event(
db.execute( db.execute(
"INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) " "INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) "
"VALUES (?, ?, ?, ?, ?, ?, ?)", "VALUES (?, ?, ?, ?, ?, ?, ?)",
(entry.id, event_type.value, source, task_id, agent_id, (
json.dumps(data or {}), entry.timestamp), entry.id,
event_type.value,
source,
task_id,
agent_id,
json.dumps(data or {}),
entry.timestamp,
),
) )
db.commit() db.commit()
finally: finally:
@@ -131,6 +140,7 @@ def log_event(
# Broadcast to WebSocket clients (non-blocking) # Broadcast to WebSocket clients (non-blocking)
try: try:
from infrastructure.events.broadcaster import event_broadcaster from infrastructure.events.broadcaster import event_broadcaster
event_broadcaster.broadcast_sync(entry) event_broadcaster.broadcast_sync(entry)
except Exception: except Exception:
pass pass
@@ -157,13 +167,15 @@ def get_task_events(task_id: str, limit: int = 50) -> list[EventLogEntry]:
et = EventType(r["event_type"]) et = EventType(r["event_type"])
except ValueError: except ValueError:
et = EventType.SYSTEM_INFO et = EventType.SYSTEM_INFO
entries.append(EventLogEntry( entries.append(
id=r["id"], EventLogEntry(
event_type=et, id=r["id"],
source=r["source"], event_type=et,
timestamp=r["timestamp"], source=r["source"],
data=json.loads(r["data"]) if r["data"] else {}, timestamp=r["timestamp"],
task_id=r["task_id"], data=json.loads(r["data"]) if r["data"] else {},
agent_id=r["agent_id"], task_id=r["task_id"],
)) agent_id=r["agent_id"],
)
)
return entries return entries

View File

@@ -29,7 +29,8 @@ def _ensure_db() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True) DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH)) conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS tasks ( CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
title TEXT NOT NULL, title TEXT NOT NULL,
@@ -42,7 +43,8 @@ def _ensure_db() -> sqlite3.Connection:
created_at TEXT DEFAULT (datetime('now')), created_at TEXT DEFAULT (datetime('now')),
completed_at TEXT completed_at TEXT
) )
""") """
)
conn.commit() conn.commit()
return conn return conn
@@ -103,9 +105,7 @@ def get_task_summary_for_briefing() -> dict:
"""Return a summary of task counts by status for the morning briefing.""" """Return a summary of task counts by status for the morning briefing."""
db = _ensure_db() db = _ensure_db()
try: try:
rows = db.execute( rows = db.execute("SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status").fetchall()
"SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status"
).fetchall()
finally: finally:
db.close() db.close()

View File

@@ -69,16 +69,16 @@ def _check_model_available(model_name: str) -> bool:
def _pull_model(model_name: str) -> bool: def _pull_model(model_name: str) -> bool:
"""Attempt to pull a model from Ollama. """Attempt to pull a model from Ollama.
Returns: Returns:
True if successful or model already exists True if successful or model already exists
""" """
try: try:
import urllib.request
import json import json
import urllib.request
logger.info("Pulling model: %s", model_name) logger.info("Pulling model: %s", model_name)
url = settings.ollama_url.replace("localhost", "127.0.0.1") url = settings.ollama_url.replace("localhost", "127.0.0.1")
req = urllib.request.Request( req = urllib.request.Request(
f"{url}/api/pull", f"{url}/api/pull",
@@ -86,7 +86,7 @@ def _pull_model(model_name: str) -> bool:
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
data=json.dumps({"name": model_name, "stream": False}).encode(), data=json.dumps({"name": model_name, "stream": False}).encode(),
) )
with urllib.request.urlopen(req, timeout=300) as response: with urllib.request.urlopen(req, timeout=300) as response:
if response.status == 200: if response.status == 200:
logger.info("Successfully pulled model: %s", model_name) logger.info("Successfully pulled model: %s", model_name)
@@ -94,7 +94,7 @@ def _pull_model(model_name: str) -> bool:
else: else:
logger.error("Failed to pull %s: HTTP %s", model_name, response.status) logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
return False return False
except Exception as exc: except Exception as exc:
logger.error("Error pulling model %s: %s", model_name, exc) logger.error("Error pulling model %s: %s", model_name, exc)
return False return False
@@ -106,53 +106,44 @@ def _resolve_model_with_fallback(
auto_pull: bool = True, auto_pull: bool = True,
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Resolve model with automatic pulling and fallback. """Resolve model with automatic pulling and fallback.
Args: Args:
requested_model: Preferred model to use requested_model: Preferred model to use
require_vision: Whether the model needs vision capabilities require_vision: Whether the model needs vision capabilities
auto_pull: Whether to attempt pulling missing models auto_pull: Whether to attempt pulling missing models
Returns: Returns:
Tuple of (model_name, is_fallback) Tuple of (model_name, is_fallback)
""" """
model = requested_model or settings.ollama_model model = requested_model or settings.ollama_model
# Check if requested model is available # Check if requested model is available
if _check_model_available(model): if _check_model_available(model):
logger.debug("Using available model: %s", model) logger.debug("Using available model: %s", model)
return model, False return model, False
# Try to pull the requested model # Try to pull the requested model
if auto_pull: if auto_pull:
logger.info("Model %s not available locally, attempting to pull...", model) logger.info("Model %s not available locally, attempting to pull...", model)
if _pull_model(model): if _pull_model(model):
return model, False return model, False
logger.warning("Failed to pull %s, checking fallbacks...", model) logger.warning("Failed to pull %s, checking fallbacks...", model)
# Use appropriate fallback chain # Use appropriate fallback chain
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
for fallback_model in fallback_chain: for fallback_model in fallback_chain:
if _check_model_available(fallback_model): if _check_model_available(fallback_model):
logger.warning( logger.warning("Using fallback model %s (requested: %s)", fallback_model, model)
"Using fallback model %s (requested: %s)",
fallback_model, model
)
return fallback_model, True return fallback_model, True
# Try to pull the fallback # Try to pull the fallback
if auto_pull and _pull_model(fallback_model): if auto_pull and _pull_model(fallback_model):
logger.info( logger.info("Pulled and using fallback model %s (requested: %s)", fallback_model, model)
"Pulled and using fallback model %s (requested: %s)",
fallback_model, model
)
return fallback_model, True return fallback_model, True
# Absolute last resort - return the requested model and hope for the best # Absolute last resort - return the requested model and hope for the best
logger.error( logger.error("No models available in fallback chain. Requested: %s", model)
"No models available in fallback chain. Requested: %s",
model
)
return model, False return model, False
@@ -190,6 +181,7 @@ def _resolve_backend(requested: str | None) -> str:
# "auto" path — lazy import to keep startup fast and tests clean. # "auto" path — lazy import to keep startup fast and tests clean.
from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon
if is_apple_silicon() and airllm_available(): if is_apple_silicon() and airllm_available():
return "airllm" return "airllm"
return "ollama" return "ollama"
@@ -215,14 +207,17 @@ def create_timmy(
if resolved == "claude": if resolved == "claude":
from timmy.backends import ClaudeBackend from timmy.backends import ClaudeBackend
return ClaudeBackend() return ClaudeBackend()
if resolved == "grok": if resolved == "grok":
from timmy.backends import GrokBackend from timmy.backends import GrokBackend
return GrokBackend() return GrokBackend()
if resolved == "airllm": if resolved == "airllm":
from timmy.backends import TimmyAirLLMAgent from timmy.backends import TimmyAirLLMAgent
return TimmyAirLLMAgent(model_size=size) return TimmyAirLLMAgent(model_size=size)
# Default: Ollama via Agno. # Default: Ollama via Agno.
@@ -236,16 +231,16 @@ def create_timmy(
# If Ollama is completely unreachable, fall back to Claude if available # If Ollama is completely unreachable, fall back to Claude if available
if not _check_model_available(model_name): if not _check_model_available(model_name):
from timmy.backends import claude_available from timmy.backends import claude_available
if claude_available(): if claude_available():
logger.warning( logger.warning("Ollama unreachable — falling back to Claude backend")
"Ollama unreachable — falling back to Claude backend"
)
from timmy.backends import ClaudeBackend from timmy.backends import ClaudeBackend
return ClaudeBackend() return ClaudeBackend()
if is_fallback: if is_fallback:
logger.info("Using fallback model %s (requested was unavailable)", model_name) logger.info("Using fallback model %s (requested was unavailable)", model_name)
use_tools = _model_supports_tools(model_name) use_tools = _model_supports_tools(model_name)
# Conditionally include tools — small models get none # Conditionally include tools — small models get none
@@ -259,6 +254,7 @@ def create_timmy(
# Try to load memory context # Try to load memory context
try: try:
from timmy.memory_system import memory_system from timmy.memory_system import memory_system
memory_context = memory_system.get_system_context() memory_context = memory_system.get_system_context()
if memory_context: if memory_context:
# Truncate if too long — smaller budget for small models # Truncate if too long — smaller budget for small models
@@ -290,32 +286,32 @@ def create_timmy(
class TimmyWithMemory: class TimmyWithMemory:
"""Agent wrapper with explicit three-tier memory management.""" """Agent wrapper with explicit three-tier memory management."""
def __init__(self, db_file: str = "timmy.db") -> None: def __init__(self, db_file: str = "timmy.db") -> None:
from timmy.memory_system import memory_system from timmy.memory_system import memory_system
self.agent = create_timmy(db_file=db_file) self.agent = create_timmy(db_file=db_file)
self.memory = memory_system self.memory = memory_system
self.session_active = True self.session_active = True
# Store initial context for reference # Store initial context for reference
self.initial_context = self.memory.get_system_context() self.initial_context = self.memory.get_system_context()
def chat(self, message: str) -> str: def chat(self, message: str) -> str:
"""Simple chat interface that tracks in memory.""" """Simple chat interface that tracks in memory."""
# Check for user facts to extract # Check for user facts to extract
self._extract_and_store_facts(message) self._extract_and_store_facts(message)
# Run agent # Run agent
result = self.agent.run(message, stream=False) result = self.agent.run(message, stream=False)
response_text = result.content if hasattr(result, "content") else str(result) response_text = result.content if hasattr(result, "content") else str(result)
return response_text return response_text
def _extract_and_store_facts(self, message: str) -> None: def _extract_and_store_facts(self, message: str) -> None:
"""Extract user facts from message and store in memory.""" """Extract user facts from message and store in memory."""
message_lower = message.lower() message_lower = message.lower()
# Extract name # Extract name
name_patterns = [ name_patterns = [
("my name is ", 11), ("my name is ", 11),
@@ -323,7 +319,7 @@ class TimmyWithMemory:
("i am ", 5), ("i am ", 5),
("call me ", 8), ("call me ", 8),
] ]
for pattern, offset in name_patterns: for pattern, offset in name_patterns:
if pattern in message_lower: if pattern in message_lower:
idx = message_lower.find(pattern) + offset idx = message_lower.find(pattern) + offset
@@ -332,7 +328,7 @@ class TimmyWithMemory:
self.memory.update_user_fact("Name", name) self.memory.update_user_fact("Name", name)
self.memory.record_decision(f"Learned user's name: {name}") self.memory.record_decision(f"Learned user's name: {name}")
break break
# Extract preferences # Extract preferences
pref_patterns = [ pref_patterns = [
("i like ", "Likes"), ("i like ", "Likes"),
@@ -341,7 +337,7 @@ class TimmyWithMemory:
("i don't like ", "Dislikes"), ("i don't like ", "Dislikes"),
("i hate ", "Dislikes"), ("i hate ", "Dislikes"),
] ]
for pattern, category in pref_patterns: for pattern, category in pref_patterns:
if pattern in message_lower: if pattern in message_lower:
idx = message_lower.find(pattern) + len(pattern) idx = message_lower.find(pattern) + len(pattern)
@@ -349,16 +345,16 @@ class TimmyWithMemory:
if pref and len(pref) > 3: if pref and len(pref) > 3:
self.memory.record_open_item(f"User {category.lower()}: {pref}") self.memory.record_open_item(f"User {category.lower()}: {pref}")
break break
def end_session(self, summary: str = "Session completed") -> None: def end_session(self, summary: str = "Session completed") -> None:
"""End session and write handoff.""" """End session and write handoff."""
if self.session_active: if self.session_active:
self.memory.end_session(summary) self.memory.end_session(summary)
self.session_active = False self.session_active = False
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.end_session() self.end_session()
return False return False

View File

@@ -16,38 +16,41 @@ Architecture:
All methods return effects that can be logged, audited, and replayed. All methods return effects that can be logged, audited, and replayed.
""" """
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import Enum, auto from enum import Enum, auto
from typing import Any, Optional from typing import Any, Optional
import uuid
class PerceptionType(Enum): class PerceptionType(Enum):
"""Types of sensory input an agent can receive.""" """Types of sensory input an agent can receive."""
TEXT = auto() # Natural language
IMAGE = auto() # Visual input TEXT = auto() # Natural language
AUDIO = auto() # Sound/speech IMAGE = auto() # Visual input
SENSOR = auto() # Temperature, distance, etc. AUDIO = auto() # Sound/speech
MOTION = auto() # Accelerometer, gyroscope SENSOR = auto() # Temperature, distance, etc.
NETWORK = auto() # API calls, messages MOTION = auto() # Accelerometer, gyroscope
INTERNAL = auto() # Self-monitoring (battery, temp) NETWORK = auto() # API calls, messages
INTERNAL = auto() # Self-monitoring (battery, temp)
class ActionType(Enum): class ActionType(Enum):
"""Types of actions an agent can perform.""" """Types of actions an agent can perform."""
TEXT = auto() # Generate text response
SPEAK = auto() # Text-to-speech TEXT = auto() # Generate text response
MOVE = auto() # Physical movement SPEAK = auto() # Text-to-speech
GRIP = auto() # Manipulate objects MOVE = auto() # Physical movement
CALL = auto() # API/network call GRIP = auto() # Manipulate objects
EMIT = auto() # Signal/light/sound CALL = auto() # API/network call
SLEEP = auto() # Power management EMIT = auto() # Signal/light/sound
SLEEP = auto() # Power management
class AgentCapability(Enum): class AgentCapability(Enum):
"""High-level capabilities a TimAgent may possess.""" """High-level capabilities a TimAgent may possess."""
REASONING = "reasoning" REASONING = "reasoning"
CODING = "coding" CODING = "coding"
WRITING = "writing" WRITING = "writing"
@@ -63,15 +66,16 @@ class AgentCapability(Enum):
@dataclass(frozen=True) @dataclass(frozen=True)
class AgentIdentity: class AgentIdentity:
"""Immutable identity for an agent instance. """Immutable identity for an agent instance.
This persists across sessions and substrates. If Timmy moves This persists across sessions and substrates. If Timmy moves
from cloud to robot, the identity follows. from cloud to robot, the identity follows.
""" """
id: str id: str
name: str name: str
version: str version: str
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
@classmethod @classmethod
def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity": def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity":
"""Generate a new unique identity.""" """Generate a new unique identity."""
@@ -85,16 +89,17 @@ class AgentIdentity:
@dataclass @dataclass
class Perception: class Perception:
"""A sensory input to the agent. """A sensory input to the agent.
Substrate-agnostic representation. A camera image and a Substrate-agnostic representation. A camera image and a
LiDAR point cloud are both Perception instances. LiDAR point cloud are both Perception instances.
""" """
type: PerceptionType type: PerceptionType
data: Any # Content depends on type data: Any # Content depends on type
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
source: str = "unknown" # e.g., "camera_1", "microphone", "user_input" source: str = "unknown" # e.g., "camera_1", "microphone", "user_input"
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
@classmethod @classmethod
def text(cls, content: str, source: str = "user") -> "Perception": def text(cls, content: str, source: str = "user") -> "Perception":
"""Factory for text perception.""" """Factory for text perception."""
@@ -103,7 +108,7 @@ class Perception:
data=content, data=content,
source=source, source=source,
) )
@classmethod @classmethod
def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception": def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception":
"""Factory for sensor readings.""" """Factory for sensor readings."""
@@ -117,16 +122,17 @@ class Perception:
@dataclass @dataclass
class Action: class Action:
"""An action the agent intends to perform. """An action the agent intends to perform.
Actions are effects — they describe what should happen, Actions are effects — they describe what should happen,
not how. The substrate implements the "how." not how. The substrate implements the "how."
""" """
type: ActionType type: ActionType
payload: Any # Action-specific data payload: Any # Action-specific data
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
confidence: float = 1.0 # 0-1, agent's certainty confidence: float = 1.0 # 0-1, agent's certainty
deadline: Optional[str] = None # When action must complete deadline: Optional[str] = None # When action must complete
@classmethod @classmethod
def respond(cls, text: str, confidence: float = 1.0) -> "Action": def respond(cls, text: str, confidence: float = 1.0) -> "Action":
"""Factory for text response action.""" """Factory for text response action."""
@@ -135,7 +141,7 @@ class Action:
payload=text, payload=text,
confidence=confidence, confidence=confidence,
) )
@classmethod @classmethod
def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action": def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action":
"""Factory for movement action (x, y, z meters).""" """Factory for movement action (x, y, z meters)."""
@@ -148,10 +154,11 @@ class Action:
@dataclass @dataclass
class Memory: class Memory:
"""A stored experience or fact. """A stored experience or fact.
Memories are substrate-agnostic. A conversation history Memories are substrate-agnostic. A conversation history
and a video recording are both Memory instances. and a video recording are both Memory instances.
""" """
id: str id: str
content: Any content: Any
created_at: str created_at: str
@@ -159,7 +166,7 @@ class Memory:
last_accessed: Optional[str] = None last_accessed: Optional[str] = None
importance: float = 0.5 # 0-1, for pruning decisions importance: float = 0.5 # 0-1, for pruning decisions
tags: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list)
def touch(self) -> None: def touch(self) -> None:
"""Mark memory as accessed.""" """Mark memory as accessed."""
self.access_count += 1 self.access_count += 1
@@ -169,6 +176,7 @@ class Memory:
@dataclass @dataclass
class Communication: class Communication:
"""A message to/from another agent or human.""" """A message to/from another agent or human."""
sender: str sender: str
recipient: str recipient: str
content: Any content: Any
@@ -179,132 +187,132 @@ class Communication:
class TimAgent(ABC): class TimAgent(ABC):
"""Abstract base class for all Timmy agent implementations. """Abstract base class for all Timmy agent implementations.
This is the substrate-agnostic interface. Implementations: This is the substrate-agnostic interface. Implementations:
- OllamaAgent: LLM-based reasoning (today) - OllamaAgent: LLM-based reasoning (today)
- RobotAgent: Physical embodiment (future) - RobotAgent: Physical embodiment (future)
- SimulationAgent: Virtual environment (future) - SimulationAgent: Virtual environment (future)
Usage: Usage:
agent = OllamaAgent(identity) # Today's implementation agent = OllamaAgent(identity) # Today's implementation
perception = Perception.text("Hello Timmy") perception = Perception.text("Hello Timmy")
memory = agent.perceive(perception) memory = agent.perceive(perception)
action = agent.reason("How should I respond?") action = agent.reason("How should I respond?")
result = agent.act(action) result = agent.act(action)
agent.remember(memory) # Store for future agent.remember(memory) # Store for future
""" """
def __init__(self, identity: AgentIdentity) -> None: def __init__(self, identity: AgentIdentity) -> None:
self._identity = identity self._identity = identity
self._capabilities: set[AgentCapability] = set() self._capabilities: set[AgentCapability] = set()
self._state: dict[str, Any] = {} self._state: dict[str, Any] = {}
@property @property
def identity(self) -> AgentIdentity: def identity(self) -> AgentIdentity:
"""Return this agent's immutable identity.""" """Return this agent's immutable identity."""
return self._identity return self._identity
@property @property
def capabilities(self) -> set[AgentCapability]: def capabilities(self) -> set[AgentCapability]:
"""Return set of supported capabilities.""" """Return set of supported capabilities."""
return self._capabilities.copy() return self._capabilities.copy()
def has_capability(self, capability: AgentCapability) -> bool: def has_capability(self, capability: AgentCapability) -> bool:
"""Check if agent supports a capability.""" """Check if agent supports a capability."""
return capability in self._capabilities return capability in self._capabilities
@abstractmethod @abstractmethod
def perceive(self, perception: Perception) -> Memory: def perceive(self, perception: Perception) -> Memory:
"""Process sensory input and create a memory. """Process sensory input and create a memory.
This is the entry point for all agent interaction. This is the entry point for all agent interaction.
A text message, camera frame, or temperature reading A text message, camera frame, or temperature reading
all enter through perceive(). all enter through perceive().
Args: Args:
perception: Sensory input perception: Sensory input
Returns: Returns:
Memory: Stored representation of the perception Memory: Stored representation of the perception
""" """
pass pass
@abstractmethod @abstractmethod
def reason(self, query: str, context: list[Memory]) -> Action: def reason(self, query: str, context: list[Memory]) -> Action:
"""Reason about a situation and decide on action. """Reason about a situation and decide on action.
This is where "thinking" happens. The agent uses its This is where "thinking" happens. The agent uses its
substrate-appropriate reasoning (LLM, neural net, rules) substrate-appropriate reasoning (LLM, neural net, rules)
to decide what to do. to decide what to do.
Args: Args:
query: What to reason about query: What to reason about
context: Relevant memories for context context: Relevant memories for context
Returns: Returns:
Action: What the agent decides to do Action: What the agent decides to do
""" """
pass pass
@abstractmethod @abstractmethod
def act(self, action: Action) -> Any: def act(self, action: Action) -> Any:
"""Execute an action in the substrate. """Execute an action in the substrate.
This is where the abstract action becomes concrete: This is where the abstract action becomes concrete:
- TEXT → Generate LLM response - TEXT → Generate LLM response
- MOVE → Send motor commands - MOVE → Send motor commands
- SPEAK → Call TTS engine - SPEAK → Call TTS engine
Args: Args:
action: The action to execute action: The action to execute
Returns: Returns:
Result of the action (substrate-specific) Result of the action (substrate-specific)
""" """
pass pass
@abstractmethod @abstractmethod
def remember(self, memory: Memory) -> None: def remember(self, memory: Memory) -> None:
"""Store a memory for future retrieval. """Store a memory for future retrieval.
The storage mechanism depends on substrate: The storage mechanism depends on substrate:
- Cloud: SQLite, vector DB - Cloud: SQLite, vector DB
- Robot: Local flash storage - Robot: Local flash storage
- Hybrid: Synced with conflict resolution - Hybrid: Synced with conflict resolution
Args: Args:
memory: Experience to store memory: Experience to store
""" """
pass pass
@abstractmethod @abstractmethod
def recall(self, query: str, limit: int = 5) -> list[Memory]: def recall(self, query: str, limit: int = 5) -> list[Memory]:
"""Retrieve relevant memories. """Retrieve relevant memories.
Args: Args:
query: What to search for query: What to search for
limit: Maximum memories to return limit: Maximum memories to return
Returns: Returns:
List of relevant memories, sorted by relevance List of relevant memories, sorted by relevance
""" """
pass pass
@abstractmethod @abstractmethod
def communicate(self, message: Communication) -> bool: def communicate(self, message: Communication) -> bool:
"""Send/receive communication with another agent. """Send/receive communication with another agent.
Args: Args:
message: Message to send message: Message to send
Returns: Returns:
True if communication succeeded True if communication succeeded
""" """
pass pass
def get_state(self) -> dict[str, Any]: def get_state(self) -> dict[str, Any]:
"""Get current agent state for monitoring/debugging.""" """Get current agent state for monitoring/debugging."""
return { return {
@@ -312,7 +320,7 @@ class TimAgent(ABC):
"capabilities": list(self._capabilities), "capabilities": list(self._capabilities),
"state": self._state.copy(), "state": self._state.copy(),
} }
def shutdown(self) -> None: def shutdown(self) -> None:
"""Graceful shutdown. Persist state, close connections.""" """Graceful shutdown. Persist state, close connections."""
# Override in subclass for cleanup # Override in subclass for cleanup
@@ -321,7 +329,7 @@ class TimAgent(ABC):
class AgentEffect: class AgentEffect:
"""Log entry for agent actions — for audit and replay. """Log entry for agent actions — for audit and replay.
The complete history of an agent's life can be captured The complete history of an agent's life can be captured
as a sequence of AgentEffects. This enables: as a sequence of AgentEffects. This enables:
- Debugging: What did the agent see and do? - Debugging: What did the agent see and do?
@@ -329,40 +337,46 @@ class AgentEffect:
- Replay: Reconstruct agent state from log - Replay: Reconstruct agent state from log
- Training: Learn from agent experiences - Training: Learn from agent experiences
""" """
def __init__(self, log_path: Optional[str] = None) -> None: def __init__(self, log_path: Optional[str] = None) -> None:
self._effects: list[dict] = [] self._effects: list[dict] = []
self._log_path = log_path self._log_path = log_path
def log_perceive(self, perception: Perception, memory_id: str) -> None: def log_perceive(self, perception: Perception, memory_id: str) -> None:
"""Log a perception event.""" """Log a perception event."""
self._effects.append({ self._effects.append(
"type": "perceive", {
"perception_type": perception.type.name, "type": "perceive",
"source": perception.source, "perception_type": perception.type.name,
"memory_id": memory_id, "source": perception.source,
"timestamp": datetime.now(timezone.utc).isoformat(), "memory_id": memory_id,
}) "timestamp": datetime.now(timezone.utc).isoformat(),
}
)
def log_reason(self, query: str, action_type: ActionType) -> None: def log_reason(self, query: str, action_type: ActionType) -> None:
"""Log a reasoning event.""" """Log a reasoning event."""
self._effects.append({ self._effects.append(
"type": "reason", {
"query": query, "type": "reason",
"action_type": action_type.name, "query": query,
"timestamp": datetime.now(timezone.utc).isoformat(), "action_type": action_type.name,
}) "timestamp": datetime.now(timezone.utc).isoformat(),
}
)
def log_act(self, action: Action, result: Any) -> None: def log_act(self, action: Action, result: Any) -> None:
"""Log an action event.""" """Log an action event."""
self._effects.append({ self._effects.append(
"type": "act", {
"action_type": action.type.name, "type": "act",
"confidence": action.confidence, "action_type": action.type.name,
"result_type": type(result).__name__, "confidence": action.confidence,
"timestamp": datetime.now(timezone.utc).isoformat(), "result_type": type(result).__name__,
}) "timestamp": datetime.now(timezone.utc).isoformat(),
}
)
def export(self) -> list[dict]: def export(self) -> list[dict]:
"""Export effect log for analysis.""" """Export effect log for analysis."""
return self._effects.copy() return self._effects.copy()

View File

@@ -7,10 +7,10 @@ between the old codebase and the new embodiment-ready architecture.
Usage: Usage:
from timmy.agent_core import AgentIdentity, Perception from timmy.agent_core import AgentIdentity, Perception
from timmy.agent_core.ollama_adapter import OllamaAgent from timmy.agent_core.ollama_adapter import OllamaAgent
identity = AgentIdentity.generate("Timmy") identity = AgentIdentity.generate("Timmy")
agent = OllamaAgent(identity) agent = OllamaAgent(identity)
perception = Perception.text("Hello!") perception = Perception.text("Hello!")
memory = agent.perceive(perception) memory = agent.perceive(perception)
action = agent.reason("How should I respond?", [memory]) action = agent.reason("How should I respond?", [memory])
@@ -19,27 +19,27 @@ Usage:
from typing import Any, Optional from typing import Any, Optional
from timmy.agent import _resolve_model_with_fallback, create_timmy
from timmy.agent_core.interface import ( from timmy.agent_core.interface import (
AgentCapability,
AgentIdentity,
Perception,
PerceptionType,
Action, Action,
ActionType, ActionType,
Memory, AgentCapability,
Communication,
TimAgent,
AgentEffect, AgentEffect,
AgentIdentity,
Communication,
Memory,
Perception,
PerceptionType,
TimAgent,
) )
from timmy.agent import create_timmy, _resolve_model_with_fallback
class OllamaAgent(TimAgent): class OllamaAgent(TimAgent):
"""TimAgent implementation using local Ollama LLM. """TimAgent implementation using local Ollama LLM.
This is the production agent for Timmy Time v2. It uses This is the production agent for Timmy Time v2. It uses
Ollama for reasoning and SQLite for memory persistence. Ollama for reasoning and SQLite for memory persistence.
Capabilities: Capabilities:
- REASONING: LLM-based inference - REASONING: LLM-based inference
- CODING: Code generation and analysis - CODING: Code generation and analysis
@@ -47,7 +47,7 @@ class OllamaAgent(TimAgent):
- ANALYSIS: Data processing and insights - ANALYSIS: Data processing and insights
- COMMUNICATION: Multi-agent messaging - COMMUNICATION: Multi-agent messaging
""" """
def __init__( def __init__(
self, self,
identity: AgentIdentity, identity: AgentIdentity,
@@ -56,7 +56,7 @@ class OllamaAgent(TimAgent):
require_vision: bool = False, require_vision: bool = False,
) -> None: ) -> None:
"""Initialize Ollama-based agent. """Initialize Ollama-based agent.
Args: Args:
identity: Agent identity (persistent across sessions) identity: Agent identity (persistent across sessions)
model: Ollama model to use (auto-resolves with fallback) model: Ollama model to use (auto-resolves with fallback)
@@ -64,23 +64,24 @@ class OllamaAgent(TimAgent):
require_vision: Whether to select a vision-capable model require_vision: Whether to select a vision-capable model
""" """
super().__init__(identity) super().__init__(identity)
# Resolve model with automatic pulling and fallback # Resolve model with automatic pulling and fallback
resolved_model, is_fallback = _resolve_model_with_fallback( resolved_model, is_fallback = _resolve_model_with_fallback(
requested_model=model, requested_model=model,
require_vision=require_vision, require_vision=require_vision,
auto_pull=True, auto_pull=True,
) )
if is_fallback: if is_fallback:
import logging import logging
logging.getLogger(__name__).info( logging.getLogger(__name__).info(
"OllamaAdapter using fallback model %s", resolved_model "OllamaAdapter using fallback model %s", resolved_model
) )
# Initialize underlying Ollama agent # Initialize underlying Ollama agent
self._timmy = create_timmy(model=resolved_model) self._timmy = create_timmy(model=resolved_model)
# Set capabilities based on what Ollama can do # Set capabilities based on what Ollama can do
self._capabilities = { self._capabilities = {
AgentCapability.REASONING, AgentCapability.REASONING,
@@ -89,17 +90,17 @@ class OllamaAgent(TimAgent):
AgentCapability.ANALYSIS, AgentCapability.ANALYSIS,
AgentCapability.COMMUNICATION, AgentCapability.COMMUNICATION,
} }
# Effect logging for audit/replay # Effect logging for audit/replay
self._effect_log = AgentEffect(effect_log) if effect_log else None self._effect_log = AgentEffect(effect_log) if effect_log else None
# Simple in-memory working memory (short term) # Simple in-memory working memory (short term)
self._working_memory: list[Memory] = [] self._working_memory: list[Memory] = []
self._max_working_memory = 10 self._max_working_memory = 10
def perceive(self, perception: Perception) -> Memory: def perceive(self, perception: Perception) -> Memory:
"""Process perception and store in memory. """Process perception and store in memory.
For text perceptions, we might do light preprocessing For text perceptions, we might do light preprocessing
(summarization, keyword extraction) before storage. (summarization, keyword extraction) before storage.
""" """
@@ -114,28 +115,28 @@ class OllamaAgent(TimAgent):
created_at=perception.timestamp, created_at=perception.timestamp,
tags=self._extract_tags(perception), tags=self._extract_tags(perception),
) )
# Add to working memory # Add to working memory
self._working_memory.append(memory) self._working_memory.append(memory)
if len(self._working_memory) > self._max_working_memory: if len(self._working_memory) > self._max_working_memory:
self._working_memory.pop(0) # FIFO eviction self._working_memory.pop(0) # FIFO eviction
# Log effect # Log effect
if self._effect_log: if self._effect_log:
self._effect_log.log_perceive(perception, memory.id) self._effect_log.log_perceive(perception, memory.id)
return memory return memory
def reason(self, query: str, context: list[Memory]) -> Action: def reason(self, query: str, context: list[Memory]) -> Action:
"""Use LLM to reason and decide on action. """Use LLM to reason and decide on action.
This is where the Ollama agent does its work. We construct This is where the Ollama agent does its work. We construct
a prompt from the query and context, then interpret the a prompt from the query and context, then interpret the
response as an action. response as an action.
""" """
# Build context string from memories # Build context string from memories
context_str = self._format_context(context) context_str = self._format_context(context)
# Construct prompt # Construct prompt
prompt = f"""You are {self._identity.name}, an AI assistant. prompt = f"""You are {self._identity.name}, an AI assistant.
@@ -145,30 +146,30 @@ Context from previous interactions:
Current query: {query} Current query: {query}
Respond naturally and helpfully.""" Respond naturally and helpfully."""
# Run LLM inference # Run LLM inference
result = self._timmy.run(prompt, stream=False) result = self._timmy.run(prompt, stream=False)
response_text = result.content if hasattr(result, "content") else str(result) response_text = result.content if hasattr(result, "content") else str(result)
# Create text response action # Create text response action
action = Action.respond(response_text, confidence=0.9) action = Action.respond(response_text, confidence=0.9)
# Log effect # Log effect
if self._effect_log: if self._effect_log:
self._effect_log.log_reason(query, action.type) self._effect_log.log_reason(query, action.type)
return action return action
def act(self, action: Action) -> Any: def act(self, action: Action) -> Any:
"""Execute action in the Ollama substrate. """Execute action in the Ollama substrate.
For text actions, the "execution" is just returning the For text actions, the "execution" is just returning the
text (already generated during reasoning). For future text (already generated during reasoning). For future
action types (MOVE, SPEAK), this would trigger the action types (MOVE, SPEAK), this would trigger the
appropriate Ollama tool calls. appropriate Ollama tool calls.
""" """
result = None result = None
if action.type == ActionType.TEXT: if action.type == ActionType.TEXT:
result = action.payload result = action.payload
elif action.type == ActionType.SPEAK: elif action.type == ActionType.SPEAK:
@@ -179,13 +180,13 @@ Respond naturally and helpfully."""
result = {"status": "not_implemented", "payload": action.payload} result = {"status": "not_implemented", "payload": action.payload}
else: else:
result = {"error": f"Action type {action.type} not supported by OllamaAgent"} result = {"error": f"Action type {action.type} not supported by OllamaAgent"}
# Log effect # Log effect
if self._effect_log: if self._effect_log:
self._effect_log.log_act(action, result) self._effect_log.log_act(action, result)
return result return result
def remember(self, memory: Memory) -> None: def remember(self, memory: Memory) -> None:
"""Store memory in working memory. """Store memory in working memory.
@@ -200,48 +201,48 @@ Respond naturally and helpfully."""
# Evict oldest if over capacity # Evict oldest if over capacity
if len(self._working_memory) > self._max_working_memory: if len(self._working_memory) > self._max_working_memory:
self._working_memory.pop(0) self._working_memory.pop(0)
def recall(self, query: str, limit: int = 5) -> list[Memory]: def recall(self, query: str, limit: int = 5) -> list[Memory]:
"""Retrieve relevant memories. """Retrieve relevant memories.
Simple keyword matching for now. Future: vector similarity. Simple keyword matching for now. Future: vector similarity.
""" """
query_lower = query.lower() query_lower = query.lower()
scored = [] scored = []
for memory in self._working_memory: for memory in self._working_memory:
score = 0 score = 0
content_str = str(memory.content).lower() content_str = str(memory.content).lower()
# Simple keyword overlap # Simple keyword overlap
query_words = set(query_lower.split()) query_words = set(query_lower.split())
content_words = set(content_str.split()) content_words = set(content_str.split())
overlap = len(query_words & content_words) overlap = len(query_words & content_words)
score += overlap score += overlap
# Boost recent memories # Boost recent memories
score += memory.importance score += memory.importance
scored.append((score, memory)) scored.append((score, memory))
# Sort by score descending # Sort by score descending
scored.sort(key=lambda x: x[0], reverse=True) scored.sort(key=lambda x: x[0], reverse=True)
# Return top N # Return top N
return [m for _, m in scored[:limit]] return [m for _, m in scored[:limit]]
def communicate(self, message: Communication) -> bool: def communicate(self, message: Communication) -> bool:
"""Send message to another agent. """Send message to another agent.
Swarm comms removed — inter-agent communication will be handled Swarm comms removed — inter-agent communication will be handled
by the unified brain memory layer. by the unified brain memory layer.
""" """
return False return False
def _extract_tags(self, perception: Perception) -> list[str]: def _extract_tags(self, perception: Perception) -> list[str]:
"""Extract searchable tags from perception.""" """Extract searchable tags from perception."""
tags = [perception.type.name, perception.source] tags = [perception.type.name, perception.source]
if perception.type == PerceptionType.TEXT: if perception.type == PerceptionType.TEXT:
# Simple keyword extraction # Simple keyword extraction
text = str(perception.data).lower() text = str(perception.data).lower()
@@ -249,14 +250,14 @@ Respond naturally and helpfully."""
for kw in keywords: for kw in keywords:
if kw in text: if kw in text:
tags.append(kw) tags.append(kw)
return tags return tags
def _format_context(self, memories: list[Memory]) -> str: def _format_context(self, memories: list[Memory]) -> str:
"""Format memories into context string for prompt.""" """Format memories into context string for prompt."""
if not memories: if not memories:
return "No previous context." return "No previous context."
parts = [] parts = []
for mem in memories[-5:]: # Last 5 memories for mem in memories[-5:]: # Last 5 memories
if isinstance(mem.content, dict): if isinstance(mem.content, dict):
@@ -264,9 +265,9 @@ Respond naturally and helpfully."""
parts.append(f"- {data}") parts.append(f"- {data}")
else: else:
parts.append(f"- {mem.content}") parts.append(f"- {mem.content}")
return "\n".join(parts) return "\n".join(parts)
def get_effect_log(self) -> Optional[list[dict]]: def get_effect_log(self) -> Optional[list[dict]]:
"""Export effect log if logging is enabled.""" """Export effect log if logging is enabled."""
if self._effect_log: if self._effect_log:

View File

@@ -30,9 +30,11 @@ logger = logging.getLogger(__name__)
# Data structures # Data structures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@dataclass @dataclass
class AgenticStep: class AgenticStep:
"""Result of a single step in the agentic loop.""" """Result of a single step in the agentic loop."""
step_num: int step_num: int
description: str description: str
result: str result: str
@@ -43,6 +45,7 @@ class AgenticStep:
@dataclass @dataclass
class AgenticResult: class AgenticResult:
"""Final result of the entire agentic loop.""" """Final result of the entire agentic loop."""
task_id: str task_id: str
task: str task: str
summary: str summary: str
@@ -55,6 +58,7 @@ class AgenticResult:
# Agent factory # Agent factory
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _get_loop_agent(): def _get_loop_agent():
"""Create a fresh agent for the agentic loop. """Create a fresh agent for the agentic loop.
@@ -62,6 +66,7 @@ def _get_loop_agent():
dedicated session so it doesn't pollute the main chat history. dedicated session so it doesn't pollute the main chat history.
""" """
from timmy.agent import create_timmy from timmy.agent import create_timmy
return create_timmy() return create_timmy()
@@ -85,6 +90,7 @@ def _parse_steps(plan_text: str) -> list[str]:
# Core loop # Core loop
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def run_agentic_loop( async def run_agentic_loop(
task: str, task: str,
*, *,
@@ -146,12 +152,15 @@ async def run_agentic_loop(
was_truncated = planned_steps > total_steps was_truncated = planned_steps > total_steps
# Broadcast plan # Broadcast plan
await _broadcast_progress("agentic.plan_ready", { await _broadcast_progress(
"task_id": task_id, "agentic.plan_ready",
"task": task, {
"steps": steps, "task_id": task_id,
"total": total_steps, "task": task,
}) "steps": steps,
"total": total_steps,
},
)
# ── Phase 2: Execution ───────────────────────────────────────────────── # ── Phase 2: Execution ─────────────────────────────────────────────────
completed_results: list[str] = [] completed_results: list[str] = []
@@ -175,6 +184,7 @@ async def run_agentic_loop(
# Clean the response # Clean the response
from timmy.session import _clean_response from timmy.session import _clean_response
step_result = _clean_response(step_result) step_result = _clean_response(step_result)
step = AgenticStep( step = AgenticStep(
@@ -188,13 +198,16 @@ async def run_agentic_loop(
completed_results.append(f"Step {i}: {step_result[:200]}") completed_results.append(f"Step {i}: {step_result[:200]}")
# Broadcast progress # Broadcast progress
await _broadcast_progress("agentic.step_complete", { await _broadcast_progress(
"task_id": task_id, "agentic.step_complete",
"step": i, {
"total": total_steps, "task_id": task_id,
"description": step_desc, "step": i,
"result": step_result[:200], "total": total_steps,
}) "description": step_desc,
"result": step_result[:200],
},
)
if on_progress: if on_progress:
await on_progress(step_desc, i, total_steps) await on_progress(step_desc, i, total_steps)
@@ -210,11 +223,16 @@ async def run_agentic_loop(
) )
try: try:
adapt_run = await asyncio.to_thread( adapt_run = await asyncio.to_thread(
agent.run, adapt_prompt, stream=False, agent.run,
adapt_prompt,
stream=False,
session_id=f"{session_id}_adapt{i}", session_id=f"{session_id}_adapt{i}",
) )
adapt_result = adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run) adapt_result = (
adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run)
)
from timmy.session import _clean_response from timmy.session import _clean_response
adapt_result = _clean_response(adapt_result) adapt_result = _clean_response(adapt_result)
step = AgenticStep( step = AgenticStep(
@@ -227,14 +245,17 @@ async def run_agentic_loop(
result.steps.append(step) result.steps.append(step)
completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}") completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}")
await _broadcast_progress("agentic.step_adapted", { await _broadcast_progress(
"task_id": task_id, "agentic.step_adapted",
"step": i, {
"total": total_steps, "task_id": task_id,
"description": step_desc, "step": i,
"error": str(exc), "total": total_steps,
"adaptation": adapt_result[:200], "description": step_desc,
}) "error": str(exc),
"adaptation": adapt_result[:200],
},
)
if on_progress: if on_progress:
await on_progress(f"[Adapted] {step_desc}", i, total_steps) await on_progress(f"[Adapted] {step_desc}", i, total_steps)
@@ -259,11 +280,16 @@ async def run_agentic_loop(
) )
try: try:
summary_run = await asyncio.to_thread( summary_run = await asyncio.to_thread(
agent.run, summary_prompt, stream=False, agent.run,
summary_prompt,
stream=False,
session_id=f"{session_id}_summary", session_id=f"{session_id}_summary",
) )
result.summary = summary_run.content if hasattr(summary_run, "content") else str(summary_run) result.summary = (
summary_run.content if hasattr(summary_run, "content") else str(summary_run)
)
from timmy.session import _clean_response from timmy.session import _clean_response
result.summary = _clean_response(result.summary) result.summary = _clean_response(result.summary)
except Exception as exc: except Exception as exc:
logger.error("Agentic loop summary failed: %s", exc) logger.error("Agentic loop summary failed: %s", exc)
@@ -281,13 +307,16 @@ async def run_agentic_loop(
result.total_duration_ms = int((time.monotonic() - start_time) * 1000) result.total_duration_ms = int((time.monotonic() - start_time) * 1000)
await _broadcast_progress("agentic.task_complete", { await _broadcast_progress(
"task_id": task_id, "agentic.task_complete",
"status": result.status, {
"steps_completed": len(result.steps), "task_id": task_id,
"summary": result.summary[:300], "status": result.status,
"duration_ms": result.total_duration_ms, "steps_completed": len(result.steps),
}) "summary": result.summary[:300],
"duration_ms": result.total_duration_ms,
},
)
return result return result
@@ -296,10 +325,12 @@ async def run_agentic_loop(
# WebSocket broadcast helper # WebSocket broadcast helper
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def _broadcast_progress(event: str, data: dict) -> None: async def _broadcast_progress(event: str, data: dict) -> None:
"""Broadcast agentic loop progress via WebSocket (best-effort).""" """Broadcast agentic loop progress via WebSocket (best-effort)."""
try: try:
from infrastructure.ws_manager.handler import ws_manager from infrastructure.ws_manager.handler import ws_manager
await ws_manager.broadcast(event, data) await ws_manager.broadcast(event, data)
except Exception: except Exception:
logger.debug("Agentic loop: WS broadcast failed for %s", event) logger.debug("Agentic loop: WS broadcast failed for %s", event)

View File

@@ -18,7 +18,7 @@ from agno.agent import Agent
from agno.models.ollama import Ollama from agno.models.ollama import Ollama
from config import settings from config import settings
from infrastructure.events.bus import EventBus, Event from infrastructure.events.bus import Event, EventBus
try: try:
from mcp.registry import tool_registry from mcp.registry import tool_registry
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class BaseAgent(ABC): class BaseAgent(ABC):
"""Base class for all sub-agents.""" """Base class for all sub-agents."""
def __init__( def __init__(
self, self,
agent_id: str, agent_id: str,
@@ -43,15 +43,15 @@ class BaseAgent(ABC):
self.name = name self.name = name
self.role = role self.role = role
self.tools = tools or [] self.tools = tools or []
# Create Agno agent # Create Agno agent
self.agent = self._create_agent(system_prompt) self.agent = self._create_agent(system_prompt)
# Event bus for communication # Event bus for communication
self.event_bus: Optional[EventBus] = None self.event_bus: Optional[EventBus] = None
logger.info("%s agent initialized (id: %s)", name, agent_id) logger.info("%s agent initialized (id: %s)", name, agent_id)
def _create_agent(self, system_prompt: str) -> Agent: def _create_agent(self, system_prompt: str) -> Agent:
"""Create the underlying Agno agent.""" """Create the underlying Agno agent."""
# Get tools from registry # Get tools from registry
@@ -60,7 +60,7 @@ class BaseAgent(ABC):
handler = tool_registry.get_handler(tool_name) handler = tool_registry.get_handler(tool_name)
if handler: if handler:
tool_instances.append(handler) tool_instances.append(handler)
return Agent( return Agent(
name=self.name, name=self.name,
model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300), model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300),
@@ -71,19 +71,19 @@ class BaseAgent(ABC):
markdown=True, markdown=True,
telemetry=settings.telemetry_enabled, telemetry=settings.telemetry_enabled,
) )
def connect_event_bus(self, bus: EventBus) -> None: def connect_event_bus(self, bus: EventBus) -> None:
"""Connect to the event bus for inter-agent communication.""" """Connect to the event bus for inter-agent communication."""
self.event_bus = bus self.event_bus = bus
# Subscribe to relevant events # Subscribe to relevant events
bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message) bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message)
bus.subscribe("agent.task.assigned")(self._handle_task_assignment) bus.subscribe("agent.task.assigned")(self._handle_task_assignment)
async def _handle_direct_message(self, event: Event) -> None: async def _handle_direct_message(self, event: Event) -> None:
"""Handle direct messages to this agent.""" """Handle direct messages to this agent."""
logger.debug("%s received message: %s", self.name, event.type) logger.debug("%s received message: %s", self.name, event.type)
async def _handle_task_assignment(self, event: Event) -> None: async def _handle_task_assignment(self, event: Event) -> None:
"""Handle task assignment events.""" """Handle task assignment events."""
assigned_agent = event.data.get("agent_id") assigned_agent = event.data.get("agent_id")
@@ -91,41 +91,43 @@ class BaseAgent(ABC):
task_id = event.data.get("task_id") task_id = event.data.get("task_id")
description = event.data.get("description", "") description = event.data.get("description", "")
logger.info("%s assigned task %s: %s", self.name, task_id, description[:50]) logger.info("%s assigned task %s: %s", self.name, task_id, description[:50])
# Execute the task # Execute the task
await self.execute_task(task_id, description, event.data) await self.execute_task(task_id, description, event.data)
@abstractmethod @abstractmethod
async def execute_task(self, task_id: str, description: str, context: dict) -> Any: async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
"""Execute a task assigned to this agent. """Execute a task assigned to this agent.
Must be implemented by subclasses. Must be implemented by subclasses.
""" """
pass pass
async def run(self, message: str) -> str: async def run(self, message: str) -> str:
"""Run the agent with a message. """Run the agent with a message.
Returns: Returns:
Agent response Agent response
""" """
result = self.agent.run(message, stream=False) result = self.agent.run(message, stream=False)
response = result.content if hasattr(result, "content") else str(result) response = result.content if hasattr(result, "content") else str(result)
# Emit completion event # Emit completion event
if self.event_bus: if self.event_bus:
await self.event_bus.publish(Event( await self.event_bus.publish(
type=f"agent.{self.agent_id}.response", Event(
source=self.agent_id, type=f"agent.{self.agent_id}.response",
data={"input": message, "output": response}, source=self.agent_id,
)) data={"input": message, "output": response},
)
)
return response return response
def get_capabilities(self) -> list[str]: def get_capabilities(self) -> list[str]:
"""Get list of capabilities this agent provides.""" """Get list of capabilities this agent provides."""
return self.tools return self.tools
def get_status(self) -> dict: def get_status(self) -> dict:
"""Get current agent status.""" """Get current agent status."""
return { return {

View File

@@ -12,9 +12,9 @@ from typing import Any, Optional
from agno.agent import Agent from agno.agent import Agent
from agno.models.ollama import Ollama from agno.models.ollama import Ollama
from timmy.agents.base import BaseAgent, SubAgent
from config import settings from config import settings
from infrastructure.events.bus import EventBus, event_bus from infrastructure.events.bus import EventBus, event_bus
from timmy.agents.base import BaseAgent, SubAgent
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ _timmy_context: dict[str, Any] = {
async def _load_hands_async() -> list[dict]: async def _load_hands_async() -> list[dict]:
"""Async helper to load hands. """Async helper to load hands.
Hands registry removed — hand definitions live in TOML files under hands/. Hands registry removed — hand definitions live in TOML files under hands/.
This will be rewired to read from brain memory. This will be rewired to read from brain memory.
""" """
@@ -42,7 +42,7 @@ def build_timmy_context_sync() -> dict[str, Any]:
Gathers git commits, active sub-agents, and hot memory. Gathers git commits, active sub-agents, and hot memory.
""" """
global _timmy_context global _timmy_context
ctx: dict[str, Any] = { ctx: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(), "timestamp": datetime.now(timezone.utc).isoformat(),
"repo_root": settings.repo_root, "repo_root": settings.repo_root,
@@ -51,45 +51,52 @@ def build_timmy_context_sync() -> dict[str, Any]:
"hands": [], "hands": [],
"memory": "", "memory": "",
} }
# 1. Get recent git commits # 1. Get recent git commits
try: try:
from tools.git_tools import git_log from tools.git_tools import git_log
result = git_log(max_count=20) result = git_log(max_count=20)
if result.get("success"): if result.get("success"):
commits = result.get("commits", []) commits = result.get("commits", [])
ctx["git_log"] = "\n".join([ ctx["git_log"] = "\n".join(
f"{c['short_sha']} {c['message'].split(chr(10))[0]}" [f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits[:20]]
for c in commits[:20] )
])
except Exception as exc: except Exception as exc:
logger.warning("Could not load git log for context: %s", exc) logger.warning("Could not load git log for context: %s", exc)
ctx["git_log"] = "(Git log unavailable)" ctx["git_log"] = "(Git log unavailable)"
# 2. Get active sub-agents # 2. Get active sub-agents
try: try:
from swarm import registry as swarm_registry from swarm import registry as swarm_registry
conn = swarm_registry._get_conn() conn = swarm_registry._get_conn()
rows = conn.execute( rows = conn.execute(
"SELECT id, name, status, capabilities FROM agents ORDER BY name" "SELECT id, name, status, capabilities FROM agents ORDER BY name"
).fetchall() ).fetchall()
ctx["agents"] = [ ctx["agents"] = [
{"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]} {
"id": r["id"],
"name": r["name"],
"status": r["status"],
"capabilities": r["capabilities"],
}
for r in rows for r in rows
] ]
conn.close() conn.close()
except Exception as exc: except Exception as exc:
logger.warning("Could not load agents for context: %s", exc) logger.warning("Could not load agents for context: %s", exc)
ctx["agents"] = [] ctx["agents"] = []
# 3. Read hot memory (via HotMemory to auto-create if missing) # 3. Read hot memory (via HotMemory to auto-create if missing)
try: try:
from timmy.memory_system import memory_system from timmy.memory_system import memory_system
ctx["memory"] = memory_system.hot.read()[:2000] ctx["memory"] = memory_system.hot.read()[:2000]
except Exception as exc: except Exception as exc:
logger.warning("Could not load memory for context: %s", exc) logger.warning("Could not load memory for context: %s", exc)
ctx["memory"] = "(Memory unavailable)" ctx["memory"] = "(Memory unavailable)"
_timmy_context.update(ctx) _timmy_context.update(ctx)
logger.info("Context built (sync): %d agents", len(ctx["agents"])) logger.info("Context built (sync): %d agents", len(ctx["agents"]))
return ctx return ctx
@@ -110,21 +117,31 @@ build_timmy_context = build_timmy_context_sync
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str: def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
"""Format the system prompt with dynamic context.""" """Format the system prompt with dynamic context."""
# Format agents list # Format agents list
agents_list = "\n".join([ agents_list = (
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |" "\n".join(
for a in context.get("agents", []) [
]) or "(No agents registered yet)" f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
for a in context.get("agents", [])
]
)
or "(No agents registered yet)"
)
# Format hands list # Format hands list
hands_list = "\n".join([ hands_list = (
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |" "\n".join(
for h in context.get("hands", []) [
]) or "(No hands configured)" f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
for h in context.get("hands", [])
repo_root = context.get('repo_root', settings.repo_root) ]
)
or "(No hands configured)"
)
repo_root = context.get("repo_root", settings.repo_root)
context_block = f""" context_block = f"""
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())}) ## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
@@ -149,10 +166,10 @@ def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
### Hot Memory: ### Hot Memory:
{context.get('memory', '(unavailable)')[:1000]} {context.get('memory', '(unavailable)')[:1000]}
""" """
# Replace {REPO_ROOT} placeholder with actual path # Replace {REPO_ROOT} placeholder with actual path
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root) base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
# Insert context after the first line # Insert context after the first line
lines = base_prompt.split("\n") lines = base_prompt.split("\n")
if lines: if lines:
@@ -227,63 +244,71 @@ class TimmyOrchestrator(BaseAgent):
name="Orchestrator", name="Orchestrator",
role="orchestrator", role="orchestrator",
system_prompt=formatted_prompt, system_prompt=formatted_prompt,
tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write", "system_status"], tools=[
"web_search",
"read_file",
"write_file",
"python",
"memory_search",
"memory_write",
"system_status",
],
) )
# Sub-agent registry # Sub-agent registry
self.sub_agents: dict[str, BaseAgent] = {} self.sub_agents: dict[str, BaseAgent] = {}
# Session tracking for init behavior # Session tracking for init behavior
self._session_initialized = False self._session_initialized = False
self._session_context: dict[str, Any] = {} self._session_context: dict[str, Any] = {}
self._context_fully_loaded = False self._context_fully_loaded = False
# Connect to event bus # Connect to event bus
self.connect_event_bus(event_bus) self.connect_event_bus(event_bus)
logger.info("Orchestrator initialized with context-aware prompt") logger.info("Orchestrator initialized with context-aware prompt")
def register_sub_agent(self, agent: BaseAgent) -> None: def register_sub_agent(self, agent: BaseAgent) -> None:
"""Register a sub-agent with the orchestrator.""" """Register a sub-agent with the orchestrator."""
self.sub_agents[agent.agent_id] = agent self.sub_agents[agent.agent_id] = agent
agent.connect_event_bus(event_bus) agent.connect_event_bus(event_bus)
logger.info("Registered sub-agent: %s", agent.name) logger.info("Registered sub-agent: %s", agent.name)
async def _session_init(self) -> None: async def _session_init(self) -> None:
"""Initialize session context on first user message. """Initialize session context on first user message.
Silently reads git log and AGENTS.md to ground the orchestrator in real data. Silently reads git log and AGENTS.md to ground the orchestrator in real data.
This runs once per session before the first response. This runs once per session before the first response.
""" """
if self._session_initialized: if self._session_initialized:
return return
logger.debug("Running session init...") logger.debug("Running session init...")
# Load full context including hands if not already done # Load full context including hands if not already done
if not self._context_fully_loaded: if not self._context_fully_loaded:
await build_timmy_context_async() await build_timmy_context_async()
self._context_fully_loaded = True self._context_fully_loaded = True
# Read recent git log --oneline -15 from repo root # Read recent git log --oneline -15 from repo root
try: try:
from tools.git_tools import git_log from tools.git_tools import git_log
git_result = git_log(max_count=15) git_result = git_log(max_count=15)
if git_result.get("success"): if git_result.get("success"):
commits = git_result.get("commits", []) commits = git_result.get("commits", [])
self._session_context["git_log_commits"] = commits self._session_context["git_log_commits"] = commits
# Format as oneline for easy reading # Format as oneline for easy reading
self._session_context["git_log_oneline"] = "\n".join([ self._session_context["git_log_oneline"] = "\n".join(
f"{c['short_sha']} {c['message'].split(chr(10))[0]}" [f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits]
for c in commits )
])
logger.debug(f"Session init: loaded {len(commits)} commits from git log") logger.debug(f"Session init: loaded {len(commits)} commits from git log")
else: else:
self._session_context["git_log_oneline"] = "Git log unavailable" self._session_context["git_log_oneline"] = "Git log unavailable"
except Exception as exc: except Exception as exc:
logger.warning("Session init: could not read git log: %s", exc) logger.warning("Session init: could not read git log: %s", exc)
self._session_context["git_log_oneline"] = "Git log unavailable" self._session_context["git_log_oneline"] = "Git log unavailable"
# Read AGENTS.md for self-awareness # Read AGENTS.md for self-awareness
try: try:
agents_md_path = Path(settings.repo_root) / "AGENTS.md" agents_md_path = Path(settings.repo_root) / "AGENTS.md"
@@ -291,7 +316,7 @@ class TimmyOrchestrator(BaseAgent):
self._session_context["agents_md"] = agents_md_path.read_text()[:3000] self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
except Exception as exc: except Exception as exc:
logger.warning("Session init: could not read AGENTS.md: %s", exc) logger.warning("Session init: could not read AGENTS.md: %s", exc)
# Read CHANGELOG for recent changes # Read CHANGELOG for recent changes
try: try:
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md" changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
@@ -299,11 +324,13 @@ class TimmyOrchestrator(BaseAgent):
self._session_context["changelog"] = changelog_path.read_text()[:2000] self._session_context["changelog"] = changelog_path.read_text()[:2000]
except Exception: except Exception:
pass # Changelog is optional pass # Changelog is optional
# Build session-specific context block for the prompt # Build session-specific context block for the prompt
recent_changes = self._session_context.get("git_log_oneline", "") recent_changes = self._session_context.get("git_log_oneline", "")
if recent_changes and recent_changes != "Git log unavailable": if recent_changes and recent_changes != "Git log unavailable":
self._session_context["recent_changes_block"] = f""" self._session_context[
"recent_changes_block"
] = f"""
## Recent Changes to Your Codebase (last 15 commits): ## Recent Changes to Your Codebase (last 15 commits):
``` ```
{recent_changes} {recent_changes}
@@ -312,17 +339,17 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
""" """
else: else:
self._session_context["recent_changes_block"] = "" self._session_context["recent_changes_block"] = ""
self._session_initialized = True self._session_initialized = True
logger.debug("Session init complete") logger.debug("Session init complete")
def _get_enhanced_system_prompt(self) -> str: def _get_enhanced_system_prompt(self) -> str:
"""Get system prompt enhanced with session-specific context. """Get system prompt enhanced with session-specific context.
Prepends the recent git log to the system prompt for grounding. Prepends the recent git log to the system prompt for grounding.
""" """
base = self.system_prompt base = self.system_prompt
# Add recent changes block if available # Add recent changes block if available
recent_changes = self._session_context.get("recent_changes_block", "") recent_changes = self._session_context.get("recent_changes_block", "")
if recent_changes: if recent_changes:
@@ -330,36 +357,45 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
lines = base.split("\n") lines = base.split("\n")
if lines: if lines:
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:]) return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
return base return base
async def orchestrate(self, user_request: str) -> str: async def orchestrate(self, user_request: str) -> str:
"""Main entry point for user requests. """Main entry point for user requests.
Analyzes the request and either handles directly or delegates. Analyzes the request and either handles directly or delegates.
""" """
# Run session init on first message (loads git log, etc.) # Run session init on first message (loads git log, etc.)
await self._session_init() await self._session_init()
# Quick classification # Quick classification
request_lower = user_request.lower() request_lower = user_request.lower()
# Direct response patterns (no delegation needed) # Direct response patterns (no delegation needed)
direct_patterns = [ direct_patterns = [
"your name", "who are you", "what are you", "your name",
"hello", "hi", "how are you", "who are you",
"help", "what can you do", "what are you",
"hello",
"hi",
"how are you",
"help",
"what can you do",
] ]
for pattern in direct_patterns: for pattern in direct_patterns:
if pattern in request_lower: if pattern in request_lower:
return await self.run(user_request) return await self.run(user_request)
# Check for memory references — delegate to Echo # Check for memory references — delegate to Echo
memory_patterns = [ memory_patterns = [
"we talked about", "we discussed", "remember", "we talked about",
"what did i say", "what did we decide", "we discussed",
"remind me", "have we", "remember",
"what did i say",
"what did we decide",
"remind me",
"have we",
] ]
for pattern in memory_patterns: for pattern in memory_patterns:
@@ -395,19 +431,16 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
if agent in text_lower: if agent in text_lower:
return agent return agent
return "orchestrator" return "orchestrator"
async def execute_task(self, task_id: str, description: str, context: dict) -> Any: async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
"""Execute a task (usually delegates to appropriate agent).""" """Execute a task (usually delegates to appropriate agent)."""
return await self.orchestrate(description) return await self.orchestrate(description)
def get_swarm_status(self) -> dict: def get_swarm_status(self) -> dict:
"""Get status of all agents in the swarm.""" """Get status of all agents in the swarm."""
return { return {
"orchestrator": self.get_status(), "orchestrator": self.get_status(),
"sub_agents": { "sub_agents": {aid: agent.get_status() for aid, agent in self.sub_agents.items()},
aid: agent.get_status()
for aid, agent in self.sub_agents.items()
},
"total_agents": 1 + len(self.sub_agents), "total_agents": 1 + len(self.sub_agents),
} }
@@ -468,10 +501,29 @@ _PERSONAS: list[dict[str, Any]] = [
"system_prompt": ( "system_prompt": (
"You are Helm, a routing and orchestration specialist.\n" "You are Helm, a routing and orchestration specialist.\n"
"Analyze tasks and decide how to route them to other agents.\n" "Analyze tasks and decide how to route them to other agents.\n"
"Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory).\n" "Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory), Lab (experiments).\n"
"Respond with: Primary Agent: [agent name]" "Respond with: Primary Agent: [agent name]"
), ),
}, },
{
"agent_id": "lab",
"name": "Lab",
"role": "experiment",
"tools": [
"run_experiment",
"prepare_experiment",
"shell",
"python",
"read_file",
"write_file",
],
"system_prompt": (
"You are Lab, an autonomous ML experimentation specialist.\n"
"You run time-boxed training experiments, evaluate metrics,\n"
"modify training code to improve results, and iterate.\n"
"Always report the metric delta. Never exceed the time budget."
),
},
] ]

View File

@@ -38,10 +38,10 @@ class ApprovalItem:
id: str id: str
title: str title: str
description: str description: str
proposed_action: str # what Timmy wants to do proposed_action: str # what Timmy wants to do
impact: str # "low" | "medium" | "high" impact: str # "low" | "medium" | "high"
created_at: datetime created_at: datetime
status: str # "pending" | "approved" | "rejected" status: str # "pending" | "approved" | "rejected"
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
@@ -81,6 +81,7 @@ def _row_to_item(row: sqlite3.Row) -> ApprovalItem:
# Public API # Public API
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def create_item( def create_item(
title: str, title: str,
description: str, description: str,
@@ -133,18 +134,14 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]: def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
"""Return all approval items regardless of status, newest first.""" """Return all approval items regardless of status, newest first."""
conn = _get_conn(db_path) conn = _get_conn(db_path)
rows = conn.execute( rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall()
"SELECT * FROM approval_items ORDER BY created_at DESC"
).fetchall()
conn.close() conn.close()
return [_row_to_item(r) for r in rows] return [_row_to_item(r) for r in rows]
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
conn = _get_conn(db_path) conn = _get_conn(db_path)
row = conn.execute( row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone()
"SELECT * FROM approval_items WHERE id = ?", (item_id,)
).fetchone()
conn.close() conn.close()
return _row_to_item(row) if row else None return _row_to_item(row) if row else None
@@ -152,9 +149,7 @@ def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
"""Mark an approval item as approved.""" """Mark an approval item as approved."""
conn = _get_conn(db_path) conn = _get_conn(db_path)
conn.execute( conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,))
"UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,)
)
conn.commit() conn.commit()
conn.close() conn.close()
return get_item(item_id, db_path) return get_item(item_id, db_path)
@@ -163,9 +158,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]: def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
"""Mark an approval item as rejected.""" """Mark an approval item as rejected."""
conn = _get_conn(db_path) conn = _get_conn(db_path)
conn.execute( conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,))
"UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,)
)
conn.commit() conn.commit()
conn.close() conn.close()
return get_item(item_id, db_path) return get_item(item_id, db_path)

214
src/timmy/autoresearch.py Normal file
View File

@@ -0,0 +1,214 @@
"""Autoresearch — autonomous ML experiment loops.
Integrates Karpathy's autoresearch pattern: an agent modifies training
code, runs time-boxed GPU experiments, evaluates a target metric
(val_bpb by default), and iterates to find improvements.
Flow:
1. prepare_experiment — clone repo + run data prep
2. run_experiment — execute train.py with wall-clock timeout
3. evaluate_result — compare metric against baseline
4. experiment_loop — orchestrate the full cycle
All subprocess calls are guarded with timeouts for graceful degradation.
"""
from __future__ import annotations
import json
import logging
import re
import subprocess
import time
from pathlib import Path
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
DEFAULT_REPO = "https://github.com/karpathy/autoresearch.git"
_METRIC_RE = re.compile(r"val_bpb[:\s]+([0-9]+\.?[0-9]*)")
def prepare_experiment(
workspace: Path,
repo_url: str = DEFAULT_REPO,
) -> str:
"""Clone autoresearch repo and run data preparation.
Args:
workspace: Directory to set up the experiment in.
repo_url: Git URL for the autoresearch repository.
Returns:
Status message describing what was prepared.
"""
workspace = Path(workspace)
workspace.mkdir(parents=True, exist_ok=True)
repo_dir = workspace / "autoresearch"
if not repo_dir.exists():
logger.info("Cloning autoresearch into %s", repo_dir)
result = subprocess.run(
["git", "clone", "--depth", "1", repo_url, str(repo_dir)],
capture_output=True,
text=True,
timeout=120,
)
if result.returncode != 0:
return f"Clone failed: {result.stderr.strip()}"
else:
logger.info("Autoresearch repo already present at %s", repo_dir)
# Run prepare.py (data download + tokeniser training)
prepare_script = repo_dir / "prepare.py"
if prepare_script.exists():
logger.info("Running prepare.py …")
result = subprocess.run(
["python", str(prepare_script)],
capture_output=True,
text=True,
cwd=str(repo_dir),
timeout=300,
)
if result.returncode != 0:
return f"Preparation failed: {result.stderr.strip()[:500]}"
return "Preparation complete — data downloaded and tokeniser trained."
return "Preparation skipped — no prepare.py found."
def run_experiment(
workspace: Path,
timeout: int = 300,
metric_name: str = "val_bpb",
) -> dict[str, Any]:
"""Run a single training experiment with a wall-clock timeout.
Args:
workspace: Experiment workspace (contains autoresearch/ subdir).
timeout: Maximum wall-clock seconds for the run.
metric_name: Name of the metric to extract from stdout.
Returns:
Dict with keys: metric (float|None), log (str), duration_s (int),
success (bool), error (str|None).
"""
repo_dir = Path(workspace) / "autoresearch"
train_script = repo_dir / "train.py"
if not train_script.exists():
return {
"metric": None,
"log": "",
"duration_s": 0,
"success": False,
"error": f"train.py not found in {repo_dir}",
}
start = time.monotonic()
try:
result = subprocess.run(
["python", str(train_script)],
capture_output=True,
text=True,
cwd=str(repo_dir),
timeout=timeout,
)
duration = int(time.monotonic() - start)
output = result.stdout + result.stderr
# Extract metric from output
metric_val = _extract_metric(output, metric_name)
return {
"metric": metric_val,
"log": output[-2000:], # Keep last 2k chars
"duration_s": duration,
"success": result.returncode == 0,
"error": None if result.returncode == 0 else f"Exit code {result.returncode}",
}
except subprocess.TimeoutExpired:
duration = int(time.monotonic() - start)
return {
"metric": None,
"log": f"Experiment timed out after {timeout}s",
"duration_s": duration,
"success": False,
"error": f"Timed out after {timeout}s",
}
except OSError as exc:
return {
"metric": None,
"log": "",
"duration_s": 0,
"success": False,
"error": str(exc),
}
def _extract_metric(output: str, metric_name: str = "val_bpb") -> Optional[float]:
"""Extract the last occurrence of a metric value from training output."""
pattern = re.compile(rf"{re.escape(metric_name)}[:\s]+([0-9]+\.?[0-9]*)")
matches = pattern.findall(output)
if matches:
try:
return float(matches[-1])
except ValueError:
pass
return None
def evaluate_result(
current: float,
baseline: float,
metric_name: str = "val_bpb",
) -> str:
"""Compare a metric against baseline and return an assessment.
For val_bpb, lower is better.
Args:
current: Current experiment's metric value.
baseline: Baseline metric to compare against.
metric_name: Name of the metric (for display).
Returns:
Human-readable assessment string.
"""
delta = current - baseline
pct = (delta / baseline) * 100 if baseline != 0 else 0.0
if delta < 0:
return f"Improvement: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
elif delta > 0:
return f"Regression: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
else:
return f"No change: {metric_name} = {current:.4f}"
def get_experiment_history(workspace: Path) -> list[dict[str, Any]]:
"""Read experiment history from the workspace results file.
Returns:
List of experiment result dicts, most recent first.
"""
results_file = Path(workspace) / "results.jsonl"
if not results_file.exists():
return []
history: list[dict[str, Any]] = []
for line in results_file.read_text().strip().splitlines():
try:
history.append(json.loads(line))
except json.JSONDecodeError:
continue
return list(reversed(history))
def _append_result(workspace: Path, result: dict[str, Any]) -> None:
"""Append a result to the workspace JSONL log."""
results_file = Path(workspace) / "results.jsonl"
results_file.parent.mkdir(parents=True, exist_ok=True)
with results_file.open("a") as f:
f.write(json.dumps(result) + "\n")

View File

@@ -24,8 +24,8 @@ logger = logging.getLogger(__name__)
# HuggingFace model IDs for each supported size. # HuggingFace model IDs for each supported size.
_AIRLLM_MODELS: dict[str, str] = { _AIRLLM_MODELS: dict[str, str] = {
"8b": "meta-llama/Meta-Llama-3.1-8B-Instruct", "8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"70b": "meta-llama/Meta-Llama-3.1-70B-Instruct", "70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
"405b": "meta-llama/Meta-Llama-3.1-405B-Instruct", "405b": "meta-llama/Meta-Llama-3.1-405B-Instruct",
} }
@@ -35,6 +35,7 @@ ModelSize = Literal["8b", "70b", "405b"]
@dataclass @dataclass
class RunResult: class RunResult:
"""Minimal Agno-compatible run result — carries the model's response text.""" """Minimal Agno-compatible run result — carries the model's response text."""
content: str content: str
@@ -47,6 +48,7 @@ def airllm_available() -> bool:
"""Return True when the airllm package is importable.""" """Return True when the airllm package is importable."""
try: try:
import airllm # noqa: F401 import airllm # noqa: F401
return True return True
except ImportError: except ImportError:
return False return False
@@ -67,15 +69,16 @@ class TimmyAirLLMAgent:
model_id = _AIRLLM_MODELS.get(model_size) model_id = _AIRLLM_MODELS.get(model_size)
if model_id is None: if model_id is None:
raise ValueError( raise ValueError(
f"Unknown model size {model_size!r}. " f"Unknown model size {model_size!r}. " f"Choose from: {list(_AIRLLM_MODELS)}"
f"Choose from: {list(_AIRLLM_MODELS)}"
) )
if is_apple_silicon(): if is_apple_silicon():
from airllm import AirLLMMLX # type: ignore[import] from airllm import AirLLMMLX # type: ignore[import]
self._model = AirLLMMLX(model_id) self._model = AirLLMMLX(model_id)
else: else:
from airllm import AutoModel # type: ignore[import] from airllm import AutoModel # type: ignore[import]
self._model = AutoModel.from_pretrained(model_id) self._model = AutoModel.from_pretrained(model_id)
self._history: list[str] = [] self._history: list[str] = []
@@ -137,6 +140,7 @@ class TimmyAirLLMAgent:
try: try:
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
Console().print(Markdown(text)) Console().print(Markdown(text))
except ImportError: except ImportError:
print(text) print(text)
@@ -157,6 +161,7 @@ GROK_MODELS: dict[str, str] = {
@dataclass @dataclass
class GrokUsageStats: class GrokUsageStats:
"""Tracks Grok API usage for cost monitoring and Spark logging.""" """Tracks Grok API usage for cost monitoring and Spark logging."""
total_requests: int = 0 total_requests: int = 0
total_prompt_tokens: int = 0 total_prompt_tokens: int = 0
total_completion_tokens: int = 0 total_completion_tokens: int = 0
@@ -240,9 +245,7 @@ class GrokBackend:
RunResult with response content RunResult with response content
""" """
if not self._api_key: if not self._api_key:
return RunResult( return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
content="Grok is not configured. Set XAI_API_KEY to enable."
)
start = time.time() start = time.time()
messages = self._build_messages(message) messages = self._build_messages(message)
@@ -285,16 +288,12 @@ class GrokBackend:
except Exception as exc: except Exception as exc:
self.stats.errors += 1 self.stats.errors += 1
logger.error("Grok API error: %s", exc) logger.error("Grok API error: %s", exc)
return RunResult( return RunResult(content=f"Grok temporarily unavailable: {exc}")
content=f"Grok temporarily unavailable: {exc}"
)
async def arun(self, message: str) -> RunResult: async def arun(self, message: str) -> RunResult:
"""Async inference via Grok API — used by cascade router and tools.""" """Async inference via Grok API — used by cascade router and tools."""
if not self._api_key: if not self._api_key:
return RunResult( return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
content="Grok is not configured. Set XAI_API_KEY to enable."
)
start = time.time() start = time.time()
messages = self._build_messages(message) messages = self._build_messages(message)
@@ -336,9 +335,7 @@ class GrokBackend:
except Exception as exc: except Exception as exc:
self.stats.errors += 1 self.stats.errors += 1
logger.error("Grok async API error: %s", exc) logger.error("Grok async API error: %s", exc)
return RunResult( return RunResult(content=f"Grok temporarily unavailable: {exc}")
content=f"Grok temporarily unavailable: {exc}"
)
def print_response(self, message: str, *, stream: bool = True) -> None: def print_response(self, message: str, *, stream: bool = True) -> None:
"""Run inference and render the response to stdout (CLI interface).""" """Run inference and render the response to stdout (CLI interface)."""
@@ -346,6 +343,7 @@ class GrokBackend:
try: try:
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
Console().print(Markdown(result.content)) Console().print(Markdown(result.content))
except ImportError: except ImportError:
print(result.content) print(result.content)
@@ -415,6 +413,7 @@ def grok_available() -> bool:
"""Return True when Grok is enabled and API key is configured.""" """Return True when Grok is enabled and API key is configured."""
try: try:
from config import settings from config import settings
return settings.grok_enabled and bool(settings.xai_api_key) return settings.grok_enabled and bool(settings.xai_api_key)
except Exception: except Exception:
return False return False
@@ -472,9 +471,7 @@ class ClaudeBackend:
def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult: def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult:
"""Synchronous inference via Claude API.""" """Synchronous inference via Claude API."""
if not self._api_key: if not self._api_key:
return RunResult( return RunResult(content="Claude is not configured. Set ANTHROPIC_API_KEY to enable.")
content="Claude is not configured. Set ANTHROPIC_API_KEY to enable."
)
start = time.time() start = time.time()
messages = self._build_messages(message) messages = self._build_messages(message)
@@ -508,9 +505,7 @@ class ClaudeBackend:
except Exception as exc: except Exception as exc:
logger.error("Claude API error: %s", exc) logger.error("Claude API error: %s", exc)
return RunResult( return RunResult(content=f"Claude temporarily unavailable: {exc}")
content=f"Claude temporarily unavailable: {exc}"
)
def print_response(self, message: str, *, stream: bool = True) -> None: def print_response(self, message: str, *, stream: bool = True) -> None:
"""Run inference and render the response to stdout (CLI interface).""" """Run inference and render the response to stdout (CLI interface)."""
@@ -518,6 +513,7 @@ class ClaudeBackend:
try: try:
from rich.console import Console from rich.console import Console
from rich.markdown import Markdown from rich.markdown import Markdown
Console().print(Markdown(result.content)) Console().print(Markdown(result.content))
except ImportError: except ImportError:
print(result.content) print(result.content)
@@ -569,6 +565,7 @@ def claude_available() -> bool:
"""Return True when Anthropic API key is configured.""" """Return True when Anthropic API key is configured."""
try: try:
from config import settings from config import settings
return bool(settings.anthropic_api_key) return bool(settings.anthropic_api_key)
except Exception: except Exception:
return False return False

View File

@@ -25,6 +25,7 @@ _CACHE_MINUTES = 30
# Data structures # Data structures
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@dataclass @dataclass
class ApprovalItem: class ApprovalItem:
"""Lightweight representation used inside a Briefing. """Lightweight representation used inside a Briefing.
@@ -32,6 +33,7 @@ class ApprovalItem:
The canonical mutable version (with persistence) lives in timmy.approvals. The canonical mutable version (with persistence) lives in timmy.approvals.
This one travels with the Briefing dataclass as a read-only snapshot. This one travels with the Briefing dataclass as a read-only snapshot.
""" """
id: str id: str
title: str title: str
description: str description: str
@@ -44,20 +46,19 @@ class ApprovalItem:
@dataclass @dataclass
class Briefing: class Briefing:
generated_at: datetime generated_at: datetime
summary: str # 150-300 words summary: str # 150-300 words
approval_items: list[ApprovalItem] = field(default_factory=list) approval_items: list[ApprovalItem] = field(default_factory=list)
period_start: datetime = field( period_start: datetime = field(
default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6) default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6)
) )
period_end: datetime = field( period_end: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
default_factory=lambda: datetime.now(timezone.utc)
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# SQLite cache # SQLite cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection: def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path)) conn = sqlite3.connect(str(db_path))
@@ -98,9 +99,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]: def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
"""Load the most-recently cached briefing, or None if there is none.""" """Load the most-recently cached briefing, or None if there is none."""
conn = _get_cache_conn(db_path) conn = _get_cache_conn(db_path)
row = conn.execute( row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone()
"SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1"
).fetchone()
conn.close() conn.close()
if row is None: if row is None:
return None return None
@@ -115,7 +114,11 @@ def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool: def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
"""Return True if the briefing was generated within max_age_minutes.""" """Return True if the briefing was generated within max_age_minutes."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
age = now - briefing.generated_at.replace(tzinfo=timezone.utc) if briefing.generated_at.tzinfo is None else now - briefing.generated_at age = (
now - briefing.generated_at.replace(tzinfo=timezone.utc)
if briefing.generated_at.tzinfo is None
else now - briefing.generated_at
)
return age.total_seconds() < max_age_minutes * 60 return age.total_seconds() < max_age_minutes * 60
@@ -123,6 +126,7 @@ def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
# Activity gathering helpers # Activity gathering helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _gather_swarm_summary(since: datetime) -> str: def _gather_swarm_summary(since: datetime) -> str:
"""Pull recent task/agent stats from swarm.db. Graceful if DB missing.""" """Pull recent task/agent stats from swarm.db. Graceful if DB missing."""
swarm_db = Path("data/swarm.db") swarm_db = Path("data/swarm.db")
@@ -170,6 +174,7 @@ def _gather_task_queue_summary() -> str:
"""Pull task queue stats for the briefing. Graceful if unavailable.""" """Pull task queue stats for the briefing. Graceful if unavailable."""
try: try:
from swarm.task_queue.models import get_task_summary_for_briefing from swarm.task_queue.models import get_task_summary_for_briefing
stats = get_task_summary_for_briefing() stats = get_task_summary_for_briefing()
parts = [] parts = []
if stats["pending_approval"]: if stats["pending_approval"]:
@@ -194,6 +199,7 @@ def _gather_chat_summary(since: datetime) -> str:
"""Pull recent chat messages from the in-memory log.""" """Pull recent chat messages from the in-memory log."""
try: try:
from dashboard.store import message_log from dashboard.store import message_log
messages = message_log.all() messages = message_log.all()
# Filter to messages in the briefing window (best-effort: no timestamps) # Filter to messages in the briefing window (best-effort: no timestamps)
recent = messages[-10:] if len(messages) > 10 else messages recent = messages[-10:] if len(messages) > 10 else messages
@@ -213,6 +219,7 @@ def _gather_chat_summary(since: datetime) -> str:
# BriefingEngine # BriefingEngine
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class BriefingEngine: class BriefingEngine:
"""Generates morning briefings by querying activity and asking Timmy.""" """Generates morning briefings by querying activity and asking Timmy."""
@@ -297,6 +304,7 @@ class BriefingEngine:
"""Call Timmy's Agno agent and return the response text.""" """Call Timmy's Agno agent and return the response text."""
try: try:
from timmy.agent import create_timmy from timmy.agent import create_timmy
agent = create_timmy() agent = create_timmy()
run = agent.run(prompt, stream=False) run = agent.run(prompt, stream=False)
result = run.content if hasattr(run, "content") else str(run) result = run.content if hasattr(run, "content") else str(run)
@@ -317,6 +325,7 @@ class BriefingEngine:
"""Return pending ApprovalItems from the approvals DB.""" """Return pending ApprovalItems from the approvals DB."""
try: try:
from timmy import approvals as _approvals from timmy import approvals as _approvals
raw_items = _approvals.list_pending() raw_items = _approvals.list_pending()
return [ return [
ApprovalItem( ApprovalItem(

View File

@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class TimmyResponse: class TimmyResponse:
"""Response from Timmy via Cascade Router.""" """Response from Timmy via Cascade Router."""
content: str content: str
provider_used: str provider_used: str
latency_ms: float latency_ms: float
@@ -27,31 +28,30 @@ class TimmyResponse:
class TimmyCascadeAdapter: class TimmyCascadeAdapter:
"""Adapter that routes Timmy requests through Cascade Router. """Adapter that routes Timmy requests through Cascade Router.
Usage: Usage:
adapter = TimmyCascadeAdapter() adapter = TimmyCascadeAdapter()
response = await adapter.chat("Hello") response = await adapter.chat("Hello")
print(f"Response: {response.content}") print(f"Response: {response.content}")
print(f"Provider: {response.provider_used}") print(f"Provider: {response.provider_used}")
""" """
def __init__(self, router: Optional[CascadeRouter] = None) -> None: def __init__(self, router: Optional[CascadeRouter] = None) -> None:
"""Initialize adapter with Cascade Router. """Initialize adapter with Cascade Router.
Args: Args:
router: CascadeRouter instance. If None, creates default. router: CascadeRouter instance. If None, creates default.
""" """
self.router = router or CascadeRouter() self.router = router or CascadeRouter()
logger.info("TimmyCascadeAdapter initialized with %d providers", logger.info("TimmyCascadeAdapter initialized with %d providers", len(self.router.providers))
len(self.router.providers))
async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse: async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse:
"""Send message through cascade router with automatic failover. """Send message through cascade router with automatic failover.
Args: Args:
message: User message message: User message
context: Optional conversation context context: Optional conversation context
Returns: Returns:
TimmyResponse with content and metadata TimmyResponse with content and metadata
""" """
@@ -60,37 +60,38 @@ class TimmyCascadeAdapter:
if context: if context:
messages.append({"role": "system", "content": context}) messages.append({"role": "system", "content": context})
messages.append({"role": "user", "content": message}) messages.append({"role": "user", "content": message})
# Route through cascade # Route through cascade
import time import time
start = time.time() start = time.time()
try: try:
result = await self.router.complete( result = await self.router.complete(
messages=messages, messages=messages,
system_prompt=SYSTEM_PROMPT, system_prompt=SYSTEM_PROMPT,
) )
latency = (time.time() - start) * 1000 latency = (time.time() - start) * 1000
# Determine if fallback was used # Determine if fallback was used
primary = self.router.providers[0] if self.router.providers else None primary = self.router.providers[0] if self.router.providers else None
fallback_used = primary and primary.status.value != "healthy" fallback_used = primary and primary.status.value != "healthy"
return TimmyResponse( return TimmyResponse(
content=result.content, content=result.content,
provider_used=result.provider_name, provider_used=result.provider_name,
latency_ms=latency, latency_ms=latency,
fallback_used=fallback_used, fallback_used=fallback_used,
) )
except Exception as exc: except Exception as exc:
logger.error("All providers failed: %s", exc) logger.error("All providers failed: %s", exc)
raise raise
def get_provider_status(self) -> list[dict]: def get_provider_status(self) -> list[dict]:
"""Get status of all providers. """Get status of all providers.
Returns: Returns:
List of provider status dicts List of provider status dicts
""" """
@@ -112,10 +113,10 @@ class TimmyCascadeAdapter:
} }
for p in self.router.providers for p in self.router.providers
] ]
def get_preferred_provider(self) -> Optional[str]: def get_preferred_provider(self) -> Optional[str]:
"""Get name of highest-priority healthy provider. """Get name of highest-priority healthy provider.
Returns: Returns:
Provider name or None if all unhealthy Provider name or None if all unhealthy
""" """

View File

@@ -17,22 +17,23 @@ logger = logging.getLogger(__name__)
@dataclass @dataclass
class ConversationContext: class ConversationContext:
"""Tracks the current conversation state.""" """Tracks the current conversation state."""
user_name: Optional[str] = None user_name: Optional[str] = None
current_topic: Optional[str] = None current_topic: Optional[str] = None
last_intent: Optional[str] = None last_intent: Optional[str] = None
turn_count: int = 0 turn_count: int = 0
started_at: datetime = field(default_factory=datetime.now) started_at: datetime = field(default_factory=datetime.now)
def update_topic(self, topic: str) -> None: def update_topic(self, topic: str) -> None:
"""Update the current conversation topic.""" """Update the current conversation topic."""
self.current_topic = topic self.current_topic = topic
self.turn_count += 1 self.turn_count += 1
def set_user_name(self, name: str) -> None: def set_user_name(self, name: str) -> None:
"""Remember the user's name.""" """Remember the user's name."""
self.user_name = name self.user_name = name
logger.info("User name set to: %s", name) logger.info("User name set to: %s", name)
def get_context_summary(self) -> str: def get_context_summary(self) -> str:
"""Generate a context summary for the prompt.""" """Generate a context summary for the prompt."""
parts = [] parts = []
@@ -47,35 +48,88 @@ class ConversationContext:
class ConversationManager: class ConversationManager:
"""Manages conversation context across sessions.""" """Manages conversation context across sessions."""
def __init__(self) -> None: def __init__(self) -> None:
self._contexts: dict[str, ConversationContext] = {} self._contexts: dict[str, ConversationContext] = {}
def get_context(self, session_id: str) -> ConversationContext: def get_context(self, session_id: str) -> ConversationContext:
"""Get or create context for a session.""" """Get or create context for a session."""
if session_id not in self._contexts: if session_id not in self._contexts:
self._contexts[session_id] = ConversationContext() self._contexts[session_id] = ConversationContext()
return self._contexts[session_id] return self._contexts[session_id]
def clear_context(self, session_id: str) -> None: def clear_context(self, session_id: str) -> None:
"""Clear context for a session.""" """Clear context for a session."""
if session_id in self._contexts: if session_id in self._contexts:
del self._contexts[session_id] del self._contexts[session_id]
# Words that look like names but are actually verbs/UI states # Words that look like names but are actually verbs/UI states
_NAME_BLOCKLIST = frozenset({ _NAME_BLOCKLIST = frozenset(
"sending", "loading", "pending", "processing", "typing", {
"working", "going", "trying", "looking", "getting", "doing", "sending",
"waiting", "running", "checking", "coming", "leaving", "loading",
"thinking", "reading", "writing", "watching", "listening", "pending",
"playing", "eating", "sleeping", "sitting", "standing", "processing",
"walking", "talking", "asking", "telling", "feeling", "typing",
"hoping", "wondering", "glad", "happy", "sorry", "sure", "working",
"fine", "good", "great", "okay", "here", "there", "back", "going",
"done", "ready", "busy", "free", "available", "interested", "trying",
"confused", "lost", "stuck", "curious", "excited", "tired", "looking",
"not", "also", "just", "still", "already", "currently", "getting",
}) "doing",
"waiting",
"running",
"checking",
"coming",
"leaving",
"thinking",
"reading",
"writing",
"watching",
"listening",
"playing",
"eating",
"sleeping",
"sitting",
"standing",
"walking",
"talking",
"asking",
"telling",
"feeling",
"hoping",
"wondering",
"glad",
"happy",
"sorry",
"sure",
"fine",
"good",
"great",
"okay",
"here",
"there",
"back",
"done",
"ready",
"busy",
"free",
"available",
"interested",
"confused",
"lost",
"stuck",
"curious",
"excited",
"tired",
"not",
"also",
"just",
"still",
"already",
"currently",
}
)
def extract_user_name(self, message: str) -> Optional[str]: def extract_user_name(self, message: str) -> Optional[str]:
"""Try to extract user's name from message.""" """Try to extract user's name from message."""
@@ -106,40 +160,66 @@ class ConversationManager:
return name.capitalize() return name.capitalize()
return None return None
def should_use_tools(self, message: str, context: ConversationContext) -> bool: def should_use_tools(self, message: str, context: ConversationContext) -> bool:
"""Determine if this message likely requires tools. """Determine if this message likely requires tools.
Returns True if tools are likely needed, False for simple chat. Returns True if tools are likely needed, False for simple chat.
""" """
message_lower = message.lower().strip() message_lower = message.lower().strip()
# Tool keywords that suggest tool usage is needed # Tool keywords that suggest tool usage is needed
tool_keywords = [ tool_keywords = [
"search", "look up", "find", "google", "current price", "search",
"latest", "today's", "news", "weather", "stock price", "look up",
"read file", "write file", "save", "calculate", "compute", "find",
"run ", "execute", "shell", "command", "install", "google",
"current price",
"latest",
"today's",
"news",
"weather",
"stock price",
"read file",
"write file",
"save",
"calculate",
"compute",
"run ",
"execute",
"shell",
"command",
"install",
] ]
# Chat-only keywords that definitely don't need tools # Chat-only keywords that definitely don't need tools
chat_only = [ chat_only = [
"hello", "hi ", "hey", "how are you", "what's up", "hello",
"your name", "who are you", "what are you", "hi ",
"thanks", "thank you", "bye", "goodbye", "hey",
"tell me about yourself", "what can you do", "how are you",
"what's up",
"your name",
"who are you",
"what are you",
"thanks",
"thank you",
"bye",
"goodbye",
"tell me about yourself",
"what can you do",
] ]
# Check for chat-only patterns first # Check for chat-only patterns first
for pattern in chat_only: for pattern in chat_only:
if pattern in message_lower: if pattern in message_lower:
return False return False
# Check for tool keywords # Check for tool keywords
for keyword in tool_keywords: for keyword in tool_keywords:
if keyword in message_lower: if keyword in message_lower:
return True return True
# Simple questions (starting with what, who, how, why, when, where) # Simple questions (starting with what, who, how, why, when, where)
# usually don't need tools unless about current/real-time info # usually don't need tools unless about current/real-time info
simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"] simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"]
@@ -150,7 +230,7 @@ class ConversationManager:
if any(t in message_lower for t in time_words): if any(t in message_lower for t in time_words):
return True return True
return False return False
# Default: don't use tools for unclear cases # Default: don't use tools for unclear cases
return False return False

View File

@@ -25,11 +25,12 @@ def _get_model():
global _model, _has_embeddings global _model, _has_embeddings
if _has_embeddings is False: if _has_embeddings is False:
return None return None
if _model is not None: if _model is not None:
return _model return _model
from config import settings from config import settings
# In test mode or low-memory environments, skip embedding model load # In test mode or low-memory environments, skip embedding model load
if settings.timmy_skip_embeddings: if settings.timmy_skip_embeddings:
_has_embeddings = False _has_embeddings = False
@@ -37,7 +38,8 @@ def _get_model():
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
_model = SentenceTransformer('all-MiniLM-L6-v2')
_model = SentenceTransformer("all-MiniLM-L6-v2")
_has_embeddings = True _has_embeddings = True
return _model return _model
except (ImportError, RuntimeError, Exception): except (ImportError, RuntimeError, Exception):
@@ -56,7 +58,7 @@ def _get_embedding_dimension() -> int:
def _compute_embedding(text: str) -> list[float]: def _compute_embedding(text: str) -> list[float]:
"""Compute embedding vector for text. """Compute embedding vector for text.
Uses sentence-transformers if available, otherwise returns Uses sentence-transformers if available, otherwise returns
a simple hash-based vector for basic similarity. a simple hash-based vector for basic similarity.
""" """
@@ -66,30 +68,31 @@ def _compute_embedding(text: str) -> list[float]:
return model.encode(text).tolist() return model.encode(text).tolist()
except Exception: except Exception:
pass pass
# Fallback: simple character n-gram hash embedding # Fallback: simple character n-gram hash embedding
# Not as good but allows the system to work without heavy deps # Not as good but allows the system to work without heavy deps
dim = 384 dim = 384
vec = [0.0] * dim vec = [0.0] * dim
text = text.lower() text = text.lower()
# Generate character trigram features # Generate character trigram features
for i in range(len(text) - 2): for i in range(len(text) - 2):
trigram = text[i:i+3] trigram = text[i : i + 3]
hash_val = hash(trigram) % dim hash_val = hash(trigram) % dim
vec[hash_val] += 1.0 vec[hash_val] += 1.0
# Normalize # Normalize
norm = sum(x*x for x in vec) ** 0.5 norm = sum(x * x for x in vec) ** 0.5
if norm > 0: if norm > 0:
vec = [x/norm for x in vec] vec = [x / norm for x in vec]
return vec return vec
@dataclass @dataclass
class MemoryEntry: class MemoryEntry:
"""A memory entry with vector embedding.""" """A memory entry with vector embedding."""
id: str = field(default_factory=lambda: str(uuid.uuid4())) id: str = field(default_factory=lambda: str(uuid.uuid4()))
content: str = "" # The actual text content content: str = "" # The actual text content
source: str = "" # Where it came from (agent, user, system) source: str = "" # Where it came from (agent, user, system)
@@ -99,9 +102,7 @@ class MemoryEntry:
session_id: Optional[str] = None session_id: Optional[str] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
embedding: Optional[list[float]] = None embedding: Optional[list[float]] = None
timestamp: str = field( timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
relevance_score: Optional[float] = None # Set during search relevance_score: Optional[float] = None # Set during search
@@ -110,7 +111,7 @@ def _get_conn() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True) DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(DB_PATH)) conn = sqlite3.connect(str(DB_PATH))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
# Try to load sqlite-vss extension # Try to load sqlite-vss extension
try: try:
conn.enable_load_extension(True) conn.enable_load_extension(True)
@@ -119,7 +120,7 @@ def _get_conn() -> sqlite3.Connection:
_has_vss = True _has_vss = True
except Exception: except Exception:
_has_vss = False _has_vss = False
# Create tables # Create tables
conn.execute( conn.execute(
""" """
@@ -137,24 +138,14 @@ def _get_conn() -> sqlite3.Connection:
) )
""" """
) )
# Create indexes # Create indexes
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)")
"CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)" conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)")
) conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)")
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)")
"CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)" conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)")
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)"
)
conn.commit() conn.commit()
return conn return conn
@@ -170,7 +161,7 @@ def store_memory(
compute_embedding: bool = True, compute_embedding: bool = True,
) -> MemoryEntry: ) -> MemoryEntry:
"""Store a memory entry with optional embedding. """Store a memory entry with optional embedding.
Args: Args:
content: The text content to store content: The text content to store
source: Source of the memory (agent name, user, system) source: Source of the memory (agent name, user, system)
@@ -180,14 +171,14 @@ def store_memory(
session_id: Session identifier session_id: Session identifier
metadata: Additional structured data metadata: Additional structured data
compute_embedding: Whether to compute vector embedding compute_embedding: Whether to compute vector embedding
Returns: Returns:
The stored MemoryEntry The stored MemoryEntry
""" """
embedding = None embedding = None
if compute_embedding: if compute_embedding:
embedding = _compute_embedding(content) embedding = _compute_embedding(content)
entry = MemoryEntry( entry = MemoryEntry(
content=content, content=content,
source=source, source=source,
@@ -198,7 +189,7 @@ def store_memory(
metadata=metadata, metadata=metadata,
embedding=embedding, embedding=embedding,
) )
conn = _get_conn() conn = _get_conn()
conn.execute( conn.execute(
""" """
@@ -222,7 +213,7 @@ def store_memory(
) )
conn.commit() conn.commit()
conn.close() conn.close()
return entry return entry
@@ -235,7 +226,7 @@ def search_memories(
min_relevance: float = 0.0, min_relevance: float = 0.0,
) -> list[MemoryEntry]: ) -> list[MemoryEntry]:
"""Search for memories by semantic similarity. """Search for memories by semantic similarity.
Args: Args:
query: Search query text query: Search query text
limit: Maximum results limit: Maximum results
@@ -243,18 +234,18 @@ def search_memories(
agent_id: Filter by agent agent_id: Filter by agent
session_id: Filter by session session_id: Filter by session
min_relevance: Minimum similarity score (0-1) min_relevance: Minimum similarity score (0-1)
Returns: Returns:
List of MemoryEntry objects sorted by relevance List of MemoryEntry objects sorted by relevance
""" """
query_embedding = _compute_embedding(query) query_embedding = _compute_embedding(query)
conn = _get_conn() conn = _get_conn()
# Build query with filters # Build query with filters
conditions = [] conditions = []
params = [] params = []
if context_type: if context_type:
conditions.append("context_type = ?") conditions.append("context_type = ?")
params.append(context_type) params.append(context_type)
@@ -264,9 +255,9 @@ def search_memories(
if session_id: if session_id:
conditions.append("session_id = ?") conditions.append("session_id = ?")
params.append(session_id) params.append(session_id)
where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
# Fetch candidates (we'll do in-memory similarity for now) # Fetch candidates (we'll do in-memory similarity for now)
# For production with sqlite-vss, this would use vector similarity index # For production with sqlite-vss, this would use vector similarity index
query_sql = f""" query_sql = f"""
@@ -276,10 +267,10 @@ def search_memories(
LIMIT ? LIMIT ?
""" """
params.append(limit * 3) # Get more candidates for ranking params.append(limit * 3) # Get more candidates for ranking
rows = conn.execute(query_sql, params).fetchall() rows = conn.execute(query_sql, params).fetchall()
conn.close() conn.close()
# Compute similarity scores # Compute similarity scores
results = [] results = []
for row in rows: for row in rows:
@@ -295,7 +286,7 @@ def search_memories(
embedding=json.loads(row["embedding"]) if row["embedding"] else None, embedding=json.loads(row["embedding"]) if row["embedding"] else None,
timestamp=row["timestamp"], timestamp=row["timestamp"],
) )
if entry.embedding: if entry.embedding:
# Cosine similarity # Cosine similarity
score = _cosine_similarity(query_embedding, entry.embedding) score = _cosine_similarity(query_embedding, entry.embedding)
@@ -308,7 +299,7 @@ def search_memories(
entry.relevance_score = score entry.relevance_score = score
if score >= min_relevance: if score >= min_relevance:
results.append(entry) results.append(entry)
# Sort by relevance and return top results # Sort by relevance and return top results
results.sort(key=lambda x: x.relevance_score or 0, reverse=True) results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
return results[:limit] return results[:limit]
@@ -316,9 +307,9 @@ def search_memories(
def _cosine_similarity(a: list[float], b: list[float]) -> float: def _cosine_similarity(a: list[float], b: list[float]) -> float:
"""Compute cosine similarity between two vectors.""" """Compute cosine similarity between two vectors."""
dot = sum(x*y for x, y in zip(a, b)) dot = sum(x * y for x, y in zip(a, b))
norm_a = sum(x*x for x in a) ** 0.5 norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x*x for x in b) ** 0.5 norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0: if norm_a == 0 or norm_b == 0:
return 0.0 return 0.0
return dot / (norm_a * norm_b) return dot / (norm_a * norm_b)
@@ -334,51 +325,47 @@ def _keyword_overlap(query: str, content: str) -> float:
return overlap / len(query_words) return overlap / len(query_words)
def get_memory_context( def get_memory_context(query: str, max_tokens: int = 2000, **filters) -> str:
query: str,
max_tokens: int = 2000,
**filters
) -> str:
"""Get relevant memory context as formatted text for LLM prompts. """Get relevant memory context as formatted text for LLM prompts.
Args: Args:
query: Search query query: Search query
max_tokens: Approximate maximum tokens to return max_tokens: Approximate maximum tokens to return
**filters: Additional filters (agent_id, session_id, etc.) **filters: Additional filters (agent_id, session_id, etc.)
Returns: Returns:
Formatted context string for inclusion in prompts Formatted context string for inclusion in prompts
""" """
memories = search_memories(query, limit=20, **filters) memories = search_memories(query, limit=20, **filters)
context_parts = [] context_parts = []
total_chars = 0 total_chars = 0
max_chars = max_tokens * 4 # Rough approximation max_chars = max_tokens * 4 # Rough approximation
for mem in memories: for mem in memories:
formatted = f"[{mem.source}]: {mem.content}" formatted = f"[{mem.source}]: {mem.content}"
if total_chars + len(formatted) > max_chars: if total_chars + len(formatted) > max_chars:
break break
context_parts.append(formatted) context_parts.append(formatted)
total_chars += len(formatted) total_chars += len(formatted)
if not context_parts: if not context_parts:
return "" return ""
return "Relevant context from memory:\n" + "\n\n".join(context_parts) return "Relevant context from memory:\n" + "\n\n".join(context_parts)
def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]: def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
"""Recall personal facts about the user or system. """Recall personal facts about the user or system.
Args: Args:
agent_id: Optional agent filter agent_id: Optional agent filter
Returns: Returns:
List of fact strings List of fact strings
""" """
conn = _get_conn() conn = _get_conn()
if agent_id: if agent_id:
rows = conn.execute( rows = conn.execute(
""" """
@@ -398,7 +385,7 @@ def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
LIMIT 100 LIMIT 100
""", """,
).fetchall() ).fetchall()
conn.close() conn.close()
return [r["content"] for r in rows] return [r["content"] for r in rows]
@@ -434,11 +421,11 @@ def update_personal_fact(memory_id: str, new_content: str) -> bool:
def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry: def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry:
"""Store a personal fact about the user or system. """Store a personal fact about the user or system.
Args: Args:
fact: The fact to store fact: The fact to store
agent_id: Associated agent agent_id: Associated agent
Returns: Returns:
The stored MemoryEntry The stored MemoryEntry
""" """
@@ -453,7 +440,7 @@ def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntr
def delete_memory(memory_id: str) -> bool: def delete_memory(memory_id: str) -> bool:
"""Delete a memory entry by ID. """Delete a memory entry by ID.
Returns: Returns:
True if deleted, False if not found True if deleted, False if not found
""" """
@@ -470,29 +457,27 @@ def delete_memory(memory_id: str) -> bool:
def get_memory_stats() -> dict: def get_memory_stats() -> dict:
"""Get statistics about the memory store. """Get statistics about the memory store.
Returns: Returns:
Dict with counts by type, total entries, etc. Dict with counts by type, total entries, etc.
""" """
conn = _get_conn() conn = _get_conn()
total = conn.execute( total = conn.execute("SELECT COUNT(*) as count FROM memory_entries").fetchone()["count"]
"SELECT COUNT(*) as count FROM memory_entries"
).fetchone()["count"]
by_type = {} by_type = {}
rows = conn.execute( rows = conn.execute(
"SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type" "SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type"
).fetchall() ).fetchall()
for row in rows: for row in rows:
by_type[row["context_type"]] = row["count"] by_type[row["context_type"]] = row["count"]
with_embeddings = conn.execute( with_embeddings = conn.execute(
"SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL" "SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL"
).fetchone()["count"] ).fetchone()["count"]
conn.close() conn.close()
return { return {
"total_entries": total, "total_entries": total,
"by_type": by_type, "by_type": by_type,
@@ -503,20 +488,20 @@ def get_memory_stats() -> dict:
def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int: def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
"""Delete old memories to manage storage. """Delete old memories to manage storage.
Args: Args:
older_than_days: Delete memories older than this older_than_days: Delete memories older than this
keep_facts: Whether to preserve fact-type memories keep_facts: Whether to preserve fact-type memories
Returns: Returns:
Number of entries deleted Number of entries deleted
""" """
from datetime import timedelta from datetime import timedelta
cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat() cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat()
conn = _get_conn() conn = _get_conn()
if keep_facts: if keep_facts:
cursor = conn.execute( cursor = conn.execute(
""" """
@@ -530,9 +515,9 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
"DELETE FROM memory_entries WHERE timestamp < ?", "DELETE FROM memory_entries WHERE timestamp < ?",
(cutoff,), (cutoff,),
) )
deleted = cursor.rowcount deleted = cursor.rowcount
conn.commit() conn.commit()
conn.close() conn.close()
return deleted return deleted

View File

@@ -28,50 +28,52 @@ HANDOFF_PATH = VAULT_PATH / "notes" / "last-session-handoff.md"
class HotMemory: class HotMemory:
"""Tier 1: Hot memory (MEMORY.md) — always loaded.""" """Tier 1: Hot memory (MEMORY.md) — always loaded."""
def __init__(self) -> None: def __init__(self) -> None:
self.path = HOT_MEMORY_PATH self.path = HOT_MEMORY_PATH
self._content: Optional[str] = None self._content: Optional[str] = None
self._last_modified: Optional[float] = None self._last_modified: Optional[float] = None
def read(self, force_refresh: bool = False) -> str: def read(self, force_refresh: bool = False) -> str:
"""Read hot memory, with caching.""" """Read hot memory, with caching."""
if not self.path.exists(): if not self.path.exists():
self._create_default() self._create_default()
# Check if file changed # Check if file changed
current_mtime = self.path.stat().st_mtime current_mtime = self.path.stat().st_mtime
if not force_refresh and self._content and self._last_modified == current_mtime: if not force_refresh and self._content and self._last_modified == current_mtime:
return self._content return self._content
self._content = self.path.read_text() self._content = self.path.read_text()
self._last_modified = current_mtime self._last_modified = current_mtime
logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path) logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path)
return self._content return self._content
def update_section(self, section: str, content: str) -> None: def update_section(self, section: str, content: str) -> None:
"""Update a specific section in MEMORY.md.""" """Update a specific section in MEMORY.md."""
full_content = self.read() full_content = self.read()
# Find section # Find section
pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)" pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)"
match = re.search(pattern, full_content, re.DOTALL) match = re.search(pattern, full_content, re.DOTALL)
if match: if match:
# Replace section # Replace section
new_section = f"## {section}\n\n{content}\n\n" new_section = f"## {section}\n\n{content}\n\n"
full_content = full_content[:match.start()] + new_section + full_content[match.end():] full_content = full_content[: match.start()] + new_section + full_content[match.end() :]
else: else:
# Append section before last updated line # Append section before last updated line
insert_point = full_content.rfind("*Prune date:") insert_point = full_content.rfind("*Prune date:")
new_section = f"## {section}\n\n{content}\n\n" new_section = f"## {section}\n\n{content}\n\n"
full_content = full_content[:insert_point] + new_section + "\n" + full_content[insert_point:] full_content = (
full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
)
self.path.write_text(full_content) self.path.write_text(full_content)
self._content = full_content self._content = full_content
self._last_modified = self.path.stat().st_mtime self._last_modified = self.path.stat().st_mtime
logger.info("HotMemory: Updated section '%s'", section) logger.info("HotMemory: Updated section '%s'", section)
def _create_default(self) -> None: def _create_default(self) -> None:
"""Create default MEMORY.md if missing.""" """Create default MEMORY.md if missing."""
default_content = """# Timmy Hot Memory default_content = """# Timmy Hot Memory
@@ -130,33 +132,33 @@ class HotMemory:
*Prune date: {prune_date}* *Prune date: {prune_date}*
""".format( """.format(
date=datetime.now(timezone.utc).strftime("%Y-%m-%d"), date=datetime.now(timezone.utc).strftime("%Y-%m-%d"),
prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d") prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d"),
) )
self.path.write_text(default_content) self.path.write_text(default_content)
logger.info("HotMemory: Created default MEMORY.md") logger.info("HotMemory: Created default MEMORY.md")
class VaultMemory: class VaultMemory:
"""Tier 2: Structured vault (memory/) — append-only markdown.""" """Tier 2: Structured vault (memory/) — append-only markdown."""
def __init__(self) -> None: def __init__(self) -> None:
self.path = VAULT_PATH self.path = VAULT_PATH
self._ensure_structure() self._ensure_structure()
def _ensure_structure(self) -> None: def _ensure_structure(self) -> None:
"""Ensure vault directory structure exists.""" """Ensure vault directory structure exists."""
(self.path / "self").mkdir(parents=True, exist_ok=True) (self.path / "self").mkdir(parents=True, exist_ok=True)
(self.path / "notes").mkdir(parents=True, exist_ok=True) (self.path / "notes").mkdir(parents=True, exist_ok=True)
(self.path / "aar").mkdir(parents=True, exist_ok=True) (self.path / "aar").mkdir(parents=True, exist_ok=True)
def write_note(self, name: str, content: str, namespace: str = "notes") -> Path: def write_note(self, name: str, content: str, namespace: str = "notes") -> Path:
"""Write a note to the vault.""" """Write a note to the vault."""
# Add timestamp to filename # Add timestamp to filename
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d") timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
filename = f"{timestamp}_{name}.md" filename = f"{timestamp}_{name}.md"
filepath = self.path / namespace / filename filepath = self.path / namespace / filename
# Add header # Add header
full_content = f"""# {name.replace('_', ' ').title()} full_content = f"""# {name.replace('_', ' ').title()}
@@ -171,39 +173,39 @@ class VaultMemory:
*Auto-generated by Timmy Memory System* *Auto-generated by Timmy Memory System*
""" """
filepath.write_text(full_content) filepath.write_text(full_content)
logger.info("VaultMemory: Wrote %s", filepath) logger.info("VaultMemory: Wrote %s", filepath)
return filepath return filepath
def read_file(self, filepath: Path) -> str: def read_file(self, filepath: Path) -> str:
"""Read a file from the vault.""" """Read a file from the vault."""
if not filepath.exists(): if not filepath.exists():
return "" return ""
return filepath.read_text() return filepath.read_text()
def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]: def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]:
"""List files in a namespace.""" """List files in a namespace."""
dir_path = self.path / namespace dir_path = self.path / namespace
if not dir_path.exists(): if not dir_path.exists():
return [] return []
return sorted(dir_path.glob(pattern)) return sorted(dir_path.glob(pattern))
def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]: def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]:
"""Get most recent file in namespace.""" """Get most recent file in namespace."""
files = self.list_files(namespace, pattern) files = self.list_files(namespace, pattern)
return files[-1] if files else None return files[-1] if files else None
def update_user_profile(self, key: str, value: str) -> None: def update_user_profile(self, key: str, value: str) -> None:
"""Update a field in user_profile.md.""" """Update a field in user_profile.md."""
profile_path = self.path / "self" / "user_profile.md" profile_path = self.path / "self" / "user_profile.md"
if not profile_path.exists(): if not profile_path.exists():
# Create default profile # Create default profile
self._create_default_profile() self._create_default_profile()
content = profile_path.read_text() content = profile_path.read_text()
# Simple pattern replacement # Simple pattern replacement
pattern = rf"(\*\*{re.escape(key)}:\*\*).*" pattern = rf"(\*\*{re.escape(key)}:\*\*).*"
if re.search(pattern, content): if re.search(pattern, content):
@@ -214,17 +216,17 @@ class VaultMemory:
if facts_section in content: if facts_section in content:
insert_point = content.find(facts_section) + len(facts_section) insert_point = content.find(facts_section) + len(facts_section)
content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:] content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:]
# Update last_updated # Update last_updated
content = re.sub( content = re.sub(
r"\*Last updated:.*\*", r"\*Last updated:.*\*",
f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*", f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*",
content content,
) )
profile_path.write_text(content) profile_path.write_text(content)
logger.info("VaultMemory: Updated user profile: %s = %s", key, value) logger.info("VaultMemory: Updated user profile: %s = %s", key, value)
def _create_default_profile(self) -> None: def _create_default_profile(self) -> None:
"""Create default user profile.""" """Create default user profile."""
profile_path = self.path / "self" / "user_profile.md" profile_path = self.path / "self" / "user_profile.md"
@@ -254,24 +256,26 @@ class VaultMemory:
--- ---
*Last updated: {date}* *Last updated: {date}*
""".format(date=datetime.now(timezone.utc).strftime("%Y-%m-%d")) """.format(
date=datetime.now(timezone.utc).strftime("%Y-%m-%d")
)
profile_path.write_text(default) profile_path.write_text(default)
class HandoffProtocol: class HandoffProtocol:
"""Session handoff protocol for continuity.""" """Session handoff protocol for continuity."""
def __init__(self) -> None: def __init__(self) -> None:
self.path = HANDOFF_PATH self.path = HANDOFF_PATH
self.vault = VaultMemory() self.vault = VaultMemory()
def write_handoff( def write_handoff(
self, self,
session_summary: str, session_summary: str,
key_decisions: list[str], key_decisions: list[str],
open_items: list[str], open_items: list[str],
next_steps: list[str] next_steps: list[str],
) -> None: ) -> None:
"""Write handoff at session end.""" """Write handoff at session end."""
content = f"""# Last Session Handoff content = f"""# Last Session Handoff
@@ -303,25 +307,24 @@ The user was last working on: {session_summary[:200]}...
*This handoff will be auto-loaded at next session start* *This handoff will be auto-loaded at next session start*
""" """
self.path.write_text(content) self.path.write_text(content)
# Also archive to notes # Also archive to notes
self.vault.write_note( self.vault.write_note("session_handoff", content, namespace="notes")
"session_handoff",
content, logger.info(
namespace="notes" "HandoffProtocol: Wrote handoff with %d decisions, %d open items",
len(key_decisions),
len(open_items),
) )
logger.info("HandoffProtocol: Wrote handoff with %d decisions, %d open items",
len(key_decisions), len(open_items))
def read_handoff(self) -> Optional[str]: def read_handoff(self) -> Optional[str]:
"""Read handoff if exists.""" """Read handoff if exists."""
if not self.path.exists(): if not self.path.exists():
return None return None
return self.path.read_text() return self.path.read_text()
def clear_handoff(self) -> None: def clear_handoff(self) -> None:
"""Clear handoff after loading.""" """Clear handoff after loading."""
if self.path.exists(): if self.path.exists():
@@ -331,7 +334,7 @@ The user was last working on: {session_summary[:200]}...
class MemorySystem: class MemorySystem:
"""Central memory system coordinating all tiers.""" """Central memory system coordinating all tiers."""
def __init__(self) -> None: def __init__(self) -> None:
self.hot = HotMemory() self.hot = HotMemory()
self.vault = VaultMemory() self.vault = VaultMemory()
@@ -339,52 +342,52 @@ class MemorySystem:
self.session_start_time: Optional[datetime] = None self.session_start_time: Optional[datetime] = None
self.session_decisions: list[str] = [] self.session_decisions: list[str] = []
self.session_open_items: list[str] = [] self.session_open_items: list[str] = []
def start_session(self) -> str: def start_session(self) -> str:
"""Start a new session, loading context from memory.""" """Start a new session, loading context from memory."""
self.session_start_time = datetime.now(timezone.utc) self.session_start_time = datetime.now(timezone.utc)
# Build context # Build context
context_parts = [] context_parts = []
# 1. Hot memory # 1. Hot memory
hot_content = self.hot.read() hot_content = self.hot.read()
context_parts.append("## Hot Memory\n" + hot_content) context_parts.append("## Hot Memory\n" + hot_content)
# 2. Last session handoff # 2. Last session handoff
handoff_content = self.handoff.read_handoff() handoff_content = self.handoff.read_handoff()
if handoff_content: if handoff_content:
context_parts.append("## Previous Session\n" + handoff_content) context_parts.append("## Previous Session\n" + handoff_content)
self.handoff.clear_handoff() self.handoff.clear_handoff()
# 3. User profile (key fields only) # 3. User profile (key fields only)
profile = self._load_user_profile_summary() profile = self._load_user_profile_summary()
if profile: if profile:
context_parts.append("## User Context\n" + profile) context_parts.append("## User Context\n" + profile)
full_context = "\n\n---\n\n".join(context_parts) full_context = "\n\n---\n\n".join(context_parts)
logger.info("MemorySystem: Session started with %d chars context", len(full_context)) logger.info("MemorySystem: Session started with %d chars context", len(full_context))
return full_context return full_context
def end_session(self, summary: str) -> None: def end_session(self, summary: str) -> None:
"""End session, write handoff.""" """End session, write handoff."""
self.handoff.write_handoff( self.handoff.write_handoff(
session_summary=summary, session_summary=summary,
key_decisions=self.session_decisions, key_decisions=self.session_decisions,
open_items=self.session_open_items, open_items=self.session_open_items,
next_steps=[] next_steps=[],
) )
# Update hot memory # Update hot memory
self.hot.update_section( self.hot.update_section(
"Current Session", "Current Session",
f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n" + f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n"
f"**Summary:** {summary[:100]}..." + f"**Summary:** {summary[:100]}...",
) )
logger.info("MemorySystem: Session ended, handoff written") logger.info("MemorySystem: Session ended, handoff written")
def record_decision(self, decision: str) -> None: def record_decision(self, decision: str) -> None:
"""Record a key decision during session.""" """Record a key decision during session."""
self.session_decisions.append(decision) self.session_decisions.append(decision)
@@ -393,43 +396,47 @@ class MemorySystem:
if "## Key Decisions" in current: if "## Key Decisions" in current:
# Append to section # Append to section
pass # Handled at session end pass # Handled at session end
def record_open_item(self, item: str) -> None: def record_open_item(self, item: str) -> None:
"""Record an open item for follow-up.""" """Record an open item for follow-up."""
self.session_open_items.append(item) self.session_open_items.append(item)
def update_user_fact(self, key: str, value: str) -> None: def update_user_fact(self, key: str, value: str) -> None:
"""Update user profile in vault.""" """Update user profile in vault."""
self.vault.update_user_profile(key, value) self.vault.update_user_profile(key, value)
# Also update hot memory # Also update hot memory
if key.lower() == "name": if key.lower() == "name":
self.hot.update_section("User Profile", f"**Name:** {value}") self.hot.update_section("User Profile", f"**Name:** {value}")
def _load_user_profile_summary(self) -> str: def _load_user_profile_summary(self) -> str:
"""Load condensed user profile.""" """Load condensed user profile."""
profile_path = self.vault.path / "self" / "user_profile.md" profile_path = self.vault.path / "self" / "user_profile.md"
if not profile_path.exists(): if not profile_path.exists():
return "" return ""
content = profile_path.read_text() content = profile_path.read_text()
# Extract key fields # Extract key fields
summary_parts = [] summary_parts = []
# Name # Name
name_match = re.search(r"\*\*Name:\*\* (.+)", content) name_match = re.search(r"\*\*Name:\*\* (.+)", content)
if name_match and "unknown" not in name_match.group(1).lower(): if name_match and "unknown" not in name_match.group(1).lower():
summary_parts.append(f"Name: {name_match.group(1).strip()}") summary_parts.append(f"Name: {name_match.group(1).strip()}")
# Interests # Interests
interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL) interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL)
if interests_section: if interests_section:
interests = [i.strip() for i in interests_section.group(1).split("\n-") if i.strip() and "to be" not in i] interests = [
i.strip()
for i in interests_section.group(1).split("\n-")
if i.strip() and "to be" not in i
]
if interests: if interests:
summary_parts.append(f"Interests: {', '.join(interests[:3])}") summary_parts.append(f"Interests: {', '.join(interests[:3])}")
return "\n".join(summary_parts) if summary_parts else "" return "\n".join(summary_parts) if summary_parts else ""
def get_system_context(self) -> str: def get_system_context(self) -> str:
"""Get full context for system prompt injection. """Get full context for system prompt injection.

View File

@@ -38,12 +38,14 @@ def _get_embedding_model():
global EMBEDDING_MODEL global EMBEDDING_MODEL
if EMBEDDING_MODEL is None: if EMBEDDING_MODEL is None:
from config import settings from config import settings
if settings.timmy_skip_embeddings: if settings.timmy_skip_embeddings:
EMBEDDING_MODEL = False EMBEDDING_MODEL = False
return EMBEDDING_MODEL return EMBEDDING_MODEL
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
logger.info("SemanticMemory: Loaded embedding model") logger.info("SemanticMemory: Loaded embedding model")
except ImportError: except ImportError:
logger.warning("SemanticMemory: sentence-transformers not installed, using fallback") logger.warning("SemanticMemory: sentence-transformers not installed, using fallback")
@@ -60,11 +62,12 @@ def _simple_hash_embedding(text: str) -> list[float]:
h = hashlib.md5(word.encode()).hexdigest() h = hashlib.md5(word.encode()).hexdigest()
for j in range(8): for j in range(8):
idx = (i * 8 + j) % 128 idx = (i * 8 + j) % 128
vec[idx] += int(h[j*2:j*2+2], 16) / 255.0 vec[idx] += int(h[j * 2 : j * 2 + 2], 16) / 255.0
# Normalize # Normalize
import math import math
mag = math.sqrt(sum(x*x for x in vec)) or 1.0
return [x/mag for x in vec] mag = math.sqrt(sum(x * x for x in vec)) or 1.0
return [x / mag for x in vec]
def embed_text(text: str) -> list[float]: def embed_text(text: str) -> list[float]:
@@ -80,9 +83,10 @@ def embed_text(text: str) -> list[float]:
def cosine_similarity(a: list[float], b: list[float]) -> float: def cosine_similarity(a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors.""" """Calculate cosine similarity between two vectors."""
import math import math
dot = sum(x*y for x, y in zip(a, b))
mag_a = math.sqrt(sum(x*x for x in a)) dot = sum(x * y for x, y in zip(a, b))
mag_b = math.sqrt(sum(x*x for x in b)) mag_a = math.sqrt(sum(x * x for x in a))
mag_b = math.sqrt(sum(x * x for x in b))
if mag_a == 0 or mag_b == 0: if mag_a == 0 or mag_b == 0:
return 0.0 return 0.0
return dot / (mag_a * mag_b) return dot / (mag_a * mag_b)
@@ -91,6 +95,7 @@ def cosine_similarity(a: list[float], b: list[float]) -> float:
@dataclass @dataclass
class MemoryChunk: class MemoryChunk:
"""A searchable chunk of memory.""" """A searchable chunk of memory."""
id: str id: str
source: str # filepath source: str # filepath
content: str content: str
@@ -100,17 +105,18 @@ class MemoryChunk:
class SemanticMemory: class SemanticMemory:
"""Vector-based semantic search over vault content.""" """Vector-based semantic search over vault content."""
def __init__(self) -> None: def __init__(self) -> None:
self.db_path = SEMANTIC_DB_PATH self.db_path = SEMANTIC_DB_PATH
self.vault_path = VAULT_PATH self.vault_path = VAULT_PATH
self._init_db() self._init_db()
def _init_db(self) -> None: def _init_db(self) -> None:
"""Initialize SQLite with vector storage.""" """Initialize SQLite with vector storage."""
self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(self.db_path)) conn = sqlite3.connect(str(self.db_path))
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS chunks ( CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
source TEXT NOT NULL, source TEXT NOT NULL,
@@ -119,76 +125,76 @@ class SemanticMemory:
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
source_hash TEXT NOT NULL source_hash TEXT NOT NULL
) )
""") """
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)") conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)")
conn.commit() conn.commit()
conn.close() conn.close()
def index_file(self, filepath: Path) -> int: def index_file(self, filepath: Path) -> int:
"""Index a single file into semantic memory.""" """Index a single file into semantic memory."""
if not filepath.exists(): if not filepath.exists():
return 0 return 0
content = filepath.read_text() content = filepath.read_text()
file_hash = hashlib.md5(content.encode()).hexdigest() file_hash = hashlib.md5(content.encode()).hexdigest()
# Check if already indexed with same hash # Check if already indexed with same hash
conn = sqlite3.connect(str(self.db_path)) conn = sqlite3.connect(str(self.db_path))
cursor = conn.execute( cursor = conn.execute(
"SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", "SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", (str(filepath),)
(str(filepath),)
) )
existing = cursor.fetchone() existing = cursor.fetchone()
if existing and existing[0] == file_hash: if existing and existing[0] == file_hash:
conn.close() conn.close()
return 0 # Already indexed return 0 # Already indexed
# Delete old chunks for this file # Delete old chunks for this file
conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),)) conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),))
# Split into chunks (paragraphs) # Split into chunks (paragraphs)
chunks = self._split_into_chunks(content) chunks = self._split_into_chunks(content)
# Index each chunk # Index each chunk
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
for i, chunk_text in enumerate(chunks): for i, chunk_text in enumerate(chunks):
if len(chunk_text.strip()) < 20: # Skip tiny chunks if len(chunk_text.strip()) < 20: # Skip tiny chunks
continue continue
chunk_id = f"{filepath.stem}_{i}" chunk_id = f"{filepath.stem}_{i}"
embedding = embed_text(chunk_text) embedding = embed_text(chunk_text)
conn.execute( conn.execute(
"""INSERT INTO chunks (id, source, content, embedding, created_at, source_hash) """INSERT INTO chunks (id, source, content, embedding, created_at, source_hash)
VALUES (?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?)""",
(chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash) (chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash),
) )
conn.commit() conn.commit()
conn.close() conn.close()
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks)) logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
return len(chunks) return len(chunks)
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]: def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
"""Split text into semantic chunks.""" """Split text into semantic chunks."""
# Split by paragraphs first # Split by paragraphs first
paragraphs = text.split('\n\n') paragraphs = text.split("\n\n")
chunks = [] chunks = []
for para in paragraphs: for para in paragraphs:
para = para.strip() para = para.strip()
if not para: if not para:
continue continue
# If paragraph is small enough, keep as one chunk # If paragraph is small enough, keep as one chunk
if len(para) <= max_chunk_size: if len(para) <= max_chunk_size:
chunks.append(para) chunks.append(para)
else: else:
# Split long paragraphs by sentences # Split long paragraphs by sentences
sentences = para.replace('. ', '.\n').split('\n') sentences = para.replace(". ", ".\n").split("\n")
current_chunk = "" current_chunk = ""
for sent in sentences: for sent in sentences:
if len(current_chunk) + len(sent) < max_chunk_size: if len(current_chunk) + len(sent) < max_chunk_size:
current_chunk += " " + sent if current_chunk else sent current_chunk += " " + sent if current_chunk else sent
@@ -196,82 +202,80 @@ class SemanticMemory:
if current_chunk: if current_chunk:
chunks.append(current_chunk.strip()) chunks.append(current_chunk.strip())
current_chunk = sent current_chunk = sent
if current_chunk: if current_chunk:
chunks.append(current_chunk.strip()) chunks.append(current_chunk.strip())
return chunks return chunks
def index_vault(self) -> int: def index_vault(self) -> int:
"""Index entire vault directory.""" """Index entire vault directory."""
total_chunks = 0 total_chunks = 0
for md_file in self.vault_path.rglob("*.md"): for md_file in self.vault_path.rglob("*.md"):
# Skip handoff file (handled separately) # Skip handoff file (handled separately)
if "last-session-handoff" in md_file.name: if "last-session-handoff" in md_file.name:
continue continue
total_chunks += self.index_file(md_file) total_chunks += self.index_file(md_file)
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks) logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
return total_chunks return total_chunks
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]: def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
"""Search for relevant memory chunks.""" """Search for relevant memory chunks."""
query_embedding = embed_text(query) query_embedding = embed_text(query)
conn = sqlite3.connect(str(self.db_path)) conn = sqlite3.connect(str(self.db_path))
conn.row_factory = sqlite3.Row conn.row_factory = sqlite3.Row
# Get all chunks (in production, use vector index) # Get all chunks (in production, use vector index)
rows = conn.execute( rows = conn.execute("SELECT source, content, embedding FROM chunks").fetchall()
"SELECT source, content, embedding FROM chunks"
).fetchall()
conn.close() conn.close()
# Calculate similarities # Calculate similarities
scored = [] scored = []
for row in rows: for row in rows:
embedding = json.loads(row["embedding"]) embedding = json.loads(row["embedding"])
score = cosine_similarity(query_embedding, embedding) score = cosine_similarity(query_embedding, embedding)
scored.append((row["source"], row["content"], score)) scored.append((row["source"], row["content"], score))
# Sort by score descending # Sort by score descending
scored.sort(key=lambda x: x[2], reverse=True) scored.sort(key=lambda x: x[2], reverse=True)
# Return top_k # Return top_k
return [(content, score) for _, content, score in scored[:top_k]] return [(content, score) for _, content, score in scored[:top_k]]
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str: def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
"""Get formatted context string for a query.""" """Get formatted context string for a query."""
results = self.search(query, top_k=3) results = self.search(query, top_k=3)
if not results: if not results:
return "" return ""
parts = [] parts = []
total_chars = 0 total_chars = 0
for content, score in results: for content, score in results:
if score < 0.3: # Similarity threshold if score < 0.3: # Similarity threshold
continue continue
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..." chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
if total_chars + len(chunk) > max_chars: if total_chars + len(chunk) > max_chars:
break break
parts.append(chunk) parts.append(chunk)
total_chars += len(chunk) total_chars += len(chunk)
return "\n\n".join(parts) if parts else "" return "\n\n".join(parts) if parts else ""
def stats(self) -> dict: def stats(self) -> dict:
"""Get indexing statistics.""" """Get indexing statistics."""
conn = sqlite3.connect(str(self.db_path)) conn = sqlite3.connect(str(self.db_path))
cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks") cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks")
total_chunks, total_files = cursor.fetchone() total_chunks, total_files = cursor.fetchone()
conn.close() conn.close()
return { return {
"total_chunks": total_chunks, "total_chunks": total_chunks,
"total_files": total_files, "total_files": total_files,
@@ -281,40 +285,39 @@ class SemanticMemory:
class MemorySearcher: class MemorySearcher:
"""High-level interface for memory search.""" """High-level interface for memory search."""
def __init__(self) -> None: def __init__(self) -> None:
self.semantic = SemanticMemory() self.semantic = SemanticMemory()
def search(self, query: str, tiers: list[str] = None) -> dict: def search(self, query: str, tiers: list[str] = None) -> dict:
"""Search across memory tiers. """Search across memory tiers.
Args: Args:
query: Search query query: Search query
tiers: List of tiers to search ["hot", "vault", "semantic"] tiers: List of tiers to search ["hot", "vault", "semantic"]
Returns: Returns:
Dict with results from each tier Dict with results from each tier
""" """
tiers = tiers or ["semantic"] # Default to semantic only tiers = tiers or ["semantic"] # Default to semantic only
results = {} results = {}
if "semantic" in tiers: if "semantic" in tiers:
semantic_results = self.semantic.search(query, top_k=5) semantic_results = self.semantic.search(query, top_k=5)
results["semantic"] = [ results["semantic"] = [
{"content": content, "score": score} {"content": content, "score": score} for content, score in semantic_results
for content, score in semantic_results
] ]
return results return results
def get_context_for_query(self, query: str) -> str: def get_context_for_query(self, query: str) -> str:
"""Get comprehensive context for a user query.""" """Get comprehensive context for a user query."""
# Get semantic context # Get semantic context
semantic_context = self.semantic.get_relevant_context(query) semantic_context = self.semantic.get_relevant_context(query)
if semantic_context: if semantic_context:
return f"## Relevant Past Context\n\n{semantic_context}" return f"## Relevant Past Context\n\n{semantic_context}"
return "" return ""
@@ -353,6 +356,7 @@ def memory_search(query: str, top_k: int = 5) -> str:
# 2. Search runtime vector store (stored facts/conversations) # 2. Search runtime vector store (stored facts/conversations)
try: try:
from timmy.memory.vector_store import search_memories from timmy.memory.vector_store import search_memories
runtime_results = search_memories(query, limit=top_k, min_relevance=0.2) runtime_results = search_memories(query, limit=top_k, min_relevance=0.2)
for entry in runtime_results: for entry in runtime_results:
label = entry.context_type or "memory" label = entry.context_type or "memory"
@@ -387,6 +391,7 @@ def memory_read(query: str = "", top_k: int = 5) -> str:
# Always include personal facts first # Always include personal facts first
try: try:
from timmy.memory.vector_store import search_memories from timmy.memory.vector_store import search_memories
facts = search_memories(query or "", limit=top_k, min_relevance=0.0) facts = search_memories(query or "", limit=top_k, min_relevance=0.0)
fact_entries = [e for e in facts if (e.context_type or "") == "fact"] fact_entries = [e for e in facts if (e.context_type or "") == "fact"]
if fact_entries: if fact_entries:
@@ -433,6 +438,7 @@ def memory_write(content: str, context_type: str = "fact") -> str:
try: try:
from timmy.memory.vector_store import store_memory from timmy.memory.vector_store import store_memory
entry = store_memory( entry = store_memory(
content=content.strip(), content=content.strip(),
source="agent", source="agent",

View File

@@ -32,13 +32,15 @@ _TOOL_CALL_JSON = re.compile(
# Matches function-call-style text: memory_search(query="...") etc. # Matches function-call-style text: memory_search(query="...") etc.
_FUNC_CALL_TEXT = re.compile( _FUNC_CALL_TEXT = re.compile(
r'\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)' r"\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)"
r'\s*\([^)]*\)', r"\s*\([^)]*\)",
) )
# Matches chain-of-thought narration lines the model should keep internal # Matches chain-of-thought narration lines the model should keep internal
_COT_PATTERNS = [ _COT_PATTERNS = [
re.compile(r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE), re.compile(
r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE
),
re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE), re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE),
] ]
@@ -48,6 +50,7 @@ def _get_agent():
global _agent global _agent
if _agent is None: if _agent is None:
from timmy.agent import create_timmy from timmy.agent import create_timmy
try: try:
_agent = create_timmy() _agent = create_timmy()
logger.info("Session: Timmy agent initialized (singleton)") logger.info("Session: Timmy agent initialized (singleton)")
@@ -99,6 +102,7 @@ def reset_session(session_id: Optional[str] = None) -> None:
sid = session_id or _DEFAULT_SESSION_ID sid = session_id or _DEFAULT_SESSION_ID
try: try:
from timmy.conversation import conversation_manager from timmy.conversation import conversation_manager
conversation_manager.clear_context(sid) conversation_manager.clear_context(sid)
except Exception as exc: except Exception as exc:
logger.debug("Session: context clear failed for %s: %s", sid, exc) logger.debug("Session: context clear failed for %s: %s", sid, exc)
@@ -112,10 +116,12 @@ def _extract_facts(message: str) -> None:
""" """
try: try:
from timmy.conversation import conversation_manager from timmy.conversation import conversation_manager
name = conversation_manager.extract_user_name(message) name = conversation_manager.extract_user_name(message)
if name: if name:
try: try:
from timmy.memory_system import memory_system from timmy.memory_system import memory_system
memory_system.update_user_fact("Name", name) memory_system.update_user_fact("Name", name)
logger.info("Session: Learned user name: %s", name) logger.info("Session: Learned user name: %s", name)
except Exception as exc: except Exception as exc:

View File

@@ -6,7 +6,7 @@ including any mistakes or errors that occur during the session."
import json import json
import logging import logging
from datetime import datetime, date from datetime import date, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any

View File

@@ -75,6 +75,7 @@ Continue your train of thought."""
@dataclass @dataclass
class Thought: class Thought:
"""A single thought in Timmy's inner stream.""" """A single thought in Timmy's inner stream."""
id: str id: str
content: str content: str
seed_type: str seed_type: str
@@ -98,9 +99,7 @@ def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
) )
""" """
) )
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
"CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)"
)
conn.commit() conn.commit()
return conn return conn
@@ -190,9 +189,7 @@ class ThinkingEngine:
def get_thought(self, thought_id: str) -> Optional[Thought]: def get_thought(self, thought_id: str) -> Optional[Thought]:
"""Retrieve a single thought by ID.""" """Retrieve a single thought by ID."""
conn = _get_conn(self._db_path) conn = _get_conn(self._db_path)
row = conn.execute( row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone()
"SELECT * FROM thoughts WHERE id = ?", (thought_id,)
).fetchone()
conn.close() conn.close()
return _row_to_thought(row) if row else None return _row_to_thought(row) if row else None
@@ -208,9 +205,7 @@ class ThinkingEngine:
for _ in range(max_depth): for _ in range(max_depth):
if not current_id: if not current_id:
break break
row = conn.execute( row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (current_id,)).fetchone()
"SELECT * FROM thoughts WHERE id = ?", (current_id,)
).fetchone()
if not row: if not row:
break break
chain.append(_row_to_thought(row)) chain.append(_row_to_thought(row))
@@ -254,8 +249,10 @@ class ThinkingEngine:
def _seed_from_swarm(self) -> str: def _seed_from_swarm(self) -> str:
"""Gather recent swarm activity as thought seed.""" """Gather recent swarm activity as thought seed."""
try: try:
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
from datetime import timedelta from datetime import timedelta
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
since = datetime.now(timezone.utc) - timedelta(hours=1) since = datetime.now(timezone.utc) - timedelta(hours=1)
swarm = _gather_swarm_summary(since) swarm = _gather_swarm_summary(since)
tasks = _gather_task_queue_summary() tasks = _gather_task_queue_summary()
@@ -272,6 +269,7 @@ class ThinkingEngine:
"""Gather memory context as thought seed.""" """Gather memory context as thought seed."""
try: try:
from timmy.memory_system import memory_system from timmy.memory_system import memory_system
context = memory_system.get_system_context() context = memory_system.get_system_context()
if context: if context:
# Truncate to a reasonable size for a thought seed # Truncate to a reasonable size for a thought seed
@@ -299,10 +297,12 @@ class ThinkingEngine:
""" """
try: try:
from timmy.session import chat from timmy.session import chat
return chat(prompt, session_id="thinking") return chat(prompt, session_id="thinking")
except Exception: except Exception:
# Fallback: create a fresh agent # Fallback: create a fresh agent
from timmy.agent import create_timmy from timmy.agent import create_timmy
agent = create_timmy() agent = create_timmy()
run = agent.run(prompt, stream=False) run = agent.run(prompt, stream=False)
return run.content if hasattr(run, "content") else str(run) return run.content if hasattr(run, "content") else str(run)
@@ -323,8 +323,7 @@ class ThinkingEngine:
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at) INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
""", """,
(thought.id, thought.content, thought.seed_type, (thought.id, thought.content, thought.seed_type, thought.parent_id, thought.created_at),
thought.parent_id, thought.created_at),
) )
conn.commit() conn.commit()
conn.close() conn.close()
@@ -333,7 +332,8 @@ class ThinkingEngine:
def _log_event(self, thought: Thought) -> None: def _log_event(self, thought: Thought) -> None:
"""Log the thought as a swarm event.""" """Log the thought as a swarm event."""
try: try:
from swarm.event_log import log_event, EventType from swarm.event_log import EventType, log_event
log_event( log_event(
EventType.TIMMY_THOUGHT, EventType.TIMMY_THOUGHT,
source="thinking-engine", source="thinking-engine",
@@ -351,12 +351,16 @@ class ThinkingEngine:
"""Broadcast the thought to WebSocket clients.""" """Broadcast the thought to WebSocket clients."""
try: try:
from infrastructure.ws_manager.handler import ws_manager from infrastructure.ws_manager.handler import ws_manager
await ws_manager.broadcast("timmy_thought", {
"thought_id": thought.id, await ws_manager.broadcast(
"content": thought.content, "timmy_thought",
"seed_type": thought.seed_type, {
"created_at": thought.created_at, "thought_id": thought.id,
}) "content": thought.content,
"seed_type": thought.seed_type,
"created_at": thought.created_at,
},
)
except Exception as exc: except Exception as exc:
logger.debug("Failed to broadcast thought: %s", exc) logger.debug("Failed to broadcast thought: %s", exc)

View File

@@ -227,11 +227,7 @@ def create_aider_tool(base_path: Path):
) )
if result.returncode == 0: if result.returncode == 0:
return ( return result.stdout if result.stdout else "Code changes applied successfully"
result.stdout
if result.stdout
else "Code changes applied successfully"
)
else: else:
return f"Aider error: {result.stderr}" return f"Aider error: {result.stderr}"
except FileNotFoundError: except FileNotFoundError:
@@ -354,7 +350,7 @@ def consult_grok(query: str) -> str:
Grok's response text, or an error/status message. Grok's response text, or an error/status message.
""" """
from config import settings from config import settings
from timmy.backends import grok_available, get_grok_backend from timmy.backends import get_grok_backend, grok_available
if not grok_available(): if not grok_available():
return ( return (
@@ -385,9 +381,7 @@ def consult_grok(query: str) -> str:
ln = get_ln_backend() ln = get_ln_backend()
sats = min(settings.grok_max_sats_per_query, 100) sats = min(settings.grok_max_sats_per_query, 100)
inv = ln.create_invoice(sats, f"Grok query: {query[:50]}") inv = ln.create_invoice(sats, f"Grok query: {query[:50]}")
invoice_info = ( invoice_info = f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
)
except Exception: except Exception:
pass pass
@@ -447,7 +441,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
# Memory search and write — persistent recall across all channels # Memory search and write — persistent recall across all channels
try: try:
from timmy.semantic_memory import memory_search, memory_write, memory_read from timmy.semantic_memory import memory_read, memory_search, memory_write
toolkit.register(memory_search, name="memory_search") toolkit.register(memory_search, name="memory_search")
toolkit.register(memory_write, name="memory_write") toolkit.register(memory_write, name="memory_write")
@@ -473,6 +467,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
Task ID and confirmation that background execution has started. Task ID and confirmation that background execution has started.
""" """
import asyncio import asyncio
task_id = None task_id = None
async def _launch(): async def _launch():
@@ -502,11 +497,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
# System introspection - query runtime environment (sovereign self-knowledge) # System introspection - query runtime environment (sovereign self-knowledge)
try: try:
from timmy.tools_intro import ( from timmy.tools_intro import check_ollama_health, get_memory_status, get_system_info
get_system_info,
check_ollama_health,
get_memory_status,
)
toolkit.register(get_system_info, name="get_system_info") toolkit.register(get_system_info, name="get_system_info")
toolkit.register(check_ollama_health, name="check_ollama_health") toolkit.register(check_ollama_health, name="check_ollama_health")
@@ -526,6 +517,60 @@ def create_full_toolkit(base_dir: str | Path | None = None):
return toolkit return toolkit
def create_experiment_tools(base_dir: str | Path | None = None):
"""Create tools for the experiment agent (Lab).
Includes: prepare_experiment, run_experiment, evaluate_result,
plus shell + file ops for editing training code.
"""
if not _AGNO_TOOLS_AVAILABLE:
raise ImportError(f"Agno tools not available: {_ImportError}")
from config import settings
toolkit = Toolkit(name="experiment")
from timmy.autoresearch import evaluate_result, prepare_experiment, run_experiment
workspace = (
Path(base_dir) if base_dir else Path(settings.repo_root) / settings.autoresearch_workspace
)
def _prepare(repo_url: str = "https://github.com/karpathy/autoresearch.git") -> str:
"""Clone and prepare an autoresearch experiment workspace."""
return prepare_experiment(workspace, repo_url)
def _run(timeout: int = 0) -> str:
"""Run a single training experiment with wall-clock timeout."""
t = timeout or settings.autoresearch_time_budget
result = run_experiment(workspace, timeout=t, metric_name=settings.autoresearch_metric)
if result["success"] and result["metric"] is not None:
return (
f"{settings.autoresearch_metric}: {result['metric']:.4f} ({result['duration_s']}s)"
)
return result.get("error") or "Experiment failed"
def _evaluate(current: float, baseline: float) -> str:
"""Compare current metric against baseline."""
return evaluate_result(current, baseline, metric_name=settings.autoresearch_metric)
toolkit.register(_prepare, name="prepare_experiment")
toolkit.register(_run, name="run_experiment")
toolkit.register(_evaluate, name="evaluate_result")
# Also give Lab access to file + shell tools for editing train.py
shell_tools = ShellTools()
toolkit.register(shell_tools.run_shell_command, name="shell")
base_path = Path(base_dir) if base_dir else Path(settings.repo_root)
file_tools = FileTools(base_dir=base_path)
toolkit.register(file_tools.read_file, name="read_file")
toolkit.register(file_tools.save_file, name="write_file")
toolkit.register(file_tools.list_files, name="list_files")
return toolkit
# Mapping of agent IDs to their toolkits # Mapping of agent IDs to their toolkits
AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = { AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
"echo": create_research_tools, "echo": create_research_tools,
@@ -534,6 +579,7 @@ AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
"seer": create_data_tools, "seer": create_data_tools,
"forge": create_code_tools, "forge": create_code_tools,
"quill": create_writing_tools, "quill": create_writing_tools,
"lab": create_experiment_tools,
"pixel": lambda base_dir=None: _create_stub_toolkit("pixel"), "pixel": lambda base_dir=None: _create_stub_toolkit("pixel"),
"lyra": lambda base_dir=None: _create_stub_toolkit("lyra"), "lyra": lambda base_dir=None: _create_stub_toolkit("lyra"),
"reel": lambda base_dir=None: _create_stub_toolkit("reel"), "reel": lambda base_dir=None: _create_stub_toolkit("reel"),
@@ -553,9 +599,7 @@ def _create_stub_toolkit(name: str):
return toolkit return toolkit
def get_tools_for_agent( def get_tools_for_agent(agent_id: str, base_dir: str | Path | None = None) -> Toolkit | None:
agent_id: str, base_dir: str | Path | None = None
) -> Toolkit | None:
"""Get the appropriate toolkit for an agent. """Get the appropriate toolkit for an agent.
Args: Args:
@@ -643,6 +687,21 @@ def get_all_available_tools() -> dict[str, dict]:
"description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)", "description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)",
"available_in": ["forge", "orchestrator"], "available_in": ["forge", "orchestrator"],
}, },
"prepare_experiment": {
"name": "Prepare Experiment",
"description": "Clone autoresearch repo and run data preparation for ML experiments",
"available_in": ["lab", "orchestrator"],
},
"run_experiment": {
"name": "Run Experiment",
"description": "Execute a time-boxed ML training experiment and capture metrics",
"available_in": ["lab", "orchestrator"],
},
"evaluate_result": {
"name": "Evaluate Result",
"description": "Compare experiment metric against baseline to assess improvement",
"available_in": ["lab", "orchestrator"],
},
} }
# ── Git tools ───────────────────────────────────────────────────────────── # ── Git tools ─────────────────────────────────────────────────────────────

View File

@@ -20,7 +20,9 @@ _VALID_AGENTS: dict[str, str] = {
} }
def delegate_task(agent_name: str, task_description: str, priority: str = "normal") -> dict[str, Any]: def delegate_task(
agent_name: str, task_description: str, priority: str = "normal"
) -> dict[str, Any]:
"""Record a delegation intent to another agent. """Record a delegation intent to another agent.
Args: Args:
@@ -44,7 +46,9 @@ def delegate_task(agent_name: str, task_description: str, priority: str = "norma
if priority not in valid_priorities: if priority not in valid_priorities:
priority = "normal" priority = "normal"
logger.info("Delegation intent: %s%s (priority=%s)", agent_name, task_description[:80], priority) logger.info(
"Delegation intent: %s%s (priority=%s)", agent_name, task_description[:80], priority
)
return { return {
"success": True, "success": True,

View File

@@ -65,9 +65,7 @@ def _get_ollama_model() -> str:
models = response.json().get("models", []) models = response.json().get("models", [])
# Check if configured model is available # Check if configured model is available
for model in models: for model in models:
if model.get("name", "").startswith( if model.get("name", "").startswith(settings.ollama_model.split(":")[0]):
settings.ollama_model.split(":")[0]
):
return settings.ollama_model return settings.ollama_model
# Fallback: return configured model # Fallback: return configured model
@@ -139,9 +137,7 @@ def get_memory_status() -> dict[str, Any]:
if tier1_exists: if tier1_exists:
lines = memory_md.read_text().splitlines() lines = memory_md.read_text().splitlines()
tier1_info["line_count"] = len(lines) tier1_info["line_count"] = len(lines)
tier1_info["sections"] = [ tier1_info["sections"] = [ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")]
ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")
]
# Vault — scan all subdirs under memory/ # Vault — scan all subdirs under memory/
vault_root = repo_root / "memory" vault_root = repo_root / "memory"
@@ -233,13 +229,15 @@ def get_agent_roster() -> dict[str, Any]:
roster = [] roster = []
for persona in _PERSONAS: for persona in _PERSONAS:
roster.append({ roster.append(
"id": persona["agent_id"], {
"name": persona["name"], "id": persona["agent_id"],
"status": "available", "name": persona["name"],
"capabilities": ", ".join(persona.get("tools", [])), "status": "available",
"role": persona.get("role", ""), "capabilities": ", ".join(persona.get("tools", [])),
}) "role": persona.get("role", ""),
}
)
return { return {
"agents": roster, "agents": roster,

View File

@@ -41,7 +41,7 @@ class StatusResponse(BaseModel):
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple in-memory rate limiting middleware.""" """Simple in-memory rate limiting middleware."""
def __init__(self, app, limit: int = 10, window: int = 60): def __init__(self, app, limit: int = 10, window: int = 60):
super().__init__(app) super().__init__(app)
self.limit = limit self.limit = limit
@@ -53,22 +53,20 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
if request.url.path == "/serve/chat" and request.method == "POST": if request.url.path == "/serve/chat" and request.method == "POST":
client_ip = request.client.host if request.client else "unknown" client_ip = request.client.host if request.client else "unknown"
now = time.time() now = time.time()
# Clean up old requests # Clean up old requests
self.requests[client_ip] = [ self.requests[client_ip] = [
t for t in self.requests[client_ip] t for t in self.requests[client_ip] if now - t < self.window
if now - t < self.window
] ]
if len(self.requests[client_ip]) >= self.limit: if len(self.requests[client_ip]) >= self.limit:
logger.warning("Rate limit exceeded for %s", client_ip) logger.warning("Rate limit exceeded for %s", client_ip)
return JSONResponse( return JSONResponse(
status_code=429, status_code=429, content={"error": "Rate limit exceeded. Try again later."}
content={"error": "Rate limit exceeded. Try again later."}
) )
self.requests[client_ip].append(now) self.requests[client_ip].append(now)
return await call_next(request) return await call_next(request)

View File

@@ -33,6 +33,7 @@ def start(
return return
import uvicorn import uvicorn
from timmy_serve.app import create_timmy_serve_app from timmy_serve.app import create_timmy_serve_app
serve_app = create_timmy_serve_app() serve_app = create_timmy_serve_app()

View File

@@ -23,9 +23,7 @@ class AgentMessage:
to_agent: str = "" to_agent: str = ""
content: str = "" content: str = ""
message_type: str = "text" # text | command | response | error message_type: str = "text" # text | command | response | error
timestamp: str = field( timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
default_factory=lambda: datetime.now(timezone.utc).isoformat()
)
replied: bool = False replied: bool = False
@@ -56,7 +54,10 @@ class InterAgentMessenger:
self._all_messages.append(msg) self._all_messages.append(msg)
logger.info( logger.info(
"Message %s%s: %s (%s)", "Message %s%s: %s (%s)",
from_agent, to_agent, content[:50], message_type, from_agent,
to_agent,
content[:50],
message_type,
) )
return msg return msg

View File

@@ -26,6 +26,7 @@ class VoiceTTS:
def _init_engine(self) -> None: def _init_engine(self) -> None:
try: try:
import pyttsx3 import pyttsx3
self._engine = pyttsx3.init() self._engine = pyttsx3.init()
self._engine.setProperty("rate", self._rate) self._engine.setProperty("rate", self._rate)
self._engine.setProperty("volume", self._volume) self._engine.setProperty("volume", self._volume)

Some files were not shown because too many files have changed in this diff Show More