forked from Rockachopa/Timmy-time-dashboard
feat: code quality audit + autoresearch integration + infra hardening (#150)
This commit is contained in:
committed by
GitHub
parent
fd0ede0d51
commit
ae3bb1cc21
20
.env.example
20
.env.example
@@ -71,3 +71,23 @@
|
||||
# Requires: pip install ".[discord]"
|
||||
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
|
||||
# DISCORD_TOKEN=
|
||||
|
||||
# ── Autoresearch — autonomous ML experiment loops ────────────────────────────
|
||||
# Enable autonomous experiment loops (Karpathy autoresearch pattern).
|
||||
# AUTORESEARCH_ENABLED=false
|
||||
# AUTORESEARCH_WORKSPACE=data/experiments
|
||||
# AUTORESEARCH_TIME_BUDGET=300
|
||||
# AUTORESEARCH_MAX_ITERATIONS=100
|
||||
# AUTORESEARCH_METRIC=val_bpb
|
||||
|
||||
# ── Docker Production ────────────────────────────────────────────────────────
|
||||
# When deploying with docker-compose.prod.yml:
|
||||
# - Containers run as non-root user "timmy" (defined in Dockerfile)
|
||||
# - No source bind mounts — code is baked into the image
|
||||
# - Set TIMMY_ENV=production to enforce security checks
|
||||
# - All secrets below MUST be set before production deployment
|
||||
#
|
||||
# Taskosaur secrets (change from dev defaults):
|
||||
# TASKOSAUR_JWT_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||
# TASKOSAUR_JWT_REFRESH_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||
# TASKOSAUR_ENCRYPTION_KEY=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||
|
||||
40
.github/workflows/tests.yml
vendored
40
.github/workflows/tests.yml
vendored
@@ -7,8 +7,30 @@ on:
|
||||
branches: ["**"]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install linters
|
||||
run: pip install black==23.12.1 isort==5.13.2 bandit==1.7.5
|
||||
|
||||
- name: Check formatting (black)
|
||||
run: black --check --line-length 100 src/ tests/
|
||||
|
||||
- name: Check import order (isort)
|
||||
run: isort --check --profile black --line-length 100 src/ tests/
|
||||
|
||||
- name: Security scan (bandit)
|
||||
run: bandit -r src/ -ll -s B101,B104,B307,B310,B324,B601,B608 -q
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: lint
|
||||
|
||||
# Required for publish-unit-test-result-action to post check runs and PR comments
|
||||
permissions:
|
||||
@@ -22,7 +44,15 @@ jobs:
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
|
||||
- name: Cache Poetry virtualenv
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cache/pypoetry
|
||||
~/.cache/pip
|
||||
key: poetry-${{ hashFiles('poetry.lock') }}
|
||||
restore-keys: poetry-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -60,3 +90,11 @@ jobs:
|
||||
name: coverage-report
|
||||
path: reports/coverage.xml
|
||||
retention-days: 14
|
||||
|
||||
docker-build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build Docker image
|
||||
run: DOCKER_BUILDKIT=1 docker build -t timmy-time:ci .
|
||||
|
||||
@@ -51,12 +51,12 @@ repos:
|
||||
exclude: ^tests/
|
||||
stages: [manual]
|
||||
|
||||
# Full test suite with 30-second wall-clock limit.
|
||||
# Current baseline: ~18s. If tests get slow, this blocks the commit.
|
||||
# Unit tests only with 30-second wall-clock limit.
|
||||
# Runs only fast unit tests on commit; full suite runs in CI.
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest-fast
|
||||
name: pytest (30s limit)
|
||||
name: pytest unit (30s limit)
|
||||
entry: timeout 30 poetry run pytest
|
||||
language: system
|
||||
types: [python]
|
||||
@@ -68,4 +68,8 @@ repos:
|
||||
- -q
|
||||
- --tb=short
|
||||
- --timeout=10
|
||||
- -m
|
||||
- unit
|
||||
- -p
|
||||
- no:xdist
|
||||
verbose: true
|
||||
|
||||
@@ -56,7 +56,7 @@ make test-cov # With coverage (term-missing + XML)
|
||||
- **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest
|
||||
- **FastAPI testing:** Use the `client` fixture
|
||||
- **Async:** `asyncio_mode = "auto"` — async tests detected automatically
|
||||
- **Coverage threshold:** 60% (`fail_under` in `pyproject.toml`)
|
||||
- **Coverage threshold:** 73% (`fail_under` in `pyproject.toml`)
|
||||
|
||||
---
|
||||
|
||||
|
||||
15
Dockerfile
15
Dockerfile
@@ -11,7 +11,7 @@
|
||||
# timmy-time:latest \
|
||||
# python -m swarm.agent_runner --agent-id w1 --name Worker-1
|
||||
|
||||
# ── Stage 1: Builder — export deps via Poetry, install via pip ──────────────
|
||||
# ── Stage 1: Builder — install deps via Poetry ──────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
@@ -20,18 +20,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Install Poetry + export plugin (only needed for export, not in runtime)
|
||||
RUN pip install --no-cache-dir poetry poetry-plugin-export
|
||||
# Install Poetry (only needed to resolve deps, not in runtime)
|
||||
RUN pip install --no-cache-dir poetry
|
||||
|
||||
# Copy dependency files only (layer caching)
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
|
||||
# Export pinned requirements and install with pip cache mount
|
||||
RUN poetry export --extras swarm --extras telegram --extras discord --without-hashes \
|
||||
-f requirements.txt -o requirements.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
# Install deps directly from lock file (no virtualenv, no export plugin needed)
|
||||
RUN poetry config virtualenvs.create false && \
|
||||
poetry install --only main --extras telegram --extras discord --no-interaction
|
||||
|
||||
# ── Stage 2: Runtime ───────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS base
|
||||
|
||||
5
Makefile
5
Makefile
@@ -210,6 +210,11 @@ docker-up:
|
||||
mkdir -p data
|
||||
docker compose up -d dashboard
|
||||
|
||||
docker-prod:
|
||||
mkdir -p data
|
||||
DOCKER_BUILDKIT=1 docker build -t timmy-time:latest .
|
||||
docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d dashboard
|
||||
|
||||
docker-down:
|
||||
docker compose down
|
||||
|
||||
|
||||
56
docker-compose.prod.yml
Normal file
56
docker-compose.prod.yml
Normal file
@@ -0,0 +1,56 @@
|
||||
# ── Production Compose Overlay ─────────────────────────────────────────────────
|
||||
#
|
||||
# Usage:
|
||||
# make docker-prod # build + start with prod settings
|
||||
# docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d
|
||||
#
|
||||
# Differences from dev:
|
||||
# - Runs as non-root user (timmy) from Dockerfile
|
||||
# - No bind mounts — uses image-baked source only
|
||||
# - Named volumes only (no host path dependencies)
|
||||
# - Read-only root filesystem with tmpfs for /tmp
|
||||
# - Resource limits enforced
|
||||
# - Secrets passed via environment variables (set in .env)
|
||||
#
|
||||
# Security note: Set all secrets in .env before deploying.
|
||||
# Required: L402_HMAC_SECRET, L402_MACAROON_SECRET
|
||||
# Recommended: TASKOSAUR_JWT_SECRET, TASKOSAUR_ENCRYPTION_KEY
|
||||
|
||||
services:
|
||||
|
||||
dashboard:
|
||||
# Remove dev-only root user override — use Dockerfile's USER timmy
|
||||
user: ""
|
||||
read_only: true
|
||||
tmpfs:
|
||||
- /tmp:size=100M
|
||||
volumes:
|
||||
# Override: named volume only, no host bind mounts
|
||||
- timmy-data:/app/data
|
||||
# Remove ./src and ./static bind mounts (use baked-in image files)
|
||||
environment:
|
||||
DEBUG: "false"
|
||||
TIMMY_ENV: "production"
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "2.0"
|
||||
memory: 2G
|
||||
|
||||
celery-worker:
|
||||
user: ""
|
||||
read_only: true
|
||||
tmpfs:
|
||||
- /tmp:size=100M
|
||||
volumes:
|
||||
- timmy-data:/app/data
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "1.0"
|
||||
memory: 1G
|
||||
|
||||
# Override timmy-data to use a simple named volume (no host bind)
|
||||
volumes:
|
||||
timmy-data:
|
||||
driver: local
|
||||
@@ -97,6 +97,12 @@ markers = [
|
||||
"skip_ci: Skip in CI environment (local development only)",
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 100
|
||||
src_paths = ["src", "tests"]
|
||||
known_first_party = ["brain", "config", "dashboard", "infrastructure", "integrations", "spark", "swarm", "timmy", "timmy_serve"]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = [
|
||||
|
||||
@@ -11,9 +11,9 @@ upgrade to distributed rqlite over Tailscale — same API, replicated.
|
||||
"""
|
||||
|
||||
from brain.client import BrainClient
|
||||
from brain.worker import DistributedWorker
|
||||
from brain.embeddings import LocalEmbedder
|
||||
from brain.memory import UnifiedMemory, get_memory
|
||||
from brain.worker import DistributedWorker
|
||||
|
||||
__all__ = [
|
||||
"BrainClient",
|
||||
|
||||
@@ -21,52 +21,54 @@ DEFAULT_RQLITE_URL = "http://localhost:4001"
|
||||
|
||||
class BrainClient:
|
||||
"""Client for distributed brain (rqlite).
|
||||
|
||||
|
||||
Connects to local rqlite instance, which handles replication.
|
||||
All writes go to leader, reads can come from local node.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None):
|
||||
from config import settings
|
||||
|
||||
self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL
|
||||
self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}"
|
||||
self.source = self._detect_source()
|
||||
self._client = httpx.AsyncClient(timeout=30)
|
||||
|
||||
|
||||
def _detect_source(self) -> str:
|
||||
"""Detect what component is using the brain."""
|
||||
# Could be 'timmy', 'zeroclaw', 'worker', etc.
|
||||
# For now, infer from context or env
|
||||
from config import settings
|
||||
|
||||
return settings.brain_source
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Memory Operations
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
content: str,
|
||||
tags: Optional[List[str]] = None,
|
||||
source: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Store a memory with embedding.
|
||||
|
||||
|
||||
Args:
|
||||
content: Text content to remember
|
||||
tags: Optional list of tags (e.g., ['shell', 'result'])
|
||||
source: Source identifier (defaults to self.source)
|
||||
metadata: Additional JSON-serializable metadata
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with 'id' and 'status'
|
||||
"""
|
||||
from brain.embeddings import get_embedder
|
||||
|
||||
|
||||
embedder = get_embedder()
|
||||
embedding_bytes = embedder.encode_single(content)
|
||||
|
||||
|
||||
query = """
|
||||
INSERT INTO memories (content, embedding, source, tags, metadata, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
@@ -77,100 +79,90 @@ class BrainClient:
|
||||
source or self.source,
|
||||
json.dumps(tags or []),
|
||||
json.dumps(metadata or {}),
|
||||
datetime.utcnow().isoformat()
|
||||
datetime.utcnow().isoformat(),
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
# Extract inserted ID
|
||||
last_id = None
|
||||
if "results" in result and result["results"]:
|
||||
last_id = result["results"][0].get("last_insert_id")
|
||||
|
||||
|
||||
logger.debug(f"Stored memory {last_id}: {content[:50]}...")
|
||||
return {"id": last_id, "status": "stored"}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store memory: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
sources: Optional[List[str]] = None
|
||||
self, query: str, limit: int = 5, sources: Optional[List[str]] = None
|
||||
) -> List[str]:
|
||||
"""Semantic search for memories.
|
||||
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
limit: Max results to return
|
||||
sources: Filter by source(s) (e.g., ['timmy', 'user'])
|
||||
|
||||
|
||||
Returns:
|
||||
List of memory content strings
|
||||
"""
|
||||
from brain.embeddings import get_embedder
|
||||
|
||||
|
||||
embedder = get_embedder()
|
||||
query_emb = embedder.encode_single(query)
|
||||
|
||||
|
||||
# rqlite with sqlite-vec extension for vector search
|
||||
sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?"
|
||||
params = [query_emb]
|
||||
|
||||
|
||||
if sources:
|
||||
placeholders = ",".join(["?"] * len(sources))
|
||||
sql += f" AND source IN ({placeholders})"
|
||||
params.extend(sources)
|
||||
|
||||
|
||||
sql += " ORDER BY distance LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, params]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
results = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
results.append({
|
||||
"content": row[0],
|
||||
"source": row[1],
|
||||
"metadata": json.loads(row[2]) if row[2] else {},
|
||||
"distance": row[3]
|
||||
})
|
||||
|
||||
results.append(
|
||||
{
|
||||
"content": row[0],
|
||||
"source": row[1],
|
||||
"metadata": json.loads(row[2]) if row[2] else {},
|
||||
"distance": row[3],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search memories: {e}")
|
||||
# Graceful fallback - return empty list
|
||||
return []
|
||||
|
||||
|
||||
async def get_recent(
|
||||
self,
|
||||
hours: int = 24,
|
||||
limit: int = 20,
|
||||
sources: Optional[List[str]] = None
|
||||
self, hours: int = 24, limit: int = 20, sources: Optional[List[str]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get recent memories by time.
|
||||
|
||||
|
||||
Args:
|
||||
hours: Look back this many hours
|
||||
limit: Max results
|
||||
sources: Optional source filter
|
||||
|
||||
|
||||
Returns:
|
||||
List of memory dicts
|
||||
"""
|
||||
@@ -180,84 +172,83 @@ class BrainClient:
|
||||
WHERE created_at > datetime('now', ?)
|
||||
"""
|
||||
params = [f"-{hours} hours"]
|
||||
|
||||
|
||||
if sources:
|
||||
placeholders = ",".join(["?"] * len(sources))
|
||||
sql += f" AND source IN ({placeholders})"
|
||||
params.extend(sources)
|
||||
|
||||
|
||||
sql += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, params]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
memories = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
memories.append({
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"source": row[2],
|
||||
"tags": json.loads(row[3]) if row[3] else [],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5]
|
||||
})
|
||||
|
||||
memories.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"source": row[2],
|
||||
"tags": json.loads(row[3]) if row[3] else [],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5],
|
||||
}
|
||||
)
|
||||
|
||||
return memories
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent memories: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_context(self, query: str) -> str:
|
||||
"""Get formatted context for system prompt.
|
||||
|
||||
|
||||
Combines recent memories + relevant memories.
|
||||
|
||||
|
||||
Args:
|
||||
query: Current user query to find relevant context
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted context string for prompt injection
|
||||
"""
|
||||
recent = await self.get_recent(hours=24, limit=10)
|
||||
relevant = await self.recall(query, limit=5)
|
||||
|
||||
|
||||
lines = ["Recent activity:"]
|
||||
for m in recent[:5]:
|
||||
lines.append(f"- {m['content'][:100]}")
|
||||
|
||||
|
||||
lines.append("\nRelevant memories:")
|
||||
for r in relevant[:5]:
|
||||
lines.append(f"- {r['content'][:100]}")
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Task Queue Operations
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def submit_task(
|
||||
self,
|
||||
content: str,
|
||||
task_type: str = "general",
|
||||
priority: int = 0,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Submit a task to the distributed queue.
|
||||
|
||||
|
||||
Args:
|
||||
content: Task description/prompt
|
||||
task_type: Type of task (shell, creative, code, research, general)
|
||||
priority: Higher = processed first
|
||||
metadata: Additional task data
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with task 'id'
|
||||
"""
|
||||
@@ -270,50 +261,45 @@ class BrainClient:
|
||||
task_type,
|
||||
priority,
|
||||
json.dumps(metadata or {}),
|
||||
datetime.utcnow().isoformat()
|
||||
datetime.utcnow().isoformat(),
|
||||
]
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
last_id = None
|
||||
if "results" in result and result["results"]:
|
||||
last_id = result["results"][0].get("last_insert_id")
|
||||
|
||||
|
||||
logger.info(f"Submitted task {last_id}: {content[:50]}...")
|
||||
return {"id": last_id, "status": "queued"}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to submit task: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def claim_task(
|
||||
self,
|
||||
capabilities: List[str],
|
||||
node_id: Optional[str] = None
|
||||
self, capabilities: List[str], node_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Atomically claim next available task.
|
||||
|
||||
|
||||
Uses UPDATE ... RETURNING pattern for atomic claim.
|
||||
|
||||
|
||||
Args:
|
||||
capabilities: List of capabilities this node has
|
||||
node_id: Identifier for claiming node
|
||||
|
||||
|
||||
Returns:
|
||||
Task dict or None if no tasks available
|
||||
"""
|
||||
claimer = node_id or self.node_id
|
||||
|
||||
|
||||
# Try to claim a matching task atomically
|
||||
# This works because rqlite uses Raft consensus - only one node wins
|
||||
placeholders = ",".join(["?"] * len(capabilities))
|
||||
|
||||
|
||||
query = f"""
|
||||
UPDATE tasks
|
||||
SET status = 'claimed',
|
||||
@@ -330,15 +316,12 @@ class BrainClient:
|
||||
RETURNING id, content, task_type, priority, metadata
|
||||
"""
|
||||
params = [claimer, datetime.utcnow().isoformat()] + capabilities
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
if "results" in result and result["results"]:
|
||||
rows = result["results"][0].get("rows", [])
|
||||
if rows:
|
||||
@@ -348,24 +331,20 @@ class BrainClient:
|
||||
"content": row[1],
|
||||
"type": row[2],
|
||||
"priority": row[3],
|
||||
"metadata": json.loads(row[4]) if row[4] else {}
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
}
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to claim task: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def complete_task(
|
||||
self,
|
||||
task_id: int,
|
||||
success: bool,
|
||||
result: Optional[str] = None,
|
||||
error: Optional[str] = None
|
||||
self, task_id: int, success: bool, result: Optional[str] = None, error: Optional[str] = None
|
||||
) -> None:
|
||||
"""Mark task as completed or failed.
|
||||
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
success: True if task succeeded
|
||||
@@ -373,7 +352,7 @@ class BrainClient:
|
||||
error: Error message if failed
|
||||
"""
|
||||
status = "done" if success else "failed"
|
||||
|
||||
|
||||
query = """
|
||||
UPDATE tasks
|
||||
SET status = ?,
|
||||
@@ -383,23 +362,20 @@ class BrainClient:
|
||||
WHERE id = ?
|
||||
"""
|
||||
params = [status, result, error, datetime.utcnow().isoformat(), task_id]
|
||||
|
||||
|
||||
try:
|
||||
await self._client.post(
|
||||
f"{self.rqlite_url}/db/execute",
|
||||
json=[query, params]
|
||||
)
|
||||
await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||
logger.debug(f"Task {task_id} marked {status}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to complete task {task_id}: {e}")
|
||||
|
||||
|
||||
async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get list of pending tasks (for dashboard/monitoring).
|
||||
|
||||
|
||||
Args:
|
||||
limit: Max tasks to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of pending task dicts
|
||||
"""
|
||||
@@ -410,33 +386,32 @@ class BrainClient:
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
resp = await self._client.post(
|
||||
f"{self.rqlite_url}/db/query",
|
||||
json=[sql, [limit]]
|
||||
)
|
||||
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, [limit]])
|
||||
resp.raise_for_status()
|
||||
result = resp.json()
|
||||
|
||||
|
||||
tasks = []
|
||||
if "results" in result and result["results"]:
|
||||
for row in result["results"][0].get("rows", []):
|
||||
tasks.append({
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"type": row[2],
|
||||
"priority": row[3],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5]
|
||||
})
|
||||
|
||||
tasks.append(
|
||||
{
|
||||
"id": row[0],
|
||||
"content": row[1],
|
||||
"type": row[2],
|
||||
"priority": row[3],
|
||||
"metadata": json.loads(row[4]) if row[4] else {},
|
||||
"created_at": row[5],
|
||||
}
|
||||
)
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pending tasks: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP client."""
|
||||
await self._client.aclose()
|
||||
|
||||
@@ -18,48 +18,51 @@ _dimensions = 384
|
||||
|
||||
class LocalEmbedder:
|
||||
"""Local sentence transformer for embeddings.
|
||||
|
||||
|
||||
Uses all-MiniLM-L6-v2 (80MB download, runs on CPU).
|
||||
384-dimensional embeddings, good enough for semantic search.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, model_name: str = _model_name):
|
||||
self.model_name = model_name
|
||||
self._model = None
|
||||
self._dimensions = _dimensions
|
||||
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the model."""
|
||||
global _model
|
||||
if _model is not None:
|
||||
self._model = _model
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger.info(f"Loading embedding model: {self.model_name}")
|
||||
_model = SentenceTransformer(self.model_name)
|
||||
self._model = _model
|
||||
logger.info(f"Embedding model loaded ({self._dimensions} dims)")
|
||||
except ImportError:
|
||||
logger.error("sentence-transformers not installed. Run: pip install sentence-transformers")
|
||||
logger.error(
|
||||
"sentence-transformers not installed. Run: pip install sentence-transformers"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def encode(self, text: Union[str, List[str]]):
|
||||
"""Encode text to embedding vector(s).
|
||||
|
||||
|
||||
Args:
|
||||
text: String or list of strings to encode
|
||||
|
||||
|
||||
Returns:
|
||||
Numpy array of shape (dims,) for single string or (n, dims) for list
|
||||
"""
|
||||
if self._model is None:
|
||||
self._load_model()
|
||||
|
||||
|
||||
# Normalize embeddings for cosine similarity
|
||||
return self._model.encode(text, normalize_embeddings=True)
|
||||
|
||||
|
||||
def encode_single(self, text: str) -> bytes:
|
||||
"""Encode single text to bytes for SQLite storage.
|
||||
|
||||
@@ -67,17 +70,19 @@ class LocalEmbedder:
|
||||
Float32 bytes
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
embedding = self.encode(text)
|
||||
if len(embedding.shape) > 1:
|
||||
embedding = embedding[0]
|
||||
return embedding.astype(np.float32).tobytes()
|
||||
|
||||
|
||||
def similarity(self, a, b) -> float:
|
||||
"""Compute cosine similarity between two vectors.
|
||||
|
||||
Vectors should already be normalized from encode().
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
return float(np.dot(a, b))
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ _SCHEMA_VERSION = 1
|
||||
def _get_db_path() -> Path:
|
||||
"""Get the brain database path from env or default."""
|
||||
from config import settings
|
||||
|
||||
if settings.brain_db_path:
|
||||
return Path(settings.brain_db_path)
|
||||
return _DEFAULT_DB_PATH
|
||||
@@ -75,6 +76,7 @@ class UnifiedMemory:
|
||||
# Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite
|
||||
if use_rqlite is None:
|
||||
from config import settings as _settings
|
||||
|
||||
use_rqlite = bool(_settings.rqlite_url)
|
||||
self._use_rqlite = use_rqlite
|
||||
|
||||
@@ -107,10 +109,12 @@ class UnifiedMemory:
|
||||
"""Lazy-load the embedding model."""
|
||||
if self._embedder is None:
|
||||
from config import settings as _settings
|
||||
|
||||
if _settings.timmy_skip_embeddings:
|
||||
return None
|
||||
try:
|
||||
from brain.embeddings import LocalEmbedder
|
||||
|
||||
self._embedder = LocalEmbedder()
|
||||
except ImportError:
|
||||
logger.warning("sentence-transformers not available — semantic search disabled")
|
||||
@@ -125,6 +129,7 @@ class UnifiedMemory:
|
||||
"""Lazy-load the rqlite BrainClient."""
|
||||
if self._rqlite_client is None:
|
||||
from brain.client import BrainClient
|
||||
|
||||
self._rqlite_client = BrainClient()
|
||||
return self._rqlite_client
|
||||
|
||||
@@ -292,15 +297,17 @@ class UnifiedMemory:
|
||||
|
||||
results = []
|
||||
for score, row in scored[:limit]:
|
||||
results.append({
|
||||
"id": row["id"],
|
||||
"content": row["content"],
|
||||
"source": row["source"],
|
||||
"tags": json.loads(row["tags"]) if row["tags"] else [],
|
||||
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
||||
"score": score,
|
||||
"created_at": row["created_at"],
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"id": row["id"],
|
||||
"content": row["content"],
|
||||
"source": row["source"],
|
||||
"tags": json.loads(row["tags"]) if row["tags"] else [],
|
||||
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
||||
"score": score,
|
||||
"created_at": row["created_at"],
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
finally:
|
||||
|
||||
@@ -84,11 +84,13 @@ def get_migration_sql(from_version: int, to_version: int) -> str:
|
||||
"""Get SQL to migrate between versions."""
|
||||
if to_version <= from_version:
|
||||
return ""
|
||||
|
||||
|
||||
sql_parts = []
|
||||
for v in range(from_version + 1, to_version + 1):
|
||||
if v in MIGRATIONS:
|
||||
sql_parts.append(MIGRATIONS[v])
|
||||
sql_parts.append(f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');")
|
||||
|
||||
sql_parts.append(
|
||||
f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');"
|
||||
)
|
||||
|
||||
return "\n".join(sql_parts)
|
||||
|
||||
@@ -21,11 +21,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class DistributedWorker:
|
||||
"""Continuous task processor for the distributed brain.
|
||||
|
||||
|
||||
Runs on every device, claims tasks matching its capabilities,
|
||||
executes them immediately, stores results.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, brain_client: Optional[BrainClient] = None):
|
||||
self.brain = brain_client or BrainClient()
|
||||
self.node_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||
@@ -33,30 +33,30 @@ class DistributedWorker:
|
||||
self.running = False
|
||||
self._handlers: Dict[str, Callable] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
|
||||
def _detect_capabilities(self) -> List[str]:
|
||||
"""Detect what this node can do."""
|
||||
caps = ["general", "shell", "file_ops", "git"]
|
||||
|
||||
|
||||
# Check for GPU
|
||||
if self._has_gpu():
|
||||
caps.append("gpu")
|
||||
caps.append("creative")
|
||||
caps.append("image_gen")
|
||||
caps.append("video_gen")
|
||||
|
||||
|
||||
# Check for internet
|
||||
if self._has_internet():
|
||||
caps.append("web")
|
||||
caps.append("research")
|
||||
|
||||
|
||||
# Check memory
|
||||
mem_gb = self._get_memory_gb()
|
||||
if mem_gb > 16:
|
||||
caps.append("large_model")
|
||||
if mem_gb > 32:
|
||||
caps.append("huge_model")
|
||||
|
||||
|
||||
# Check for specific tools
|
||||
if self._has_command("ollama"):
|
||||
caps.append("ollama")
|
||||
@@ -64,17 +64,15 @@ class DistributedWorker:
|
||||
caps.append("docker")
|
||||
if self._has_command("cargo"):
|
||||
caps.append("rust")
|
||||
|
||||
|
||||
logger.info(f"Worker capabilities: {caps}")
|
||||
return caps
|
||||
|
||||
|
||||
def _has_gpu(self) -> bool:
|
||||
"""Check for NVIDIA or AMD GPU."""
|
||||
try:
|
||||
# Check for nvidia-smi
|
||||
result = subprocess.run(
|
||||
["nvidia-smi"], capture_output=True, timeout=5
|
||||
)
|
||||
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=5)
|
||||
if result.returncode == 0:
|
||||
return True
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
@@ -83,13 +81,15 @@ class DistributedWorker:
|
||||
# Check for ROCm
|
||||
if os.path.exists("/opt/rocm"):
|
||||
return True
|
||||
|
||||
|
||||
# Check for Apple Silicon Metal
|
||||
if os.uname().sysname == "Darwin":
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["system_profiler", "SPDisplaysDataType"],
|
||||
capture_output=True, text=True, timeout=5
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
if "Metal" in result.stdout:
|
||||
return True
|
||||
@@ -102,8 +102,7 @@ class DistributedWorker:
|
||||
"""Check if we have internet connectivity."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["curl", "-s", "--max-time", "3", "https://1.1.1.1"],
|
||||
capture_output=True, timeout=5
|
||||
["curl", "-s", "--max-time", "3", "https://1.1.1.1"], capture_output=True, timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
@@ -114,8 +113,7 @@ class DistributedWorker:
|
||||
try:
|
||||
if os.uname().sysname == "Darwin":
|
||||
result = subprocess.run(
|
||||
["sysctl", "-n", "hw.memsize"],
|
||||
capture_output=True, text=True
|
||||
["sysctl", "-n", "hw.memsize"], capture_output=True, text=True
|
||||
)
|
||||
bytes_mem = int(result.stdout.strip())
|
||||
return bytes_mem / (1024**3)
|
||||
@@ -128,13 +126,11 @@ class DistributedWorker:
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
return 8.0 # Assume 8GB if we can't detect
|
||||
|
||||
|
||||
def _has_command(self, cmd: str) -> bool:
|
||||
"""Check if command exists."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["which", cmd], capture_output=True, timeout=5
|
||||
)
|
||||
result = subprocess.run(["which", cmd], capture_output=True, timeout=5)
|
||||
return result.returncode == 0
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
return False
|
||||
@@ -148,10 +144,10 @@ class DistributedWorker:
|
||||
"research": self._handle_research,
|
||||
"general": self._handle_general,
|
||||
}
|
||||
|
||||
|
||||
def register_handler(self, task_type: str, handler: Callable[[str], Any]):
|
||||
"""Register a custom task handler.
|
||||
|
||||
|
||||
Args:
|
||||
task_type: Type of task this handler handles
|
||||
handler: Async function that takes task content and returns result
|
||||
@@ -159,11 +155,11 @@ class DistributedWorker:
|
||||
self._handlers[task_type] = handler
|
||||
if task_type not in self.capabilities:
|
||||
self.capabilities.append(task_type)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Task Handlers
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_shell(self, command: str) -> str:
|
||||
"""Execute shell command via ZeroClaw or direct subprocess."""
|
||||
# Try ZeroClaw first if available
|
||||
@@ -171,156 +167,153 @@ class DistributedWorker:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
f"zeroclaw exec --json '{command}'",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
|
||||
# Store result in brain
|
||||
await self.brain.remember(
|
||||
content=f"Shell: {command}\nOutput: {stdout.decode()}",
|
||||
tags=["shell", "result"],
|
||||
source=self.node_id,
|
||||
metadata={"command": command, "exit_code": proc.returncode}
|
||||
metadata={"command": command, "exit_code": proc.returncode},
|
||||
)
|
||||
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise Exception(f"Command failed: {stderr.decode()}")
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
# Fallback to direct subprocess (less safe)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise Exception(f"Command failed: {stderr.decode()}")
|
||||
return stdout.decode()
|
||||
|
||||
|
||||
async def _handle_creative(self, prompt: str) -> str:
|
||||
"""Generate creative media (requires GPU)."""
|
||||
if "gpu" not in self.capabilities:
|
||||
raise Exception("GPU not available on this node")
|
||||
|
||||
|
||||
# This would call creative tools (Stable Diffusion, etc.)
|
||||
# For now, placeholder
|
||||
logger.info(f"Creative task: {prompt[:50]}...")
|
||||
|
||||
|
||||
# Store result
|
||||
result = f"Creative output for: {prompt}"
|
||||
await self.brain.remember(
|
||||
content=result,
|
||||
tags=["creative", "generated"],
|
||||
source=self.node_id,
|
||||
metadata={"prompt": prompt}
|
||||
metadata={"prompt": prompt},
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _handle_code(self, description: str) -> str:
|
||||
"""Code generation and modification."""
|
||||
# Would use LLM to generate code
|
||||
# For now, placeholder
|
||||
logger.info(f"Code task: {description[:50]}...")
|
||||
return f"Code generated for: {description}"
|
||||
|
||||
|
||||
async def _handle_research(self, query: str) -> str:
|
||||
"""Web research."""
|
||||
if "web" not in self.capabilities:
|
||||
raise Exception("Internet not available on this node")
|
||||
|
||||
|
||||
# Would use browser automation or search
|
||||
logger.info(f"Research task: {query[:50]}...")
|
||||
return f"Research results for: {query}"
|
||||
|
||||
|
||||
async def _handle_general(self, prompt: str) -> str:
|
||||
"""General LLM task via local Ollama."""
|
||||
if "ollama" not in self.capabilities:
|
||||
raise Exception("Ollama not available on this node")
|
||||
|
||||
|
||||
# Call Ollama
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"curl", "-s", "http://localhost:11434/api/generate",
|
||||
"-d", json.dumps({
|
||||
"model": "llama3.1:8b-instruct",
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}),
|
||||
stdout=asyncio.subprocess.PIPE
|
||||
"curl",
|
||||
"-s",
|
||||
"http://localhost:11434/api/generate",
|
||||
"-d",
|
||||
json.dumps({"model": "llama3.1:8b-instruct", "prompt": prompt, "stream": False}),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
|
||||
|
||||
response = json.loads(stdout.decode())
|
||||
result = response.get("response", "No response")
|
||||
|
||||
|
||||
# Store in brain
|
||||
await self.brain.remember(
|
||||
content=f"Task: {prompt}\nResult: {result}",
|
||||
tags=["llm", "result"],
|
||||
source=self.node_id,
|
||||
metadata={"model": "llama3.1:8b-instruct"}
|
||||
metadata={"model": "llama3.1:8b-instruct"},
|
||||
)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"LLM failed: {e}")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# Main Loop
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Execute a claimed task."""
|
||||
task_type = task.get("type", "general")
|
||||
content = task.get("content", "")
|
||||
task_id = task.get("id")
|
||||
|
||||
|
||||
handler = self._handlers.get(task_type, self._handlers["general"])
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"Executing task {task_id}: {task_type}")
|
||||
result = await handler(content)
|
||||
|
||||
|
||||
await self.brain.complete_task(task_id, success=True, result=result)
|
||||
logger.info(f"Task {task_id} completed")
|
||||
return {"success": True, "result": result}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"Task {task_id} failed: {error_msg}")
|
||||
await self.brain.complete_task(task_id, success=False, error=error_msg)
|
||||
return {"success": False, "error": error_msg}
|
||||
|
||||
|
||||
async def run_once(self) -> bool:
|
||||
"""Process one task if available.
|
||||
|
||||
|
||||
Returns:
|
||||
True if a task was processed, False if no tasks available
|
||||
"""
|
||||
task = await self.brain.claim_task(self.capabilities, self.node_id)
|
||||
|
||||
|
||||
if task:
|
||||
await self.execute_task(task)
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def run(self):
|
||||
"""Main loop — continuously process tasks."""
|
||||
logger.info(f"Worker {self.node_id} started")
|
||||
logger.info(f"Capabilities: {self.capabilities}")
|
||||
|
||||
|
||||
self.running = True
|
||||
consecutive_empty = 0
|
||||
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
had_work = await self.run_once()
|
||||
|
||||
|
||||
if had_work:
|
||||
# Immediately check for more work
|
||||
consecutive_empty = 0
|
||||
@@ -331,11 +324,11 @@ class DistributedWorker:
|
||||
# Sleep 0.5s, but up to 2s if consistently empty
|
||||
sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""Stop the worker loop."""
|
||||
self.running = False
|
||||
@@ -345,7 +338,7 @@ class DistributedWorker:
|
||||
async def main():
|
||||
"""CLI entry point for worker."""
|
||||
import sys
|
||||
|
||||
|
||||
# Allow capability overrides from CLI
|
||||
if len(sys.argv) > 1:
|
||||
caps = sys.argv[1].split(",")
|
||||
@@ -354,12 +347,12 @@ async def main():
|
||||
logger.info(f"Overriding capabilities: {caps}")
|
||||
else:
|
||||
worker = DistributedWorker()
|
||||
|
||||
|
||||
try:
|
||||
await worker.run()
|
||||
except KeyboardInterrupt:
|
||||
worker.stop()
|
||||
print("\nWorker stopped.")
|
||||
logger.info("Worker stopped.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -213,6 +213,15 @@ class Settings(BaseSettings):
|
||||
# Timeout in seconds for OpenFang hand execution (some hands are slow).
|
||||
openfang_timeout: int = 120
|
||||
|
||||
# ── Autoresearch — autonomous ML experiment loops ──────────────────
|
||||
# Integrates Karpathy's autoresearch pattern: agents modify training
|
||||
# code, run time-boxed experiments, evaluate metrics, and iterate.
|
||||
autoresearch_enabled: bool = False
|
||||
autoresearch_workspace: str = "data/experiments"
|
||||
autoresearch_time_budget: int = 300 # seconds per experiment run
|
||||
autoresearch_max_iterations: int = 100
|
||||
autoresearch_metric: str = "val_bpb" # metric to optimise (lower = better)
|
||||
|
||||
# ── Local Hands (Shell + Git) ──────────────────────────────────────
|
||||
# Enable local shell/git execution hands.
|
||||
hands_shell_enabled: bool = True
|
||||
|
||||
@@ -18,36 +18,38 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from config import settings
|
||||
from dashboard.routes.agents import router as agents_router
|
||||
from dashboard.routes.health import router as health_router
|
||||
from dashboard.routes.marketplace import router as marketplace_router
|
||||
from dashboard.routes.voice import router as voice_router
|
||||
from dashboard.routes.mobile import router as mobile_router
|
||||
from dashboard.routes.briefing import router as briefing_router
|
||||
from dashboard.routes.telegram import router as telegram_router
|
||||
from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.spark import router as spark_router
|
||||
from dashboard.routes.discord import router as discord_router
|
||||
from dashboard.routes.memory import router as memory_router
|
||||
from dashboard.routes.router import router as router_status_router
|
||||
from dashboard.routes.grok import router as grok_router
|
||||
from dashboard.routes.models import router as models_router
|
||||
from dashboard.routes.models import api_router as models_api_router
|
||||
from dashboard.routes.chat_api import router as chat_api_router
|
||||
from dashboard.routes.thinking import router as thinking_router
|
||||
from dashboard.routes.calm import router as calm_router
|
||||
from dashboard.routes.swarm import router as swarm_router
|
||||
from dashboard.routes.tasks import router as tasks_router
|
||||
from dashboard.routes.work_orders import router as work_orders_router
|
||||
from dashboard.routes.system import router as system_router
|
||||
from dashboard.routes.paperclip import router as paperclip_router
|
||||
from infrastructure.router.api import router as cascade_router
|
||||
|
||||
# Import dedicated middleware
|
||||
from dashboard.middleware.csrf import CSRFMiddleware
|
||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from dashboard.routes.agents import router as agents_router
|
||||
from dashboard.routes.briefing import router as briefing_router
|
||||
from dashboard.routes.calm import router as calm_router
|
||||
from dashboard.routes.chat_api import router as chat_api_router
|
||||
from dashboard.routes.discord import router as discord_router
|
||||
from dashboard.routes.experiments import router as experiments_router
|
||||
from dashboard.routes.grok import router as grok_router
|
||||
from dashboard.routes.health import router as health_router
|
||||
from dashboard.routes.marketplace import router as marketplace_router
|
||||
from dashboard.routes.memory import router as memory_router
|
||||
from dashboard.routes.mobile import router as mobile_router
|
||||
from dashboard.routes.models import api_router as models_api_router
|
||||
from dashboard.routes.models import router as models_router
|
||||
from dashboard.routes.paperclip import router as paperclip_router
|
||||
from dashboard.routes.router import router as router_status_router
|
||||
from dashboard.routes.spark import router as spark_router
|
||||
from dashboard.routes.swarm import router as swarm_router
|
||||
from dashboard.routes.system import router as system_router
|
||||
from dashboard.routes.tasks import router as tasks_router
|
||||
from dashboard.routes.telegram import router as telegram_router
|
||||
from dashboard.routes.thinking import router as thinking_router
|
||||
from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.voice import router as voice_router
|
||||
from dashboard.routes.work_orders import router as work_orders_router
|
||||
from infrastructure.router.api import router as cascade_router
|
||||
|
||||
|
||||
def _configure_logging() -> None:
|
||||
@@ -100,8 +102,8 @@ _BRIEFING_INTERVAL_HOURS = 6
|
||||
|
||||
async def _briefing_scheduler() -> None:
|
||||
"""Background task: regenerate Timmy's briefing every 6 hours."""
|
||||
from timmy.briefing import engine as briefing_engine
|
||||
from infrastructure.notifications.push import notify_briefing_ready
|
||||
from timmy.briefing import engine as briefing_engine
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
@@ -121,9 +123,9 @@ async def _briefing_scheduler() -> None:
|
||||
|
||||
async def _start_chat_integrations_background() -> None:
|
||||
"""Background task: start chat integrations without blocking startup."""
|
||||
from integrations.telegram_bot.bot import telegram_bot
|
||||
from integrations.chat_bridge.vendors.discord import discord_bot
|
||||
from integrations.chat_bridge.registry import platform_registry
|
||||
from integrations.chat_bridge.vendors.discord import discord_bot
|
||||
from integrations.telegram_bot.bot import telegram_bot
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
@@ -164,9 +166,9 @@ async def _discord_token_watcher() -> None:
|
||||
if discord_bot.state.name == "CONNECTED":
|
||||
return # Already running — stop watching
|
||||
|
||||
# 1. Check live environment variable (intentionally uses os.environ,
|
||||
# not settings, because this polls for runtime hot-reload changes)
|
||||
token = os.environ.get("DISCORD_TOKEN", "")
|
||||
# 1. Check settings (pydantic-settings reads env on instantiation;
|
||||
# hot-reload is handled by re-reading .env below)
|
||||
token = settings.discord_token
|
||||
|
||||
# 2. Re-read .env file for hot-reload
|
||||
if not token:
|
||||
@@ -203,6 +205,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Initialize Spark Intelligence engine
|
||||
from spark.engine import spark_engine
|
||||
|
||||
if spark_engine.enabled:
|
||||
logger.info("Spark Intelligence active — event capture enabled")
|
||||
|
||||
@@ -210,12 +213,17 @@ async def lifespan(app: FastAPI):
|
||||
if settings.memory_prune_days > 0:
|
||||
try:
|
||||
from timmy.memory.vector_store import prune_memories
|
||||
|
||||
pruned = prune_memories(
|
||||
older_than_days=settings.memory_prune_days,
|
||||
keep_facts=settings.memory_prune_keep_facts,
|
||||
)
|
||||
if pruned:
|
||||
logger.info("Memory auto-prune: removed %d entries older than %d days", pruned, settings.memory_prune_days)
|
||||
logger.info(
|
||||
"Memory auto-prune: removed %d entries older than %d days",
|
||||
pruned,
|
||||
settings.memory_prune_days,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Memory auto-prune skipped: %s", exc)
|
||||
|
||||
@@ -229,7 +237,8 @@ async def lifespan(app: FastAPI):
|
||||
if total_mb > settings.memory_vault_max_mb:
|
||||
logger.warning(
|
||||
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
|
||||
total_mb, settings.memory_vault_max_mb,
|
||||
total_mb,
|
||||
settings.memory_vault_max_mb,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Vault size check skipped: %s", exc)
|
||||
@@ -284,10 +293,7 @@ def _get_cors_origins() -> list[str]:
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||
|
||||
# 2. Security Headers
|
||||
app.add_middleware(
|
||||
SecurityHeadersMiddleware,
|
||||
production=not settings.debug
|
||||
)
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
|
||||
|
||||
# 3. CSRF Protection
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
@@ -314,7 +320,6 @@ if static_dir.exists():
|
||||
# Shared templates instance
|
||||
from dashboard.templating import templates # noqa: E402
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(health_router)
|
||||
app.include_router(agents_router)
|
||||
@@ -339,6 +344,7 @@ app.include_router(tasks_router)
|
||||
app.include_router(work_orders_router)
|
||||
app.include_router(system_router)
|
||||
app.include_router(paperclip_router)
|
||||
app.include_router(experiments_router)
|
||||
app.include_router(cascade_router)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Dashboard middleware package."""
|
||||
|
||||
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
|
||||
from .security_headers import SecurityHeadersMiddleware
|
||||
from .request_logging import RequestLoggingMiddleware
|
||||
from .security_headers import SecurityHeadersMiddleware
|
||||
|
||||
__all__ = [
|
||||
"CSRFMiddleware",
|
||||
|
||||
@@ -4,16 +4,15 @@ Provides CSRF token generation, validation, and middleware integration
|
||||
to protect state-changing endpoints from cross-site request attacks.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import hmac
|
||||
import hashlib
|
||||
from typing import Callable, Optional
|
||||
import hmac
|
||||
import secrets
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
# Module-level set to track exempt routes
|
||||
_exempt_routes: set[str] = set()
|
||||
@@ -21,26 +20,27 @@ _exempt_routes: set[str] = set()
|
||||
|
||||
def csrf_exempt(endpoint: Callable) -> Callable:
|
||||
"""Decorator to mark an endpoint as exempt from CSRF validation.
|
||||
|
||||
|
||||
Usage:
|
||||
@app.post("/webhook")
|
||||
@csrf_exempt
|
||||
def webhook_endpoint():
|
||||
...
|
||||
"""
|
||||
|
||||
@wraps(endpoint)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
return await endpoint(*args, **kwargs)
|
||||
|
||||
|
||||
@wraps(endpoint)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
return endpoint(*args, **kwargs)
|
||||
|
||||
|
||||
# Mark the original function as exempt
|
||||
endpoint._csrf_exempt = True # type: ignore
|
||||
|
||||
|
||||
# Also mark the wrapper
|
||||
if hasattr(endpoint, '__code__') and endpoint.__code__.co_flags & 0x80:
|
||||
if hasattr(endpoint, "__code__") and endpoint.__code__.co_flags & 0x80:
|
||||
async_wrapper._csrf_exempt = True # type: ignore
|
||||
return async_wrapper
|
||||
else:
|
||||
@@ -50,12 +50,12 @@ def csrf_exempt(endpoint: Callable) -> Callable:
|
||||
|
||||
def is_csrf_exempt(endpoint: Callable) -> bool:
|
||||
"""Check if an endpoint is marked as CSRF exempt."""
|
||||
return getattr(endpoint, '_csrf_exempt', False)
|
||||
return getattr(endpoint, "_csrf_exempt", False)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a cryptographically secure CSRF token.
|
||||
|
||||
|
||||
Returns:
|
||||
A secure random token string.
|
||||
"""
|
||||
@@ -64,77 +64,78 @@ def generate_csrf_token() -> str:
|
||||
|
||||
def validate_csrf_token(token: str, expected_token: str) -> bool:
|
||||
"""Validate a CSRF token against the expected token.
|
||||
|
||||
|
||||
Uses constant-time comparison to prevent timing attacks.
|
||||
|
||||
|
||||
Args:
|
||||
token: The token provided by the client.
|
||||
expected_token: The expected token (from cookie/session).
|
||||
|
||||
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
if not token or not expected_token:
|
||||
return False
|
||||
|
||||
|
||||
return hmac.compare_digest(token, expected_token)
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to enforce CSRF protection on state-changing requests.
|
||||
|
||||
|
||||
Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens.
|
||||
State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token.
|
||||
|
||||
|
||||
The token is expected to be:
|
||||
- In the X-CSRF-Token header, or
|
||||
- In the request body as 'csrf_token', or
|
||||
- Matching the token in the csrf_token cookie
|
||||
|
||||
|
||||
Usage:
|
||||
app.add_middleware(CSRFMiddleware, secret="your-secret-key")
|
||||
|
||||
|
||||
Attributes:
|
||||
secret: Secret key for token signing (optional, for future use).
|
||||
cookie_name: Name of the CSRF cookie.
|
||||
header_name: Name of the CSRF header.
|
||||
safe_methods: HTTP methods that don't require CSRF tokens.
|
||||
"""
|
||||
|
||||
|
||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
secret: Optional[str] = None,
|
||||
cookie_name: str = "csrf_token",
|
||||
header_name: str = "X-CSRF-Token",
|
||||
form_field: str = "csrf_token"
|
||||
form_field: str = "csrf_token",
|
||||
):
|
||||
super().__init__(app)
|
||||
self.secret = secret
|
||||
self.cookie_name = cookie_name
|
||||
self.header_name = header_name
|
||||
self.form_field = form_field
|
||||
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Process the request and enforce CSRF protection.
|
||||
|
||||
|
||||
For safe methods: Set a CSRF token cookie if not present.
|
||||
For unsafe methods: Validate the CSRF token.
|
||||
"""
|
||||
# Bypass CSRF if explicitly disabled (e.g. in tests)
|
||||
from config import settings
|
||||
|
||||
if settings.timmy_disable_csrf:
|
||||
return await call_next(request)
|
||||
|
||||
# Get existing CSRF token from cookie
|
||||
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||
|
||||
|
||||
# For safe methods, just ensure a token exists
|
||||
if request.method in self.SAFE_METHODS:
|
||||
response = await call_next(request)
|
||||
|
||||
|
||||
# Set CSRF token cookie if not present
|
||||
if not csrf_cookie:
|
||||
new_token = generate_csrf_token()
|
||||
@@ -144,15 +145,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
httponly=False, # Must be readable by JavaScript
|
||||
secure=settings.csrf_cookie_secure,
|
||||
samesite="Lax",
|
||||
max_age=86400 # 24 hours
|
||||
max_age=86400, # 24 hours
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# For unsafe methods, check if route is exempt first
|
||||
# Note: We need to let the request proceed and check at response time
|
||||
# since FastAPI routes are resolved after middleware
|
||||
|
||||
|
||||
# Try to validate token early
|
||||
if not await self._validate_request(request, csrf_cookie):
|
||||
# Check if this might be an exempt route by checking path patterns
|
||||
@@ -164,33 +165,34 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
content={
|
||||
"error": "CSRF validation failed",
|
||||
"code": "CSRF_INVALID",
|
||||
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field."
|
||||
}
|
||||
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def _is_likely_exempt(self, path: str) -> bool:
|
||||
"""Check if a path is likely to be CSRF exempt.
|
||||
|
||||
|
||||
Common patterns like webhooks, API endpoints, etc.
|
||||
Uses path normalization and exact/prefix matching to prevent bypasses.
|
||||
|
||||
|
||||
Args:
|
||||
path: The request path.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the path is likely exempt.
|
||||
"""
|
||||
# 1. Normalize path to prevent /webhook/../ bypasses
|
||||
# Use posixpath for consistent behavior on all platforms
|
||||
import posixpath
|
||||
|
||||
normalized_path = posixpath.normpath(path)
|
||||
|
||||
|
||||
# Ensure it starts with / for comparison
|
||||
if not normalized_path.startswith("/"):
|
||||
normalized_path = "/" + normalized_path
|
||||
|
||||
|
||||
# Add back trailing slash if it was present in original path
|
||||
# to ensure prefix matching behaves as expected
|
||||
if path.endswith("/") and not normalized_path.endswith("/"):
|
||||
@@ -200,15 +202,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
# Patterns ending with / are prefix-matched
|
||||
# Patterns NOT ending with / are exact-matched
|
||||
exempt_patterns = [
|
||||
"/webhook/", # Prefix match (e.g., /webhook/stripe)
|
||||
"/webhook", # Exact match
|
||||
"/api/v1/", # Prefix match
|
||||
"/lightning/webhook/", # Prefix match
|
||||
"/webhook/", # Prefix match (e.g., /webhook/stripe)
|
||||
"/webhook", # Exact match
|
||||
"/api/v1/", # Prefix match
|
||||
"/lightning/webhook/", # Prefix match
|
||||
"/lightning/webhook", # Exact match
|
||||
"/_internal/", # Prefix match
|
||||
"/_internal", # Exact match
|
||||
"/_internal/", # Prefix match
|
||||
"/_internal", # Exact match
|
||||
]
|
||||
|
||||
|
||||
for pattern in exempt_patterns:
|
||||
if pattern.endswith("/"):
|
||||
if normalized_path.startswith(pattern):
|
||||
@@ -216,20 +218,20 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
else:
|
||||
if normalized_path == pattern:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
|
||||
"""Validate the CSRF token in the request.
|
||||
|
||||
|
||||
Checks for token in:
|
||||
1. X-CSRF-Token header
|
||||
2. csrf_token form field
|
||||
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
csrf_cookie: The expected token from the cookie.
|
||||
|
||||
|
||||
Returns:
|
||||
True if the token is valid, False otherwise.
|
||||
"""
|
||||
@@ -241,11 +243,14 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
header_token = request.headers.get(self.header_name)
|
||||
if header_token and validate_csrf_token(header_token, csrf_cookie):
|
||||
return True
|
||||
|
||||
|
||||
# If no header token, try form data (for non-JSON POSTs)
|
||||
# Check Content-Type to avoid hanging on non-form requests
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
||||
if (
|
||||
"application/x-www-form-urlencoded" in content_type
|
||||
or "multipart/form-data" in content_type
|
||||
):
|
||||
try:
|
||||
form_data = await request.form()
|
||||
form_token = form_data.get(self.form_field)
|
||||
@@ -254,5 +259,5 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
except Exception:
|
||||
# Error parsing form data, treat as invalid
|
||||
pass
|
||||
|
||||
|
||||
return False
|
||||
|
||||
@@ -4,22 +4,21 @@ Logs HTTP requests with timing, status codes, and client information
|
||||
for monitoring and debugging purposes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import List, Optional
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
logger = logging.getLogger("timmy.requests")
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to log all HTTP requests.
|
||||
|
||||
|
||||
Logs the following information for each request:
|
||||
- HTTP method and path
|
||||
- Response status code
|
||||
@@ -27,60 +26,55 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
- Client IP address
|
||||
- User-Agent header
|
||||
- Correlation ID for tracing
|
||||
|
||||
|
||||
Usage:
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
|
||||
|
||||
# Skip certain paths:
|
||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
|
||||
|
||||
|
||||
Attributes:
|
||||
skip_paths: List of URL paths to skip logging.
|
||||
log_level: Logging level for successful requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
skip_paths: Optional[List[str]] = None,
|
||||
log_level: int = logging.INFO
|
||||
):
|
||||
|
||||
def __init__(self, app, skip_paths: Optional[List[str]] = None, log_level: int = logging.INFO):
|
||||
super().__init__(app)
|
||||
self.skip_paths = set(skip_paths or [])
|
||||
self.log_level = log_level
|
||||
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Log the request and response details.
|
||||
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
call_next: Callable to get the response from downstream.
|
||||
|
||||
|
||||
Returns:
|
||||
The response from downstream.
|
||||
"""
|
||||
# Check if we should skip logging this path
|
||||
if request.url.path in self.skip_paths:
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# Generate correlation ID
|
||||
correlation_id = str(uuid.uuid4())[:8]
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
|
||||
# Record start time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# Get client info
|
||||
client_ip = self._get_client_ip(request)
|
||||
user_agent = request.headers.get("user-agent", "-")
|
||||
|
||||
|
||||
try:
|
||||
# Process the request
|
||||
response = await call_next(request)
|
||||
|
||||
|
||||
# Calculate duration
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
# Log the request
|
||||
self._log_request(
|
||||
method=request.method,
|
||||
@@ -89,14 +83,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
duration_ms=duration_ms,
|
||||
client_ip=client_ip,
|
||||
user_agent=user_agent,
|
||||
correlation_id=correlation_id
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
|
||||
# Add correlation ID to response headers
|
||||
response.headers["X-Correlation-ID"] = correlation_id
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
# Calculate duration even for failed requests
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
@@ -110,6 +104,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
# Auto-escalate: create bug report task from unhandled exception
|
||||
try:
|
||||
from infrastructure.error_capture import capture_error
|
||||
|
||||
capture_error(
|
||||
exc,
|
||||
source="http",
|
||||
@@ -126,16 +121,16 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Re-raise the exception
|
||||
raise
|
||||
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract the client IP address from the request.
|
||||
|
||||
|
||||
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
|
||||
falls back to the direct client IP.
|
||||
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
|
||||
|
||||
Returns:
|
||||
Client IP address string.
|
||||
"""
|
||||
@@ -144,17 +139,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For can contain multiple IPs, take the first one
|
||||
return forwarded_for.split(",")[0].strip()
|
||||
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
|
||||
# Fall back to direct connection
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
|
||||
return "-"
|
||||
|
||||
|
||||
def _log_request(
|
||||
self,
|
||||
method: str,
|
||||
@@ -163,10 +158,10 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
duration_ms: float,
|
||||
client_ip: str,
|
||||
user_agent: str,
|
||||
correlation_id: str
|
||||
correlation_id: str,
|
||||
) -> None:
|
||||
"""Format and log the request details.
|
||||
|
||||
|
||||
Args:
|
||||
method: HTTP method.
|
||||
path: Request path.
|
||||
@@ -182,14 +177,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
level = logging.ERROR
|
||||
elif status_code >= 400:
|
||||
level = logging.WARNING
|
||||
|
||||
|
||||
message = (
|
||||
f"[{correlation_id}] {method} {path} - {status_code} "
|
||||
f"- {duration_ms:.2f}ms - {client_ip}"
|
||||
)
|
||||
|
||||
|
||||
# Add user agent for non-health requests
|
||||
if path not in self.skip_paths:
|
||||
message += f" - {user_agent[:50]}"
|
||||
|
||||
|
||||
logger.log(level, message)
|
||||
|
||||
@@ -4,6 +4,8 @@ Adds common security headers to all HTTP responses to improve
|
||||
application security posture against various attacks.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
@@ -11,7 +13,7 @@ from starlette.responses import Response
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware to add security headers to all responses.
|
||||
|
||||
|
||||
Adds the following headers:
|
||||
- X-Content-Type-Options: Prevents MIME type sniffing
|
||||
- X-Frame-Options: Prevents clickjacking
|
||||
@@ -20,41 +22,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
- Permissions-Policy: Restricts feature access
|
||||
- Content-Security-Policy: Mitigates XSS and data injection
|
||||
- Strict-Transport-Security: Enforces HTTPS (production only)
|
||||
|
||||
|
||||
Usage:
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
# Or with production settings:
|
||||
app.add_middleware(SecurityHeadersMiddleware, production=True)
|
||||
|
||||
|
||||
Attributes:
|
||||
production: If True, adds HSTS header for HTTPS enforcement.
|
||||
csp_report_only: If True, sends CSP in report-only mode.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
production: bool = False,
|
||||
csp_report_only: bool = False,
|
||||
custom_csp: str = None
|
||||
custom_csp: Optional[str] = None,
|
||||
):
|
||||
super().__init__(app)
|
||||
self.production = production
|
||||
self.csp_report_only = csp_report_only
|
||||
|
||||
|
||||
# Build CSP directive
|
||||
self.csp_directive = custom_csp or self._build_csp()
|
||||
|
||||
|
||||
def _build_csp(self) -> str:
|
||||
"""Build the Content-Security-Policy directive.
|
||||
|
||||
|
||||
Creates a restrictive default policy that allows:
|
||||
- Same-origin resources by default
|
||||
- Inline scripts/styles (needed for HTMX/Bootstrap)
|
||||
- Data URIs for images
|
||||
- WebSocket connections
|
||||
|
||||
|
||||
Returns:
|
||||
CSP directive string.
|
||||
"""
|
||||
@@ -73,25 +75,25 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"form-action 'self'",
|
||||
]
|
||||
return "; ".join(directives)
|
||||
|
||||
|
||||
def _add_security_headers(self, response: Response) -> None:
|
||||
"""Add security headers to a response.
|
||||
|
||||
|
||||
Args:
|
||||
response: The response to add headers to.
|
||||
"""
|
||||
# Prevent MIME type sniffing
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
|
||||
|
||||
# Prevent clickjacking
|
||||
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||
|
||||
|
||||
# Enable XSS protection (legacy browsers)
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
|
||||
|
||||
# Control referrer information
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
|
||||
# Restrict browser features
|
||||
response.headers["Permissions-Policy"] = (
|
||||
"camera=(), "
|
||||
@@ -103,38 +105,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"gyroscope=(), "
|
||||
"accelerometer=()"
|
||||
)
|
||||
|
||||
|
||||
# Content Security Policy
|
||||
csp_header = "Content-Security-Policy-Report-Only" if self.csp_report_only else "Content-Security-Policy"
|
||||
csp_header = (
|
||||
"Content-Security-Policy-Report-Only"
|
||||
if self.csp_report_only
|
||||
else "Content-Security-Policy"
|
||||
)
|
||||
response.headers[csp_header] = self.csp_directive
|
||||
|
||||
|
||||
# HTTPS enforcement (production only)
|
||||
if self.production:
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=31536000; includeSubDomains; preload"
|
||||
)
|
||||
|
||||
response.headers[
|
||||
"Strict-Transport-Security"
|
||||
] = "max-age=31536000; includeSubDomains; preload"
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Add security headers to the response.
|
||||
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
call_next: Callable to get the response from downstream.
|
||||
|
||||
|
||||
Returns:
|
||||
Response with security headers added.
|
||||
"""
|
||||
try:
|
||||
response = await call_next(request)
|
||||
self._add_security_headers(response)
|
||||
return response
|
||||
except Exception:
|
||||
# Create a response for the error with security headers
|
||||
from starlette.responses import PlainTextResponse
|
||||
response = PlainTextResponse(
|
||||
content="Internal Server Error",
|
||||
status_code=500
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).debug(
|
||||
"Upstream error in security headers middleware", exc_info=True
|
||||
)
|
||||
self._add_security_headers(response)
|
||||
# Return the error response with headers (don't re-raise)
|
||||
return response
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
response = PlainTextResponse("Internal Server Error", status_code=500)
|
||||
self._add_security_headers(response)
|
||||
return response
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
|
||||
from datetime import datetime, date
|
||||
from datetime import date, datetime
|
||||
from enum import Enum as PyEnum
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, DateTime, Boolean, Enum as SQLEnum,
|
||||
Date, ForeignKey, Index, JSON
|
||||
)
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, Date, DateTime
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from sqlalchemy import ForeignKey, Index, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from .database import Base # Assuming a shared Base in models/database.py
|
||||
|
||||
|
||||
class TaskState(str, PyEnum):
|
||||
LATER = "LATER"
|
||||
NEXT = "NEXT"
|
||||
NOW = "NOW"
|
||||
DONE = "DONE"
|
||||
DEFERRED = "DEFERRED" # Task pushed to tomorrow
|
||||
DEFERRED = "DEFERRED" # Task pushed to tomorrow
|
||||
|
||||
|
||||
class TaskCertainty(str, PyEnum):
|
||||
FUZZY = "FUZZY" # An intention without a time
|
||||
SOFT = "SOFT" # A flexible task with a time
|
||||
HARD = "HARD" # A fixed meeting/appointment
|
||||
FUZZY = "FUZZY" # An intention without a time
|
||||
SOFT = "SOFT" # A flexible task with a time
|
||||
HARD = "HARD" # A fixed meeting/appointment
|
||||
|
||||
|
||||
class Task(Base):
|
||||
__tablename__ = "tasks"
|
||||
@@ -29,7 +32,7 @@ class Task(Base):
|
||||
|
||||
state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True)
|
||||
certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False)
|
||||
is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day
|
||||
is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day
|
||||
|
||||
sort_order = Column(Integer, default=0, nullable=False)
|
||||
|
||||
@@ -42,7 +45,8 @@ class Task(Base):
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||
|
||||
__table_args__ = (Index('ix_task_state_order', 'state', 'sort_order'),)
|
||||
__table_args__ = (Index("ix_task_state_order", "state", "sort_order"),)
|
||||
|
||||
|
||||
class JournalEntry(Base):
|
||||
__tablename__ = "journal_entries"
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db"
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def create_tables():
|
||||
"""Create all tables defined by models that have imported Base."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
@@ -5,9 +5,9 @@ from datetime import datetime
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from timmy.session import chat as agent_chat
|
||||
from dashboard.store import message_log
|
||||
from dashboard.templating import templates
|
||||
from timmy.session import chat as agent_chat
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,9 +38,7 @@ async def list_agents():
|
||||
@router.get("/default/panel", response_class=HTMLResponse)
|
||||
async def agent_panel(request: Request):
|
||||
"""Chat panel — for HTMX main-panel swaps."""
|
||||
return templates.TemplateResponse(
|
||||
request, "partials/agent_panel_chat.html", {"agent": None}
|
||||
)
|
||||
return templates.TemplateResponse(request, "partials/agent_panel_chat.html", {"agent": None})
|
||||
|
||||
|
||||
@router.get("/default/history", response_class=HTMLResponse)
|
||||
@@ -77,7 +75,9 @@ async def chat_agent(request: Request, message: str = Form(...)):
|
||||
|
||||
message_log.append(role="user", content=message, timestamp=timestamp, source="browser")
|
||||
if response_text is not None:
|
||||
message_log.append(role="agent", content=response_text, timestamp=timestamp, source="browser")
|
||||
message_log.append(
|
||||
role="agent", content=response_text, timestamp=timestamp, source="browser"
|
||||
)
|
||||
elif error_text:
|
||||
message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser")
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@ from datetime import datetime, timezone
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from timmy.briefing import Briefing, engine as briefing_engine
|
||||
from timmy import approvals as approval_store
|
||||
from dashboard.templating import templates
|
||||
from timmy import approvals as approval_store
|
||||
from timmy.briefing import Briefing
|
||||
from timmy.briefing import engine as briefing_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import logging
|
||||
from datetime import date, datetime
|
||||
from typing import List, Optional
|
||||
@@ -8,7 +7,7 @@ from fastapi.responses import HTMLResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState
|
||||
from dashboard.models.database import SessionLocal, engine, get_db, create_tables
|
||||
from dashboard.models.database import SessionLocal, create_tables, engine, get_db
|
||||
from dashboard.templating import templates
|
||||
|
||||
# Ensure CALM tables exist (safe to call multiple times)
|
||||
@@ -23,11 +22,19 @@ router = APIRouter(tags=["calm"])
|
||||
def get_now_task(db: Session) -> Optional[Task]:
|
||||
return db.query(Task).filter(Task.state == TaskState.NOW).first()
|
||||
|
||||
|
||||
def get_next_task(db: Session) -> Optional[Task]:
|
||||
return db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
||||
|
||||
|
||||
def get_later_tasks(db: Session) -> List[Task]:
|
||||
return db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
||||
return (
|
||||
db.query(Task)
|
||||
.filter(Task.state == TaskState.LATER)
|
||||
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def promote_tasks(db: Session):
|
||||
# Ensure only one NOW task exists. If multiple, demote extras to NEXT.
|
||||
@@ -38,7 +45,7 @@ def promote_tasks(db: Session):
|
||||
for task_to_demote in now_tasks[1:]:
|
||||
task_to_demote.state = TaskState.NEXT
|
||||
db.add(task_to_demote)
|
||||
db.flush() # Make changes visible
|
||||
db.flush() # Make changes visible
|
||||
|
||||
# If no NOW task, promote NEXT to NOW
|
||||
current_now = db.query(Task).filter(Task.state == TaskState.NOW).first()
|
||||
@@ -47,12 +54,17 @@ def promote_tasks(db: Session):
|
||||
if next_task:
|
||||
next_task.state = TaskState.NOW
|
||||
db.add(next_task)
|
||||
db.flush() # Make changes visible
|
||||
db.flush() # Make changes visible
|
||||
|
||||
# If no NEXT task, promote highest priority LATER to NEXT
|
||||
current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
||||
if not current_next:
|
||||
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
||||
later_tasks = (
|
||||
db.query(Task)
|
||||
.filter(Task.state == TaskState.LATER)
|
||||
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||
.all()
|
||||
)
|
||||
if later_tasks:
|
||||
later_tasks[0].state = TaskState.NEXT
|
||||
db.add(later_tasks[0])
|
||||
@@ -60,14 +72,17 @@ def promote_tasks(db: Session):
|
||||
db.commit()
|
||||
|
||||
|
||||
|
||||
# Endpoints
|
||||
@router.get("/calm", response_class=HTMLResponse)
|
||||
async def get_calm_view(request: Request, db: Session = Depends(get_db)):
|
||||
now_task = get_now_task(db)
|
||||
next_task = get_next_task(db)
|
||||
later_tasks_count = len(get_later_tasks(db))
|
||||
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": now_task,
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/calm_view.html",
|
||||
{
|
||||
"now_task": now_task,
|
||||
"next_task": next_task,
|
||||
"later_tasks_count": later_tasks_count,
|
||||
},
|
||||
@@ -101,7 +116,7 @@ async def post_morning_ritual(
|
||||
task = Task(
|
||||
title=mit_title,
|
||||
is_mit=True,
|
||||
state=TaskState.LATER, # Initially LATER, will be promoted
|
||||
state=TaskState.LATER, # Initially LATER, will be promoted
|
||||
certainty=TaskCertainty.SOFT,
|
||||
)
|
||||
db.add(task)
|
||||
@@ -113,7 +128,7 @@ async def post_morning_ritual(
|
||||
db.add(journal_entry)
|
||||
|
||||
# Create other tasks
|
||||
for task_title in other_tasks.split('\n'):
|
||||
for task_title in other_tasks.split("\n"):
|
||||
task_title = task_title.strip()
|
||||
if task_title:
|
||||
task = Task(
|
||||
@@ -128,20 +143,29 @@ async def post_morning_ritual(
|
||||
# Set initial NOW/NEXT states
|
||||
# Set initial NOW/NEXT states after all tasks are created
|
||||
if not get_now_task(db) and not get_next_task(db):
|
||||
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
||||
later_tasks = (
|
||||
db.query(Task)
|
||||
.filter(Task.state == TaskState.LATER)
|
||||
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||
.all()
|
||||
)
|
||||
if later_tasks:
|
||||
# Set the highest priority LATER task to NOW
|
||||
later_tasks[0].state = TaskState.NOW
|
||||
db.add(later_tasks[0])
|
||||
db.flush() # Flush to make the change visible for the next query
|
||||
db.flush() # Flush to make the change visible for the next query
|
||||
|
||||
# Set the next highest priority LATER task to NEXT
|
||||
if len(later_tasks) > 1:
|
||||
later_tasks[1].state = TaskState.NEXT
|
||||
db.add(later_tasks[1])
|
||||
db.commit() # Commit changes after initial NOW/NEXT setup
|
||||
db.commit() # Commit changes after initial NOW/NEXT setup
|
||||
|
||||
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": get_now_task(db),
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"calm/calm_view.html",
|
||||
{
|
||||
"now_task": get_now_task(db),
|
||||
"next_task": get_next_task(db),
|
||||
"later_tasks_count": len(get_later_tasks(db)),
|
||||
},
|
||||
@@ -154,7 +178,8 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db
|
||||
if not journal_entry:
|
||||
raise HTTPException(status_code=404, detail="No journal entry for today")
|
||||
return templates.TemplateResponse(
|
||||
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry})
|
||||
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/calm/ritual/evening", response_class=HTMLResponse)
|
||||
@@ -175,9 +200,13 @@ async def post_evening_ritual(
|
||||
db.add(journal_entry)
|
||||
|
||||
# Archive any remaining active tasks
|
||||
active_tasks = db.query(Task).filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER])).all()
|
||||
active_tasks = (
|
||||
db.query(Task)
|
||||
.filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER]))
|
||||
.all()
|
||||
)
|
||||
for task in active_tasks:
|
||||
task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic
|
||||
task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic
|
||||
task.deferred_at = datetime.utcnow()
|
||||
db.add(task)
|
||||
|
||||
@@ -221,7 +250,7 @@ async def start_task(
|
||||
):
|
||||
current_now_task = get_now_task(db)
|
||||
if current_now_task and current_now_task.id != task_id:
|
||||
current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT
|
||||
current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT
|
||||
db.add(current_now_task)
|
||||
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
@@ -322,7 +351,7 @@ async def reorder_tasks(
|
||||
):
|
||||
# Reorder LATER tasks
|
||||
if later_task_ids:
|
||||
ids_in_order = [int(x.strip()) for x in later_task_ids.split(',') if x.strip()]
|
||||
ids_in_order = [int(x.strip()) for x in later_task_ids.split(",") if x.strip()]
|
||||
for index, task_id in enumerate(ids_in_order):
|
||||
task = db.query(Task).filter(Task.id == task_id).first()
|
||||
if task and task.state == TaskState.LATER:
|
||||
@@ -332,16 +361,18 @@ async def reorder_tasks(
|
||||
# Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly)
|
||||
if next_task_id:
|
||||
task = db.query(Task).filter(Task.id == next_task_id).first()
|
||||
if task and task.state == TaskState.LATER: # Only if it was a LATER task being promoted manually
|
||||
if (
|
||||
task and task.state == TaskState.LATER
|
||||
): # Only if it was a LATER task being promoted manually
|
||||
# Demote current NEXT to LATER
|
||||
current_next = get_next_task(db)
|
||||
if current_next:
|
||||
current_next.state = TaskState.LATER
|
||||
current_next.sort_order = len(get_later_tasks(db)) # Add to end of later
|
||||
current_next.sort_order = len(get_later_tasks(db)) # Add to end of later
|
||||
db.add(current_next)
|
||||
|
||||
task.state = TaskState.NEXT
|
||||
task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency
|
||||
task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency
|
||||
db.add(task)
|
||||
|
||||
db.commit()
|
||||
|
||||
@@ -27,12 +27,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["chat-api"])
|
||||
|
||||
_UPLOAD_DIR = os.path.join("data", "chat-uploads")
|
||||
_UPLOAD_DIR = str(Path(settings.repo_root) / "data" / "chat-uploads")
|
||||
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
# ── POST /api/chat ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def api_chat(request: Request):
|
||||
"""Accept a JSON chat payload and return the agent's reply.
|
||||
@@ -65,7 +66,8 @@ async def api_chat(request: Request):
|
||||
# Handle multimodal content arrays — extract text parts
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
p.get("text", "") for p in content
|
||||
p.get("text", "")
|
||||
for p in content
|
||||
if isinstance(p, dict) and p.get("type") == "text"
|
||||
]
|
||||
last_user_msg = " ".join(text_parts).strip()
|
||||
@@ -109,6 +111,7 @@ async def api_chat(request: Request):
|
||||
|
||||
# ── POST /api/upload ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def api_upload(file: UploadFile = File(...)):
|
||||
"""Accept a file upload and return its URL.
|
||||
@@ -147,6 +150,7 @@ async def api_upload(file: UploadFile = File(...)):
|
||||
|
||||
# ── GET /api/chat/history ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/chat/history")
|
||||
async def api_chat_history():
|
||||
"""Return the in-memory chat history as JSON."""
|
||||
@@ -165,6 +169,7 @@ async def api_chat_history():
|
||||
|
||||
# ── DELETE /api/chat/history ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.delete("/chat/history")
|
||||
async def api_clear_history():
|
||||
"""Clear the in-memory chat history."""
|
||||
|
||||
@@ -7,9 +7,10 @@ Endpoints:
|
||||
GET /discord/oauth-url — get the bot's OAuth2 authorization URL
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter(prefix="/discord", tags=["discord"])
|
||||
|
||||
|
||||
77
src/dashboard/routes/experiments.py
Normal file
77
src/dashboard/routes/experiments.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Experiment dashboard routes — autoresearch experiment monitoring.
|
||||
|
||||
Provides endpoints for viewing, starting, and monitoring autonomous
|
||||
ML experiment loops powered by Karpathy's autoresearch pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from config import settings
|
||||
from dashboard.templating import templates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/experiments", tags=["experiments"])
|
||||
|
||||
|
||||
def _workspace() -> Path:
|
||||
return Path(settings.repo_root) / settings.autoresearch_workspace
|
||||
|
||||
|
||||
@router.get("", response_class=HTMLResponse)
|
||||
async def experiments_page(request: Request):
|
||||
"""Experiment dashboard — lists past runs and allows starting new ones."""
|
||||
from timmy.autoresearch import get_experiment_history
|
||||
|
||||
history = []
|
||||
try:
|
||||
history = get_experiment_history(_workspace())
|
||||
except Exception:
|
||||
logger.debug("Failed to load experiment history", exc_info=True)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"experiments.html",
|
||||
{
|
||||
"page_title": "Experiments — Autoresearch",
|
||||
"enabled": settings.autoresearch_enabled,
|
||||
"history": history[:50],
|
||||
"metric_name": settings.autoresearch_metric,
|
||||
"time_budget": settings.autoresearch_time_budget,
|
||||
"max_iterations": settings.autoresearch_max_iterations,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/start", response_class=JSONResponse)
|
||||
async def start_experiment(request: Request):
|
||||
"""Kick off an experiment loop in the background."""
|
||||
if not settings.autoresearch_enabled:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Autoresearch is disabled. Set AUTORESEARCH_ENABLED=true.",
|
||||
)
|
||||
|
||||
from timmy.autoresearch import prepare_experiment
|
||||
|
||||
workspace = _workspace()
|
||||
status = prepare_experiment(workspace)
|
||||
|
||||
return {"status": "started", "workspace": str(workspace), "prepare": status}
|
||||
|
||||
|
||||
@router.get("/{run_id}", response_class=JSONResponse)
|
||||
async def experiment_detail(run_id: str):
|
||||
"""Get details for a specific experiment run."""
|
||||
from timmy.autoresearch import get_experiment_history
|
||||
|
||||
history = get_experiment_history(_workspace())
|
||||
for entry in history:
|
||||
if entry.get("run_id") == run_id:
|
||||
return entry
|
||||
|
||||
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||
@@ -43,6 +43,7 @@ async def grok_status(request: Request):
|
||||
stats = None
|
||||
try:
|
||||
from timmy.backends import get_grok_backend
|
||||
|
||||
backend = get_grok_backend()
|
||||
stats = {
|
||||
"total_requests": backend.stats.total_requests,
|
||||
@@ -52,12 +53,16 @@ async def grok_status(request: Request):
|
||||
"errors": backend.stats.errors,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Failed to load Grok stats", exc_info=True)
|
||||
|
||||
return templates.TemplateResponse(request, "grok_status.html", {
|
||||
"status": status,
|
||||
"stats": stats,
|
||||
})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"grok_status.html",
|
||||
{
|
||||
"status": status,
|
||||
"stats": stats,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/toggle")
|
||||
@@ -90,7 +95,7 @@ async def toggle_grok_mode(request: Request):
|
||||
success=True,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Failed to log Grok toggle to Spark", exc_info=True)
|
||||
|
||||
return HTMLResponse(
|
||||
_render_toggle_card(_grok_mode_active),
|
||||
@@ -104,10 +109,13 @@ def _run_grok_query(message: str) -> dict:
|
||||
Returns:
|
||||
{"response": str | None, "error": str | None}
|
||||
"""
|
||||
from timmy.backends import grok_available, get_grok_backend
|
||||
from timmy.backends import get_grok_backend, grok_available
|
||||
|
||||
if not grok_available():
|
||||
return {"response": None, "error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY."}
|
||||
return {
|
||||
"response": None,
|
||||
"error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY.",
|
||||
}
|
||||
|
||||
backend = get_grok_backend()
|
||||
|
||||
@@ -115,12 +123,13 @@ def _run_grok_query(message: str) -> dict:
|
||||
if not settings.grok_free:
|
||||
try:
|
||||
from lightning.factory import get_backend as get_ln_backend
|
||||
|
||||
ln = get_ln_backend()
|
||||
sats = min(settings.grok_max_sats_per_query, 100)
|
||||
ln.create_invoice(sats, f"Grok: {message[:50]}")
|
||||
invoice_note = f" | {sats} sats"
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Lightning invoice creation failed", exc_info=True)
|
||||
|
||||
try:
|
||||
result = backend.run(message)
|
||||
@@ -132,9 +141,10 @@ def _run_grok_query(message: str) -> dict:
|
||||
@router.post("/chat", response_class=HTMLResponse)
|
||||
async def grok_chat(request: Request, message: str = Form(...)):
|
||||
"""Send a message directly to Grok and return HTMX chat partial."""
|
||||
from dashboard.store import message_log
|
||||
from datetime import datetime
|
||||
|
||||
from dashboard.store import message_log
|
||||
|
||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||
result = _run_grok_query(message)
|
||||
|
||||
@@ -142,9 +152,13 @@ async def grok_chat(request: Request, message: str = Form(...)):
|
||||
message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser")
|
||||
|
||||
if result["response"]:
|
||||
message_log.append(role="agent", content=result["response"], timestamp=timestamp, source="browser")
|
||||
message_log.append(
|
||||
role="agent", content=result["response"], timestamp=timestamp, source="browser"
|
||||
)
|
||||
else:
|
||||
message_log.append(role="error", content=result["error"], timestamp=timestamp, source="browser")
|
||||
message_log.append(
|
||||
role="error", content=result["error"], timestamp=timestamp, source="browser"
|
||||
)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@@ -185,6 +199,7 @@ async def grok_stats():
|
||||
def _render_toggle_card(active: bool) -> str:
|
||||
"""Render the Grok Mode toggle card HTML."""
|
||||
import html
|
||||
|
||||
color = "#00ff88" if active else "#666"
|
||||
state = "ACTIVE" if active else "STANDBY"
|
||||
glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none"
|
||||
|
||||
@@ -22,6 +22,7 @@ router = APIRouter(tags=["health"])
|
||||
|
||||
class DependencyStatus(BaseModel):
|
||||
"""Status of a single dependency."""
|
||||
|
||||
name: str
|
||||
status: str # "healthy", "degraded", "unavailable"
|
||||
sovereignty_score: int # 0-10
|
||||
@@ -30,6 +31,7 @@ class DependencyStatus(BaseModel):
|
||||
|
||||
class SovereigntyReport(BaseModel):
|
||||
"""Full sovereignty audit report."""
|
||||
|
||||
overall_score: float
|
||||
dependencies: list[DependencyStatus]
|
||||
timestamp: str
|
||||
@@ -38,6 +40,7 @@ class SovereigntyReport(BaseModel):
|
||||
|
||||
class HealthStatus(BaseModel):
|
||||
"""System health status."""
|
||||
|
||||
status: str
|
||||
timestamp: str
|
||||
version: str
|
||||
@@ -52,6 +55,7 @@ def _check_ollama_sync() -> DependencyStatus:
|
||||
"""Synchronous Ollama check — run via asyncio.to_thread()."""
|
||||
try:
|
||||
import urllib.request
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
@@ -67,7 +71,7 @@ def _check_ollama_sync() -> DependencyStatus:
|
||||
details={"url": settings.ollama_url, "model": settings.ollama_model},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Ollama health check failed", exc_info=True)
|
||||
|
||||
return DependencyStatus(
|
||||
name="Ollama AI",
|
||||
@@ -142,7 +146,7 @@ def _calculate_overall_score(deps: list[DependencyStatus]) -> float:
|
||||
def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
|
||||
"""Generate recommendations based on dependency status."""
|
||||
recommendations = []
|
||||
|
||||
|
||||
for dep in deps:
|
||||
if dep.status == "unavailable":
|
||||
recommendations.append(f"{dep.name} is unavailable - check configuration")
|
||||
@@ -151,25 +155,25 @@ def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
|
||||
recommendations.append(
|
||||
"Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND"
|
||||
)
|
||||
|
||||
|
||||
if not recommendations:
|
||||
recommendations.append("System operating optimally - all dependencies healthy")
|
||||
|
||||
|
||||
return recommendations
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""Basic health check endpoint.
|
||||
|
||||
|
||||
Returns legacy format for backward compatibility with existing tests,
|
||||
plus extended information for the Mission Control dashboard.
|
||||
"""
|
||||
uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds()
|
||||
|
||||
|
||||
# Legacy format for test compatibility
|
||||
ollama_ok = await check_ollama()
|
||||
|
||||
|
||||
agent_status = "idle" if ollama_ok else "offline"
|
||||
|
||||
return {
|
||||
@@ -193,12 +197,13 @@ async def health_check():
|
||||
async def health_status_panel(request: Request):
|
||||
"""Simple HTML health status panel."""
|
||||
ollama_ok = await check_ollama()
|
||||
|
||||
|
||||
status_text = "UP" if ollama_ok else "DOWN"
|
||||
status_color = "#10b981" if ollama_ok else "#ef4444"
|
||||
import html
|
||||
|
||||
model = html.escape(settings.ollama_model) # Include model for test compatibility
|
||||
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
@@ -217,7 +222,7 @@ async def health_status_panel(request: Request):
|
||||
@router.get("/health/sovereignty", response_model=SovereigntyReport)
|
||||
async def sovereignty_check():
|
||||
"""Comprehensive sovereignty audit report.
|
||||
|
||||
|
||||
Returns the status of all external dependencies with sovereignty scores.
|
||||
Use this to verify the system is operating in a sovereign manner.
|
||||
"""
|
||||
@@ -226,10 +231,10 @@ async def sovereignty_check():
|
||||
_check_lightning(),
|
||||
_check_sqlite(),
|
||||
]
|
||||
|
||||
|
||||
overall = _calculate_overall_score(dependencies)
|
||||
recommendations = _generate_recommendations(dependencies)
|
||||
|
||||
|
||||
return SovereigntyReport(
|
||||
overall_score=overall,
|
||||
dependencies=dependencies,
|
||||
|
||||
@@ -19,8 +19,7 @@ AGENT_CATALOG = [
|
||||
"name": "Orchestrator",
|
||||
"role": "Local AI",
|
||||
"description": (
|
||||
"Primary AI agent. Coordinates tasks, manages memory. "
|
||||
"Uses distributed brain."
|
||||
"Primary AI agent. Coordinates tasks, manages memory. " "Uses distributed brain."
|
||||
),
|
||||
"capabilities": "chat,reasoning,coordination,memory",
|
||||
"rate_sats": 0,
|
||||
@@ -37,11 +36,11 @@ async def api_list_agents():
|
||||
pending_tasks = len(await brain.get_pending_tasks(limit=1000))
|
||||
except Exception:
|
||||
pending_tasks = 0
|
||||
|
||||
|
||||
catalog = [dict(AGENT_CATALOG[0])]
|
||||
catalog[0]["pending_tasks"] = pending_tasks
|
||||
catalog[0]["status"] = "active"
|
||||
|
||||
|
||||
# Include 'total' for backward compatibility with tests
|
||||
return {"agents": catalog, "total": len(catalog)}
|
||||
|
||||
@@ -82,7 +81,7 @@ async def marketplace_ui(request: Request):
|
||||
"page_title": "Agent Marketplace",
|
||||
"active_count": active,
|
||||
"planned_count": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,17 +5,17 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from dashboard.templating import templates
|
||||
from timmy.memory.vector_store import (
|
||||
store_memory,
|
||||
search_memories,
|
||||
delete_memory,
|
||||
get_memory_stats,
|
||||
recall_personal_facts,
|
||||
recall_personal_facts_with_ids,
|
||||
search_memories,
|
||||
store_memory,
|
||||
store_personal_fact,
|
||||
update_personal_fact,
|
||||
delete_memory,
|
||||
)
|
||||
from dashboard.templating import templates
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
@@ -36,10 +36,10 @@ async def memory_page(
|
||||
agent_id=agent_id,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
|
||||
stats = get_memory_stats()
|
||||
facts = recall_personal_facts_with_ids()[:10]
|
||||
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"memory.html",
|
||||
@@ -67,7 +67,7 @@ async def memory_search(
|
||||
context_type=context_type,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
|
||||
# Return partial for HTMX
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
|
||||
@@ -13,6 +13,7 @@ from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import settings
|
||||
from dashboard.templating import templates
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
@@ -20,7 +21,6 @@ from infrastructure.models.registry import (
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from dashboard.templating import templates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +33,7 @@ api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
|
||||
|
||||
class RegisterModelRequest(BaseModel):
|
||||
"""Request body for model registration."""
|
||||
|
||||
name: str
|
||||
format: str # gguf, safetensors, hf, ollama
|
||||
path: str
|
||||
@@ -45,12 +46,14 @@ class RegisterModelRequest(BaseModel):
|
||||
|
||||
class AssignModelRequest(BaseModel):
|
||||
"""Request body for assigning a model to an agent."""
|
||||
|
||||
agent_id: str
|
||||
model_name: str
|
||||
|
||||
|
||||
class SetActiveRequest(BaseModel):
|
||||
"""Request body for enabling/disabling a model."""
|
||||
|
||||
active: bool
|
||||
|
||||
|
||||
@@ -92,15 +95,14 @@ async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid format: {request.format}. "
|
||||
f"Choose from: {[f.value for f in ModelFormat]}",
|
||||
f"Choose from: {[f.value for f in ModelFormat]}",
|
||||
)
|
||||
try:
|
||||
role = ModelRole(request.role)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid role: {request.role}. "
|
||||
f"Choose from: {[r.value for r in ModelRole]}",
|
||||
detail=f"Invalid role: {request.role}. " f"Choose from: {[r.value for r in ModelRole]}",
|
||||
)
|
||||
|
||||
# Validate path exists for non-Ollama formats
|
||||
@@ -163,9 +165,7 @@ async def unregister_model(model_name: str) -> dict[str, str]:
|
||||
|
||||
|
||||
@api_router.patch("/{model_name}/active")
|
||||
async def set_model_active(
|
||||
model_name: str, request: SetActiveRequest
|
||||
) -> dict[str, str]:
|
||||
async def set_model_active(model_name: str, request: SetActiveRequest) -> dict[str, str]:
|
||||
"""Enable or disable a model."""
|
||||
if not model_registry.set_active(model_name, request.active):
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
@@ -182,8 +182,7 @@ async def list_assignments() -> dict[str, Any]:
|
||||
assignments = model_registry.get_agent_assignments()
|
||||
return {
|
||||
"assignments": [
|
||||
{"agent_id": aid, "model_name": mname}
|
||||
for aid, mname in assignments.items()
|
||||
{"agent_id": aid, "model_name": mname} for aid, mname in assignments.items()
|
||||
],
|
||||
"total": len(assignments),
|
||||
}
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from timmy.cascade_adapter import get_cascade_adapter
|
||||
from dashboard.templating import templates
|
||||
from timmy.cascade_adapter import get_cascade_adapter
|
||||
|
||||
router = APIRouter(prefix="/router", tags=["router"])
|
||||
|
||||
@@ -13,19 +13,19 @@ router = APIRouter(prefix="/router", tags=["router"])
|
||||
async def router_status_page(request: Request):
|
||||
"""Cascade Router status dashboard."""
|
||||
adapter = get_cascade_adapter()
|
||||
|
||||
|
||||
providers = adapter.get_provider_status()
|
||||
preferred = adapter.get_preferred_provider()
|
||||
|
||||
|
||||
# Calculate overall stats
|
||||
total_requests = sum(p["metrics"]["total"] for p in providers)
|
||||
total_success = sum(p["metrics"]["success"] for p in providers)
|
||||
total_failed = sum(p["metrics"]["failed"] for p in providers)
|
||||
|
||||
|
||||
avg_latency = 0.0
|
||||
if providers:
|
||||
avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers)
|
||||
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"router_status.html",
|
||||
|
||||
@@ -13,8 +13,8 @@ import logging
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from spark.engine import spark_engine
|
||||
from dashboard.templating import templates
|
||||
from spark.engine import spark_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,23 +86,26 @@ async def spark_ui(request: Request):
|
||||
async def spark_status_json():
|
||||
"""Return Spark Intelligence status as JSON."""
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
status = spark_engine.status()
|
||||
advisories = spark_engine.get_advisories()
|
||||
return JSONResponse({
|
||||
"status": status,
|
||||
"advisories": [
|
||||
{
|
||||
"category": a.category,
|
||||
"priority": a.priority,
|
||||
"title": a.title,
|
||||
"detail": a.detail,
|
||||
"suggested_action": a.suggested_action,
|
||||
"subject": a.subject,
|
||||
"evidence_count": a.evidence_count,
|
||||
}
|
||||
for a in advisories
|
||||
],
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"status": status,
|
||||
"advisories": [
|
||||
{
|
||||
"category": a.category,
|
||||
"priority": a.priority,
|
||||
"title": a.title,
|
||||
"detail": a.detail,
|
||||
"suggested_action": a.suggested_action,
|
||||
"subject": a.subject,
|
||||
"evidence_count": a.evidence_count,
|
||||
}
|
||||
for a in advisories
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/timeline", response_class=HTMLResponse)
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from spark.engine import spark_engine
|
||||
from dashboard.templating import templates
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
from spark.engine import spark_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,7 @@ async def swarm_events(
|
||||
):
|
||||
"""Event log page."""
|
||||
events = spark_engine.get_timeline(limit=100)
|
||||
|
||||
|
||||
# Filter if requested
|
||||
if task_id:
|
||||
events = [e for e in events if e.task_id == task_id]
|
||||
@@ -33,7 +33,7 @@ async def swarm_events(
|
||||
events = [e for e in events if e.agent_id == agent_id]
|
||||
if event_type:
|
||||
events = [e for e in events if e.event_type == event_type]
|
||||
|
||||
|
||||
# Prepare summary and event types for template
|
||||
summary = {}
|
||||
event_types = set()
|
||||
@@ -41,7 +41,7 @@ async def swarm_events(
|
||||
etype = e.event_type
|
||||
event_types.add(etype)
|
||||
summary[etype] = summary.get(etype, 0) + 1
|
||||
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"events.html",
|
||||
@@ -78,14 +78,16 @@ async def swarm_ws(websocket: WebSocket):
|
||||
await ws_manager.connect(websocket)
|
||||
try:
|
||||
# Send initial state so frontend can clear loading placeholders
|
||||
await websocket.send_json({
|
||||
"type": "initial_state",
|
||||
"data": {
|
||||
"agents": {"total": 0, "active": 0, "list": []},
|
||||
"tasks": {"active": 0},
|
||||
"auctions": {"list": []},
|
||||
},
|
||||
})
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "initial_state",
|
||||
"data": {
|
||||
"agents": {"total": 0, "active": 0, "list": []},
|
||||
"tasks": {"active": 0},
|
||||
"auctions": {"list": []},
|
||||
},
|
||||
}
|
||||
)
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
|
||||
@@ -25,26 +25,42 @@ async def lightning_ledger(request: Request):
|
||||
"pending_incoming_sats": 0,
|
||||
"pending_outgoing_sats": 0,
|
||||
}
|
||||
|
||||
|
||||
# Mock transactions
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TxType(Enum):
|
||||
incoming = "incoming"
|
||||
outgoing = "outgoing"
|
||||
|
||||
|
||||
class TxStatus(Enum):
|
||||
completed = "completed"
|
||||
pending = "pending"
|
||||
|
||||
Tx = namedtuple("Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"])
|
||||
|
||||
|
||||
Tx = namedtuple(
|
||||
"Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"]
|
||||
)
|
||||
|
||||
transactions = [
|
||||
Tx(TxType.outgoing, TxStatus.completed, 50, "hash1", "Model inference", "2026-03-04 10:00:00"),
|
||||
Tx(TxType.incoming, TxStatus.completed, 1000, "hash2", "Manual deposit", "2026-03-03 15:00:00"),
|
||||
Tx(
|
||||
TxType.outgoing,
|
||||
TxStatus.completed,
|
||||
50,
|
||||
"hash1",
|
||||
"Model inference",
|
||||
"2026-03-04 10:00:00",
|
||||
),
|
||||
Tx(
|
||||
TxType.incoming,
|
||||
TxStatus.completed,
|
||||
1000,
|
||||
"hash2",
|
||||
"Manual deposit",
|
||||
"2026-03-03 15:00:00",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"ledger.html",
|
||||
@@ -84,9 +100,16 @@ async def mission_control(request: Request):
|
||||
|
||||
@router.get("/bugs", response_class=HTMLResponse)
|
||||
async def bugs_page(request: Request):
|
||||
return templates.TemplateResponse(request, "bugs.html", {
|
||||
"bugs": [], "total": 0, "stats": {}, "filter_status": None,
|
||||
})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"bugs.html",
|
||||
{
|
||||
"bugs": [],
|
||||
"total": 0,
|
||||
"stats": {},
|
||||
"filter_status": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/self-coding", response_class=HTMLResponse)
|
||||
@@ -109,14 +132,17 @@ async def api_notifications():
|
||||
"""Return recent system events for the notification dropdown."""
|
||||
try:
|
||||
from spark.engine import spark_engine
|
||||
|
||||
events = spark_engine.get_timeline(limit=20)
|
||||
return JSONResponse([
|
||||
{
|
||||
"event_type": e.event_type,
|
||||
"title": getattr(e, "description", e.event_type),
|
||||
"timestamp": str(getattr(e, "timestamp", "")),
|
||||
}
|
||||
for e in events
|
||||
])
|
||||
return JSONResponse(
|
||||
[
|
||||
{
|
||||
"event_type": e.event_type,
|
||||
"title": getattr(e, "description", e.event_type),
|
||||
"timestamp": str(getattr(e, "timestamp", "")),
|
||||
}
|
||||
for e in events
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
return JSONResponse([])
|
||||
|
||||
@@ -7,9 +7,10 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from config import settings
|
||||
from dashboard.templating import templates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -20,11 +21,17 @@ router = APIRouter(tags=["tasks"])
|
||||
# Database helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DB_PATH = Path("data/tasks.db")
|
||||
DB_PATH = Path(settings.repo_root) / "data" / "tasks.db"
|
||||
|
||||
VALID_STATUSES = {
|
||||
"pending_approval", "approved", "running", "paused",
|
||||
"completed", "vetoed", "failed", "backlogged",
|
||||
"pending_approval",
|
||||
"approved",
|
||||
"running",
|
||||
"paused",
|
||||
"completed",
|
||||
"vetoed",
|
||||
"failed",
|
||||
"backlogged",
|
||||
}
|
||||
VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
|
||||
|
||||
@@ -33,7 +40,8 @@ def _get_db() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
@@ -46,7 +54,8 @@ def _get_db() -> sqlite3.Connection:
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -91,37 +100,52 @@ class _TaskView:
|
||||
# Page routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/tasks", response_class=HTMLResponse)
|
||||
async def tasks_page(request: Request):
|
||||
"""Render the main task queue page with 3-column layout."""
|
||||
db = _get_db()
|
||||
try:
|
||||
pending = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC"
|
||||
).fetchall()]
|
||||
active = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
|
||||
).fetchall()]
|
||||
completed = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
|
||||
).fetchall()]
|
||||
pending = [
|
||||
_TaskView(_row_to_dict(r))
|
||||
for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
]
|
||||
active = [
|
||||
_TaskView(_row_to_dict(r))
|
||||
for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
]
|
||||
completed = [
|
||||
_TaskView(_row_to_dict(r))
|
||||
for r in db.execute(
|
||||
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
|
||||
).fetchall()
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return templates.TemplateResponse(request, "tasks.html", {
|
||||
"pending_count": len(pending),
|
||||
"pending": pending,
|
||||
"active": active,
|
||||
"completed": completed,
|
||||
"agents": [], # no agent roster wired yet
|
||||
"pre_assign": "",
|
||||
})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"tasks.html",
|
||||
{
|
||||
"pending_count": len(pending),
|
||||
"pending": pending,
|
||||
"active": active,
|
||||
"completed": completed,
|
||||
"agents": [], # no agent roster wired yet
|
||||
"pre_assign": "",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTMX partials (polled by the template)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/tasks/pending", response_class=HTMLResponse)
|
||||
async def tasks_pending(request: Request):
|
||||
db = _get_db()
|
||||
@@ -134,9 +158,11 @@ async def tasks_pending(request: Request):
|
||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||
parts = []
|
||||
for task in tasks:
|
||||
parts.append(templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode())
|
||||
parts.append(
|
||||
templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode()
|
||||
)
|
||||
if not parts:
|
||||
return HTMLResponse('<div class="empty-column">No pending tasks</div>')
|
||||
return HTMLResponse("".join(parts))
|
||||
@@ -154,9 +180,11 @@ async def tasks_active(request: Request):
|
||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||
parts = []
|
||||
for task in tasks:
|
||||
parts.append(templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode())
|
||||
parts.append(
|
||||
templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode()
|
||||
)
|
||||
if not parts:
|
||||
return HTMLResponse('<div class="empty-column">No active tasks</div>')
|
||||
return HTMLResponse("".join(parts))
|
||||
@@ -174,9 +202,11 @@ async def tasks_completed(request: Request):
|
||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||
parts = []
|
||||
for task in tasks:
|
||||
parts.append(templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode())
|
||||
parts.append(
|
||||
templates.TemplateResponse(
|
||||
request, "partials/task_card.html", {"task": task}
|
||||
).body.decode()
|
||||
)
|
||||
if not parts:
|
||||
return HTMLResponse('<div class="empty-column">No completed tasks yet</div>')
|
||||
return HTMLResponse("".join(parts))
|
||||
@@ -186,6 +216,7 @@ async def tasks_completed(request: Request):
|
||||
# Form-based create (used by the modal in tasks.html)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/tasks/create", response_class=HTMLResponse)
|
||||
async def create_task_form(
|
||||
request: Request,
|
||||
@@ -218,6 +249,7 @@ async def create_task_form(
|
||||
# Task action endpoints (approve, veto, modify, pause, cancel, retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/approve", response_class=HTMLResponse)
|
||||
async def approve_task(request: Request, task_id: str):
|
||||
return await _set_status(request, task_id, "approved")
|
||||
@@ -268,7 +300,9 @@ async def modify_task(
|
||||
|
||||
async def _set_status(request: Request, task_id: str, new_status: str):
|
||||
"""Helper to update status and return refreshed task card."""
|
||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
)
|
||||
db = _get_db()
|
||||
try:
|
||||
db.execute(
|
||||
@@ -289,6 +323,7 @@ async def _set_status(request: Request, task_id: str, new_status: str):
|
||||
# JSON API (for programmatic access / Timmy's tool calls)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/api/tasks", response_class=JSONResponse, status_code=201)
|
||||
async def api_create_task(request: Request):
|
||||
"""Create a task via JSON API."""
|
||||
@@ -345,7 +380,9 @@ async def api_update_status(task_id: str, request: Request):
|
||||
if not new_status or new_status not in VALID_STATUSES:
|
||||
raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}")
|
||||
|
||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||
)
|
||||
db = _get_db()
|
||||
try:
|
||||
db.execute(
|
||||
@@ -379,6 +416,7 @@ async def api_delete_task(task_id: str):
|
||||
# Queue status (polled by the chat panel every 10 seconds)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/api/queue/status", response_class=JSONResponse)
|
||||
async def queue_status(assigned_to: str = "default"):
|
||||
"""Return queue status for the chat panel's agent status indicator."""
|
||||
@@ -396,14 +434,18 @@ async def queue_status(assigned_to: str = "default"):
|
||||
db.close()
|
||||
|
||||
if running:
|
||||
return JSONResponse({
|
||||
"is_working": True,
|
||||
"current_task": {"id": running["id"], "title": running["title"]},
|
||||
"tasks_ahead": 0,
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"is_working": True,
|
||||
"current_task": {"id": running["id"], "title": running["title"]},
|
||||
"tasks_ahead": 0,
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse({
|
||||
"is_working": False,
|
||||
"current_task": None,
|
||||
"tasks_ahead": ahead["cnt"] if ahead else 0,
|
||||
})
|
||||
return JSONResponse(
|
||||
{
|
||||
"is_working": False,
|
||||
"current_task": None,
|
||||
"tasks_ahead": ahead["cnt"] if ahead else 0,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -10,8 +10,8 @@ import logging
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from timmy.thinking import thinking_engine
|
||||
from dashboard.templating import templates
|
||||
from timmy.thinking import thinking_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from collections import namedtuple
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from timmy.tools import get_all_available_tools
|
||||
from dashboard.templating import templates
|
||||
from timmy.tools import get_all_available_tools
|
||||
|
||||
router = APIRouter(tags=["tools"])
|
||||
|
||||
@@ -29,9 +29,7 @@ def _build_agent_tools():
|
||||
for name, fn in available.items()
|
||||
]
|
||||
|
||||
return [
|
||||
_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))
|
||||
]
|
||||
return [_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))]
|
||||
|
||||
|
||||
@router.get("/tools", response_class=HTMLResponse)
|
||||
|
||||
@@ -10,9 +10,9 @@ import logging
|
||||
from fastapi import APIRouter, Form, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from dashboard.templating import templates
|
||||
from integrations.voice.nlu import detect_intent, extract_command
|
||||
from timmy.agent import create_timmy
|
||||
from dashboard.templating import templates
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,6 +38,7 @@ async def tts_status():
|
||||
"""Check TTS engine availability."""
|
||||
try:
|
||||
from timmy_serve.voice_tts import voice_tts
|
||||
|
||||
return {
|
||||
"available": voice_tts.available,
|
||||
"voices": voice_tts.get_voices() if voice_tts.available else [],
|
||||
@@ -51,6 +52,7 @@ async def tts_speak(text: str = Form(...)):
|
||||
"""Speak text aloud via TTS."""
|
||||
try:
|
||||
from timmy_serve.voice_tts import voice_tts
|
||||
|
||||
if not voice_tts.available:
|
||||
return {"spoken": False, "reason": "TTS engine not available"}
|
||||
voice_tts.speak(text)
|
||||
@@ -86,6 +88,7 @@ async def voice_command(text: str = Form(...)):
|
||||
|
||||
# ── Enhanced voice pipeline ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/enhanced/process")
|
||||
async def process_voice_input(
|
||||
text: str = Form(...),
|
||||
@@ -133,6 +136,7 @@ async def process_voice_input(
|
||||
if speak_response and response_text:
|
||||
try:
|
||||
from timmy_serve.voice_tts import voice_tts
|
||||
|
||||
if voice_tts.available:
|
||||
voice_tts.speak(response_text)
|
||||
except Exception:
|
||||
|
||||
@@ -6,7 +6,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse
|
||||
|
||||
from dashboard.templating import templates
|
||||
@@ -26,7 +26,8 @@ def _get_db() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS work_orders (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
@@ -41,7 +42,8 @@ def _get_db() -> sqlite3.Connection:
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -71,7 +73,9 @@ class _WOView:
|
||||
self.submitter = row.get("submitter", "dashboard")
|
||||
self.status = _EnumLike(row.get("status", "submitted"))
|
||||
raw_files = row.get("related_files", "")
|
||||
self.related_files = [f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else []
|
||||
self.related_files = (
|
||||
[f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else []
|
||||
)
|
||||
self.result = row.get("result", "")
|
||||
self.rejection_reason = row.get("rejection_reason", "")
|
||||
self.created_at = row.get("created_at", "")
|
||||
@@ -98,6 +102,7 @@ def _query_wos(db, statuses):
|
||||
# Page route
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/work-orders/queue", response_class=HTMLResponse)
|
||||
async def work_orders_page(request: Request):
|
||||
db = _get_db()
|
||||
@@ -109,21 +114,26 @@ async def work_orders_page(request: Request):
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return templates.TemplateResponse(request, "work_orders.html", {
|
||||
"pending_count": len(pending),
|
||||
"pending": pending,
|
||||
"active": active,
|
||||
"completed": completed,
|
||||
"rejected": rejected,
|
||||
"priorities": PRIORITIES,
|
||||
"categories": CATEGORIES,
|
||||
})
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"work_orders.html",
|
||||
{
|
||||
"pending_count": len(pending),
|
||||
"pending": pending,
|
||||
"active": active,
|
||||
"completed": completed,
|
||||
"rejected": rejected,
|
||||
"priorities": PRIORITIES,
|
||||
"categories": CATEGORIES,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Form submit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/work-orders/submit", response_class=HTMLResponse)
|
||||
async def submit_work_order(
|
||||
request: Request,
|
||||
@@ -159,6 +169,7 @@ async def submit_work_order(
|
||||
# HTMX partials
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/work-orders/queue/pending", response_class=HTMLResponse)
|
||||
async def pending_partial(request: Request):
|
||||
db = _get_db()
|
||||
@@ -174,7 +185,9 @@ async def pending_partial(request: Request):
|
||||
parts = []
|
||||
for wo in wos:
|
||||
parts.append(
|
||||
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode()
|
||||
templates.TemplateResponse(
|
||||
request, "partials/work_order_card.html", {"wo": wo}
|
||||
).body.decode()
|
||||
)
|
||||
return HTMLResponse("".join(parts))
|
||||
|
||||
@@ -194,7 +207,9 @@ async def active_partial(request: Request):
|
||||
parts = []
|
||||
for wo in wos:
|
||||
parts.append(
|
||||
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode()
|
||||
templates.TemplateResponse(
|
||||
request, "partials/work_order_card.html", {"wo": wo}
|
||||
).body.decode()
|
||||
)
|
||||
return HTMLResponse("".join(parts))
|
||||
|
||||
@@ -203,8 +218,11 @@ async def active_partial(request: Request):
|
||||
# Action endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _update_status(request: Request, wo_id: str, new_status: str, **extra):
|
||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
|
||||
completed_at = (
|
||||
datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
|
||||
)
|
||||
db = _get_db()
|
||||
try:
|
||||
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]
|
||||
|
||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: str # "user" | "agent" | "error"
|
||||
role: str # "user" | "agent" | "error"
|
||||
content: str
|
||||
timestamp: str
|
||||
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
||||
@@ -16,7 +16,9 @@ class MessageLog:
|
||||
self._entries: list[Message] = []
|
||||
|
||||
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
||||
self._entries.append(Message(role=role, content=content, timestamp=timestamp, source=source))
|
||||
self._entries.append(
|
||||
Message(role=role, content=content, timestamp=timestamp, source=source)
|
||||
)
|
||||
|
||||
def all(self) -> list[Message]:
|
||||
return list(self._entries)
|
||||
|
||||
90
src/dashboard/templates/experiments.html
Normal file
90
src/dashboard/templates/experiments.html
Normal file
@@ -0,0 +1,90 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}{{ page_title }}{% endblock %}
|
||||
|
||||
{% block extra_styles %}
|
||||
<style>
|
||||
.experiments-container { max-width: 1000px; margin: 0 auto; }
|
||||
.exp-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 20px; }
|
||||
.exp-title { font-size: 1.3rem; font-weight: 700; color: var(--text-bright); }
|
||||
.exp-subtitle { font-size: 0.8rem; color: var(--text-dim); margin-top: 2px; }
|
||||
.exp-config { display: flex; gap: 16px; font-size: 0.8rem; color: var(--text-dim); }
|
||||
.exp-config span { background: var(--glass-bg); border: 1px solid var(--border); padding: 4px 10px; border-radius: 6px; }
|
||||
.exp-table { width: 100%; border-collapse: collapse; font-size: 0.85rem; }
|
||||
.exp-table th { text-align: left; padding: 8px 12px; color: var(--text-dim); border-bottom: 1px solid var(--border); font-weight: 600; }
|
||||
.exp-table td { padding: 8px 12px; border-bottom: 1px solid var(--border); color: var(--text); }
|
||||
.exp-table tr:hover { background: var(--glass-bg); }
|
||||
.metric-good { color: var(--success); }
|
||||
.metric-bad { color: var(--danger); }
|
||||
.btn-start { background: var(--accent); color: #fff; border: none; padding: 8px 18px; border-radius: 6px; cursor: pointer; font-size: 0.85rem; }
|
||||
.btn-start:hover { opacity: 0.9; }
|
||||
.btn-start:disabled { opacity: 0.4; cursor: not-allowed; }
|
||||
.disabled-note { font-size: 0.8rem; color: var(--text-dim); margin-top: 8px; }
|
||||
.empty-state { text-align: center; padding: 40px; color: var(--text-dim); }
|
||||
</style>
|
||||
{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<div class="experiments-container">
|
||||
<div class="exp-header">
|
||||
<div>
|
||||
<div class="exp-title">Autoresearch Experiments</div>
|
||||
<div class="exp-subtitle">Autonomous ML experiment loops — modify code, train, evaluate, iterate</div>
|
||||
</div>
|
||||
<div>
|
||||
{% if enabled %}
|
||||
<button class="btn-start"
|
||||
hx-post="/experiments/start"
|
||||
hx-target="#experiment-status"
|
||||
hx-swap="innerHTML">
|
||||
Start Experiment
|
||||
</button>
|
||||
{% else %}
|
||||
<button class="btn-start" disabled>Disabled</button>
|
||||
<div class="disabled-note">Set AUTORESEARCH_ENABLED=true to enable</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="exp-config">
|
||||
<span>Metric: {{ metric_name }}</span>
|
||||
<span>Budget: {{ time_budget }}s</span>
|
||||
<span>Max iters: {{ max_iterations }}</span>
|
||||
</div>
|
||||
|
||||
<div id="experiment-status" style="margin: 12px 0;"></div>
|
||||
|
||||
{% if history %}
|
||||
<table class="exp-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>#</th>
|
||||
<th>{{ metric_name }}</th>
|
||||
<th>Duration</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for run in history %}
|
||||
<tr>
|
||||
<td>{{ loop.index }}</td>
|
||||
<td>
|
||||
{% if run.metric is not none %}
|
||||
{{ "%.4f"|format(run.metric) }}
|
||||
{% else %}
|
||||
—
|
||||
{% endif %}
|
||||
</td>
|
||||
<td>{{ run.get("duration_s", "—") }}s</td>
|
||||
<td>{% if run.get("success") %}OK{% else %}{{ run.get("error", "failed") }}{% endif %}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<div class="empty-state">
|
||||
No experiments yet. Start one to begin autonomous training.
|
||||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endblock %}
|
||||
@@ -119,9 +119,7 @@ def capture_error(
|
||||
return None
|
||||
|
||||
# Format the stack trace
|
||||
tb_str = "".join(
|
||||
traceback.format_exception(type(exc), exc, exc.__traceback__)
|
||||
)
|
||||
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||
|
||||
# Extract file/line from traceback
|
||||
tb_obj = exc.__traceback__
|
||||
|
||||
@@ -19,38 +19,39 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class EventBroadcaster:
|
||||
"""Broadcasts events to WebSocket clients.
|
||||
|
||||
|
||||
Usage:
|
||||
from infrastructure.events.broadcaster import event_broadcaster
|
||||
event_broadcaster.broadcast(event)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ws_manager: Optional = None
|
||||
|
||||
|
||||
def _get_ws_manager(self):
|
||||
"""Lazy import to avoid circular deps."""
|
||||
if self._ws_manager is None:
|
||||
try:
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
|
||||
self._ws_manager = ws_manager
|
||||
except Exception as exc:
|
||||
logger.debug("WebSocket manager not available: %s", exc)
|
||||
return self._ws_manager
|
||||
|
||||
|
||||
async def broadcast(self, event: EventLogEntry) -> int:
|
||||
"""Broadcast an event to all connected WebSocket clients.
|
||||
|
||||
|
||||
Args:
|
||||
event: The event to broadcast
|
||||
|
||||
|
||||
Returns:
|
||||
Number of clients notified
|
||||
"""
|
||||
ws_manager = self._get_ws_manager()
|
||||
if not ws_manager:
|
||||
return 0
|
||||
|
||||
|
||||
# Build message payload
|
||||
payload = {
|
||||
"type": "event",
|
||||
@@ -62,9 +63,9 @@ class EventBroadcaster:
|
||||
"agent_id": event.agent_id,
|
||||
"timestamp": event.timestamp,
|
||||
"data": event.data,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Broadcast to all connected clients
|
||||
count = await ws_manager.broadcast_json(payload)
|
||||
@@ -73,10 +74,10 @@ class EventBroadcaster:
|
||||
except Exception as exc:
|
||||
logger.error("Failed to broadcast event: %s", exc)
|
||||
return 0
|
||||
|
||||
|
||||
def broadcast_sync(self, event: EventLogEntry) -> None:
|
||||
"""Synchronous wrapper for broadcast.
|
||||
|
||||
|
||||
Use this from synchronous code - it schedules the async broadcast
|
||||
in the event loop if one is running.
|
||||
"""
|
||||
@@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str:
|
||||
|
||||
def format_event_for_display(event: EventLogEntry) -> dict:
|
||||
"""Format event for display in activity feed.
|
||||
|
||||
|
||||
Returns dict with display-friendly fields.
|
||||
"""
|
||||
data = event.data or {}
|
||||
|
||||
|
||||
# Build description based on event type
|
||||
description = ""
|
||||
if event.event_type.value == "task.created":
|
||||
@@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict:
|
||||
val = str(data[key])
|
||||
description = val[:60] + "..." if len(val) > 60 else val
|
||||
break
|
||||
|
||||
|
||||
return {
|
||||
"id": event.id,
|
||||
"icon": get_event_icon(event.event_type.value),
|
||||
|
||||
@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class Event:
|
||||
"""A typed event in the system."""
|
||||
|
||||
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
|
||||
source: str # Agent or component that emitted the event
|
||||
data: dict = field(default_factory=dict)
|
||||
@@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
|
||||
|
||||
class EventBus:
|
||||
"""Async event bus for publish/subscribe pattern.
|
||||
|
||||
|
||||
Usage:
|
||||
bus = EventBus()
|
||||
|
||||
|
||||
# Subscribe to events
|
||||
@bus.subscribe("agent.task.*")
|
||||
async def handle_task(event: Event):
|
||||
print(f"Task event: {event.data}")
|
||||
|
||||
|
||||
# Publish events
|
||||
await bus.publish(Event(
|
||||
type="agent.task.assigned",
|
||||
@@ -45,88 +46,89 @@ class EventBus:
|
||||
data={"task_id": "123", "agent": "forge"}
|
||||
))
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: dict[str, list[EventHandler]] = {}
|
||||
self._history: list[Event] = []
|
||||
self._max_history = 1000
|
||||
logger.info("EventBus initialized")
|
||||
|
||||
|
||||
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
"""Decorator to subscribe to events matching a pattern.
|
||||
|
||||
|
||||
Patterns support wildcards:
|
||||
- "agent.task.assigned" — exact match
|
||||
- "agent.task.*" — any task event
|
||||
- "agent.*" — any agent event
|
||||
- "*" — all events
|
||||
"""
|
||||
|
||||
def decorator(handler: EventHandler) -> EventHandler:
|
||||
if event_pattern not in self._subscribers:
|
||||
self._subscribers[event_pattern] = []
|
||||
self._subscribers[event_pattern].append(handler)
|
||||
logger.debug("Subscribed handler to '%s'", event_pattern)
|
||||
return handler
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
|
||||
"""Remove a handler from a subscription."""
|
||||
if event_pattern not in self._subscribers:
|
||||
return False
|
||||
|
||||
|
||||
if handler in self._subscribers[event_pattern]:
|
||||
self._subscribers[event_pattern].remove(handler)
|
||||
logger.debug("Unsubscribed handler from '%s'", event_pattern)
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def publish(self, event: Event) -> int:
|
||||
"""Publish an event to all matching subscribers.
|
||||
|
||||
|
||||
Returns:
|
||||
Number of handlers invoked
|
||||
"""
|
||||
# Store in history
|
||||
self._history.append(event)
|
||||
if len(self._history) > self._max_history:
|
||||
self._history = self._history[-self._max_history:]
|
||||
|
||||
self._history = self._history[-self._max_history :]
|
||||
|
||||
# Find matching handlers
|
||||
handlers: list[EventHandler] = []
|
||||
|
||||
|
||||
for pattern, pattern_handlers in self._subscribers.items():
|
||||
if self._match_pattern(event.type, pattern):
|
||||
handlers.extend(pattern_handlers)
|
||||
|
||||
|
||||
# Invoke handlers concurrently
|
||||
if handlers:
|
||||
await asyncio.gather(
|
||||
*[self._invoke_handler(h, event) for h in handlers],
|
||||
return_exceptions=True
|
||||
*[self._invoke_handler(h, event) for h in handlers], return_exceptions=True
|
||||
)
|
||||
|
||||
|
||||
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
|
||||
return len(handlers)
|
||||
|
||||
|
||||
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
|
||||
"""Invoke a handler with error handling."""
|
||||
try:
|
||||
await handler(event)
|
||||
except Exception as exc:
|
||||
logger.error("Event handler failed for '%s': %s", event.type, exc)
|
||||
|
||||
|
||||
def _match_pattern(self, event_type: str, pattern: str) -> bool:
|
||||
"""Check if event type matches a wildcard pattern."""
|
||||
if pattern == "*":
|
||||
return True
|
||||
|
||||
|
||||
if pattern.endswith(".*"):
|
||||
prefix = pattern[:-2]
|
||||
return event_type.startswith(prefix + ".")
|
||||
|
||||
|
||||
return event_type == pattern
|
||||
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
event_type: str | None = None,
|
||||
@@ -135,15 +137,15 @@ class EventBus:
|
||||
) -> list[Event]:
|
||||
"""Get recent event history with optional filtering."""
|
||||
events = self._history
|
||||
|
||||
|
||||
if event_type:
|
||||
events = [e for e in events if e.type == event_type]
|
||||
|
||||
|
||||
if source:
|
||||
events = [e for e in events if e.source == source]
|
||||
|
||||
|
||||
return events[-limit:]
|
||||
|
||||
|
||||
def clear_history(self) -> None:
|
||||
"""Clear event history."""
|
||||
self._history.clear()
|
||||
@@ -156,11 +158,13 @@ event_bus = EventBus()
|
||||
# Convenience functions
|
||||
async def emit(event_type: str, source: str, data: dict) -> int:
|
||||
"""Quick emit an event."""
|
||||
return await event_bus.publish(Event(
|
||||
type=event_type,
|
||||
source=source,
|
||||
data=data,
|
||||
))
|
||||
return await event_bus.publish(
|
||||
Event(
|
||||
type=event_type,
|
||||
source=source,
|
||||
data=data,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||
|
||||
@@ -11,7 +11,7 @@ Usage:
|
||||
result = await git_hand.run("status")
|
||||
"""
|
||||
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
from infrastructure.hands.git import git_hand
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
|
||||
__all__ = ["shell_hand", "git_hand"]
|
||||
|
||||
@@ -25,16 +25,18 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Operations that require explicit confirmation before execution
|
||||
DESTRUCTIVE_OPS = frozenset({
|
||||
"push --force",
|
||||
"push -f",
|
||||
"reset --hard",
|
||||
"clean -fd",
|
||||
"clean -f",
|
||||
"branch -D",
|
||||
"checkout -- .",
|
||||
"restore .",
|
||||
})
|
||||
DESTRUCTIVE_OPS = frozenset(
|
||||
{
|
||||
"push --force",
|
||||
"push -f",
|
||||
"reset --hard",
|
||||
"clean -fd",
|
||||
"clean -f",
|
||||
"branch -D",
|
||||
"checkout -- .",
|
||||
"restore .",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -190,7 +192,9 @@ class GitHand:
|
||||
flag = "-b" if create else ""
|
||||
return await self.run(f"checkout {flag} {branch}".strip())
|
||||
|
||||
async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult:
|
||||
async def push(
|
||||
self, remote: str = "origin", branch: str = "", force: bool = False
|
||||
) -> GitResult:
|
||||
"""Push to remote. Force-push requires explicit opt-in."""
|
||||
args = f"push -u {remote} {branch}".strip()
|
||||
if force:
|
||||
|
||||
@@ -26,15 +26,17 @@ from config import settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Commands that are always blocked regardless of allow-list
|
||||
_BLOCKED_COMMANDS = frozenset({
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"mkfs",
|
||||
"dd if=/dev/zero",
|
||||
":(){ :|:& };:", # fork bomb
|
||||
"> /dev/sda",
|
||||
"chmod -R 777 /",
|
||||
})
|
||||
_BLOCKED_COMMANDS = frozenset(
|
||||
{
|
||||
"rm -rf /",
|
||||
"rm -rf /*",
|
||||
"mkfs",
|
||||
"dd if=/dev/zero",
|
||||
":(){ :|:& };:", # fork bomb
|
||||
"> /dev/sda",
|
||||
"chmod -R 777 /",
|
||||
}
|
||||
)
|
||||
|
||||
# Default allow-list: safe build/dev commands
|
||||
DEFAULT_ALLOWED_PREFIXES = (
|
||||
@@ -199,9 +201,7 @@ class ShellHand:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
latency = (time.time() - start) * 1000
|
||||
logger.warning(
|
||||
"Shell command timed out after %ds: %s", effective_timeout, command
|
||||
)
|
||||
logger.warning("Shell command timed out after %ds: %s", effective_timeout, command)
|
||||
return ShellResult(
|
||||
command=command,
|
||||
success=False,
|
||||
|
||||
@@ -11,15 +11,17 @@ the tool registry.
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
from infrastructure.hands.git import git_hand
|
||||
from infrastructure.hands.shell import shell_hand
|
||||
|
||||
try:
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
except ImportError:
|
||||
|
||||
def create_tool_schema(**kwargs):
|
||||
return kwargs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||
@@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = {
|
||||
|
||||
# ── Handlers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_shell(**kwargs: Any) -> str:
|
||||
"""Handler for the shell MCP tool."""
|
||||
command = kwargs.get("command", "")
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
@@ -17,6 +10,13 @@ from infrastructure.models.multimodal import (
|
||||
model_supports_vision,
|
||||
pull_model_with_fallback,
|
||||
)
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
|
||||
@@ -21,39 +21,130 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelCapability(Enum):
|
||||
"""Capabilities a model can have."""
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
|
||||
# Known model capabilities (local Ollama models)
|
||||
# These are used when we can't query the model directly
|
||||
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# Llama 3.x series
|
||||
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.1": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:8b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:8b-instruct": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:70b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.1:405b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"llama3.2": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
"llama3.2:3b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2-vision": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"llama3.2-vision:11b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
# Qwen series
|
||||
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
"qwen2.5": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:14b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:32b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5:72b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"qwen2.5-vl": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"qwen2.5-vl:3b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
"qwen2.5-vl:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
ModelCapability.VISION,
|
||||
},
|
||||
# DeepSeek series
|
||||
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
@@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"deepseek-v3": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Gemma series
|
||||
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Mistral series
|
||||
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"mistral": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral:7b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-nemo": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-small": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"mistral-large": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Vision-specific models
|
||||
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
@@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
|
||||
# Phi series
|
||||
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"phi4": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Command R
|
||||
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
"command-r": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"command-r:35b": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"command-r-plus": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
# Granite (IBM)
|
||||
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-dense": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
"granite3-moe": {
|
||||
ModelCapability.TEXT,
|
||||
ModelCapability.TOOLS,
|
||||
ModelCapability.JSON,
|
||||
ModelCapability.STREAMING,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# These are tried in order when the primary model doesn't support a capability
|
||||
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
ModelCapability.VISION: [
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
],
|
||||
ModelCapability.TOOLS: [
|
||||
"llama3.1:8b-instruct", # Best tool use
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
],
|
||||
ModelCapability.AUDIO: [
|
||||
# Audio models are less common in Ollama
|
||||
@@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a model's capabilities and availability."""
|
||||
|
||||
name: str
|
||||
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||
is_available: bool = False
|
||||
is_pulled: bool = False
|
||||
size_mb: Optional[int] = None
|
||||
description: str = ""
|
||||
|
||||
|
||||
def supports(self, capability: ModelCapability) -> bool:
|
||||
"""Check if model supports a specific capability."""
|
||||
return capability in self.capabilities
|
||||
@@ -142,26 +288,26 @@ class ModelInfo:
|
||||
|
||||
class MultiModalManager:
|
||||
"""Manages multi-modal model capabilities and fallback chains.
|
||||
|
||||
|
||||
This class:
|
||||
1. Detects what capabilities each model has
|
||||
2. Maintains fallback chains for different capabilities
|
||||
3. Pulls models on-demand with automatic fallback
|
||||
4. Routes requests to appropriate models based on content type
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||
self.ollama_url = ollama_url or settings.ollama_url
|
||||
self._available_models: dict[str, ModelInfo] = {}
|
||||
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||
self._refresh_available_models()
|
||||
|
||||
|
||||
def _refresh_available_models(self) -> None:
|
||||
"""Query Ollama for available models."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
import urllib.request
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
@@ -170,7 +316,7 @@ class MultiModalManager:
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
|
||||
|
||||
for model_data in data.get("models", []):
|
||||
name = model_data.get("name", "")
|
||||
self._available_models[name] = ModelInfo(
|
||||
@@ -181,58 +327,53 @@ class MultiModalManager:
|
||||
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||
description=model_data.get("details", {}).get("family", ""),
|
||||
)
|
||||
|
||||
|
||||
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Could not refresh available models: %s", exc)
|
||||
|
||||
|
||||
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Detect capabilities for a model based on known data."""
|
||||
# Normalize model name (strip tags for lookup)
|
||||
base_name = model_name.split(":")[0]
|
||||
|
||||
|
||||
# Try exact match first
|
||||
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||
|
||||
|
||||
# Try base name match
|
||||
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||
|
||||
|
||||
# Default to text-only for unknown models
|
||||
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||
|
||||
|
||||
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Get capabilities for a specific model."""
|
||||
if model_name in self._available_models:
|
||||
return self._available_models[model_name].capabilities
|
||||
return self._detect_capabilities(model_name)
|
||||
|
||||
|
||||
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||
"""Check if a model supports a specific capability."""
|
||||
capabilities = self.get_model_capabilities(model_name)
|
||||
return capability in capabilities
|
||||
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||
"""Get all available models that support a capability."""
|
||||
return [
|
||||
info for info in self._available_models.values()
|
||||
if capability in info.capabilities
|
||||
]
|
||||
|
||||
return [info for info in self._available_models.values() if capability in info.capabilities]
|
||||
|
||||
def get_best_model_for(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
self, capability: ModelCapability, preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the best available model for a specific capability.
|
||||
|
||||
|
||||
Args:
|
||||
capability: The required capability
|
||||
preferred_model: Preferred model to use if available and capable
|
||||
|
||||
|
||||
Returns:
|
||||
Model name or None if no suitable model found
|
||||
"""
|
||||
@@ -243,25 +384,26 @@ class MultiModalManager:
|
||||
return preferred_model
|
||||
logger.debug(
|
||||
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||
preferred_model, capability.name
|
||||
preferred_model,
|
||||
capability.name,
|
||||
)
|
||||
|
||||
|
||||
# Check fallback chain for this capability
|
||||
fallback_chain = self._fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if model_name in self._available_models:
|
||||
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||
return model_name
|
||||
|
||||
|
||||
# Find any available model with this capability
|
||||
capable_models = self.get_models_with_capability(capability)
|
||||
if capable_models:
|
||||
# Sort by size (prefer smaller/faster models as fallback)
|
||||
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
||||
capable_models.sort(key=lambda m: m.size_mb or float("inf"))
|
||||
return capable_models[0].name
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def pull_model_with_fallback(
|
||||
self,
|
||||
primary_model: str,
|
||||
@@ -269,58 +411,58 @@ class MultiModalManager:
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Pull a model with automatic fallback if unavailable.
|
||||
|
||||
|
||||
Args:
|
||||
primary_model: The desired model to use
|
||||
capability: Required capability (for finding fallback)
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
# Check if primary model is already available
|
||||
if primary_model in self._available_models:
|
||||
return primary_model, False
|
||||
|
||||
|
||||
# Try to pull the primary model
|
||||
if auto_pull:
|
||||
if self._pull_model(primary_model):
|
||||
return primary_model, False
|
||||
|
||||
|
||||
# Need to find a fallback
|
||||
if capability:
|
||||
fallback = self.get_best_model_for(capability, primary_model)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Primary model %s unavailable, using fallback %s",
|
||||
primary_model, fallback
|
||||
"Primary model %s unavailable, using fallback %s", primary_model, fallback
|
||||
)
|
||||
return fallback, True
|
||||
|
||||
|
||||
# Last resort: use the configured default model
|
||||
default_model = settings.ollama_model
|
||||
if default_model in self._available_models:
|
||||
logger.warning(
|
||||
"Falling back to default model %s (primary: %s unavailable)",
|
||||
default_model, primary_model
|
||||
default_model,
|
||||
primary_model,
|
||||
)
|
||||
return default_model, True
|
||||
|
||||
|
||||
# Absolute last resort
|
||||
return primary_model, False
|
||||
|
||||
|
||||
def _pull_model(self, model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
import urllib.request
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
@@ -328,7 +470,7 @@ class MultiModalManager:
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
@@ -338,55 +480,51 @@ class MultiModalManager:
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
def configure_fallback_chain(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
models: list[str]
|
||||
) -> None:
|
||||
|
||||
def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None:
|
||||
"""Configure a custom fallback chain for a capability."""
|
||||
self._fallback_chains[capability] = models
|
||||
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||
|
||||
|
||||
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||
"""Get the fallback chain for a capability."""
|
||||
return list(self._fallback_chains.get(capability, []))
|
||||
|
||||
|
||||
def list_available_models(self) -> list[ModelInfo]:
|
||||
"""List all available models with their capabilities."""
|
||||
return list(self._available_models.values())
|
||||
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh the list of available models."""
|
||||
self._refresh_available_models()
|
||||
|
||||
|
||||
def get_model_for_content(
|
||||
self,
|
||||
content_type: str, # "text", "image", "audio", "multimodal"
|
||||
preferred_model: Optional[str] = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Get appropriate model based on content type.
|
||||
|
||||
|
||||
Args:
|
||||
content_type: Type of content (text, image, audio, multimodal)
|
||||
preferred_model: User's preferred model
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
content_type = content_type.lower()
|
||||
|
||||
|
||||
if content_type in ("image", "vision", "multimodal"):
|
||||
# For vision content, we need a vision-capable model
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or "llava:7b",
|
||||
capability=ModelCapability.VISION,
|
||||
)
|
||||
|
||||
|
||||
elif content_type == "audio":
|
||||
# Audio support is limited in Ollama
|
||||
# Would need specific audio models
|
||||
@@ -395,7 +533,7 @@ class MultiModalManager:
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
# Standard text content
|
||||
return self.pull_model_with_fallback(
|
||||
@@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager:
|
||||
|
||||
|
||||
def get_model_for_capability(
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
capability: ModelCapability, preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Convenience function to get best model for a capability."""
|
||||
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||
@@ -430,9 +567,7 @@ def pull_model_with_fallback(
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Convenience function to pull model with fallback."""
|
||||
return get_multimodal_manager().pull_model_with_fallback(
|
||||
primary_model, capability, auto_pull
|
||||
)
|
||||
return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull)
|
||||
|
||||
|
||||
def model_supports_vision(model_name: str) -> bool:
|
||||
|
||||
@@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db")
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Supported model weight formats."""
|
||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||
|
||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||
|
||||
|
||||
class ModelRole(str, Enum):
|
||||
"""Role a model can play in the system (OpenClaw-RL style)."""
|
||||
GENERAL = "general" # Default agent inference
|
||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||
TEACHER = "teacher" # On-policy distillation teacher
|
||||
JUDGE = "judge" # Output quality evaluation
|
||||
|
||||
GENERAL = "general" # Default agent inference
|
||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||
TEACHER = "teacher" # On-policy distillation teacher
|
||||
JUDGE = "judge" # Output quality evaluation
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomModel:
|
||||
"""A registered custom model."""
|
||||
|
||||
name: str
|
||||
format: ModelFormat
|
||||
path: str # Absolute path or Ollama model name
|
||||
path: str # Absolute path or Ollama model name
|
||||
role: ModelRole = ModelRole.GENERAL
|
||||
context_window: int = 4096
|
||||
description: str = ""
|
||||
@@ -141,10 +144,16 @@ class ModelRegistry:
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
model.name, model.format.value, model.path,
|
||||
model.role.value, model.context_window, model.description,
|
||||
model.registered_at, int(model.active),
|
||||
model.default_temperature, model.max_tokens,
|
||||
model.name,
|
||||
model.format.value,
|
||||
model.path,
|
||||
model.role.value,
|
||||
model.context_window,
|
||||
model.description,
|
||||
model.registered_at,
|
||||
int(model.active),
|
||||
model.default_temperature,
|
||||
model.max_tokens,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
@@ -160,9 +169,7 @@ class ModelRegistry:
|
||||
return False
|
||||
conn = _get_conn()
|
||||
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
||||
conn.execute(
|
||||
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
|
||||
)
|
||||
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
del self._models[name]
|
||||
|
||||
@@ -9,8 +9,8 @@ No cloud push services — everything stays local.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import platform
|
||||
import subprocess
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
@@ -25,9 +25,7 @@ class Notification:
|
||||
title: str
|
||||
message: str
|
||||
category: str # swarm | task | agent | system | payment
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
read: bool = False
|
||||
|
||||
|
||||
@@ -74,9 +72,11 @@ class PushNotifier:
|
||||
def _native_notify(self, title: str, message: str) -> None:
|
||||
"""Send a native macOS notification via osascript."""
|
||||
try:
|
||||
safe_message = message.replace("\\", "\\\\").replace('"', '\\"')
|
||||
safe_title = title.replace("\\", "\\\\").replace('"', '\\"')
|
||||
script = (
|
||||
f'display notification "{message}" '
|
||||
f'with title "Agent Dashboard" subtitle "{title}"'
|
||||
f'display notification "{safe_message}" '
|
||||
f'with title "Agent Dashboard" subtitle "{safe_title}"'
|
||||
)
|
||||
subprocess.Popen(
|
||||
["osascript", "-e", script],
|
||||
@@ -114,7 +114,7 @@ class PushNotifier:
|
||||
def clear(self) -> None:
|
||||
self._notifications.clear()
|
||||
|
||||
def add_listener(self, callback) -> None:
|
||||
def add_listener(self, callback: "Callable[[Notification], None]") -> None:
|
||||
"""Register a callback for real-time notification delivery."""
|
||||
self._listeners.append(callback)
|
||||
|
||||
@@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None:
|
||||
logger.info("Briefing ready but no pending approvals — skipping native notification")
|
||||
return
|
||||
|
||||
message = (
|
||||
f"Your morning briefing is ready. "
|
||||
f"{n_approvals} item(s) await your approval."
|
||||
)
|
||||
message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval."
|
||||
notifier.notify(
|
||||
title="Morning Briefing Ready",
|
||||
message=message,
|
||||
|
||||
@@ -156,33 +156,23 @@ class OpenFangClient:
|
||||
|
||||
async def browse(self, url: str, instruction: str = "") -> HandResult:
|
||||
"""Web automation via OpenFang's Browser hand."""
|
||||
return await self.execute_hand(
|
||||
"browser", {"url": url, "instruction": instruction}
|
||||
)
|
||||
return await self.execute_hand("browser", {"url": url, "instruction": instruction})
|
||||
|
||||
async def collect(self, target: str, depth: str = "shallow") -> HandResult:
|
||||
"""OSINT collection via OpenFang's Collector hand."""
|
||||
return await self.execute_hand(
|
||||
"collector", {"target": target, "depth": depth}
|
||||
)
|
||||
return await self.execute_hand("collector", {"target": target, "depth": depth})
|
||||
|
||||
async def predict(self, question: str, horizon: str = "1w") -> HandResult:
|
||||
"""Superforecasting via OpenFang's Predictor hand."""
|
||||
return await self.execute_hand(
|
||||
"predictor", {"question": question, "horizon": horizon}
|
||||
)
|
||||
return await self.execute_hand("predictor", {"question": question, "horizon": horizon})
|
||||
|
||||
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
|
||||
"""Prospect discovery via OpenFang's Lead hand."""
|
||||
return await self.execute_hand(
|
||||
"lead", {"icp": icp, "max_results": max_results}
|
||||
)
|
||||
return await self.execute_hand("lead", {"icp": icp, "max_results": max_results})
|
||||
|
||||
async def research(self, topic: str, depth: str = "standard") -> HandResult:
|
||||
"""Deep research via OpenFang's Researcher hand."""
|
||||
return await self.execute_hand(
|
||||
"researcher", {"topic": topic, "depth": depth}
|
||||
)
|
||||
return await self.execute_hand("researcher", {"topic": topic, "depth": depth})
|
||||
|
||||
# ── Inventory ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client
|
||||
try:
|
||||
from mcp.schemas.base import create_tool_schema
|
||||
except ImportError:
|
||||
|
||||
def create_tool_schema(**kwargs):
|
||||
return kwargs
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Cascade LLM Router — Automatic failover between providers."""
|
||||
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
from .api import router
|
||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||
|
||||
__all__ = [
|
||||
"CascadeRouter",
|
||||
|
||||
@@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"])
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
"""Request body for completions."""
|
||||
|
||||
messages: list[dict[str, str]]
|
||||
model: str | None = None
|
||||
temperature: float = 0.7
|
||||
@@ -23,6 +24,7 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
"""Response from completion endpoint."""
|
||||
|
||||
content: str
|
||||
provider: str
|
||||
model: str
|
||||
@@ -31,6 +33,7 @@ class CompletionResponse(BaseModel):
|
||||
|
||||
class ProviderControl(BaseModel):
|
||||
"""Control a provider's status."""
|
||||
|
||||
action: str # "enable", "disable", "reset_circuit"
|
||||
|
||||
|
||||
@@ -45,7 +48,7 @@ async def complete(
|
||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||
) -> dict[str, Any]:
|
||||
"""Complete a conversation with automatic failover.
|
||||
|
||||
|
||||
Routes through providers in priority order until one succeeds.
|
||||
"""
|
||||
try:
|
||||
@@ -108,30 +111,32 @@ async def control_provider(
|
||||
if p.name == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||
|
||||
|
||||
if control.action == "enable":
|
||||
provider.enabled = True
|
||||
provider.status = provider.status.__class__.HEALTHY
|
||||
return {"message": f"Provider {provider_name} enabled"}
|
||||
|
||||
|
||||
elif control.action == "disable":
|
||||
provider.enabled = False
|
||||
from .cascade import ProviderStatus
|
||||
|
||||
provider.status = ProviderStatus.DISABLED
|
||||
return {"message": f"Provider {provider_name} disabled"}
|
||||
|
||||
|
||||
elif control.action == "reset_circuit":
|
||||
from .cascade import CircuitState, ProviderStatus
|
||||
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.half_open_calls = 0
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
return {"message": f"Circuit breaker reset for {provider_name}"}
|
||||
|
||||
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
|
||||
|
||||
@@ -142,28 +147,35 @@ async def run_health_check(
|
||||
) -> dict[str, Any]:
|
||||
"""Run health checks on all providers."""
|
||||
results = []
|
||||
|
||||
|
||||
for provider in cascade.providers:
|
||||
# Quick ping to check availability
|
||||
is_healthy = cascade._check_provider_available(provider)
|
||||
|
||||
|
||||
from .cascade import ProviderStatus
|
||||
|
||||
if is_healthy:
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Reset circuit if it was open but now healthy
|
||||
provider.circuit_state = provider.circuit_state.__class__.CLOSED
|
||||
provider.circuit_opened_at = None
|
||||
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED
|
||||
provider.status = (
|
||||
ProviderStatus.HEALTHY
|
||||
if provider.metrics.error_rate < 0.1
|
||||
else ProviderStatus.DEGRADED
|
||||
)
|
||||
else:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
results.append({
|
||||
"name": provider.name,
|
||||
"type": provider.type,
|
||||
"healthy": is_healthy,
|
||||
"status": provider.status.value,
|
||||
})
|
||||
|
||||
|
||||
results.append(
|
||||
{
|
||||
"name": provider.name,
|
||||
"type": provider.type,
|
||||
"healthy": is_healthy,
|
||||
"status": provider.status.value,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"checked_at": asyncio.get_event_loop().time(),
|
||||
"providers": results,
|
||||
@@ -177,7 +189,7 @@ async def get_config(
|
||||
) -> dict[str, Any]:
|
||||
"""Get router configuration (without secrets)."""
|
||||
cfg = cascade.config
|
||||
|
||||
|
||||
return {
|
||||
"timeout_seconds": cfg.timeout_seconds,
|
||||
"max_retries_per_provider": cfg.max_retries_per_provider,
|
||||
|
||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ProviderStatus(Enum):
|
||||
"""Health status of a provider."""
|
||||
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||
@@ -41,22 +42,25 @@ class ProviderStatus(Enum):
|
||||
|
||||
class CircuitState(Enum):
|
||||
"""Circuit breaker state."""
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Failing, rejecting requests
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
|
||||
total_requests: int = 0
|
||||
successful_requests: int = 0
|
||||
failed_requests: int = 0
|
||||
@@ -64,13 +68,13 @@ class ProviderMetrics:
|
||||
last_request_time: Optional[str] = None
|
||||
last_error_time: Optional[str] = None
|
||||
consecutive_failures: int = 0
|
||||
|
||||
|
||||
@property
|
||||
def avg_latency_ms(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.total_latency_ms / self.total_requests
|
||||
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
@@ -81,6 +85,7 @@ class ProviderMetrics:
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
@@ -93,6 +98,7 @@ class ModelCapability:
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
|
||||
name: str
|
||||
type: str # ollama, openai, anthropic, airllm
|
||||
enabled: bool
|
||||
@@ -101,14 +107,14 @@ class Provider:
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
models: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
# Runtime state
|
||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||
circuit_state: CircuitState = CircuitState.CLOSED
|
||||
circuit_opened_at: Optional[float] = None
|
||||
half_open_calls: int = 0
|
||||
|
||||
|
||||
def get_default_model(self) -> Optional[str]:
|
||||
"""Get the default model for this provider."""
|
||||
for model in self.models:
|
||||
@@ -117,7 +123,7 @@ class Provider:
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
@@ -126,7 +132,7 @@ class Provider:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
@@ -139,6 +145,7 @@ class Provider:
|
||||
@dataclass
|
||||
class RouterConfig:
|
||||
"""Cascade router configuration."""
|
||||
|
||||
timeout_seconds: int = 30
|
||||
max_retries_per_provider: int = 2
|
||||
retry_delay_seconds: int = 1
|
||||
@@ -154,22 +161,22 @@ class RouterConfig:
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
|
||||
Now with multi-modal support:
|
||||
- Automatically detects content type (text, vision, audio)
|
||||
- Selects appropriate models based on capabilities
|
||||
- Falls back through capability-specific model chains
|
||||
- Supports image URLs and base64 encoding
|
||||
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
|
||||
# Text request
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
|
||||
# Vision request (automatically detects and selects vision model)
|
||||
response = await router.complete(
|
||||
messages=[{
|
||||
@@ -179,68 +186,75 @@ class CascadeRouter:
|
||||
}],
|
||||
model="llava:7b"
|
||||
)
|
||||
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None) -> None:
|
||||
self.config_path = config_path or Path("config/providers.yaml")
|
||||
self.providers: list[Provider] = []
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
|
||||
# Initialize multi-modal manager if available
|
||||
self._mm_manager: Optional[Any] = None
|
||||
try:
|
||||
from infrastructure.models.multimodal import get_multimodal_manager
|
||||
|
||||
self._mm_manager = get_multimodal_manager()
|
||||
except Exception as exc:
|
||||
logger.debug("Multi-modal manager not available: %s", exc)
|
||||
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""Load configuration from YAML."""
|
||||
if not self.config_path.exists():
|
||||
logger.warning("Config not found: %s, using defaults", self.config_path)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
if yaml is None:
|
||||
raise RuntimeError("PyYAML not installed")
|
||||
|
||||
|
||||
content = self.config_path.read_text()
|
||||
# Expand environment variables
|
||||
content = self._expand_env_vars(content)
|
||||
data = yaml.safe_load(content)
|
||||
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
|
||||
"failure_threshold", 5
|
||||
),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
|
||||
"recovery_timeout", 60
|
||||
),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
|
||||
"half_open_max_calls", 2
|
||||
),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
|
||||
# Load providers
|
||||
for p_data in data.get("providers", []):
|
||||
# Skip disabled providers
|
||||
if not p_data.get("enabled", False):
|
||||
continue
|
||||
|
||||
|
||||
provider = Provider(
|
||||
name=p_data["name"],
|
||||
type=p_data["type"],
|
||||
@@ -251,30 +265,34 @@ class CascadeRouter:
|
||||
base_url=p_data.get("base_url"),
|
||||
models=p_data.get("models", []),
|
||||
)
|
||||
|
||||
|
||||
# Check if provider is actually available
|
||||
if self._check_provider_available(provider):
|
||||
self.providers.append(provider)
|
||||
else:
|
||||
logger.warning("Provider %s not available, skipping", provider.name)
|
||||
|
||||
|
||||
# Sort by priority
|
||||
self.providers.sort(key=lambda p: p.priority)
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load config: %s", exc)
|
||||
|
||||
|
||||
def _expand_env_vars(self, content: str) -> str:
|
||||
"""Expand ${VAR} syntax in YAML content."""
|
||||
"""Expand ${VAR} syntax in YAML content.
|
||||
|
||||
Uses os.environ directly (not settings) because this is a generic
|
||||
YAML config loader that must expand arbitrary variable references.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
def replace_var(match):
|
||||
|
||||
def replace_var(match: "re.Match[str]") -> str:
|
||||
var_name = match.group(1)
|
||||
return os.environ.get(var_name, match.group(0))
|
||||
|
||||
|
||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||
|
||||
|
||||
def _check_provider_available(self, provider: Provider) -> bool:
|
||||
"""Check if a provider is actually available."""
|
||||
if provider.type == "ollama":
|
||||
@@ -288,48 +306,49 @@ class CascadeRouter:
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
elif provider.type == "airllm":
|
||||
# Check if airllm is installed
|
||||
try:
|
||||
import airllm
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
elif provider.type in ("openai", "anthropic", "grok"):
|
||||
# Check if API key is set
|
||||
return provider.api_key is not None and provider.api_key != ""
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -338,7 +357,7 @@ class CascadeRouter:
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
@@ -346,12 +365,9 @@ class CascadeRouter:
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
|
||||
def _get_fallback_model(
|
||||
self,
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType
|
||||
self, provider: Provider, original_model: str, content_type: ContentType
|
||||
) -> Optional[str]:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
@@ -360,24 +376,24 @@ class CascadeRouter:
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -386,21 +402,21 @@ class CascadeRouter:
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
|
||||
Multi-modal support:
|
||||
- Automatically detects if messages contain images
|
||||
- Falls back to vision-capable models when needed
|
||||
- Supports image URLs, paths, and base64 encoding
|
||||
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with content, provider_used, and metrics
|
||||
|
||||
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
@@ -408,15 +424,15 @@ class CascadeRouter:
|
||||
content_type = self._detect_content_type(messages)
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
|
||||
errors = []
|
||||
|
||||
|
||||
for provider in self.providers:
|
||||
# Skip disabled providers
|
||||
if not provider.enabled:
|
||||
logger.debug("Skipping %s (disabled)", provider.name)
|
||||
continue
|
||||
|
||||
|
||||
# Skip unhealthy providers (circuit breaker)
|
||||
if provider.status == ProviderStatus.UNHEALTHY:
|
||||
# Check if circuit breaker can close
|
||||
@@ -427,16 +443,16 @@ class CascadeRouter:
|
||||
else:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
|
||||
# Determine which model to use
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback_model = False
|
||||
|
||||
|
||||
# For non-text content, check if model supports it
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and self._mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
|
||||
# Check if selected model supports the required capability
|
||||
if content_type == ContentType.VISION:
|
||||
supports = self._mm_manager.model_supports(
|
||||
@@ -450,16 +466,17 @@ class CascadeRouter:
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model, fallback
|
||||
selected_model,
|
||||
fallback,
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name
|
||||
provider.name,
|
||||
)
|
||||
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
@@ -471,34 +488,35 @@ class CascadeRouter:
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
|
||||
# Success! Update metrics and return
|
||||
self._record_success(provider, result.get("latency_ms", 0))
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"model": result.get(
|
||||
"model", selected_model or provider.get_default_model()
|
||||
),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.warning(
|
||||
"Provider %s attempt %d failed: %s",
|
||||
provider.name, attempt + 1, error_msg
|
||||
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
|
||||
)
|
||||
errors.append(f"{provider.name}: {error_msg}")
|
||||
|
||||
|
||||
if attempt < self.config.max_retries_per_provider - 1:
|
||||
await asyncio.sleep(self.config.retry_delay_seconds)
|
||||
|
||||
|
||||
# All retries failed for this provider
|
||||
self._record_failure(provider)
|
||||
|
||||
|
||||
# All providers failed
|
||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||
|
||||
|
||||
async def _try_provider(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -510,7 +528,7 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
if provider.type == "ollama":
|
||||
result = await self._call_ollama(
|
||||
provider=provider,
|
||||
@@ -545,12 +563,12 @@ class CascadeRouter:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
result["latency_ms"] = latency_ms
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -561,12 +579,12 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
import aiohttp
|
||||
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": transformed_messages,
|
||||
@@ -575,31 +593,31 @@ class CascadeRouter:
|
||||
"temperature": temperature,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
||||
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(url, json=payload) as response:
|
||||
if response.status != 200:
|
||||
text = await response.text()
|
||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||
|
||||
|
||||
data = await response.json()
|
||||
return {
|
||||
"content": data["message"]["content"],
|
||||
"model": model,
|
||||
}
|
||||
|
||||
|
||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
@@ -620,11 +638,11 @@ class CascadeRouter:
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
|
||||
return transformed
|
||||
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -635,13 +653,13 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call OpenAI API."""
|
||||
import openai
|
||||
|
||||
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
@@ -649,14 +667,14 @@ class CascadeRouter:
|
||||
}
|
||||
if max_tokens:
|
||||
kwargs["max_tokens"] = max_tokens
|
||||
|
||||
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -667,12 +685,12 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Call Anthropic API."""
|
||||
import anthropic
|
||||
|
||||
|
||||
client = anthropic.AsyncAnthropic(
|
||||
api_key=provider.api_key,
|
||||
timeout=self.config.timeout_seconds,
|
||||
)
|
||||
|
||||
|
||||
# Convert messages to Anthropic format
|
||||
system_msg = None
|
||||
conversation = []
|
||||
@@ -680,11 +698,13 @@ class CascadeRouter:
|
||||
if msg["role"] == "system":
|
||||
system_msg = msg["content"]
|
||||
else:
|
||||
conversation.append({
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
})
|
||||
|
||||
conversation.append(
|
||||
{
|
||||
"role": msg["role"],
|
||||
"content": msg["content"],
|
||||
}
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"messages": conversation,
|
||||
@@ -693,9 +713,9 @@ class CascadeRouter:
|
||||
}
|
||||
if system_msg:
|
||||
kwargs["system"] = system_msg
|
||||
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
|
||||
|
||||
return {
|
||||
"content": response.content[0].text,
|
||||
"model": response.model,
|
||||
@@ -733,7 +753,7 @@ class CascadeRouter:
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
@@ -741,50 +761,50 @@ class CascadeRouter:
|
||||
provider.metrics.total_latency_ms += latency_ms
|
||||
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures = 0
|
||||
|
||||
|
||||
# Close circuit breaker if half-open
|
||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||
provider.half_open_calls += 1
|
||||
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
||||
self._close_circuit(provider)
|
||||
|
||||
|
||||
# Update status based on error rate
|
||||
if provider.metrics.error_rate < 0.1:
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
elif provider.metrics.error_rate < 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
|
||||
|
||||
def _record_failure(self, provider: Provider) -> None:
|
||||
"""Record a failed request."""
|
||||
provider.metrics.total_requests += 1
|
||||
provider.metrics.failed_requests += 1
|
||||
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
|
||||
provider.metrics.consecutive_failures += 1
|
||||
|
||||
|
||||
# Check if we should open circuit breaker
|
||||
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
||||
self._open_circuit(provider)
|
||||
|
||||
|
||||
# Update status
|
||||
if provider.metrics.error_rate > 0.3:
|
||||
provider.status = ProviderStatus.DEGRADED
|
||||
if provider.metrics.error_rate > 0.5:
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
|
||||
|
||||
def _open_circuit(self, provider: Provider) -> None:
|
||||
"""Open the circuit breaker for a provider."""
|
||||
provider.circuit_state = CircuitState.OPEN
|
||||
provider.circuit_opened_at = time.time()
|
||||
provider.status = ProviderStatus.UNHEALTHY
|
||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||
|
||||
|
||||
def _can_close_circuit(self, provider: Provider) -> bool:
|
||||
"""Check if circuit breaker can transition to half-open."""
|
||||
if provider.circuit_opened_at is None:
|
||||
return False
|
||||
elapsed = time.time() - provider.circuit_opened_at
|
||||
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
||||
|
||||
|
||||
def _close_circuit(self, provider: Provider) -> None:
|
||||
"""Close the circuit breaker (provider healthy again)."""
|
||||
provider.circuit_state = CircuitState.CLOSED
|
||||
@@ -793,7 +813,7 @@ class CascadeRouter:
|
||||
provider.metrics.consecutive_failures = 0
|
||||
provider.status = ProviderStatus.HEALTHY
|
||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
"""Get metrics for all providers."""
|
||||
return {
|
||||
@@ -814,16 +834,20 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current router status."""
|
||||
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
||||
|
||||
|
||||
return {
|
||||
"total_providers": len(self.providers),
|
||||
"healthy_providers": healthy,
|
||||
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
|
||||
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
|
||||
"degraded_providers": sum(
|
||||
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
|
||||
),
|
||||
"unhealthy_providers": sum(
|
||||
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
|
||||
),
|
||||
"providers": [
|
||||
{
|
||||
"name": p.name,
|
||||
@@ -835,7 +859,7 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def generate_with_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -844,21 +868,23 @@ class CascadeRouter:
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
"""Convenience method for vision requests.
|
||||
|
||||
|
||||
Args:
|
||||
prompt: Text prompt about the image
|
||||
image_path: Path to image file
|
||||
model: Vision-capable model (auto-selected if not provided)
|
||||
temperature: Sampling temperature
|
||||
|
||||
|
||||
Returns:
|
||||
Response dict with content and metadata
|
||||
"""
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}
|
||||
]
|
||||
return await self.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class WSEvent:
|
||||
"""A WebSocket event to broadcast to connected clients."""
|
||||
|
||||
event: str
|
||||
data: dict
|
||||
timestamp: str
|
||||
@@ -93,28 +94,42 @@ class WebSocketManager:
|
||||
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
|
||||
|
||||
async def broadcast_task_posted(self, task_id: str, description: str) -> None:
|
||||
await self.broadcast("task_posted", {
|
||||
"task_id": task_id, "description": description,
|
||||
})
|
||||
await self.broadcast(
|
||||
"task_posted",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_bid_submitted(
|
||||
self, task_id: str, agent_id: str, bid_sats: int
|
||||
) -> None:
|
||||
await self.broadcast("bid_submitted", {
|
||||
"task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats,
|
||||
})
|
||||
async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None:
|
||||
await self.broadcast(
|
||||
"bid_submitted",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"bid_sats": bid_sats,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
|
||||
await self.broadcast("task_assigned", {
|
||||
"task_id": task_id, "agent_id": agent_id,
|
||||
})
|
||||
await self.broadcast(
|
||||
"task_assigned",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def broadcast_task_completed(
|
||||
self, task_id: str, agent_id: str, result: str
|
||||
) -> None:
|
||||
await self.broadcast("task_completed", {
|
||||
"task_id": task_id, "agent_id": agent_id, "result": result[:200],
|
||||
})
|
||||
async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None:
|
||||
await self.broadcast(
|
||||
"task_completed",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"result": result[:200],
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
@@ -122,28 +137,28 @@ class WebSocketManager:
|
||||
|
||||
async def broadcast_json(self, data: dict) -> int:
|
||||
"""Broadcast raw JSON data to all connected clients.
|
||||
|
||||
|
||||
Args:
|
||||
data: Dictionary to send as JSON
|
||||
|
||||
|
||||
Returns:
|
||||
Number of clients notified
|
||||
"""
|
||||
message = json.dumps(data)
|
||||
disconnected = []
|
||||
count = 0
|
||||
|
||||
|
||||
for ws in self._connections:
|
||||
try:
|
||||
await ws.send_text(message)
|
||||
count += 1
|
||||
except Exception:
|
||||
disconnected.append(ws)
|
||||
|
||||
|
||||
# Clean up dead connections
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
return count
|
||||
|
||||
@property
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any, Optional
|
||||
|
||||
class PlatformState(Enum):
|
||||
"""Lifecycle state of a chat platform connection."""
|
||||
|
||||
DISCONNECTED = auto()
|
||||
CONNECTING = auto()
|
||||
CONNECTED = auto()
|
||||
@@ -30,13 +31,12 @@ class PlatformState(Enum):
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""Vendor-agnostic representation of a chat message."""
|
||||
|
||||
content: str
|
||||
author: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
message_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
attachments: list[str] = field(default_factory=list)
|
||||
@@ -46,13 +46,12 @@ class ChatMessage:
|
||||
@dataclass
|
||||
class ChatThread:
|
||||
"""Vendor-agnostic representation of a conversation thread."""
|
||||
|
||||
thread_id: str
|
||||
title: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
created_at: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
archived: bool = False
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -61,6 +60,7 @@ class ChatThread:
|
||||
@dataclass
|
||||
class InviteInfo:
|
||||
"""Parsed invite extracted from an image or text."""
|
||||
|
||||
url: str
|
||||
code: str
|
||||
platform: str
|
||||
@@ -71,6 +71,7 @@ class InviteInfo:
|
||||
@dataclass
|
||||
class PlatformStatus:
|
||||
"""Current status of a chat platform connection."""
|
||||
|
||||
platform: str
|
||||
state: PlatformState
|
||||
token_set: bool
|
||||
|
||||
@@ -115,7 +115,9 @@ class InviteParser:
|
||||
"""Strategy 2: Use Ollama vision model for local OCR."""
|
||||
try:
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
except ImportError:
|
||||
logger.debug("httpx not available for Ollama vision.")
|
||||
|
||||
15
src/integrations/chat_bridge/vendors/discord.py
vendored
15
src/integrations/chat_bridge/vendors/discord.py
vendored
@@ -90,10 +90,7 @@ class DiscordVendor(ChatPlatform):
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"discord.py is not installed. "
|
||||
'Run: pip install ".[discord]"'
|
||||
)
|
||||
logger.error("discord.py is not installed. " 'Run: pip install ".[discord]"')
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -267,6 +264,7 @@ class DiscordVendor(ChatPlatform):
|
||||
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
return settings.discord_token or None
|
||||
except Exception:
|
||||
return None
|
||||
@@ -363,9 +361,7 @@ class DiscordVendor(ChatPlatform):
|
||||
# Show typing indicator while the agent processes
|
||||
async with target.typing():
|
||||
run = await asyncio.wait_for(
|
||||
asyncio.to_thread(
|
||||
agent.run, content, stream=False, session_id=session_id
|
||||
),
|
||||
asyncio.to_thread(agent.run, content, stream=False, session_id=session_id),
|
||||
timeout=300,
|
||||
)
|
||||
response = run.content if hasattr(run, "content") else str(run)
|
||||
@@ -374,7 +370,9 @@ class DiscordVendor(ChatPlatform):
|
||||
response = "Sorry, that took too long. Please try a simpler request."
|
||||
except Exception as exc:
|
||||
logger.error("Discord: agent.run() failed: %s", exc)
|
||||
response = "I'm having trouble reaching my language model right now. Please try again shortly."
|
||||
response = (
|
||||
"I'm having trouble reaching my language model right now. Please try again shortly."
|
||||
)
|
||||
|
||||
# Strip hallucinated tool-call JSON and chain-of-thought narration
|
||||
from timmy.session import _clean_response
|
||||
@@ -408,6 +406,7 @@ class DiscordVendor(ChatPlatform):
|
||||
|
||||
# Create a thread from this message
|
||||
from config import settings
|
||||
|
||||
thread_name = f"{settings.agent_name} | {message.author.display_name}"
|
||||
thread = await message.create_thread(
|
||||
name=thread_name[:100],
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Inbound: Paperclip → Timmy ──────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,8 @@ import logging
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from config import settings
|
||||
from integrations.paperclip.bridge import PaperclipBridge, bridge as default_bridge
|
||||
from integrations.paperclip.bridge import PaperclipBridge
|
||||
from integrations.paperclip.bridge import bridge as default_bridge
|
||||
from integrations.paperclip.models import PaperclipIssue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,9 +31,8 @@ logger = logging.getLogger(__name__)
|
||||
class Orchestrator(Protocol):
|
||||
"""Anything with an ``execute_task`` matching Timmy's orchestrator."""
|
||||
|
||||
async def execute_task(
|
||||
self, task_id: str, description: str, context: dict
|
||||
) -> Any: ...
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def _wrap_orchestrator(orch: Orchestrator) -> Callable:
|
||||
@@ -125,7 +125,9 @@ class TaskRunner:
|
||||
# Mark the issue as done
|
||||
return await self.bridge.close_issue(issue.id, comment=None)
|
||||
|
||||
async def create_follow_up(self, original: PaperclipIssue, result: str) -> Optional[PaperclipIssue]:
|
||||
async def create_follow_up(
|
||||
self, original: PaperclipIssue, result: str
|
||||
) -> Optional[PaperclipIssue]:
|
||||
"""Create a recursive follow-up task for Timmy.
|
||||
|
||||
Timmy muses about task automation and writes a follow-up issue
|
||||
|
||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class ShortcutAction:
|
||||
"""Describes a Siri Shortcut action for the setup guide."""
|
||||
|
||||
name: str
|
||||
endpoint: str
|
||||
method: str
|
||||
|
||||
@@ -54,6 +54,7 @@ class TelegramBot:
|
||||
return from_file
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
return settings.telegram_token or None
|
||||
except Exception:
|
||||
return None
|
||||
@@ -94,10 +95,7 @@ class TelegramBot:
|
||||
filters,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"python-telegram-bot is not installed. "
|
||||
'Run: pip install ".[telegram]"'
|
||||
)
|
||||
logger.error("python-telegram-bot is not installed. " 'Run: pip install ".[telegram]"')
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -149,6 +147,7 @@ class TelegramBot:
|
||||
user_text = update.message.text
|
||||
try:
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
agent = create_timmy()
|
||||
run = await asyncio.to_thread(agent.run, user_text, stream=False)
|
||||
response = run.content if hasattr(run, "content") else str(run)
|
||||
|
||||
@@ -15,8 +15,8 @@ Intents:
|
||||
- unknown: Unrecognized intent
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -35,47 +35,68 @@ class Intent:
|
||||
|
||||
_PATTERNS: list[tuple[str, re.Pattern, float]] = [
|
||||
# Status queries
|
||||
("status", re.compile(
|
||||
r"\b(status|health|how are you|are you (running|online|alive)|check)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.9),
|
||||
|
||||
(
|
||||
"status",
|
||||
re.compile(
|
||||
r"\b(status|health|how are you|are you (running|online|alive)|check)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.9,
|
||||
),
|
||||
# Swarm commands
|
||||
("swarm", re.compile(
|
||||
r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.85),
|
||||
|
||||
(
|
||||
"swarm",
|
||||
re.compile(
|
||||
r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.85,
|
||||
),
|
||||
# Task commands
|
||||
("task", re.compile(
|
||||
r"\b(task|assign|create task|new task|post task|bid)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.85),
|
||||
|
||||
(
|
||||
"task",
|
||||
re.compile(
|
||||
r"\b(task|assign|create task|new task|post task|bid)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.85,
|
||||
),
|
||||
# Help
|
||||
("help", re.compile(
|
||||
r"\b(help|commands?|what can you do|capabilities)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.9),
|
||||
|
||||
(
|
||||
"help",
|
||||
re.compile(
|
||||
r"\b(help|commands?|what can you do|capabilities)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.9,
|
||||
),
|
||||
# Voice settings
|
||||
("voice", re.compile(
|
||||
r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.85),
|
||||
|
||||
(
|
||||
"voice",
|
||||
re.compile(
|
||||
r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.85,
|
||||
),
|
||||
# Code modification / self-modify
|
||||
("code", re.compile(
|
||||
r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b"
|
||||
r"|\bself[- ]?modify\b"
|
||||
r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b",
|
||||
re.IGNORECASE,
|
||||
), 0.9),
|
||||
(
|
||||
"code",
|
||||
re.compile(
|
||||
r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b"
|
||||
r"|\bself[- ]?modify\b"
|
||||
r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
0.9,
|
||||
),
|
||||
]
|
||||
|
||||
# Keywords for entity extraction
|
||||
_ENTITY_PATTERNS = {
|
||||
"agent_name": re.compile(r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE),
|
||||
"agent_name": re.compile(
|
||||
r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE
|
||||
),
|
||||
"task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE),
|
||||
"number": re.compile(r"\b(\d+)\b"),
|
||||
"target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE),
|
||||
|
||||
@@ -17,8 +17,8 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from spark import memory as spark_memory
|
||||
from spark import eidos as spark_eidos
|
||||
from spark import memory as spark_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,10 +29,11 @@ _MIN_EVENTS = 3
|
||||
@dataclass
|
||||
class Advisory:
|
||||
"""A single ranked recommendation."""
|
||||
category: str # agent_performance, bid_optimization, etc.
|
||||
priority: float # 0.0–1.0 (higher = more urgent)
|
||||
title: str # Short headline
|
||||
detail: str # Longer explanation
|
||||
|
||||
category: str # agent_performance, bid_optimization, etc.
|
||||
priority: float # 0.0–1.0 (higher = more urgent)
|
||||
title: str # Short headline
|
||||
detail: str # Longer explanation
|
||||
suggested_action: str # What to do about it
|
||||
subject: Optional[str] = None # agent_id or None for system-level
|
||||
evidence_count: int = 0 # Number of supporting events
|
||||
@@ -47,15 +48,17 @@ def generate_advisories() -> list[Advisory]:
|
||||
|
||||
event_count = spark_memory.count_events()
|
||||
if event_count < _MIN_EVENTS:
|
||||
advisories.append(Advisory(
|
||||
category="system_health",
|
||||
priority=0.3,
|
||||
title="Insufficient data",
|
||||
detail=f"Only {event_count} events captured. "
|
||||
f"Spark needs at least {_MIN_EVENTS} events to generate insights.",
|
||||
suggested_action="Run more swarm tasks to build intelligence.",
|
||||
evidence_count=event_count,
|
||||
))
|
||||
advisories.append(
|
||||
Advisory(
|
||||
category="system_health",
|
||||
priority=0.3,
|
||||
title="Insufficient data",
|
||||
detail=f"Only {event_count} events captured. "
|
||||
f"Spark needs at least {_MIN_EVENTS} events to generate insights.",
|
||||
suggested_action="Run more swarm tasks to build intelligence.",
|
||||
evidence_count=event_count,
|
||||
)
|
||||
)
|
||||
return advisories
|
||||
|
||||
advisories.extend(_check_failure_patterns())
|
||||
@@ -82,18 +85,20 @@ def _check_failure_patterns() -> list[Advisory]:
|
||||
|
||||
for aid, count in agent_failures.items():
|
||||
if count >= 2:
|
||||
results.append(Advisory(
|
||||
category="failure_prevention",
|
||||
priority=min(1.0, 0.5 + count * 0.15),
|
||||
title=f"Agent {aid[:8]} has {count} failures",
|
||||
detail=f"Agent {aid[:8]}... has failed {count} recent tasks. "
|
||||
f"This pattern may indicate a capability mismatch or "
|
||||
f"configuration issue.",
|
||||
suggested_action=f"Review task types assigned to {aid[:8]}... "
|
||||
f"and consider adjusting routing preferences.",
|
||||
subject=aid,
|
||||
evidence_count=count,
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="failure_prevention",
|
||||
priority=min(1.0, 0.5 + count * 0.15),
|
||||
title=f"Agent {aid[:8]} has {count} failures",
|
||||
detail=f"Agent {aid[:8]}... has failed {count} recent tasks. "
|
||||
f"This pattern may indicate a capability mismatch or "
|
||||
f"configuration issue.",
|
||||
suggested_action=f"Review task types assigned to {aid[:8]}... "
|
||||
f"and consider adjusting routing preferences.",
|
||||
subject=aid,
|
||||
evidence_count=count,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -128,27 +133,31 @@ def _check_agent_performance() -> list[Advisory]:
|
||||
|
||||
rate = wins / total
|
||||
if rate >= 0.8 and total >= 3:
|
||||
results.append(Advisory(
|
||||
category="agent_performance",
|
||||
priority=0.6,
|
||||
title=f"Agent {aid[:8]} excels ({rate:.0%} success)",
|
||||
detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks "
|
||||
f"successfully. Consider routing more tasks to this agent.",
|
||||
suggested_action="Increase task routing weight for this agent.",
|
||||
subject=aid,
|
||||
evidence_count=total,
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="agent_performance",
|
||||
priority=0.6,
|
||||
title=f"Agent {aid[:8]} excels ({rate:.0%} success)",
|
||||
detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks "
|
||||
f"successfully. Consider routing more tasks to this agent.",
|
||||
suggested_action="Increase task routing weight for this agent.",
|
||||
subject=aid,
|
||||
evidence_count=total,
|
||||
)
|
||||
)
|
||||
elif rate <= 0.3 and total >= 3:
|
||||
results.append(Advisory(
|
||||
category="agent_performance",
|
||||
priority=0.75,
|
||||
title=f"Agent {aid[:8]} struggling ({rate:.0%} success)",
|
||||
detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. "
|
||||
f"May need different task types or capability updates.",
|
||||
suggested_action="Review this agent's capabilities and assigned task types.",
|
||||
subject=aid,
|
||||
evidence_count=total,
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="agent_performance",
|
||||
priority=0.75,
|
||||
title=f"Agent {aid[:8]} struggling ({rate:.0%} success)",
|
||||
detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. "
|
||||
f"May need different task types or capability updates.",
|
||||
suggested_action="Review this agent's capabilities and assigned task types.",
|
||||
subject=aid,
|
||||
evidence_count=total,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -181,27 +190,31 @@ def _check_bid_patterns() -> list[Advisory]:
|
||||
spread = max_bid - min_bid
|
||||
|
||||
if spread > avg_bid * 1.5:
|
||||
results.append(Advisory(
|
||||
category="bid_optimization",
|
||||
priority=0.5,
|
||||
title=f"Wide bid spread ({min_bid}–{max_bid} sats)",
|
||||
detail=f"Bids range from {min_bid} to {max_bid} sats "
|
||||
f"(avg {avg_bid:.0f}). Large spread may indicate "
|
||||
f"inefficient auction dynamics.",
|
||||
suggested_action="Review agent bid strategies for consistency.",
|
||||
evidence_count=len(bid_amounts),
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="bid_optimization",
|
||||
priority=0.5,
|
||||
title=f"Wide bid spread ({min_bid}–{max_bid} sats)",
|
||||
detail=f"Bids range from {min_bid} to {max_bid} sats "
|
||||
f"(avg {avg_bid:.0f}). Large spread may indicate "
|
||||
f"inefficient auction dynamics.",
|
||||
suggested_action="Review agent bid strategies for consistency.",
|
||||
evidence_count=len(bid_amounts),
|
||||
)
|
||||
)
|
||||
|
||||
if avg_bid > 70:
|
||||
results.append(Advisory(
|
||||
category="bid_optimization",
|
||||
priority=0.45,
|
||||
title=f"High average bid ({avg_bid:.0f} sats)",
|
||||
detail=f"The swarm average bid is {avg_bid:.0f} sats across "
|
||||
f"{len(bid_amounts)} bids. This may be above optimal.",
|
||||
suggested_action="Consider adjusting base bid rates for persona agents.",
|
||||
evidence_count=len(bid_amounts),
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="bid_optimization",
|
||||
priority=0.45,
|
||||
title=f"High average bid ({avg_bid:.0f} sats)",
|
||||
detail=f"The swarm average bid is {avg_bid:.0f} sats across "
|
||||
f"{len(bid_amounts)} bids. This may be above optimal.",
|
||||
suggested_action="Consider adjusting base bid rates for persona agents.",
|
||||
evidence_count=len(bid_amounts),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -216,27 +229,31 @@ def _check_prediction_accuracy() -> list[Advisory]:
|
||||
|
||||
avg = stats["avg_accuracy"]
|
||||
if avg < 0.4:
|
||||
results.append(Advisory(
|
||||
category="system_health",
|
||||
priority=0.65,
|
||||
title=f"Low prediction accuracy ({avg:.0%})",
|
||||
detail=f"EIDOS predictions have averaged {avg:.0%} accuracy "
|
||||
f"over {stats['evaluated']} evaluations. The learning "
|
||||
f"model needs more data or the swarm behaviour is changing.",
|
||||
suggested_action="Continue running tasks; accuracy should improve "
|
||||
"as the model accumulates more training data.",
|
||||
evidence_count=stats["evaluated"],
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="system_health",
|
||||
priority=0.65,
|
||||
title=f"Low prediction accuracy ({avg:.0%})",
|
||||
detail=f"EIDOS predictions have averaged {avg:.0%} accuracy "
|
||||
f"over {stats['evaluated']} evaluations. The learning "
|
||||
f"model needs more data or the swarm behaviour is changing.",
|
||||
suggested_action="Continue running tasks; accuracy should improve "
|
||||
"as the model accumulates more training data.",
|
||||
evidence_count=stats["evaluated"],
|
||||
)
|
||||
)
|
||||
elif avg >= 0.75:
|
||||
results.append(Advisory(
|
||||
category="system_health",
|
||||
priority=0.3,
|
||||
title=f"Strong prediction accuracy ({avg:.0%})",
|
||||
detail=f"EIDOS predictions are performing well at {avg:.0%} "
|
||||
f"average accuracy over {stats['evaluated']} evaluations.",
|
||||
suggested_action="No action needed. Spark intelligence is learning effectively.",
|
||||
evidence_count=stats["evaluated"],
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="system_health",
|
||||
priority=0.3,
|
||||
title=f"Strong prediction accuracy ({avg:.0%})",
|
||||
detail=f"EIDOS predictions are performing well at {avg:.0%} "
|
||||
f"average accuracy over {stats['evaluated']} evaluations.",
|
||||
suggested_action="No action needed. Spark intelligence is learning effectively.",
|
||||
evidence_count=stats["evaluated"],
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -247,14 +264,16 @@ def _check_system_activity() -> list[Advisory]:
|
||||
recent = spark_memory.get_events(limit=5)
|
||||
|
||||
if not recent:
|
||||
results.append(Advisory(
|
||||
category="system_health",
|
||||
priority=0.4,
|
||||
title="No swarm activity detected",
|
||||
detail="Spark has not captured any events. "
|
||||
"The swarm may be idle or Spark event capture is not active.",
|
||||
suggested_action="Post a task to the swarm to activate the pipeline.",
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="system_health",
|
||||
priority=0.4,
|
||||
title="No swarm activity detected",
|
||||
detail="Spark has not captured any events. "
|
||||
"The swarm may be idle or Spark event capture is not active.",
|
||||
suggested_action="Post a task to the swarm to activate the pipeline.",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
# Check event type distribution
|
||||
@@ -265,14 +284,16 @@ def _check_system_activity() -> list[Advisory]:
|
||||
|
||||
if "task_completed" not in type_counts and "task_failed" not in type_counts:
|
||||
if type_counts.get("task_posted", 0) > 3:
|
||||
results.append(Advisory(
|
||||
category="system_health",
|
||||
priority=0.6,
|
||||
title="Tasks posted but none completing",
|
||||
detail=f"{type_counts.get('task_posted', 0)} tasks posted "
|
||||
f"but no completions or failures recorded.",
|
||||
suggested_action="Check agent availability and auction configuration.",
|
||||
evidence_count=type_counts.get("task_posted", 0),
|
||||
))
|
||||
results.append(
|
||||
Advisory(
|
||||
category="system_health",
|
||||
priority=0.6,
|
||||
title="Tasks posted but none completing",
|
||||
detail=f"{type_counts.get('task_posted', 0)} tasks posted "
|
||||
f"but no completions or failures recorded.",
|
||||
suggested_action="Check agent availability and auction configuration.",
|
||||
evidence_count=type_counts.get("task_posted", 0),
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -29,12 +29,13 @@ DB_PATH = Path("data/spark.db")
|
||||
@dataclass
|
||||
class Prediction:
|
||||
"""A prediction made by the EIDOS loop."""
|
||||
|
||||
id: str
|
||||
task_id: str
|
||||
prediction_type: str # outcome, best_agent, bid_range
|
||||
predicted_value: str # JSON-encoded prediction
|
||||
actual_value: Optional[str] # JSON-encoded actual (filled on evaluation)
|
||||
accuracy: Optional[float] # 0.0–1.0 (filled on evaluation)
|
||||
prediction_type: str # outcome, best_agent, bid_range
|
||||
predicted_value: str # JSON-encoded prediction
|
||||
actual_value: Optional[str] # JSON-encoded actual (filled on evaluation)
|
||||
accuracy: Optional[float] # 0.0–1.0 (filled on evaluation)
|
||||
created_at: str
|
||||
evaluated_at: Optional[str]
|
||||
|
||||
@@ -57,18 +58,15 @@ def _get_conn() -> sqlite3.Connection:
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
# ── Prediction phase ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def predict_task_outcome(
|
||||
task_id: str,
|
||||
task_description: str,
|
||||
@@ -104,12 +102,8 @@ def predict_task_outcome(
|
||||
|
||||
if best_agent:
|
||||
prediction["likely_winner"] = best_agent
|
||||
prediction["success_probability"] = round(
|
||||
min(1.0, 0.5 + best_rate * 0.4), 2
|
||||
)
|
||||
prediction["reasoning"] = (
|
||||
f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
|
||||
)
|
||||
prediction["success_probability"] = round(min(1.0, 0.5 + best_rate * 0.4), 2)
|
||||
prediction["reasoning"] = f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
|
||||
|
||||
# Adjust bid range from history
|
||||
all_bids = []
|
||||
@@ -144,6 +138,7 @@ def predict_task_outcome(
|
||||
|
||||
# ── Evaluation phase ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def evaluate_prediction(
|
||||
task_id: str,
|
||||
actual_winner: Optional[str],
|
||||
@@ -242,6 +237,7 @@ def _compute_accuracy(predicted: dict, actual: dict) -> float:
|
||||
|
||||
# ── Query helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_predictions(
|
||||
task_id: Optional[str] = None,
|
||||
evaluated_only: bool = False,
|
||||
|
||||
@@ -76,7 +76,10 @@ class SparkEngine:
|
||||
return event_id
|
||||
|
||||
def on_bid_submitted(
|
||||
self, task_id: str, agent_id: str, bid_sats: int,
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
bid_sats: int,
|
||||
) -> Optional[str]:
|
||||
"""Capture a bid event."""
|
||||
if not self._enabled:
|
||||
@@ -90,12 +93,13 @@ class SparkEngine:
|
||||
data=json.dumps({"bid_sats": bid_sats}),
|
||||
)
|
||||
|
||||
logger.debug("Spark: captured bid %s→%s (%d sats)",
|
||||
agent_id[:8], task_id[:8], bid_sats)
|
||||
logger.debug("Spark: captured bid %s→%s (%d sats)", agent_id[:8], task_id[:8], bid_sats)
|
||||
return event_id
|
||||
|
||||
def on_task_assigned(
|
||||
self, task_id: str, agent_id: str,
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Capture a task-assigned event."""
|
||||
if not self._enabled:
|
||||
@@ -108,8 +112,7 @@ class SparkEngine:
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
logger.debug("Spark: captured assignment %s→%s",
|
||||
task_id[:8], agent_id[:8])
|
||||
logger.debug("Spark: captured assignment %s→%s", task_id[:8], agent_id[:8])
|
||||
return event_id
|
||||
|
||||
def on_task_completed(
|
||||
@@ -128,10 +131,12 @@ class SparkEngine:
|
||||
description=f"Task completed by {agent_id[:8]}",
|
||||
agent_id=agent_id,
|
||||
task_id=task_id,
|
||||
data=json.dumps({
|
||||
"result_length": len(result),
|
||||
"winning_bid": winning_bid,
|
||||
}),
|
||||
data=json.dumps(
|
||||
{
|
||||
"result_length": len(result),
|
||||
"winning_bid": winning_bid,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Evaluate EIDOS prediction
|
||||
@@ -154,8 +159,7 @@ class SparkEngine:
|
||||
# Consolidate memory if enough events for this agent
|
||||
self._maybe_consolidate(agent_id)
|
||||
|
||||
logger.debug("Spark: captured completion %s by %s",
|
||||
task_id[:8], agent_id[:8])
|
||||
logger.debug("Spark: captured completion %s by %s", task_id[:8], agent_id[:8])
|
||||
return event_id
|
||||
|
||||
def on_task_failed(
|
||||
@@ -186,8 +190,7 @@ class SparkEngine:
|
||||
# Failures always worth consolidating
|
||||
self._maybe_consolidate(agent_id)
|
||||
|
||||
logger.debug("Spark: captured failure %s by %s",
|
||||
task_id[:8], agent_id[:8])
|
||||
logger.debug("Spark: captured failure %s by %s", task_id[:8], agent_id[:8])
|
||||
return event_id
|
||||
|
||||
def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]:
|
||||
@@ -288,7 +291,7 @@ class SparkEngine:
|
||||
memory_type="pattern",
|
||||
subject=agent_id,
|
||||
content=f"Agent {agent_id[:8]} has a strong track record: "
|
||||
f"{len(completions)}/{total} tasks completed successfully.",
|
||||
f"{len(completions)}/{total} tasks completed successfully.",
|
||||
confidence=min(0.95, 0.6 + total * 0.05),
|
||||
source_events=total,
|
||||
)
|
||||
@@ -297,7 +300,7 @@ class SparkEngine:
|
||||
memory_type="anomaly",
|
||||
subject=agent_id,
|
||||
content=f"Agent {agent_id[:8]} is struggling: only "
|
||||
f"{len(completions)}/{total} tasks completed.",
|
||||
f"{len(completions)}/{total} tasks completed.",
|
||||
confidence=min(0.95, 0.6 + total * 0.05),
|
||||
source_events=total,
|
||||
)
|
||||
@@ -347,6 +350,7 @@ class SparkEngine:
|
||||
def _create_engine() -> SparkEngine:
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
return SparkEngine(enabled=settings.spark_enabled)
|
||||
except Exception:
|
||||
return SparkEngine(enabled=True)
|
||||
|
||||
@@ -28,25 +28,27 @@ IMPORTANCE_HIGH = 0.8
|
||||
@dataclass
|
||||
class SparkEvent:
|
||||
"""A single captured swarm event."""
|
||||
|
||||
id: str
|
||||
event_type: str # task_posted, bid, assignment, completion, failure
|
||||
event_type: str # task_posted, bid, assignment, completion, failure
|
||||
agent_id: Optional[str]
|
||||
task_id: Optional[str]
|
||||
description: str
|
||||
data: str # JSON payload
|
||||
importance: float # 0.0–1.0
|
||||
data: str # JSON payload
|
||||
importance: float # 0.0–1.0
|
||||
created_at: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparkMemory:
|
||||
"""A consolidated memory distilled from event patterns."""
|
||||
|
||||
id: str
|
||||
memory_type: str # pattern, insight, anomaly
|
||||
subject: str # agent_id or "system"
|
||||
content: str # Human-readable insight
|
||||
confidence: float # 0.0–1.0
|
||||
source_events: int # How many events contributed
|
||||
memory_type: str # pattern, insight, anomaly
|
||||
subject: str # agent_id or "system"
|
||||
content: str # Human-readable insight
|
||||
confidence: float # 0.0–1.0
|
||||
source_events: int # How many events contributed
|
||||
created_at: str
|
||||
expires_at: Optional[str]
|
||||
|
||||
@@ -83,24 +85,17 @@ def _get_conn() -> sqlite3.Connection:
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)"
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
# ── Importance scoring ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def score_importance(event_type: str, data: dict) -> float:
|
||||
"""Compute importance score for an event (0.0–1.0).
|
||||
|
||||
@@ -132,6 +127,7 @@ def score_importance(event_type: str, data: dict) -> float:
|
||||
|
||||
# ── Event recording ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def record_event(
|
||||
event_type: str,
|
||||
description: str,
|
||||
@@ -142,6 +138,7 @@ def record_event(
|
||||
) -> str:
|
||||
"""Record a swarm event. Returns the event id."""
|
||||
import json
|
||||
|
||||
event_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@@ -224,6 +221,7 @@ def count_events(event_type: Optional[str] = None) -> int:
|
||||
|
||||
# ── Memory consolidation ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def store_memory(
|
||||
memory_type: str,
|
||||
subject: str,
|
||||
|
||||
@@ -73,7 +73,8 @@ def _ensure_db() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
@@ -83,7 +84,8 @@ def _ensure_db() -> sqlite3.Connection:
|
||||
data TEXT DEFAULT '{}',
|
||||
timestamp TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -119,8 +121,15 @@ def log_event(
|
||||
db.execute(
|
||||
"INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
(entry.id, event_type.value, source, task_id, agent_id,
|
||||
json.dumps(data or {}), entry.timestamp),
|
||||
(
|
||||
entry.id,
|
||||
event_type.value,
|
||||
source,
|
||||
task_id,
|
||||
agent_id,
|
||||
json.dumps(data or {}),
|
||||
entry.timestamp,
|
||||
),
|
||||
)
|
||||
db.commit()
|
||||
finally:
|
||||
@@ -131,6 +140,7 @@ def log_event(
|
||||
# Broadcast to WebSocket clients (non-blocking)
|
||||
try:
|
||||
from infrastructure.events.broadcaster import event_broadcaster
|
||||
|
||||
event_broadcaster.broadcast_sync(entry)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -157,13 +167,15 @@ def get_task_events(task_id: str, limit: int = 50) -> list[EventLogEntry]:
|
||||
et = EventType(r["event_type"])
|
||||
except ValueError:
|
||||
et = EventType.SYSTEM_INFO
|
||||
entries.append(EventLogEntry(
|
||||
id=r["id"],
|
||||
event_type=et,
|
||||
source=r["source"],
|
||||
timestamp=r["timestamp"],
|
||||
data=json.loads(r["data"]) if r["data"] else {},
|
||||
task_id=r["task_id"],
|
||||
agent_id=r["agent_id"],
|
||||
))
|
||||
entries.append(
|
||||
EventLogEntry(
|
||||
id=r["id"],
|
||||
event_type=et,
|
||||
source=r["source"],
|
||||
timestamp=r["timestamp"],
|
||||
data=json.loads(r["data"]) if r["data"] else {},
|
||||
task_id=r["task_id"],
|
||||
agent_id=r["agent_id"],
|
||||
)
|
||||
)
|
||||
return entries
|
||||
|
||||
@@ -29,7 +29,8 @@ def _ensure_db() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("""
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
title TEXT NOT NULL,
|
||||
@@ -42,7 +43,8 @@ def _ensure_db() -> sqlite3.Connection:
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -103,9 +105,7 @@ def get_task_summary_for_briefing() -> dict:
|
||||
"""Return a summary of task counts by status for the morning briefing."""
|
||||
db = _ensure_db()
|
||||
try:
|
||||
rows = db.execute(
|
||||
"SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status"
|
||||
).fetchall()
|
||||
rows = db.execute("SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status").fetchall()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -69,16 +69,16 @@ def _check_model_available(model_name: str) -> bool:
|
||||
|
||||
def _pull_model(model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
import urllib.request
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
@@ -86,7 +86,7 @@ def _pull_model(model_name: str) -> bool:
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
@@ -94,7 +94,7 @@ def _pull_model(model_name: str) -> bool:
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
@@ -106,53 +106,44 @@ def _resolve_model_with_fallback(
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Resolve model with automatic pulling and fallback.
|
||||
|
||||
|
||||
Args:
|
||||
requested_model: Preferred model to use
|
||||
require_vision: Whether the model needs vision capabilities
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
model = requested_model or settings.ollama_model
|
||||
|
||||
|
||||
# Check if requested model is available
|
||||
if _check_model_available(model):
|
||||
logger.debug("Using available model: %s", model)
|
||||
return model, False
|
||||
|
||||
|
||||
# Try to pull the requested model
|
||||
if auto_pull:
|
||||
logger.info("Model %s not available locally, attempting to pull...", model)
|
||||
if _pull_model(model):
|
||||
return model, False
|
||||
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
||||
|
||||
|
||||
# Use appropriate fallback chain
|
||||
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
||||
|
||||
|
||||
for fallback_model in fallback_chain:
|
||||
if _check_model_available(fallback_model):
|
||||
logger.warning(
|
||||
"Using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
logger.warning("Using fallback model %s (requested: %s)", fallback_model, model)
|
||||
return fallback_model, True
|
||||
|
||||
|
||||
# Try to pull the fallback
|
||||
if auto_pull and _pull_model(fallback_model):
|
||||
logger.info(
|
||||
"Pulled and using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
logger.info("Pulled and using fallback model %s (requested: %s)", fallback_model, model)
|
||||
return fallback_model, True
|
||||
|
||||
|
||||
# Absolute last resort - return the requested model and hope for the best
|
||||
logger.error(
|
||||
"No models available in fallback chain. Requested: %s",
|
||||
model
|
||||
)
|
||||
logger.error("No models available in fallback chain. Requested: %s", model)
|
||||
return model, False
|
||||
|
||||
|
||||
@@ -190,6 +181,7 @@ def _resolve_backend(requested: str | None) -> str:
|
||||
|
||||
# "auto" path — lazy import to keep startup fast and tests clean.
|
||||
from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon
|
||||
|
||||
if is_apple_silicon() and airllm_available():
|
||||
return "airllm"
|
||||
return "ollama"
|
||||
@@ -215,14 +207,17 @@ def create_timmy(
|
||||
|
||||
if resolved == "claude":
|
||||
from timmy.backends import ClaudeBackend
|
||||
|
||||
return ClaudeBackend()
|
||||
|
||||
if resolved == "grok":
|
||||
from timmy.backends import GrokBackend
|
||||
|
||||
return GrokBackend()
|
||||
|
||||
if resolved == "airllm":
|
||||
from timmy.backends import TimmyAirLLMAgent
|
||||
|
||||
return TimmyAirLLMAgent(model_size=size)
|
||||
|
||||
# Default: Ollama via Agno.
|
||||
@@ -236,16 +231,16 @@ def create_timmy(
|
||||
# If Ollama is completely unreachable, fall back to Claude if available
|
||||
if not _check_model_available(model_name):
|
||||
from timmy.backends import claude_available
|
||||
|
||||
if claude_available():
|
||||
logger.warning(
|
||||
"Ollama unreachable — falling back to Claude backend"
|
||||
)
|
||||
logger.warning("Ollama unreachable — falling back to Claude backend")
|
||||
from timmy.backends import ClaudeBackend
|
||||
|
||||
return ClaudeBackend()
|
||||
|
||||
if is_fallback:
|
||||
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
||||
|
||||
|
||||
use_tools = _model_supports_tools(model_name)
|
||||
|
||||
# Conditionally include tools — small models get none
|
||||
@@ -259,6 +254,7 @@ def create_timmy(
|
||||
# Try to load memory context
|
||||
try:
|
||||
from timmy.memory_system import memory_system
|
||||
|
||||
memory_context = memory_system.get_system_context()
|
||||
if memory_context:
|
||||
# Truncate if too long — smaller budget for small models
|
||||
@@ -290,32 +286,32 @@ def create_timmy(
|
||||
|
||||
class TimmyWithMemory:
|
||||
"""Agent wrapper with explicit three-tier memory management."""
|
||||
|
||||
|
||||
def __init__(self, db_file: str = "timmy.db") -> None:
|
||||
from timmy.memory_system import memory_system
|
||||
|
||||
|
||||
self.agent = create_timmy(db_file=db_file)
|
||||
self.memory = memory_system
|
||||
self.session_active = True
|
||||
|
||||
|
||||
# Store initial context for reference
|
||||
self.initial_context = self.memory.get_system_context()
|
||||
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
"""Simple chat interface that tracks in memory."""
|
||||
# Check for user facts to extract
|
||||
self._extract_and_store_facts(message)
|
||||
|
||||
|
||||
# Run agent
|
||||
result = self.agent.run(message, stream=False)
|
||||
response_text = result.content if hasattr(result, "content") else str(result)
|
||||
|
||||
|
||||
return response_text
|
||||
|
||||
|
||||
def _extract_and_store_facts(self, message: str) -> None:
|
||||
"""Extract user facts from message and store in memory."""
|
||||
message_lower = message.lower()
|
||||
|
||||
|
||||
# Extract name
|
||||
name_patterns = [
|
||||
("my name is ", 11),
|
||||
@@ -323,7 +319,7 @@ class TimmyWithMemory:
|
||||
("i am ", 5),
|
||||
("call me ", 8),
|
||||
]
|
||||
|
||||
|
||||
for pattern, offset in name_patterns:
|
||||
if pattern in message_lower:
|
||||
idx = message_lower.find(pattern) + offset
|
||||
@@ -332,7 +328,7 @@ class TimmyWithMemory:
|
||||
self.memory.update_user_fact("Name", name)
|
||||
self.memory.record_decision(f"Learned user's name: {name}")
|
||||
break
|
||||
|
||||
|
||||
# Extract preferences
|
||||
pref_patterns = [
|
||||
("i like ", "Likes"),
|
||||
@@ -341,7 +337,7 @@ class TimmyWithMemory:
|
||||
("i don't like ", "Dislikes"),
|
||||
("i hate ", "Dislikes"),
|
||||
]
|
||||
|
||||
|
||||
for pattern, category in pref_patterns:
|
||||
if pattern in message_lower:
|
||||
idx = message_lower.find(pattern) + len(pattern)
|
||||
@@ -349,16 +345,16 @@ class TimmyWithMemory:
|
||||
if pref and len(pref) > 3:
|
||||
self.memory.record_open_item(f"User {category.lower()}: {pref}")
|
||||
break
|
||||
|
||||
|
||||
def end_session(self, summary: str = "Session completed") -> None:
|
||||
"""End session and write handoff."""
|
||||
if self.session_active:
|
||||
self.memory.end_session(summary)
|
||||
self.session_active = False
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.end_session()
|
||||
return False
|
||||
|
||||
@@ -16,38 +16,41 @@ Architecture:
|
||||
All methods return effects that can be logged, audited, and replayed.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
class PerceptionType(Enum):
|
||||
"""Types of sensory input an agent can receive."""
|
||||
TEXT = auto() # Natural language
|
||||
IMAGE = auto() # Visual input
|
||||
AUDIO = auto() # Sound/speech
|
||||
SENSOR = auto() # Temperature, distance, etc.
|
||||
MOTION = auto() # Accelerometer, gyroscope
|
||||
NETWORK = auto() # API calls, messages
|
||||
INTERNAL = auto() # Self-monitoring (battery, temp)
|
||||
|
||||
TEXT = auto() # Natural language
|
||||
IMAGE = auto() # Visual input
|
||||
AUDIO = auto() # Sound/speech
|
||||
SENSOR = auto() # Temperature, distance, etc.
|
||||
MOTION = auto() # Accelerometer, gyroscope
|
||||
NETWORK = auto() # API calls, messages
|
||||
INTERNAL = auto() # Self-monitoring (battery, temp)
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
"""Types of actions an agent can perform."""
|
||||
TEXT = auto() # Generate text response
|
||||
SPEAK = auto() # Text-to-speech
|
||||
MOVE = auto() # Physical movement
|
||||
GRIP = auto() # Manipulate objects
|
||||
CALL = auto() # API/network call
|
||||
EMIT = auto() # Signal/light/sound
|
||||
SLEEP = auto() # Power management
|
||||
|
||||
TEXT = auto() # Generate text response
|
||||
SPEAK = auto() # Text-to-speech
|
||||
MOVE = auto() # Physical movement
|
||||
GRIP = auto() # Manipulate objects
|
||||
CALL = auto() # API/network call
|
||||
EMIT = auto() # Signal/light/sound
|
||||
SLEEP = auto() # Power management
|
||||
|
||||
|
||||
class AgentCapability(Enum):
|
||||
"""High-level capabilities a TimAgent may possess."""
|
||||
|
||||
REASONING = "reasoning"
|
||||
CODING = "coding"
|
||||
WRITING = "writing"
|
||||
@@ -63,15 +66,16 @@ class AgentCapability(Enum):
|
||||
@dataclass(frozen=True)
|
||||
class AgentIdentity:
|
||||
"""Immutable identity for an agent instance.
|
||||
|
||||
|
||||
This persists across sessions and substrates. If Timmy moves
|
||||
from cloud to robot, the identity follows.
|
||||
"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
version: str
|
||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
|
||||
@classmethod
|
||||
def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity":
|
||||
"""Generate a new unique identity."""
|
||||
@@ -85,16 +89,17 @@ class AgentIdentity:
|
||||
@dataclass
|
||||
class Perception:
|
||||
"""A sensory input to the agent.
|
||||
|
||||
Substrate-agnostic representation. A camera image and a
|
||||
|
||||
Substrate-agnostic representation. A camera image and a
|
||||
LiDAR point cloud are both Perception instances.
|
||||
"""
|
||||
|
||||
type: PerceptionType
|
||||
data: Any # Content depends on type
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
source: str = "unknown" # e.g., "camera_1", "microphone", "user_input"
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@classmethod
|
||||
def text(cls, content: str, source: str = "user") -> "Perception":
|
||||
"""Factory for text perception."""
|
||||
@@ -103,7 +108,7 @@ class Perception:
|
||||
data=content,
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception":
|
||||
"""Factory for sensor readings."""
|
||||
@@ -117,16 +122,17 @@ class Perception:
|
||||
@dataclass
|
||||
class Action:
|
||||
"""An action the agent intends to perform.
|
||||
|
||||
|
||||
Actions are effects — they describe what should happen,
|
||||
not how. The substrate implements the "how."
|
||||
"""
|
||||
|
||||
type: ActionType
|
||||
payload: Any # Action-specific data
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
confidence: float = 1.0 # 0-1, agent's certainty
|
||||
deadline: Optional[str] = None # When action must complete
|
||||
|
||||
|
||||
@classmethod
|
||||
def respond(cls, text: str, confidence: float = 1.0) -> "Action":
|
||||
"""Factory for text response action."""
|
||||
@@ -135,7 +141,7 @@ class Action:
|
||||
payload=text,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action":
|
||||
"""Factory for movement action (x, y, z meters)."""
|
||||
@@ -148,10 +154,11 @@ class Action:
|
||||
@dataclass
|
||||
class Memory:
|
||||
"""A stored experience or fact.
|
||||
|
||||
|
||||
Memories are substrate-agnostic. A conversation history
|
||||
and a video recording are both Memory instances.
|
||||
"""
|
||||
|
||||
id: str
|
||||
content: Any
|
||||
created_at: str
|
||||
@@ -159,7 +166,7 @@ class Memory:
|
||||
last_accessed: Optional[str] = None
|
||||
importance: float = 0.5 # 0-1, for pruning decisions
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Mark memory as accessed."""
|
||||
self.access_count += 1
|
||||
@@ -169,6 +176,7 @@ class Memory:
|
||||
@dataclass
|
||||
class Communication:
|
||||
"""A message to/from another agent or human."""
|
||||
|
||||
sender: str
|
||||
recipient: str
|
||||
content: Any
|
||||
@@ -179,132 +187,132 @@ class Communication:
|
||||
|
||||
class TimAgent(ABC):
|
||||
"""Abstract base class for all Timmy agent implementations.
|
||||
|
||||
|
||||
This is the substrate-agnostic interface. Implementations:
|
||||
- OllamaAgent: LLM-based reasoning (today)
|
||||
- RobotAgent: Physical embodiment (future)
|
||||
- SimulationAgent: Virtual environment (future)
|
||||
|
||||
|
||||
Usage:
|
||||
agent = OllamaAgent(identity) # Today's implementation
|
||||
|
||||
|
||||
perception = Perception.text("Hello Timmy")
|
||||
memory = agent.perceive(perception)
|
||||
|
||||
|
||||
action = agent.reason("How should I respond?")
|
||||
result = agent.act(action)
|
||||
|
||||
|
||||
agent.remember(memory) # Store for future
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, identity: AgentIdentity) -> None:
|
||||
self._identity = identity
|
||||
self._capabilities: set[AgentCapability] = set()
|
||||
self._state: dict[str, Any] = {}
|
||||
|
||||
|
||||
@property
|
||||
def identity(self) -> AgentIdentity:
|
||||
"""Return this agent's immutable identity."""
|
||||
return self._identity
|
||||
|
||||
|
||||
@property
|
||||
def capabilities(self) -> set[AgentCapability]:
|
||||
"""Return set of supported capabilities."""
|
||||
return self._capabilities.copy()
|
||||
|
||||
|
||||
def has_capability(self, capability: AgentCapability) -> bool:
|
||||
"""Check if agent supports a capability."""
|
||||
return capability in self._capabilities
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def perceive(self, perception: Perception) -> Memory:
|
||||
"""Process sensory input and create a memory.
|
||||
|
||||
|
||||
This is the entry point for all agent interaction.
|
||||
A text message, camera frame, or temperature reading
|
||||
all enter through perceive().
|
||||
|
||||
|
||||
Args:
|
||||
perception: Sensory input
|
||||
|
||||
|
||||
Returns:
|
||||
Memory: Stored representation of the perception
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def reason(self, query: str, context: list[Memory]) -> Action:
|
||||
"""Reason about a situation and decide on action.
|
||||
|
||||
|
||||
This is where "thinking" happens. The agent uses its
|
||||
substrate-appropriate reasoning (LLM, neural net, rules)
|
||||
to decide what to do.
|
||||
|
||||
|
||||
Args:
|
||||
query: What to reason about
|
||||
context: Relevant memories for context
|
||||
|
||||
|
||||
Returns:
|
||||
Action: What the agent decides to do
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def act(self, action: Action) -> Any:
|
||||
"""Execute an action in the substrate.
|
||||
|
||||
|
||||
This is where the abstract action becomes concrete:
|
||||
- TEXT → Generate LLM response
|
||||
- MOVE → Send motor commands
|
||||
- SPEAK → Call TTS engine
|
||||
|
||||
|
||||
Args:
|
||||
action: The action to execute
|
||||
|
||||
|
||||
Returns:
|
||||
Result of the action (substrate-specific)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def remember(self, memory: Memory) -> None:
|
||||
"""Store a memory for future retrieval.
|
||||
|
||||
|
||||
The storage mechanism depends on substrate:
|
||||
- Cloud: SQLite, vector DB
|
||||
- Robot: Local flash storage
|
||||
- Hybrid: Synced with conflict resolution
|
||||
|
||||
|
||||
Args:
|
||||
memory: Experience to store
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
||||
"""Retrieve relevant memories.
|
||||
|
||||
|
||||
Args:
|
||||
query: What to search for
|
||||
limit: Maximum memories to return
|
||||
|
||||
|
||||
Returns:
|
||||
List of relevant memories, sorted by relevance
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def communicate(self, message: Communication) -> bool:
|
||||
"""Send/receive communication with another agent.
|
||||
|
||||
|
||||
Args:
|
||||
message: Message to send
|
||||
|
||||
|
||||
Returns:
|
||||
True if communication succeeded
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_state(self) -> dict[str, Any]:
|
||||
"""Get current agent state for monitoring/debugging."""
|
||||
return {
|
||||
@@ -312,7 +320,7 @@ class TimAgent(ABC):
|
||||
"capabilities": list(self._capabilities),
|
||||
"state": self._state.copy(),
|
||||
}
|
||||
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Graceful shutdown. Persist state, close connections."""
|
||||
# Override in subclass for cleanup
|
||||
@@ -321,7 +329,7 @@ class TimAgent(ABC):
|
||||
|
||||
class AgentEffect:
|
||||
"""Log entry for agent actions — for audit and replay.
|
||||
|
||||
|
||||
The complete history of an agent's life can be captured
|
||||
as a sequence of AgentEffects. This enables:
|
||||
- Debugging: What did the agent see and do?
|
||||
@@ -329,40 +337,46 @@ class AgentEffect:
|
||||
- Replay: Reconstruct agent state from log
|
||||
- Training: Learn from agent experiences
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, log_path: Optional[str] = None) -> None:
|
||||
self._effects: list[dict] = []
|
||||
self._log_path = log_path
|
||||
|
||||
|
||||
def log_perceive(self, perception: Perception, memory_id: str) -> None:
|
||||
"""Log a perception event."""
|
||||
self._effects.append({
|
||||
"type": "perceive",
|
||||
"perception_type": perception.type.name,
|
||||
"source": perception.source,
|
||||
"memory_id": memory_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
self._effects.append(
|
||||
{
|
||||
"type": "perceive",
|
||||
"perception_type": perception.type.name,
|
||||
"source": perception.source,
|
||||
"memory_id": memory_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def log_reason(self, query: str, action_type: ActionType) -> None:
|
||||
"""Log a reasoning event."""
|
||||
self._effects.append({
|
||||
"type": "reason",
|
||||
"query": query,
|
||||
"action_type": action_type.name,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
self._effects.append(
|
||||
{
|
||||
"type": "reason",
|
||||
"query": query,
|
||||
"action_type": action_type.name,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def log_act(self, action: Action, result: Any) -> None:
|
||||
"""Log an action event."""
|
||||
self._effects.append({
|
||||
"type": "act",
|
||||
"action_type": action.type.name,
|
||||
"confidence": action.confidence,
|
||||
"result_type": type(result).__name__,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
self._effects.append(
|
||||
{
|
||||
"type": "act",
|
||||
"action_type": action.type.name,
|
||||
"confidence": action.confidence,
|
||||
"result_type": type(result).__name__,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
def export(self) -> list[dict]:
|
||||
"""Export effect log for analysis."""
|
||||
return self._effects.copy()
|
||||
|
||||
@@ -7,10 +7,10 @@ between the old codebase and the new embodiment-ready architecture.
|
||||
Usage:
|
||||
from timmy.agent_core import AgentIdentity, Perception
|
||||
from timmy.agent_core.ollama_adapter import OllamaAgent
|
||||
|
||||
|
||||
identity = AgentIdentity.generate("Timmy")
|
||||
agent = OllamaAgent(identity)
|
||||
|
||||
|
||||
perception = Perception.text("Hello!")
|
||||
memory = agent.perceive(perception)
|
||||
action = agent.reason("How should I respond?", [memory])
|
||||
@@ -19,27 +19,27 @@ Usage:
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from timmy.agent import _resolve_model_with_fallback, create_timmy
|
||||
from timmy.agent_core.interface import (
|
||||
AgentCapability,
|
||||
AgentIdentity,
|
||||
Perception,
|
||||
PerceptionType,
|
||||
Action,
|
||||
ActionType,
|
||||
Memory,
|
||||
Communication,
|
||||
TimAgent,
|
||||
AgentCapability,
|
||||
AgentEffect,
|
||||
AgentIdentity,
|
||||
Communication,
|
||||
Memory,
|
||||
Perception,
|
||||
PerceptionType,
|
||||
TimAgent,
|
||||
)
|
||||
from timmy.agent import create_timmy, _resolve_model_with_fallback
|
||||
|
||||
|
||||
class OllamaAgent(TimAgent):
|
||||
"""TimAgent implementation using local Ollama LLM.
|
||||
|
||||
|
||||
This is the production agent for Timmy Time v2. It uses
|
||||
Ollama for reasoning and SQLite for memory persistence.
|
||||
|
||||
|
||||
Capabilities:
|
||||
- REASONING: LLM-based inference
|
||||
- CODING: Code generation and analysis
|
||||
@@ -47,7 +47,7 @@ class OllamaAgent(TimAgent):
|
||||
- ANALYSIS: Data processing and insights
|
||||
- COMMUNICATION: Multi-agent messaging
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
identity: AgentIdentity,
|
||||
@@ -56,7 +56,7 @@ class OllamaAgent(TimAgent):
|
||||
require_vision: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Ollama-based agent.
|
||||
|
||||
|
||||
Args:
|
||||
identity: Agent identity (persistent across sessions)
|
||||
model: Ollama model to use (auto-resolves with fallback)
|
||||
@@ -64,23 +64,24 @@ class OllamaAgent(TimAgent):
|
||||
require_vision: Whether to select a vision-capable model
|
||||
"""
|
||||
super().__init__(identity)
|
||||
|
||||
|
||||
# Resolve model with automatic pulling and fallback
|
||||
resolved_model, is_fallback = _resolve_model_with_fallback(
|
||||
requested_model=model,
|
||||
require_vision=require_vision,
|
||||
auto_pull=True,
|
||||
)
|
||||
|
||||
|
||||
if is_fallback:
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).info(
|
||||
"OllamaAdapter using fallback model %s", resolved_model
|
||||
)
|
||||
|
||||
|
||||
# Initialize underlying Ollama agent
|
||||
self._timmy = create_timmy(model=resolved_model)
|
||||
|
||||
|
||||
# Set capabilities based on what Ollama can do
|
||||
self._capabilities = {
|
||||
AgentCapability.REASONING,
|
||||
@@ -89,17 +90,17 @@ class OllamaAgent(TimAgent):
|
||||
AgentCapability.ANALYSIS,
|
||||
AgentCapability.COMMUNICATION,
|
||||
}
|
||||
|
||||
|
||||
# Effect logging for audit/replay
|
||||
self._effect_log = AgentEffect(effect_log) if effect_log else None
|
||||
|
||||
|
||||
# Simple in-memory working memory (short term)
|
||||
self._working_memory: list[Memory] = []
|
||||
self._max_working_memory = 10
|
||||
|
||||
|
||||
def perceive(self, perception: Perception) -> Memory:
|
||||
"""Process perception and store in memory.
|
||||
|
||||
|
||||
For text perceptions, we might do light preprocessing
|
||||
(summarization, keyword extraction) before storage.
|
||||
"""
|
||||
@@ -114,28 +115,28 @@ class OllamaAgent(TimAgent):
|
||||
created_at=perception.timestamp,
|
||||
tags=self._extract_tags(perception),
|
||||
)
|
||||
|
||||
|
||||
# Add to working memory
|
||||
self._working_memory.append(memory)
|
||||
if len(self._working_memory) > self._max_working_memory:
|
||||
self._working_memory.pop(0) # FIFO eviction
|
||||
|
||||
|
||||
# Log effect
|
||||
if self._effect_log:
|
||||
self._effect_log.log_perceive(perception, memory.id)
|
||||
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def reason(self, query: str, context: list[Memory]) -> Action:
|
||||
"""Use LLM to reason and decide on action.
|
||||
|
||||
|
||||
This is where the Ollama agent does its work. We construct
|
||||
a prompt from the query and context, then interpret the
|
||||
response as an action.
|
||||
"""
|
||||
# Build context string from memories
|
||||
context_str = self._format_context(context)
|
||||
|
||||
|
||||
# Construct prompt
|
||||
prompt = f"""You are {self._identity.name}, an AI assistant.
|
||||
|
||||
@@ -145,30 +146,30 @@ Context from previous interactions:
|
||||
Current query: {query}
|
||||
|
||||
Respond naturally and helpfully."""
|
||||
|
||||
|
||||
# Run LLM inference
|
||||
result = self._timmy.run(prompt, stream=False)
|
||||
response_text = result.content if hasattr(result, "content") else str(result)
|
||||
|
||||
|
||||
# Create text response action
|
||||
action = Action.respond(response_text, confidence=0.9)
|
||||
|
||||
|
||||
# Log effect
|
||||
if self._effect_log:
|
||||
self._effect_log.log_reason(query, action.type)
|
||||
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def act(self, action: Action) -> Any:
|
||||
"""Execute action in the Ollama substrate.
|
||||
|
||||
|
||||
For text actions, the "execution" is just returning the
|
||||
text (already generated during reasoning). For future
|
||||
action types (MOVE, SPEAK), this would trigger the
|
||||
appropriate Ollama tool calls.
|
||||
"""
|
||||
result = None
|
||||
|
||||
|
||||
if action.type == ActionType.TEXT:
|
||||
result = action.payload
|
||||
elif action.type == ActionType.SPEAK:
|
||||
@@ -179,13 +180,13 @@ Respond naturally and helpfully."""
|
||||
result = {"status": "not_implemented", "payload": action.payload}
|
||||
else:
|
||||
result = {"error": f"Action type {action.type} not supported by OllamaAgent"}
|
||||
|
||||
|
||||
# Log effect
|
||||
if self._effect_log:
|
||||
self._effect_log.log_act(action, result)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remember(self, memory: Memory) -> None:
|
||||
"""Store memory in working memory.
|
||||
|
||||
@@ -200,48 +201,48 @@ Respond naturally and helpfully."""
|
||||
# Evict oldest if over capacity
|
||||
if len(self._working_memory) > self._max_working_memory:
|
||||
self._working_memory.pop(0)
|
||||
|
||||
|
||||
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
||||
"""Retrieve relevant memories.
|
||||
|
||||
|
||||
Simple keyword matching for now. Future: vector similarity.
|
||||
"""
|
||||
query_lower = query.lower()
|
||||
scored = []
|
||||
|
||||
|
||||
for memory in self._working_memory:
|
||||
score = 0
|
||||
content_str = str(memory.content).lower()
|
||||
|
||||
|
||||
# Simple keyword overlap
|
||||
query_words = set(query_lower.split())
|
||||
content_words = set(content_str.split())
|
||||
overlap = len(query_words & content_words)
|
||||
score += overlap
|
||||
|
||||
|
||||
# Boost recent memories
|
||||
score += memory.importance
|
||||
|
||||
|
||||
scored.append((score, memory))
|
||||
|
||||
|
||||
# Sort by score descending
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
|
||||
# Return top N
|
||||
return [m for _, m in scored[:limit]]
|
||||
|
||||
|
||||
def communicate(self, message: Communication) -> bool:
|
||||
"""Send message to another agent.
|
||||
|
||||
|
||||
Swarm comms removed — inter-agent communication will be handled
|
||||
by the unified brain memory layer.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def _extract_tags(self, perception: Perception) -> list[str]:
|
||||
"""Extract searchable tags from perception."""
|
||||
tags = [perception.type.name, perception.source]
|
||||
|
||||
|
||||
if perception.type == PerceptionType.TEXT:
|
||||
# Simple keyword extraction
|
||||
text = str(perception.data).lower()
|
||||
@@ -249,14 +250,14 @@ Respond naturally and helpfully."""
|
||||
for kw in keywords:
|
||||
if kw in text:
|
||||
tags.append(kw)
|
||||
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def _format_context(self, memories: list[Memory]) -> str:
|
||||
"""Format memories into context string for prompt."""
|
||||
if not memories:
|
||||
return "No previous context."
|
||||
|
||||
|
||||
parts = []
|
||||
for mem in memories[-5:]: # Last 5 memories
|
||||
if isinstance(mem.content, dict):
|
||||
@@ -264,9 +265,9 @@ Respond naturally and helpfully."""
|
||||
parts.append(f"- {data}")
|
||||
else:
|
||||
parts.append(f"- {mem.content}")
|
||||
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def get_effect_log(self) -> Optional[list[dict]]:
|
||||
"""Export effect log if logging is enabled."""
|
||||
if self._effect_log:
|
||||
|
||||
@@ -30,9 +30,11 @@ logger = logging.getLogger(__name__)
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgenticStep:
|
||||
"""Result of a single step in the agentic loop."""
|
||||
|
||||
step_num: int
|
||||
description: str
|
||||
result: str
|
||||
@@ -43,6 +45,7 @@ class AgenticStep:
|
||||
@dataclass
|
||||
class AgenticResult:
|
||||
"""Final result of the entire agentic loop."""
|
||||
|
||||
task_id: str
|
||||
task: str
|
||||
summary: str
|
||||
@@ -55,6 +58,7 @@ class AgenticResult:
|
||||
# Agent factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_loop_agent():
|
||||
"""Create a fresh agent for the agentic loop.
|
||||
|
||||
@@ -62,6 +66,7 @@ def _get_loop_agent():
|
||||
dedicated session so it doesn't pollute the main chat history.
|
||||
"""
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
return create_timmy()
|
||||
|
||||
|
||||
@@ -85,6 +90,7 @@ def _parse_steps(plan_text: str) -> list[str]:
|
||||
# Core loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def run_agentic_loop(
|
||||
task: str,
|
||||
*,
|
||||
@@ -146,12 +152,15 @@ async def run_agentic_loop(
|
||||
was_truncated = planned_steps > total_steps
|
||||
|
||||
# Broadcast plan
|
||||
await _broadcast_progress("agentic.plan_ready", {
|
||||
"task_id": task_id,
|
||||
"task": task,
|
||||
"steps": steps,
|
||||
"total": total_steps,
|
||||
})
|
||||
await _broadcast_progress(
|
||||
"agentic.plan_ready",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"task": task,
|
||||
"steps": steps,
|
||||
"total": total_steps,
|
||||
},
|
||||
)
|
||||
|
||||
# ── Phase 2: Execution ─────────────────────────────────────────────────
|
||||
completed_results: list[str] = []
|
||||
@@ -175,6 +184,7 @@ async def run_agentic_loop(
|
||||
|
||||
# Clean the response
|
||||
from timmy.session import _clean_response
|
||||
|
||||
step_result = _clean_response(step_result)
|
||||
|
||||
step = AgenticStep(
|
||||
@@ -188,13 +198,16 @@ async def run_agentic_loop(
|
||||
completed_results.append(f"Step {i}: {step_result[:200]}")
|
||||
|
||||
# Broadcast progress
|
||||
await _broadcast_progress("agentic.step_complete", {
|
||||
"task_id": task_id,
|
||||
"step": i,
|
||||
"total": total_steps,
|
||||
"description": step_desc,
|
||||
"result": step_result[:200],
|
||||
})
|
||||
await _broadcast_progress(
|
||||
"agentic.step_complete",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"step": i,
|
||||
"total": total_steps,
|
||||
"description": step_desc,
|
||||
"result": step_result[:200],
|
||||
},
|
||||
)
|
||||
|
||||
if on_progress:
|
||||
await on_progress(step_desc, i, total_steps)
|
||||
@@ -210,11 +223,16 @@ async def run_agentic_loop(
|
||||
)
|
||||
try:
|
||||
adapt_run = await asyncio.to_thread(
|
||||
agent.run, adapt_prompt, stream=False,
|
||||
agent.run,
|
||||
adapt_prompt,
|
||||
stream=False,
|
||||
session_id=f"{session_id}_adapt{i}",
|
||||
)
|
||||
adapt_result = adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run)
|
||||
adapt_result = (
|
||||
adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run)
|
||||
)
|
||||
from timmy.session import _clean_response
|
||||
|
||||
adapt_result = _clean_response(adapt_result)
|
||||
|
||||
step = AgenticStep(
|
||||
@@ -227,14 +245,17 @@ async def run_agentic_loop(
|
||||
result.steps.append(step)
|
||||
completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}")
|
||||
|
||||
await _broadcast_progress("agentic.step_adapted", {
|
||||
"task_id": task_id,
|
||||
"step": i,
|
||||
"total": total_steps,
|
||||
"description": step_desc,
|
||||
"error": str(exc),
|
||||
"adaptation": adapt_result[:200],
|
||||
})
|
||||
await _broadcast_progress(
|
||||
"agentic.step_adapted",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"step": i,
|
||||
"total": total_steps,
|
||||
"description": step_desc,
|
||||
"error": str(exc),
|
||||
"adaptation": adapt_result[:200],
|
||||
},
|
||||
)
|
||||
|
||||
if on_progress:
|
||||
await on_progress(f"[Adapted] {step_desc}", i, total_steps)
|
||||
@@ -259,11 +280,16 @@ async def run_agentic_loop(
|
||||
)
|
||||
try:
|
||||
summary_run = await asyncio.to_thread(
|
||||
agent.run, summary_prompt, stream=False,
|
||||
agent.run,
|
||||
summary_prompt,
|
||||
stream=False,
|
||||
session_id=f"{session_id}_summary",
|
||||
)
|
||||
result.summary = summary_run.content if hasattr(summary_run, "content") else str(summary_run)
|
||||
result.summary = (
|
||||
summary_run.content if hasattr(summary_run, "content") else str(summary_run)
|
||||
)
|
||||
from timmy.session import _clean_response
|
||||
|
||||
result.summary = _clean_response(result.summary)
|
||||
except Exception as exc:
|
||||
logger.error("Agentic loop summary failed: %s", exc)
|
||||
@@ -281,13 +307,16 @@ async def run_agentic_loop(
|
||||
|
||||
result.total_duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||
|
||||
await _broadcast_progress("agentic.task_complete", {
|
||||
"task_id": task_id,
|
||||
"status": result.status,
|
||||
"steps_completed": len(result.steps),
|
||||
"summary": result.summary[:300],
|
||||
"duration_ms": result.total_duration_ms,
|
||||
})
|
||||
await _broadcast_progress(
|
||||
"agentic.task_complete",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": result.status,
|
||||
"steps_completed": len(result.steps),
|
||||
"summary": result.summary[:300],
|
||||
"duration_ms": result.total_duration_ms,
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -296,10 +325,12 @@ async def run_agentic_loop(
|
||||
# WebSocket broadcast helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _broadcast_progress(event: str, data: dict) -> None:
|
||||
"""Broadcast agentic loop progress via WebSocket (best-effort)."""
|
||||
try:
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
|
||||
await ws_manager.broadcast(event, data)
|
||||
except Exception:
|
||||
logger.debug("Agentic loop: WS broadcast failed for %s", event)
|
||||
|
||||
@@ -18,7 +18,7 @@ from agno.agent import Agent
|
||||
from agno.models.ollama import Ollama
|
||||
|
||||
from config import settings
|
||||
from infrastructure.events.bus import EventBus, Event
|
||||
from infrastructure.events.bus import Event, EventBus
|
||||
|
||||
try:
|
||||
from mcp.registry import tool_registry
|
||||
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Base class for all sub-agents."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -43,15 +43,15 @@ class BaseAgent(ABC):
|
||||
self.name = name
|
||||
self.role = role
|
||||
self.tools = tools or []
|
||||
|
||||
|
||||
# Create Agno agent
|
||||
self.agent = self._create_agent(system_prompt)
|
||||
|
||||
|
||||
# Event bus for communication
|
||||
self.event_bus: Optional[EventBus] = None
|
||||
|
||||
|
||||
logger.info("%s agent initialized (id: %s)", name, agent_id)
|
||||
|
||||
|
||||
def _create_agent(self, system_prompt: str) -> Agent:
|
||||
"""Create the underlying Agno agent."""
|
||||
# Get tools from registry
|
||||
@@ -60,7 +60,7 @@ class BaseAgent(ABC):
|
||||
handler = tool_registry.get_handler(tool_name)
|
||||
if handler:
|
||||
tool_instances.append(handler)
|
||||
|
||||
|
||||
return Agent(
|
||||
name=self.name,
|
||||
model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300),
|
||||
@@ -71,19 +71,19 @@ class BaseAgent(ABC):
|
||||
markdown=True,
|
||||
telemetry=settings.telemetry_enabled,
|
||||
)
|
||||
|
||||
|
||||
def connect_event_bus(self, bus: EventBus) -> None:
|
||||
"""Connect to the event bus for inter-agent communication."""
|
||||
self.event_bus = bus
|
||||
|
||||
|
||||
# Subscribe to relevant events
|
||||
bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message)
|
||||
bus.subscribe("agent.task.assigned")(self._handle_task_assignment)
|
||||
|
||||
|
||||
async def _handle_direct_message(self, event: Event) -> None:
|
||||
"""Handle direct messages to this agent."""
|
||||
logger.debug("%s received message: %s", self.name, event.type)
|
||||
|
||||
|
||||
async def _handle_task_assignment(self, event: Event) -> None:
|
||||
"""Handle task assignment events."""
|
||||
assigned_agent = event.data.get("agent_id")
|
||||
@@ -91,41 +91,43 @@ class BaseAgent(ABC):
|
||||
task_id = event.data.get("task_id")
|
||||
description = event.data.get("description", "")
|
||||
logger.info("%s assigned task %s: %s", self.name, task_id, description[:50])
|
||||
|
||||
|
||||
# Execute the task
|
||||
await self.execute_task(task_id, description, event.data)
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a task assigned to this agent.
|
||||
|
||||
|
||||
Must be implemented by subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def run(self, message: str) -> str:
|
||||
"""Run the agent with a message.
|
||||
|
||||
|
||||
Returns:
|
||||
Agent response
|
||||
"""
|
||||
result = self.agent.run(message, stream=False)
|
||||
response = result.content if hasattr(result, "content") else str(result)
|
||||
|
||||
|
||||
# Emit completion event
|
||||
if self.event_bus:
|
||||
await self.event_bus.publish(Event(
|
||||
type=f"agent.{self.agent_id}.response",
|
||||
source=self.agent_id,
|
||||
data={"input": message, "output": response},
|
||||
))
|
||||
|
||||
await self.event_bus.publish(
|
||||
Event(
|
||||
type=f"agent.{self.agent_id}.response",
|
||||
source=self.agent_id,
|
||||
data={"input": message, "output": response},
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_capabilities(self) -> list[str]:
|
||||
"""Get list of capabilities this agent provides."""
|
||||
return self.tools
|
||||
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get current agent status."""
|
||||
return {
|
||||
|
||||
@@ -12,9 +12,9 @@ from typing import Any, Optional
|
||||
from agno.agent import Agent
|
||||
from agno.models.ollama import Ollama
|
||||
|
||||
from timmy.agents.base import BaseAgent, SubAgent
|
||||
from config import settings
|
||||
from infrastructure.events.bus import EventBus, event_bus
|
||||
from timmy.agents.base import BaseAgent, SubAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,7 +29,7 @@ _timmy_context: dict[str, Any] = {
|
||||
|
||||
async def _load_hands_async() -> list[dict]:
|
||||
"""Async helper to load hands.
|
||||
|
||||
|
||||
Hands registry removed — hand definitions live in TOML files under hands/.
|
||||
This will be rewired to read from brain memory.
|
||||
"""
|
||||
@@ -42,7 +42,7 @@ def build_timmy_context_sync() -> dict[str, Any]:
|
||||
Gathers git commits, active sub-agents, and hot memory.
|
||||
"""
|
||||
global _timmy_context
|
||||
|
||||
|
||||
ctx: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"repo_root": settings.repo_root,
|
||||
@@ -51,45 +51,52 @@ def build_timmy_context_sync() -> dict[str, Any]:
|
||||
"hands": [],
|
||||
"memory": "",
|
||||
}
|
||||
|
||||
|
||||
# 1. Get recent git commits
|
||||
try:
|
||||
from tools.git_tools import git_log
|
||||
|
||||
result = git_log(max_count=20)
|
||||
if result.get("success"):
|
||||
commits = result.get("commits", [])
|
||||
ctx["git_log"] = "\n".join([
|
||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||
for c in commits[:20]
|
||||
])
|
||||
ctx["git_log"] = "\n".join(
|
||||
[f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits[:20]]
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load git log for context: %s", exc)
|
||||
ctx["git_log"] = "(Git log unavailable)"
|
||||
|
||||
|
||||
# 2. Get active sub-agents
|
||||
try:
|
||||
from swarm import registry as swarm_registry
|
||||
|
||||
conn = swarm_registry._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT id, name, status, capabilities FROM agents ORDER BY name"
|
||||
).fetchall()
|
||||
ctx["agents"] = [
|
||||
{"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]}
|
||||
{
|
||||
"id": r["id"],
|
||||
"name": r["name"],
|
||||
"status": r["status"],
|
||||
"capabilities": r["capabilities"],
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load agents for context: %s", exc)
|
||||
ctx["agents"] = []
|
||||
|
||||
|
||||
# 3. Read hot memory (via HotMemory to auto-create if missing)
|
||||
try:
|
||||
from timmy.memory_system import memory_system
|
||||
|
||||
ctx["memory"] = memory_system.hot.read()[:2000]
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load memory for context: %s", exc)
|
||||
ctx["memory"] = "(Memory unavailable)"
|
||||
|
||||
|
||||
_timmy_context.update(ctx)
|
||||
logger.info("Context built (sync): %d agents", len(ctx["agents"]))
|
||||
return ctx
|
||||
@@ -110,21 +117,31 @@ build_timmy_context = build_timmy_context_sync
|
||||
|
||||
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
||||
"""Format the system prompt with dynamic context."""
|
||||
|
||||
|
||||
# Format agents list
|
||||
agents_list = "\n".join([
|
||||
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
||||
for a in context.get("agents", [])
|
||||
]) or "(No agents registered yet)"
|
||||
|
||||
agents_list = (
|
||||
"\n".join(
|
||||
[
|
||||
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
||||
for a in context.get("agents", [])
|
||||
]
|
||||
)
|
||||
or "(No agents registered yet)"
|
||||
)
|
||||
|
||||
# Format hands list
|
||||
hands_list = "\n".join([
|
||||
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
||||
for h in context.get("hands", [])
|
||||
]) or "(No hands configured)"
|
||||
|
||||
repo_root = context.get('repo_root', settings.repo_root)
|
||||
|
||||
hands_list = (
|
||||
"\n".join(
|
||||
[
|
||||
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
||||
for h in context.get("hands", [])
|
||||
]
|
||||
)
|
||||
or "(No hands configured)"
|
||||
)
|
||||
|
||||
repo_root = context.get("repo_root", settings.repo_root)
|
||||
|
||||
context_block = f"""
|
||||
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
|
||||
|
||||
@@ -149,10 +166,10 @@ def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
||||
### Hot Memory:
|
||||
{context.get('memory', '(unavailable)')[:1000]}
|
||||
"""
|
||||
|
||||
|
||||
# Replace {REPO_ROOT} placeholder with actual path
|
||||
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
|
||||
|
||||
|
||||
# Insert context after the first line
|
||||
lines = base_prompt.split("\n")
|
||||
if lines:
|
||||
@@ -227,63 +244,71 @@ class TimmyOrchestrator(BaseAgent):
|
||||
name="Orchestrator",
|
||||
role="orchestrator",
|
||||
system_prompt=formatted_prompt,
|
||||
tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write", "system_status"],
|
||||
tools=[
|
||||
"web_search",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"python",
|
||||
"memory_search",
|
||||
"memory_write",
|
||||
"system_status",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# Sub-agent registry
|
||||
self.sub_agents: dict[str, BaseAgent] = {}
|
||||
|
||||
|
||||
# Session tracking for init behavior
|
||||
self._session_initialized = False
|
||||
self._session_context: dict[str, Any] = {}
|
||||
self._context_fully_loaded = False
|
||||
|
||||
|
||||
# Connect to event bus
|
||||
self.connect_event_bus(event_bus)
|
||||
|
||||
|
||||
logger.info("Orchestrator initialized with context-aware prompt")
|
||||
|
||||
|
||||
def register_sub_agent(self, agent: BaseAgent) -> None:
|
||||
"""Register a sub-agent with the orchestrator."""
|
||||
self.sub_agents[agent.agent_id] = agent
|
||||
agent.connect_event_bus(event_bus)
|
||||
logger.info("Registered sub-agent: %s", agent.name)
|
||||
|
||||
|
||||
async def _session_init(self) -> None:
|
||||
"""Initialize session context on first user message.
|
||||
|
||||
|
||||
Silently reads git log and AGENTS.md to ground the orchestrator in real data.
|
||||
This runs once per session before the first response.
|
||||
"""
|
||||
if self._session_initialized:
|
||||
return
|
||||
|
||||
|
||||
logger.debug("Running session init...")
|
||||
|
||||
|
||||
# Load full context including hands if not already done
|
||||
if not self._context_fully_loaded:
|
||||
await build_timmy_context_async()
|
||||
self._context_fully_loaded = True
|
||||
|
||||
|
||||
# Read recent git log --oneline -15 from repo root
|
||||
try:
|
||||
from tools.git_tools import git_log
|
||||
|
||||
git_result = git_log(max_count=15)
|
||||
if git_result.get("success"):
|
||||
commits = git_result.get("commits", [])
|
||||
self._session_context["git_log_commits"] = commits
|
||||
# Format as oneline for easy reading
|
||||
self._session_context["git_log_oneline"] = "\n".join([
|
||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||
for c in commits
|
||||
])
|
||||
self._session_context["git_log_oneline"] = "\n".join(
|
||||
[f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits]
|
||||
)
|
||||
logger.debug(f"Session init: loaded {len(commits)} commits from git log")
|
||||
else:
|
||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||
except Exception as exc:
|
||||
logger.warning("Session init: could not read git log: %s", exc)
|
||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||
|
||||
|
||||
# Read AGENTS.md for self-awareness
|
||||
try:
|
||||
agents_md_path = Path(settings.repo_root) / "AGENTS.md"
|
||||
@@ -291,7 +316,7 @@ class TimmyOrchestrator(BaseAgent):
|
||||
self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
|
||||
except Exception as exc:
|
||||
logger.warning("Session init: could not read AGENTS.md: %s", exc)
|
||||
|
||||
|
||||
# Read CHANGELOG for recent changes
|
||||
try:
|
||||
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
|
||||
@@ -299,11 +324,13 @@ class TimmyOrchestrator(BaseAgent):
|
||||
self._session_context["changelog"] = changelog_path.read_text()[:2000]
|
||||
except Exception:
|
||||
pass # Changelog is optional
|
||||
|
||||
|
||||
# Build session-specific context block for the prompt
|
||||
recent_changes = self._session_context.get("git_log_oneline", "")
|
||||
if recent_changes and recent_changes != "Git log unavailable":
|
||||
self._session_context["recent_changes_block"] = f"""
|
||||
self._session_context[
|
||||
"recent_changes_block"
|
||||
] = f"""
|
||||
## Recent Changes to Your Codebase (last 15 commits):
|
||||
```
|
||||
{recent_changes}
|
||||
@@ -312,17 +339,17 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
||||
"""
|
||||
else:
|
||||
self._session_context["recent_changes_block"] = ""
|
||||
|
||||
|
||||
self._session_initialized = True
|
||||
logger.debug("Session init complete")
|
||||
|
||||
|
||||
def _get_enhanced_system_prompt(self) -> str:
|
||||
"""Get system prompt enhanced with session-specific context.
|
||||
|
||||
|
||||
Prepends the recent git log to the system prompt for grounding.
|
||||
"""
|
||||
base = self.system_prompt
|
||||
|
||||
|
||||
# Add recent changes block if available
|
||||
recent_changes = self._session_context.get("recent_changes_block", "")
|
||||
if recent_changes:
|
||||
@@ -330,36 +357,45 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
||||
lines = base.split("\n")
|
||||
if lines:
|
||||
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
|
||||
|
||||
|
||||
return base
|
||||
|
||||
|
||||
async def orchestrate(self, user_request: str) -> str:
|
||||
"""Main entry point for user requests.
|
||||
|
||||
|
||||
Analyzes the request and either handles directly or delegates.
|
||||
"""
|
||||
# Run session init on first message (loads git log, etc.)
|
||||
await self._session_init()
|
||||
|
||||
|
||||
# Quick classification
|
||||
request_lower = user_request.lower()
|
||||
|
||||
|
||||
# Direct response patterns (no delegation needed)
|
||||
direct_patterns = [
|
||||
"your name", "who are you", "what are you",
|
||||
"hello", "hi", "how are you",
|
||||
"help", "what can you do",
|
||||
"your name",
|
||||
"who are you",
|
||||
"what are you",
|
||||
"hello",
|
||||
"hi",
|
||||
"how are you",
|
||||
"help",
|
||||
"what can you do",
|
||||
]
|
||||
|
||||
|
||||
for pattern in direct_patterns:
|
||||
if pattern in request_lower:
|
||||
return await self.run(user_request)
|
||||
|
||||
|
||||
# Check for memory references — delegate to Echo
|
||||
memory_patterns = [
|
||||
"we talked about", "we discussed", "remember",
|
||||
"what did i say", "what did we decide",
|
||||
"remind me", "have we",
|
||||
"we talked about",
|
||||
"we discussed",
|
||||
"remember",
|
||||
"what did i say",
|
||||
"what did we decide",
|
||||
"remind me",
|
||||
"have we",
|
||||
]
|
||||
|
||||
for pattern in memory_patterns:
|
||||
@@ -395,19 +431,16 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
||||
if agent in text_lower:
|
||||
return agent
|
||||
return "orchestrator"
|
||||
|
||||
|
||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||
"""Execute a task (usually delegates to appropriate agent)."""
|
||||
return await self.orchestrate(description)
|
||||
|
||||
|
||||
def get_swarm_status(self) -> dict:
|
||||
"""Get status of all agents in the swarm."""
|
||||
return {
|
||||
"orchestrator": self.get_status(),
|
||||
"sub_agents": {
|
||||
aid: agent.get_status()
|
||||
for aid, agent in self.sub_agents.items()
|
||||
},
|
||||
"sub_agents": {aid: agent.get_status() for aid, agent in self.sub_agents.items()},
|
||||
"total_agents": 1 + len(self.sub_agents),
|
||||
}
|
||||
|
||||
@@ -468,10 +501,29 @@ _PERSONAS: list[dict[str, Any]] = [
|
||||
"system_prompt": (
|
||||
"You are Helm, a routing and orchestration specialist.\n"
|
||||
"Analyze tasks and decide how to route them to other agents.\n"
|
||||
"Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory).\n"
|
||||
"Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory), Lab (experiments).\n"
|
||||
"Respond with: Primary Agent: [agent name]"
|
||||
),
|
||||
},
|
||||
{
|
||||
"agent_id": "lab",
|
||||
"name": "Lab",
|
||||
"role": "experiment",
|
||||
"tools": [
|
||||
"run_experiment",
|
||||
"prepare_experiment",
|
||||
"shell",
|
||||
"python",
|
||||
"read_file",
|
||||
"write_file",
|
||||
],
|
||||
"system_prompt": (
|
||||
"You are Lab, an autonomous ML experimentation specialist.\n"
|
||||
"You run time-boxed training experiments, evaluate metrics,\n"
|
||||
"modify training code to improve results, and iterate.\n"
|
||||
"Always report the metric delta. Never exceed the time budget."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ class ApprovalItem:
|
||||
id: str
|
||||
title: str
|
||||
description: str
|
||||
proposed_action: str # what Timmy wants to do
|
||||
impact: str # "low" | "medium" | "high"
|
||||
proposed_action: str # what Timmy wants to do
|
||||
impact: str # "low" | "medium" | "high"
|
||||
created_at: datetime
|
||||
status: str # "pending" | "approved" | "rejected"
|
||||
status: str # "pending" | "approved" | "rejected"
|
||||
|
||||
|
||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
@@ -81,6 +81,7 @@ def _row_to_item(row: sqlite3.Row) -> ApprovalItem:
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_item(
|
||||
title: str,
|
||||
description: str,
|
||||
@@ -133,18 +134,14 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||
"""Return all approval items regardless of status, newest first."""
|
||||
conn = _get_conn(db_path)
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM approval_items ORDER BY created_at DESC"
|
||||
).fetchall()
|
||||
rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall()
|
||||
conn.close()
|
||||
return [_row_to_item(r) for r in rows]
|
||||
|
||||
|
||||
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||
conn = _get_conn(db_path)
|
||||
row = conn.execute(
|
||||
"SELECT * FROM approval_items WHERE id = ?", (item_id,)
|
||||
).fetchone()
|
||||
row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone()
|
||||
conn.close()
|
||||
return _row_to_item(row) if row else None
|
||||
|
||||
@@ -152,9 +149,7 @@ def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem
|
||||
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||
"""Mark an approval item as approved."""
|
||||
conn = _get_conn(db_path)
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,)
|
||||
)
|
||||
conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return get_item(item_id, db_path)
|
||||
@@ -163,9 +158,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]
|
||||
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||
"""Mark an approval item as rejected."""
|
||||
conn = _get_conn(db_path)
|
||||
conn.execute(
|
||||
"UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,)
|
||||
)
|
||||
conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return get_item(item_id, db_path)
|
||||
|
||||
214
src/timmy/autoresearch.py
Normal file
214
src/timmy/autoresearch.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Autoresearch — autonomous ML experiment loops.
|
||||
|
||||
Integrates Karpathy's autoresearch pattern: an agent modifies training
|
||||
code, runs time-boxed GPU experiments, evaluates a target metric
|
||||
(val_bpb by default), and iterates to find improvements.
|
||||
|
||||
Flow:
|
||||
1. prepare_experiment — clone repo + run data prep
|
||||
2. run_experiment — execute train.py with wall-clock timeout
|
||||
3. evaluate_result — compare metric against baseline
|
||||
4. experiment_loop — orchestrate the full cycle
|
||||
|
||||
All subprocess calls are guarded with timeouts for graceful degradation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_REPO = "https://github.com/karpathy/autoresearch.git"
|
||||
_METRIC_RE = re.compile(r"val_bpb[:\s]+([0-9]+\.?[0-9]*)")
|
||||
|
||||
|
||||
def prepare_experiment(
|
||||
workspace: Path,
|
||||
repo_url: str = DEFAULT_REPO,
|
||||
) -> str:
|
||||
"""Clone autoresearch repo and run data preparation.
|
||||
|
||||
Args:
|
||||
workspace: Directory to set up the experiment in.
|
||||
repo_url: Git URL for the autoresearch repository.
|
||||
|
||||
Returns:
|
||||
Status message describing what was prepared.
|
||||
"""
|
||||
workspace = Path(workspace)
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_dir = workspace / "autoresearch"
|
||||
if not repo_dir.exists():
|
||||
logger.info("Cloning autoresearch into %s", repo_dir)
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", repo_url, str(repo_dir)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return f"Clone failed: {result.stderr.strip()}"
|
||||
else:
|
||||
logger.info("Autoresearch repo already present at %s", repo_dir)
|
||||
|
||||
# Run prepare.py (data download + tokeniser training)
|
||||
prepare_script = repo_dir / "prepare.py"
|
||||
if prepare_script.exists():
|
||||
logger.info("Running prepare.py …")
|
||||
result = subprocess.run(
|
||||
["python", str(prepare_script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(repo_dir),
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return f"Preparation failed: {result.stderr.strip()[:500]}"
|
||||
return "Preparation complete — data downloaded and tokeniser trained."
|
||||
|
||||
return "Preparation skipped — no prepare.py found."
|
||||
|
||||
|
||||
def run_experiment(
|
||||
workspace: Path,
|
||||
timeout: int = 300,
|
||||
metric_name: str = "val_bpb",
|
||||
) -> dict[str, Any]:
|
||||
"""Run a single training experiment with a wall-clock timeout.
|
||||
|
||||
Args:
|
||||
workspace: Experiment workspace (contains autoresearch/ subdir).
|
||||
timeout: Maximum wall-clock seconds for the run.
|
||||
metric_name: Name of the metric to extract from stdout.
|
||||
|
||||
Returns:
|
||||
Dict with keys: metric (float|None), log (str), duration_s (int),
|
||||
success (bool), error (str|None).
|
||||
"""
|
||||
repo_dir = Path(workspace) / "autoresearch"
|
||||
train_script = repo_dir / "train.py"
|
||||
|
||||
if not train_script.exists():
|
||||
return {
|
||||
"metric": None,
|
||||
"log": "",
|
||||
"duration_s": 0,
|
||||
"success": False,
|
||||
"error": f"train.py not found in {repo_dir}",
|
||||
}
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python", str(train_script)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(repo_dir),
|
||||
timeout=timeout,
|
||||
)
|
||||
duration = int(time.monotonic() - start)
|
||||
output = result.stdout + result.stderr
|
||||
|
||||
# Extract metric from output
|
||||
metric_val = _extract_metric(output, metric_name)
|
||||
|
||||
return {
|
||||
"metric": metric_val,
|
||||
"log": output[-2000:], # Keep last 2k chars
|
||||
"duration_s": duration,
|
||||
"success": result.returncode == 0,
|
||||
"error": None if result.returncode == 0 else f"Exit code {result.returncode}",
|
||||
}
|
||||
except subprocess.TimeoutExpired:
|
||||
duration = int(time.monotonic() - start)
|
||||
return {
|
||||
"metric": None,
|
||||
"log": f"Experiment timed out after {timeout}s",
|
||||
"duration_s": duration,
|
||||
"success": False,
|
||||
"error": f"Timed out after {timeout}s",
|
||||
}
|
||||
except OSError as exc:
|
||||
return {
|
||||
"metric": None,
|
||||
"log": "",
|
||||
"duration_s": 0,
|
||||
"success": False,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
def _extract_metric(output: str, metric_name: str = "val_bpb") -> Optional[float]:
|
||||
"""Extract the last occurrence of a metric value from training output."""
|
||||
pattern = re.compile(rf"{re.escape(metric_name)}[:\s]+([0-9]+\.?[0-9]*)")
|
||||
matches = pattern.findall(output)
|
||||
if matches:
|
||||
try:
|
||||
return float(matches[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def evaluate_result(
|
||||
current: float,
|
||||
baseline: float,
|
||||
metric_name: str = "val_bpb",
|
||||
) -> str:
|
||||
"""Compare a metric against baseline and return an assessment.
|
||||
|
||||
For val_bpb, lower is better.
|
||||
|
||||
Args:
|
||||
current: Current experiment's metric value.
|
||||
baseline: Baseline metric to compare against.
|
||||
metric_name: Name of the metric (for display).
|
||||
|
||||
Returns:
|
||||
Human-readable assessment string.
|
||||
"""
|
||||
delta = current - baseline
|
||||
pct = (delta / baseline) * 100 if baseline != 0 else 0.0
|
||||
|
||||
if delta < 0:
|
||||
return f"Improvement: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
|
||||
elif delta > 0:
|
||||
return f"Regression: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
|
||||
else:
|
||||
return f"No change: {metric_name} = {current:.4f}"
|
||||
|
||||
|
||||
def get_experiment_history(workspace: Path) -> list[dict[str, Any]]:
|
||||
"""Read experiment history from the workspace results file.
|
||||
|
||||
Returns:
|
||||
List of experiment result dicts, most recent first.
|
||||
"""
|
||||
results_file = Path(workspace) / "results.jsonl"
|
||||
if not results_file.exists():
|
||||
return []
|
||||
|
||||
history: list[dict[str, Any]] = []
|
||||
for line in results_file.read_text().strip().splitlines():
|
||||
try:
|
||||
history.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return list(reversed(history))
|
||||
|
||||
|
||||
def _append_result(workspace: Path, result: dict[str, Any]) -> None:
|
||||
"""Append a result to the workspace JSONL log."""
|
||||
results_file = Path(workspace) / "results.jsonl"
|
||||
results_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with results_file.open("a") as f:
|
||||
f.write(json.dumps(result) + "\n")
|
||||
@@ -24,8 +24,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# HuggingFace model IDs for each supported size.
|
||||
_AIRLLM_MODELS: dict[str, str] = {
|
||||
"8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"405b": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||
}
|
||||
|
||||
@@ -35,6 +35,7 @@ ModelSize = Literal["8b", "70b", "405b"]
|
||||
@dataclass
|
||||
class RunResult:
|
||||
"""Minimal Agno-compatible run result — carries the model's response text."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
@@ -47,6 +48,7 @@ def airllm_available() -> bool:
|
||||
"""Return True when the airllm package is importable."""
|
||||
try:
|
||||
import airllm # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@@ -67,15 +69,16 @@ class TimmyAirLLMAgent:
|
||||
model_id = _AIRLLM_MODELS.get(model_size)
|
||||
if model_id is None:
|
||||
raise ValueError(
|
||||
f"Unknown model size {model_size!r}. "
|
||||
f"Choose from: {list(_AIRLLM_MODELS)}"
|
||||
f"Unknown model size {model_size!r}. " f"Choose from: {list(_AIRLLM_MODELS)}"
|
||||
)
|
||||
|
||||
if is_apple_silicon():
|
||||
from airllm import AirLLMMLX # type: ignore[import]
|
||||
|
||||
self._model = AirLLMMLX(model_id)
|
||||
else:
|
||||
from airllm import AutoModel # type: ignore[import]
|
||||
|
||||
self._model = AutoModel.from_pretrained(model_id)
|
||||
|
||||
self._history: list[str] = []
|
||||
@@ -137,6 +140,7 @@ class TimmyAirLLMAgent:
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
Console().print(Markdown(text))
|
||||
except ImportError:
|
||||
print(text)
|
||||
@@ -157,6 +161,7 @@ GROK_MODELS: dict[str, str] = {
|
||||
@dataclass
|
||||
class GrokUsageStats:
|
||||
"""Tracks Grok API usage for cost monitoring and Spark logging."""
|
||||
|
||||
total_requests: int = 0
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
@@ -240,9 +245,7 @@ class GrokBackend:
|
||||
RunResult with response content
|
||||
"""
|
||||
if not self._api_key:
|
||||
return RunResult(
|
||||
content="Grok is not configured. Set XAI_API_KEY to enable."
|
||||
)
|
||||
return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
|
||||
|
||||
start = time.time()
|
||||
messages = self._build_messages(message)
|
||||
@@ -285,16 +288,12 @@ class GrokBackend:
|
||||
except Exception as exc:
|
||||
self.stats.errors += 1
|
||||
logger.error("Grok API error: %s", exc)
|
||||
return RunResult(
|
||||
content=f"Grok temporarily unavailable: {exc}"
|
||||
)
|
||||
return RunResult(content=f"Grok temporarily unavailable: {exc}")
|
||||
|
||||
async def arun(self, message: str) -> RunResult:
|
||||
"""Async inference via Grok API — used by cascade router and tools."""
|
||||
if not self._api_key:
|
||||
return RunResult(
|
||||
content="Grok is not configured. Set XAI_API_KEY to enable."
|
||||
)
|
||||
return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
|
||||
|
||||
start = time.time()
|
||||
messages = self._build_messages(message)
|
||||
@@ -336,9 +335,7 @@ class GrokBackend:
|
||||
except Exception as exc:
|
||||
self.stats.errors += 1
|
||||
logger.error("Grok async API error: %s", exc)
|
||||
return RunResult(
|
||||
content=f"Grok temporarily unavailable: {exc}"
|
||||
)
|
||||
return RunResult(content=f"Grok temporarily unavailable: {exc}")
|
||||
|
||||
def print_response(self, message: str, *, stream: bool = True) -> None:
|
||||
"""Run inference and render the response to stdout (CLI interface)."""
|
||||
@@ -346,6 +343,7 @@ class GrokBackend:
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
Console().print(Markdown(result.content))
|
||||
except ImportError:
|
||||
print(result.content)
|
||||
@@ -415,6 +413,7 @@ def grok_available() -> bool:
|
||||
"""Return True when Grok is enabled and API key is configured."""
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
return settings.grok_enabled and bool(settings.xai_api_key)
|
||||
except Exception:
|
||||
return False
|
||||
@@ -472,9 +471,7 @@ class ClaudeBackend:
|
||||
def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult:
|
||||
"""Synchronous inference via Claude API."""
|
||||
if not self._api_key:
|
||||
return RunResult(
|
||||
content="Claude is not configured. Set ANTHROPIC_API_KEY to enable."
|
||||
)
|
||||
return RunResult(content="Claude is not configured. Set ANTHROPIC_API_KEY to enable.")
|
||||
|
||||
start = time.time()
|
||||
messages = self._build_messages(message)
|
||||
@@ -508,9 +505,7 @@ class ClaudeBackend:
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Claude API error: %s", exc)
|
||||
return RunResult(
|
||||
content=f"Claude temporarily unavailable: {exc}"
|
||||
)
|
||||
return RunResult(content=f"Claude temporarily unavailable: {exc}")
|
||||
|
||||
def print_response(self, message: str, *, stream: bool = True) -> None:
|
||||
"""Run inference and render the response to stdout (CLI interface)."""
|
||||
@@ -518,6 +513,7 @@ class ClaudeBackend:
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
Console().print(Markdown(result.content))
|
||||
except ImportError:
|
||||
print(result.content)
|
||||
@@ -569,6 +565,7 @@ def claude_available() -> bool:
|
||||
"""Return True when Anthropic API key is configured."""
|
||||
try:
|
||||
from config import settings
|
||||
|
||||
return bool(settings.anthropic_api_key)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -25,6 +25,7 @@ _CACHE_MINUTES = 30
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApprovalItem:
|
||||
"""Lightweight representation used inside a Briefing.
|
||||
@@ -32,6 +33,7 @@ class ApprovalItem:
|
||||
The canonical mutable version (with persistence) lives in timmy.approvals.
|
||||
This one travels with the Briefing dataclass as a read-only snapshot.
|
||||
"""
|
||||
|
||||
id: str
|
||||
title: str
|
||||
description: str
|
||||
@@ -44,20 +46,19 @@ class ApprovalItem:
|
||||
@dataclass
|
||||
class Briefing:
|
||||
generated_at: datetime
|
||||
summary: str # 150-300 words
|
||||
summary: str # 150-300 words
|
||||
approval_items: list[ApprovalItem] = field(default_factory=list)
|
||||
period_start: datetime = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6)
|
||||
)
|
||||
period_end: datetime = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
period_end: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
@@ -98,9 +99,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
|
||||
def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
|
||||
"""Load the most-recently cached briefing, or None if there is none."""
|
||||
conn = _get_cache_conn(db_path)
|
||||
row = conn.execute(
|
||||
"SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1"
|
||||
).fetchone()
|
||||
row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone()
|
||||
conn.close()
|
||||
if row is None:
|
||||
return None
|
||||
@@ -115,7 +114,11 @@ def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
|
||||
def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
|
||||
"""Return True if the briefing was generated within max_age_minutes."""
|
||||
now = datetime.now(timezone.utc)
|
||||
age = now - briefing.generated_at.replace(tzinfo=timezone.utc) if briefing.generated_at.tzinfo is None else now - briefing.generated_at
|
||||
age = (
|
||||
now - briefing.generated_at.replace(tzinfo=timezone.utc)
|
||||
if briefing.generated_at.tzinfo is None
|
||||
else now - briefing.generated_at
|
||||
)
|
||||
return age.total_seconds() < max_age_minutes * 60
|
||||
|
||||
|
||||
@@ -123,6 +126,7 @@ def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
|
||||
# Activity gathering helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _gather_swarm_summary(since: datetime) -> str:
|
||||
"""Pull recent task/agent stats from swarm.db. Graceful if DB missing."""
|
||||
swarm_db = Path("data/swarm.db")
|
||||
@@ -170,6 +174,7 @@ def _gather_task_queue_summary() -> str:
|
||||
"""Pull task queue stats for the briefing. Graceful if unavailable."""
|
||||
try:
|
||||
from swarm.task_queue.models import get_task_summary_for_briefing
|
||||
|
||||
stats = get_task_summary_for_briefing()
|
||||
parts = []
|
||||
if stats["pending_approval"]:
|
||||
@@ -194,6 +199,7 @@ def _gather_chat_summary(since: datetime) -> str:
|
||||
"""Pull recent chat messages from the in-memory log."""
|
||||
try:
|
||||
from dashboard.store import message_log
|
||||
|
||||
messages = message_log.all()
|
||||
# Filter to messages in the briefing window (best-effort: no timestamps)
|
||||
recent = messages[-10:] if len(messages) > 10 else messages
|
||||
@@ -213,6 +219,7 @@ def _gather_chat_summary(since: datetime) -> str:
|
||||
# BriefingEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BriefingEngine:
|
||||
"""Generates morning briefings by querying activity and asking Timmy."""
|
||||
|
||||
@@ -297,6 +304,7 @@ class BriefingEngine:
|
||||
"""Call Timmy's Agno agent and return the response text."""
|
||||
try:
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
agent = create_timmy()
|
||||
run = agent.run(prompt, stream=False)
|
||||
result = run.content if hasattr(run, "content") else str(run)
|
||||
@@ -317,6 +325,7 @@ class BriefingEngine:
|
||||
"""Return pending ApprovalItems from the approvals DB."""
|
||||
try:
|
||||
from timmy import approvals as _approvals
|
||||
|
||||
raw_items = _approvals.list_pending()
|
||||
return [
|
||||
ApprovalItem(
|
||||
|
||||
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class TimmyResponse:
|
||||
"""Response from Timmy via Cascade Router."""
|
||||
|
||||
content: str
|
||||
provider_used: str
|
||||
latency_ms: float
|
||||
@@ -27,31 +28,30 @@ class TimmyResponse:
|
||||
|
||||
class TimmyCascadeAdapter:
|
||||
"""Adapter that routes Timmy requests through Cascade Router.
|
||||
|
||||
|
||||
Usage:
|
||||
adapter = TimmyCascadeAdapter()
|
||||
response = await adapter.chat("Hello")
|
||||
print(f"Response: {response.content}")
|
||||
print(f"Provider: {response.provider_used}")
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, router: Optional[CascadeRouter] = None) -> None:
|
||||
"""Initialize adapter with Cascade Router.
|
||||
|
||||
|
||||
Args:
|
||||
router: CascadeRouter instance. If None, creates default.
|
||||
"""
|
||||
self.router = router or CascadeRouter()
|
||||
logger.info("TimmyCascadeAdapter initialized with %d providers",
|
||||
len(self.router.providers))
|
||||
|
||||
logger.info("TimmyCascadeAdapter initialized with %d providers", len(self.router.providers))
|
||||
|
||||
async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse:
|
||||
"""Send message through cascade router with automatic failover.
|
||||
|
||||
|
||||
Args:
|
||||
message: User message
|
||||
context: Optional conversation context
|
||||
|
||||
|
||||
Returns:
|
||||
TimmyResponse with content and metadata
|
||||
"""
|
||||
@@ -60,37 +60,38 @@ class TimmyCascadeAdapter:
|
||||
if context:
|
||||
messages.append({"role": "system", "content": context})
|
||||
messages.append({"role": "user", "content": message})
|
||||
|
||||
|
||||
# Route through cascade
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
result = await self.router.complete(
|
||||
messages=messages,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
latency = (time.time() - start) * 1000
|
||||
|
||||
|
||||
# Determine if fallback was used
|
||||
primary = self.router.providers[0] if self.router.providers else None
|
||||
fallback_used = primary and primary.status.value != "healthy"
|
||||
|
||||
|
||||
return TimmyResponse(
|
||||
content=result.content,
|
||||
provider_used=result.provider_name,
|
||||
latency_ms=latency,
|
||||
fallback_used=fallback_used,
|
||||
)
|
||||
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("All providers failed: %s", exc)
|
||||
raise
|
||||
|
||||
|
||||
def get_provider_status(self) -> list[dict]:
|
||||
"""Get status of all providers.
|
||||
|
||||
|
||||
Returns:
|
||||
List of provider status dicts
|
||||
"""
|
||||
@@ -112,10 +113,10 @@ class TimmyCascadeAdapter:
|
||||
}
|
||||
for p in self.router.providers
|
||||
]
|
||||
|
||||
|
||||
def get_preferred_provider(self) -> Optional[str]:
|
||||
"""Get name of highest-priority healthy provider.
|
||||
|
||||
|
||||
Returns:
|
||||
Provider name or None if all unhealthy
|
||||
"""
|
||||
|
||||
@@ -17,22 +17,23 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class ConversationContext:
|
||||
"""Tracks the current conversation state."""
|
||||
|
||||
user_name: Optional[str] = None
|
||||
current_topic: Optional[str] = None
|
||||
last_intent: Optional[str] = None
|
||||
turn_count: int = 0
|
||||
started_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
def update_topic(self, topic: str) -> None:
|
||||
"""Update the current conversation topic."""
|
||||
self.current_topic = topic
|
||||
self.turn_count += 1
|
||||
|
||||
|
||||
def set_user_name(self, name: str) -> None:
|
||||
"""Remember the user's name."""
|
||||
self.user_name = name
|
||||
logger.info("User name set to: %s", name)
|
||||
|
||||
|
||||
def get_context_summary(self) -> str:
|
||||
"""Generate a context summary for the prompt."""
|
||||
parts = []
|
||||
@@ -47,35 +48,88 @@ class ConversationContext:
|
||||
|
||||
class ConversationManager:
|
||||
"""Manages conversation context across sessions."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._contexts: dict[str, ConversationContext] = {}
|
||||
|
||||
|
||||
def get_context(self, session_id: str) -> ConversationContext:
|
||||
"""Get or create context for a session."""
|
||||
if session_id not in self._contexts:
|
||||
self._contexts[session_id] = ConversationContext()
|
||||
return self._contexts[session_id]
|
||||
|
||||
|
||||
def clear_context(self, session_id: str) -> None:
|
||||
"""Clear context for a session."""
|
||||
if session_id in self._contexts:
|
||||
del self._contexts[session_id]
|
||||
|
||||
|
||||
# Words that look like names but are actually verbs/UI states
|
||||
_NAME_BLOCKLIST = frozenset({
|
||||
"sending", "loading", "pending", "processing", "typing",
|
||||
"working", "going", "trying", "looking", "getting", "doing",
|
||||
"waiting", "running", "checking", "coming", "leaving",
|
||||
"thinking", "reading", "writing", "watching", "listening",
|
||||
"playing", "eating", "sleeping", "sitting", "standing",
|
||||
"walking", "talking", "asking", "telling", "feeling",
|
||||
"hoping", "wondering", "glad", "happy", "sorry", "sure",
|
||||
"fine", "good", "great", "okay", "here", "there", "back",
|
||||
"done", "ready", "busy", "free", "available", "interested",
|
||||
"confused", "lost", "stuck", "curious", "excited", "tired",
|
||||
"not", "also", "just", "still", "already", "currently",
|
||||
})
|
||||
_NAME_BLOCKLIST = frozenset(
|
||||
{
|
||||
"sending",
|
||||
"loading",
|
||||
"pending",
|
||||
"processing",
|
||||
"typing",
|
||||
"working",
|
||||
"going",
|
||||
"trying",
|
||||
"looking",
|
||||
"getting",
|
||||
"doing",
|
||||
"waiting",
|
||||
"running",
|
||||
"checking",
|
||||
"coming",
|
||||
"leaving",
|
||||
"thinking",
|
||||
"reading",
|
||||
"writing",
|
||||
"watching",
|
||||
"listening",
|
||||
"playing",
|
||||
"eating",
|
||||
"sleeping",
|
||||
"sitting",
|
||||
"standing",
|
||||
"walking",
|
||||
"talking",
|
||||
"asking",
|
||||
"telling",
|
||||
"feeling",
|
||||
"hoping",
|
||||
"wondering",
|
||||
"glad",
|
||||
"happy",
|
||||
"sorry",
|
||||
"sure",
|
||||
"fine",
|
||||
"good",
|
||||
"great",
|
||||
"okay",
|
||||
"here",
|
||||
"there",
|
||||
"back",
|
||||
"done",
|
||||
"ready",
|
||||
"busy",
|
||||
"free",
|
||||
"available",
|
||||
"interested",
|
||||
"confused",
|
||||
"lost",
|
||||
"stuck",
|
||||
"curious",
|
||||
"excited",
|
||||
"tired",
|
||||
"not",
|
||||
"also",
|
||||
"just",
|
||||
"still",
|
||||
"already",
|
||||
"currently",
|
||||
}
|
||||
)
|
||||
|
||||
def extract_user_name(self, message: str) -> Optional[str]:
|
||||
"""Try to extract user's name from message."""
|
||||
@@ -106,40 +160,66 @@ class ConversationManager:
|
||||
return name.capitalize()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def should_use_tools(self, message: str, context: ConversationContext) -> bool:
|
||||
"""Determine if this message likely requires tools.
|
||||
|
||||
|
||||
Returns True if tools are likely needed, False for simple chat.
|
||||
"""
|
||||
message_lower = message.lower().strip()
|
||||
|
||||
|
||||
# Tool keywords that suggest tool usage is needed
|
||||
tool_keywords = [
|
||||
"search", "look up", "find", "google", "current price",
|
||||
"latest", "today's", "news", "weather", "stock price",
|
||||
"read file", "write file", "save", "calculate", "compute",
|
||||
"run ", "execute", "shell", "command", "install",
|
||||
"search",
|
||||
"look up",
|
||||
"find",
|
||||
"google",
|
||||
"current price",
|
||||
"latest",
|
||||
"today's",
|
||||
"news",
|
||||
"weather",
|
||||
"stock price",
|
||||
"read file",
|
||||
"write file",
|
||||
"save",
|
||||
"calculate",
|
||||
"compute",
|
||||
"run ",
|
||||
"execute",
|
||||
"shell",
|
||||
"command",
|
||||
"install",
|
||||
]
|
||||
|
||||
|
||||
# Chat-only keywords that definitely don't need tools
|
||||
chat_only = [
|
||||
"hello", "hi ", "hey", "how are you", "what's up",
|
||||
"your name", "who are you", "what are you",
|
||||
"thanks", "thank you", "bye", "goodbye",
|
||||
"tell me about yourself", "what can you do",
|
||||
"hello",
|
||||
"hi ",
|
||||
"hey",
|
||||
"how are you",
|
||||
"what's up",
|
||||
"your name",
|
||||
"who are you",
|
||||
"what are you",
|
||||
"thanks",
|
||||
"thank you",
|
||||
"bye",
|
||||
"goodbye",
|
||||
"tell me about yourself",
|
||||
"what can you do",
|
||||
]
|
||||
|
||||
|
||||
# Check for chat-only patterns first
|
||||
for pattern in chat_only:
|
||||
if pattern in message_lower:
|
||||
return False
|
||||
|
||||
|
||||
# Check for tool keywords
|
||||
for keyword in tool_keywords:
|
||||
if keyword in message_lower:
|
||||
return True
|
||||
|
||||
|
||||
# Simple questions (starting with what, who, how, why, when, where)
|
||||
# usually don't need tools unless about current/real-time info
|
||||
simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"]
|
||||
@@ -150,7 +230,7 @@ class ConversationManager:
|
||||
if any(t in message_lower for t in time_words):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Default: don't use tools for unclear cases
|
||||
return False
|
||||
|
||||
|
||||
@@ -25,11 +25,12 @@ def _get_model():
|
||||
global _model, _has_embeddings
|
||||
if _has_embeddings is False:
|
||||
return None
|
||||
|
||||
|
||||
if _model is not None:
|
||||
return _model
|
||||
|
||||
|
||||
from config import settings
|
||||
|
||||
# In test mode or low-memory environments, skip embedding model load
|
||||
if settings.timmy_skip_embeddings:
|
||||
_has_embeddings = False
|
||||
@@ -37,7 +38,8 @@ def _get_model():
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
_model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
_has_embeddings = True
|
||||
return _model
|
||||
except (ImportError, RuntimeError, Exception):
|
||||
@@ -56,7 +58,7 @@ def _get_embedding_dimension() -> int:
|
||||
|
||||
def _compute_embedding(text: str) -> list[float]:
|
||||
"""Compute embedding vector for text.
|
||||
|
||||
|
||||
Uses sentence-transformers if available, otherwise returns
|
||||
a simple hash-based vector for basic similarity.
|
||||
"""
|
||||
@@ -66,30 +68,31 @@ def _compute_embedding(text: str) -> list[float]:
|
||||
return model.encode(text).tolist()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Fallback: simple character n-gram hash embedding
|
||||
# Not as good but allows the system to work without heavy deps
|
||||
dim = 384
|
||||
vec = [0.0] * dim
|
||||
text = text.lower()
|
||||
|
||||
|
||||
# Generate character trigram features
|
||||
for i in range(len(text) - 2):
|
||||
trigram = text[i:i+3]
|
||||
trigram = text[i : i + 3]
|
||||
hash_val = hash(trigram) % dim
|
||||
vec[hash_val] += 1.0
|
||||
|
||||
|
||||
# Normalize
|
||||
norm = sum(x*x for x in vec) ** 0.5
|
||||
norm = sum(x * x for x in vec) ** 0.5
|
||||
if norm > 0:
|
||||
vec = [x/norm for x in vec]
|
||||
|
||||
vec = [x / norm for x in vec]
|
||||
|
||||
return vec
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryEntry:
|
||||
"""A memory entry with vector embedding."""
|
||||
|
||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
content: str = "" # The actual text content
|
||||
source: str = "" # Where it came from (agent, user, system)
|
||||
@@ -99,9 +102,7 @@ class MemoryEntry:
|
||||
session_id: Optional[str] = None
|
||||
metadata: Optional[dict] = None
|
||||
embedding: Optional[list[float]] = None
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
relevance_score: Optional[float] = None # Set during search
|
||||
|
||||
|
||||
@@ -110,7 +111,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
|
||||
# Try to load sqlite-vss extension
|
||||
try:
|
||||
conn.enable_load_extension(True)
|
||||
@@ -119,7 +120,7 @@ def _get_conn() -> sqlite3.Connection:
|
||||
_has_vss = True
|
||||
except Exception:
|
||||
_has_vss = False
|
||||
|
||||
|
||||
# Create tables
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -137,24 +138,14 @@ def _get_conn() -> sqlite3.Connection:
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Create indexes
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)"
|
||||
)
|
||||
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)")
|
||||
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -170,7 +161,7 @@ def store_memory(
|
||||
compute_embedding: bool = True,
|
||||
) -> MemoryEntry:
|
||||
"""Store a memory entry with optional embedding.
|
||||
|
||||
|
||||
Args:
|
||||
content: The text content to store
|
||||
source: Source of the memory (agent name, user, system)
|
||||
@@ -180,14 +171,14 @@ def store_memory(
|
||||
session_id: Session identifier
|
||||
metadata: Additional structured data
|
||||
compute_embedding: Whether to compute vector embedding
|
||||
|
||||
|
||||
Returns:
|
||||
The stored MemoryEntry
|
||||
"""
|
||||
embedding = None
|
||||
if compute_embedding:
|
||||
embedding = _compute_embedding(content)
|
||||
|
||||
|
||||
entry = MemoryEntry(
|
||||
content=content,
|
||||
source=source,
|
||||
@@ -198,7 +189,7 @@ def store_memory(
|
||||
metadata=metadata,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
|
||||
conn = _get_conn()
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -222,7 +213,7 @@ def store_memory(
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
@@ -235,7 +226,7 @@ def search_memories(
|
||||
min_relevance: float = 0.0,
|
||||
) -> list[MemoryEntry]:
|
||||
"""Search for memories by semantic similarity.
|
||||
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
limit: Maximum results
|
||||
@@ -243,18 +234,18 @@ def search_memories(
|
||||
agent_id: Filter by agent
|
||||
session_id: Filter by session
|
||||
min_relevance: Minimum similarity score (0-1)
|
||||
|
||||
|
||||
Returns:
|
||||
List of MemoryEntry objects sorted by relevance
|
||||
"""
|
||||
query_embedding = _compute_embedding(query)
|
||||
|
||||
|
||||
conn = _get_conn()
|
||||
|
||||
|
||||
# Build query with filters
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
|
||||
if context_type:
|
||||
conditions.append("context_type = ?")
|
||||
params.append(context_type)
|
||||
@@ -264,9 +255,9 @@ def search_memories(
|
||||
if session_id:
|
||||
conditions.append("session_id = ?")
|
||||
params.append(session_id)
|
||||
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
||||
|
||||
|
||||
# Fetch candidates (we'll do in-memory similarity for now)
|
||||
# For production with sqlite-vss, this would use vector similarity index
|
||||
query_sql = f"""
|
||||
@@ -276,10 +267,10 @@ def search_memories(
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit * 3) # Get more candidates for ranking
|
||||
|
||||
|
||||
rows = conn.execute(query_sql, params).fetchall()
|
||||
conn.close()
|
||||
|
||||
|
||||
# Compute similarity scores
|
||||
results = []
|
||||
for row in rows:
|
||||
@@ -295,7 +286,7 @@ def search_memories(
|
||||
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
|
||||
timestamp=row["timestamp"],
|
||||
)
|
||||
|
||||
|
||||
if entry.embedding:
|
||||
# Cosine similarity
|
||||
score = _cosine_similarity(query_embedding, entry.embedding)
|
||||
@@ -308,7 +299,7 @@ def search_memories(
|
||||
entry.relevance_score = score
|
||||
if score >= min_relevance:
|
||||
results.append(entry)
|
||||
|
||||
|
||||
# Sort by relevance and return top results
|
||||
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
|
||||
return results[:limit]
|
||||
@@ -316,9 +307,9 @@ def search_memories(
|
||||
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
dot = sum(x*y for x, y in zip(a, b))
|
||||
norm_a = sum(x*x for x in a) ** 0.5
|
||||
norm_b = sum(x*x for x in b) ** 0.5
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
@@ -334,51 +325,47 @@ def _keyword_overlap(query: str, content: str) -> float:
|
||||
return overlap / len(query_words)
|
||||
|
||||
|
||||
def get_memory_context(
|
||||
query: str,
|
||||
max_tokens: int = 2000,
|
||||
**filters
|
||||
) -> str:
|
||||
def get_memory_context(query: str, max_tokens: int = 2000, **filters) -> str:
|
||||
"""Get relevant memory context as formatted text for LLM prompts.
|
||||
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
max_tokens: Approximate maximum tokens to return
|
||||
**filters: Additional filters (agent_id, session_id, etc.)
|
||||
|
||||
|
||||
Returns:
|
||||
Formatted context string for inclusion in prompts
|
||||
"""
|
||||
memories = search_memories(query, limit=20, **filters)
|
||||
|
||||
|
||||
context_parts = []
|
||||
total_chars = 0
|
||||
max_chars = max_tokens * 4 # Rough approximation
|
||||
|
||||
|
||||
for mem in memories:
|
||||
formatted = f"[{mem.source}]: {mem.content}"
|
||||
if total_chars + len(formatted) > max_chars:
|
||||
break
|
||||
context_parts.append(formatted)
|
||||
total_chars += len(formatted)
|
||||
|
||||
|
||||
if not context_parts:
|
||||
return ""
|
||||
|
||||
|
||||
return "Relevant context from memory:\n" + "\n\n".join(context_parts)
|
||||
|
||||
|
||||
def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
|
||||
"""Recall personal facts about the user or system.
|
||||
|
||||
|
||||
Args:
|
||||
agent_id: Optional agent filter
|
||||
|
||||
|
||||
Returns:
|
||||
List of fact strings
|
||||
"""
|
||||
conn = _get_conn()
|
||||
|
||||
|
||||
if agent_id:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
@@ -398,7 +385,7 @@ def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
|
||||
LIMIT 100
|
||||
""",
|
||||
).fetchall()
|
||||
|
||||
|
||||
conn.close()
|
||||
return [r["content"] for r in rows]
|
||||
|
||||
@@ -434,11 +421,11 @@ def update_personal_fact(memory_id: str, new_content: str) -> bool:
|
||||
|
||||
def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry:
|
||||
"""Store a personal fact about the user or system.
|
||||
|
||||
|
||||
Args:
|
||||
fact: The fact to store
|
||||
agent_id: Associated agent
|
||||
|
||||
|
||||
Returns:
|
||||
The stored MemoryEntry
|
||||
"""
|
||||
@@ -453,7 +440,7 @@ def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntr
|
||||
|
||||
def delete_memory(memory_id: str) -> bool:
|
||||
"""Delete a memory entry by ID.
|
||||
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
@@ -470,29 +457,27 @@ def delete_memory(memory_id: str) -> bool:
|
||||
|
||||
def get_memory_stats() -> dict:
|
||||
"""Get statistics about the memory store.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with counts by type, total entries, etc.
|
||||
"""
|
||||
conn = _get_conn()
|
||||
|
||||
total = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM memory_entries"
|
||||
).fetchone()["count"]
|
||||
|
||||
|
||||
total = conn.execute("SELECT COUNT(*) as count FROM memory_entries").fetchone()["count"]
|
||||
|
||||
by_type = {}
|
||||
rows = conn.execute(
|
||||
"SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type"
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
by_type[row["context_type"]] = row["count"]
|
||||
|
||||
|
||||
with_embeddings = conn.execute(
|
||||
"SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL"
|
||||
).fetchone()["count"]
|
||||
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
return {
|
||||
"total_entries": total,
|
||||
"by_type": by_type,
|
||||
@@ -503,20 +488,20 @@ def get_memory_stats() -> dict:
|
||||
|
||||
def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
||||
"""Delete old memories to manage storage.
|
||||
|
||||
|
||||
Args:
|
||||
older_than_days: Delete memories older than this
|
||||
keep_facts: Whether to preserve fact-type memories
|
||||
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat()
|
||||
|
||||
|
||||
conn = _get_conn()
|
||||
|
||||
|
||||
if keep_facts:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
@@ -530,9 +515,9 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
||||
"DELETE FROM memory_entries WHERE timestamp < ?",
|
||||
(cutoff,),
|
||||
)
|
||||
|
||||
|
||||
deleted = cursor.rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
return deleted
|
||||
|
||||
@@ -28,50 +28,52 @@ HANDOFF_PATH = VAULT_PATH / "notes" / "last-session-handoff.md"
|
||||
|
||||
class HotMemory:
|
||||
"""Tier 1: Hot memory (MEMORY.md) — always loaded."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.path = HOT_MEMORY_PATH
|
||||
self._content: Optional[str] = None
|
||||
self._last_modified: Optional[float] = None
|
||||
|
||||
|
||||
def read(self, force_refresh: bool = False) -> str:
|
||||
"""Read hot memory, with caching."""
|
||||
if not self.path.exists():
|
||||
self._create_default()
|
||||
|
||||
|
||||
# Check if file changed
|
||||
current_mtime = self.path.stat().st_mtime
|
||||
if not force_refresh and self._content and self._last_modified == current_mtime:
|
||||
return self._content
|
||||
|
||||
|
||||
self._content = self.path.read_text()
|
||||
self._last_modified = current_mtime
|
||||
logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path)
|
||||
return self._content
|
||||
|
||||
|
||||
def update_section(self, section: str, content: str) -> None:
|
||||
"""Update a specific section in MEMORY.md."""
|
||||
full_content = self.read()
|
||||
|
||||
|
||||
# Find section
|
||||
pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)"
|
||||
match = re.search(pattern, full_content, re.DOTALL)
|
||||
|
||||
|
||||
if match:
|
||||
# Replace section
|
||||
new_section = f"## {section}\n\n{content}\n\n"
|
||||
full_content = full_content[:match.start()] + new_section + full_content[match.end():]
|
||||
full_content = full_content[: match.start()] + new_section + full_content[match.end() :]
|
||||
else:
|
||||
# Append section before last updated line
|
||||
insert_point = full_content.rfind("*Prune date:")
|
||||
new_section = f"## {section}\n\n{content}\n\n"
|
||||
full_content = full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
|
||||
|
||||
full_content = (
|
||||
full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
|
||||
)
|
||||
|
||||
self.path.write_text(full_content)
|
||||
self._content = full_content
|
||||
self._last_modified = self.path.stat().st_mtime
|
||||
logger.info("HotMemory: Updated section '%s'", section)
|
||||
|
||||
|
||||
def _create_default(self) -> None:
|
||||
"""Create default MEMORY.md if missing."""
|
||||
default_content = """# Timmy Hot Memory
|
||||
@@ -130,33 +132,33 @@ class HotMemory:
|
||||
*Prune date: {prune_date}*
|
||||
""".format(
|
||||
date=datetime.now(timezone.utc).strftime("%Y-%m-%d"),
|
||||
prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d")
|
||||
prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
|
||||
self.path.write_text(default_content)
|
||||
logger.info("HotMemory: Created default MEMORY.md")
|
||||
|
||||
|
||||
class VaultMemory:
|
||||
"""Tier 2: Structured vault (memory/) — append-only markdown."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.path = VAULT_PATH
|
||||
self._ensure_structure()
|
||||
|
||||
|
||||
def _ensure_structure(self) -> None:
|
||||
"""Ensure vault directory structure exists."""
|
||||
(self.path / "self").mkdir(parents=True, exist_ok=True)
|
||||
(self.path / "notes").mkdir(parents=True, exist_ok=True)
|
||||
(self.path / "aar").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def write_note(self, name: str, content: str, namespace: str = "notes") -> Path:
|
||||
"""Write a note to the vault."""
|
||||
# Add timestamp to filename
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
|
||||
filename = f"{timestamp}_{name}.md"
|
||||
filepath = self.path / namespace / filename
|
||||
|
||||
|
||||
# Add header
|
||||
full_content = f"""# {name.replace('_', ' ').title()}
|
||||
|
||||
@@ -171,39 +173,39 @@ class VaultMemory:
|
||||
|
||||
*Auto-generated by Timmy Memory System*
|
||||
"""
|
||||
|
||||
|
||||
filepath.write_text(full_content)
|
||||
logger.info("VaultMemory: Wrote %s", filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
def read_file(self, filepath: Path) -> str:
|
||||
"""Read a file from the vault."""
|
||||
if not filepath.exists():
|
||||
return ""
|
||||
return filepath.read_text()
|
||||
|
||||
|
||||
def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]:
|
||||
"""List files in a namespace."""
|
||||
dir_path = self.path / namespace
|
||||
if not dir_path.exists():
|
||||
return []
|
||||
return sorted(dir_path.glob(pattern))
|
||||
|
||||
|
||||
def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]:
|
||||
"""Get most recent file in namespace."""
|
||||
files = self.list_files(namespace, pattern)
|
||||
return files[-1] if files else None
|
||||
|
||||
|
||||
def update_user_profile(self, key: str, value: str) -> None:
|
||||
"""Update a field in user_profile.md."""
|
||||
profile_path = self.path / "self" / "user_profile.md"
|
||||
|
||||
|
||||
if not profile_path.exists():
|
||||
# Create default profile
|
||||
self._create_default_profile()
|
||||
|
||||
|
||||
content = profile_path.read_text()
|
||||
|
||||
|
||||
# Simple pattern replacement
|
||||
pattern = rf"(\*\*{re.escape(key)}:\*\*).*"
|
||||
if re.search(pattern, content):
|
||||
@@ -214,17 +216,17 @@ class VaultMemory:
|
||||
if facts_section in content:
|
||||
insert_point = content.find(facts_section) + len(facts_section)
|
||||
content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:]
|
||||
|
||||
|
||||
# Update last_updated
|
||||
content = re.sub(
|
||||
r"\*Last updated:.*\*",
|
||||
f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*",
|
||||
content
|
||||
content,
|
||||
)
|
||||
|
||||
|
||||
profile_path.write_text(content)
|
||||
logger.info("VaultMemory: Updated user profile: %s = %s", key, value)
|
||||
|
||||
|
||||
def _create_default_profile(self) -> None:
|
||||
"""Create default user profile."""
|
||||
profile_path = self.path / "self" / "user_profile.md"
|
||||
@@ -254,24 +256,26 @@ class VaultMemory:
|
||||
---
|
||||
|
||||
*Last updated: {date}*
|
||||
""".format(date=datetime.now(timezone.utc).strftime("%Y-%m-%d"))
|
||||
|
||||
""".format(
|
||||
date=datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
)
|
||||
|
||||
profile_path.write_text(default)
|
||||
|
||||
|
||||
class HandoffProtocol:
|
||||
"""Session handoff protocol for continuity."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.path = HANDOFF_PATH
|
||||
self.vault = VaultMemory()
|
||||
|
||||
|
||||
def write_handoff(
|
||||
self,
|
||||
session_summary: str,
|
||||
key_decisions: list[str],
|
||||
open_items: list[str],
|
||||
next_steps: list[str]
|
||||
next_steps: list[str],
|
||||
) -> None:
|
||||
"""Write handoff at session end."""
|
||||
content = f"""# Last Session Handoff
|
||||
@@ -303,25 +307,24 @@ The user was last working on: {session_summary[:200]}...
|
||||
|
||||
*This handoff will be auto-loaded at next session start*
|
||||
"""
|
||||
|
||||
|
||||
self.path.write_text(content)
|
||||
|
||||
|
||||
# Also archive to notes
|
||||
self.vault.write_note(
|
||||
"session_handoff",
|
||||
content,
|
||||
namespace="notes"
|
||||
self.vault.write_note("session_handoff", content, namespace="notes")
|
||||
|
||||
logger.info(
|
||||
"HandoffProtocol: Wrote handoff with %d decisions, %d open items",
|
||||
len(key_decisions),
|
||||
len(open_items),
|
||||
)
|
||||
|
||||
logger.info("HandoffProtocol: Wrote handoff with %d decisions, %d open items",
|
||||
len(key_decisions), len(open_items))
|
||||
|
||||
|
||||
def read_handoff(self) -> Optional[str]:
|
||||
"""Read handoff if exists."""
|
||||
if not self.path.exists():
|
||||
return None
|
||||
return self.path.read_text()
|
||||
|
||||
|
||||
def clear_handoff(self) -> None:
|
||||
"""Clear handoff after loading."""
|
||||
if self.path.exists():
|
||||
@@ -331,7 +334,7 @@ The user was last working on: {session_summary[:200]}...
|
||||
|
||||
class MemorySystem:
|
||||
"""Central memory system coordinating all tiers."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hot = HotMemory()
|
||||
self.vault = VaultMemory()
|
||||
@@ -339,52 +342,52 @@ class MemorySystem:
|
||||
self.session_start_time: Optional[datetime] = None
|
||||
self.session_decisions: list[str] = []
|
||||
self.session_open_items: list[str] = []
|
||||
|
||||
|
||||
def start_session(self) -> str:
|
||||
"""Start a new session, loading context from memory."""
|
||||
self.session_start_time = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Build context
|
||||
context_parts = []
|
||||
|
||||
|
||||
# 1. Hot memory
|
||||
hot_content = self.hot.read()
|
||||
context_parts.append("## Hot Memory\n" + hot_content)
|
||||
|
||||
|
||||
# 2. Last session handoff
|
||||
handoff_content = self.handoff.read_handoff()
|
||||
if handoff_content:
|
||||
context_parts.append("## Previous Session\n" + handoff_content)
|
||||
self.handoff.clear_handoff()
|
||||
|
||||
|
||||
# 3. User profile (key fields only)
|
||||
profile = self._load_user_profile_summary()
|
||||
if profile:
|
||||
context_parts.append("## User Context\n" + profile)
|
||||
|
||||
|
||||
full_context = "\n\n---\n\n".join(context_parts)
|
||||
logger.info("MemorySystem: Session started with %d chars context", len(full_context))
|
||||
|
||||
|
||||
return full_context
|
||||
|
||||
|
||||
def end_session(self, summary: str) -> None:
|
||||
"""End session, write handoff."""
|
||||
self.handoff.write_handoff(
|
||||
session_summary=summary,
|
||||
key_decisions=self.session_decisions,
|
||||
open_items=self.session_open_items,
|
||||
next_steps=[]
|
||||
next_steps=[],
|
||||
)
|
||||
|
||||
|
||||
# Update hot memory
|
||||
self.hot.update_section(
|
||||
"Current Session",
|
||||
f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n" +
|
||||
f"**Summary:** {summary[:100]}..."
|
||||
f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n"
|
||||
+ f"**Summary:** {summary[:100]}...",
|
||||
)
|
||||
|
||||
|
||||
logger.info("MemorySystem: Session ended, handoff written")
|
||||
|
||||
|
||||
def record_decision(self, decision: str) -> None:
|
||||
"""Record a key decision during session."""
|
||||
self.session_decisions.append(decision)
|
||||
@@ -393,43 +396,47 @@ class MemorySystem:
|
||||
if "## Key Decisions" in current:
|
||||
# Append to section
|
||||
pass # Handled at session end
|
||||
|
||||
|
||||
def record_open_item(self, item: str) -> None:
|
||||
"""Record an open item for follow-up."""
|
||||
self.session_open_items.append(item)
|
||||
|
||||
|
||||
def update_user_fact(self, key: str, value: str) -> None:
|
||||
"""Update user profile in vault."""
|
||||
self.vault.update_user_profile(key, value)
|
||||
# Also update hot memory
|
||||
if key.lower() == "name":
|
||||
self.hot.update_section("User Profile", f"**Name:** {value}")
|
||||
|
||||
|
||||
def _load_user_profile_summary(self) -> str:
|
||||
"""Load condensed user profile."""
|
||||
profile_path = self.vault.path / "self" / "user_profile.md"
|
||||
if not profile_path.exists():
|
||||
return ""
|
||||
|
||||
|
||||
content = profile_path.read_text()
|
||||
|
||||
|
||||
# Extract key fields
|
||||
summary_parts = []
|
||||
|
||||
|
||||
# Name
|
||||
name_match = re.search(r"\*\*Name:\*\* (.+)", content)
|
||||
if name_match and "unknown" not in name_match.group(1).lower():
|
||||
summary_parts.append(f"Name: {name_match.group(1).strip()}")
|
||||
|
||||
|
||||
# Interests
|
||||
interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL)
|
||||
if interests_section:
|
||||
interests = [i.strip() for i in interests_section.group(1).split("\n-") if i.strip() and "to be" not in i]
|
||||
interests = [
|
||||
i.strip()
|
||||
for i in interests_section.group(1).split("\n-")
|
||||
if i.strip() and "to be" not in i
|
||||
]
|
||||
if interests:
|
||||
summary_parts.append(f"Interests: {', '.join(interests[:3])}")
|
||||
|
||||
|
||||
return "\n".join(summary_parts) if summary_parts else ""
|
||||
|
||||
|
||||
def get_system_context(self) -> str:
|
||||
"""Get full context for system prompt injection.
|
||||
|
||||
|
||||
@@ -38,12 +38,14 @@ def _get_embedding_model():
|
||||
global EMBEDDING_MODEL
|
||||
if EMBEDDING_MODEL is None:
|
||||
from config import settings
|
||||
|
||||
if settings.timmy_skip_embeddings:
|
||||
EMBEDDING_MODEL = False
|
||||
return EMBEDDING_MODEL
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||
logger.info("SemanticMemory: Loaded embedding model")
|
||||
except ImportError:
|
||||
logger.warning("SemanticMemory: sentence-transformers not installed, using fallback")
|
||||
@@ -60,11 +62,12 @@ def _simple_hash_embedding(text: str) -> list[float]:
|
||||
h = hashlib.md5(word.encode()).hexdigest()
|
||||
for j in range(8):
|
||||
idx = (i * 8 + j) % 128
|
||||
vec[idx] += int(h[j*2:j*2+2], 16) / 255.0
|
||||
vec[idx] += int(h[j * 2 : j * 2 + 2], 16) / 255.0
|
||||
# Normalize
|
||||
import math
|
||||
mag = math.sqrt(sum(x*x for x in vec)) or 1.0
|
||||
return [x/mag for x in vec]
|
||||
|
||||
mag = math.sqrt(sum(x * x for x in vec)) or 1.0
|
||||
return [x / mag for x in vec]
|
||||
|
||||
|
||||
def embed_text(text: str) -> list[float]:
|
||||
@@ -80,9 +83,10 @@ def embed_text(text: str) -> list[float]:
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
import math
|
||||
dot = sum(x*y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x*x for x in a))
|
||||
mag_b = math.sqrt(sum(x*x for x in b))
|
||||
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0 or mag_b == 0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
@@ -91,6 +95,7 @@ def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
@dataclass
|
||||
class MemoryChunk:
|
||||
"""A searchable chunk of memory."""
|
||||
|
||||
id: str
|
||||
source: str # filepath
|
||||
content: str
|
||||
@@ -100,17 +105,18 @@ class MemoryChunk:
|
||||
|
||||
class SemanticMemory:
|
||||
"""Vector-based semantic search over vault content."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.db_path = SEMANTIC_DB_PATH
|
||||
self.vault_path = VAULT_PATH
|
||||
self._init_db()
|
||||
|
||||
|
||||
def _init_db(self) -> None:
|
||||
"""Initialize SQLite with vector storage."""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
conn.execute("""
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
source TEXT NOT NULL,
|
||||
@@ -119,76 +125,76 @@ class SemanticMemory:
|
||||
created_at TEXT NOT NULL,
|
||||
source_hash TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
"""
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def index_file(self, filepath: Path) -> int:
|
||||
"""Index a single file into semantic memory."""
|
||||
if not filepath.exists():
|
||||
return 0
|
||||
|
||||
|
||||
content = filepath.read_text()
|
||||
file_hash = hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
|
||||
# Check if already indexed with same hash
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
cursor = conn.execute(
|
||||
"SELECT source_hash FROM chunks WHERE source = ? LIMIT 1",
|
||||
(str(filepath),)
|
||||
"SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", (str(filepath),)
|
||||
)
|
||||
existing = cursor.fetchone()
|
||||
if existing and existing[0] == file_hash:
|
||||
conn.close()
|
||||
return 0 # Already indexed
|
||||
|
||||
|
||||
# Delete old chunks for this file
|
||||
conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),))
|
||||
|
||||
|
||||
# Split into chunks (paragraphs)
|
||||
chunks = self._split_into_chunks(content)
|
||||
|
||||
|
||||
# Index each chunk
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
for i, chunk_text in enumerate(chunks):
|
||||
if len(chunk_text.strip()) < 20: # Skip tiny chunks
|
||||
continue
|
||||
|
||||
|
||||
chunk_id = f"{filepath.stem}_{i}"
|
||||
embedding = embed_text(chunk_text)
|
||||
|
||||
|
||||
conn.execute(
|
||||
"""INSERT INTO chunks (id, source, content, embedding, created_at, source_hash)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash)
|
||||
(chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash),
|
||||
)
|
||||
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
|
||||
return len(chunks)
|
||||
|
||||
|
||||
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
|
||||
"""Split text into semantic chunks."""
|
||||
# Split by paragraphs first
|
||||
paragraphs = text.split('\n\n')
|
||||
paragraphs = text.split("\n\n")
|
||||
chunks = []
|
||||
|
||||
|
||||
for para in paragraphs:
|
||||
para = para.strip()
|
||||
if not para:
|
||||
continue
|
||||
|
||||
|
||||
# If paragraph is small enough, keep as one chunk
|
||||
if len(para) <= max_chunk_size:
|
||||
chunks.append(para)
|
||||
else:
|
||||
# Split long paragraphs by sentences
|
||||
sentences = para.replace('. ', '.\n').split('\n')
|
||||
sentences = para.replace(". ", ".\n").split("\n")
|
||||
current_chunk = ""
|
||||
|
||||
|
||||
for sent in sentences:
|
||||
if len(current_chunk) + len(sent) < max_chunk_size:
|
||||
current_chunk += " " + sent if current_chunk else sent
|
||||
@@ -196,82 +202,80 @@ class SemanticMemory:
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = sent
|
||||
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def index_vault(self) -> int:
|
||||
"""Index entire vault directory."""
|
||||
total_chunks = 0
|
||||
|
||||
|
||||
for md_file in self.vault_path.rglob("*.md"):
|
||||
# Skip handoff file (handled separately)
|
||||
if "last-session-handoff" in md_file.name:
|
||||
continue
|
||||
total_chunks += self.index_file(md_file)
|
||||
|
||||
|
||||
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
|
||||
return total_chunks
|
||||
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
|
||||
"""Search for relevant memory chunks."""
|
||||
query_embedding = embed_text(query)
|
||||
|
||||
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
|
||||
# Get all chunks (in production, use vector index)
|
||||
rows = conn.execute(
|
||||
"SELECT source, content, embedding FROM chunks"
|
||||
).fetchall()
|
||||
|
||||
rows = conn.execute("SELECT source, content, embedding FROM chunks").fetchall()
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
# Calculate similarities
|
||||
scored = []
|
||||
for row in rows:
|
||||
embedding = json.loads(row["embedding"])
|
||||
score = cosine_similarity(query_embedding, embedding)
|
||||
scored.append((row["source"], row["content"], score))
|
||||
|
||||
|
||||
# Sort by score descending
|
||||
scored.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
|
||||
# Return top_k
|
||||
return [(content, score) for _, content, score in scored[:top_k]]
|
||||
|
||||
|
||||
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
|
||||
"""Get formatted context string for a query."""
|
||||
results = self.search(query, top_k=3)
|
||||
|
||||
|
||||
if not results:
|
||||
return ""
|
||||
|
||||
|
||||
parts = []
|
||||
total_chars = 0
|
||||
|
||||
|
||||
for content, score in results:
|
||||
if score < 0.3: # Similarity threshold
|
||||
continue
|
||||
|
||||
|
||||
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
|
||||
if total_chars + len(chunk) > max_chars:
|
||||
break
|
||||
|
||||
|
||||
parts.append(chunk)
|
||||
total_chars += len(chunk)
|
||||
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Get indexing statistics."""
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks")
|
||||
total_chunks, total_files = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
|
||||
return {
|
||||
"total_chunks": total_chunks,
|
||||
"total_files": total_files,
|
||||
@@ -281,40 +285,39 @@ class SemanticMemory:
|
||||
|
||||
class MemorySearcher:
|
||||
"""High-level interface for memory search."""
|
||||
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.semantic = SemanticMemory()
|
||||
|
||||
|
||||
def search(self, query: str, tiers: list[str] = None) -> dict:
|
||||
"""Search across memory tiers.
|
||||
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
tiers: List of tiers to search ["hot", "vault", "semantic"]
|
||||
|
||||
|
||||
Returns:
|
||||
Dict with results from each tier
|
||||
"""
|
||||
tiers = tiers or ["semantic"] # Default to semantic only
|
||||
results = {}
|
||||
|
||||
|
||||
if "semantic" in tiers:
|
||||
semantic_results = self.semantic.search(query, top_k=5)
|
||||
results["semantic"] = [
|
||||
{"content": content, "score": score}
|
||||
for content, score in semantic_results
|
||||
{"content": content, "score": score} for content, score in semantic_results
|
||||
]
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_context_for_query(self, query: str) -> str:
|
||||
"""Get comprehensive context for a user query."""
|
||||
# Get semantic context
|
||||
semantic_context = self.semantic.get_relevant_context(query)
|
||||
|
||||
|
||||
if semantic_context:
|
||||
return f"## Relevant Past Context\n\n{semantic_context}"
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@@ -353,6 +356,7 @@ def memory_search(query: str, top_k: int = 5) -> str:
|
||||
# 2. Search runtime vector store (stored facts/conversations)
|
||||
try:
|
||||
from timmy.memory.vector_store import search_memories
|
||||
|
||||
runtime_results = search_memories(query, limit=top_k, min_relevance=0.2)
|
||||
for entry in runtime_results:
|
||||
label = entry.context_type or "memory"
|
||||
@@ -387,6 +391,7 @@ def memory_read(query: str = "", top_k: int = 5) -> str:
|
||||
# Always include personal facts first
|
||||
try:
|
||||
from timmy.memory.vector_store import search_memories
|
||||
|
||||
facts = search_memories(query or "", limit=top_k, min_relevance=0.0)
|
||||
fact_entries = [e for e in facts if (e.context_type or "") == "fact"]
|
||||
if fact_entries:
|
||||
@@ -433,6 +438,7 @@ def memory_write(content: str, context_type: str = "fact") -> str:
|
||||
|
||||
try:
|
||||
from timmy.memory.vector_store import store_memory
|
||||
|
||||
entry = store_memory(
|
||||
content=content.strip(),
|
||||
source="agent",
|
||||
|
||||
@@ -32,13 +32,15 @@ _TOOL_CALL_JSON = re.compile(
|
||||
|
||||
# Matches function-call-style text: memory_search(query="...") etc.
|
||||
_FUNC_CALL_TEXT = re.compile(
|
||||
r'\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)'
|
||||
r'\s*\([^)]*\)',
|
||||
r"\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)"
|
||||
r"\s*\([^)]*\)",
|
||||
)
|
||||
|
||||
# Matches chain-of-thought narration lines the model should keep internal
|
||||
_COT_PATTERNS = [
|
||||
re.compile(r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE),
|
||||
re.compile(
|
||||
r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE
|
||||
),
|
||||
re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE),
|
||||
]
|
||||
|
||||
@@ -48,6 +50,7 @@ def _get_agent():
|
||||
global _agent
|
||||
if _agent is None:
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
try:
|
||||
_agent = create_timmy()
|
||||
logger.info("Session: Timmy agent initialized (singleton)")
|
||||
@@ -99,6 +102,7 @@ def reset_session(session_id: Optional[str] = None) -> None:
|
||||
sid = session_id or _DEFAULT_SESSION_ID
|
||||
try:
|
||||
from timmy.conversation import conversation_manager
|
||||
|
||||
conversation_manager.clear_context(sid)
|
||||
except Exception as exc:
|
||||
logger.debug("Session: context clear failed for %s: %s", sid, exc)
|
||||
@@ -112,10 +116,12 @@ def _extract_facts(message: str) -> None:
|
||||
"""
|
||||
try:
|
||||
from timmy.conversation import conversation_manager
|
||||
|
||||
name = conversation_manager.extract_user_name(message)
|
||||
if name:
|
||||
try:
|
||||
from timmy.memory_system import memory_system
|
||||
|
||||
memory_system.update_user_fact("Name", name)
|
||||
logger.info("Session: Learned user name: %s", name)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -6,7 +6,7 @@ including any mistakes or errors that occur during the session."
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, date
|
||||
from datetime import date, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ Continue your train of thought."""
|
||||
@dataclass
|
||||
class Thought:
|
||||
"""A single thought in Timmy's inner stream."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
seed_type: str
|
||||
@@ -98,9 +99,7 @@ def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)"
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
@@ -190,9 +189,7 @@ class ThinkingEngine:
|
||||
def get_thought(self, thought_id: str) -> Optional[Thought]:
|
||||
"""Retrieve a single thought by ID."""
|
||||
conn = _get_conn(self._db_path)
|
||||
row = conn.execute(
|
||||
"SELECT * FROM thoughts WHERE id = ?", (thought_id,)
|
||||
).fetchone()
|
||||
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone()
|
||||
conn.close()
|
||||
return _row_to_thought(row) if row else None
|
||||
|
||||
@@ -208,9 +205,7 @@ class ThinkingEngine:
|
||||
for _ in range(max_depth):
|
||||
if not current_id:
|
||||
break
|
||||
row = conn.execute(
|
||||
"SELECT * FROM thoughts WHERE id = ?", (current_id,)
|
||||
).fetchone()
|
||||
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (current_id,)).fetchone()
|
||||
if not row:
|
||||
break
|
||||
chain.append(_row_to_thought(row))
|
||||
@@ -254,8 +249,10 @@ class ThinkingEngine:
|
||||
def _seed_from_swarm(self) -> str:
|
||||
"""Gather recent swarm activity as thought seed."""
|
||||
try:
|
||||
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
|
||||
from datetime import timedelta
|
||||
|
||||
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
|
||||
|
||||
since = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
swarm = _gather_swarm_summary(since)
|
||||
tasks = _gather_task_queue_summary()
|
||||
@@ -272,6 +269,7 @@ class ThinkingEngine:
|
||||
"""Gather memory context as thought seed."""
|
||||
try:
|
||||
from timmy.memory_system import memory_system
|
||||
|
||||
context = memory_system.get_system_context()
|
||||
if context:
|
||||
# Truncate to a reasonable size for a thought seed
|
||||
@@ -299,10 +297,12 @@ class ThinkingEngine:
|
||||
"""
|
||||
try:
|
||||
from timmy.session import chat
|
||||
|
||||
return chat(prompt, session_id="thinking")
|
||||
except Exception:
|
||||
# Fallback: create a fresh agent
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
agent = create_timmy()
|
||||
run = agent.run(prompt, stream=False)
|
||||
return run.content if hasattr(run, "content") else str(run)
|
||||
@@ -323,8 +323,7 @@ class ThinkingEngine:
|
||||
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(thought.id, thought.content, thought.seed_type,
|
||||
thought.parent_id, thought.created_at),
|
||||
(thought.id, thought.content, thought.seed_type, thought.parent_id, thought.created_at),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
@@ -333,7 +332,8 @@ class ThinkingEngine:
|
||||
def _log_event(self, thought: Thought) -> None:
|
||||
"""Log the thought as a swarm event."""
|
||||
try:
|
||||
from swarm.event_log import log_event, EventType
|
||||
from swarm.event_log import EventType, log_event
|
||||
|
||||
log_event(
|
||||
EventType.TIMMY_THOUGHT,
|
||||
source="thinking-engine",
|
||||
@@ -351,12 +351,16 @@ class ThinkingEngine:
|
||||
"""Broadcast the thought to WebSocket clients."""
|
||||
try:
|
||||
from infrastructure.ws_manager.handler import ws_manager
|
||||
await ws_manager.broadcast("timmy_thought", {
|
||||
"thought_id": thought.id,
|
||||
"content": thought.content,
|
||||
"seed_type": thought.seed_type,
|
||||
"created_at": thought.created_at,
|
||||
})
|
||||
|
||||
await ws_manager.broadcast(
|
||||
"timmy_thought",
|
||||
{
|
||||
"thought_id": thought.id,
|
||||
"content": thought.content,
|
||||
"seed_type": thought.seed_type,
|
||||
"created_at": thought.created_at,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to broadcast thought: %s", exc)
|
||||
|
||||
|
||||
@@ -227,11 +227,7 @@ def create_aider_tool(base_path: Path):
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
return (
|
||||
result.stdout
|
||||
if result.stdout
|
||||
else "Code changes applied successfully"
|
||||
)
|
||||
return result.stdout if result.stdout else "Code changes applied successfully"
|
||||
else:
|
||||
return f"Aider error: {result.stderr}"
|
||||
except FileNotFoundError:
|
||||
@@ -354,7 +350,7 @@ def consult_grok(query: str) -> str:
|
||||
Grok's response text, or an error/status message.
|
||||
"""
|
||||
from config import settings
|
||||
from timmy.backends import grok_available, get_grok_backend
|
||||
from timmy.backends import get_grok_backend, grok_available
|
||||
|
||||
if not grok_available():
|
||||
return (
|
||||
@@ -385,9 +381,7 @@ def consult_grok(query: str) -> str:
|
||||
ln = get_ln_backend()
|
||||
sats = min(settings.grok_max_sats_per_query, 100)
|
||||
inv = ln.create_invoice(sats, f"Grok query: {query[:50]}")
|
||||
invoice_info = (
|
||||
f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
|
||||
)
|
||||
invoice_info = f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -447,7 +441,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
||||
|
||||
# Memory search and write — persistent recall across all channels
|
||||
try:
|
||||
from timmy.semantic_memory import memory_search, memory_write, memory_read
|
||||
from timmy.semantic_memory import memory_read, memory_search, memory_write
|
||||
|
||||
toolkit.register(memory_search, name="memory_search")
|
||||
toolkit.register(memory_write, name="memory_write")
|
||||
@@ -473,6 +467,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
||||
Task ID and confirmation that background execution has started.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
task_id = None
|
||||
|
||||
async def _launch():
|
||||
@@ -502,11 +497,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
||||
|
||||
# System introspection - query runtime environment (sovereign self-knowledge)
|
||||
try:
|
||||
from timmy.tools_intro import (
|
||||
get_system_info,
|
||||
check_ollama_health,
|
||||
get_memory_status,
|
||||
)
|
||||
from timmy.tools_intro import check_ollama_health, get_memory_status, get_system_info
|
||||
|
||||
toolkit.register(get_system_info, name="get_system_info")
|
||||
toolkit.register(check_ollama_health, name="check_ollama_health")
|
||||
@@ -526,6 +517,60 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
||||
return toolkit
|
||||
|
||||
|
||||
def create_experiment_tools(base_dir: str | Path | None = None):
|
||||
"""Create tools for the experiment agent (Lab).
|
||||
|
||||
Includes: prepare_experiment, run_experiment, evaluate_result,
|
||||
plus shell + file ops for editing training code.
|
||||
"""
|
||||
if not _AGNO_TOOLS_AVAILABLE:
|
||||
raise ImportError(f"Agno tools not available: {_ImportError}")
|
||||
|
||||
from config import settings
|
||||
|
||||
toolkit = Toolkit(name="experiment")
|
||||
|
||||
from timmy.autoresearch import evaluate_result, prepare_experiment, run_experiment
|
||||
|
||||
workspace = (
|
||||
Path(base_dir) if base_dir else Path(settings.repo_root) / settings.autoresearch_workspace
|
||||
)
|
||||
|
||||
def _prepare(repo_url: str = "https://github.com/karpathy/autoresearch.git") -> str:
|
||||
"""Clone and prepare an autoresearch experiment workspace."""
|
||||
return prepare_experiment(workspace, repo_url)
|
||||
|
||||
def _run(timeout: int = 0) -> str:
|
||||
"""Run a single training experiment with wall-clock timeout."""
|
||||
t = timeout or settings.autoresearch_time_budget
|
||||
result = run_experiment(workspace, timeout=t, metric_name=settings.autoresearch_metric)
|
||||
if result["success"] and result["metric"] is not None:
|
||||
return (
|
||||
f"{settings.autoresearch_metric}: {result['metric']:.4f} ({result['duration_s']}s)"
|
||||
)
|
||||
return result.get("error") or "Experiment failed"
|
||||
|
||||
def _evaluate(current: float, baseline: float) -> str:
|
||||
"""Compare current metric against baseline."""
|
||||
return evaluate_result(current, baseline, metric_name=settings.autoresearch_metric)
|
||||
|
||||
toolkit.register(_prepare, name="prepare_experiment")
|
||||
toolkit.register(_run, name="run_experiment")
|
||||
toolkit.register(_evaluate, name="evaluate_result")
|
||||
|
||||
# Also give Lab access to file + shell tools for editing train.py
|
||||
shell_tools = ShellTools()
|
||||
toolkit.register(shell_tools.run_shell_command, name="shell")
|
||||
|
||||
base_path = Path(base_dir) if base_dir else Path(settings.repo_root)
|
||||
file_tools = FileTools(base_dir=base_path)
|
||||
toolkit.register(file_tools.read_file, name="read_file")
|
||||
toolkit.register(file_tools.save_file, name="write_file")
|
||||
toolkit.register(file_tools.list_files, name="list_files")
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
# Mapping of agent IDs to their toolkits
|
||||
AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
|
||||
"echo": create_research_tools,
|
||||
@@ -534,6 +579,7 @@ AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
|
||||
"seer": create_data_tools,
|
||||
"forge": create_code_tools,
|
||||
"quill": create_writing_tools,
|
||||
"lab": create_experiment_tools,
|
||||
"pixel": lambda base_dir=None: _create_stub_toolkit("pixel"),
|
||||
"lyra": lambda base_dir=None: _create_stub_toolkit("lyra"),
|
||||
"reel": lambda base_dir=None: _create_stub_toolkit("reel"),
|
||||
@@ -553,9 +599,7 @@ def _create_stub_toolkit(name: str):
|
||||
return toolkit
|
||||
|
||||
|
||||
def get_tools_for_agent(
|
||||
agent_id: str, base_dir: str | Path | None = None
|
||||
) -> Toolkit | None:
|
||||
def get_tools_for_agent(agent_id: str, base_dir: str | Path | None = None) -> Toolkit | None:
|
||||
"""Get the appropriate toolkit for an agent.
|
||||
|
||||
Args:
|
||||
@@ -643,6 +687,21 @@ def get_all_available_tools() -> dict[str, dict]:
|
||||
"description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)",
|
||||
"available_in": ["forge", "orchestrator"],
|
||||
},
|
||||
"prepare_experiment": {
|
||||
"name": "Prepare Experiment",
|
||||
"description": "Clone autoresearch repo and run data preparation for ML experiments",
|
||||
"available_in": ["lab", "orchestrator"],
|
||||
},
|
||||
"run_experiment": {
|
||||
"name": "Run Experiment",
|
||||
"description": "Execute a time-boxed ML training experiment and capture metrics",
|
||||
"available_in": ["lab", "orchestrator"],
|
||||
},
|
||||
"evaluate_result": {
|
||||
"name": "Evaluate Result",
|
||||
"description": "Compare experiment metric against baseline to assess improvement",
|
||||
"available_in": ["lab", "orchestrator"],
|
||||
},
|
||||
}
|
||||
|
||||
# ── Git tools ─────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -20,7 +20,9 @@ _VALID_AGENTS: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
def delegate_task(agent_name: str, task_description: str, priority: str = "normal") -> dict[str, Any]:
|
||||
def delegate_task(
|
||||
agent_name: str, task_description: str, priority: str = "normal"
|
||||
) -> dict[str, Any]:
|
||||
"""Record a delegation intent to another agent.
|
||||
|
||||
Args:
|
||||
@@ -44,7 +46,9 @@ def delegate_task(agent_name: str, task_description: str, priority: str = "norma
|
||||
if priority not in valid_priorities:
|
||||
priority = "normal"
|
||||
|
||||
logger.info("Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority)
|
||||
logger.info(
|
||||
"Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
@@ -65,9 +65,7 @@ def _get_ollama_model() -> str:
|
||||
models = response.json().get("models", [])
|
||||
# Check if configured model is available
|
||||
for model in models:
|
||||
if model.get("name", "").startswith(
|
||||
settings.ollama_model.split(":")[0]
|
||||
):
|
||||
if model.get("name", "").startswith(settings.ollama_model.split(":")[0]):
|
||||
return settings.ollama_model
|
||||
|
||||
# Fallback: return configured model
|
||||
@@ -139,9 +137,7 @@ def get_memory_status() -> dict[str, Any]:
|
||||
if tier1_exists:
|
||||
lines = memory_md.read_text().splitlines()
|
||||
tier1_info["line_count"] = len(lines)
|
||||
tier1_info["sections"] = [
|
||||
ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")
|
||||
]
|
||||
tier1_info["sections"] = [ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")]
|
||||
|
||||
# Vault — scan all subdirs under memory/
|
||||
vault_root = repo_root / "memory"
|
||||
@@ -233,13 +229,15 @@ def get_agent_roster() -> dict[str, Any]:
|
||||
|
||||
roster = []
|
||||
for persona in _PERSONAS:
|
||||
roster.append({
|
||||
"id": persona["agent_id"],
|
||||
"name": persona["name"],
|
||||
"status": "available",
|
||||
"capabilities": ", ".join(persona.get("tools", [])),
|
||||
"role": persona.get("role", ""),
|
||||
})
|
||||
roster.append(
|
||||
{
|
||||
"id": persona["agent_id"],
|
||||
"name": persona["name"],
|
||||
"status": "available",
|
||||
"capabilities": ", ".join(persona.get("tools", [])),
|
||||
"role": persona.get("role", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"agents": roster,
|
||||
|
||||
@@ -41,7 +41,7 @@ class StatusResponse(BaseModel):
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Simple in-memory rate limiting middleware."""
|
||||
|
||||
|
||||
def __init__(self, app, limit: int = 10, window: int = 60):
|
||||
super().__init__(app)
|
||||
self.limit = limit
|
||||
@@ -53,22 +53,20 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
if request.url.path == "/serve/chat" and request.method == "POST":
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
now = time.time()
|
||||
|
||||
|
||||
# Clean up old requests
|
||||
self.requests[client_ip] = [
|
||||
t for t in self.requests[client_ip]
|
||||
if now - t < self.window
|
||||
t for t in self.requests[client_ip] if now - t < self.window
|
||||
]
|
||||
|
||||
|
||||
if len(self.requests[client_ip]) >= self.limit:
|
||||
logger.warning("Rate limit exceeded for %s", client_ip)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"error": "Rate limit exceeded. Try again later."}
|
||||
status_code=429, content={"error": "Rate limit exceeded. Try again later."}
|
||||
)
|
||||
|
||||
|
||||
self.requests[client_ip].append(now)
|
||||
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ def start(
|
||||
return
|
||||
|
||||
import uvicorn
|
||||
|
||||
from timmy_serve.app import create_timmy_serve_app
|
||||
|
||||
serve_app = create_timmy_serve_app()
|
||||
|
||||
@@ -23,9 +23,7 @@ class AgentMessage:
|
||||
to_agent: str = ""
|
||||
content: str = ""
|
||||
message_type: str = "text" # text | command | response | error
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
replied: bool = False
|
||||
|
||||
|
||||
@@ -56,7 +54,10 @@ class InterAgentMessenger:
|
||||
self._all_messages.append(msg)
|
||||
logger.info(
|
||||
"Message %s → %s: %s (%s)",
|
||||
from_agent, to_agent, content[:50], message_type,
|
||||
from_agent,
|
||||
to_agent,
|
||||
content[:50],
|
||||
message_type,
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class VoiceTTS:
|
||||
def _init_engine(self) -> None:
|
||||
try:
|
||||
import pyttsx3
|
||||
|
||||
self._engine = pyttsx3.init()
|
||||
self._engine.setProperty("rate", self._rate)
|
||||
self._engine.setProperty("volume", self._volume)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user