forked from Rockachopa/Timmy-time-dashboard
feat: code quality audit + autoresearch integration + infra hardening (#150)
This commit is contained in:
committed by
GitHub
parent
fd0ede0d51
commit
ae3bb1cc21
20
.env.example
20
.env.example
@@ -71,3 +71,23 @@
|
|||||||
# Requires: pip install ".[discord]"
|
# Requires: pip install ".[discord]"
|
||||||
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
|
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
|
||||||
# DISCORD_TOKEN=
|
# DISCORD_TOKEN=
|
||||||
|
|
||||||
|
# ── Autoresearch — autonomous ML experiment loops ────────────────────────────
|
||||||
|
# Enable autonomous experiment loops (Karpathy autoresearch pattern).
|
||||||
|
# AUTORESEARCH_ENABLED=false
|
||||||
|
# AUTORESEARCH_WORKSPACE=data/experiments
|
||||||
|
# AUTORESEARCH_TIME_BUDGET=300
|
||||||
|
# AUTORESEARCH_MAX_ITERATIONS=100
|
||||||
|
# AUTORESEARCH_METRIC=val_bpb
|
||||||
|
|
||||||
|
# ── Docker Production ────────────────────────────────────────────────────────
|
||||||
|
# When deploying with docker-compose.prod.yml:
|
||||||
|
# - Containers run as non-root user "timmy" (defined in Dockerfile)
|
||||||
|
# - No source bind mounts — code is baked into the image
|
||||||
|
# - Set TIMMY_ENV=production to enforce security checks
|
||||||
|
# - All secrets below MUST be set before production deployment
|
||||||
|
#
|
||||||
|
# Taskosaur secrets (change from dev defaults):
|
||||||
|
# TASKOSAUR_JWT_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||||
|
# TASKOSAUR_JWT_REFRESH_SECRET=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||||
|
# TASKOSAUR_ENCRYPTION_KEY=<generate with: python3 -c "import secrets; print(secrets.token_hex(32))">
|
||||||
|
|||||||
40
.github/workflows/tests.yml
vendored
40
.github/workflows/tests.yml
vendored
@@ -7,8 +7,30 @@ on:
|
|||||||
branches: ["**"]
|
branches: ["**"]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install linters
|
||||||
|
run: pip install black==23.12.1 isort==5.13.2 bandit==1.7.5
|
||||||
|
|
||||||
|
- name: Check formatting (black)
|
||||||
|
run: black --check --line-length 100 src/ tests/
|
||||||
|
|
||||||
|
- name: Check import order (isort)
|
||||||
|
run: isort --check --profile black --line-length 100 src/ tests/
|
||||||
|
|
||||||
|
- name: Security scan (bandit)
|
||||||
|
run: bandit -r src/ -ll -s B101,B104,B307,B310,B324,B601,B608 -q
|
||||||
|
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
|
||||||
# Required for publish-unit-test-result-action to post check runs and PR comments
|
# Required for publish-unit-test-result-action to post check runs and PR comments
|
||||||
permissions:
|
permissions:
|
||||||
@@ -22,7 +44,15 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: "pip"
|
|
||||||
|
- name: Cache Poetry virtualenv
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
~/.cache/pypoetry
|
||||||
|
~/.cache/pip
|
||||||
|
key: poetry-${{ hashFiles('poetry.lock') }}
|
||||||
|
restore-keys: poetry-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -60,3 +90,11 @@ jobs:
|
|||||||
name: coverage-report
|
name: coverage-report
|
||||||
path: reports/coverage.xml
|
path: reports/coverage.xml
|
||||||
retention-days: 14
|
retention-days: 14
|
||||||
|
|
||||||
|
docker-build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
|
run: DOCKER_BUILDKIT=1 docker build -t timmy-time:ci .
|
||||||
|
|||||||
@@ -51,12 +51,12 @@ repos:
|
|||||||
exclude: ^tests/
|
exclude: ^tests/
|
||||||
stages: [manual]
|
stages: [manual]
|
||||||
|
|
||||||
# Full test suite with 30-second wall-clock limit.
|
# Unit tests only with 30-second wall-clock limit.
|
||||||
# Current baseline: ~18s. If tests get slow, this blocks the commit.
|
# Runs only fast unit tests on commit; full suite runs in CI.
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: pytest-fast
|
- id: pytest-fast
|
||||||
name: pytest (30s limit)
|
name: pytest unit (30s limit)
|
||||||
entry: timeout 30 poetry run pytest
|
entry: timeout 30 poetry run pytest
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
@@ -68,4 +68,8 @@ repos:
|
|||||||
- -q
|
- -q
|
||||||
- --tb=short
|
- --tb=short
|
||||||
- --timeout=10
|
- --timeout=10
|
||||||
|
- -m
|
||||||
|
- unit
|
||||||
|
- -p
|
||||||
|
- no:xdist
|
||||||
verbose: true
|
verbose: true
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ make test-cov # With coverage (term-missing + XML)
|
|||||||
- **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest
|
- **Test mode:** `TIMMY_TEST_MODE=1` set automatically in conftest
|
||||||
- **FastAPI testing:** Use the `client` fixture
|
- **FastAPI testing:** Use the `client` fixture
|
||||||
- **Async:** `asyncio_mode = "auto"` — async tests detected automatically
|
- **Async:** `asyncio_mode = "auto"` — async tests detected automatically
|
||||||
- **Coverage threshold:** 60% (`fail_under` in `pyproject.toml`)
|
- **Coverage threshold:** 73% (`fail_under` in `pyproject.toml`)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
15
Dockerfile
15
Dockerfile
@@ -11,7 +11,7 @@
|
|||||||
# timmy-time:latest \
|
# timmy-time:latest \
|
||||||
# python -m swarm.agent_runner --agent-id w1 --name Worker-1
|
# python -m swarm.agent_runner --agent-id w1 --name Worker-1
|
||||||
|
|
||||||
# ── Stage 1: Builder — export deps via Poetry, install via pip ──────────────
|
# ── Stage 1: Builder — install deps via Poetry ──────────────────────────────
|
||||||
FROM python:3.12-slim AS builder
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
@@ -20,18 +20,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
# Install Poetry + export plugin (only needed for export, not in runtime)
|
# Install Poetry (only needed to resolve deps, not in runtime)
|
||||||
RUN pip install --no-cache-dir poetry poetry-plugin-export
|
RUN pip install --no-cache-dir poetry
|
||||||
|
|
||||||
# Copy dependency files only (layer caching)
|
# Copy dependency files only (layer caching)
|
||||||
COPY pyproject.toml poetry.lock ./
|
COPY pyproject.toml poetry.lock ./
|
||||||
|
|
||||||
# Export pinned requirements and install with pip cache mount
|
# Install deps directly from lock file (no virtualenv, no export plugin needed)
|
||||||
RUN poetry export --extras swarm --extras telegram --extras discord --without-hashes \
|
RUN poetry config virtualenvs.create false && \
|
||||||
-f requirements.txt -o requirements.txt
|
poetry install --only main --extras telegram --extras discord --no-interaction
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# ── Stage 2: Runtime ───────────────────────────────────────────────────────
|
# ── Stage 2: Runtime ───────────────────────────────────────────────────────
|
||||||
FROM python:3.12-slim AS base
|
FROM python:3.12-slim AS base
|
||||||
|
|||||||
5
Makefile
5
Makefile
@@ -210,6 +210,11 @@ docker-up:
|
|||||||
mkdir -p data
|
mkdir -p data
|
||||||
docker compose up -d dashboard
|
docker compose up -d dashboard
|
||||||
|
|
||||||
|
docker-prod:
|
||||||
|
mkdir -p data
|
||||||
|
DOCKER_BUILDKIT=1 docker build -t timmy-time:latest .
|
||||||
|
docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d dashboard
|
||||||
|
|
||||||
docker-down:
|
docker-down:
|
||||||
docker compose down
|
docker compose down
|
||||||
|
|
||||||
|
|||||||
56
docker-compose.prod.yml
Normal file
56
docker-compose.prod.yml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# ── Production Compose Overlay ─────────────────────────────────────────────────
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# make docker-prod # build + start with prod settings
|
||||||
|
# docker compose -f docker-compose.yml -f docker-compose.prod.yml up -d
|
||||||
|
#
|
||||||
|
# Differences from dev:
|
||||||
|
# - Runs as non-root user (timmy) from Dockerfile
|
||||||
|
# - No bind mounts — uses image-baked source only
|
||||||
|
# - Named volumes only (no host path dependencies)
|
||||||
|
# - Read-only root filesystem with tmpfs for /tmp
|
||||||
|
# - Resource limits enforced
|
||||||
|
# - Secrets passed via environment variables (set in .env)
|
||||||
|
#
|
||||||
|
# Security note: Set all secrets in .env before deploying.
|
||||||
|
# Required: L402_HMAC_SECRET, L402_MACAROON_SECRET
|
||||||
|
# Recommended: TASKOSAUR_JWT_SECRET, TASKOSAUR_ENCRYPTION_KEY
|
||||||
|
|
||||||
|
services:
|
||||||
|
|
||||||
|
dashboard:
|
||||||
|
# Remove dev-only root user override — use Dockerfile's USER timmy
|
||||||
|
user: ""
|
||||||
|
read_only: true
|
||||||
|
tmpfs:
|
||||||
|
- /tmp:size=100M
|
||||||
|
volumes:
|
||||||
|
# Override: named volume only, no host bind mounts
|
||||||
|
- timmy-data:/app/data
|
||||||
|
# Remove ./src and ./static bind mounts (use baked-in image files)
|
||||||
|
environment:
|
||||||
|
DEBUG: "false"
|
||||||
|
TIMMY_ENV: "production"
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: "2.0"
|
||||||
|
memory: 2G
|
||||||
|
|
||||||
|
celery-worker:
|
||||||
|
user: ""
|
||||||
|
read_only: true
|
||||||
|
tmpfs:
|
||||||
|
- /tmp:size=100M
|
||||||
|
volumes:
|
||||||
|
- timmy-data:/app/data
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
cpus: "1.0"
|
||||||
|
memory: 1G
|
||||||
|
|
||||||
|
# Override timmy-data to use a simple named volume (no host bind)
|
||||||
|
volumes:
|
||||||
|
timmy-data:
|
||||||
|
driver: local
|
||||||
@@ -97,6 +97,12 @@ markers = [
|
|||||||
"skip_ci: Skip in CI environment (local development only)",
|
"skip_ci: Skip in CI environment (local development only)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 100
|
||||||
|
src_paths = ["src", "tests"]
|
||||||
|
known_first_party = ["brain", "config", "dashboard", "infrastructure", "integrations", "spark", "swarm", "timmy", "timmy_serve"]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
source = ["src"]
|
source = ["src"]
|
||||||
omit = [
|
omit = [
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ upgrade to distributed rqlite over Tailscale — same API, replicated.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from brain.client import BrainClient
|
from brain.client import BrainClient
|
||||||
from brain.worker import DistributedWorker
|
|
||||||
from brain.embeddings import LocalEmbedder
|
from brain.embeddings import LocalEmbedder
|
||||||
from brain.memory import UnifiedMemory, get_memory
|
from brain.memory import UnifiedMemory, get_memory
|
||||||
|
from brain.worker import DistributedWorker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BrainClient",
|
"BrainClient",
|
||||||
|
|||||||
@@ -21,52 +21,54 @@ DEFAULT_RQLITE_URL = "http://localhost:4001"
|
|||||||
|
|
||||||
class BrainClient:
|
class BrainClient:
|
||||||
"""Client for distributed brain (rqlite).
|
"""Client for distributed brain (rqlite).
|
||||||
|
|
||||||
Connects to local rqlite instance, which handles replication.
|
Connects to local rqlite instance, which handles replication.
|
||||||
All writes go to leader, reads can come from local node.
|
All writes go to leader, reads can come from local node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None):
|
def __init__(self, rqlite_url: Optional[str] = None, node_id: Optional[str] = None):
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL
|
self.rqlite_url = rqlite_url or settings.rqlite_url or DEFAULT_RQLITE_URL
|
||||||
self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}"
|
self.node_id = node_id or f"{socket.gethostname()}-{os.getpid()}"
|
||||||
self.source = self._detect_source()
|
self.source = self._detect_source()
|
||||||
self._client = httpx.AsyncClient(timeout=30)
|
self._client = httpx.AsyncClient(timeout=30)
|
||||||
|
|
||||||
def _detect_source(self) -> str:
|
def _detect_source(self) -> str:
|
||||||
"""Detect what component is using the brain."""
|
"""Detect what component is using the brain."""
|
||||||
# Could be 'timmy', 'zeroclaw', 'worker', etc.
|
# Could be 'timmy', 'zeroclaw', 'worker', etc.
|
||||||
# For now, infer from context or env
|
# For now, infer from context or env
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return settings.brain_source
|
return settings.brain_source
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# Memory Operations
|
# Memory Operations
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def remember(
|
async def remember(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
source: Optional[str] = None,
|
source: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Store a memory with embedding.
|
"""Store a memory with embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Text content to remember
|
content: Text content to remember
|
||||||
tags: Optional list of tags (e.g., ['shell', 'result'])
|
tags: Optional list of tags (e.g., ['shell', 'result'])
|
||||||
source: Source identifier (defaults to self.source)
|
source: Source identifier (defaults to self.source)
|
||||||
metadata: Additional JSON-serializable metadata
|
metadata: Additional JSON-serializable metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with 'id' and 'status'
|
Dict with 'id' and 'status'
|
||||||
"""
|
"""
|
||||||
from brain.embeddings import get_embedder
|
from brain.embeddings import get_embedder
|
||||||
|
|
||||||
embedder = get_embedder()
|
embedder = get_embedder()
|
||||||
embedding_bytes = embedder.encode_single(content)
|
embedding_bytes = embedder.encode_single(content)
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
INSERT INTO memories (content, embedding, source, tags, metadata, created_at)
|
INSERT INTO memories (content, embedding, source, tags, metadata, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
@@ -77,100 +79,90 @@ class BrainClient:
|
|||||||
source or self.source,
|
source or self.source,
|
||||||
json.dumps(tags or []),
|
json.dumps(tags or []),
|
||||||
json.dumps(metadata or {}),
|
json.dumps(metadata or {}),
|
||||||
datetime.utcnow().isoformat()
|
datetime.utcnow().isoformat(),
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||||
f"{self.rqlite_url}/db/execute",
|
|
||||||
json=[query, params]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
# Extract inserted ID
|
# Extract inserted ID
|
||||||
last_id = None
|
last_id = None
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
last_id = result["results"][0].get("last_insert_id")
|
last_id = result["results"][0].get("last_insert_id")
|
||||||
|
|
||||||
logger.debug(f"Stored memory {last_id}: {content[:50]}...")
|
logger.debug(f"Stored memory {last_id}: {content[:50]}...")
|
||||||
return {"id": last_id, "status": "stored"}
|
return {"id": last_id, "status": "stored"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to store memory: {e}")
|
logger.error(f"Failed to store memory: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def recall(
|
async def recall(
|
||||||
self,
|
self, query: str, limit: int = 5, sources: Optional[List[str]] = None
|
||||||
query: str,
|
|
||||||
limit: int = 5,
|
|
||||||
sources: Optional[List[str]] = None
|
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Semantic search for memories.
|
"""Semantic search for memories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query text
|
query: Search query text
|
||||||
limit: Max results to return
|
limit: Max results to return
|
||||||
sources: Filter by source(s) (e.g., ['timmy', 'user'])
|
sources: Filter by source(s) (e.g., ['timmy', 'user'])
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of memory content strings
|
List of memory content strings
|
||||||
"""
|
"""
|
||||||
from brain.embeddings import get_embedder
|
from brain.embeddings import get_embedder
|
||||||
|
|
||||||
embedder = get_embedder()
|
embedder = get_embedder()
|
||||||
query_emb = embedder.encode_single(query)
|
query_emb = embedder.encode_single(query)
|
||||||
|
|
||||||
# rqlite with sqlite-vec extension for vector search
|
# rqlite with sqlite-vec extension for vector search
|
||||||
sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?"
|
sql = "SELECT content, source, metadata, distance FROM memories WHERE embedding MATCH ?"
|
||||||
params = [query_emb]
|
params = [query_emb]
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
placeholders = ",".join(["?"] * len(sources))
|
placeholders = ",".join(["?"] * len(sources))
|
||||||
sql += f" AND source IN ({placeholders})"
|
sql += f" AND source IN ({placeholders})"
|
||||||
params.extend(sources)
|
params.extend(sources)
|
||||||
|
|
||||||
sql += " ORDER BY distance LIMIT ?"
|
sql += " ORDER BY distance LIMIT ?"
|
||||||
params.append(limit)
|
params.append(limit)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
|
||||||
f"{self.rqlite_url}/db/query",
|
|
||||||
json=[sql, params]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
for row in result["results"][0].get("rows", []):
|
for row in result["results"][0].get("rows", []):
|
||||||
results.append({
|
results.append(
|
||||||
"content": row[0],
|
{
|
||||||
"source": row[1],
|
"content": row[0],
|
||||||
"metadata": json.loads(row[2]) if row[2] else {},
|
"source": row[1],
|
||||||
"distance": row[3]
|
"metadata": json.loads(row[2]) if row[2] else {},
|
||||||
})
|
"distance": row[3],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to search memories: {e}")
|
logger.error(f"Failed to search memories: {e}")
|
||||||
# Graceful fallback - return empty list
|
# Graceful fallback - return empty list
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_recent(
|
async def get_recent(
|
||||||
self,
|
self, hours: int = 24, limit: int = 20, sources: Optional[List[str]] = None
|
||||||
hours: int = 24,
|
|
||||||
limit: int = 20,
|
|
||||||
sources: Optional[List[str]] = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Get recent memories by time.
|
"""Get recent memories by time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hours: Look back this many hours
|
hours: Look back this many hours
|
||||||
limit: Max results
|
limit: Max results
|
||||||
sources: Optional source filter
|
sources: Optional source filter
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of memory dicts
|
List of memory dicts
|
||||||
"""
|
"""
|
||||||
@@ -180,84 +172,83 @@ class BrainClient:
|
|||||||
WHERE created_at > datetime('now', ?)
|
WHERE created_at > datetime('now', ?)
|
||||||
"""
|
"""
|
||||||
params = [f"-{hours} hours"]
|
params = [f"-{hours} hours"]
|
||||||
|
|
||||||
if sources:
|
if sources:
|
||||||
placeholders = ",".join(["?"] * len(sources))
|
placeholders = ",".join(["?"] * len(sources))
|
||||||
sql += f" AND source IN ({placeholders})"
|
sql += f" AND source IN ({placeholders})"
|
||||||
params.extend(sources)
|
params.extend(sources)
|
||||||
|
|
||||||
sql += " ORDER BY created_at DESC LIMIT ?"
|
sql += " ORDER BY created_at DESC LIMIT ?"
|
||||||
params.append(limit)
|
params.append(limit)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, params])
|
||||||
f"{self.rqlite_url}/db/query",
|
|
||||||
json=[sql, params]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
memories = []
|
memories = []
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
for row in result["results"][0].get("rows", []):
|
for row in result["results"][0].get("rows", []):
|
||||||
memories.append({
|
memories.append(
|
||||||
"id": row[0],
|
{
|
||||||
"content": row[1],
|
"id": row[0],
|
||||||
"source": row[2],
|
"content": row[1],
|
||||||
"tags": json.loads(row[3]) if row[3] else [],
|
"source": row[2],
|
||||||
"metadata": json.loads(row[4]) if row[4] else {},
|
"tags": json.loads(row[3]) if row[3] else [],
|
||||||
"created_at": row[5]
|
"metadata": json.loads(row[4]) if row[4] else {},
|
||||||
})
|
"created_at": row[5],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get recent memories: {e}")
|
logger.error(f"Failed to get recent memories: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def get_context(self, query: str) -> str:
|
async def get_context(self, query: str) -> str:
|
||||||
"""Get formatted context for system prompt.
|
"""Get formatted context for system prompt.
|
||||||
|
|
||||||
Combines recent memories + relevant memories.
|
Combines recent memories + relevant memories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Current user query to find relevant context
|
query: Current user query to find relevant context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted context string for prompt injection
|
Formatted context string for prompt injection
|
||||||
"""
|
"""
|
||||||
recent = await self.get_recent(hours=24, limit=10)
|
recent = await self.get_recent(hours=24, limit=10)
|
||||||
relevant = await self.recall(query, limit=5)
|
relevant = await self.recall(query, limit=5)
|
||||||
|
|
||||||
lines = ["Recent activity:"]
|
lines = ["Recent activity:"]
|
||||||
for m in recent[:5]:
|
for m in recent[:5]:
|
||||||
lines.append(f"- {m['content'][:100]}")
|
lines.append(f"- {m['content'][:100]}")
|
||||||
|
|
||||||
lines.append("\nRelevant memories:")
|
lines.append("\nRelevant memories:")
|
||||||
for r in relevant[:5]:
|
for r in relevant[:5]:
|
||||||
lines.append(f"- {r['content'][:100]}")
|
lines.append(f"- {r['content'][:100]}")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# Task Queue Operations
|
# Task Queue Operations
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def submit_task(
|
async def submit_task(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
task_type: str = "general",
|
task_type: str = "general",
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
metadata: Optional[Dict[str, Any]] = None
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Submit a task to the distributed queue.
|
"""Submit a task to the distributed queue.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: Task description/prompt
|
content: Task description/prompt
|
||||||
task_type: Type of task (shell, creative, code, research, general)
|
task_type: Type of task (shell, creative, code, research, general)
|
||||||
priority: Higher = processed first
|
priority: Higher = processed first
|
||||||
metadata: Additional task data
|
metadata: Additional task data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with task 'id'
|
Dict with task 'id'
|
||||||
"""
|
"""
|
||||||
@@ -270,50 +261,45 @@ class BrainClient:
|
|||||||
task_type,
|
task_type,
|
||||||
priority,
|
priority,
|
||||||
json.dumps(metadata or {}),
|
json.dumps(metadata or {}),
|
||||||
datetime.utcnow().isoformat()
|
datetime.utcnow().isoformat(),
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||||
f"{self.rqlite_url}/db/execute",
|
|
||||||
json=[query, params]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
last_id = None
|
last_id = None
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
last_id = result["results"][0].get("last_insert_id")
|
last_id = result["results"][0].get("last_insert_id")
|
||||||
|
|
||||||
logger.info(f"Submitted task {last_id}: {content[:50]}...")
|
logger.info(f"Submitted task {last_id}: {content[:50]}...")
|
||||||
return {"id": last_id, "status": "queued"}
|
return {"id": last_id, "status": "queued"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to submit task: {e}")
|
logger.error(f"Failed to submit task: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def claim_task(
|
async def claim_task(
|
||||||
self,
|
self, capabilities: List[str], node_id: Optional[str] = None
|
||||||
capabilities: List[str],
|
|
||||||
node_id: Optional[str] = None
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Atomically claim next available task.
|
"""Atomically claim next available task.
|
||||||
|
|
||||||
Uses UPDATE ... RETURNING pattern for atomic claim.
|
Uses UPDATE ... RETURNING pattern for atomic claim.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
capabilities: List of capabilities this node has
|
capabilities: List of capabilities this node has
|
||||||
node_id: Identifier for claiming node
|
node_id: Identifier for claiming node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Task dict or None if no tasks available
|
Task dict or None if no tasks available
|
||||||
"""
|
"""
|
||||||
claimer = node_id or self.node_id
|
claimer = node_id or self.node_id
|
||||||
|
|
||||||
# Try to claim a matching task atomically
|
# Try to claim a matching task atomically
|
||||||
# This works because rqlite uses Raft consensus - only one node wins
|
# This works because rqlite uses Raft consensus - only one node wins
|
||||||
placeholders = ",".join(["?"] * len(capabilities))
|
placeholders = ",".join(["?"] * len(capabilities))
|
||||||
|
|
||||||
query = f"""
|
query = f"""
|
||||||
UPDATE tasks
|
UPDATE tasks
|
||||||
SET status = 'claimed',
|
SET status = 'claimed',
|
||||||
@@ -330,15 +316,12 @@ class BrainClient:
|
|||||||
RETURNING id, content, task_type, priority, metadata
|
RETURNING id, content, task_type, priority, metadata
|
||||||
"""
|
"""
|
||||||
params = [claimer, datetime.utcnow().isoformat()] + capabilities
|
params = [claimer, datetime.utcnow().isoformat()] + capabilities
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||||
f"{self.rqlite_url}/db/execute",
|
|
||||||
json=[query, params]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
rows = result["results"][0].get("rows", [])
|
rows = result["results"][0].get("rows", [])
|
||||||
if rows:
|
if rows:
|
||||||
@@ -348,24 +331,20 @@ class BrainClient:
|
|||||||
"content": row[1],
|
"content": row[1],
|
||||||
"type": row[2],
|
"type": row[2],
|
||||||
"priority": row[3],
|
"priority": row[3],
|
||||||
"metadata": json.loads(row[4]) if row[4] else {}
|
"metadata": json.loads(row[4]) if row[4] else {},
|
||||||
}
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to claim task: {e}")
|
logger.error(f"Failed to claim task: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def complete_task(
|
async def complete_task(
|
||||||
self,
|
self, task_id: int, success: bool, result: Optional[str] = None, error: Optional[str] = None
|
||||||
task_id: int,
|
|
||||||
success: bool,
|
|
||||||
result: Optional[str] = None,
|
|
||||||
error: Optional[str] = None
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Mark task as completed or failed.
|
"""Mark task as completed or failed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id: Task ID
|
task_id: Task ID
|
||||||
success: True if task succeeded
|
success: True if task succeeded
|
||||||
@@ -373,7 +352,7 @@ class BrainClient:
|
|||||||
error: Error message if failed
|
error: Error message if failed
|
||||||
"""
|
"""
|
||||||
status = "done" if success else "failed"
|
status = "done" if success else "failed"
|
||||||
|
|
||||||
query = """
|
query = """
|
||||||
UPDATE tasks
|
UPDATE tasks
|
||||||
SET status = ?,
|
SET status = ?,
|
||||||
@@ -383,23 +362,20 @@ class BrainClient:
|
|||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
"""
|
"""
|
||||||
params = [status, result, error, datetime.utcnow().isoformat(), task_id]
|
params = [status, result, error, datetime.utcnow().isoformat(), task_id]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._client.post(
|
await self._client.post(f"{self.rqlite_url}/db/execute", json=[query, params])
|
||||||
f"{self.rqlite_url}/db/execute",
|
|
||||||
json=[query, params]
|
|
||||||
)
|
|
||||||
logger.debug(f"Task {task_id} marked {status}")
|
logger.debug(f"Task {task_id} marked {status}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to complete task {task_id}: {e}")
|
logger.error(f"Failed to complete task {task_id}: {e}")
|
||||||
|
|
||||||
async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]:
|
async def get_pending_tasks(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
"""Get list of pending tasks (for dashboard/monitoring).
|
"""Get list of pending tasks (for dashboard/monitoring).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
limit: Max tasks to return
|
limit: Max tasks to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of pending task dicts
|
List of pending task dicts
|
||||||
"""
|
"""
|
||||||
@@ -410,33 +386,32 @@ class BrainClient:
|
|||||||
ORDER BY priority DESC, created_at ASC
|
ORDER BY priority DESC, created_at ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._client.post(
|
resp = await self._client.post(f"{self.rqlite_url}/db/query", json=[sql, [limit]])
|
||||||
f"{self.rqlite_url}/db/query",
|
|
||||||
json=[sql, [limit]]
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
result = resp.json()
|
result = resp.json()
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
if "results" in result and result["results"]:
|
if "results" in result and result["results"]:
|
||||||
for row in result["results"][0].get("rows", []):
|
for row in result["results"][0].get("rows", []):
|
||||||
tasks.append({
|
tasks.append(
|
||||||
"id": row[0],
|
{
|
||||||
"content": row[1],
|
"id": row[0],
|
||||||
"type": row[2],
|
"content": row[1],
|
||||||
"priority": row[3],
|
"type": row[2],
|
||||||
"metadata": json.loads(row[4]) if row[4] else {},
|
"priority": row[3],
|
||||||
"created_at": row[5]
|
"metadata": json.loads(row[4]) if row[4] else {},
|
||||||
})
|
"created_at": row[5],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get pending tasks: {e}")
|
logger.error(f"Failed to get pending tasks: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Close HTTP client."""
|
"""Close HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|||||||
@@ -18,48 +18,51 @@ _dimensions = 384
|
|||||||
|
|
||||||
class LocalEmbedder:
|
class LocalEmbedder:
|
||||||
"""Local sentence transformer for embeddings.
|
"""Local sentence transformer for embeddings.
|
||||||
|
|
||||||
Uses all-MiniLM-L6-v2 (80MB download, runs on CPU).
|
Uses all-MiniLM-L6-v2 (80MB download, runs on CPU).
|
||||||
384-dimensional embeddings, good enough for semantic search.
|
384-dimensional embeddings, good enough for semantic search.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = _model_name):
|
def __init__(self, model_name: str = _model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._model = None
|
self._model = None
|
||||||
self._dimensions = _dimensions
|
self._dimensions = _dimensions
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""Lazy load the model."""
|
"""Lazy load the model."""
|
||||||
global _model
|
global _model
|
||||||
if _model is not None:
|
if _model is not None:
|
||||||
self._model = _model
|
self._model = _model
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
logger.info(f"Loading embedding model: {self.model_name}")
|
logger.info(f"Loading embedding model: {self.model_name}")
|
||||||
_model = SentenceTransformer(self.model_name)
|
_model = SentenceTransformer(self.model_name)
|
||||||
self._model = _model
|
self._model = _model
|
||||||
logger.info(f"Embedding model loaded ({self._dimensions} dims)")
|
logger.info(f"Embedding model loaded ({self._dimensions} dims)")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("sentence-transformers not installed. Run: pip install sentence-transformers")
|
logger.error(
|
||||||
|
"sentence-transformers not installed. Run: pip install sentence-transformers"
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def encode(self, text: Union[str, List[str]]):
|
def encode(self, text: Union[str, List[str]]):
|
||||||
"""Encode text to embedding vector(s).
|
"""Encode text to embedding vector(s).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: String or list of strings to encode
|
text: String or list of strings to encode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Numpy array of shape (dims,) for single string or (n, dims) for list
|
Numpy array of shape (dims,) for single string or (n, dims) for list
|
||||||
"""
|
"""
|
||||||
if self._model is None:
|
if self._model is None:
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
# Normalize embeddings for cosine similarity
|
# Normalize embeddings for cosine similarity
|
||||||
return self._model.encode(text, normalize_embeddings=True)
|
return self._model.encode(text, normalize_embeddings=True)
|
||||||
|
|
||||||
def encode_single(self, text: str) -> bytes:
|
def encode_single(self, text: str) -> bytes:
|
||||||
"""Encode single text to bytes for SQLite storage.
|
"""Encode single text to bytes for SQLite storage.
|
||||||
|
|
||||||
@@ -67,17 +70,19 @@ class LocalEmbedder:
|
|||||||
Float32 bytes
|
Float32 bytes
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
embedding = self.encode(text)
|
embedding = self.encode(text)
|
||||||
if len(embedding.shape) > 1:
|
if len(embedding.shape) > 1:
|
||||||
embedding = embedding[0]
|
embedding = embedding[0]
|
||||||
return embedding.astype(np.float32).tobytes()
|
return embedding.astype(np.float32).tobytes()
|
||||||
|
|
||||||
def similarity(self, a, b) -> float:
|
def similarity(self, a, b) -> float:
|
||||||
"""Compute cosine similarity between two vectors.
|
"""Compute cosine similarity between two vectors.
|
||||||
|
|
||||||
Vectors should already be normalized from encode().
|
Vectors should already be normalized from encode().
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
return float(np.dot(a, b))
|
return float(np.dot(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ _SCHEMA_VERSION = 1
|
|||||||
def _get_db_path() -> Path:
|
def _get_db_path() -> Path:
|
||||||
"""Get the brain database path from env or default."""
|
"""Get the brain database path from env or default."""
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
if settings.brain_db_path:
|
if settings.brain_db_path:
|
||||||
return Path(settings.brain_db_path)
|
return Path(settings.brain_db_path)
|
||||||
return _DEFAULT_DB_PATH
|
return _DEFAULT_DB_PATH
|
||||||
@@ -75,6 +76,7 @@ class UnifiedMemory:
|
|||||||
# Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite
|
# Auto-detect: use rqlite if RQLITE_URL is set, otherwise local SQLite
|
||||||
if use_rqlite is None:
|
if use_rqlite is None:
|
||||||
from config import settings as _settings
|
from config import settings as _settings
|
||||||
|
|
||||||
use_rqlite = bool(_settings.rqlite_url)
|
use_rqlite = bool(_settings.rqlite_url)
|
||||||
self._use_rqlite = use_rqlite
|
self._use_rqlite = use_rqlite
|
||||||
|
|
||||||
@@ -107,10 +109,12 @@ class UnifiedMemory:
|
|||||||
"""Lazy-load the embedding model."""
|
"""Lazy-load the embedding model."""
|
||||||
if self._embedder is None:
|
if self._embedder is None:
|
||||||
from config import settings as _settings
|
from config import settings as _settings
|
||||||
|
|
||||||
if _settings.timmy_skip_embeddings:
|
if _settings.timmy_skip_embeddings:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
from brain.embeddings import LocalEmbedder
|
from brain.embeddings import LocalEmbedder
|
||||||
|
|
||||||
self._embedder = LocalEmbedder()
|
self._embedder = LocalEmbedder()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("sentence-transformers not available — semantic search disabled")
|
logger.warning("sentence-transformers not available — semantic search disabled")
|
||||||
@@ -125,6 +129,7 @@ class UnifiedMemory:
|
|||||||
"""Lazy-load the rqlite BrainClient."""
|
"""Lazy-load the rqlite BrainClient."""
|
||||||
if self._rqlite_client is None:
|
if self._rqlite_client is None:
|
||||||
from brain.client import BrainClient
|
from brain.client import BrainClient
|
||||||
|
|
||||||
self._rqlite_client = BrainClient()
|
self._rqlite_client = BrainClient()
|
||||||
return self._rqlite_client
|
return self._rqlite_client
|
||||||
|
|
||||||
@@ -292,15 +297,17 @@ class UnifiedMemory:
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
for score, row in scored[:limit]:
|
for score, row in scored[:limit]:
|
||||||
results.append({
|
results.append(
|
||||||
"id": row["id"],
|
{
|
||||||
"content": row["content"],
|
"id": row["id"],
|
||||||
"source": row["source"],
|
"content": row["content"],
|
||||||
"tags": json.loads(row["tags"]) if row["tags"] else [],
|
"source": row["source"],
|
||||||
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
"tags": json.loads(row["tags"]) if row["tags"] else [],
|
||||||
"score": score,
|
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
||||||
"created_at": row["created_at"],
|
"score": score,
|
||||||
})
|
"created_at": row["created_at"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -84,11 +84,13 @@ def get_migration_sql(from_version: int, to_version: int) -> str:
|
|||||||
"""Get SQL to migrate between versions."""
|
"""Get SQL to migrate between versions."""
|
||||||
if to_version <= from_version:
|
if to_version <= from_version:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
sql_parts = []
|
sql_parts = []
|
||||||
for v in range(from_version + 1, to_version + 1):
|
for v in range(from_version + 1, to_version + 1):
|
||||||
if v in MIGRATIONS:
|
if v in MIGRATIONS:
|
||||||
sql_parts.append(MIGRATIONS[v])
|
sql_parts.append(MIGRATIONS[v])
|
||||||
sql_parts.append(f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');")
|
sql_parts.append(
|
||||||
|
f"UPDATE schema_version SET version = {v}, applied_at = datetime('now');"
|
||||||
|
)
|
||||||
|
|
||||||
return "\n".join(sql_parts)
|
return "\n".join(sql_parts)
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class DistributedWorker:
|
class DistributedWorker:
|
||||||
"""Continuous task processor for the distributed brain.
|
"""Continuous task processor for the distributed brain.
|
||||||
|
|
||||||
Runs on every device, claims tasks matching its capabilities,
|
Runs on every device, claims tasks matching its capabilities,
|
||||||
executes them immediately, stores results.
|
executes them immediately, stores results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, brain_client: Optional[BrainClient] = None):
|
def __init__(self, brain_client: Optional[BrainClient] = None):
|
||||||
self.brain = brain_client or BrainClient()
|
self.brain = brain_client or BrainClient()
|
||||||
self.node_id = f"{socket.gethostname()}-{os.getpid()}"
|
self.node_id = f"{socket.gethostname()}-{os.getpid()}"
|
||||||
@@ -33,30 +33,30 @@ class DistributedWorker:
|
|||||||
self.running = False
|
self.running = False
|
||||||
self._handlers: Dict[str, Callable] = {}
|
self._handlers: Dict[str, Callable] = {}
|
||||||
self._register_default_handlers()
|
self._register_default_handlers()
|
||||||
|
|
||||||
def _detect_capabilities(self) -> List[str]:
|
def _detect_capabilities(self) -> List[str]:
|
||||||
"""Detect what this node can do."""
|
"""Detect what this node can do."""
|
||||||
caps = ["general", "shell", "file_ops", "git"]
|
caps = ["general", "shell", "file_ops", "git"]
|
||||||
|
|
||||||
# Check for GPU
|
# Check for GPU
|
||||||
if self._has_gpu():
|
if self._has_gpu():
|
||||||
caps.append("gpu")
|
caps.append("gpu")
|
||||||
caps.append("creative")
|
caps.append("creative")
|
||||||
caps.append("image_gen")
|
caps.append("image_gen")
|
||||||
caps.append("video_gen")
|
caps.append("video_gen")
|
||||||
|
|
||||||
# Check for internet
|
# Check for internet
|
||||||
if self._has_internet():
|
if self._has_internet():
|
||||||
caps.append("web")
|
caps.append("web")
|
||||||
caps.append("research")
|
caps.append("research")
|
||||||
|
|
||||||
# Check memory
|
# Check memory
|
||||||
mem_gb = self._get_memory_gb()
|
mem_gb = self._get_memory_gb()
|
||||||
if mem_gb > 16:
|
if mem_gb > 16:
|
||||||
caps.append("large_model")
|
caps.append("large_model")
|
||||||
if mem_gb > 32:
|
if mem_gb > 32:
|
||||||
caps.append("huge_model")
|
caps.append("huge_model")
|
||||||
|
|
||||||
# Check for specific tools
|
# Check for specific tools
|
||||||
if self._has_command("ollama"):
|
if self._has_command("ollama"):
|
||||||
caps.append("ollama")
|
caps.append("ollama")
|
||||||
@@ -64,17 +64,15 @@ class DistributedWorker:
|
|||||||
caps.append("docker")
|
caps.append("docker")
|
||||||
if self._has_command("cargo"):
|
if self._has_command("cargo"):
|
||||||
caps.append("rust")
|
caps.append("rust")
|
||||||
|
|
||||||
logger.info(f"Worker capabilities: {caps}")
|
logger.info(f"Worker capabilities: {caps}")
|
||||||
return caps
|
return caps
|
||||||
|
|
||||||
def _has_gpu(self) -> bool:
|
def _has_gpu(self) -> bool:
|
||||||
"""Check for NVIDIA or AMD GPU."""
|
"""Check for NVIDIA or AMD GPU."""
|
||||||
try:
|
try:
|
||||||
# Check for nvidia-smi
|
# Check for nvidia-smi
|
||||||
result = subprocess.run(
|
result = subprocess.run(["nvidia-smi"], capture_output=True, timeout=5)
|
||||||
["nvidia-smi"], capture_output=True, timeout=5
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return True
|
return True
|
||||||
except (OSError, subprocess.SubprocessError):
|
except (OSError, subprocess.SubprocessError):
|
||||||
@@ -83,13 +81,15 @@ class DistributedWorker:
|
|||||||
# Check for ROCm
|
# Check for ROCm
|
||||||
if os.path.exists("/opt/rocm"):
|
if os.path.exists("/opt/rocm"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for Apple Silicon Metal
|
# Check for Apple Silicon Metal
|
||||||
if os.uname().sysname == "Darwin":
|
if os.uname().sysname == "Darwin":
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["system_profiler", "SPDisplaysDataType"],
|
["system_profiler", "SPDisplaysDataType"],
|
||||||
capture_output=True, text=True, timeout=5
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=5,
|
||||||
)
|
)
|
||||||
if "Metal" in result.stdout:
|
if "Metal" in result.stdout:
|
||||||
return True
|
return True
|
||||||
@@ -102,8 +102,7 @@ class DistributedWorker:
|
|||||||
"""Check if we have internet connectivity."""
|
"""Check if we have internet connectivity."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["curl", "-s", "--max-time", "3", "https://1.1.1.1"],
|
["curl", "-s", "--max-time", "3", "https://1.1.1.1"], capture_output=True, timeout=5
|
||||||
capture_output=True, timeout=5
|
|
||||||
)
|
)
|
||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
except (OSError, subprocess.SubprocessError):
|
except (OSError, subprocess.SubprocessError):
|
||||||
@@ -114,8 +113,7 @@ class DistributedWorker:
|
|||||||
try:
|
try:
|
||||||
if os.uname().sysname == "Darwin":
|
if os.uname().sysname == "Darwin":
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["sysctl", "-n", "hw.memsize"],
|
["sysctl", "-n", "hw.memsize"], capture_output=True, text=True
|
||||||
capture_output=True, text=True
|
|
||||||
)
|
)
|
||||||
bytes_mem = int(result.stdout.strip())
|
bytes_mem = int(result.stdout.strip())
|
||||||
return bytes_mem / (1024**3)
|
return bytes_mem / (1024**3)
|
||||||
@@ -128,13 +126,11 @@ class DistributedWorker:
|
|||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
pass
|
pass
|
||||||
return 8.0 # Assume 8GB if we can't detect
|
return 8.0 # Assume 8GB if we can't detect
|
||||||
|
|
||||||
def _has_command(self, cmd: str) -> bool:
|
def _has_command(self, cmd: str) -> bool:
|
||||||
"""Check if command exists."""
|
"""Check if command exists."""
|
||||||
try:
|
try:
|
||||||
result = subprocess.run(
|
result = subprocess.run(["which", cmd], capture_output=True, timeout=5)
|
||||||
["which", cmd], capture_output=True, timeout=5
|
|
||||||
)
|
|
||||||
return result.returncode == 0
|
return result.returncode == 0
|
||||||
except (OSError, subprocess.SubprocessError):
|
except (OSError, subprocess.SubprocessError):
|
||||||
return False
|
return False
|
||||||
@@ -148,10 +144,10 @@ class DistributedWorker:
|
|||||||
"research": self._handle_research,
|
"research": self._handle_research,
|
||||||
"general": self._handle_general,
|
"general": self._handle_general,
|
||||||
}
|
}
|
||||||
|
|
||||||
def register_handler(self, task_type: str, handler: Callable[[str], Any]):
|
def register_handler(self, task_type: str, handler: Callable[[str], Any]):
|
||||||
"""Register a custom task handler.
|
"""Register a custom task handler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_type: Type of task this handler handles
|
task_type: Type of task this handler handles
|
||||||
handler: Async function that takes task content and returns result
|
handler: Async function that takes task content and returns result
|
||||||
@@ -159,11 +155,11 @@ class DistributedWorker:
|
|||||||
self._handlers[task_type] = handler
|
self._handlers[task_type] = handler
|
||||||
if task_type not in self.capabilities:
|
if task_type not in self.capabilities:
|
||||||
self.capabilities.append(task_type)
|
self.capabilities.append(task_type)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# Task Handlers
|
# Task Handlers
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _handle_shell(self, command: str) -> str:
|
async def _handle_shell(self, command: str) -> str:
|
||||||
"""Execute shell command via ZeroClaw or direct subprocess."""
|
"""Execute shell command via ZeroClaw or direct subprocess."""
|
||||||
# Try ZeroClaw first if available
|
# Try ZeroClaw first if available
|
||||||
@@ -171,156 +167,153 @@ class DistributedWorker:
|
|||||||
proc = await asyncio.create_subprocess_shell(
|
proc = await asyncio.create_subprocess_shell(
|
||||||
f"zeroclaw exec --json '{command}'",
|
f"zeroclaw exec --json '{command}'",
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
stdout, stderr = await proc.communicate()
|
stdout, stderr = await proc.communicate()
|
||||||
|
|
||||||
# Store result in brain
|
# Store result in brain
|
||||||
await self.brain.remember(
|
await self.brain.remember(
|
||||||
content=f"Shell: {command}\nOutput: {stdout.decode()}",
|
content=f"Shell: {command}\nOutput: {stdout.decode()}",
|
||||||
tags=["shell", "result"],
|
tags=["shell", "result"],
|
||||||
source=self.node_id,
|
source=self.node_id,
|
||||||
metadata={"command": command, "exit_code": proc.returncode}
|
metadata={"command": command, "exit_code": proc.returncode},
|
||||||
)
|
)
|
||||||
|
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
raise Exception(f"Command failed: {stderr.decode()}")
|
raise Exception(f"Command failed: {stderr.decode()}")
|
||||||
return stdout.decode()
|
return stdout.decode()
|
||||||
|
|
||||||
# Fallback to direct subprocess (less safe)
|
# Fallback to direct subprocess (less safe)
|
||||||
proc = await asyncio.create_subprocess_shell(
|
proc = await asyncio.create_subprocess_shell(
|
||||||
command,
|
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE
|
|
||||||
)
|
)
|
||||||
stdout, stderr = await proc.communicate()
|
stdout, stderr = await proc.communicate()
|
||||||
|
|
||||||
if proc.returncode != 0:
|
if proc.returncode != 0:
|
||||||
raise Exception(f"Command failed: {stderr.decode()}")
|
raise Exception(f"Command failed: {stderr.decode()}")
|
||||||
return stdout.decode()
|
return stdout.decode()
|
||||||
|
|
||||||
async def _handle_creative(self, prompt: str) -> str:
|
async def _handle_creative(self, prompt: str) -> str:
|
||||||
"""Generate creative media (requires GPU)."""
|
"""Generate creative media (requires GPU)."""
|
||||||
if "gpu" not in self.capabilities:
|
if "gpu" not in self.capabilities:
|
||||||
raise Exception("GPU not available on this node")
|
raise Exception("GPU not available on this node")
|
||||||
|
|
||||||
# This would call creative tools (Stable Diffusion, etc.)
|
# This would call creative tools (Stable Diffusion, etc.)
|
||||||
# For now, placeholder
|
# For now, placeholder
|
||||||
logger.info(f"Creative task: {prompt[:50]}...")
|
logger.info(f"Creative task: {prompt[:50]}...")
|
||||||
|
|
||||||
# Store result
|
# Store result
|
||||||
result = f"Creative output for: {prompt}"
|
result = f"Creative output for: {prompt}"
|
||||||
await self.brain.remember(
|
await self.brain.remember(
|
||||||
content=result,
|
content=result,
|
||||||
tags=["creative", "generated"],
|
tags=["creative", "generated"],
|
||||||
source=self.node_id,
|
source=self.node_id,
|
||||||
metadata={"prompt": prompt}
|
metadata={"prompt": prompt},
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _handle_code(self, description: str) -> str:
|
async def _handle_code(self, description: str) -> str:
|
||||||
"""Code generation and modification."""
|
"""Code generation and modification."""
|
||||||
# Would use LLM to generate code
|
# Would use LLM to generate code
|
||||||
# For now, placeholder
|
# For now, placeholder
|
||||||
logger.info(f"Code task: {description[:50]}...")
|
logger.info(f"Code task: {description[:50]}...")
|
||||||
return f"Code generated for: {description}"
|
return f"Code generated for: {description}"
|
||||||
|
|
||||||
async def _handle_research(self, query: str) -> str:
|
async def _handle_research(self, query: str) -> str:
|
||||||
"""Web research."""
|
"""Web research."""
|
||||||
if "web" not in self.capabilities:
|
if "web" not in self.capabilities:
|
||||||
raise Exception("Internet not available on this node")
|
raise Exception("Internet not available on this node")
|
||||||
|
|
||||||
# Would use browser automation or search
|
# Would use browser automation or search
|
||||||
logger.info(f"Research task: {query[:50]}...")
|
logger.info(f"Research task: {query[:50]}...")
|
||||||
return f"Research results for: {query}"
|
return f"Research results for: {query}"
|
||||||
|
|
||||||
async def _handle_general(self, prompt: str) -> str:
|
async def _handle_general(self, prompt: str) -> str:
|
||||||
"""General LLM task via local Ollama."""
|
"""General LLM task via local Ollama."""
|
||||||
if "ollama" not in self.capabilities:
|
if "ollama" not in self.capabilities:
|
||||||
raise Exception("Ollama not available on this node")
|
raise Exception("Ollama not available on this node")
|
||||||
|
|
||||||
# Call Ollama
|
# Call Ollama
|
||||||
try:
|
try:
|
||||||
proc = await asyncio.create_subprocess_exec(
|
proc = await asyncio.create_subprocess_exec(
|
||||||
"curl", "-s", "http://localhost:11434/api/generate",
|
"curl",
|
||||||
"-d", json.dumps({
|
"-s",
|
||||||
"model": "llama3.1:8b-instruct",
|
"http://localhost:11434/api/generate",
|
||||||
"prompt": prompt,
|
"-d",
|
||||||
"stream": False
|
json.dumps({"model": "llama3.1:8b-instruct", "prompt": prompt, "stream": False}),
|
||||||
}),
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stdout=asyncio.subprocess.PIPE
|
|
||||||
)
|
)
|
||||||
stdout, _ = await proc.communicate()
|
stdout, _ = await proc.communicate()
|
||||||
|
|
||||||
response = json.loads(stdout.decode())
|
response = json.loads(stdout.decode())
|
||||||
result = response.get("response", "No response")
|
result = response.get("response", "No response")
|
||||||
|
|
||||||
# Store in brain
|
# Store in brain
|
||||||
await self.brain.remember(
|
await self.brain.remember(
|
||||||
content=f"Task: {prompt}\nResult: {result}",
|
content=f"Task: {prompt}\nResult: {result}",
|
||||||
tags=["llm", "result"],
|
tags=["llm", "result"],
|
||||||
source=self.node_id,
|
source=self.node_id,
|
||||||
metadata={"model": "llama3.1:8b-instruct"}
|
metadata={"model": "llama3.1:8b-instruct"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"LLM failed: {e}")
|
raise Exception(f"LLM failed: {e}")
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# Main Loop
|
# Main Loop
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute_task(self, task: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Execute a claimed task."""
|
"""Execute a claimed task."""
|
||||||
task_type = task.get("type", "general")
|
task_type = task.get("type", "general")
|
||||||
content = task.get("content", "")
|
content = task.get("content", "")
|
||||||
task_id = task.get("id")
|
task_id = task.get("id")
|
||||||
|
|
||||||
handler = self._handlers.get(task_type, self._handlers["general"])
|
handler = self._handlers.get(task_type, self._handlers["general"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Executing task {task_id}: {task_type}")
|
logger.info(f"Executing task {task_id}: {task_type}")
|
||||||
result = await handler(content)
|
result = await handler(content)
|
||||||
|
|
||||||
await self.brain.complete_task(task_id, success=True, result=result)
|
await self.brain.complete_task(task_id, success=True, result=result)
|
||||||
logger.info(f"Task {task_id} completed")
|
logger.info(f"Task {task_id} completed")
|
||||||
return {"success": True, "result": result}
|
return {"success": True, "result": result}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
logger.error(f"Task {task_id} failed: {error_msg}")
|
logger.error(f"Task {task_id} failed: {error_msg}")
|
||||||
await self.brain.complete_task(task_id, success=False, error=error_msg)
|
await self.brain.complete_task(task_id, success=False, error=error_msg)
|
||||||
return {"success": False, "error": error_msg}
|
return {"success": False, "error": error_msg}
|
||||||
|
|
||||||
async def run_once(self) -> bool:
|
async def run_once(self) -> bool:
|
||||||
"""Process one task if available.
|
"""Process one task if available.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if a task was processed, False if no tasks available
|
True if a task was processed, False if no tasks available
|
||||||
"""
|
"""
|
||||||
task = await self.brain.claim_task(self.capabilities, self.node_id)
|
task = await self.brain.claim_task(self.capabilities, self.node_id)
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
await self.execute_task(task)
|
await self.execute_task(task)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Main loop — continuously process tasks."""
|
"""Main loop — continuously process tasks."""
|
||||||
logger.info(f"Worker {self.node_id} started")
|
logger.info(f"Worker {self.node_id} started")
|
||||||
logger.info(f"Capabilities: {self.capabilities}")
|
logger.info(f"Capabilities: {self.capabilities}")
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
consecutive_empty = 0
|
consecutive_empty = 0
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
had_work = await self.run_once()
|
had_work = await self.run_once()
|
||||||
|
|
||||||
if had_work:
|
if had_work:
|
||||||
# Immediately check for more work
|
# Immediately check for more work
|
||||||
consecutive_empty = 0
|
consecutive_empty = 0
|
||||||
@@ -331,11 +324,11 @@ class DistributedWorker:
|
|||||||
# Sleep 0.5s, but up to 2s if consistently empty
|
# Sleep 0.5s, but up to 2s if consistently empty
|
||||||
sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0)
|
sleep_time = min(0.5 + (consecutive_empty * 0.1), 2.0)
|
||||||
await asyncio.sleep(sleep_time)
|
await asyncio.sleep(sleep_time)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Worker error: {e}")
|
logger.error(f"Worker error: {e}")
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""Stop the worker loop."""
|
"""Stop the worker loop."""
|
||||||
self.running = False
|
self.running = False
|
||||||
@@ -345,7 +338,7 @@ class DistributedWorker:
|
|||||||
async def main():
|
async def main():
|
||||||
"""CLI entry point for worker."""
|
"""CLI entry point for worker."""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# Allow capability overrides from CLI
|
# Allow capability overrides from CLI
|
||||||
if len(sys.argv) > 1:
|
if len(sys.argv) > 1:
|
||||||
caps = sys.argv[1].split(",")
|
caps = sys.argv[1].split(",")
|
||||||
@@ -354,12 +347,12 @@ async def main():
|
|||||||
logger.info(f"Overriding capabilities: {caps}")
|
logger.info(f"Overriding capabilities: {caps}")
|
||||||
else:
|
else:
|
||||||
worker = DistributedWorker()
|
worker = DistributedWorker()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await worker.run()
|
await worker.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
worker.stop()
|
worker.stop()
|
||||||
print("\nWorker stopped.")
|
logger.info("Worker stopped.")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -213,6 +213,15 @@ class Settings(BaseSettings):
|
|||||||
# Timeout in seconds for OpenFang hand execution (some hands are slow).
|
# Timeout in seconds for OpenFang hand execution (some hands are slow).
|
||||||
openfang_timeout: int = 120
|
openfang_timeout: int = 120
|
||||||
|
|
||||||
|
# ── Autoresearch — autonomous ML experiment loops ──────────────────
|
||||||
|
# Integrates Karpathy's autoresearch pattern: agents modify training
|
||||||
|
# code, run time-boxed experiments, evaluate metrics, and iterate.
|
||||||
|
autoresearch_enabled: bool = False
|
||||||
|
autoresearch_workspace: str = "data/experiments"
|
||||||
|
autoresearch_time_budget: int = 300 # seconds per experiment run
|
||||||
|
autoresearch_max_iterations: int = 100
|
||||||
|
autoresearch_metric: str = "val_bpb" # metric to optimise (lower = better)
|
||||||
|
|
||||||
# ── Local Hands (Shell + Git) ──────────────────────────────────────
|
# ── Local Hands (Shell + Git) ──────────────────────────────────────
|
||||||
# Enable local shell/git execution hands.
|
# Enable local shell/git execution hands.
|
||||||
hands_shell_enabled: bool = True
|
hands_shell_enabled: bool = True
|
||||||
|
|||||||
@@ -18,36 +18,38 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from dashboard.routes.agents import router as agents_router
|
|
||||||
from dashboard.routes.health import router as health_router
|
|
||||||
from dashboard.routes.marketplace import router as marketplace_router
|
|
||||||
from dashboard.routes.voice import router as voice_router
|
|
||||||
from dashboard.routes.mobile import router as mobile_router
|
|
||||||
from dashboard.routes.briefing import router as briefing_router
|
|
||||||
from dashboard.routes.telegram import router as telegram_router
|
|
||||||
from dashboard.routes.tools import router as tools_router
|
|
||||||
from dashboard.routes.spark import router as spark_router
|
|
||||||
from dashboard.routes.discord import router as discord_router
|
|
||||||
from dashboard.routes.memory import router as memory_router
|
|
||||||
from dashboard.routes.router import router as router_status_router
|
|
||||||
from dashboard.routes.grok import router as grok_router
|
|
||||||
from dashboard.routes.models import router as models_router
|
|
||||||
from dashboard.routes.models import api_router as models_api_router
|
|
||||||
from dashboard.routes.chat_api import router as chat_api_router
|
|
||||||
from dashboard.routes.thinking import router as thinking_router
|
|
||||||
from dashboard.routes.calm import router as calm_router
|
|
||||||
from dashboard.routes.swarm import router as swarm_router
|
|
||||||
from dashboard.routes.tasks import router as tasks_router
|
|
||||||
from dashboard.routes.work_orders import router as work_orders_router
|
|
||||||
from dashboard.routes.system import router as system_router
|
|
||||||
from dashboard.routes.paperclip import router as paperclip_router
|
|
||||||
from infrastructure.router.api import router as cascade_router
|
|
||||||
|
|
||||||
# Import dedicated middleware
|
# Import dedicated middleware
|
||||||
from dashboard.middleware.csrf import CSRFMiddleware
|
from dashboard.middleware.csrf import CSRFMiddleware
|
||||||
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
from dashboard.middleware.request_logging import RequestLoggingMiddleware
|
||||||
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
from dashboard.middleware.security_headers import SecurityHeadersMiddleware
|
||||||
|
from dashboard.routes.agents import router as agents_router
|
||||||
|
from dashboard.routes.briefing import router as briefing_router
|
||||||
|
from dashboard.routes.calm import router as calm_router
|
||||||
|
from dashboard.routes.chat_api import router as chat_api_router
|
||||||
|
from dashboard.routes.discord import router as discord_router
|
||||||
|
from dashboard.routes.experiments import router as experiments_router
|
||||||
|
from dashboard.routes.grok import router as grok_router
|
||||||
|
from dashboard.routes.health import router as health_router
|
||||||
|
from dashboard.routes.marketplace import router as marketplace_router
|
||||||
|
from dashboard.routes.memory import router as memory_router
|
||||||
|
from dashboard.routes.mobile import router as mobile_router
|
||||||
|
from dashboard.routes.models import api_router as models_api_router
|
||||||
|
from dashboard.routes.models import router as models_router
|
||||||
|
from dashboard.routes.paperclip import router as paperclip_router
|
||||||
|
from dashboard.routes.router import router as router_status_router
|
||||||
|
from dashboard.routes.spark import router as spark_router
|
||||||
|
from dashboard.routes.swarm import router as swarm_router
|
||||||
|
from dashboard.routes.system import router as system_router
|
||||||
|
from dashboard.routes.tasks import router as tasks_router
|
||||||
|
from dashboard.routes.telegram import router as telegram_router
|
||||||
|
from dashboard.routes.thinking import router as thinking_router
|
||||||
|
from dashboard.routes.tools import router as tools_router
|
||||||
|
from dashboard.routes.voice import router as voice_router
|
||||||
|
from dashboard.routes.work_orders import router as work_orders_router
|
||||||
|
from infrastructure.router.api import router as cascade_router
|
||||||
|
|
||||||
|
|
||||||
def _configure_logging() -> None:
|
def _configure_logging() -> None:
|
||||||
@@ -100,8 +102,8 @@ _BRIEFING_INTERVAL_HOURS = 6
|
|||||||
|
|
||||||
async def _briefing_scheduler() -> None:
|
async def _briefing_scheduler() -> None:
|
||||||
"""Background task: regenerate Timmy's briefing every 6 hours."""
|
"""Background task: regenerate Timmy's briefing every 6 hours."""
|
||||||
from timmy.briefing import engine as briefing_engine
|
|
||||||
from infrastructure.notifications.push import notify_briefing_ready
|
from infrastructure.notifications.push import notify_briefing_ready
|
||||||
|
from timmy.briefing import engine as briefing_engine
|
||||||
|
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
@@ -121,9 +123,9 @@ async def _briefing_scheduler() -> None:
|
|||||||
|
|
||||||
async def _start_chat_integrations_background() -> None:
|
async def _start_chat_integrations_background() -> None:
|
||||||
"""Background task: start chat integrations without blocking startup."""
|
"""Background task: start chat integrations without blocking startup."""
|
||||||
from integrations.telegram_bot.bot import telegram_bot
|
|
||||||
from integrations.chat_bridge.vendors.discord import discord_bot
|
|
||||||
from integrations.chat_bridge.registry import platform_registry
|
from integrations.chat_bridge.registry import platform_registry
|
||||||
|
from integrations.chat_bridge.vendors.discord import discord_bot
|
||||||
|
from integrations.telegram_bot.bot import telegram_bot
|
||||||
|
|
||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
@@ -164,9 +166,9 @@ async def _discord_token_watcher() -> None:
|
|||||||
if discord_bot.state.name == "CONNECTED":
|
if discord_bot.state.name == "CONNECTED":
|
||||||
return # Already running — stop watching
|
return # Already running — stop watching
|
||||||
|
|
||||||
# 1. Check live environment variable (intentionally uses os.environ,
|
# 1. Check settings (pydantic-settings reads env on instantiation;
|
||||||
# not settings, because this polls for runtime hot-reload changes)
|
# hot-reload is handled by re-reading .env below)
|
||||||
token = os.environ.get("DISCORD_TOKEN", "")
|
token = settings.discord_token
|
||||||
|
|
||||||
# 2. Re-read .env file for hot-reload
|
# 2. Re-read .env file for hot-reload
|
||||||
if not token:
|
if not token:
|
||||||
@@ -203,6 +205,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
# Initialize Spark Intelligence engine
|
# Initialize Spark Intelligence engine
|
||||||
from spark.engine import spark_engine
|
from spark.engine import spark_engine
|
||||||
|
|
||||||
if spark_engine.enabled:
|
if spark_engine.enabled:
|
||||||
logger.info("Spark Intelligence active — event capture enabled")
|
logger.info("Spark Intelligence active — event capture enabled")
|
||||||
|
|
||||||
@@ -210,12 +213,17 @@ async def lifespan(app: FastAPI):
|
|||||||
if settings.memory_prune_days > 0:
|
if settings.memory_prune_days > 0:
|
||||||
try:
|
try:
|
||||||
from timmy.memory.vector_store import prune_memories
|
from timmy.memory.vector_store import prune_memories
|
||||||
|
|
||||||
pruned = prune_memories(
|
pruned = prune_memories(
|
||||||
older_than_days=settings.memory_prune_days,
|
older_than_days=settings.memory_prune_days,
|
||||||
keep_facts=settings.memory_prune_keep_facts,
|
keep_facts=settings.memory_prune_keep_facts,
|
||||||
)
|
)
|
||||||
if pruned:
|
if pruned:
|
||||||
logger.info("Memory auto-prune: removed %d entries older than %d days", pruned, settings.memory_prune_days)
|
logger.info(
|
||||||
|
"Memory auto-prune: removed %d entries older than %d days",
|
||||||
|
pruned,
|
||||||
|
settings.memory_prune_days,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Memory auto-prune skipped: %s", exc)
|
logger.debug("Memory auto-prune skipped: %s", exc)
|
||||||
|
|
||||||
@@ -229,7 +237,8 @@ async def lifespan(app: FastAPI):
|
|||||||
if total_mb > settings.memory_vault_max_mb:
|
if total_mb > settings.memory_vault_max_mb:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
|
"Memory vault (%.1f MB) exceeds limit (%d MB) — consider archiving old notes",
|
||||||
total_mb, settings.memory_vault_max_mb,
|
total_mb,
|
||||||
|
settings.memory_vault_max_mb,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Vault size check skipped: %s", exc)
|
logger.debug("Vault size check skipped: %s", exc)
|
||||||
@@ -284,10 +293,7 @@ def _get_cors_origins() -> list[str]:
|
|||||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health"])
|
||||||
|
|
||||||
# 2. Security Headers
|
# 2. Security Headers
|
||||||
app.add_middleware(
|
app.add_middleware(SecurityHeadersMiddleware, production=not settings.debug)
|
||||||
SecurityHeadersMiddleware,
|
|
||||||
production=not settings.debug
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. CSRF Protection
|
# 3. CSRF Protection
|
||||||
app.add_middleware(CSRFMiddleware)
|
app.add_middleware(CSRFMiddleware)
|
||||||
@@ -314,7 +320,6 @@ if static_dir.exists():
|
|||||||
# Shared templates instance
|
# Shared templates instance
|
||||||
from dashboard.templating import templates # noqa: E402
|
from dashboard.templating import templates # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
app.include_router(health_router)
|
app.include_router(health_router)
|
||||||
app.include_router(agents_router)
|
app.include_router(agents_router)
|
||||||
@@ -339,6 +344,7 @@ app.include_router(tasks_router)
|
|||||||
app.include_router(work_orders_router)
|
app.include_router(work_orders_router)
|
||||||
app.include_router(system_router)
|
app.include_router(system_router)
|
||||||
app.include_router(paperclip_router)
|
app.include_router(paperclip_router)
|
||||||
|
app.include_router(experiments_router)
|
||||||
app.include_router(cascade_router)
|
app.include_router(cascade_router)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Dashboard middleware package."""
|
"""Dashboard middleware package."""
|
||||||
|
|
||||||
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
|
from .csrf import CSRFMiddleware, csrf_exempt, generate_csrf_token, validate_csrf_token
|
||||||
from .security_headers import SecurityHeadersMiddleware
|
|
||||||
from .request_logging import RequestLoggingMiddleware
|
from .request_logging import RequestLoggingMiddleware
|
||||||
|
from .security_headers import SecurityHeadersMiddleware
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CSRFMiddleware",
|
"CSRFMiddleware",
|
||||||
|
|||||||
@@ -4,16 +4,15 @@ Provides CSRF token generation, validation, and middleware integration
|
|||||||
to protect state-changing endpoints from cross-site request attacks.
|
to protect state-changing endpoints from cross-site request attacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import secrets
|
|
||||||
import hmac
|
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Callable, Optional
|
import hmac
|
||||||
|
import secrets
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response, JSONResponse
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
|
||||||
# Module-level set to track exempt routes
|
# Module-level set to track exempt routes
|
||||||
_exempt_routes: set[str] = set()
|
_exempt_routes: set[str] = set()
|
||||||
@@ -21,26 +20,27 @@ _exempt_routes: set[str] = set()
|
|||||||
|
|
||||||
def csrf_exempt(endpoint: Callable) -> Callable:
|
def csrf_exempt(endpoint: Callable) -> Callable:
|
||||||
"""Decorator to mark an endpoint as exempt from CSRF validation.
|
"""Decorator to mark an endpoint as exempt from CSRF validation.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@app.post("/webhook")
|
@app.post("/webhook")
|
||||||
@csrf_exempt
|
@csrf_exempt
|
||||||
def webhook_endpoint():
|
def webhook_endpoint():
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(endpoint)
|
@wraps(endpoint)
|
||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
return await endpoint(*args, **kwargs)
|
return await endpoint(*args, **kwargs)
|
||||||
|
|
||||||
@wraps(endpoint)
|
@wraps(endpoint)
|
||||||
def sync_wrapper(*args, **kwargs):
|
def sync_wrapper(*args, **kwargs):
|
||||||
return endpoint(*args, **kwargs)
|
return endpoint(*args, **kwargs)
|
||||||
|
|
||||||
# Mark the original function as exempt
|
# Mark the original function as exempt
|
||||||
endpoint._csrf_exempt = True # type: ignore
|
endpoint._csrf_exempt = True # type: ignore
|
||||||
|
|
||||||
# Also mark the wrapper
|
# Also mark the wrapper
|
||||||
if hasattr(endpoint, '__code__') and endpoint.__code__.co_flags & 0x80:
|
if hasattr(endpoint, "__code__") and endpoint.__code__.co_flags & 0x80:
|
||||||
async_wrapper._csrf_exempt = True # type: ignore
|
async_wrapper._csrf_exempt = True # type: ignore
|
||||||
return async_wrapper
|
return async_wrapper
|
||||||
else:
|
else:
|
||||||
@@ -50,12 +50,12 @@ def csrf_exempt(endpoint: Callable) -> Callable:
|
|||||||
|
|
||||||
def is_csrf_exempt(endpoint: Callable) -> bool:
|
def is_csrf_exempt(endpoint: Callable) -> bool:
|
||||||
"""Check if an endpoint is marked as CSRF exempt."""
|
"""Check if an endpoint is marked as CSRF exempt."""
|
||||||
return getattr(endpoint, '_csrf_exempt', False)
|
return getattr(endpoint, "_csrf_exempt", False)
|
||||||
|
|
||||||
|
|
||||||
def generate_csrf_token() -> str:
|
def generate_csrf_token() -> str:
|
||||||
"""Generate a cryptographically secure CSRF token.
|
"""Generate a cryptographically secure CSRF token.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A secure random token string.
|
A secure random token string.
|
||||||
"""
|
"""
|
||||||
@@ -64,77 +64,78 @@ def generate_csrf_token() -> str:
|
|||||||
|
|
||||||
def validate_csrf_token(token: str, expected_token: str) -> bool:
|
def validate_csrf_token(token: str, expected_token: str) -> bool:
|
||||||
"""Validate a CSRF token against the expected token.
|
"""Validate a CSRF token against the expected token.
|
||||||
|
|
||||||
Uses constant-time comparison to prevent timing attacks.
|
Uses constant-time comparison to prevent timing attacks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The token provided by the client.
|
token: The token provided by the client.
|
||||||
expected_token: The expected token (from cookie/session).
|
expected_token: The expected token (from cookie/session).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the token is valid, False otherwise.
|
True if the token is valid, False otherwise.
|
||||||
"""
|
"""
|
||||||
if not token or not expected_token:
|
if not token or not expected_token:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return hmac.compare_digest(token, expected_token)
|
return hmac.compare_digest(token, expected_token)
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware to enforce CSRF protection on state-changing requests.
|
"""Middleware to enforce CSRF protection on state-changing requests.
|
||||||
|
|
||||||
Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens.
|
Safe methods (GET, HEAD, OPTIONS, TRACE) are allowed without CSRF tokens.
|
||||||
State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token.
|
State-changing methods (POST, PUT, DELETE, PATCH) require a valid CSRF token.
|
||||||
|
|
||||||
The token is expected to be:
|
The token is expected to be:
|
||||||
- In the X-CSRF-Token header, or
|
- In the X-CSRF-Token header, or
|
||||||
- In the request body as 'csrf_token', or
|
- In the request body as 'csrf_token', or
|
||||||
- Matching the token in the csrf_token cookie
|
- Matching the token in the csrf_token cookie
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
app.add_middleware(CSRFMiddleware, secret="your-secret-key")
|
app.add_middleware(CSRFMiddleware, secret="your-secret-key")
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
secret: Secret key for token signing (optional, for future use).
|
secret: Secret key for token signing (optional, for future use).
|
||||||
cookie_name: Name of the CSRF cookie.
|
cookie_name: Name of the CSRF cookie.
|
||||||
header_name: Name of the CSRF header.
|
header_name: Name of the CSRF header.
|
||||||
safe_methods: HTTP methods that don't require CSRF tokens.
|
safe_methods: HTTP methods that don't require CSRF tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
SAFE_METHODS = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app,
|
app,
|
||||||
secret: Optional[str] = None,
|
secret: Optional[str] = None,
|
||||||
cookie_name: str = "csrf_token",
|
cookie_name: str = "csrf_token",
|
||||||
header_name: str = "X-CSRF-Token",
|
header_name: str = "X-CSRF-Token",
|
||||||
form_field: str = "csrf_token"
|
form_field: str = "csrf_token",
|
||||||
):
|
):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.cookie_name = cookie_name
|
self.cookie_name = cookie_name
|
||||||
self.header_name = header_name
|
self.header_name = header_name
|
||||||
self.form_field = form_field
|
self.form_field = form_field
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response:
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
"""Process the request and enforce CSRF protection.
|
"""Process the request and enforce CSRF protection.
|
||||||
|
|
||||||
For safe methods: Set a CSRF token cookie if not present.
|
For safe methods: Set a CSRF token cookie if not present.
|
||||||
For unsafe methods: Validate the CSRF token.
|
For unsafe methods: Validate the CSRF token.
|
||||||
"""
|
"""
|
||||||
# Bypass CSRF if explicitly disabled (e.g. in tests)
|
# Bypass CSRF if explicitly disabled (e.g. in tests)
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
if settings.timmy_disable_csrf:
|
if settings.timmy_disable_csrf:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Get existing CSRF token from cookie
|
# Get existing CSRF token from cookie
|
||||||
csrf_cookie = request.cookies.get(self.cookie_name)
|
csrf_cookie = request.cookies.get(self.cookie_name)
|
||||||
|
|
||||||
# For safe methods, just ensure a token exists
|
# For safe methods, just ensure a token exists
|
||||||
if request.method in self.SAFE_METHODS:
|
if request.method in self.SAFE_METHODS:
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
# Set CSRF token cookie if not present
|
# Set CSRF token cookie if not present
|
||||||
if not csrf_cookie:
|
if not csrf_cookie:
|
||||||
new_token = generate_csrf_token()
|
new_token = generate_csrf_token()
|
||||||
@@ -144,15 +145,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
httponly=False, # Must be readable by JavaScript
|
httponly=False, # Must be readable by JavaScript
|
||||||
secure=settings.csrf_cookie_secure,
|
secure=settings.csrf_cookie_secure,
|
||||||
samesite="Lax",
|
samesite="Lax",
|
||||||
max_age=86400 # 24 hours
|
max_age=86400, # 24 hours
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# For unsafe methods, check if route is exempt first
|
# For unsafe methods, check if route is exempt first
|
||||||
# Note: We need to let the request proceed and check at response time
|
# Note: We need to let the request proceed and check at response time
|
||||||
# since FastAPI routes are resolved after middleware
|
# since FastAPI routes are resolved after middleware
|
||||||
|
|
||||||
# Try to validate token early
|
# Try to validate token early
|
||||||
if not await self._validate_request(request, csrf_cookie):
|
if not await self._validate_request(request, csrf_cookie):
|
||||||
# Check if this might be an exempt route by checking path patterns
|
# Check if this might be an exempt route by checking path patterns
|
||||||
@@ -164,33 +165,34 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
content={
|
content={
|
||||||
"error": "CSRF validation failed",
|
"error": "CSRF validation failed",
|
||||||
"code": "CSRF_INVALID",
|
"code": "CSRF_INVALID",
|
||||||
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field."
|
"message": "Missing or invalid CSRF token. Include the token from the csrf_token cookie in the X-CSRF-Token header or as a form field.",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
def _is_likely_exempt(self, path: str) -> bool:
|
def _is_likely_exempt(self, path: str) -> bool:
|
||||||
"""Check if a path is likely to be CSRF exempt.
|
"""Check if a path is likely to be CSRF exempt.
|
||||||
|
|
||||||
Common patterns like webhooks, API endpoints, etc.
|
Common patterns like webhooks, API endpoints, etc.
|
||||||
Uses path normalization and exact/prefix matching to prevent bypasses.
|
Uses path normalization and exact/prefix matching to prevent bypasses.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: The request path.
|
path: The request path.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the path is likely exempt.
|
True if the path is likely exempt.
|
||||||
"""
|
"""
|
||||||
# 1. Normalize path to prevent /webhook/../ bypasses
|
# 1. Normalize path to prevent /webhook/../ bypasses
|
||||||
# Use posixpath for consistent behavior on all platforms
|
# Use posixpath for consistent behavior on all platforms
|
||||||
import posixpath
|
import posixpath
|
||||||
|
|
||||||
normalized_path = posixpath.normpath(path)
|
normalized_path = posixpath.normpath(path)
|
||||||
|
|
||||||
# Ensure it starts with / for comparison
|
# Ensure it starts with / for comparison
|
||||||
if not normalized_path.startswith("/"):
|
if not normalized_path.startswith("/"):
|
||||||
normalized_path = "/" + normalized_path
|
normalized_path = "/" + normalized_path
|
||||||
|
|
||||||
# Add back trailing slash if it was present in original path
|
# Add back trailing slash if it was present in original path
|
||||||
# to ensure prefix matching behaves as expected
|
# to ensure prefix matching behaves as expected
|
||||||
if path.endswith("/") and not normalized_path.endswith("/"):
|
if path.endswith("/") and not normalized_path.endswith("/"):
|
||||||
@@ -200,15 +202,15 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
# Patterns ending with / are prefix-matched
|
# Patterns ending with / are prefix-matched
|
||||||
# Patterns NOT ending with / are exact-matched
|
# Patterns NOT ending with / are exact-matched
|
||||||
exempt_patterns = [
|
exempt_patterns = [
|
||||||
"/webhook/", # Prefix match (e.g., /webhook/stripe)
|
"/webhook/", # Prefix match (e.g., /webhook/stripe)
|
||||||
"/webhook", # Exact match
|
"/webhook", # Exact match
|
||||||
"/api/v1/", # Prefix match
|
"/api/v1/", # Prefix match
|
||||||
"/lightning/webhook/", # Prefix match
|
"/lightning/webhook/", # Prefix match
|
||||||
"/lightning/webhook", # Exact match
|
"/lightning/webhook", # Exact match
|
||||||
"/_internal/", # Prefix match
|
"/_internal/", # Prefix match
|
||||||
"/_internal", # Exact match
|
"/_internal", # Exact match
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in exempt_patterns:
|
for pattern in exempt_patterns:
|
||||||
if pattern.endswith("/"):
|
if pattern.endswith("/"):
|
||||||
if normalized_path.startswith(pattern):
|
if normalized_path.startswith(pattern):
|
||||||
@@ -216,20 +218,20 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
else:
|
else:
|
||||||
if normalized_path == pattern:
|
if normalized_path == pattern:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
|
async def _validate_request(self, request: Request, csrf_cookie: Optional[str]) -> bool:
|
||||||
"""Validate the CSRF token in the request.
|
"""Validate the CSRF token in the request.
|
||||||
|
|
||||||
Checks for token in:
|
Checks for token in:
|
||||||
1. X-CSRF-Token header
|
1. X-CSRF-Token header
|
||||||
2. csrf_token form field
|
2. csrf_token form field
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming request.
|
request: The incoming request.
|
||||||
csrf_cookie: The expected token from the cookie.
|
csrf_cookie: The expected token from the cookie.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the token is valid, False otherwise.
|
True if the token is valid, False otherwise.
|
||||||
"""
|
"""
|
||||||
@@ -241,11 +243,14 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
header_token = request.headers.get(self.header_name)
|
header_token = request.headers.get(self.header_name)
|
||||||
if header_token and validate_csrf_token(header_token, csrf_cookie):
|
if header_token and validate_csrf_token(header_token, csrf_cookie):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# If no header token, try form data (for non-JSON POSTs)
|
# If no header token, try form data (for non-JSON POSTs)
|
||||||
# Check Content-Type to avoid hanging on non-form requests
|
# Check Content-Type to avoid hanging on non-form requests
|
||||||
content_type = request.headers.get("Content-Type", "")
|
content_type = request.headers.get("Content-Type", "")
|
||||||
if "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
|
if (
|
||||||
|
"application/x-www-form-urlencoded" in content_type
|
||||||
|
or "multipart/form-data" in content_type
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
form_data = await request.form()
|
form_data = await request.form()
|
||||||
form_token = form_data.get(self.form_field)
|
form_token = form_data.get(self.form_field)
|
||||||
@@ -254,5 +259,5 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Error parsing form data, treat as invalid
|
# Error parsing form data, treat as invalid
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -4,22 +4,21 @@ Logs HTTP requests with timing, status codes, and client information
|
|||||||
for monitoring and debugging purposes.
|
for monitoring and debugging purposes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
from typing import List, Optional
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger("timmy.requests")
|
logger = logging.getLogger("timmy.requests")
|
||||||
|
|
||||||
|
|
||||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware to log all HTTP requests.
|
"""Middleware to log all HTTP requests.
|
||||||
|
|
||||||
Logs the following information for each request:
|
Logs the following information for each request:
|
||||||
- HTTP method and path
|
- HTTP method and path
|
||||||
- Response status code
|
- Response status code
|
||||||
@@ -27,60 +26,55 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
- Client IP address
|
- Client IP address
|
||||||
- User-Agent header
|
- User-Agent header
|
||||||
- Correlation ID for tracing
|
- Correlation ID for tracing
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
app.add_middleware(RequestLoggingMiddleware)
|
app.add_middleware(RequestLoggingMiddleware)
|
||||||
|
|
||||||
# Skip certain paths:
|
# Skip certain paths:
|
||||||
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
|
app.add_middleware(RequestLoggingMiddleware, skip_paths=["/health", "/metrics"])
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
skip_paths: List of URL paths to skip logging.
|
skip_paths: List of URL paths to skip logging.
|
||||||
log_level: Logging level for successful requests.
|
log_level: Logging level for successful requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, app, skip_paths: Optional[List[str]] = None, log_level: int = logging.INFO):
|
||||||
self,
|
|
||||||
app,
|
|
||||||
skip_paths: Optional[List[str]] = None,
|
|
||||||
log_level: int = logging.INFO
|
|
||||||
):
|
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.skip_paths = set(skip_paths or [])
|
self.skip_paths = set(skip_paths or [])
|
||||||
self.log_level = log_level
|
self.log_level = log_level
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response:
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
"""Log the request and response details.
|
"""Log the request and response details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming request.
|
request: The incoming request.
|
||||||
call_next: Callable to get the response from downstream.
|
call_next: Callable to get the response from downstream.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The response from downstream.
|
The response from downstream.
|
||||||
"""
|
"""
|
||||||
# Check if we should skip logging this path
|
# Check if we should skip logging this path
|
||||||
if request.url.path in self.skip_paths:
|
if request.url.path in self.skip_paths:
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# Generate correlation ID
|
# Generate correlation ID
|
||||||
correlation_id = str(uuid.uuid4())[:8]
|
correlation_id = str(uuid.uuid4())[:8]
|
||||||
request.state.correlation_id = correlation_id
|
request.state.correlation_id = correlation_id
|
||||||
|
|
||||||
# Record start time
|
# Record start time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Get client info
|
# Get client info
|
||||||
client_ip = self._get_client_ip(request)
|
client_ip = self._get_client_ip(request)
|
||||||
user_agent = request.headers.get("user-agent", "-")
|
user_agent = request.headers.get("user-agent", "-")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Process the request
|
# Process the request
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
# Calculate duration
|
# Calculate duration
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
# Log the request
|
# Log the request
|
||||||
self._log_request(
|
self._log_request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
@@ -89,14 +83,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
duration_ms=duration_ms,
|
duration_ms=duration_ms,
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
correlation_id=correlation_id
|
correlation_id=correlation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add correlation ID to response headers
|
# Add correlation ID to response headers
|
||||||
response.headers["X-Correlation-ID"] = correlation_id
|
response.headers["X-Correlation-ID"] = correlation_id
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# Calculate duration even for failed requests
|
# Calculate duration even for failed requests
|
||||||
duration_ms = (time.time() - start_time) * 1000
|
duration_ms = (time.time() - start_time) * 1000
|
||||||
@@ -110,6 +104,7 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
# Auto-escalate: create bug report task from unhandled exception
|
# Auto-escalate: create bug report task from unhandled exception
|
||||||
try:
|
try:
|
||||||
from infrastructure.error_capture import capture_error
|
from infrastructure.error_capture import capture_error
|
||||||
|
|
||||||
capture_error(
|
capture_error(
|
||||||
exc,
|
exc,
|
||||||
source="http",
|
source="http",
|
||||||
@@ -126,16 +121,16 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
# Re-raise the exception
|
# Re-raise the exception
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _get_client_ip(self, request: Request) -> str:
|
def _get_client_ip(self, request: Request) -> str:
|
||||||
"""Extract the client IP address from the request.
|
"""Extract the client IP address from the request.
|
||||||
|
|
||||||
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
|
Checks X-Forwarded-For and X-Real-IP headers first for proxied requests,
|
||||||
falls back to the direct client IP.
|
falls back to the direct client IP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming request.
|
request: The incoming request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Client IP address string.
|
Client IP address string.
|
||||||
"""
|
"""
|
||||||
@@ -144,17 +139,17 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
if forwarded_for:
|
if forwarded_for:
|
||||||
# X-Forwarded-For can contain multiple IPs, take the first one
|
# X-Forwarded-For can contain multiple IPs, take the first one
|
||||||
return forwarded_for.split(",")[0].strip()
|
return forwarded_for.split(",")[0].strip()
|
||||||
|
|
||||||
real_ip = request.headers.get("x-real-ip")
|
real_ip = request.headers.get("x-real-ip")
|
||||||
if real_ip:
|
if real_ip:
|
||||||
return real_ip
|
return real_ip
|
||||||
|
|
||||||
# Fall back to direct connection
|
# Fall back to direct connection
|
||||||
if request.client:
|
if request.client:
|
||||||
return request.client.host
|
return request.client.host
|
||||||
|
|
||||||
return "-"
|
return "-"
|
||||||
|
|
||||||
def _log_request(
|
def _log_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
@@ -163,10 +158,10 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
duration_ms: float,
|
duration_ms: float,
|
||||||
client_ip: str,
|
client_ip: str,
|
||||||
user_agent: str,
|
user_agent: str,
|
||||||
correlation_id: str
|
correlation_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Format and log the request details.
|
"""Format and log the request details.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method: HTTP method.
|
method: HTTP method.
|
||||||
path: Request path.
|
path: Request path.
|
||||||
@@ -182,14 +177,14 @@ class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|||||||
level = logging.ERROR
|
level = logging.ERROR
|
||||||
elif status_code >= 400:
|
elif status_code >= 400:
|
||||||
level = logging.WARNING
|
level = logging.WARNING
|
||||||
|
|
||||||
message = (
|
message = (
|
||||||
f"[{correlation_id}] {method} {path} - {status_code} "
|
f"[{correlation_id}] {method} {path} - {status_code} "
|
||||||
f"- {duration_ms:.2f}ms - {client_ip}"
|
f"- {duration_ms:.2f}ms - {client_ip}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add user agent for non-health requests
|
# Add user agent for non-health requests
|
||||||
if path not in self.skip_paths:
|
if path not in self.skip_paths:
|
||||||
message += f" - {user_agent[:50]}"
|
message += f" - {user_agent[:50]}"
|
||||||
|
|
||||||
logger.log(level, message)
|
logger.log(level, message)
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ Adds common security headers to all HTTP responses to improve
|
|||||||
application security posture against various attacks.
|
application security posture against various attacks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
@@ -11,7 +13,7 @@ from starlette.responses import Response
|
|||||||
|
|
||||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware to add security headers to all responses.
|
"""Middleware to add security headers to all responses.
|
||||||
|
|
||||||
Adds the following headers:
|
Adds the following headers:
|
||||||
- X-Content-Type-Options: Prevents MIME type sniffing
|
- X-Content-Type-Options: Prevents MIME type sniffing
|
||||||
- X-Frame-Options: Prevents clickjacking
|
- X-Frame-Options: Prevents clickjacking
|
||||||
@@ -20,41 +22,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
- Permissions-Policy: Restricts feature access
|
- Permissions-Policy: Restricts feature access
|
||||||
- Content-Security-Policy: Mitigates XSS and data injection
|
- Content-Security-Policy: Mitigates XSS and data injection
|
||||||
- Strict-Transport-Security: Enforces HTTPS (production only)
|
- Strict-Transport-Security: Enforces HTTPS (production only)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
app.add_middleware(SecurityHeadersMiddleware)
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
|
|
||||||
# Or with production settings:
|
# Or with production settings:
|
||||||
app.add_middleware(SecurityHeadersMiddleware, production=True)
|
app.add_middleware(SecurityHeadersMiddleware, production=True)
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
production: If True, adds HSTS header for HTTPS enforcement.
|
production: If True, adds HSTS header for HTTPS enforcement.
|
||||||
csp_report_only: If True, sends CSP in report-only mode.
|
csp_report_only: If True, sends CSP in report-only mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app,
|
app,
|
||||||
production: bool = False,
|
production: bool = False,
|
||||||
csp_report_only: bool = False,
|
csp_report_only: bool = False,
|
||||||
custom_csp: str = None
|
custom_csp: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.production = production
|
self.production = production
|
||||||
self.csp_report_only = csp_report_only
|
self.csp_report_only = csp_report_only
|
||||||
|
|
||||||
# Build CSP directive
|
# Build CSP directive
|
||||||
self.csp_directive = custom_csp or self._build_csp()
|
self.csp_directive = custom_csp or self._build_csp()
|
||||||
|
|
||||||
def _build_csp(self) -> str:
|
def _build_csp(self) -> str:
|
||||||
"""Build the Content-Security-Policy directive.
|
"""Build the Content-Security-Policy directive.
|
||||||
|
|
||||||
Creates a restrictive default policy that allows:
|
Creates a restrictive default policy that allows:
|
||||||
- Same-origin resources by default
|
- Same-origin resources by default
|
||||||
- Inline scripts/styles (needed for HTMX/Bootstrap)
|
- Inline scripts/styles (needed for HTMX/Bootstrap)
|
||||||
- Data URIs for images
|
- Data URIs for images
|
||||||
- WebSocket connections
|
- WebSocket connections
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
CSP directive string.
|
CSP directive string.
|
||||||
"""
|
"""
|
||||||
@@ -73,25 +75,25 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
"form-action 'self'",
|
"form-action 'self'",
|
||||||
]
|
]
|
||||||
return "; ".join(directives)
|
return "; ".join(directives)
|
||||||
|
|
||||||
def _add_security_headers(self, response: Response) -> None:
|
def _add_security_headers(self, response: Response) -> None:
|
||||||
"""Add security headers to a response.
|
"""Add security headers to a response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The response to add headers to.
|
response: The response to add headers to.
|
||||||
"""
|
"""
|
||||||
# Prevent MIME type sniffing
|
# Prevent MIME type sniffing
|
||||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
|
||||||
# Prevent clickjacking
|
# Prevent clickjacking
|
||||||
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
response.headers["X-Frame-Options"] = "SAMEORIGIN"
|
||||||
|
|
||||||
# Enable XSS protection (legacy browsers)
|
# Enable XSS protection (legacy browsers)
|
||||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||||
|
|
||||||
# Control referrer information
|
# Control referrer information
|
||||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
|
||||||
# Restrict browser features
|
# Restrict browser features
|
||||||
response.headers["Permissions-Policy"] = (
|
response.headers["Permissions-Policy"] = (
|
||||||
"camera=(), "
|
"camera=(), "
|
||||||
@@ -103,38 +105,41 @@ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|||||||
"gyroscope=(), "
|
"gyroscope=(), "
|
||||||
"accelerometer=()"
|
"accelerometer=()"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Content Security Policy
|
# Content Security Policy
|
||||||
csp_header = "Content-Security-Policy-Report-Only" if self.csp_report_only else "Content-Security-Policy"
|
csp_header = (
|
||||||
|
"Content-Security-Policy-Report-Only"
|
||||||
|
if self.csp_report_only
|
||||||
|
else "Content-Security-Policy"
|
||||||
|
)
|
||||||
response.headers[csp_header] = self.csp_directive
|
response.headers[csp_header] = self.csp_directive
|
||||||
|
|
||||||
# HTTPS enforcement (production only)
|
# HTTPS enforcement (production only)
|
||||||
if self.production:
|
if self.production:
|
||||||
response.headers["Strict-Transport-Security"] = (
|
response.headers[
|
||||||
"max-age=31536000; includeSubDomains; preload"
|
"Strict-Transport-Security"
|
||||||
)
|
] = "max-age=31536000; includeSubDomains; preload"
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response:
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
"""Add security headers to the response.
|
"""Add security headers to the response.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming request.
|
request: The incoming request.
|
||||||
call_next: Callable to get the response from downstream.
|
call_next: Callable to get the response from downstream.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response with security headers added.
|
Response with security headers added.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
self._add_security_headers(response)
|
|
||||||
return response
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Create a response for the error with security headers
|
import logging
|
||||||
from starlette.responses import PlainTextResponse
|
|
||||||
response = PlainTextResponse(
|
logging.getLogger(__name__).debug(
|
||||||
content="Internal Server Error",
|
"Upstream error in security headers middleware", exc_info=True
|
||||||
status_code=500
|
|
||||||
)
|
)
|
||||||
self._add_security_headers(response)
|
from starlette.responses import PlainTextResponse
|
||||||
# Return the error response with headers (don't re-raise)
|
|
||||||
return response
|
response = PlainTextResponse("Internal Server Error", status_code=500)
|
||||||
|
self._add_security_headers(response)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -1,24 +1,27 @@
|
|||||||
|
from datetime import date, datetime
|
||||||
from datetime import datetime, date
|
|
||||||
from enum import Enum as PyEnum
|
from enum import Enum as PyEnum
|
||||||
from sqlalchemy import (
|
|
||||||
Column, Integer, String, DateTime, Boolean, Enum as SQLEnum,
|
from sqlalchemy import JSON, Boolean, Column, Date, DateTime
|
||||||
Date, ForeignKey, Index, JSON
|
from sqlalchemy import Enum as SQLEnum
|
||||||
)
|
from sqlalchemy import ForeignKey, Index, Integer, String
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from .database import Base # Assuming a shared Base in models/database.py
|
from .database import Base # Assuming a shared Base in models/database.py
|
||||||
|
|
||||||
|
|
||||||
class TaskState(str, PyEnum):
|
class TaskState(str, PyEnum):
|
||||||
LATER = "LATER"
|
LATER = "LATER"
|
||||||
NEXT = "NEXT"
|
NEXT = "NEXT"
|
||||||
NOW = "NOW"
|
NOW = "NOW"
|
||||||
DONE = "DONE"
|
DONE = "DONE"
|
||||||
DEFERRED = "DEFERRED" # Task pushed to tomorrow
|
DEFERRED = "DEFERRED" # Task pushed to tomorrow
|
||||||
|
|
||||||
|
|
||||||
class TaskCertainty(str, PyEnum):
|
class TaskCertainty(str, PyEnum):
|
||||||
FUZZY = "FUZZY" # An intention without a time
|
FUZZY = "FUZZY" # An intention without a time
|
||||||
SOFT = "SOFT" # A flexible task with a time
|
SOFT = "SOFT" # A flexible task with a time
|
||||||
HARD = "HARD" # A fixed meeting/appointment
|
HARD = "HARD" # A fixed meeting/appointment
|
||||||
|
|
||||||
|
|
||||||
class Task(Base):
|
class Task(Base):
|
||||||
__tablename__ = "tasks"
|
__tablename__ = "tasks"
|
||||||
@@ -29,7 +32,7 @@ class Task(Base):
|
|||||||
|
|
||||||
state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True)
|
state = Column(SQLEnum(TaskState), default=TaskState.LATER, nullable=False, index=True)
|
||||||
certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False)
|
certainty = Column(SQLEnum(TaskCertainty), default=TaskCertainty.SOFT, nullable=False)
|
||||||
is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day
|
is_mit = Column(Boolean, default=False, nullable=False) # 1-3 per day
|
||||||
|
|
||||||
sort_order = Column(Integer, default=0, nullable=False)
|
sort_order = Column(Integer, default=0, nullable=False)
|
||||||
|
|
||||||
@@ -42,7 +45,8 @@ class Task(Base):
|
|||||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (Index('ix_task_state_order', 'state', 'sort_order'),)
|
__table_args__ = (Index("ix_task_state_order", "state", "sort_order"),)
|
||||||
|
|
||||||
|
|
||||||
class JournalEntry(Base):
|
class JournalEntry(Base):
|
||||||
__tablename__ = "journal_entries"
|
__tablename__ = "journal_entries"
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db"
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./data/timmy_calm.db"
|
||||||
|
|
||||||
engine = create_engine(
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
def create_tables():
|
def create_tables():
|
||||||
"""Create all tables defined by models that have imported Base."""
|
"""Create all tables defined by models that have imported Base."""
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from datetime import datetime
|
|||||||
from fastapi import APIRouter, Form, Request
|
from fastapi import APIRouter, Form, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from timmy.session import chat as agent_chat
|
|
||||||
from dashboard.store import message_log
|
from dashboard.store import message_log
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from timmy.session import chat as agent_chat
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,9 +38,7 @@ async def list_agents():
|
|||||||
@router.get("/default/panel", response_class=HTMLResponse)
|
@router.get("/default/panel", response_class=HTMLResponse)
|
||||||
async def agent_panel(request: Request):
|
async def agent_panel(request: Request):
|
||||||
"""Chat panel — for HTMX main-panel swaps."""
|
"""Chat panel — for HTMX main-panel swaps."""
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(request, "partials/agent_panel_chat.html", {"agent": None})
|
||||||
request, "partials/agent_panel_chat.html", {"agent": None}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/default/history", response_class=HTMLResponse)
|
@router.get("/default/history", response_class=HTMLResponse)
|
||||||
@@ -77,7 +75,9 @@ async def chat_agent(request: Request, message: str = Form(...)):
|
|||||||
|
|
||||||
message_log.append(role="user", content=message, timestamp=timestamp, source="browser")
|
message_log.append(role="user", content=message, timestamp=timestamp, source="browser")
|
||||||
if response_text is not None:
|
if response_text is not None:
|
||||||
message_log.append(role="agent", content=response_text, timestamp=timestamp, source="browser")
|
message_log.append(
|
||||||
|
role="agent", content=response_text, timestamp=timestamp, source="browser"
|
||||||
|
)
|
||||||
elif error_text:
|
elif error_text:
|
||||||
message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser")
|
message_log.append(role="error", content=error_text, timestamp=timestamp, source="browser")
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,10 @@ from datetime import datetime, timezone
|
|||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
from timmy.briefing import Briefing, engine as briefing_engine
|
|
||||||
from timmy import approvals as approval_store
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from timmy import approvals as approval_store
|
||||||
|
from timmy.briefing import Briefing
|
||||||
|
from timmy.briefing import engine as briefing_engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
@@ -8,7 +7,7 @@ from fastapi.responses import HTMLResponse
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState
|
from dashboard.models.calm import JournalEntry, Task, TaskCertainty, TaskState
|
||||||
from dashboard.models.database import SessionLocal, engine, get_db, create_tables
|
from dashboard.models.database import SessionLocal, create_tables, engine, get_db
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
|
||||||
# Ensure CALM tables exist (safe to call multiple times)
|
# Ensure CALM tables exist (safe to call multiple times)
|
||||||
@@ -23,11 +22,19 @@ router = APIRouter(tags=["calm"])
|
|||||||
def get_now_task(db: Session) -> Optional[Task]:
|
def get_now_task(db: Session) -> Optional[Task]:
|
||||||
return db.query(Task).filter(Task.state == TaskState.NOW).first()
|
return db.query(Task).filter(Task.state == TaskState.NOW).first()
|
||||||
|
|
||||||
|
|
||||||
def get_next_task(db: Session) -> Optional[Task]:
|
def get_next_task(db: Session) -> Optional[Task]:
|
||||||
return db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
return db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
||||||
|
|
||||||
|
|
||||||
def get_later_tasks(db: Session) -> List[Task]:
|
def get_later_tasks(db: Session) -> List[Task]:
|
||||||
return db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
return (
|
||||||
|
db.query(Task)
|
||||||
|
.filter(Task.state == TaskState.LATER)
|
||||||
|
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def promote_tasks(db: Session):
|
def promote_tasks(db: Session):
|
||||||
# Ensure only one NOW task exists. If multiple, demote extras to NEXT.
|
# Ensure only one NOW task exists. If multiple, demote extras to NEXT.
|
||||||
@@ -38,7 +45,7 @@ def promote_tasks(db: Session):
|
|||||||
for task_to_demote in now_tasks[1:]:
|
for task_to_demote in now_tasks[1:]:
|
||||||
task_to_demote.state = TaskState.NEXT
|
task_to_demote.state = TaskState.NEXT
|
||||||
db.add(task_to_demote)
|
db.add(task_to_demote)
|
||||||
db.flush() # Make changes visible
|
db.flush() # Make changes visible
|
||||||
|
|
||||||
# If no NOW task, promote NEXT to NOW
|
# If no NOW task, promote NEXT to NOW
|
||||||
current_now = db.query(Task).filter(Task.state == TaskState.NOW).first()
|
current_now = db.query(Task).filter(Task.state == TaskState.NOW).first()
|
||||||
@@ -47,12 +54,17 @@ def promote_tasks(db: Session):
|
|||||||
if next_task:
|
if next_task:
|
||||||
next_task.state = TaskState.NOW
|
next_task.state = TaskState.NOW
|
||||||
db.add(next_task)
|
db.add(next_task)
|
||||||
db.flush() # Make changes visible
|
db.flush() # Make changes visible
|
||||||
|
|
||||||
# If no NEXT task, promote highest priority LATER to NEXT
|
# If no NEXT task, promote highest priority LATER to NEXT
|
||||||
current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
current_next = db.query(Task).filter(Task.state == TaskState.NEXT).first()
|
||||||
if not current_next:
|
if not current_next:
|
||||||
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
later_tasks = (
|
||||||
|
db.query(Task)
|
||||||
|
.filter(Task.state == TaskState.LATER)
|
||||||
|
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
if later_tasks:
|
if later_tasks:
|
||||||
later_tasks[0].state = TaskState.NEXT
|
later_tasks[0].state = TaskState.NEXT
|
||||||
db.add(later_tasks[0])
|
db.add(later_tasks[0])
|
||||||
@@ -60,14 +72,17 @@ def promote_tasks(db: Session):
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Endpoints
|
# Endpoints
|
||||||
@router.get("/calm", response_class=HTMLResponse)
|
@router.get("/calm", response_class=HTMLResponse)
|
||||||
async def get_calm_view(request: Request, db: Session = Depends(get_db)):
|
async def get_calm_view(request: Request, db: Session = Depends(get_db)):
|
||||||
now_task = get_now_task(db)
|
now_task = get_now_task(db)
|
||||||
next_task = get_next_task(db)
|
next_task = get_next_task(db)
|
||||||
later_tasks_count = len(get_later_tasks(db))
|
later_tasks_count = len(get_later_tasks(db))
|
||||||
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": now_task,
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"calm/calm_view.html",
|
||||||
|
{
|
||||||
|
"now_task": now_task,
|
||||||
"next_task": next_task,
|
"next_task": next_task,
|
||||||
"later_tasks_count": later_tasks_count,
|
"later_tasks_count": later_tasks_count,
|
||||||
},
|
},
|
||||||
@@ -101,7 +116,7 @@ async def post_morning_ritual(
|
|||||||
task = Task(
|
task = Task(
|
||||||
title=mit_title,
|
title=mit_title,
|
||||||
is_mit=True,
|
is_mit=True,
|
||||||
state=TaskState.LATER, # Initially LATER, will be promoted
|
state=TaskState.LATER, # Initially LATER, will be promoted
|
||||||
certainty=TaskCertainty.SOFT,
|
certainty=TaskCertainty.SOFT,
|
||||||
)
|
)
|
||||||
db.add(task)
|
db.add(task)
|
||||||
@@ -113,7 +128,7 @@ async def post_morning_ritual(
|
|||||||
db.add(journal_entry)
|
db.add(journal_entry)
|
||||||
|
|
||||||
# Create other tasks
|
# Create other tasks
|
||||||
for task_title in other_tasks.split('\n'):
|
for task_title in other_tasks.split("\n"):
|
||||||
task_title = task_title.strip()
|
task_title = task_title.strip()
|
||||||
if task_title:
|
if task_title:
|
||||||
task = Task(
|
task = Task(
|
||||||
@@ -128,20 +143,29 @@ async def post_morning_ritual(
|
|||||||
# Set initial NOW/NEXT states
|
# Set initial NOW/NEXT states
|
||||||
# Set initial NOW/NEXT states after all tasks are created
|
# Set initial NOW/NEXT states after all tasks are created
|
||||||
if not get_now_task(db) and not get_next_task(db):
|
if not get_now_task(db) and not get_next_task(db):
|
||||||
later_tasks = db.query(Task).filter(Task.state == TaskState.LATER).order_by(Task.is_mit.desc(), Task.sort_order).all()
|
later_tasks = (
|
||||||
|
db.query(Task)
|
||||||
|
.filter(Task.state == TaskState.LATER)
|
||||||
|
.order_by(Task.is_mit.desc(), Task.sort_order)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
if later_tasks:
|
if later_tasks:
|
||||||
# Set the highest priority LATER task to NOW
|
# Set the highest priority LATER task to NOW
|
||||||
later_tasks[0].state = TaskState.NOW
|
later_tasks[0].state = TaskState.NOW
|
||||||
db.add(later_tasks[0])
|
db.add(later_tasks[0])
|
||||||
db.flush() # Flush to make the change visible for the next query
|
db.flush() # Flush to make the change visible for the next query
|
||||||
|
|
||||||
# Set the next highest priority LATER task to NEXT
|
# Set the next highest priority LATER task to NEXT
|
||||||
if len(later_tasks) > 1:
|
if len(later_tasks) > 1:
|
||||||
later_tasks[1].state = TaskState.NEXT
|
later_tasks[1].state = TaskState.NEXT
|
||||||
db.add(later_tasks[1])
|
db.add(later_tasks[1])
|
||||||
db.commit() # Commit changes after initial NOW/NEXT setup
|
db.commit() # Commit changes after initial NOW/NEXT setup
|
||||||
|
|
||||||
return templates.TemplateResponse(request, "calm/calm_view.html", {"now_task": get_now_task(db),
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"calm/calm_view.html",
|
||||||
|
{
|
||||||
|
"now_task": get_now_task(db),
|
||||||
"next_task": get_next_task(db),
|
"next_task": get_next_task(db),
|
||||||
"later_tasks_count": len(get_later_tasks(db)),
|
"later_tasks_count": len(get_later_tasks(db)),
|
||||||
},
|
},
|
||||||
@@ -154,7 +178,8 @@ async def get_evening_ritual_form(request: Request, db: Session = Depends(get_db
|
|||||||
if not journal_entry:
|
if not journal_entry:
|
||||||
raise HTTPException(status_code=404, detail="No journal entry for today")
|
raise HTTPException(status_code=404, detail="No journal entry for today")
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry})
|
"calm/evening_ritual_form.html", {"request": request, "journal_entry": journal_entry}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/calm/ritual/evening", response_class=HTMLResponse)
|
@router.post("/calm/ritual/evening", response_class=HTMLResponse)
|
||||||
@@ -175,9 +200,13 @@ async def post_evening_ritual(
|
|||||||
db.add(journal_entry)
|
db.add(journal_entry)
|
||||||
|
|
||||||
# Archive any remaining active tasks
|
# Archive any remaining active tasks
|
||||||
active_tasks = db.query(Task).filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER])).all()
|
active_tasks = (
|
||||||
|
db.query(Task)
|
||||||
|
.filter(Task.state.in_([TaskState.NOW, TaskState.NEXT, TaskState.LATER]))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
for task in active_tasks:
|
for task in active_tasks:
|
||||||
task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic
|
task.state = TaskState.DEFERRED # Or DONE, depending on desired archiving logic
|
||||||
task.deferred_at = datetime.utcnow()
|
task.deferred_at = datetime.utcnow()
|
||||||
db.add(task)
|
db.add(task)
|
||||||
|
|
||||||
@@ -221,7 +250,7 @@ async def start_task(
|
|||||||
):
|
):
|
||||||
current_now_task = get_now_task(db)
|
current_now_task = get_now_task(db)
|
||||||
if current_now_task and current_now_task.id != task_id:
|
if current_now_task and current_now_task.id != task_id:
|
||||||
current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT
|
current_now_task.state = TaskState.NEXT # Demote current NOW to NEXT
|
||||||
db.add(current_now_task)
|
db.add(current_now_task)
|
||||||
|
|
||||||
task = db.query(Task).filter(Task.id == task_id).first()
|
task = db.query(Task).filter(Task.id == task_id).first()
|
||||||
@@ -322,7 +351,7 @@ async def reorder_tasks(
|
|||||||
):
|
):
|
||||||
# Reorder LATER tasks
|
# Reorder LATER tasks
|
||||||
if later_task_ids:
|
if later_task_ids:
|
||||||
ids_in_order = [int(x.strip()) for x in later_task_ids.split(',') if x.strip()]
|
ids_in_order = [int(x.strip()) for x in later_task_ids.split(",") if x.strip()]
|
||||||
for index, task_id in enumerate(ids_in_order):
|
for index, task_id in enumerate(ids_in_order):
|
||||||
task = db.query(Task).filter(Task.id == task_id).first()
|
task = db.query(Task).filter(Task.id == task_id).first()
|
||||||
if task and task.state == TaskState.LATER:
|
if task and task.state == TaskState.LATER:
|
||||||
@@ -332,16 +361,18 @@ async def reorder_tasks(
|
|||||||
# Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly)
|
# Handle NEXT task if it's part of the reorder (e.g., moved from LATER to NEXT explicitly)
|
||||||
if next_task_id:
|
if next_task_id:
|
||||||
task = db.query(Task).filter(Task.id == next_task_id).first()
|
task = db.query(Task).filter(Task.id == next_task_id).first()
|
||||||
if task and task.state == TaskState.LATER: # Only if it was a LATER task being promoted manually
|
if (
|
||||||
|
task and task.state == TaskState.LATER
|
||||||
|
): # Only if it was a LATER task being promoted manually
|
||||||
# Demote current NEXT to LATER
|
# Demote current NEXT to LATER
|
||||||
current_next = get_next_task(db)
|
current_next = get_next_task(db)
|
||||||
if current_next:
|
if current_next:
|
||||||
current_next.state = TaskState.LATER
|
current_next.state = TaskState.LATER
|
||||||
current_next.sort_order = len(get_later_tasks(db)) # Add to end of later
|
current_next.sort_order = len(get_later_tasks(db)) # Add to end of later
|
||||||
db.add(current_next)
|
db.add(current_next)
|
||||||
|
|
||||||
task.state = TaskState.NEXT
|
task.state = TaskState.NEXT
|
||||||
task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency
|
task.sort_order = 0 # NEXT tasks don't really need sort_order, but for consistency
|
||||||
db.add(task)
|
db.add(task)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -27,12 +27,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["chat-api"])
|
router = APIRouter(prefix="/api", tags=["chat-api"])
|
||||||
|
|
||||||
_UPLOAD_DIR = os.path.join("data", "chat-uploads")
|
_UPLOAD_DIR = str(Path(settings.repo_root) / "data" / "chat-uploads")
|
||||||
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
|
_MAX_UPLOAD_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||||
|
|
||||||
|
|
||||||
# ── POST /api/chat ────────────────────────────────────────────────────────────
|
# ── POST /api/chat ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
async def api_chat(request: Request):
|
async def api_chat(request: Request):
|
||||||
"""Accept a JSON chat payload and return the agent's reply.
|
"""Accept a JSON chat payload and return the agent's reply.
|
||||||
@@ -65,7 +66,8 @@ async def api_chat(request: Request):
|
|||||||
# Handle multimodal content arrays — extract text parts
|
# Handle multimodal content arrays — extract text parts
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
text_parts = [
|
text_parts = [
|
||||||
p.get("text", "") for p in content
|
p.get("text", "")
|
||||||
|
for p in content
|
||||||
if isinstance(p, dict) and p.get("type") == "text"
|
if isinstance(p, dict) and p.get("type") == "text"
|
||||||
]
|
]
|
||||||
last_user_msg = " ".join(text_parts).strip()
|
last_user_msg = " ".join(text_parts).strip()
|
||||||
@@ -109,6 +111,7 @@ async def api_chat(request: Request):
|
|||||||
|
|
||||||
# ── POST /api/upload ──────────────────────────────────────────────────────────
|
# ── POST /api/upload ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload")
|
@router.post("/upload")
|
||||||
async def api_upload(file: UploadFile = File(...)):
|
async def api_upload(file: UploadFile = File(...)):
|
||||||
"""Accept a file upload and return its URL.
|
"""Accept a file upload and return its URL.
|
||||||
@@ -147,6 +150,7 @@ async def api_upload(file: UploadFile = File(...)):
|
|||||||
|
|
||||||
# ── GET /api/chat/history ────────────────────────────────────────────────────
|
# ── GET /api/chat/history ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chat/history")
|
@router.get("/chat/history")
|
||||||
async def api_chat_history():
|
async def api_chat_history():
|
||||||
"""Return the in-memory chat history as JSON."""
|
"""Return the in-memory chat history as JSON."""
|
||||||
@@ -165,6 +169,7 @@ async def api_chat_history():
|
|||||||
|
|
||||||
# ── DELETE /api/chat/history ──────────────────────────────────────────────────
|
# ── DELETE /api/chat/history ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/chat/history")
|
@router.delete("/chat/history")
|
||||||
async def api_clear_history():
|
async def api_clear_history():
|
||||||
"""Clear the in-memory chat history."""
|
"""Clear the in-memory chat history."""
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ Endpoints:
|
|||||||
GET /discord/oauth-url — get the bot's OAuth2 authorization URL
|
GET /discord/oauth-url — get the bot's OAuth2 authorization URL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, File, Form, UploadFile
|
from fastapi import APIRouter, File, Form, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/discord", tags=["discord"])
|
router = APIRouter(prefix="/discord", tags=["discord"])
|
||||||
|
|
||||||
|
|||||||
77
src/dashboard/routes/experiments.py
Normal file
77
src/dashboard/routes/experiments.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Experiment dashboard routes — autoresearch experiment monitoring.
|
||||||
|
|
||||||
|
Provides endpoints for viewing, starting, and monitoring autonomous
|
||||||
|
ML experiment loops powered by Karpathy's autoresearch pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
from dashboard.templating import templates
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/experiments", tags=["experiments"])
|
||||||
|
|
||||||
|
|
||||||
|
def _workspace() -> Path:
|
||||||
|
return Path(settings.repo_root) / settings.autoresearch_workspace
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_class=HTMLResponse)
|
||||||
|
async def experiments_page(request: Request):
|
||||||
|
"""Experiment dashboard — lists past runs and allows starting new ones."""
|
||||||
|
from timmy.autoresearch import get_experiment_history
|
||||||
|
|
||||||
|
history = []
|
||||||
|
try:
|
||||||
|
history = get_experiment_history(_workspace())
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Failed to load experiment history", exc_info=True)
|
||||||
|
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"experiments.html",
|
||||||
|
{
|
||||||
|
"page_title": "Experiments — Autoresearch",
|
||||||
|
"enabled": settings.autoresearch_enabled,
|
||||||
|
"history": history[:50],
|
||||||
|
"metric_name": settings.autoresearch_metric,
|
||||||
|
"time_budget": settings.autoresearch_time_budget,
|
||||||
|
"max_iterations": settings.autoresearch_max_iterations,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_class=JSONResponse)
|
||||||
|
async def start_experiment(request: Request):
|
||||||
|
"""Kick off an experiment loop in the background."""
|
||||||
|
if not settings.autoresearch_enabled:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="Autoresearch is disabled. Set AUTORESEARCH_ENABLED=true.",
|
||||||
|
)
|
||||||
|
|
||||||
|
from timmy.autoresearch import prepare_experiment
|
||||||
|
|
||||||
|
workspace = _workspace()
|
||||||
|
status = prepare_experiment(workspace)
|
||||||
|
|
||||||
|
return {"status": "started", "workspace": str(workspace), "prepare": status}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{run_id}", response_class=JSONResponse)
|
||||||
|
async def experiment_detail(run_id: str):
|
||||||
|
"""Get details for a specific experiment run."""
|
||||||
|
from timmy.autoresearch import get_experiment_history
|
||||||
|
|
||||||
|
history = get_experiment_history(_workspace())
|
||||||
|
for entry in history:
|
||||||
|
if entry.get("run_id") == run_id:
|
||||||
|
return entry
|
||||||
|
|
||||||
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
||||||
@@ -43,6 +43,7 @@ async def grok_status(request: Request):
|
|||||||
stats = None
|
stats = None
|
||||||
try:
|
try:
|
||||||
from timmy.backends import get_grok_backend
|
from timmy.backends import get_grok_backend
|
||||||
|
|
||||||
backend = get_grok_backend()
|
backend = get_grok_backend()
|
||||||
stats = {
|
stats = {
|
||||||
"total_requests": backend.stats.total_requests,
|
"total_requests": backend.stats.total_requests,
|
||||||
@@ -52,12 +53,16 @@ async def grok_status(request: Request):
|
|||||||
"errors": backend.stats.errors,
|
"errors": backend.stats.errors,
|
||||||
}
|
}
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.debug("Failed to load Grok stats", exc_info=True)
|
||||||
|
|
||||||
return templates.TemplateResponse(request, "grok_status.html", {
|
return templates.TemplateResponse(
|
||||||
"status": status,
|
request,
|
||||||
"stats": stats,
|
"grok_status.html",
|
||||||
})
|
{
|
||||||
|
"status": status,
|
||||||
|
"stats": stats,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/toggle")
|
@router.post("/toggle")
|
||||||
@@ -90,7 +95,7 @@ async def toggle_grok_mode(request: Request):
|
|||||||
success=True,
|
success=True,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.debug("Failed to log Grok toggle to Spark", exc_info=True)
|
||||||
|
|
||||||
return HTMLResponse(
|
return HTMLResponse(
|
||||||
_render_toggle_card(_grok_mode_active),
|
_render_toggle_card(_grok_mode_active),
|
||||||
@@ -104,10 +109,13 @@ def _run_grok_query(message: str) -> dict:
|
|||||||
Returns:
|
Returns:
|
||||||
{"response": str | None, "error": str | None}
|
{"response": str | None, "error": str | None}
|
||||||
"""
|
"""
|
||||||
from timmy.backends import grok_available, get_grok_backend
|
from timmy.backends import get_grok_backend, grok_available
|
||||||
|
|
||||||
if not grok_available():
|
if not grok_available():
|
||||||
return {"response": None, "error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY."}
|
return {
|
||||||
|
"response": None,
|
||||||
|
"error": "Grok is not available. Set GROK_ENABLED=true and XAI_API_KEY.",
|
||||||
|
}
|
||||||
|
|
||||||
backend = get_grok_backend()
|
backend = get_grok_backend()
|
||||||
|
|
||||||
@@ -115,12 +123,13 @@ def _run_grok_query(message: str) -> dict:
|
|||||||
if not settings.grok_free:
|
if not settings.grok_free:
|
||||||
try:
|
try:
|
||||||
from lightning.factory import get_backend as get_ln_backend
|
from lightning.factory import get_backend as get_ln_backend
|
||||||
|
|
||||||
ln = get_ln_backend()
|
ln = get_ln_backend()
|
||||||
sats = min(settings.grok_max_sats_per_query, 100)
|
sats = min(settings.grok_max_sats_per_query, 100)
|
||||||
ln.create_invoice(sats, f"Grok: {message[:50]}")
|
ln.create_invoice(sats, f"Grok: {message[:50]}")
|
||||||
invoice_note = f" | {sats} sats"
|
invoice_note = f" | {sats} sats"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.debug("Lightning invoice creation failed", exc_info=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = backend.run(message)
|
result = backend.run(message)
|
||||||
@@ -132,9 +141,10 @@ def _run_grok_query(message: str) -> dict:
|
|||||||
@router.post("/chat", response_class=HTMLResponse)
|
@router.post("/chat", response_class=HTMLResponse)
|
||||||
async def grok_chat(request: Request, message: str = Form(...)):
|
async def grok_chat(request: Request, message: str = Form(...)):
|
||||||
"""Send a message directly to Grok and return HTMX chat partial."""
|
"""Send a message directly to Grok and return HTMX chat partial."""
|
||||||
from dashboard.store import message_log
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dashboard.store import message_log
|
||||||
|
|
||||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||||
result = _run_grok_query(message)
|
result = _run_grok_query(message)
|
||||||
|
|
||||||
@@ -142,9 +152,13 @@ async def grok_chat(request: Request, message: str = Form(...)):
|
|||||||
message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser")
|
message_log.append(role="user", content=user_msg, timestamp=timestamp, source="browser")
|
||||||
|
|
||||||
if result["response"]:
|
if result["response"]:
|
||||||
message_log.append(role="agent", content=result["response"], timestamp=timestamp, source="browser")
|
message_log.append(
|
||||||
|
role="agent", content=result["response"], timestamp=timestamp, source="browser"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
message_log.append(role="error", content=result["error"], timestamp=timestamp, source="browser")
|
message_log.append(
|
||||||
|
role="error", content=result["error"], timestamp=timestamp, source="browser"
|
||||||
|
)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@@ -185,6 +199,7 @@ async def grok_stats():
|
|||||||
def _render_toggle_card(active: bool) -> str:
|
def _render_toggle_card(active: bool) -> str:
|
||||||
"""Render the Grok Mode toggle card HTML."""
|
"""Render the Grok Mode toggle card HTML."""
|
||||||
import html
|
import html
|
||||||
|
|
||||||
color = "#00ff88" if active else "#666"
|
color = "#00ff88" if active else "#666"
|
||||||
state = "ACTIVE" if active else "STANDBY"
|
state = "ACTIVE" if active else "STANDBY"
|
||||||
glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none"
|
glow = "0 0 20px rgba(0, 255, 136, 0.4)" if active else "none"
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ router = APIRouter(tags=["health"])
|
|||||||
|
|
||||||
class DependencyStatus(BaseModel):
|
class DependencyStatus(BaseModel):
|
||||||
"""Status of a single dependency."""
|
"""Status of a single dependency."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
status: str # "healthy", "degraded", "unavailable"
|
status: str # "healthy", "degraded", "unavailable"
|
||||||
sovereignty_score: int # 0-10
|
sovereignty_score: int # 0-10
|
||||||
@@ -30,6 +31,7 @@ class DependencyStatus(BaseModel):
|
|||||||
|
|
||||||
class SovereigntyReport(BaseModel):
|
class SovereigntyReport(BaseModel):
|
||||||
"""Full sovereignty audit report."""
|
"""Full sovereignty audit report."""
|
||||||
|
|
||||||
overall_score: float
|
overall_score: float
|
||||||
dependencies: list[DependencyStatus]
|
dependencies: list[DependencyStatus]
|
||||||
timestamp: str
|
timestamp: str
|
||||||
@@ -38,6 +40,7 @@ class SovereigntyReport(BaseModel):
|
|||||||
|
|
||||||
class HealthStatus(BaseModel):
|
class HealthStatus(BaseModel):
|
||||||
"""System health status."""
|
"""System health status."""
|
||||||
|
|
||||||
status: str
|
status: str
|
||||||
timestamp: str
|
timestamp: str
|
||||||
version: str
|
version: str
|
||||||
@@ -52,6 +55,7 @@ def _check_ollama_sync() -> DependencyStatus:
|
|||||||
"""Synchronous Ollama check — run via asyncio.to_thread()."""
|
"""Synchronous Ollama check — run via asyncio.to_thread()."""
|
||||||
try:
|
try:
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
f"{url}/api/tags",
|
f"{url}/api/tags",
|
||||||
@@ -67,7 +71,7 @@ def _check_ollama_sync() -> DependencyStatus:
|
|||||||
details={"url": settings.ollama_url, "model": settings.ollama_model},
|
details={"url": settings.ollama_url, "model": settings.ollama_model},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
logger.debug("Ollama health check failed", exc_info=True)
|
||||||
|
|
||||||
return DependencyStatus(
|
return DependencyStatus(
|
||||||
name="Ollama AI",
|
name="Ollama AI",
|
||||||
@@ -142,7 +146,7 @@ def _calculate_overall_score(deps: list[DependencyStatus]) -> float:
|
|||||||
def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
|
def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
|
||||||
"""Generate recommendations based on dependency status."""
|
"""Generate recommendations based on dependency status."""
|
||||||
recommendations = []
|
recommendations = []
|
||||||
|
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
if dep.status == "unavailable":
|
if dep.status == "unavailable":
|
||||||
recommendations.append(f"{dep.name} is unavailable - check configuration")
|
recommendations.append(f"{dep.name} is unavailable - check configuration")
|
||||||
@@ -151,25 +155,25 @@ def _generate_recommendations(deps: list[DependencyStatus]) -> list[str]:
|
|||||||
recommendations.append(
|
recommendations.append(
|
||||||
"Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND"
|
"Switch to real Lightning: set LIGHTNING_BACKEND=lnd and configure LND"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not recommendations:
|
if not recommendations:
|
||||||
recommendations.append("System operating optimally - all dependencies healthy")
|
recommendations.append("System operating optimally - all dependencies healthy")
|
||||||
|
|
||||||
return recommendations
|
return recommendations
|
||||||
|
|
||||||
|
|
||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
async def health_check():
|
async def health_check():
|
||||||
"""Basic health check endpoint.
|
"""Basic health check endpoint.
|
||||||
|
|
||||||
Returns legacy format for backward compatibility with existing tests,
|
Returns legacy format for backward compatibility with existing tests,
|
||||||
plus extended information for the Mission Control dashboard.
|
plus extended information for the Mission Control dashboard.
|
||||||
"""
|
"""
|
||||||
uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds()
|
uptime = (datetime.now(timezone.utc) - _START_TIME).total_seconds()
|
||||||
|
|
||||||
# Legacy format for test compatibility
|
# Legacy format for test compatibility
|
||||||
ollama_ok = await check_ollama()
|
ollama_ok = await check_ollama()
|
||||||
|
|
||||||
agent_status = "idle" if ollama_ok else "offline"
|
agent_status = "idle" if ollama_ok else "offline"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -193,12 +197,13 @@ async def health_check():
|
|||||||
async def health_status_panel(request: Request):
|
async def health_status_panel(request: Request):
|
||||||
"""Simple HTML health status panel."""
|
"""Simple HTML health status panel."""
|
||||||
ollama_ok = await check_ollama()
|
ollama_ok = await check_ollama()
|
||||||
|
|
||||||
status_text = "UP" if ollama_ok else "DOWN"
|
status_text = "UP" if ollama_ok else "DOWN"
|
||||||
status_color = "#10b981" if ollama_ok else "#ef4444"
|
status_color = "#10b981" if ollama_ok else "#ef4444"
|
||||||
import html
|
import html
|
||||||
|
|
||||||
model = html.escape(settings.ollama_model) # Include model for test compatibility
|
model = html.escape(settings.ollama_model) # Include model for test compatibility
|
||||||
|
|
||||||
html_content = f"""
|
html_content = f"""
|
||||||
<!DOCTYPE html>
|
<!DOCTYPE html>
|
||||||
<html>
|
<html>
|
||||||
@@ -217,7 +222,7 @@ async def health_status_panel(request: Request):
|
|||||||
@router.get("/health/sovereignty", response_model=SovereigntyReport)
|
@router.get("/health/sovereignty", response_model=SovereigntyReport)
|
||||||
async def sovereignty_check():
|
async def sovereignty_check():
|
||||||
"""Comprehensive sovereignty audit report.
|
"""Comprehensive sovereignty audit report.
|
||||||
|
|
||||||
Returns the status of all external dependencies with sovereignty scores.
|
Returns the status of all external dependencies with sovereignty scores.
|
||||||
Use this to verify the system is operating in a sovereign manner.
|
Use this to verify the system is operating in a sovereign manner.
|
||||||
"""
|
"""
|
||||||
@@ -226,10 +231,10 @@ async def sovereignty_check():
|
|||||||
_check_lightning(),
|
_check_lightning(),
|
||||||
_check_sqlite(),
|
_check_sqlite(),
|
||||||
]
|
]
|
||||||
|
|
||||||
overall = _calculate_overall_score(dependencies)
|
overall = _calculate_overall_score(dependencies)
|
||||||
recommendations = _generate_recommendations(dependencies)
|
recommendations = _generate_recommendations(dependencies)
|
||||||
|
|
||||||
return SovereigntyReport(
|
return SovereigntyReport(
|
||||||
overall_score=overall,
|
overall_score=overall,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ AGENT_CATALOG = [
|
|||||||
"name": "Orchestrator",
|
"name": "Orchestrator",
|
||||||
"role": "Local AI",
|
"role": "Local AI",
|
||||||
"description": (
|
"description": (
|
||||||
"Primary AI agent. Coordinates tasks, manages memory. "
|
"Primary AI agent. Coordinates tasks, manages memory. " "Uses distributed brain."
|
||||||
"Uses distributed brain."
|
|
||||||
),
|
),
|
||||||
"capabilities": "chat,reasoning,coordination,memory",
|
"capabilities": "chat,reasoning,coordination,memory",
|
||||||
"rate_sats": 0,
|
"rate_sats": 0,
|
||||||
@@ -37,11 +36,11 @@ async def api_list_agents():
|
|||||||
pending_tasks = len(await brain.get_pending_tasks(limit=1000))
|
pending_tasks = len(await brain.get_pending_tasks(limit=1000))
|
||||||
except Exception:
|
except Exception:
|
||||||
pending_tasks = 0
|
pending_tasks = 0
|
||||||
|
|
||||||
catalog = [dict(AGENT_CATALOG[0])]
|
catalog = [dict(AGENT_CATALOG[0])]
|
||||||
catalog[0]["pending_tasks"] = pending_tasks
|
catalog[0]["pending_tasks"] = pending_tasks
|
||||||
catalog[0]["status"] = "active"
|
catalog[0]["status"] = "active"
|
||||||
|
|
||||||
# Include 'total' for backward compatibility with tests
|
# Include 'total' for backward compatibility with tests
|
||||||
return {"agents": catalog, "total": len(catalog)}
|
return {"agents": catalog, "total": len(catalog)}
|
||||||
|
|
||||||
@@ -82,7 +81,7 @@ async def marketplace_ui(request: Request):
|
|||||||
"page_title": "Agent Marketplace",
|
"page_title": "Agent Marketplace",
|
||||||
"active_count": active,
|
"active_count": active,
|
||||||
"planned_count": 0,
|
"planned_count": 0,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,17 +5,17 @@ from typing import Optional
|
|||||||
from fastapi import APIRouter, Form, HTTPException, Request
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
|
from dashboard.templating import templates
|
||||||
from timmy.memory.vector_store import (
|
from timmy.memory.vector_store import (
|
||||||
store_memory,
|
delete_memory,
|
||||||
search_memories,
|
|
||||||
get_memory_stats,
|
get_memory_stats,
|
||||||
recall_personal_facts,
|
recall_personal_facts,
|
||||||
recall_personal_facts_with_ids,
|
recall_personal_facts_with_ids,
|
||||||
|
search_memories,
|
||||||
|
store_memory,
|
||||||
store_personal_fact,
|
store_personal_fact,
|
||||||
update_personal_fact,
|
update_personal_fact,
|
||||||
delete_memory,
|
|
||||||
)
|
)
|
||||||
from dashboard.templating import templates
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||||
|
|
||||||
@@ -36,10 +36,10 @@ async def memory_page(
|
|||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
limit=20,
|
limit=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = get_memory_stats()
|
stats = get_memory_stats()
|
||||||
facts = recall_personal_facts_with_ids()[:10]
|
facts = recall_personal_facts_with_ids()[:10]
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"memory.html",
|
"memory.html",
|
||||||
@@ -67,7 +67,7 @@ async def memory_search(
|
|||||||
context_type=context_type,
|
context_type=context_type,
|
||||||
limit=20,
|
limit=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return partial for HTMX
|
# Return partial for HTMX
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from fastapi.responses import HTMLResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
from dashboard.templating import templates
|
||||||
from infrastructure.models.registry import (
|
from infrastructure.models.registry import (
|
||||||
CustomModel,
|
CustomModel,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
@@ -20,7 +21,6 @@ from infrastructure.models.registry import (
|
|||||||
ModelRole,
|
ModelRole,
|
||||||
model_registry,
|
model_registry,
|
||||||
)
|
)
|
||||||
from dashboard.templating import templates
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ api_router = APIRouter(prefix="/api/v1/models", tags=["models-api"])
|
|||||||
|
|
||||||
class RegisterModelRequest(BaseModel):
|
class RegisterModelRequest(BaseModel):
|
||||||
"""Request body for model registration."""
|
"""Request body for model registration."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
format: str # gguf, safetensors, hf, ollama
|
format: str # gguf, safetensors, hf, ollama
|
||||||
path: str
|
path: str
|
||||||
@@ -45,12 +46,14 @@ class RegisterModelRequest(BaseModel):
|
|||||||
|
|
||||||
class AssignModelRequest(BaseModel):
|
class AssignModelRequest(BaseModel):
|
||||||
"""Request body for assigning a model to an agent."""
|
"""Request body for assigning a model to an agent."""
|
||||||
|
|
||||||
agent_id: str
|
agent_id: str
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
|
|
||||||
class SetActiveRequest(BaseModel):
|
class SetActiveRequest(BaseModel):
|
||||||
"""Request body for enabling/disabling a model."""
|
"""Request body for enabling/disabling a model."""
|
||||||
|
|
||||||
active: bool
|
active: bool
|
||||||
|
|
||||||
|
|
||||||
@@ -92,15 +95,14 @@ async def register_model(request: RegisterModelRequest) -> dict[str, Any]:
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Invalid format: {request.format}. "
|
detail=f"Invalid format: {request.format}. "
|
||||||
f"Choose from: {[f.value for f in ModelFormat]}",
|
f"Choose from: {[f.value for f in ModelFormat]}",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
role = ModelRole(request.role)
|
role = ModelRole(request.role)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"Invalid role: {request.role}. "
|
detail=f"Invalid role: {request.role}. " f"Choose from: {[r.value for r in ModelRole]}",
|
||||||
f"Choose from: {[r.value for r in ModelRole]}",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate path exists for non-Ollama formats
|
# Validate path exists for non-Ollama formats
|
||||||
@@ -163,9 +165,7 @@ async def unregister_model(model_name: str) -> dict[str, str]:
|
|||||||
|
|
||||||
|
|
||||||
@api_router.patch("/{model_name}/active")
|
@api_router.patch("/{model_name}/active")
|
||||||
async def set_model_active(
|
async def set_model_active(model_name: str, request: SetActiveRequest) -> dict[str, str]:
|
||||||
model_name: str, request: SetActiveRequest
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Enable or disable a model."""
|
"""Enable or disable a model."""
|
||||||
if not model_registry.set_active(model_name, request.active):
|
if not model_registry.set_active(model_name, request.active):
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||||
@@ -182,8 +182,7 @@ async def list_assignments() -> dict[str, Any]:
|
|||||||
assignments = model_registry.get_agent_assignments()
|
assignments = model_registry.get_agent_assignments()
|
||||||
return {
|
return {
|
||||||
"assignments": [
|
"assignments": [
|
||||||
{"agent_id": aid, "model_name": mname}
|
{"agent_id": aid, "model_name": mname} for aid, mname in assignments.items()
|
||||||
for aid, mname in assignments.items()
|
|
||||||
],
|
],
|
||||||
"total": len(assignments),
|
"total": len(assignments),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,8 @@
|
|||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from timmy.cascade_adapter import get_cascade_adapter
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from timmy.cascade_adapter import get_cascade_adapter
|
||||||
|
|
||||||
router = APIRouter(prefix="/router", tags=["router"])
|
router = APIRouter(prefix="/router", tags=["router"])
|
||||||
|
|
||||||
@@ -13,19 +13,19 @@ router = APIRouter(prefix="/router", tags=["router"])
|
|||||||
async def router_status_page(request: Request):
|
async def router_status_page(request: Request):
|
||||||
"""Cascade Router status dashboard."""
|
"""Cascade Router status dashboard."""
|
||||||
adapter = get_cascade_adapter()
|
adapter = get_cascade_adapter()
|
||||||
|
|
||||||
providers = adapter.get_provider_status()
|
providers = adapter.get_provider_status()
|
||||||
preferred = adapter.get_preferred_provider()
|
preferred = adapter.get_preferred_provider()
|
||||||
|
|
||||||
# Calculate overall stats
|
# Calculate overall stats
|
||||||
total_requests = sum(p["metrics"]["total"] for p in providers)
|
total_requests = sum(p["metrics"]["total"] for p in providers)
|
||||||
total_success = sum(p["metrics"]["success"] for p in providers)
|
total_success = sum(p["metrics"]["success"] for p in providers)
|
||||||
total_failed = sum(p["metrics"]["failed"] for p in providers)
|
total_failed = sum(p["metrics"]["failed"] for p in providers)
|
||||||
|
|
||||||
avg_latency = 0.0
|
avg_latency = 0.0
|
||||||
if providers:
|
if providers:
|
||||||
avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers)
|
avg_latency = sum(p["metrics"]["avg_latency_ms"] for p in providers) / len(providers)
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"router_status.html",
|
"router_status.html",
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import logging
|
|||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from spark.engine import spark_engine
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from spark.engine import spark_engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -86,23 +86,26 @@ async def spark_ui(request: Request):
|
|||||||
async def spark_status_json():
|
async def spark_status_json():
|
||||||
"""Return Spark Intelligence status as JSON."""
|
"""Return Spark Intelligence status as JSON."""
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
status = spark_engine.status()
|
status = spark_engine.status()
|
||||||
advisories = spark_engine.get_advisories()
|
advisories = spark_engine.get_advisories()
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
"status": status,
|
{
|
||||||
"advisories": [
|
"status": status,
|
||||||
{
|
"advisories": [
|
||||||
"category": a.category,
|
{
|
||||||
"priority": a.priority,
|
"category": a.category,
|
||||||
"title": a.title,
|
"priority": a.priority,
|
||||||
"detail": a.detail,
|
"title": a.title,
|
||||||
"suggested_action": a.suggested_action,
|
"detail": a.detail,
|
||||||
"subject": a.subject,
|
"suggested_action": a.suggested_action,
|
||||||
"evidence_count": a.evidence_count,
|
"subject": a.subject,
|
||||||
}
|
"evidence_count": a.evidence_count,
|
||||||
for a in advisories
|
}
|
||||||
],
|
for a in advisories
|
||||||
})
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/timeline", response_class=HTMLResponse)
|
@router.get("/timeline", response_class=HTMLResponse)
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ from typing import Optional
|
|||||||
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Request, WebSocket, WebSocketDisconnect
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
from spark.engine import spark_engine
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
from infrastructure.ws_manager.handler import ws_manager
|
from infrastructure.ws_manager.handler import ws_manager
|
||||||
|
from spark.engine import spark_engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ async def swarm_events(
|
|||||||
):
|
):
|
||||||
"""Event log page."""
|
"""Event log page."""
|
||||||
events = spark_engine.get_timeline(limit=100)
|
events = spark_engine.get_timeline(limit=100)
|
||||||
|
|
||||||
# Filter if requested
|
# Filter if requested
|
||||||
if task_id:
|
if task_id:
|
||||||
events = [e for e in events if e.task_id == task_id]
|
events = [e for e in events if e.task_id == task_id]
|
||||||
@@ -33,7 +33,7 @@ async def swarm_events(
|
|||||||
events = [e for e in events if e.agent_id == agent_id]
|
events = [e for e in events if e.agent_id == agent_id]
|
||||||
if event_type:
|
if event_type:
|
||||||
events = [e for e in events if e.event_type == event_type]
|
events = [e for e in events if e.event_type == event_type]
|
||||||
|
|
||||||
# Prepare summary and event types for template
|
# Prepare summary and event types for template
|
||||||
summary = {}
|
summary = {}
|
||||||
event_types = set()
|
event_types = set()
|
||||||
@@ -41,7 +41,7 @@ async def swarm_events(
|
|||||||
etype = e.event_type
|
etype = e.event_type
|
||||||
event_types.add(etype)
|
event_types.add(etype)
|
||||||
summary[etype] = summary.get(etype, 0) + 1
|
summary[etype] = summary.get(etype, 0) + 1
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"events.html",
|
"events.html",
|
||||||
@@ -78,14 +78,16 @@ async def swarm_ws(websocket: WebSocket):
|
|||||||
await ws_manager.connect(websocket)
|
await ws_manager.connect(websocket)
|
||||||
try:
|
try:
|
||||||
# Send initial state so frontend can clear loading placeholders
|
# Send initial state so frontend can clear loading placeholders
|
||||||
await websocket.send_json({
|
await websocket.send_json(
|
||||||
"type": "initial_state",
|
{
|
||||||
"data": {
|
"type": "initial_state",
|
||||||
"agents": {"total": 0, "active": 0, "list": []},
|
"data": {
|
||||||
"tasks": {"active": 0},
|
"agents": {"total": 0, "active": 0, "list": []},
|
||||||
"auctions": {"list": []},
|
"tasks": {"active": 0},
|
||||||
},
|
"auctions": {"list": []},
|
||||||
})
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
while True:
|
while True:
|
||||||
await websocket.receive_text()
|
await websocket.receive_text()
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
|
|||||||
@@ -25,26 +25,42 @@ async def lightning_ledger(request: Request):
|
|||||||
"pending_incoming_sats": 0,
|
"pending_incoming_sats": 0,
|
||||||
"pending_outgoing_sats": 0,
|
"pending_outgoing_sats": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Mock transactions
|
# Mock transactions
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
class TxType(Enum):
|
class TxType(Enum):
|
||||||
incoming = "incoming"
|
incoming = "incoming"
|
||||||
outgoing = "outgoing"
|
outgoing = "outgoing"
|
||||||
|
|
||||||
class TxStatus(Enum):
|
class TxStatus(Enum):
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
pending = "pending"
|
pending = "pending"
|
||||||
|
|
||||||
Tx = namedtuple("Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"])
|
Tx = namedtuple(
|
||||||
|
"Tx", ["tx_type", "status", "amount_sats", "payment_hash", "memo", "created_at"]
|
||||||
|
)
|
||||||
|
|
||||||
transactions = [
|
transactions = [
|
||||||
Tx(TxType.outgoing, TxStatus.completed, 50, "hash1", "Model inference", "2026-03-04 10:00:00"),
|
Tx(
|
||||||
Tx(TxType.incoming, TxStatus.completed, 1000, "hash2", "Manual deposit", "2026-03-03 15:00:00"),
|
TxType.outgoing,
|
||||||
|
TxStatus.completed,
|
||||||
|
50,
|
||||||
|
"hash1",
|
||||||
|
"Model inference",
|
||||||
|
"2026-03-04 10:00:00",
|
||||||
|
),
|
||||||
|
Tx(
|
||||||
|
TxType.incoming,
|
||||||
|
TxStatus.completed,
|
||||||
|
1000,
|
||||||
|
"hash2",
|
||||||
|
"Manual deposit",
|
||||||
|
"2026-03-03 15:00:00",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
"ledger.html",
|
"ledger.html",
|
||||||
@@ -84,9 +100,16 @@ async def mission_control(request: Request):
|
|||||||
|
|
||||||
@router.get("/bugs", response_class=HTMLResponse)
|
@router.get("/bugs", response_class=HTMLResponse)
|
||||||
async def bugs_page(request: Request):
|
async def bugs_page(request: Request):
|
||||||
return templates.TemplateResponse(request, "bugs.html", {
|
return templates.TemplateResponse(
|
||||||
"bugs": [], "total": 0, "stats": {}, "filter_status": None,
|
request,
|
||||||
})
|
"bugs.html",
|
||||||
|
{
|
||||||
|
"bugs": [],
|
||||||
|
"total": 0,
|
||||||
|
"stats": {},
|
||||||
|
"filter_status": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/self-coding", response_class=HTMLResponse)
|
@router.get("/self-coding", response_class=HTMLResponse)
|
||||||
@@ -109,14 +132,17 @@ async def api_notifications():
|
|||||||
"""Return recent system events for the notification dropdown."""
|
"""Return recent system events for the notification dropdown."""
|
||||||
try:
|
try:
|
||||||
from spark.engine import spark_engine
|
from spark.engine import spark_engine
|
||||||
|
|
||||||
events = spark_engine.get_timeline(limit=20)
|
events = spark_engine.get_timeline(limit=20)
|
||||||
return JSONResponse([
|
return JSONResponse(
|
||||||
{
|
[
|
||||||
"event_type": e.event_type,
|
{
|
||||||
"title": getattr(e, "description", e.event_type),
|
"event_type": e.event_type,
|
||||||
"timestamp": str(getattr(e, "timestamp", "")),
|
"title": getattr(e, "description", e.event_type),
|
||||||
}
|
"timestamp": str(getattr(e, "timestamp", "")),
|
||||||
for e in events
|
}
|
||||||
])
|
for e in events
|
||||||
|
]
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return JSONResponse([])
|
return JSONResponse([])
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request, Form
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
|
from config import settings
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -20,11 +21,17 @@ router = APIRouter(tags=["tasks"])
|
|||||||
# Database helpers
|
# Database helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
DB_PATH = Path("data/tasks.db")
|
DB_PATH = Path(settings.repo_root) / "data" / "tasks.db"
|
||||||
|
|
||||||
VALID_STATUSES = {
|
VALID_STATUSES = {
|
||||||
"pending_approval", "approved", "running", "paused",
|
"pending_approval",
|
||||||
"completed", "vetoed", "failed", "backlogged",
|
"approved",
|
||||||
|
"running",
|
||||||
|
"paused",
|
||||||
|
"completed",
|
||||||
|
"vetoed",
|
||||||
|
"failed",
|
||||||
|
"backlogged",
|
||||||
}
|
}
|
||||||
VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
|
VALID_PRIORITIES = {"low", "normal", "high", "urgent"}
|
||||||
|
|
||||||
@@ -33,7 +40,8 @@ def _get_db() -> sqlite3.Connection:
|
|||||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(DB_PATH))
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS tasks (
|
CREATE TABLE IF NOT EXISTS tasks (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
@@ -46,7 +54,8 @@ def _get_db() -> sqlite3.Connection:
|
|||||||
created_at TEXT DEFAULT (datetime('now')),
|
created_at TEXT DEFAULT (datetime('now')),
|
||||||
completed_at TEXT
|
completed_at TEXT
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -91,37 +100,52 @@ class _TaskView:
|
|||||||
# Page routes
|
# Page routes
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks", response_class=HTMLResponse)
|
@router.get("/tasks", response_class=HTMLResponse)
|
||||||
async def tasks_page(request: Request):
|
async def tasks_page(request: Request):
|
||||||
"""Render the main task queue page with 3-column layout."""
|
"""Render the main task queue page with 3-column layout."""
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
try:
|
try:
|
||||||
pending = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
pending = [
|
||||||
"SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC"
|
_TaskView(_row_to_dict(r))
|
||||||
).fetchall()]
|
for r in db.execute(
|
||||||
active = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
"SELECT * FROM tasks WHERE status IN ('pending_approval') ORDER BY created_at DESC"
|
||||||
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
|
).fetchall()
|
||||||
).fetchall()]
|
]
|
||||||
completed = [_TaskView(_row_to_dict(r)) for r in db.execute(
|
active = [
|
||||||
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
|
_TaskView(_row_to_dict(r))
|
||||||
).fetchall()]
|
for r in db.execute(
|
||||||
|
"SELECT * FROM tasks WHERE status IN ('approved','running','paused') ORDER BY created_at DESC"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
|
completed = [
|
||||||
|
_TaskView(_row_to_dict(r))
|
||||||
|
for r in db.execute(
|
||||||
|
"SELECT * FROM tasks WHERE status IN ('completed','vetoed','failed') ORDER BY completed_at DESC LIMIT 50"
|
||||||
|
).fetchall()
|
||||||
|
]
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
return templates.TemplateResponse(request, "tasks.html", {
|
return templates.TemplateResponse(
|
||||||
"pending_count": len(pending),
|
request,
|
||||||
"pending": pending,
|
"tasks.html",
|
||||||
"active": active,
|
{
|
||||||
"completed": completed,
|
"pending_count": len(pending),
|
||||||
"agents": [], # no agent roster wired yet
|
"pending": pending,
|
||||||
"pre_assign": "",
|
"active": active,
|
||||||
})
|
"completed": completed,
|
||||||
|
"agents": [], # no agent roster wired yet
|
||||||
|
"pre_assign": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# HTMX partials (polled by the template)
|
# HTMX partials (polled by the template)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks/pending", response_class=HTMLResponse)
|
@router.get("/tasks/pending", response_class=HTMLResponse)
|
||||||
async def tasks_pending(request: Request):
|
async def tasks_pending(request: Request):
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
@@ -134,9 +158,11 @@ async def tasks_pending(request: Request):
|
|||||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||||
parts = []
|
parts = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
parts.append(templates.TemplateResponse(
|
parts.append(
|
||||||
request, "partials/task_card.html", {"task": task}
|
templates.TemplateResponse(
|
||||||
).body.decode())
|
request, "partials/task_card.html", {"task": task}
|
||||||
|
).body.decode()
|
||||||
|
)
|
||||||
if not parts:
|
if not parts:
|
||||||
return HTMLResponse('<div class="empty-column">No pending tasks</div>')
|
return HTMLResponse('<div class="empty-column">No pending tasks</div>')
|
||||||
return HTMLResponse("".join(parts))
|
return HTMLResponse("".join(parts))
|
||||||
@@ -154,9 +180,11 @@ async def tasks_active(request: Request):
|
|||||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||||
parts = []
|
parts = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
parts.append(templates.TemplateResponse(
|
parts.append(
|
||||||
request, "partials/task_card.html", {"task": task}
|
templates.TemplateResponse(
|
||||||
).body.decode())
|
request, "partials/task_card.html", {"task": task}
|
||||||
|
).body.decode()
|
||||||
|
)
|
||||||
if not parts:
|
if not parts:
|
||||||
return HTMLResponse('<div class="empty-column">No active tasks</div>')
|
return HTMLResponse('<div class="empty-column">No active tasks</div>')
|
||||||
return HTMLResponse("".join(parts))
|
return HTMLResponse("".join(parts))
|
||||||
@@ -174,9 +202,11 @@ async def tasks_completed(request: Request):
|
|||||||
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
tasks = [_TaskView(_row_to_dict(r)) for r in rows]
|
||||||
parts = []
|
parts = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
parts.append(templates.TemplateResponse(
|
parts.append(
|
||||||
request, "partials/task_card.html", {"task": task}
|
templates.TemplateResponse(
|
||||||
).body.decode())
|
request, "partials/task_card.html", {"task": task}
|
||||||
|
).body.decode()
|
||||||
|
)
|
||||||
if not parts:
|
if not parts:
|
||||||
return HTMLResponse('<div class="empty-column">No completed tasks yet</div>')
|
return HTMLResponse('<div class="empty-column">No completed tasks yet</div>')
|
||||||
return HTMLResponse("".join(parts))
|
return HTMLResponse("".join(parts))
|
||||||
@@ -186,6 +216,7 @@ async def tasks_completed(request: Request):
|
|||||||
# Form-based create (used by the modal in tasks.html)
|
# Form-based create (used by the modal in tasks.html)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tasks/create", response_class=HTMLResponse)
|
@router.post("/tasks/create", response_class=HTMLResponse)
|
||||||
async def create_task_form(
|
async def create_task_form(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -218,6 +249,7 @@ async def create_task_form(
|
|||||||
# Task action endpoints (approve, veto, modify, pause, cancel, retry)
|
# Task action endpoints (approve, veto, modify, pause, cancel, retry)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tasks/{task_id}/approve", response_class=HTMLResponse)
|
@router.post("/tasks/{task_id}/approve", response_class=HTMLResponse)
|
||||||
async def approve_task(request: Request, task_id: str):
|
async def approve_task(request: Request, task_id: str):
|
||||||
return await _set_status(request, task_id, "approved")
|
return await _set_status(request, task_id, "approved")
|
||||||
@@ -268,7 +300,9 @@ async def modify_task(
|
|||||||
|
|
||||||
async def _set_status(request: Request, task_id: str, new_status: str):
|
async def _set_status(request: Request, task_id: str, new_status: str):
|
||||||
"""Helper to update status and return refreshed task card."""
|
"""Helper to update status and return refreshed task card."""
|
||||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
completed_at = (
|
||||||
|
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||||
|
)
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
try:
|
try:
|
||||||
db.execute(
|
db.execute(
|
||||||
@@ -289,6 +323,7 @@ async def _set_status(request: Request, task_id: str, new_status: str):
|
|||||||
# JSON API (for programmatic access / Timmy's tool calls)
|
# JSON API (for programmatic access / Timmy's tool calls)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/tasks", response_class=JSONResponse, status_code=201)
|
@router.post("/api/tasks", response_class=JSONResponse, status_code=201)
|
||||||
async def api_create_task(request: Request):
|
async def api_create_task(request: Request):
|
||||||
"""Create a task via JSON API."""
|
"""Create a task via JSON API."""
|
||||||
@@ -345,7 +380,9 @@ async def api_update_status(task_id: str, request: Request):
|
|||||||
if not new_status or new_status not in VALID_STATUSES:
|
if not new_status or new_status not in VALID_STATUSES:
|
||||||
raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}")
|
raise HTTPException(422, f"Invalid status. Must be one of: {VALID_STATUSES}")
|
||||||
|
|
||||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
completed_at = (
|
||||||
|
datetime.utcnow().isoformat() if new_status in ("completed", "vetoed", "failed") else None
|
||||||
|
)
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
try:
|
try:
|
||||||
db.execute(
|
db.execute(
|
||||||
@@ -379,6 +416,7 @@ async def api_delete_task(task_id: str):
|
|||||||
# Queue status (polled by the chat panel every 10 seconds)
|
# Queue status (polled by the chat panel every 10 seconds)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/queue/status", response_class=JSONResponse)
|
@router.get("/api/queue/status", response_class=JSONResponse)
|
||||||
async def queue_status(assigned_to: str = "default"):
|
async def queue_status(assigned_to: str = "default"):
|
||||||
"""Return queue status for the chat panel's agent status indicator."""
|
"""Return queue status for the chat panel's agent status indicator."""
|
||||||
@@ -396,14 +434,18 @@ async def queue_status(assigned_to: str = "default"):
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
if running:
|
if running:
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
"is_working": True,
|
{
|
||||||
"current_task": {"id": running["id"], "title": running["title"]},
|
"is_working": True,
|
||||||
"tasks_ahead": 0,
|
"current_task": {"id": running["id"], "title": running["title"]},
|
||||||
})
|
"tasks_ahead": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return JSONResponse({
|
return JSONResponse(
|
||||||
"is_working": False,
|
{
|
||||||
"current_task": None,
|
"is_working": False,
|
||||||
"tasks_ahead": ahead["cnt"] if ahead else 0,
|
"current_task": None,
|
||||||
})
|
"tasks_ahead": ahead["cnt"] if ahead else 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import logging
|
|||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
from timmy.thinking import thinking_engine
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from timmy.thinking import thinking_engine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ from collections import namedtuple
|
|||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
from timmy.tools import get_all_available_tools
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
|
from timmy.tools import get_all_available_tools
|
||||||
|
|
||||||
router = APIRouter(tags=["tools"])
|
router = APIRouter(tags=["tools"])
|
||||||
|
|
||||||
@@ -29,9 +29,7 @@ def _build_agent_tools():
|
|||||||
for name, fn in available.items()
|
for name, fn in available.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
return [
|
return [_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))]
|
||||||
_AgentView(name="Timmy", status="idle", tools=tool_views, stats=_Stats(total_calls=0))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tools", response_class=HTMLResponse)
|
@router.get("/tools", response_class=HTMLResponse)
|
||||||
|
|||||||
@@ -10,9 +10,9 @@ import logging
|
|||||||
from fastapi import APIRouter, Form, Request
|
from fastapi import APIRouter, Form, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse
|
||||||
|
|
||||||
|
from dashboard.templating import templates
|
||||||
from integrations.voice.nlu import detect_intent, extract_command
|
from integrations.voice.nlu import detect_intent, extract_command
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
from dashboard.templating import templates
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,6 +38,7 @@ async def tts_status():
|
|||||||
"""Check TTS engine availability."""
|
"""Check TTS engine availability."""
|
||||||
try:
|
try:
|
||||||
from timmy_serve.voice_tts import voice_tts
|
from timmy_serve.voice_tts import voice_tts
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"available": voice_tts.available,
|
"available": voice_tts.available,
|
||||||
"voices": voice_tts.get_voices() if voice_tts.available else [],
|
"voices": voice_tts.get_voices() if voice_tts.available else [],
|
||||||
@@ -51,6 +52,7 @@ async def tts_speak(text: str = Form(...)):
|
|||||||
"""Speak text aloud via TTS."""
|
"""Speak text aloud via TTS."""
|
||||||
try:
|
try:
|
||||||
from timmy_serve.voice_tts import voice_tts
|
from timmy_serve.voice_tts import voice_tts
|
||||||
|
|
||||||
if not voice_tts.available:
|
if not voice_tts.available:
|
||||||
return {"spoken": False, "reason": "TTS engine not available"}
|
return {"spoken": False, "reason": "TTS engine not available"}
|
||||||
voice_tts.speak(text)
|
voice_tts.speak(text)
|
||||||
@@ -86,6 +88,7 @@ async def voice_command(text: str = Form(...)):
|
|||||||
|
|
||||||
# ── Enhanced voice pipeline ──────────────────────────────────────────────
|
# ── Enhanced voice pipeline ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/enhanced/process")
|
@router.post("/enhanced/process")
|
||||||
async def process_voice_input(
|
async def process_voice_input(
|
||||||
text: str = Form(...),
|
text: str = Form(...),
|
||||||
@@ -133,6 +136,7 @@ async def process_voice_input(
|
|||||||
if speak_response and response_text:
|
if speak_response and response_text:
|
||||||
try:
|
try:
|
||||||
from timmy_serve.voice_tts import voice_tts
|
from timmy_serve.voice_tts import voice_tts
|
||||||
|
|
||||||
if voice_tts.available:
|
if voice_tts.available:
|
||||||
voice_tts.speak(response_text)
|
voice_tts.speak(response_text)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Request, Form
|
from fastapi import APIRouter, Form, HTTPException, Request
|
||||||
from fastapi.responses import HTMLResponse, JSONResponse
|
from fastapi.responses import HTMLResponse, JSONResponse
|
||||||
|
|
||||||
from dashboard.templating import templates
|
from dashboard.templating import templates
|
||||||
@@ -26,7 +26,8 @@ def _get_db() -> sqlite3.Connection:
|
|||||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(DB_PATH))
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS work_orders (
|
CREATE TABLE IF NOT EXISTS work_orders (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
@@ -41,7 +42,8 @@ def _get_db() -> sqlite3.Connection:
|
|||||||
created_at TEXT DEFAULT (datetime('now')),
|
created_at TEXT DEFAULT (datetime('now')),
|
||||||
completed_at TEXT
|
completed_at TEXT
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -71,7 +73,9 @@ class _WOView:
|
|||||||
self.submitter = row.get("submitter", "dashboard")
|
self.submitter = row.get("submitter", "dashboard")
|
||||||
self.status = _EnumLike(row.get("status", "submitted"))
|
self.status = _EnumLike(row.get("status", "submitted"))
|
||||||
raw_files = row.get("related_files", "")
|
raw_files = row.get("related_files", "")
|
||||||
self.related_files = [f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else []
|
self.related_files = (
|
||||||
|
[f.strip() for f in raw_files.split(",") if f.strip()] if raw_files else []
|
||||||
|
)
|
||||||
self.result = row.get("result", "")
|
self.result = row.get("result", "")
|
||||||
self.rejection_reason = row.get("rejection_reason", "")
|
self.rejection_reason = row.get("rejection_reason", "")
|
||||||
self.created_at = row.get("created_at", "")
|
self.created_at = row.get("created_at", "")
|
||||||
@@ -98,6 +102,7 @@ def _query_wos(db, statuses):
|
|||||||
# Page route
|
# Page route
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.get("/work-orders/queue", response_class=HTMLResponse)
|
@router.get("/work-orders/queue", response_class=HTMLResponse)
|
||||||
async def work_orders_page(request: Request):
|
async def work_orders_page(request: Request):
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
@@ -109,21 +114,26 @@ async def work_orders_page(request: Request):
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
return templates.TemplateResponse(request, "work_orders.html", {
|
return templates.TemplateResponse(
|
||||||
"pending_count": len(pending),
|
request,
|
||||||
"pending": pending,
|
"work_orders.html",
|
||||||
"active": active,
|
{
|
||||||
"completed": completed,
|
"pending_count": len(pending),
|
||||||
"rejected": rejected,
|
"pending": pending,
|
||||||
"priorities": PRIORITIES,
|
"active": active,
|
||||||
"categories": CATEGORIES,
|
"completed": completed,
|
||||||
})
|
"rejected": rejected,
|
||||||
|
"priorities": PRIORITIES,
|
||||||
|
"categories": CATEGORIES,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Form submit
|
# Form submit
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.post("/work-orders/submit", response_class=HTMLResponse)
|
@router.post("/work-orders/submit", response_class=HTMLResponse)
|
||||||
async def submit_work_order(
|
async def submit_work_order(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -159,6 +169,7 @@ async def submit_work_order(
|
|||||||
# HTMX partials
|
# HTMX partials
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@router.get("/work-orders/queue/pending", response_class=HTMLResponse)
|
@router.get("/work-orders/queue/pending", response_class=HTMLResponse)
|
||||||
async def pending_partial(request: Request):
|
async def pending_partial(request: Request):
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
@@ -174,7 +185,9 @@ async def pending_partial(request: Request):
|
|||||||
parts = []
|
parts = []
|
||||||
for wo in wos:
|
for wo in wos:
|
||||||
parts.append(
|
parts.append(
|
||||||
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode()
|
templates.TemplateResponse(
|
||||||
|
request, "partials/work_order_card.html", {"wo": wo}
|
||||||
|
).body.decode()
|
||||||
)
|
)
|
||||||
return HTMLResponse("".join(parts))
|
return HTMLResponse("".join(parts))
|
||||||
|
|
||||||
@@ -194,7 +207,9 @@ async def active_partial(request: Request):
|
|||||||
parts = []
|
parts = []
|
||||||
for wo in wos:
|
for wo in wos:
|
||||||
parts.append(
|
parts.append(
|
||||||
templates.TemplateResponse(request, "partials/work_order_card.html", {"wo": wo}).body.decode()
|
templates.TemplateResponse(
|
||||||
|
request, "partials/work_order_card.html", {"wo": wo}
|
||||||
|
).body.decode()
|
||||||
)
|
)
|
||||||
return HTMLResponse("".join(parts))
|
return HTMLResponse("".join(parts))
|
||||||
|
|
||||||
@@ -203,8 +218,11 @@ async def active_partial(request: Request):
|
|||||||
# Action endpoints
|
# Action endpoints
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _update_status(request: Request, wo_id: str, new_status: str, **extra):
|
async def _update_status(request: Request, wo_id: str, new_status: str, **extra):
|
||||||
completed_at = datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
|
completed_at = (
|
||||||
|
datetime.utcnow().isoformat() if new_status in ("completed", "rejected") else None
|
||||||
|
)
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
try:
|
try:
|
||||||
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]
|
sets = ["status=?", "completed_at=COALESCE(?, completed_at)"]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message:
|
class Message:
|
||||||
role: str # "user" | "agent" | "error"
|
role: str # "user" | "agent" | "error"
|
||||||
content: str
|
content: str
|
||||||
timestamp: str
|
timestamp: str
|
||||||
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
source: str = "browser" # "browser" | "api" | "telegram" | "discord" | "system"
|
||||||
@@ -16,7 +16,9 @@ class MessageLog:
|
|||||||
self._entries: list[Message] = []
|
self._entries: list[Message] = []
|
||||||
|
|
||||||
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
def append(self, role: str, content: str, timestamp: str, source: str = "browser") -> None:
|
||||||
self._entries.append(Message(role=role, content=content, timestamp=timestamp, source=source))
|
self._entries.append(
|
||||||
|
Message(role=role, content=content, timestamp=timestamp, source=source)
|
||||||
|
)
|
||||||
|
|
||||||
def all(self) -> list[Message]:
|
def all(self) -> list[Message]:
|
||||||
return list(self._entries)
|
return list(self._entries)
|
||||||
|
|||||||
90
src/dashboard/templates/experiments.html
Normal file
90
src/dashboard/templates/experiments.html
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
|
||||||
|
{% block title %}{{ page_title }}{% endblock %}
|
||||||
|
|
||||||
|
{% block extra_styles %}
|
||||||
|
<style>
|
||||||
|
.experiments-container { max-width: 1000px; margin: 0 auto; }
|
||||||
|
.exp-header { display: flex; justify-content: space-between; align-items: center; margin-bottom: 20px; }
|
||||||
|
.exp-title { font-size: 1.3rem; font-weight: 700; color: var(--text-bright); }
|
||||||
|
.exp-subtitle { font-size: 0.8rem; color: var(--text-dim); margin-top: 2px; }
|
||||||
|
.exp-config { display: flex; gap: 16px; font-size: 0.8rem; color: var(--text-dim); }
|
||||||
|
.exp-config span { background: var(--glass-bg); border: 1px solid var(--border); padding: 4px 10px; border-radius: 6px; }
|
||||||
|
.exp-table { width: 100%; border-collapse: collapse; font-size: 0.85rem; }
|
||||||
|
.exp-table th { text-align: left; padding: 8px 12px; color: var(--text-dim); border-bottom: 1px solid var(--border); font-weight: 600; }
|
||||||
|
.exp-table td { padding: 8px 12px; border-bottom: 1px solid var(--border); color: var(--text); }
|
||||||
|
.exp-table tr:hover { background: var(--glass-bg); }
|
||||||
|
.metric-good { color: var(--success); }
|
||||||
|
.metric-bad { color: var(--danger); }
|
||||||
|
.btn-start { background: var(--accent); color: #fff; border: none; padding: 8px 18px; border-radius: 6px; cursor: pointer; font-size: 0.85rem; }
|
||||||
|
.btn-start:hover { opacity: 0.9; }
|
||||||
|
.btn-start:disabled { opacity: 0.4; cursor: not-allowed; }
|
||||||
|
.disabled-note { font-size: 0.8rem; color: var(--text-dim); margin-top: 8px; }
|
||||||
|
.empty-state { text-align: center; padding: 40px; color: var(--text-dim); }
|
||||||
|
</style>
|
||||||
|
{% endblock %}
|
||||||
|
|
||||||
|
{% block content %}
|
||||||
|
<div class="experiments-container">
|
||||||
|
<div class="exp-header">
|
||||||
|
<div>
|
||||||
|
<div class="exp-title">Autoresearch Experiments</div>
|
||||||
|
<div class="exp-subtitle">Autonomous ML experiment loops — modify code, train, evaluate, iterate</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
{% if enabled %}
|
||||||
|
<button class="btn-start"
|
||||||
|
hx-post="/experiments/start"
|
||||||
|
hx-target="#experiment-status"
|
||||||
|
hx-swap="innerHTML">
|
||||||
|
Start Experiment
|
||||||
|
</button>
|
||||||
|
{% else %}
|
||||||
|
<button class="btn-start" disabled>Disabled</button>
|
||||||
|
<div class="disabled-note">Set AUTORESEARCH_ENABLED=true to enable</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="exp-config">
|
||||||
|
<span>Metric: {{ metric_name }}</span>
|
||||||
|
<span>Budget: {{ time_budget }}s</span>
|
||||||
|
<span>Max iters: {{ max_iterations }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="experiment-status" style="margin: 12px 0;"></div>
|
||||||
|
|
||||||
|
{% if history %}
|
||||||
|
<table class="exp-table">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>#</th>
|
||||||
|
<th>{{ metric_name }}</th>
|
||||||
|
<th>Duration</th>
|
||||||
|
<th>Status</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{% for run in history %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ loop.index }}</td>
|
||||||
|
<td>
|
||||||
|
{% if run.metric is not none %}
|
||||||
|
{{ "%.4f"|format(run.metric) }}
|
||||||
|
{% else %}
|
||||||
|
—
|
||||||
|
{% endif %}
|
||||||
|
</td>
|
||||||
|
<td>{{ run.get("duration_s", "—") }}s</td>
|
||||||
|
<td>{% if run.get("success") %}OK{% else %}{{ run.get("error", "failed") }}{% endif %}</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
{% else %}
|
||||||
|
<div class="empty-state">
|
||||||
|
No experiments yet. Start one to begin autonomous training.
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -119,9 +119,7 @@ def capture_error(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# Format the stack trace
|
# Format the stack trace
|
||||||
tb_str = "".join(
|
tb_str = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
|
||||||
traceback.format_exception(type(exc), exc, exc.__traceback__)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract file/line from traceback
|
# Extract file/line from traceback
|
||||||
tb_obj = exc.__traceback__
|
tb_obj = exc.__traceback__
|
||||||
|
|||||||
@@ -19,38 +19,39 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class EventBroadcaster:
|
class EventBroadcaster:
|
||||||
"""Broadcasts events to WebSocket clients.
|
"""Broadcasts events to WebSocket clients.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from infrastructure.events.broadcaster import event_broadcaster
|
from infrastructure.events.broadcaster import event_broadcaster
|
||||||
event_broadcaster.broadcast(event)
|
event_broadcaster.broadcast(event)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._ws_manager: Optional = None
|
self._ws_manager: Optional = None
|
||||||
|
|
||||||
def _get_ws_manager(self):
|
def _get_ws_manager(self):
|
||||||
"""Lazy import to avoid circular deps."""
|
"""Lazy import to avoid circular deps."""
|
||||||
if self._ws_manager is None:
|
if self._ws_manager is None:
|
||||||
try:
|
try:
|
||||||
from infrastructure.ws_manager.handler import ws_manager
|
from infrastructure.ws_manager.handler import ws_manager
|
||||||
|
|
||||||
self._ws_manager = ws_manager
|
self._ws_manager = ws_manager
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("WebSocket manager not available: %s", exc)
|
logger.debug("WebSocket manager not available: %s", exc)
|
||||||
return self._ws_manager
|
return self._ws_manager
|
||||||
|
|
||||||
async def broadcast(self, event: EventLogEntry) -> int:
|
async def broadcast(self, event: EventLogEntry) -> int:
|
||||||
"""Broadcast an event to all connected WebSocket clients.
|
"""Broadcast an event to all connected WebSocket clients.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The event to broadcast
|
event: The event to broadcast
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of clients notified
|
Number of clients notified
|
||||||
"""
|
"""
|
||||||
ws_manager = self._get_ws_manager()
|
ws_manager = self._get_ws_manager()
|
||||||
if not ws_manager:
|
if not ws_manager:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# Build message payload
|
# Build message payload
|
||||||
payload = {
|
payload = {
|
||||||
"type": "event",
|
"type": "event",
|
||||||
@@ -62,9 +63,9 @@ class EventBroadcaster:
|
|||||||
"agent_id": event.agent_id,
|
"agent_id": event.agent_id,
|
||||||
"timestamp": event.timestamp,
|
"timestamp": event.timestamp,
|
||||||
"data": event.data,
|
"data": event.data,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Broadcast to all connected clients
|
# Broadcast to all connected clients
|
||||||
count = await ws_manager.broadcast_json(payload)
|
count = await ws_manager.broadcast_json(payload)
|
||||||
@@ -73,10 +74,10 @@ class EventBroadcaster:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Failed to broadcast event: %s", exc)
|
logger.error("Failed to broadcast event: %s", exc)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def broadcast_sync(self, event: EventLogEntry) -> None:
|
def broadcast_sync(self, event: EventLogEntry) -> None:
|
||||||
"""Synchronous wrapper for broadcast.
|
"""Synchronous wrapper for broadcast.
|
||||||
|
|
||||||
Use this from synchronous code - it schedules the async broadcast
|
Use this from synchronous code - it schedules the async broadcast
|
||||||
in the event loop if one is running.
|
in the event loop if one is running.
|
||||||
"""
|
"""
|
||||||
@@ -151,11 +152,11 @@ def get_event_label(event_type: str) -> str:
|
|||||||
|
|
||||||
def format_event_for_display(event: EventLogEntry) -> dict:
|
def format_event_for_display(event: EventLogEntry) -> dict:
|
||||||
"""Format event for display in activity feed.
|
"""Format event for display in activity feed.
|
||||||
|
|
||||||
Returns dict with display-friendly fields.
|
Returns dict with display-friendly fields.
|
||||||
"""
|
"""
|
||||||
data = event.data or {}
|
data = event.data or {}
|
||||||
|
|
||||||
# Build description based on event type
|
# Build description based on event type
|
||||||
description = ""
|
description = ""
|
||||||
if event.event_type.value == "task.created":
|
if event.event_type.value == "task.created":
|
||||||
@@ -178,7 +179,7 @@ def format_event_for_display(event: EventLogEntry) -> dict:
|
|||||||
val = str(data[key])
|
val = str(data[key])
|
||||||
description = val[:60] + "..." if len(val) > 60 else val
|
description = val[:60] + "..." if len(val) > 60 else val
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": event.id,
|
"id": event.id,
|
||||||
"icon": get_event_icon(event.event_type.value),
|
"icon": get_event_icon(event.event_type.value),
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Event:
|
class Event:
|
||||||
"""A typed event in the system."""
|
"""A typed event in the system."""
|
||||||
|
|
||||||
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
|
type: str # e.g., "agent.task.assigned", "tool.execution.completed"
|
||||||
source: str # Agent or component that emitted the event
|
source: str # Agent or component that emitted the event
|
||||||
data: dict = field(default_factory=dict)
|
data: dict = field(default_factory=dict)
|
||||||
@@ -29,15 +30,15 @@ EventHandler = Callable[[Event], Coroutine[Any, Any, None]]
|
|||||||
|
|
||||||
class EventBus:
|
class EventBus:
|
||||||
"""Async event bus for publish/subscribe pattern.
|
"""Async event bus for publish/subscribe pattern.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
bus = EventBus()
|
bus = EventBus()
|
||||||
|
|
||||||
# Subscribe to events
|
# Subscribe to events
|
||||||
@bus.subscribe("agent.task.*")
|
@bus.subscribe("agent.task.*")
|
||||||
async def handle_task(event: Event):
|
async def handle_task(event: Event):
|
||||||
print(f"Task event: {event.data}")
|
print(f"Task event: {event.data}")
|
||||||
|
|
||||||
# Publish events
|
# Publish events
|
||||||
await bus.publish(Event(
|
await bus.publish(Event(
|
||||||
type="agent.task.assigned",
|
type="agent.task.assigned",
|
||||||
@@ -45,88 +46,89 @@ class EventBus:
|
|||||||
data={"task_id": "123", "agent": "forge"}
|
data={"task_id": "123", "agent": "forge"}
|
||||||
))
|
))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._subscribers: dict[str, list[EventHandler]] = {}
|
self._subscribers: dict[str, list[EventHandler]] = {}
|
||||||
self._history: list[Event] = []
|
self._history: list[Event] = []
|
||||||
self._max_history = 1000
|
self._max_history = 1000
|
||||||
logger.info("EventBus initialized")
|
logger.info("EventBus initialized")
|
||||||
|
|
||||||
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
def subscribe(self, event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||||
"""Decorator to subscribe to events matching a pattern.
|
"""Decorator to subscribe to events matching a pattern.
|
||||||
|
|
||||||
Patterns support wildcards:
|
Patterns support wildcards:
|
||||||
- "agent.task.assigned" — exact match
|
- "agent.task.assigned" — exact match
|
||||||
- "agent.task.*" — any task event
|
- "agent.task.*" — any task event
|
||||||
- "agent.*" — any agent event
|
- "agent.*" — any agent event
|
||||||
- "*" — all events
|
- "*" — all events
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(handler: EventHandler) -> EventHandler:
|
def decorator(handler: EventHandler) -> EventHandler:
|
||||||
if event_pattern not in self._subscribers:
|
if event_pattern not in self._subscribers:
|
||||||
self._subscribers[event_pattern] = []
|
self._subscribers[event_pattern] = []
|
||||||
self._subscribers[event_pattern].append(handler)
|
self._subscribers[event_pattern].append(handler)
|
||||||
logger.debug("Subscribed handler to '%s'", event_pattern)
|
logger.debug("Subscribed handler to '%s'", event_pattern)
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
|
def unsubscribe(self, event_pattern: str, handler: EventHandler) -> bool:
|
||||||
"""Remove a handler from a subscription."""
|
"""Remove a handler from a subscription."""
|
||||||
if event_pattern not in self._subscribers:
|
if event_pattern not in self._subscribers:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if handler in self._subscribers[event_pattern]:
|
if handler in self._subscribers[event_pattern]:
|
||||||
self._subscribers[event_pattern].remove(handler)
|
self._subscribers[event_pattern].remove(handler)
|
||||||
logger.debug("Unsubscribed handler from '%s'", event_pattern)
|
logger.debug("Unsubscribed handler from '%s'", event_pattern)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def publish(self, event: Event) -> int:
|
async def publish(self, event: Event) -> int:
|
||||||
"""Publish an event to all matching subscribers.
|
"""Publish an event to all matching subscribers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of handlers invoked
|
Number of handlers invoked
|
||||||
"""
|
"""
|
||||||
# Store in history
|
# Store in history
|
||||||
self._history.append(event)
|
self._history.append(event)
|
||||||
if len(self._history) > self._max_history:
|
if len(self._history) > self._max_history:
|
||||||
self._history = self._history[-self._max_history:]
|
self._history = self._history[-self._max_history :]
|
||||||
|
|
||||||
# Find matching handlers
|
# Find matching handlers
|
||||||
handlers: list[EventHandler] = []
|
handlers: list[EventHandler] = []
|
||||||
|
|
||||||
for pattern, pattern_handlers in self._subscribers.items():
|
for pattern, pattern_handlers in self._subscribers.items():
|
||||||
if self._match_pattern(event.type, pattern):
|
if self._match_pattern(event.type, pattern):
|
||||||
handlers.extend(pattern_handlers)
|
handlers.extend(pattern_handlers)
|
||||||
|
|
||||||
# Invoke handlers concurrently
|
# Invoke handlers concurrently
|
||||||
if handlers:
|
if handlers:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[self._invoke_handler(h, event) for h in handlers],
|
*[self._invoke_handler(h, event) for h in handlers], return_exceptions=True
|
||||||
return_exceptions=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
|
logger.debug("Published event '%s' to %d handlers", event.type, len(handlers))
|
||||||
return len(handlers)
|
return len(handlers)
|
||||||
|
|
||||||
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
|
async def _invoke_handler(self, handler: EventHandler, event: Event) -> None:
|
||||||
"""Invoke a handler with error handling."""
|
"""Invoke a handler with error handling."""
|
||||||
try:
|
try:
|
||||||
await handler(event)
|
await handler(event)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Event handler failed for '%s': %s", event.type, exc)
|
logger.error("Event handler failed for '%s': %s", event.type, exc)
|
||||||
|
|
||||||
def _match_pattern(self, event_type: str, pattern: str) -> bool:
|
def _match_pattern(self, event_type: str, pattern: str) -> bool:
|
||||||
"""Check if event type matches a wildcard pattern."""
|
"""Check if event type matches a wildcard pattern."""
|
||||||
if pattern == "*":
|
if pattern == "*":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if pattern.endswith(".*"):
|
if pattern.endswith(".*"):
|
||||||
prefix = pattern[:-2]
|
prefix = pattern[:-2]
|
||||||
return event_type.startswith(prefix + ".")
|
return event_type.startswith(prefix + ".")
|
||||||
|
|
||||||
return event_type == pattern
|
return event_type == pattern
|
||||||
|
|
||||||
def get_history(
|
def get_history(
|
||||||
self,
|
self,
|
||||||
event_type: str | None = None,
|
event_type: str | None = None,
|
||||||
@@ -135,15 +137,15 @@ class EventBus:
|
|||||||
) -> list[Event]:
|
) -> list[Event]:
|
||||||
"""Get recent event history with optional filtering."""
|
"""Get recent event history with optional filtering."""
|
||||||
events = self._history
|
events = self._history
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
events = [e for e in events if e.type == event_type]
|
events = [e for e in events if e.type == event_type]
|
||||||
|
|
||||||
if source:
|
if source:
|
||||||
events = [e for e in events if e.source == source]
|
events = [e for e in events if e.source == source]
|
||||||
|
|
||||||
return events[-limit:]
|
return events[-limit:]
|
||||||
|
|
||||||
def clear_history(self) -> None:
|
def clear_history(self) -> None:
|
||||||
"""Clear event history."""
|
"""Clear event history."""
|
||||||
self._history.clear()
|
self._history.clear()
|
||||||
@@ -156,11 +158,13 @@ event_bus = EventBus()
|
|||||||
# Convenience functions
|
# Convenience functions
|
||||||
async def emit(event_type: str, source: str, data: dict) -> int:
|
async def emit(event_type: str, source: str, data: dict) -> int:
|
||||||
"""Quick emit an event."""
|
"""Quick emit an event."""
|
||||||
return await event_bus.publish(Event(
|
return await event_bus.publish(
|
||||||
type=event_type,
|
Event(
|
||||||
source=source,
|
type=event_type,
|
||||||
data=data,
|
source=source,
|
||||||
))
|
data=data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
def on(event_pattern: str) -> Callable[[EventHandler], EventHandler]:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Usage:
|
|||||||
result = await git_hand.run("status")
|
result = await git_hand.run("status")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from infrastructure.hands.shell import shell_hand
|
|
||||||
from infrastructure.hands.git import git_hand
|
from infrastructure.hands.git import git_hand
|
||||||
|
from infrastructure.hands.shell import shell_hand
|
||||||
|
|
||||||
__all__ = ["shell_hand", "git_hand"]
|
__all__ = ["shell_hand", "git_hand"]
|
||||||
|
|||||||
@@ -25,16 +25,18 @@ from config import settings
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Operations that require explicit confirmation before execution
|
# Operations that require explicit confirmation before execution
|
||||||
DESTRUCTIVE_OPS = frozenset({
|
DESTRUCTIVE_OPS = frozenset(
|
||||||
"push --force",
|
{
|
||||||
"push -f",
|
"push --force",
|
||||||
"reset --hard",
|
"push -f",
|
||||||
"clean -fd",
|
"reset --hard",
|
||||||
"clean -f",
|
"clean -fd",
|
||||||
"branch -D",
|
"clean -f",
|
||||||
"checkout -- .",
|
"branch -D",
|
||||||
"restore .",
|
"checkout -- .",
|
||||||
})
|
"restore .",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -190,7 +192,9 @@ class GitHand:
|
|||||||
flag = "-b" if create else ""
|
flag = "-b" if create else ""
|
||||||
return await self.run(f"checkout {flag} {branch}".strip())
|
return await self.run(f"checkout {flag} {branch}".strip())
|
||||||
|
|
||||||
async def push(self, remote: str = "origin", branch: str = "", force: bool = False) -> GitResult:
|
async def push(
|
||||||
|
self, remote: str = "origin", branch: str = "", force: bool = False
|
||||||
|
) -> GitResult:
|
||||||
"""Push to remote. Force-push requires explicit opt-in."""
|
"""Push to remote. Force-push requires explicit opt-in."""
|
||||||
args = f"push -u {remote} {branch}".strip()
|
args = f"push -u {remote} {branch}".strip()
|
||||||
if force:
|
if force:
|
||||||
|
|||||||
@@ -26,15 +26,17 @@ from config import settings
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Commands that are always blocked regardless of allow-list
|
# Commands that are always blocked regardless of allow-list
|
||||||
_BLOCKED_COMMANDS = frozenset({
|
_BLOCKED_COMMANDS = frozenset(
|
||||||
"rm -rf /",
|
{
|
||||||
"rm -rf /*",
|
"rm -rf /",
|
||||||
"mkfs",
|
"rm -rf /*",
|
||||||
"dd if=/dev/zero",
|
"mkfs",
|
||||||
":(){ :|:& };:", # fork bomb
|
"dd if=/dev/zero",
|
||||||
"> /dev/sda",
|
":(){ :|:& };:", # fork bomb
|
||||||
"chmod -R 777 /",
|
"> /dev/sda",
|
||||||
})
|
"chmod -R 777 /",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Default allow-list: safe build/dev commands
|
# Default allow-list: safe build/dev commands
|
||||||
DEFAULT_ALLOWED_PREFIXES = (
|
DEFAULT_ALLOWED_PREFIXES = (
|
||||||
@@ -199,9 +201,7 @@ class ShellHand:
|
|||||||
proc.kill()
|
proc.kill()
|
||||||
await proc.wait()
|
await proc.wait()
|
||||||
latency = (time.time() - start) * 1000
|
latency = (time.time() - start) * 1000
|
||||||
logger.warning(
|
logger.warning("Shell command timed out after %ds: %s", effective_timeout, command)
|
||||||
"Shell command timed out after %ds: %s", effective_timeout, command
|
|
||||||
)
|
|
||||||
return ShellResult(
|
return ShellResult(
|
||||||
command=command,
|
command=command,
|
||||||
success=False,
|
success=False,
|
||||||
|
|||||||
@@ -11,15 +11,17 @@ the tool registry.
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from infrastructure.hands.shell import shell_hand
|
|
||||||
from infrastructure.hands.git import git_hand
|
from infrastructure.hands.git import git_hand
|
||||||
|
from infrastructure.hands.shell import shell_hand
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mcp.schemas.base import create_tool_schema
|
from mcp.schemas.base import create_tool_schema
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def create_tool_schema(**kwargs):
|
def create_tool_schema(**kwargs):
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||||
@@ -83,6 +85,7 @@ PERSONA_LOCAL_HAND_MAP: dict[str, list[str]] = {
|
|||||||
|
|
||||||
# ── Handlers ─────────────────────────────────────────────────────────────────
|
# ── Handlers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _handle_shell(**kwargs: Any) -> str:
|
async def _handle_shell(**kwargs: Any) -> str:
|
||||||
"""Handler for the shell MCP tool."""
|
"""Handler for the shell MCP tool."""
|
||||||
command = kwargs.get("command", "")
|
command = kwargs.get("command", "")
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
"""Infrastructure models package."""
|
"""Infrastructure models package."""
|
||||||
|
|
||||||
from infrastructure.models.registry import (
|
|
||||||
CustomModel,
|
|
||||||
ModelFormat,
|
|
||||||
ModelRegistry,
|
|
||||||
ModelRole,
|
|
||||||
model_registry,
|
|
||||||
)
|
|
||||||
from infrastructure.models.multimodal import (
|
from infrastructure.models.multimodal import (
|
||||||
ModelCapability,
|
ModelCapability,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
@@ -17,6 +10,13 @@ from infrastructure.models.multimodal import (
|
|||||||
model_supports_vision,
|
model_supports_vision,
|
||||||
pull_model_with_fallback,
|
pull_model_with_fallback,
|
||||||
)
|
)
|
||||||
|
from infrastructure.models.registry import (
|
||||||
|
CustomModel,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRegistry,
|
||||||
|
ModelRole,
|
||||||
|
model_registry,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Registry
|
# Registry
|
||||||
|
|||||||
@@ -21,39 +21,130 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ModelCapability(Enum):
|
class ModelCapability(Enum):
|
||||||
"""Capabilities a model can have."""
|
"""Capabilities a model can have."""
|
||||||
TEXT = auto() # Standard text completion
|
|
||||||
VISION = auto() # Image understanding
|
TEXT = auto() # Standard text completion
|
||||||
AUDIO = auto() # Audio/speech processing
|
VISION = auto() # Image understanding
|
||||||
TOOLS = auto() # Function calling / tool use
|
AUDIO = auto() # Audio/speech processing
|
||||||
JSON = auto() # Structured output / JSON mode
|
TOOLS = auto() # Function calling / tool use
|
||||||
STREAMING = auto() # Streaming responses
|
JSON = auto() # Structured output / JSON mode
|
||||||
|
STREAMING = auto() # Streaming responses
|
||||||
|
|
||||||
|
|
||||||
# Known model capabilities (local Ollama models)
|
# Known model capabilities (local Ollama models)
|
||||||
# These are used when we can't query the model directly
|
# These are used when we can't query the model directly
|
||||||
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||||
# Llama 3.x series
|
# Llama 3.x series
|
||||||
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"llama3.1": {
|
||||||
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TEXT,
|
||||||
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TOOLS,
|
||||||
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.JSON,
|
||||||
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.STREAMING,
|
||||||
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
},
|
||||||
|
"llama3.1:8b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"llama3.1:8b-instruct": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"llama3.1:70b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"llama3.1:405b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"llama3.2": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
"llama3.2:3b": {
|
||||||
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
ModelCapability.TEXT,
|
||||||
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
|
"llama3.2-vision": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
|
"llama3.2-vision:11b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
# Qwen series
|
# Qwen series
|
||||||
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"qwen2.5": {
|
||||||
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TEXT,
|
||||||
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TOOLS,
|
||||||
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.JSON,
|
||||||
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.STREAMING,
|
||||||
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
},
|
||||||
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
"qwen2.5:7b": {
|
||||||
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"qwen2.5:14b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"qwen2.5:32b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"qwen2.5:72b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"qwen2.5-vl": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
|
"qwen2.5-vl:3b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
|
"qwen2.5-vl:7b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
ModelCapability.VISION,
|
||||||
|
},
|
||||||
# DeepSeek series
|
# DeepSeek series
|
||||||
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
@@ -61,21 +152,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
|||||||
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"deepseek-v3": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
# Gemma series
|
# Gemma series
|
||||||
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
# Mistral series
|
# Mistral series
|
||||||
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"mistral": {
|
||||||
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TEXT,
|
||||||
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TOOLS,
|
||||||
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.JSON,
|
||||||
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"mistral:7b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"mistral-nemo": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"mistral-small": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"mistral-large": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
# Vision-specific models
|
# Vision-specific models
|
||||||
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
@@ -86,21 +204,48 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
|||||||
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
|
||||||
# Phi series
|
# Phi series
|
||||||
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"phi4": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
# Command R
|
# Command R
|
||||||
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"command-r": {
|
||||||
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TEXT,
|
||||||
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"command-r:35b": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"command-r-plus": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
# Granite (IBM)
|
# Granite (IBM)
|
||||||
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
"granite3-dense": {
|
||||||
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
|
"granite3-moe": {
|
||||||
|
ModelCapability.TEXT,
|
||||||
|
ModelCapability.TOOLS,
|
||||||
|
ModelCapability.JSON,
|
||||||
|
ModelCapability.STREAMING,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -108,15 +253,15 @@ KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
|||||||
# These are tried in order when the primary model doesn't support a capability
|
# These are tried in order when the primary model doesn't support a capability
|
||||||
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||||
ModelCapability.VISION: [
|
ModelCapability.VISION: [
|
||||||
"llama3.2:3b", # Fast vision model
|
"llama3.2:3b", # Fast vision model
|
||||||
"llava:7b", # Classic vision model
|
"llava:7b", # Classic vision model
|
||||||
"qwen2.5-vl:3b", # Qwen vision
|
"qwen2.5-vl:3b", # Qwen vision
|
||||||
"moondream:1.8b", # Tiny vision model (last resort)
|
"moondream:1.8b", # Tiny vision model (last resort)
|
||||||
],
|
],
|
||||||
ModelCapability.TOOLS: [
|
ModelCapability.TOOLS: [
|
||||||
"llama3.1:8b-instruct", # Best tool use
|
"llama3.1:8b-instruct", # Best tool use
|
||||||
"llama3.2:3b", # Smaller but capable
|
"llama3.2:3b", # Smaller but capable
|
||||||
"qwen2.5:7b", # Reliable fallback
|
"qwen2.5:7b", # Reliable fallback
|
||||||
],
|
],
|
||||||
ModelCapability.AUDIO: [
|
ModelCapability.AUDIO: [
|
||||||
# Audio models are less common in Ollama
|
# Audio models are less common in Ollama
|
||||||
@@ -128,13 +273,14 @@ DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo:
|
class ModelInfo:
|
||||||
"""Information about a model's capabilities and availability."""
|
"""Information about a model's capabilities and availability."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
capabilities: set[ModelCapability] = field(default_factory=set)
|
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||||
is_available: bool = False
|
is_available: bool = False
|
||||||
is_pulled: bool = False
|
is_pulled: bool = False
|
||||||
size_mb: Optional[int] = None
|
size_mb: Optional[int] = None
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
def supports(self, capability: ModelCapability) -> bool:
|
def supports(self, capability: ModelCapability) -> bool:
|
||||||
"""Check if model supports a specific capability."""
|
"""Check if model supports a specific capability."""
|
||||||
return capability in self.capabilities
|
return capability in self.capabilities
|
||||||
@@ -142,26 +288,26 @@ class ModelInfo:
|
|||||||
|
|
||||||
class MultiModalManager:
|
class MultiModalManager:
|
||||||
"""Manages multi-modal model capabilities and fallback chains.
|
"""Manages multi-modal model capabilities and fallback chains.
|
||||||
|
|
||||||
This class:
|
This class:
|
||||||
1. Detects what capabilities each model has
|
1. Detects what capabilities each model has
|
||||||
2. Maintains fallback chains for different capabilities
|
2. Maintains fallback chains for different capabilities
|
||||||
3. Pulls models on-demand with automatic fallback
|
3. Pulls models on-demand with automatic fallback
|
||||||
4. Routes requests to appropriate models based on content type
|
4. Routes requests to appropriate models based on content type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||||
self.ollama_url = ollama_url or settings.ollama_url
|
self.ollama_url = ollama_url or settings.ollama_url
|
||||||
self._available_models: dict[str, ModelInfo] = {}
|
self._available_models: dict[str, ModelInfo] = {}
|
||||||
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||||
self._refresh_available_models()
|
self._refresh_available_models()
|
||||||
|
|
||||||
def _refresh_available_models(self) -> None:
|
def _refresh_available_models(self) -> None:
|
||||||
"""Query Ollama for available models."""
|
"""Query Ollama for available models."""
|
||||||
try:
|
try:
|
||||||
import urllib.request
|
|
||||||
import json
|
import json
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
f"{url}/api/tags",
|
f"{url}/api/tags",
|
||||||
@@ -170,7 +316,7 @@ class MultiModalManager:
|
|||||||
)
|
)
|
||||||
with urllib.request.urlopen(req, timeout=5) as response:
|
with urllib.request.urlopen(req, timeout=5) as response:
|
||||||
data = json.loads(response.read().decode())
|
data = json.loads(response.read().decode())
|
||||||
|
|
||||||
for model_data in data.get("models", []):
|
for model_data in data.get("models", []):
|
||||||
name = model_data.get("name", "")
|
name = model_data.get("name", "")
|
||||||
self._available_models[name] = ModelInfo(
|
self._available_models[name] = ModelInfo(
|
||||||
@@ -181,58 +327,53 @@ class MultiModalManager:
|
|||||||
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||||
description=model_data.get("details", {}).get("family", ""),
|
description=model_data.get("details", {}).get("family", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Found %d models in Ollama", len(self._available_models))
|
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not refresh available models: %s", exc)
|
logger.warning("Could not refresh available models: %s", exc)
|
||||||
|
|
||||||
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||||
"""Detect capabilities for a model based on known data."""
|
"""Detect capabilities for a model based on known data."""
|
||||||
# Normalize model name (strip tags for lookup)
|
# Normalize model name (strip tags for lookup)
|
||||||
base_name = model_name.split(":")[0]
|
base_name = model_name.split(":")[0]
|
||||||
|
|
||||||
# Try exact match first
|
# Try exact match first
|
||||||
if model_name in KNOWN_MODEL_CAPABILITIES:
|
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||||
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||||
|
|
||||||
# Try base name match
|
# Try base name match
|
||||||
if base_name in KNOWN_MODEL_CAPABILITIES:
|
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||||
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||||
|
|
||||||
# Default to text-only for unknown models
|
# Default to text-only for unknown models
|
||||||
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||||
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||||
|
|
||||||
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||||
"""Get capabilities for a specific model."""
|
"""Get capabilities for a specific model."""
|
||||||
if model_name in self._available_models:
|
if model_name in self._available_models:
|
||||||
return self._available_models[model_name].capabilities
|
return self._available_models[model_name].capabilities
|
||||||
return self._detect_capabilities(model_name)
|
return self._detect_capabilities(model_name)
|
||||||
|
|
||||||
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||||
"""Check if a model supports a specific capability."""
|
"""Check if a model supports a specific capability."""
|
||||||
capabilities = self.get_model_capabilities(model_name)
|
capabilities = self.get_model_capabilities(model_name)
|
||||||
return capability in capabilities
|
return capability in capabilities
|
||||||
|
|
||||||
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||||
"""Get all available models that support a capability."""
|
"""Get all available models that support a capability."""
|
||||||
return [
|
return [info for info in self._available_models.values() if capability in info.capabilities]
|
||||||
info for info in self._available_models.values()
|
|
||||||
if capability in info.capabilities
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_best_model_for(
|
def get_best_model_for(
|
||||||
self,
|
self, capability: ModelCapability, preferred_model: Optional[str] = None
|
||||||
capability: ModelCapability,
|
|
||||||
preferred_model: Optional[str] = None
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Get the best available model for a specific capability.
|
"""Get the best available model for a specific capability.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
capability: The required capability
|
capability: The required capability
|
||||||
preferred_model: Preferred model to use if available and capable
|
preferred_model: Preferred model to use if available and capable
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model name or None if no suitable model found
|
Model name or None if no suitable model found
|
||||||
"""
|
"""
|
||||||
@@ -243,25 +384,26 @@ class MultiModalManager:
|
|||||||
return preferred_model
|
return preferred_model
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Preferred model %s doesn't support %s, checking fallbacks",
|
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||||
preferred_model, capability.name
|
preferred_model,
|
||||||
|
capability.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check fallback chain for this capability
|
# Check fallback chain for this capability
|
||||||
fallback_chain = self._fallback_chains.get(capability, [])
|
fallback_chain = self._fallback_chains.get(capability, [])
|
||||||
for model_name in fallback_chain:
|
for model_name in fallback_chain:
|
||||||
if model_name in self._available_models:
|
if model_name in self._available_models:
|
||||||
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
# Find any available model with this capability
|
# Find any available model with this capability
|
||||||
capable_models = self.get_models_with_capability(capability)
|
capable_models = self.get_models_with_capability(capability)
|
||||||
if capable_models:
|
if capable_models:
|
||||||
# Sort by size (prefer smaller/faster models as fallback)
|
# Sort by size (prefer smaller/faster models as fallback)
|
||||||
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
capable_models.sort(key=lambda m: m.size_mb or float("inf"))
|
||||||
return capable_models[0].name
|
return capable_models[0].name
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def pull_model_with_fallback(
|
def pull_model_with_fallback(
|
||||||
self,
|
self,
|
||||||
primary_model: str,
|
primary_model: str,
|
||||||
@@ -269,58 +411,58 @@ class MultiModalManager:
|
|||||||
auto_pull: bool = True,
|
auto_pull: bool = True,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Pull a model with automatic fallback if unavailable.
|
"""Pull a model with automatic fallback if unavailable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
primary_model: The desired model to use
|
primary_model: The desired model to use
|
||||||
capability: Required capability (for finding fallback)
|
capability: Required capability (for finding fallback)
|
||||||
auto_pull: Whether to attempt pulling missing models
|
auto_pull: Whether to attempt pulling missing models
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (model_name, is_fallback)
|
Tuple of (model_name, is_fallback)
|
||||||
"""
|
"""
|
||||||
# Check if primary model is already available
|
# Check if primary model is already available
|
||||||
if primary_model in self._available_models:
|
if primary_model in self._available_models:
|
||||||
return primary_model, False
|
return primary_model, False
|
||||||
|
|
||||||
# Try to pull the primary model
|
# Try to pull the primary model
|
||||||
if auto_pull:
|
if auto_pull:
|
||||||
if self._pull_model(primary_model):
|
if self._pull_model(primary_model):
|
||||||
return primary_model, False
|
return primary_model, False
|
||||||
|
|
||||||
# Need to find a fallback
|
# Need to find a fallback
|
||||||
if capability:
|
if capability:
|
||||||
fallback = self.get_best_model_for(capability, primary_model)
|
fallback = self.get_best_model_for(capability, primary_model)
|
||||||
if fallback:
|
if fallback:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Primary model %s unavailable, using fallback %s",
|
"Primary model %s unavailable, using fallback %s", primary_model, fallback
|
||||||
primary_model, fallback
|
|
||||||
)
|
)
|
||||||
return fallback, True
|
return fallback, True
|
||||||
|
|
||||||
# Last resort: use the configured default model
|
# Last resort: use the configured default model
|
||||||
default_model = settings.ollama_model
|
default_model = settings.ollama_model
|
||||||
if default_model in self._available_models:
|
if default_model in self._available_models:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Falling back to default model %s (primary: %s unavailable)",
|
"Falling back to default model %s (primary: %s unavailable)",
|
||||||
default_model, primary_model
|
default_model,
|
||||||
|
primary_model,
|
||||||
)
|
)
|
||||||
return default_model, True
|
return default_model, True
|
||||||
|
|
||||||
# Absolute last resort
|
# Absolute last resort
|
||||||
return primary_model, False
|
return primary_model, False
|
||||||
|
|
||||||
def _pull_model(self, model_name: str) -> bool:
|
def _pull_model(self, model_name: str) -> bool:
|
||||||
"""Attempt to pull a model from Ollama.
|
"""Attempt to pull a model from Ollama.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful or model already exists
|
True if successful or model already exists
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import urllib.request
|
|
||||||
import json
|
import json
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
logger.info("Pulling model: %s", model_name)
|
logger.info("Pulling model: %s", model_name)
|
||||||
|
|
||||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
f"{url}/api/pull",
|
f"{url}/api/pull",
|
||||||
@@ -328,7 +470,7 @@ class MultiModalManager:
|
|||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with urllib.request.urlopen(req, timeout=300) as response:
|
with urllib.request.urlopen(req, timeout=300) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
logger.info("Successfully pulled model: %s", model_name)
|
logger.info("Successfully pulled model: %s", model_name)
|
||||||
@@ -338,55 +480,51 @@ class MultiModalManager:
|
|||||||
else:
|
else:
|
||||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def configure_fallback_chain(
|
def configure_fallback_chain(self, capability: ModelCapability, models: list[str]) -> None:
|
||||||
self,
|
|
||||||
capability: ModelCapability,
|
|
||||||
models: list[str]
|
|
||||||
) -> None:
|
|
||||||
"""Configure a custom fallback chain for a capability."""
|
"""Configure a custom fallback chain for a capability."""
|
||||||
self._fallback_chains[capability] = models
|
self._fallback_chains[capability] = models
|
||||||
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||||
|
|
||||||
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||||
"""Get the fallback chain for a capability."""
|
"""Get the fallback chain for a capability."""
|
||||||
return list(self._fallback_chains.get(capability, []))
|
return list(self._fallback_chains.get(capability, []))
|
||||||
|
|
||||||
def list_available_models(self) -> list[ModelInfo]:
|
def list_available_models(self) -> list[ModelInfo]:
|
||||||
"""List all available models with their capabilities."""
|
"""List all available models with their capabilities."""
|
||||||
return list(self._available_models.values())
|
return list(self._available_models.values())
|
||||||
|
|
||||||
def refresh(self) -> None:
|
def refresh(self) -> None:
|
||||||
"""Refresh the list of available models."""
|
"""Refresh the list of available models."""
|
||||||
self._refresh_available_models()
|
self._refresh_available_models()
|
||||||
|
|
||||||
def get_model_for_content(
|
def get_model_for_content(
|
||||||
self,
|
self,
|
||||||
content_type: str, # "text", "image", "audio", "multimodal"
|
content_type: str, # "text", "image", "audio", "multimodal"
|
||||||
preferred_model: Optional[str] = None,
|
preferred_model: Optional[str] = None,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Get appropriate model based on content type.
|
"""Get appropriate model based on content type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content_type: Type of content (text, image, audio, multimodal)
|
content_type: Type of content (text, image, audio, multimodal)
|
||||||
preferred_model: User's preferred model
|
preferred_model: User's preferred model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (model_name, is_fallback)
|
Tuple of (model_name, is_fallback)
|
||||||
"""
|
"""
|
||||||
content_type = content_type.lower()
|
content_type = content_type.lower()
|
||||||
|
|
||||||
if content_type in ("image", "vision", "multimodal"):
|
if content_type in ("image", "vision", "multimodal"):
|
||||||
# For vision content, we need a vision-capable model
|
# For vision content, we need a vision-capable model
|
||||||
return self.pull_model_with_fallback(
|
return self.pull_model_with_fallback(
|
||||||
preferred_model or "llava:7b",
|
preferred_model or "llava:7b",
|
||||||
capability=ModelCapability.VISION,
|
capability=ModelCapability.VISION,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif content_type == "audio":
|
elif content_type == "audio":
|
||||||
# Audio support is limited in Ollama
|
# Audio support is limited in Ollama
|
||||||
# Would need specific audio models
|
# Would need specific audio models
|
||||||
@@ -395,7 +533,7 @@ class MultiModalManager:
|
|||||||
preferred_model or settings.ollama_model,
|
preferred_model or settings.ollama_model,
|
||||||
capability=ModelCapability.TEXT,
|
capability=ModelCapability.TEXT,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard text content
|
# Standard text content
|
||||||
return self.pull_model_with_fallback(
|
return self.pull_model_with_fallback(
|
||||||
@@ -417,8 +555,7 @@ def get_multimodal_manager() -> MultiModalManager:
|
|||||||
|
|
||||||
|
|
||||||
def get_model_for_capability(
|
def get_model_for_capability(
|
||||||
capability: ModelCapability,
|
capability: ModelCapability, preferred_model: Optional[str] = None
|
||||||
preferred_model: Optional[str] = None
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Convenience function to get best model for a capability."""
|
"""Convenience function to get best model for a capability."""
|
||||||
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||||
@@ -430,9 +567,7 @@ def pull_model_with_fallback(
|
|||||||
auto_pull: bool = True,
|
auto_pull: bool = True,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Convenience function to pull model with fallback."""
|
"""Convenience function to pull model with fallback."""
|
||||||
return get_multimodal_manager().pull_model_with_fallback(
|
return get_multimodal_manager().pull_model_with_fallback(primary_model, capability, auto_pull)
|
||||||
primary_model, capability, auto_pull
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def model_supports_vision(model_name: str) -> bool:
|
def model_supports_vision(model_name: str) -> bool:
|
||||||
|
|||||||
@@ -26,26 +26,29 @@ DB_PATH = Path("data/swarm.db")
|
|||||||
|
|
||||||
class ModelFormat(str, Enum):
|
class ModelFormat(str, Enum):
|
||||||
"""Supported model weight formats."""
|
"""Supported model weight formats."""
|
||||||
GGUF = "gguf" # Ollama-compatible quantised weights
|
|
||||||
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
GGUF = "gguf" # Ollama-compatible quantised weights
|
||||||
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
SAFETENSORS = "safetensors" # HuggingFace safetensors
|
||||||
OLLAMA = "ollama" # Already loaded in Ollama by name
|
HF_CHECKPOINT = "hf" # Full HuggingFace checkpoint directory
|
||||||
|
OLLAMA = "ollama" # Already loaded in Ollama by name
|
||||||
|
|
||||||
|
|
||||||
class ModelRole(str, Enum):
|
class ModelRole(str, Enum):
|
||||||
"""Role a model can play in the system (OpenClaw-RL style)."""
|
"""Role a model can play in the system (OpenClaw-RL style)."""
|
||||||
GENERAL = "general" # Default agent inference
|
|
||||||
REWARD = "reward" # Process Reward Model (PRM) scoring
|
GENERAL = "general" # Default agent inference
|
||||||
TEACHER = "teacher" # On-policy distillation teacher
|
REWARD = "reward" # Process Reward Model (PRM) scoring
|
||||||
JUDGE = "judge" # Output quality evaluation
|
TEACHER = "teacher" # On-policy distillation teacher
|
||||||
|
JUDGE = "judge" # Output quality evaluation
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CustomModel:
|
class CustomModel:
|
||||||
"""A registered custom model."""
|
"""A registered custom model."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
format: ModelFormat
|
format: ModelFormat
|
||||||
path: str # Absolute path or Ollama model name
|
path: str # Absolute path or Ollama model name
|
||||||
role: ModelRole = ModelRole.GENERAL
|
role: ModelRole = ModelRole.GENERAL
|
||||||
context_window: int = 4096
|
context_window: int = 4096
|
||||||
description: str = ""
|
description: str = ""
|
||||||
@@ -141,10 +144,16 @@ class ModelRegistry:
|
|||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
model.name, model.format.value, model.path,
|
model.name,
|
||||||
model.role.value, model.context_window, model.description,
|
model.format.value,
|
||||||
model.registered_at, int(model.active),
|
model.path,
|
||||||
model.default_temperature, model.max_tokens,
|
model.role.value,
|
||||||
|
model.context_window,
|
||||||
|
model.description,
|
||||||
|
model.registered_at,
|
||||||
|
int(model.active),
|
||||||
|
model.default_temperature,
|
||||||
|
model.max_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -160,9 +169,7 @@ class ModelRegistry:
|
|||||||
return False
|
return False
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
conn.execute("DELETE FROM custom_models WHERE name = ?", (name,))
|
||||||
conn.execute(
|
conn.execute("DELETE FROM agent_model_assignments WHERE model_name = ?", (name,))
|
||||||
"DELETE FROM agent_model_assignments WHERE model_name = ?", (name,)
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
del self._models[name]
|
del self._models[name]
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ No cloud push services — everything stays local.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
|
||||||
import platform
|
import platform
|
||||||
|
import subprocess
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -25,9 +25,7 @@ class Notification:
|
|||||||
title: str
|
title: str
|
||||||
message: str
|
message: str
|
||||||
category: str # swarm | task | agent | system | payment
|
category: str # swarm | task | agent | system | payment
|
||||||
timestamp: str = field(
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
|
||||||
)
|
|
||||||
read: bool = False
|
read: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -74,9 +72,11 @@ class PushNotifier:
|
|||||||
def _native_notify(self, title: str, message: str) -> None:
|
def _native_notify(self, title: str, message: str) -> None:
|
||||||
"""Send a native macOS notification via osascript."""
|
"""Send a native macOS notification via osascript."""
|
||||||
try:
|
try:
|
||||||
|
safe_message = message.replace("\\", "\\\\").replace('"', '\\"')
|
||||||
|
safe_title = title.replace("\\", "\\\\").replace('"', '\\"')
|
||||||
script = (
|
script = (
|
||||||
f'display notification "{message}" '
|
f'display notification "{safe_message}" '
|
||||||
f'with title "Agent Dashboard" subtitle "{title}"'
|
f'with title "Agent Dashboard" subtitle "{safe_title}"'
|
||||||
)
|
)
|
||||||
subprocess.Popen(
|
subprocess.Popen(
|
||||||
["osascript", "-e", script],
|
["osascript", "-e", script],
|
||||||
@@ -114,7 +114,7 @@ class PushNotifier:
|
|||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self._notifications.clear()
|
self._notifications.clear()
|
||||||
|
|
||||||
def add_listener(self, callback) -> None:
|
def add_listener(self, callback: "Callable[[Notification], None]") -> None:
|
||||||
"""Register a callback for real-time notification delivery."""
|
"""Register a callback for real-time notification delivery."""
|
||||||
self._listeners.append(callback)
|
self._listeners.append(callback)
|
||||||
|
|
||||||
@@ -139,10 +139,7 @@ async def notify_briefing_ready(briefing) -> None:
|
|||||||
logger.info("Briefing ready but no pending approvals — skipping native notification")
|
logger.info("Briefing ready but no pending approvals — skipping native notification")
|
||||||
return
|
return
|
||||||
|
|
||||||
message = (
|
message = f"Your morning briefing is ready. " f"{n_approvals} item(s) await your approval."
|
||||||
f"Your morning briefing is ready. "
|
|
||||||
f"{n_approvals} item(s) await your approval."
|
|
||||||
)
|
|
||||||
notifier.notify(
|
notifier.notify(
|
||||||
title="Morning Briefing Ready",
|
title="Morning Briefing Ready",
|
||||||
message=message,
|
message=message,
|
||||||
|
|||||||
@@ -156,33 +156,23 @@ class OpenFangClient:
|
|||||||
|
|
||||||
async def browse(self, url: str, instruction: str = "") -> HandResult:
|
async def browse(self, url: str, instruction: str = "") -> HandResult:
|
||||||
"""Web automation via OpenFang's Browser hand."""
|
"""Web automation via OpenFang's Browser hand."""
|
||||||
return await self.execute_hand(
|
return await self.execute_hand("browser", {"url": url, "instruction": instruction})
|
||||||
"browser", {"url": url, "instruction": instruction}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def collect(self, target: str, depth: str = "shallow") -> HandResult:
|
async def collect(self, target: str, depth: str = "shallow") -> HandResult:
|
||||||
"""OSINT collection via OpenFang's Collector hand."""
|
"""OSINT collection via OpenFang's Collector hand."""
|
||||||
return await self.execute_hand(
|
return await self.execute_hand("collector", {"target": target, "depth": depth})
|
||||||
"collector", {"target": target, "depth": depth}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def predict(self, question: str, horizon: str = "1w") -> HandResult:
|
async def predict(self, question: str, horizon: str = "1w") -> HandResult:
|
||||||
"""Superforecasting via OpenFang's Predictor hand."""
|
"""Superforecasting via OpenFang's Predictor hand."""
|
||||||
return await self.execute_hand(
|
return await self.execute_hand("predictor", {"question": question, "horizon": horizon})
|
||||||
"predictor", {"question": question, "horizon": horizon}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
|
async def find_leads(self, icp: str, max_results: int = 10) -> HandResult:
|
||||||
"""Prospect discovery via OpenFang's Lead hand."""
|
"""Prospect discovery via OpenFang's Lead hand."""
|
||||||
return await self.execute_hand(
|
return await self.execute_hand("lead", {"icp": icp, "max_results": max_results})
|
||||||
"lead", {"icp": icp, "max_results": max_results}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def research(self, topic: str, depth: str = "standard") -> HandResult:
|
async def research(self, topic: str, depth: str = "standard") -> HandResult:
|
||||||
"""Deep research via OpenFang's Researcher hand."""
|
"""Deep research via OpenFang's Researcher hand."""
|
||||||
return await self.execute_hand(
|
return await self.execute_hand("researcher", {"topic": topic, "depth": depth})
|
||||||
"researcher", {"topic": topic, "depth": depth}
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Inventory ────────────────────────────────────────────────────────────
|
# ── Inventory ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ from infrastructure.openfang.client import OPENFANG_HANDS, openfang_client
|
|||||||
try:
|
try:
|
||||||
from mcp.schemas.base import create_tool_schema
|
from mcp.schemas.base import create_tool_schema
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def create_tool_schema(**kwargs):
|
def create_tool_schema(**kwargs):
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
# ── Tool schemas ─────────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Cascade LLM Router — Automatic failover between providers."""
|
"""Cascade LLM Router — Automatic failover between providers."""
|
||||||
|
|
||||||
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
|
||||||
from .api import router
|
from .api import router
|
||||||
|
from .cascade import CascadeRouter, Provider, ProviderStatus, get_router
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CascadeRouter",
|
"CascadeRouter",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ router = APIRouter(prefix="/api/v1/router", tags=["router"])
|
|||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
"""Request body for completions."""
|
"""Request body for completions."""
|
||||||
|
|
||||||
messages: list[dict[str, str]]
|
messages: list[dict[str, str]]
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
@@ -23,6 +24,7 @@ class CompletionRequest(BaseModel):
|
|||||||
|
|
||||||
class CompletionResponse(BaseModel):
|
class CompletionResponse(BaseModel):
|
||||||
"""Response from completion endpoint."""
|
"""Response from completion endpoint."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
provider: str
|
provider: str
|
||||||
model: str
|
model: str
|
||||||
@@ -31,6 +33,7 @@ class CompletionResponse(BaseModel):
|
|||||||
|
|
||||||
class ProviderControl(BaseModel):
|
class ProviderControl(BaseModel):
|
||||||
"""Control a provider's status."""
|
"""Control a provider's status."""
|
||||||
|
|
||||||
action: str # "enable", "disable", "reset_circuit"
|
action: str # "enable", "disable", "reset_circuit"
|
||||||
|
|
||||||
|
|
||||||
@@ -45,7 +48,7 @@ async def complete(
|
|||||||
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
cascade: Annotated[CascadeRouter, Depends(get_cascade_router)],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Complete a conversation with automatic failover.
|
"""Complete a conversation with automatic failover.
|
||||||
|
|
||||||
Routes through providers in priority order until one succeeds.
|
Routes through providers in priority order until one succeeds.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@@ -108,30 +111,32 @@ async def control_provider(
|
|||||||
if p.name == provider_name:
|
if p.name == provider_name:
|
||||||
provider = p
|
provider = p
|
||||||
break
|
break
|
||||||
|
|
||||||
if not provider:
|
if not provider:
|
||||||
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
raise HTTPException(status_code=404, detail=f"Provider {provider_name} not found")
|
||||||
|
|
||||||
if control.action == "enable":
|
if control.action == "enable":
|
||||||
provider.enabled = True
|
provider.enabled = True
|
||||||
provider.status = provider.status.__class__.HEALTHY
|
provider.status = provider.status.__class__.HEALTHY
|
||||||
return {"message": f"Provider {provider_name} enabled"}
|
return {"message": f"Provider {provider_name} enabled"}
|
||||||
|
|
||||||
elif control.action == "disable":
|
elif control.action == "disable":
|
||||||
provider.enabled = False
|
provider.enabled = False
|
||||||
from .cascade import ProviderStatus
|
from .cascade import ProviderStatus
|
||||||
|
|
||||||
provider.status = ProviderStatus.DISABLED
|
provider.status = ProviderStatus.DISABLED
|
||||||
return {"message": f"Provider {provider_name} disabled"}
|
return {"message": f"Provider {provider_name} disabled"}
|
||||||
|
|
||||||
elif control.action == "reset_circuit":
|
elif control.action == "reset_circuit":
|
||||||
from .cascade import CircuitState, ProviderStatus
|
from .cascade import CircuitState, ProviderStatus
|
||||||
|
|
||||||
provider.circuit_state = CircuitState.CLOSED
|
provider.circuit_state = CircuitState.CLOSED
|
||||||
provider.circuit_opened_at = None
|
provider.circuit_opened_at = None
|
||||||
provider.half_open_calls = 0
|
provider.half_open_calls = 0
|
||||||
provider.metrics.consecutive_failures = 0
|
provider.metrics.consecutive_failures = 0
|
||||||
provider.status = ProviderStatus.HEALTHY
|
provider.status = ProviderStatus.HEALTHY
|
||||||
return {"message": f"Circuit breaker reset for {provider_name}"}
|
return {"message": f"Circuit breaker reset for {provider_name}"}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
|
raise HTTPException(status_code=400, detail=f"Unknown action: {control.action}")
|
||||||
|
|
||||||
@@ -142,28 +147,35 @@ async def run_health_check(
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Run health checks on all providers."""
|
"""Run health checks on all providers."""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for provider in cascade.providers:
|
for provider in cascade.providers:
|
||||||
# Quick ping to check availability
|
# Quick ping to check availability
|
||||||
is_healthy = cascade._check_provider_available(provider)
|
is_healthy = cascade._check_provider_available(provider)
|
||||||
|
|
||||||
from .cascade import ProviderStatus
|
from .cascade import ProviderStatus
|
||||||
|
|
||||||
if is_healthy:
|
if is_healthy:
|
||||||
if provider.status == ProviderStatus.UNHEALTHY:
|
if provider.status == ProviderStatus.UNHEALTHY:
|
||||||
# Reset circuit if it was open but now healthy
|
# Reset circuit if it was open but now healthy
|
||||||
provider.circuit_state = provider.circuit_state.__class__.CLOSED
|
provider.circuit_state = provider.circuit_state.__class__.CLOSED
|
||||||
provider.circuit_opened_at = None
|
provider.circuit_opened_at = None
|
||||||
provider.status = ProviderStatus.HEALTHY if provider.metrics.error_rate < 0.1 else ProviderStatus.DEGRADED
|
provider.status = (
|
||||||
|
ProviderStatus.HEALTHY
|
||||||
|
if provider.metrics.error_rate < 0.1
|
||||||
|
else ProviderStatus.DEGRADED
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
provider.status = ProviderStatus.UNHEALTHY
|
provider.status = ProviderStatus.UNHEALTHY
|
||||||
|
|
||||||
results.append({
|
results.append(
|
||||||
"name": provider.name,
|
{
|
||||||
"type": provider.type,
|
"name": provider.name,
|
||||||
"healthy": is_healthy,
|
"type": provider.type,
|
||||||
"status": provider.status.value,
|
"healthy": is_healthy,
|
||||||
})
|
"status": provider.status.value,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"checked_at": asyncio.get_event_loop().time(),
|
"checked_at": asyncio.get_event_loop().time(),
|
||||||
"providers": results,
|
"providers": results,
|
||||||
@@ -177,7 +189,7 @@ async def get_config(
|
|||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Get router configuration (without secrets)."""
|
"""Get router configuration (without secrets)."""
|
||||||
cfg = cascade.config
|
cfg = cascade.config
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"timeout_seconds": cfg.timeout_seconds,
|
"timeout_seconds": cfg.timeout_seconds,
|
||||||
"max_retries_per_provider": cfg.max_retries_per_provider,
|
"max_retries_per_provider": cfg.max_retries_per_provider,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ProviderStatus(Enum):
|
class ProviderStatus(Enum):
|
||||||
"""Health status of a provider."""
|
"""Health status of a provider."""
|
||||||
|
|
||||||
HEALTHY = "healthy"
|
HEALTHY = "healthy"
|
||||||
DEGRADED = "degraded" # Working but slow or occasional errors
|
DEGRADED = "degraded" # Working but slow or occasional errors
|
||||||
UNHEALTHY = "unhealthy" # Circuit breaker open
|
UNHEALTHY = "unhealthy" # Circuit breaker open
|
||||||
@@ -41,22 +42,25 @@ class ProviderStatus(Enum):
|
|||||||
|
|
||||||
class CircuitState(Enum):
|
class CircuitState(Enum):
|
||||||
"""Circuit breaker state."""
|
"""Circuit breaker state."""
|
||||||
CLOSED = "closed" # Normal operation
|
|
||||||
OPEN = "open" # Failing, rejecting requests
|
CLOSED = "closed" # Normal operation
|
||||||
|
OPEN = "open" # Failing, rejecting requests
|
||||||
HALF_OPEN = "half_open" # Testing if recovered
|
HALF_OPEN = "half_open" # Testing if recovered
|
||||||
|
|
||||||
|
|
||||||
class ContentType(Enum):
|
class ContentType(Enum):
|
||||||
"""Type of content in the request."""
|
"""Type of content in the request."""
|
||||||
|
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
VISION = "vision" # Contains images
|
VISION = "vision" # Contains images
|
||||||
AUDIO = "audio" # Contains audio
|
AUDIO = "audio" # Contains audio
|
||||||
MULTIMODAL = "multimodal" # Multiple content types
|
MULTIMODAL = "multimodal" # Multiple content types
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProviderMetrics:
|
class ProviderMetrics:
|
||||||
"""Metrics for a single provider."""
|
"""Metrics for a single provider."""
|
||||||
|
|
||||||
total_requests: int = 0
|
total_requests: int = 0
|
||||||
successful_requests: int = 0
|
successful_requests: int = 0
|
||||||
failed_requests: int = 0
|
failed_requests: int = 0
|
||||||
@@ -64,13 +68,13 @@ class ProviderMetrics:
|
|||||||
last_request_time: Optional[str] = None
|
last_request_time: Optional[str] = None
|
||||||
last_error_time: Optional[str] = None
|
last_error_time: Optional[str] = None
|
||||||
consecutive_failures: int = 0
|
consecutive_failures: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def avg_latency_ms(self) -> float:
|
def avg_latency_ms(self) -> float:
|
||||||
if self.total_requests == 0:
|
if self.total_requests == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return self.total_latency_ms / self.total_requests
|
return self.total_latency_ms / self.total_requests
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def error_rate(self) -> float:
|
def error_rate(self) -> float:
|
||||||
if self.total_requests == 0:
|
if self.total_requests == 0:
|
||||||
@@ -81,6 +85,7 @@ class ProviderMetrics:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelCapability:
|
class ModelCapability:
|
||||||
"""Capabilities a model supports."""
|
"""Capabilities a model supports."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
supports_vision: bool = False
|
supports_vision: bool = False
|
||||||
supports_audio: bool = False
|
supports_audio: bool = False
|
||||||
@@ -93,6 +98,7 @@ class ModelCapability:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Provider:
|
class Provider:
|
||||||
"""LLM provider configuration and state."""
|
"""LLM provider configuration and state."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
type: str # ollama, openai, anthropic, airllm
|
type: str # ollama, openai, anthropic, airllm
|
||||||
enabled: bool
|
enabled: bool
|
||||||
@@ -101,14 +107,14 @@ class Provider:
|
|||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
models: list[dict] = field(default_factory=list)
|
models: list[dict] = field(default_factory=list)
|
||||||
|
|
||||||
# Runtime state
|
# Runtime state
|
||||||
status: ProviderStatus = ProviderStatus.HEALTHY
|
status: ProviderStatus = ProviderStatus.HEALTHY
|
||||||
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
metrics: ProviderMetrics = field(default_factory=ProviderMetrics)
|
||||||
circuit_state: CircuitState = CircuitState.CLOSED
|
circuit_state: CircuitState = CircuitState.CLOSED
|
||||||
circuit_opened_at: Optional[float] = None
|
circuit_opened_at: Optional[float] = None
|
||||||
half_open_calls: int = 0
|
half_open_calls: int = 0
|
||||||
|
|
||||||
def get_default_model(self) -> Optional[str]:
|
def get_default_model(self) -> Optional[str]:
|
||||||
"""Get the default model for this provider."""
|
"""Get the default model for this provider."""
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
@@ -117,7 +123,7 @@ class Provider:
|
|||||||
if self.models:
|
if self.models:
|
||||||
return self.models[0]["name"]
|
return self.models[0]["name"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||||
"""Get a model that supports the given capability."""
|
"""Get a model that supports the given capability."""
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
@@ -126,7 +132,7 @@ class Provider:
|
|||||||
return model["name"]
|
return model["name"]
|
||||||
# Fall back to default
|
# Fall back to default
|
||||||
return self.get_default_model()
|
return self.get_default_model()
|
||||||
|
|
||||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||||
"""Check if a specific model has a capability."""
|
"""Check if a specific model has a capability."""
|
||||||
for model in self.models:
|
for model in self.models:
|
||||||
@@ -139,6 +145,7 @@ class Provider:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RouterConfig:
|
class RouterConfig:
|
||||||
"""Cascade router configuration."""
|
"""Cascade router configuration."""
|
||||||
|
|
||||||
timeout_seconds: int = 30
|
timeout_seconds: int = 30
|
||||||
max_retries_per_provider: int = 2
|
max_retries_per_provider: int = 2
|
||||||
retry_delay_seconds: int = 1
|
retry_delay_seconds: int = 1
|
||||||
@@ -154,22 +161,22 @@ class RouterConfig:
|
|||||||
|
|
||||||
class CascadeRouter:
|
class CascadeRouter:
|
||||||
"""Routes LLM requests with automatic failover.
|
"""Routes LLM requests with automatic failover.
|
||||||
|
|
||||||
Now with multi-modal support:
|
Now with multi-modal support:
|
||||||
- Automatically detects content type (text, vision, audio)
|
- Automatically detects content type (text, vision, audio)
|
||||||
- Selects appropriate models based on capabilities
|
- Selects appropriate models based on capabilities
|
||||||
- Falls back through capability-specific model chains
|
- Falls back through capability-specific model chains
|
||||||
- Supports image URLs and base64 encoding
|
- Supports image URLs and base64 encoding
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
router = CascadeRouter()
|
router = CascadeRouter()
|
||||||
|
|
||||||
# Text request
|
# Text request
|
||||||
response = await router.complete(
|
response = await router.complete(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
model="llama3.2"
|
model="llama3.2"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Vision request (automatically detects and selects vision model)
|
# Vision request (automatically detects and selects vision model)
|
||||||
response = await router.complete(
|
response = await router.complete(
|
||||||
messages=[{
|
messages=[{
|
||||||
@@ -179,68 +186,75 @@ class CascadeRouter:
|
|||||||
}],
|
}],
|
||||||
model="llava:7b"
|
model="llava:7b"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
metrics = router.get_metrics()
|
metrics = router.get_metrics()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[Path] = None) -> None:
|
def __init__(self, config_path: Optional[Path] = None) -> None:
|
||||||
self.config_path = config_path or Path("config/providers.yaml")
|
self.config_path = config_path or Path("config/providers.yaml")
|
||||||
self.providers: list[Provider] = []
|
self.providers: list[Provider] = []
|
||||||
self.config: RouterConfig = RouterConfig()
|
self.config: RouterConfig = RouterConfig()
|
||||||
self._load_config()
|
self._load_config()
|
||||||
|
|
||||||
# Initialize multi-modal manager if available
|
# Initialize multi-modal manager if available
|
||||||
self._mm_manager: Optional[Any] = None
|
self._mm_manager: Optional[Any] = None
|
||||||
try:
|
try:
|
||||||
from infrastructure.models.multimodal import get_multimodal_manager
|
from infrastructure.models.multimodal import get_multimodal_manager
|
||||||
|
|
||||||
self._mm_manager = get_multimodal_manager()
|
self._mm_manager = get_multimodal_manager()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Multi-modal manager not available: %s", exc)
|
logger.debug("Multi-modal manager not available: %s", exc)
|
||||||
|
|
||||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||||
|
|
||||||
def _load_config(self) -> None:
|
def _load_config(self) -> None:
|
||||||
"""Load configuration from YAML."""
|
"""Load configuration from YAML."""
|
||||||
if not self.config_path.exists():
|
if not self.config_path.exists():
|
||||||
logger.warning("Config not found: %s, using defaults", self.config_path)
|
logger.warning("Config not found: %s, using defaults", self.config_path)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if yaml is None:
|
if yaml is None:
|
||||||
raise RuntimeError("PyYAML not installed")
|
raise RuntimeError("PyYAML not installed")
|
||||||
|
|
||||||
content = self.config_path.read_text()
|
content = self.config_path.read_text()
|
||||||
# Expand environment variables
|
# Expand environment variables
|
||||||
content = self._expand_env_vars(content)
|
content = self._expand_env_vars(content)
|
||||||
data = yaml.safe_load(content)
|
data = yaml.safe_load(content)
|
||||||
|
|
||||||
# Load cascade settings
|
# Load cascade settings
|
||||||
cascade = data.get("cascade", {})
|
cascade = data.get("cascade", {})
|
||||||
|
|
||||||
# Load fallback chains
|
# Load fallback chains
|
||||||
fallback_chains = data.get("fallback_chains", {})
|
fallback_chains = data.get("fallback_chains", {})
|
||||||
|
|
||||||
# Load multi-modal settings
|
# Load multi-modal settings
|
||||||
multimodal = data.get("multimodal", {})
|
multimodal = data.get("multimodal", {})
|
||||||
|
|
||||||
self.config = RouterConfig(
|
self.config = RouterConfig(
|
||||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||||
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
retry_delay_seconds=cascade.get("retry_delay_seconds", 1),
|
||||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get(
|
||||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
"failure_threshold", 5
|
||||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
),
|
||||||
|
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get(
|
||||||
|
"recovery_timeout", 60
|
||||||
|
),
|
||||||
|
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get(
|
||||||
|
"half_open_max_calls", 2
|
||||||
|
),
|
||||||
auto_pull_models=multimodal.get("auto_pull", True),
|
auto_pull_models=multimodal.get("auto_pull", True),
|
||||||
fallback_chains=fallback_chains,
|
fallback_chains=fallback_chains,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load providers
|
# Load providers
|
||||||
for p_data in data.get("providers", []):
|
for p_data in data.get("providers", []):
|
||||||
# Skip disabled providers
|
# Skip disabled providers
|
||||||
if not p_data.get("enabled", False):
|
if not p_data.get("enabled", False):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
provider = Provider(
|
provider = Provider(
|
||||||
name=p_data["name"],
|
name=p_data["name"],
|
||||||
type=p_data["type"],
|
type=p_data["type"],
|
||||||
@@ -251,30 +265,34 @@ class CascadeRouter:
|
|||||||
base_url=p_data.get("base_url"),
|
base_url=p_data.get("base_url"),
|
||||||
models=p_data.get("models", []),
|
models=p_data.get("models", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if provider is actually available
|
# Check if provider is actually available
|
||||||
if self._check_provider_available(provider):
|
if self._check_provider_available(provider):
|
||||||
self.providers.append(provider)
|
self.providers.append(provider)
|
||||||
else:
|
else:
|
||||||
logger.warning("Provider %s not available, skipping", provider.name)
|
logger.warning("Provider %s not available, skipping", provider.name)
|
||||||
|
|
||||||
# Sort by priority
|
# Sort by priority
|
||||||
self.providers.sort(key=lambda p: p.priority)
|
self.providers.sort(key=lambda p: p.priority)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Failed to load config: %s", exc)
|
logger.error("Failed to load config: %s", exc)
|
||||||
|
|
||||||
def _expand_env_vars(self, content: str) -> str:
|
def _expand_env_vars(self, content: str) -> str:
|
||||||
"""Expand ${VAR} syntax in YAML content."""
|
"""Expand ${VAR} syntax in YAML content.
|
||||||
|
|
||||||
|
Uses os.environ directly (not settings) because this is a generic
|
||||||
|
YAML config loader that must expand arbitrary variable references.
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
def replace_var(match):
|
def replace_var(match: "re.Match[str]") -> str:
|
||||||
var_name = match.group(1)
|
var_name = match.group(1)
|
||||||
return os.environ.get(var_name, match.group(0))
|
return os.environ.get(var_name, match.group(0))
|
||||||
|
|
||||||
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
return re.sub(r"\$\{(\w+)\}", replace_var, content)
|
||||||
|
|
||||||
def _check_provider_available(self, provider: Provider) -> bool:
|
def _check_provider_available(self, provider: Provider) -> bool:
|
||||||
"""Check if a provider is actually available."""
|
"""Check if a provider is actually available."""
|
||||||
if provider.type == "ollama":
|
if provider.type == "ollama":
|
||||||
@@ -288,48 +306,49 @@ class CascadeRouter:
|
|||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elif provider.type == "airllm":
|
elif provider.type == "airllm":
|
||||||
# Check if airllm is installed
|
# Check if airllm is installed
|
||||||
try:
|
try:
|
||||||
import airllm
|
import airllm
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
elif provider.type in ("openai", "anthropic", "grok"):
|
elif provider.type in ("openai", "anthropic", "grok"):
|
||||||
# Check if API key is set
|
# Check if API key is set
|
||||||
return provider.api_key is not None and provider.api_key != ""
|
return provider.api_key is not None and provider.api_key != ""
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||||
"""Detect the type of content in the messages.
|
"""Detect the type of content in the messages.
|
||||||
|
|
||||||
Checks for images, audio, etc. in the message content.
|
Checks for images, audio, etc. in the message content.
|
||||||
"""
|
"""
|
||||||
has_image = False
|
has_image = False
|
||||||
has_audio = False
|
has_audio = False
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
content = msg.get("content", "")
|
content = msg.get("content", "")
|
||||||
|
|
||||||
# Check for image URLs/paths
|
# Check for image URLs/paths
|
||||||
if msg.get("images"):
|
if msg.get("images"):
|
||||||
has_image = True
|
has_image = True
|
||||||
|
|
||||||
# Check for image URLs in content
|
# Check for image URLs in content
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||||
if any(ext in content.lower() for ext in image_extensions):
|
if any(ext in content.lower() for ext in image_extensions):
|
||||||
has_image = True
|
has_image = True
|
||||||
if content.startswith("data:image/"):
|
if content.startswith("data:image/"):
|
||||||
has_image = True
|
has_image = True
|
||||||
|
|
||||||
# Check for audio
|
# Check for audio
|
||||||
if msg.get("audio"):
|
if msg.get("audio"):
|
||||||
has_audio = True
|
has_audio = True
|
||||||
|
|
||||||
# Check for multimodal content structure
|
# Check for multimodal content structure
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
for item in content:
|
for item in content:
|
||||||
@@ -338,7 +357,7 @@ class CascadeRouter:
|
|||||||
has_image = True
|
has_image = True
|
||||||
elif item.get("type") == "audio":
|
elif item.get("type") == "audio":
|
||||||
has_audio = True
|
has_audio = True
|
||||||
|
|
||||||
if has_image and has_audio:
|
if has_image and has_audio:
|
||||||
return ContentType.MULTIMODAL
|
return ContentType.MULTIMODAL
|
||||||
elif has_image:
|
elif has_image:
|
||||||
@@ -346,12 +365,9 @@ class CascadeRouter:
|
|||||||
elif has_audio:
|
elif has_audio:
|
||||||
return ContentType.AUDIO
|
return ContentType.AUDIO
|
||||||
return ContentType.TEXT
|
return ContentType.TEXT
|
||||||
|
|
||||||
def _get_fallback_model(
|
def _get_fallback_model(
|
||||||
self,
|
self, provider: Provider, original_model: str, content_type: ContentType
|
||||||
provider: Provider,
|
|
||||||
original_model: str,
|
|
||||||
content_type: ContentType
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Get a fallback model for the given content type."""
|
"""Get a fallback model for the given content type."""
|
||||||
# Map content type to capability
|
# Map content type to capability
|
||||||
@@ -360,24 +376,24 @@ class CascadeRouter:
|
|||||||
ContentType.AUDIO: "audio",
|
ContentType.AUDIO: "audio",
|
||||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||||
}
|
}
|
||||||
|
|
||||||
capability = capability_map.get(content_type)
|
capability = capability_map.get(content_type)
|
||||||
if not capability:
|
if not capability:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check provider's models for capability
|
# Check provider's models for capability
|
||||||
fallback_model = provider.get_model_with_capability(capability)
|
fallback_model = provider.get_model_with_capability(capability)
|
||||||
if fallback_model and fallback_model != original_model:
|
if fallback_model and fallback_model != original_model:
|
||||||
return fallback_model
|
return fallback_model
|
||||||
|
|
||||||
# Use fallback chains from config
|
# Use fallback chains from config
|
||||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||||
for model_name in fallback_chain:
|
for model_name in fallback_chain:
|
||||||
if provider.model_has_capability(model_name, capability):
|
if provider.model_has_capability(model_name, capability):
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def complete(
|
async def complete(
|
||||||
self,
|
self,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
@@ -386,21 +402,21 @@ class CascadeRouter:
|
|||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Complete a chat conversation with automatic failover.
|
"""Complete a chat conversation with automatic failover.
|
||||||
|
|
||||||
Multi-modal support:
|
Multi-modal support:
|
||||||
- Automatically detects if messages contain images
|
- Automatically detects if messages contain images
|
||||||
- Falls back to vision-capable models when needed
|
- Falls back to vision-capable models when needed
|
||||||
- Supports image URLs, paths, and base64 encoding
|
- Supports image URLs, paths, and base64 encoding
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts with role and content
|
messages: List of message dicts with role and content
|
||||||
model: Preferred model (tries this first, then provider defaults)
|
model: Preferred model (tries this first, then provider defaults)
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
max_tokens: Maximum tokens to generate
|
max_tokens: Maximum tokens to generate
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with content, provider_used, and metrics
|
Dict with content, provider_used, and metrics
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If all providers fail
|
RuntimeError: If all providers fail
|
||||||
"""
|
"""
|
||||||
@@ -408,15 +424,15 @@ class CascadeRouter:
|
|||||||
content_type = self._detect_content_type(messages)
|
content_type = self._detect_content_type(messages)
|
||||||
if content_type != ContentType.TEXT:
|
if content_type != ContentType.TEXT:
|
||||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for provider in self.providers:
|
for provider in self.providers:
|
||||||
# Skip disabled providers
|
# Skip disabled providers
|
||||||
if not provider.enabled:
|
if not provider.enabled:
|
||||||
logger.debug("Skipping %s (disabled)", provider.name)
|
logger.debug("Skipping %s (disabled)", provider.name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip unhealthy providers (circuit breaker)
|
# Skip unhealthy providers (circuit breaker)
|
||||||
if provider.status == ProviderStatus.UNHEALTHY:
|
if provider.status == ProviderStatus.UNHEALTHY:
|
||||||
# Check if circuit breaker can close
|
# Check if circuit breaker can close
|
||||||
@@ -427,16 +443,16 @@ class CascadeRouter:
|
|||||||
else:
|
else:
|
||||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Determine which model to use
|
# Determine which model to use
|
||||||
selected_model = model or provider.get_default_model()
|
selected_model = model or provider.get_default_model()
|
||||||
is_fallback_model = False
|
is_fallback_model = False
|
||||||
|
|
||||||
# For non-text content, check if model supports it
|
# For non-text content, check if model supports it
|
||||||
if content_type != ContentType.TEXT and selected_model:
|
if content_type != ContentType.TEXT and selected_model:
|
||||||
if provider.type == "ollama" and self._mm_manager:
|
if provider.type == "ollama" and self._mm_manager:
|
||||||
from infrastructure.models.multimodal import ModelCapability
|
from infrastructure.models.multimodal import ModelCapability
|
||||||
|
|
||||||
# Check if selected model supports the required capability
|
# Check if selected model supports the required capability
|
||||||
if content_type == ContentType.VISION:
|
if content_type == ContentType.VISION:
|
||||||
supports = self._mm_manager.model_supports(
|
supports = self._mm_manager.model_supports(
|
||||||
@@ -450,16 +466,17 @@ class CascadeRouter:
|
|||||||
if fallback:
|
if fallback:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Model %s doesn't support vision, falling back to %s",
|
"Model %s doesn't support vision, falling back to %s",
|
||||||
selected_model, fallback
|
selected_model,
|
||||||
|
fallback,
|
||||||
)
|
)
|
||||||
selected_model = fallback
|
selected_model = fallback
|
||||||
is_fallback_model = True
|
is_fallback_model = True
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No vision-capable model found on %s, trying anyway",
|
"No vision-capable model found on %s, trying anyway",
|
||||||
provider.name
|
provider.name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try this provider
|
# Try this provider
|
||||||
for attempt in range(self.config.max_retries_per_provider):
|
for attempt in range(self.config.max_retries_per_provider):
|
||||||
try:
|
try:
|
||||||
@@ -471,34 +488,35 @@ class CascadeRouter:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Success! Update metrics and return
|
# Success! Update metrics and return
|
||||||
self._record_success(provider, result.get("latency_ms", 0))
|
self._record_success(provider, result.get("latency_ms", 0))
|
||||||
return {
|
return {
|
||||||
"content": result["content"],
|
"content": result["content"],
|
||||||
"provider": provider.name,
|
"provider": provider.name,
|
||||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
"model": result.get(
|
||||||
|
"model", selected_model or provider.get_default_model()
|
||||||
|
),
|
||||||
"latency_ms": result.get("latency_ms", 0),
|
"latency_ms": result.get("latency_ms", 0),
|
||||||
"is_fallback_model": is_fallback_model,
|
"is_fallback_model": is_fallback_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_msg = str(exc)
|
error_msg = str(exc)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Provider %s attempt %d failed: %s",
|
"Provider %s attempt %d failed: %s", provider.name, attempt + 1, error_msg
|
||||||
provider.name, attempt + 1, error_msg
|
|
||||||
)
|
)
|
||||||
errors.append(f"{provider.name}: {error_msg}")
|
errors.append(f"{provider.name}: {error_msg}")
|
||||||
|
|
||||||
if attempt < self.config.max_retries_per_provider - 1:
|
if attempt < self.config.max_retries_per_provider - 1:
|
||||||
await asyncio.sleep(self.config.retry_delay_seconds)
|
await asyncio.sleep(self.config.retry_delay_seconds)
|
||||||
|
|
||||||
# All retries failed for this provider
|
# All retries failed for this provider
|
||||||
self._record_failure(provider)
|
self._record_failure(provider)
|
||||||
|
|
||||||
# All providers failed
|
# All providers failed
|
||||||
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
raise RuntimeError(f"All providers failed: {'; '.join(errors)}")
|
||||||
|
|
||||||
async def _try_provider(
|
async def _try_provider(
|
||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
@@ -510,7 +528,7 @@ class CascadeRouter:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Try a single provider request."""
|
"""Try a single provider request."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if provider.type == "ollama":
|
if provider.type == "ollama":
|
||||||
result = await self._call_ollama(
|
result = await self._call_ollama(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -545,12 +563,12 @@ class CascadeRouter:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider type: {provider.type}")
|
raise ValueError(f"Unknown provider type: {provider.type}")
|
||||||
|
|
||||||
latency_ms = (time.time() - start_time) * 1000
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
result["latency_ms"] = latency_ms
|
result["latency_ms"] = latency_ms
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _call_ollama(
|
async def _call_ollama(
|
||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
@@ -561,12 +579,12 @@ class CascadeRouter:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Call Ollama API with multi-modal support."""
|
"""Call Ollama API with multi-modal support."""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
url = f"{provider.url}/api/chat"
|
url = f"{provider.url}/api/chat"
|
||||||
|
|
||||||
# Transform messages for Ollama format (including images)
|
# Transform messages for Ollama format (including images)
|
||||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": transformed_messages,
|
"messages": transformed_messages,
|
||||||
@@ -575,31 +593,31 @@ class CascadeRouter:
|
|||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
async with session.post(url, json=payload) as response:
|
async with session.post(url, json=payload) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
||||||
|
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
return {
|
return {
|
||||||
"content": data["message"]["content"],
|
"content": data["message"]["content"],
|
||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||||
"""Transform messages to Ollama format, handling images."""
|
"""Transform messages to Ollama format, handling images."""
|
||||||
transformed = []
|
transformed = []
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
new_msg = {
|
new_msg = {
|
||||||
"role": msg.get("role", "user"),
|
"role": msg.get("role", "user"),
|
||||||
"content": msg.get("content", ""),
|
"content": msg.get("content", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle images
|
# Handle images
|
||||||
images = msg.get("images", [])
|
images = msg.get("images", [])
|
||||||
if images:
|
if images:
|
||||||
@@ -620,11 +638,11 @@ class CascadeRouter:
|
|||||||
new_msg["images"].append(img_data)
|
new_msg["images"].append(img_data)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Failed to read image %s: %s", img, exc)
|
logger.error("Failed to read image %s: %s", img, exc)
|
||||||
|
|
||||||
transformed.append(new_msg)
|
transformed.append(new_msg)
|
||||||
|
|
||||||
return transformed
|
return transformed
|
||||||
|
|
||||||
async def _call_openai(
|
async def _call_openai(
|
||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
@@ -635,13 +653,13 @@ class CascadeRouter:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Call OpenAI API."""
|
"""Call OpenAI API."""
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
client = openai.AsyncOpenAI(
|
client = openai.AsyncOpenAI(
|
||||||
api_key=provider.api_key,
|
api_key=provider.api_key,
|
||||||
base_url=provider.base_url,
|
base_url=provider.base_url,
|
||||||
timeout=self.config.timeout_seconds,
|
timeout=self.config.timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -649,14 +667,14 @@ class CascadeRouter:
|
|||||||
}
|
}
|
||||||
if max_tokens:
|
if max_tokens:
|
||||||
kwargs["max_tokens"] = max_tokens
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
response = await client.chat.completions.create(**kwargs)
|
response = await client.chat.completions.create(**kwargs)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"content": response.choices[0].message.content,
|
"content": response.choices[0].message.content,
|
||||||
"model": response.model,
|
"model": response.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _call_anthropic(
|
async def _call_anthropic(
|
||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
@@ -667,12 +685,12 @@ class CascadeRouter:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Call Anthropic API."""
|
"""Call Anthropic API."""
|
||||||
import anthropic
|
import anthropic
|
||||||
|
|
||||||
client = anthropic.AsyncAnthropic(
|
client = anthropic.AsyncAnthropic(
|
||||||
api_key=provider.api_key,
|
api_key=provider.api_key,
|
||||||
timeout=self.config.timeout_seconds,
|
timeout=self.config.timeout_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert messages to Anthropic format
|
# Convert messages to Anthropic format
|
||||||
system_msg = None
|
system_msg = None
|
||||||
conversation = []
|
conversation = []
|
||||||
@@ -680,11 +698,13 @@ class CascadeRouter:
|
|||||||
if msg["role"] == "system":
|
if msg["role"] == "system":
|
||||||
system_msg = msg["content"]
|
system_msg = msg["content"]
|
||||||
else:
|
else:
|
||||||
conversation.append({
|
conversation.append(
|
||||||
"role": msg["role"],
|
{
|
||||||
"content": msg["content"],
|
"role": msg["role"],
|
||||||
})
|
"content": msg["content"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": conversation,
|
"messages": conversation,
|
||||||
@@ -693,9 +713,9 @@ class CascadeRouter:
|
|||||||
}
|
}
|
||||||
if system_msg:
|
if system_msg:
|
||||||
kwargs["system"] = system_msg
|
kwargs["system"] = system_msg
|
||||||
|
|
||||||
response = await client.messages.create(**kwargs)
|
response = await client.messages.create(**kwargs)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"content": response.content[0].text,
|
"content": response.content[0].text,
|
||||||
"model": response.model,
|
"model": response.model,
|
||||||
@@ -733,7 +753,7 @@ class CascadeRouter:
|
|||||||
"content": response.choices[0].message.content,
|
"content": response.choices[0].message.content,
|
||||||
"model": response.model,
|
"model": response.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||||
"""Record a successful request."""
|
"""Record a successful request."""
|
||||||
provider.metrics.total_requests += 1
|
provider.metrics.total_requests += 1
|
||||||
@@ -741,50 +761,50 @@ class CascadeRouter:
|
|||||||
provider.metrics.total_latency_ms += latency_ms
|
provider.metrics.total_latency_ms += latency_ms
|
||||||
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
|
provider.metrics.last_request_time = datetime.now(timezone.utc).isoformat()
|
||||||
provider.metrics.consecutive_failures = 0
|
provider.metrics.consecutive_failures = 0
|
||||||
|
|
||||||
# Close circuit breaker if half-open
|
# Close circuit breaker if half-open
|
||||||
if provider.circuit_state == CircuitState.HALF_OPEN:
|
if provider.circuit_state == CircuitState.HALF_OPEN:
|
||||||
provider.half_open_calls += 1
|
provider.half_open_calls += 1
|
||||||
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
if provider.half_open_calls >= self.config.circuit_breaker_half_open_max_calls:
|
||||||
self._close_circuit(provider)
|
self._close_circuit(provider)
|
||||||
|
|
||||||
# Update status based on error rate
|
# Update status based on error rate
|
||||||
if provider.metrics.error_rate < 0.1:
|
if provider.metrics.error_rate < 0.1:
|
||||||
provider.status = ProviderStatus.HEALTHY
|
provider.status = ProviderStatus.HEALTHY
|
||||||
elif provider.metrics.error_rate < 0.3:
|
elif provider.metrics.error_rate < 0.3:
|
||||||
provider.status = ProviderStatus.DEGRADED
|
provider.status = ProviderStatus.DEGRADED
|
||||||
|
|
||||||
def _record_failure(self, provider: Provider) -> None:
|
def _record_failure(self, provider: Provider) -> None:
|
||||||
"""Record a failed request."""
|
"""Record a failed request."""
|
||||||
provider.metrics.total_requests += 1
|
provider.metrics.total_requests += 1
|
||||||
provider.metrics.failed_requests += 1
|
provider.metrics.failed_requests += 1
|
||||||
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
|
provider.metrics.last_error_time = datetime.now(timezone.utc).isoformat()
|
||||||
provider.metrics.consecutive_failures += 1
|
provider.metrics.consecutive_failures += 1
|
||||||
|
|
||||||
# Check if we should open circuit breaker
|
# Check if we should open circuit breaker
|
||||||
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
if provider.metrics.consecutive_failures >= self.config.circuit_breaker_failure_threshold:
|
||||||
self._open_circuit(provider)
|
self._open_circuit(provider)
|
||||||
|
|
||||||
# Update status
|
# Update status
|
||||||
if provider.metrics.error_rate > 0.3:
|
if provider.metrics.error_rate > 0.3:
|
||||||
provider.status = ProviderStatus.DEGRADED
|
provider.status = ProviderStatus.DEGRADED
|
||||||
if provider.metrics.error_rate > 0.5:
|
if provider.metrics.error_rate > 0.5:
|
||||||
provider.status = ProviderStatus.UNHEALTHY
|
provider.status = ProviderStatus.UNHEALTHY
|
||||||
|
|
||||||
def _open_circuit(self, provider: Provider) -> None:
|
def _open_circuit(self, provider: Provider) -> None:
|
||||||
"""Open the circuit breaker for a provider."""
|
"""Open the circuit breaker for a provider."""
|
||||||
provider.circuit_state = CircuitState.OPEN
|
provider.circuit_state = CircuitState.OPEN
|
||||||
provider.circuit_opened_at = time.time()
|
provider.circuit_opened_at = time.time()
|
||||||
provider.status = ProviderStatus.UNHEALTHY
|
provider.status = ProviderStatus.UNHEALTHY
|
||||||
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
logger.warning("Circuit breaker OPEN for %s", provider.name)
|
||||||
|
|
||||||
def _can_close_circuit(self, provider: Provider) -> bool:
|
def _can_close_circuit(self, provider: Provider) -> bool:
|
||||||
"""Check if circuit breaker can transition to half-open."""
|
"""Check if circuit breaker can transition to half-open."""
|
||||||
if provider.circuit_opened_at is None:
|
if provider.circuit_opened_at is None:
|
||||||
return False
|
return False
|
||||||
elapsed = time.time() - provider.circuit_opened_at
|
elapsed = time.time() - provider.circuit_opened_at
|
||||||
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
return elapsed >= self.config.circuit_breaker_recovery_timeout
|
||||||
|
|
||||||
def _close_circuit(self, provider: Provider) -> None:
|
def _close_circuit(self, provider: Provider) -> None:
|
||||||
"""Close the circuit breaker (provider healthy again)."""
|
"""Close the circuit breaker (provider healthy again)."""
|
||||||
provider.circuit_state = CircuitState.CLOSED
|
provider.circuit_state = CircuitState.CLOSED
|
||||||
@@ -793,7 +813,7 @@ class CascadeRouter:
|
|||||||
provider.metrics.consecutive_failures = 0
|
provider.metrics.consecutive_failures = 0
|
||||||
provider.status = ProviderStatus.HEALTHY
|
provider.status = ProviderStatus.HEALTHY
|
||||||
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
logger.info("Circuit breaker CLOSED for %s", provider.name)
|
||||||
|
|
||||||
def get_metrics(self) -> dict:
|
def get_metrics(self) -> dict:
|
||||||
"""Get metrics for all providers."""
|
"""Get metrics for all providers."""
|
||||||
return {
|
return {
|
||||||
@@ -814,16 +834,20 @@ class CascadeRouter:
|
|||||||
for p in self.providers
|
for p in self.providers
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_status(self) -> dict:
|
def get_status(self) -> dict:
|
||||||
"""Get current router status."""
|
"""Get current router status."""
|
||||||
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
healthy = sum(1 for p in self.providers if p.status == ProviderStatus.HEALTHY)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_providers": len(self.providers),
|
"total_providers": len(self.providers),
|
||||||
"healthy_providers": healthy,
|
"healthy_providers": healthy,
|
||||||
"degraded_providers": sum(1 for p in self.providers if p.status == ProviderStatus.DEGRADED),
|
"degraded_providers": sum(
|
||||||
"unhealthy_providers": sum(1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY),
|
1 for p in self.providers if p.status == ProviderStatus.DEGRADED
|
||||||
|
),
|
||||||
|
"unhealthy_providers": sum(
|
||||||
|
1 for p in self.providers if p.status == ProviderStatus.UNHEALTHY
|
||||||
|
),
|
||||||
"providers": [
|
"providers": [
|
||||||
{
|
{
|
||||||
"name": p.name,
|
"name": p.name,
|
||||||
@@ -835,7 +859,7 @@ class CascadeRouter:
|
|||||||
for p in self.providers
|
for p in self.providers
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def generate_with_image(
|
async def generate_with_image(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -844,21 +868,23 @@ class CascadeRouter:
|
|||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Convenience method for vision requests.
|
"""Convenience method for vision requests.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: Text prompt about the image
|
prompt: Text prompt about the image
|
||||||
image_path: Path to image file
|
image_path: Path to image file
|
||||||
model: Vision-capable model (auto-selected if not provided)
|
model: Vision-capable model (auto-selected if not provided)
|
||||||
temperature: Sampling temperature
|
temperature: Sampling temperature
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Response dict with content and metadata
|
Response dict with content and metadata
|
||||||
"""
|
"""
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "user",
|
{
|
||||||
"content": prompt,
|
"role": "user",
|
||||||
"images": [image_path],
|
"content": prompt,
|
||||||
}]
|
"images": [image_path],
|
||||||
|
}
|
||||||
|
]
|
||||||
return await self.complete(
|
return await self.complete(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=model,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class WSEvent:
|
class WSEvent:
|
||||||
"""A WebSocket event to broadcast to connected clients."""
|
"""A WebSocket event to broadcast to connected clients."""
|
||||||
|
|
||||||
event: str
|
event: str
|
||||||
data: dict
|
data: dict
|
||||||
timestamp: str
|
timestamp: str
|
||||||
@@ -93,28 +94,42 @@ class WebSocketManager:
|
|||||||
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
|
await self.broadcast("agent_left", {"agent_id": agent_id, "name": name})
|
||||||
|
|
||||||
async def broadcast_task_posted(self, task_id: str, description: str) -> None:
|
async def broadcast_task_posted(self, task_id: str, description: str) -> None:
|
||||||
await self.broadcast("task_posted", {
|
await self.broadcast(
|
||||||
"task_id": task_id, "description": description,
|
"task_posted",
|
||||||
})
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"description": description,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def broadcast_bid_submitted(
|
async def broadcast_bid_submitted(self, task_id: str, agent_id: str, bid_sats: int) -> None:
|
||||||
self, task_id: str, agent_id: str, bid_sats: int
|
await self.broadcast(
|
||||||
) -> None:
|
"bid_submitted",
|
||||||
await self.broadcast("bid_submitted", {
|
{
|
||||||
"task_id": task_id, "agent_id": agent_id, "bid_sats": bid_sats,
|
"task_id": task_id,
|
||||||
})
|
"agent_id": agent_id,
|
||||||
|
"bid_sats": bid_sats,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
|
async def broadcast_task_assigned(self, task_id: str, agent_id: str) -> None:
|
||||||
await self.broadcast("task_assigned", {
|
await self.broadcast(
|
||||||
"task_id": task_id, "agent_id": agent_id,
|
"task_assigned",
|
||||||
})
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
async def broadcast_task_completed(
|
async def broadcast_task_completed(self, task_id: str, agent_id: str, result: str) -> None:
|
||||||
self, task_id: str, agent_id: str, result: str
|
await self.broadcast(
|
||||||
) -> None:
|
"task_completed",
|
||||||
await self.broadcast("task_completed", {
|
{
|
||||||
"task_id": task_id, "agent_id": agent_id, "result": result[:200],
|
"task_id": task_id,
|
||||||
})
|
"agent_id": agent_id,
|
||||||
|
"result": result[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def connection_count(self) -> int:
|
def connection_count(self) -> int:
|
||||||
@@ -122,28 +137,28 @@ class WebSocketManager:
|
|||||||
|
|
||||||
async def broadcast_json(self, data: dict) -> int:
|
async def broadcast_json(self, data: dict) -> int:
|
||||||
"""Broadcast raw JSON data to all connected clients.
|
"""Broadcast raw JSON data to all connected clients.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Dictionary to send as JSON
|
data: Dictionary to send as JSON
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of clients notified
|
Number of clients notified
|
||||||
"""
|
"""
|
||||||
message = json.dumps(data)
|
message = json.dumps(data)
|
||||||
disconnected = []
|
disconnected = []
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
for ws in self._connections:
|
for ws in self._connections:
|
||||||
try:
|
try:
|
||||||
await ws.send_text(message)
|
await ws.send_text(message)
|
||||||
count += 1
|
count += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
disconnected.append(ws)
|
disconnected.append(ws)
|
||||||
|
|
||||||
# Clean up dead connections
|
# Clean up dead connections
|
||||||
for ws in disconnected:
|
for ws in disconnected:
|
||||||
self.disconnect(ws)
|
self.disconnect(ws)
|
||||||
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
class PlatformState(Enum):
|
class PlatformState(Enum):
|
||||||
"""Lifecycle state of a chat platform connection."""
|
"""Lifecycle state of a chat platform connection."""
|
||||||
|
|
||||||
DISCONNECTED = auto()
|
DISCONNECTED = auto()
|
||||||
CONNECTING = auto()
|
CONNECTING = auto()
|
||||||
CONNECTED = auto()
|
CONNECTED = auto()
|
||||||
@@ -30,13 +31,12 @@ class PlatformState(Enum):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage:
|
class ChatMessage:
|
||||||
"""Vendor-agnostic representation of a chat message."""
|
"""Vendor-agnostic representation of a chat message."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
author: str
|
author: str
|
||||||
channel_id: str
|
channel_id: str
|
||||||
platform: str
|
platform: str
|
||||||
timestamp: str = field(
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
|
||||||
)
|
|
||||||
message_id: Optional[str] = None
|
message_id: Optional[str] = None
|
||||||
thread_id: Optional[str] = None
|
thread_id: Optional[str] = None
|
||||||
attachments: list[str] = field(default_factory=list)
|
attachments: list[str] = field(default_factory=list)
|
||||||
@@ -46,13 +46,12 @@ class ChatMessage:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatThread:
|
class ChatThread:
|
||||||
"""Vendor-agnostic representation of a conversation thread."""
|
"""Vendor-agnostic representation of a conversation thread."""
|
||||||
|
|
||||||
thread_id: str
|
thread_id: str
|
||||||
title: str
|
title: str
|
||||||
channel_id: str
|
channel_id: str
|
||||||
platform: str
|
platform: str
|
||||||
created_at: str = field(
|
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
|
||||||
)
|
|
||||||
archived: bool = False
|
archived: bool = False
|
||||||
message_count: int = 0
|
message_count: int = 0
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
@@ -61,6 +60,7 @@ class ChatThread:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class InviteInfo:
|
class InviteInfo:
|
||||||
"""Parsed invite extracted from an image or text."""
|
"""Parsed invite extracted from an image or text."""
|
||||||
|
|
||||||
url: str
|
url: str
|
||||||
code: str
|
code: str
|
||||||
platform: str
|
platform: str
|
||||||
@@ -71,6 +71,7 @@ class InviteInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PlatformStatus:
|
class PlatformStatus:
|
||||||
"""Current status of a chat platform connection."""
|
"""Current status of a chat platform connection."""
|
||||||
|
|
||||||
platform: str
|
platform: str
|
||||||
state: PlatformState
|
state: PlatformState
|
||||||
token_set: bool
|
token_set: bool
|
||||||
|
|||||||
@@ -115,7 +115,9 @@ class InviteParser:
|
|||||||
"""Strategy 2: Use Ollama vision model for local OCR."""
|
"""Strategy 2: Use Ollama vision model for local OCR."""
|
||||||
try:
|
try:
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.debug("httpx not available for Ollama vision.")
|
logger.debug("httpx not available for Ollama vision.")
|
||||||
|
|||||||
15
src/integrations/chat_bridge/vendors/discord.py
vendored
15
src/integrations/chat_bridge/vendors/discord.py
vendored
@@ -90,10 +90,7 @@ class DiscordVendor(ChatPlatform):
|
|||||||
try:
|
try:
|
||||||
import discord
|
import discord
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error("discord.py is not installed. " 'Run: pip install ".[discord]"')
|
||||||
"discord.py is not installed. "
|
|
||||||
'Run: pip install ".[discord]"'
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -267,6 +264,7 @@ class DiscordVendor(ChatPlatform):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return settings.discord_token or None
|
return settings.discord_token or None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -363,9 +361,7 @@ class DiscordVendor(ChatPlatform):
|
|||||||
# Show typing indicator while the agent processes
|
# Show typing indicator while the agent processes
|
||||||
async with target.typing():
|
async with target.typing():
|
||||||
run = await asyncio.wait_for(
|
run = await asyncio.wait_for(
|
||||||
asyncio.to_thread(
|
asyncio.to_thread(agent.run, content, stream=False, session_id=session_id),
|
||||||
agent.run, content, stream=False, session_id=session_id
|
|
||||||
),
|
|
||||||
timeout=300,
|
timeout=300,
|
||||||
)
|
)
|
||||||
response = run.content if hasattr(run, "content") else str(run)
|
response = run.content if hasattr(run, "content") else str(run)
|
||||||
@@ -374,7 +370,9 @@ class DiscordVendor(ChatPlatform):
|
|||||||
response = "Sorry, that took too long. Please try a simpler request."
|
response = "Sorry, that took too long. Please try a simpler request."
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Discord: agent.run() failed: %s", exc)
|
logger.error("Discord: agent.run() failed: %s", exc)
|
||||||
response = "I'm having trouble reaching my language model right now. Please try again shortly."
|
response = (
|
||||||
|
"I'm having trouble reaching my language model right now. Please try again shortly."
|
||||||
|
)
|
||||||
|
|
||||||
# Strip hallucinated tool-call JSON and chain-of-thought narration
|
# Strip hallucinated tool-call JSON and chain-of-thought narration
|
||||||
from timmy.session import _clean_response
|
from timmy.session import _clean_response
|
||||||
@@ -408,6 +406,7 @@ class DiscordVendor(ChatPlatform):
|
|||||||
|
|
||||||
# Create a thread from this message
|
# Create a thread from this message
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
thread_name = f"{settings.agent_name} | {message.author.display_name}"
|
thread_name = f"{settings.agent_name} | {message.author.display_name}"
|
||||||
thread = await message.create_thread(
|
thread = await message.create_thread(
|
||||||
name=thread_name[:100],
|
name=thread_name[:100],
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
# ── Inbound: Paperclip → Timmy ──────────────────────────────────────────────
|
# ── Inbound: Paperclip → Timmy ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ import logging
|
|||||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from integrations.paperclip.bridge import PaperclipBridge, bridge as default_bridge
|
from integrations.paperclip.bridge import PaperclipBridge
|
||||||
|
from integrations.paperclip.bridge import bridge as default_bridge
|
||||||
from integrations.paperclip.models import PaperclipIssue
|
from integrations.paperclip.models import PaperclipIssue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -30,9 +31,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class Orchestrator(Protocol):
|
class Orchestrator(Protocol):
|
||||||
"""Anything with an ``execute_task`` matching Timmy's orchestrator."""
|
"""Anything with an ``execute_task`` matching Timmy's orchestrator."""
|
||||||
|
|
||||||
async def execute_task(
|
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||||
self, task_id: str, description: str, context: dict
|
...
|
||||||
) -> Any: ...
|
|
||||||
|
|
||||||
|
|
||||||
def _wrap_orchestrator(orch: Orchestrator) -> Callable:
|
def _wrap_orchestrator(orch: Orchestrator) -> Callable:
|
||||||
@@ -125,7 +125,9 @@ class TaskRunner:
|
|||||||
# Mark the issue as done
|
# Mark the issue as done
|
||||||
return await self.bridge.close_issue(issue.id, comment=None)
|
return await self.bridge.close_issue(issue.id, comment=None)
|
||||||
|
|
||||||
async def create_follow_up(self, original: PaperclipIssue, result: str) -> Optional[PaperclipIssue]:
|
async def create_follow_up(
|
||||||
|
self, original: PaperclipIssue, result: str
|
||||||
|
) -> Optional[PaperclipIssue]:
|
||||||
"""Create a recursive follow-up task for Timmy.
|
"""Create a recursive follow-up task for Timmy.
|
||||||
|
|
||||||
Timmy muses about task automation and writes a follow-up issue
|
Timmy muses about task automation and writes a follow-up issue
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ShortcutAction:
|
class ShortcutAction:
|
||||||
"""Describes a Siri Shortcut action for the setup guide."""
|
"""Describes a Siri Shortcut action for the setup guide."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
endpoint: str
|
endpoint: str
|
||||||
method: str
|
method: str
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TelegramBot:
|
|||||||
return from_file
|
return from_file
|
||||||
try:
|
try:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return settings.telegram_token or None
|
return settings.telegram_token or None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -94,10 +95,7 @@ class TelegramBot:
|
|||||||
filters,
|
filters,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error(
|
logger.error("python-telegram-bot is not installed. " 'Run: pip install ".[telegram]"')
|
||||||
"python-telegram-bot is not installed. "
|
|
||||||
'Run: pip install ".[telegram]"'
|
|
||||||
)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -149,6 +147,7 @@ class TelegramBot:
|
|||||||
user_text = update.message.text
|
user_text = update.message.text
|
||||||
try:
|
try:
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
|
|
||||||
agent = create_timmy()
|
agent = create_timmy()
|
||||||
run = await asyncio.to_thread(agent.run, user_text, stream=False)
|
run = await asyncio.to_thread(agent.run, user_text, stream=False)
|
||||||
response = run.content if hasattr(run, "content") else str(run)
|
response = run.content if hasattr(run, "content") else str(run)
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Intents:
|
|||||||
- unknown: Unrecognized intent
|
- unknown: Unrecognized intent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -35,47 +35,68 @@ class Intent:
|
|||||||
|
|
||||||
_PATTERNS: list[tuple[str, re.Pattern, float]] = [
|
_PATTERNS: list[tuple[str, re.Pattern, float]] = [
|
||||||
# Status queries
|
# Status queries
|
||||||
("status", re.compile(
|
(
|
||||||
r"\b(status|health|how are you|are you (running|online|alive)|check)\b",
|
"status",
|
||||||
re.IGNORECASE,
|
re.compile(
|
||||||
), 0.9),
|
r"\b(status|health|how are you|are you (running|online|alive)|check)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.9,
|
||||||
|
),
|
||||||
# Swarm commands
|
# Swarm commands
|
||||||
("swarm", re.compile(
|
(
|
||||||
r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b",
|
"swarm",
|
||||||
re.IGNORECASE,
|
re.compile(
|
||||||
), 0.85),
|
r"\b(swarm|spawn|agents?|sub-?agents?|workers?)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
# Task commands
|
# Task commands
|
||||||
("task", re.compile(
|
(
|
||||||
r"\b(task|assign|create task|new task|post task|bid)\b",
|
"task",
|
||||||
re.IGNORECASE,
|
re.compile(
|
||||||
), 0.85),
|
r"\b(task|assign|create task|new task|post task|bid)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
# Help
|
# Help
|
||||||
("help", re.compile(
|
(
|
||||||
r"\b(help|commands?|what can you do|capabilities)\b",
|
"help",
|
||||||
re.IGNORECASE,
|
re.compile(
|
||||||
), 0.9),
|
r"\b(help|commands?|what can you do|capabilities)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.9,
|
||||||
|
),
|
||||||
# Voice settings
|
# Voice settings
|
||||||
("voice", re.compile(
|
(
|
||||||
r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b",
|
"voice",
|
||||||
re.IGNORECASE,
|
re.compile(
|
||||||
), 0.85),
|
r"\b(voice|speak|volume|rate|speed|louder|quieter|faster|slower|mute|unmute)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.85,
|
||||||
|
),
|
||||||
# Code modification / self-modify
|
# Code modification / self-modify
|
||||||
("code", re.compile(
|
(
|
||||||
r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b"
|
"code",
|
||||||
r"|\bself[- ]?modify\b"
|
re.compile(
|
||||||
r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b",
|
r"\b(modify|edit|change|update|fix|refactor|implement|patch)\s+(the\s+)?(code|file|function|class|module|source)\b"
|
||||||
re.IGNORECASE,
|
r"|\bself[- ]?modify\b"
|
||||||
), 0.9),
|
r"|\b(update|change|edit)\s+(your|the)\s+(code|source)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
0.9,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Keywords for entity extraction
|
# Keywords for entity extraction
|
||||||
_ENTITY_PATTERNS = {
|
_ENTITY_PATTERNS = {
|
||||||
"agent_name": re.compile(r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE),
|
"agent_name": re.compile(
|
||||||
|
r"(?:spawn|start)\s+(?:agent\s+)?(\w+)|(?:agent)\s+(\w+)", re.IGNORECASE
|
||||||
|
),
|
||||||
"task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE),
|
"task_description": re.compile(r"(?:task|assign)[:;]?\s+(.+)", re.IGNORECASE),
|
||||||
"number": re.compile(r"\b(\d+)\b"),
|
"number": re.compile(r"\b(\d+)\b"),
|
||||||
"target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE),
|
"target_file": re.compile(r"(?:in|file|modify)\s+(?:the\s+)?([/\w._-]+\.py)", re.IGNORECASE),
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from spark import memory as spark_memory
|
|
||||||
from spark import eidos as spark_eidos
|
from spark import eidos as spark_eidos
|
||||||
|
from spark import memory as spark_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,10 +29,11 @@ _MIN_EVENTS = 3
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Advisory:
|
class Advisory:
|
||||||
"""A single ranked recommendation."""
|
"""A single ranked recommendation."""
|
||||||
category: str # agent_performance, bid_optimization, etc.
|
|
||||||
priority: float # 0.0–1.0 (higher = more urgent)
|
category: str # agent_performance, bid_optimization, etc.
|
||||||
title: str # Short headline
|
priority: float # 0.0–1.0 (higher = more urgent)
|
||||||
detail: str # Longer explanation
|
title: str # Short headline
|
||||||
|
detail: str # Longer explanation
|
||||||
suggested_action: str # What to do about it
|
suggested_action: str # What to do about it
|
||||||
subject: Optional[str] = None # agent_id or None for system-level
|
subject: Optional[str] = None # agent_id or None for system-level
|
||||||
evidence_count: int = 0 # Number of supporting events
|
evidence_count: int = 0 # Number of supporting events
|
||||||
@@ -47,15 +48,17 @@ def generate_advisories() -> list[Advisory]:
|
|||||||
|
|
||||||
event_count = spark_memory.count_events()
|
event_count = spark_memory.count_events()
|
||||||
if event_count < _MIN_EVENTS:
|
if event_count < _MIN_EVENTS:
|
||||||
advisories.append(Advisory(
|
advisories.append(
|
||||||
category="system_health",
|
Advisory(
|
||||||
priority=0.3,
|
category="system_health",
|
||||||
title="Insufficient data",
|
priority=0.3,
|
||||||
detail=f"Only {event_count} events captured. "
|
title="Insufficient data",
|
||||||
f"Spark needs at least {_MIN_EVENTS} events to generate insights.",
|
detail=f"Only {event_count} events captured. "
|
||||||
suggested_action="Run more swarm tasks to build intelligence.",
|
f"Spark needs at least {_MIN_EVENTS} events to generate insights.",
|
||||||
evidence_count=event_count,
|
suggested_action="Run more swarm tasks to build intelligence.",
|
||||||
))
|
evidence_count=event_count,
|
||||||
|
)
|
||||||
|
)
|
||||||
return advisories
|
return advisories
|
||||||
|
|
||||||
advisories.extend(_check_failure_patterns())
|
advisories.extend(_check_failure_patterns())
|
||||||
@@ -82,18 +85,20 @@ def _check_failure_patterns() -> list[Advisory]:
|
|||||||
|
|
||||||
for aid, count in agent_failures.items():
|
for aid, count in agent_failures.items():
|
||||||
if count >= 2:
|
if count >= 2:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="failure_prevention",
|
Advisory(
|
||||||
priority=min(1.0, 0.5 + count * 0.15),
|
category="failure_prevention",
|
||||||
title=f"Agent {aid[:8]} has {count} failures",
|
priority=min(1.0, 0.5 + count * 0.15),
|
||||||
detail=f"Agent {aid[:8]}... has failed {count} recent tasks. "
|
title=f"Agent {aid[:8]} has {count} failures",
|
||||||
f"This pattern may indicate a capability mismatch or "
|
detail=f"Agent {aid[:8]}... has failed {count} recent tasks. "
|
||||||
f"configuration issue.",
|
f"This pattern may indicate a capability mismatch or "
|
||||||
suggested_action=f"Review task types assigned to {aid[:8]}... "
|
f"configuration issue.",
|
||||||
f"and consider adjusting routing preferences.",
|
suggested_action=f"Review task types assigned to {aid[:8]}... "
|
||||||
subject=aid,
|
f"and consider adjusting routing preferences.",
|
||||||
evidence_count=count,
|
subject=aid,
|
||||||
))
|
evidence_count=count,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -128,27 +133,31 @@ def _check_agent_performance() -> list[Advisory]:
|
|||||||
|
|
||||||
rate = wins / total
|
rate = wins / total
|
||||||
if rate >= 0.8 and total >= 3:
|
if rate >= 0.8 and total >= 3:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="agent_performance",
|
Advisory(
|
||||||
priority=0.6,
|
category="agent_performance",
|
||||||
title=f"Agent {aid[:8]} excels ({rate:.0%} success)",
|
priority=0.6,
|
||||||
detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks "
|
title=f"Agent {aid[:8]} excels ({rate:.0%} success)",
|
||||||
f"successfully. Consider routing more tasks to this agent.",
|
detail=f"Agent {aid[:8]}... has completed {wins}/{total} tasks "
|
||||||
suggested_action="Increase task routing weight for this agent.",
|
f"successfully. Consider routing more tasks to this agent.",
|
||||||
subject=aid,
|
suggested_action="Increase task routing weight for this agent.",
|
||||||
evidence_count=total,
|
subject=aid,
|
||||||
))
|
evidence_count=total,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif rate <= 0.3 and total >= 3:
|
elif rate <= 0.3 and total >= 3:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="agent_performance",
|
Advisory(
|
||||||
priority=0.75,
|
category="agent_performance",
|
||||||
title=f"Agent {aid[:8]} struggling ({rate:.0%} success)",
|
priority=0.75,
|
||||||
detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. "
|
title=f"Agent {aid[:8]} struggling ({rate:.0%} success)",
|
||||||
f"May need different task types or capability updates.",
|
detail=f"Agent {aid[:8]}... has only succeeded on {wins}/{total} tasks. "
|
||||||
suggested_action="Review this agent's capabilities and assigned task types.",
|
f"May need different task types or capability updates.",
|
||||||
subject=aid,
|
suggested_action="Review this agent's capabilities and assigned task types.",
|
||||||
evidence_count=total,
|
subject=aid,
|
||||||
))
|
evidence_count=total,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -181,27 +190,31 @@ def _check_bid_patterns() -> list[Advisory]:
|
|||||||
spread = max_bid - min_bid
|
spread = max_bid - min_bid
|
||||||
|
|
||||||
if spread > avg_bid * 1.5:
|
if spread > avg_bid * 1.5:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="bid_optimization",
|
Advisory(
|
||||||
priority=0.5,
|
category="bid_optimization",
|
||||||
title=f"Wide bid spread ({min_bid}–{max_bid} sats)",
|
priority=0.5,
|
||||||
detail=f"Bids range from {min_bid} to {max_bid} sats "
|
title=f"Wide bid spread ({min_bid}–{max_bid} sats)",
|
||||||
f"(avg {avg_bid:.0f}). Large spread may indicate "
|
detail=f"Bids range from {min_bid} to {max_bid} sats "
|
||||||
f"inefficient auction dynamics.",
|
f"(avg {avg_bid:.0f}). Large spread may indicate "
|
||||||
suggested_action="Review agent bid strategies for consistency.",
|
f"inefficient auction dynamics.",
|
||||||
evidence_count=len(bid_amounts),
|
suggested_action="Review agent bid strategies for consistency.",
|
||||||
))
|
evidence_count=len(bid_amounts),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if avg_bid > 70:
|
if avg_bid > 70:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="bid_optimization",
|
Advisory(
|
||||||
priority=0.45,
|
category="bid_optimization",
|
||||||
title=f"High average bid ({avg_bid:.0f} sats)",
|
priority=0.45,
|
||||||
detail=f"The swarm average bid is {avg_bid:.0f} sats across "
|
title=f"High average bid ({avg_bid:.0f} sats)",
|
||||||
f"{len(bid_amounts)} bids. This may be above optimal.",
|
detail=f"The swarm average bid is {avg_bid:.0f} sats across "
|
||||||
suggested_action="Consider adjusting base bid rates for persona agents.",
|
f"{len(bid_amounts)} bids. This may be above optimal.",
|
||||||
evidence_count=len(bid_amounts),
|
suggested_action="Consider adjusting base bid rates for persona agents.",
|
||||||
))
|
evidence_count=len(bid_amounts),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -216,27 +229,31 @@ def _check_prediction_accuracy() -> list[Advisory]:
|
|||||||
|
|
||||||
avg = stats["avg_accuracy"]
|
avg = stats["avg_accuracy"]
|
||||||
if avg < 0.4:
|
if avg < 0.4:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="system_health",
|
Advisory(
|
||||||
priority=0.65,
|
category="system_health",
|
||||||
title=f"Low prediction accuracy ({avg:.0%})",
|
priority=0.65,
|
||||||
detail=f"EIDOS predictions have averaged {avg:.0%} accuracy "
|
title=f"Low prediction accuracy ({avg:.0%})",
|
||||||
f"over {stats['evaluated']} evaluations. The learning "
|
detail=f"EIDOS predictions have averaged {avg:.0%} accuracy "
|
||||||
f"model needs more data or the swarm behaviour is changing.",
|
f"over {stats['evaluated']} evaluations. The learning "
|
||||||
suggested_action="Continue running tasks; accuracy should improve "
|
f"model needs more data or the swarm behaviour is changing.",
|
||||||
"as the model accumulates more training data.",
|
suggested_action="Continue running tasks; accuracy should improve "
|
||||||
evidence_count=stats["evaluated"],
|
"as the model accumulates more training data.",
|
||||||
))
|
evidence_count=stats["evaluated"],
|
||||||
|
)
|
||||||
|
)
|
||||||
elif avg >= 0.75:
|
elif avg >= 0.75:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="system_health",
|
Advisory(
|
||||||
priority=0.3,
|
category="system_health",
|
||||||
title=f"Strong prediction accuracy ({avg:.0%})",
|
priority=0.3,
|
||||||
detail=f"EIDOS predictions are performing well at {avg:.0%} "
|
title=f"Strong prediction accuracy ({avg:.0%})",
|
||||||
f"average accuracy over {stats['evaluated']} evaluations.",
|
detail=f"EIDOS predictions are performing well at {avg:.0%} "
|
||||||
suggested_action="No action needed. Spark intelligence is learning effectively.",
|
f"average accuracy over {stats['evaluated']} evaluations.",
|
||||||
evidence_count=stats["evaluated"],
|
suggested_action="No action needed. Spark intelligence is learning effectively.",
|
||||||
))
|
evidence_count=stats["evaluated"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -247,14 +264,16 @@ def _check_system_activity() -> list[Advisory]:
|
|||||||
recent = spark_memory.get_events(limit=5)
|
recent = spark_memory.get_events(limit=5)
|
||||||
|
|
||||||
if not recent:
|
if not recent:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="system_health",
|
Advisory(
|
||||||
priority=0.4,
|
category="system_health",
|
||||||
title="No swarm activity detected",
|
priority=0.4,
|
||||||
detail="Spark has not captured any events. "
|
title="No swarm activity detected",
|
||||||
"The swarm may be idle or Spark event capture is not active.",
|
detail="Spark has not captured any events. "
|
||||||
suggested_action="Post a task to the swarm to activate the pipeline.",
|
"The swarm may be idle or Spark event capture is not active.",
|
||||||
))
|
suggested_action="Post a task to the swarm to activate the pipeline.",
|
||||||
|
)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# Check event type distribution
|
# Check event type distribution
|
||||||
@@ -265,14 +284,16 @@ def _check_system_activity() -> list[Advisory]:
|
|||||||
|
|
||||||
if "task_completed" not in type_counts and "task_failed" not in type_counts:
|
if "task_completed" not in type_counts and "task_failed" not in type_counts:
|
||||||
if type_counts.get("task_posted", 0) > 3:
|
if type_counts.get("task_posted", 0) > 3:
|
||||||
results.append(Advisory(
|
results.append(
|
||||||
category="system_health",
|
Advisory(
|
||||||
priority=0.6,
|
category="system_health",
|
||||||
title="Tasks posted but none completing",
|
priority=0.6,
|
||||||
detail=f"{type_counts.get('task_posted', 0)} tasks posted "
|
title="Tasks posted but none completing",
|
||||||
f"but no completions or failures recorded.",
|
detail=f"{type_counts.get('task_posted', 0)} tasks posted "
|
||||||
suggested_action="Check agent availability and auction configuration.",
|
f"but no completions or failures recorded.",
|
||||||
evidence_count=type_counts.get("task_posted", 0),
|
suggested_action="Check agent availability and auction configuration.",
|
||||||
))
|
evidence_count=type_counts.get("task_posted", 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -29,12 +29,13 @@ DB_PATH = Path("data/spark.db")
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Prediction:
|
class Prediction:
|
||||||
"""A prediction made by the EIDOS loop."""
|
"""A prediction made by the EIDOS loop."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
task_id: str
|
task_id: str
|
||||||
prediction_type: str # outcome, best_agent, bid_range
|
prediction_type: str # outcome, best_agent, bid_range
|
||||||
predicted_value: str # JSON-encoded prediction
|
predicted_value: str # JSON-encoded prediction
|
||||||
actual_value: Optional[str] # JSON-encoded actual (filled on evaluation)
|
actual_value: Optional[str] # JSON-encoded actual (filled on evaluation)
|
||||||
accuracy: Optional[float] # 0.0–1.0 (filled on evaluation)
|
accuracy: Optional[float] # 0.0–1.0 (filled on evaluation)
|
||||||
created_at: str
|
created_at: str
|
||||||
evaluated_at: Optional[str]
|
evaluated_at: Optional[str]
|
||||||
|
|
||||||
@@ -57,18 +58,15 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_pred_task ON spark_predictions(task_id)"
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)")
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_pred_type ON spark_predictions(prediction_type)"
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
# ── Prediction phase ────────────────────────────────────────────────────────
|
# ── Prediction phase ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def predict_task_outcome(
|
def predict_task_outcome(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_description: str,
|
task_description: str,
|
||||||
@@ -104,12 +102,8 @@ def predict_task_outcome(
|
|||||||
|
|
||||||
if best_agent:
|
if best_agent:
|
||||||
prediction["likely_winner"] = best_agent
|
prediction["likely_winner"] = best_agent
|
||||||
prediction["success_probability"] = round(
|
prediction["success_probability"] = round(min(1.0, 0.5 + best_rate * 0.4), 2)
|
||||||
min(1.0, 0.5 + best_rate * 0.4), 2
|
prediction["reasoning"] = f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
|
||||||
)
|
|
||||||
prediction["reasoning"] = (
|
|
||||||
f"agent {best_agent[:8]} has {best_rate:.0%} success rate"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust bid range from history
|
# Adjust bid range from history
|
||||||
all_bids = []
|
all_bids = []
|
||||||
@@ -144,6 +138,7 @@ def predict_task_outcome(
|
|||||||
|
|
||||||
# ── Evaluation phase ────────────────────────────────────────────────────────
|
# ── Evaluation phase ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def evaluate_prediction(
|
def evaluate_prediction(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
actual_winner: Optional[str],
|
actual_winner: Optional[str],
|
||||||
@@ -242,6 +237,7 @@ def _compute_accuracy(predicted: dict, actual: dict) -> float:
|
|||||||
|
|
||||||
# ── Query helpers ──────────────────────────────────────────────────────────
|
# ── Query helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def get_predictions(
|
def get_predictions(
|
||||||
task_id: Optional[str] = None,
|
task_id: Optional[str] = None,
|
||||||
evaluated_only: bool = False,
|
evaluated_only: bool = False,
|
||||||
|
|||||||
@@ -76,7 +76,10 @@ class SparkEngine:
|
|||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def on_bid_submitted(
|
def on_bid_submitted(
|
||||||
self, task_id: str, agent_id: str, bid_sats: int,
|
self,
|
||||||
|
task_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
bid_sats: int,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Capture a bid event."""
|
"""Capture a bid event."""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
@@ -90,12 +93,13 @@ class SparkEngine:
|
|||||||
data=json.dumps({"bid_sats": bid_sats}),
|
data=json.dumps({"bid_sats": bid_sats}),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Spark: captured bid %s→%s (%d sats)",
|
logger.debug("Spark: captured bid %s→%s (%d sats)", agent_id[:8], task_id[:8], bid_sats)
|
||||||
agent_id[:8], task_id[:8], bid_sats)
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def on_task_assigned(
|
def on_task_assigned(
|
||||||
self, task_id: str, agent_id: str,
|
self,
|
||||||
|
task_id: str,
|
||||||
|
agent_id: str,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Capture a task-assigned event."""
|
"""Capture a task-assigned event."""
|
||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
@@ -108,8 +112,7 @@ class SparkEngine:
|
|||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Spark: captured assignment %s→%s",
|
logger.debug("Spark: captured assignment %s→%s", task_id[:8], agent_id[:8])
|
||||||
task_id[:8], agent_id[:8])
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def on_task_completed(
|
def on_task_completed(
|
||||||
@@ -128,10 +131,12 @@ class SparkEngine:
|
|||||||
description=f"Task completed by {agent_id[:8]}",
|
description=f"Task completed by {agent_id[:8]}",
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
data=json.dumps({
|
data=json.dumps(
|
||||||
"result_length": len(result),
|
{
|
||||||
"winning_bid": winning_bid,
|
"result_length": len(result),
|
||||||
}),
|
"winning_bid": winning_bid,
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluate EIDOS prediction
|
# Evaluate EIDOS prediction
|
||||||
@@ -154,8 +159,7 @@ class SparkEngine:
|
|||||||
# Consolidate memory if enough events for this agent
|
# Consolidate memory if enough events for this agent
|
||||||
self._maybe_consolidate(agent_id)
|
self._maybe_consolidate(agent_id)
|
||||||
|
|
||||||
logger.debug("Spark: captured completion %s by %s",
|
logger.debug("Spark: captured completion %s by %s", task_id[:8], agent_id[:8])
|
||||||
task_id[:8], agent_id[:8])
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def on_task_failed(
|
def on_task_failed(
|
||||||
@@ -186,8 +190,7 @@ class SparkEngine:
|
|||||||
# Failures always worth consolidating
|
# Failures always worth consolidating
|
||||||
self._maybe_consolidate(agent_id)
|
self._maybe_consolidate(agent_id)
|
||||||
|
|
||||||
logger.debug("Spark: captured failure %s by %s",
|
logger.debug("Spark: captured failure %s by %s", task_id[:8], agent_id[:8])
|
||||||
task_id[:8], agent_id[:8])
|
|
||||||
return event_id
|
return event_id
|
||||||
|
|
||||||
def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]:
|
def on_agent_joined(self, agent_id: str, name: str) -> Optional[str]:
|
||||||
@@ -288,7 +291,7 @@ class SparkEngine:
|
|||||||
memory_type="pattern",
|
memory_type="pattern",
|
||||||
subject=agent_id,
|
subject=agent_id,
|
||||||
content=f"Agent {agent_id[:8]} has a strong track record: "
|
content=f"Agent {agent_id[:8]} has a strong track record: "
|
||||||
f"{len(completions)}/{total} tasks completed successfully.",
|
f"{len(completions)}/{total} tasks completed successfully.",
|
||||||
confidence=min(0.95, 0.6 + total * 0.05),
|
confidence=min(0.95, 0.6 + total * 0.05),
|
||||||
source_events=total,
|
source_events=total,
|
||||||
)
|
)
|
||||||
@@ -297,7 +300,7 @@ class SparkEngine:
|
|||||||
memory_type="anomaly",
|
memory_type="anomaly",
|
||||||
subject=agent_id,
|
subject=agent_id,
|
||||||
content=f"Agent {agent_id[:8]} is struggling: only "
|
content=f"Agent {agent_id[:8]} is struggling: only "
|
||||||
f"{len(completions)}/{total} tasks completed.",
|
f"{len(completions)}/{total} tasks completed.",
|
||||||
confidence=min(0.95, 0.6 + total * 0.05),
|
confidence=min(0.95, 0.6 + total * 0.05),
|
||||||
source_events=total,
|
source_events=total,
|
||||||
)
|
)
|
||||||
@@ -347,6 +350,7 @@ class SparkEngine:
|
|||||||
def _create_engine() -> SparkEngine:
|
def _create_engine() -> SparkEngine:
|
||||||
try:
|
try:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return SparkEngine(enabled=settings.spark_enabled)
|
return SparkEngine(enabled=settings.spark_enabled)
|
||||||
except Exception:
|
except Exception:
|
||||||
return SparkEngine(enabled=True)
|
return SparkEngine(enabled=True)
|
||||||
|
|||||||
@@ -28,25 +28,27 @@ IMPORTANCE_HIGH = 0.8
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SparkEvent:
|
class SparkEvent:
|
||||||
"""A single captured swarm event."""
|
"""A single captured swarm event."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
event_type: str # task_posted, bid, assignment, completion, failure
|
event_type: str # task_posted, bid, assignment, completion, failure
|
||||||
agent_id: Optional[str]
|
agent_id: Optional[str]
|
||||||
task_id: Optional[str]
|
task_id: Optional[str]
|
||||||
description: str
|
description: str
|
||||||
data: str # JSON payload
|
data: str # JSON payload
|
||||||
importance: float # 0.0–1.0
|
importance: float # 0.0–1.0
|
||||||
created_at: str
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SparkMemory:
|
class SparkMemory:
|
||||||
"""A consolidated memory distilled from event patterns."""
|
"""A consolidated memory distilled from event patterns."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
memory_type: str # pattern, insight, anomaly
|
memory_type: str # pattern, insight, anomaly
|
||||||
subject: str # agent_id or "system"
|
subject: str # agent_id or "system"
|
||||||
content: str # Human-readable insight
|
content: str # Human-readable insight
|
||||||
confidence: float # 0.0–1.0
|
confidence: float # 0.0–1.0
|
||||||
source_events: int # How many events contributed
|
source_events: int # How many events contributed
|
||||||
created_at: str
|
created_at: str
|
||||||
expires_at: Optional[str]
|
expires_at: Optional[str]
|
||||||
|
|
||||||
@@ -83,24 +85,17 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_events_type ON spark_events(event_type)"
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)")
|
||||||
)
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)")
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_events_agent ON spark_events(agent_id)"
|
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_events_task ON spark_events(task_id)"
|
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memories_subject ON spark_memories(subject)"
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
|
|
||||||
# ── Importance scoring ──────────────────────────────────────────────────────
|
# ── Importance scoring ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def score_importance(event_type: str, data: dict) -> float:
|
def score_importance(event_type: str, data: dict) -> float:
|
||||||
"""Compute importance score for an event (0.0–1.0).
|
"""Compute importance score for an event (0.0–1.0).
|
||||||
|
|
||||||
@@ -132,6 +127,7 @@ def score_importance(event_type: str, data: dict) -> float:
|
|||||||
|
|
||||||
# ── Event recording ─────────────────────────────────────────────────────────
|
# ── Event recording ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def record_event(
|
def record_event(
|
||||||
event_type: str,
|
event_type: str,
|
||||||
description: str,
|
description: str,
|
||||||
@@ -142,6 +138,7 @@ def record_event(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Record a swarm event. Returns the event id."""
|
"""Record a swarm event. Returns the event id."""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
event_id = str(uuid.uuid4())
|
event_id = str(uuid.uuid4())
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
@@ -224,6 +221,7 @@ def count_events(event_type: Optional[str] = None) -> int:
|
|||||||
|
|
||||||
# ── Memory consolidation ───────────────────────────────────────────────────
|
# ── Memory consolidation ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def store_memory(
|
def store_memory(
|
||||||
memory_type: str,
|
memory_type: str,
|
||||||
subject: str,
|
subject: str,
|
||||||
|
|||||||
@@ -73,7 +73,8 @@ def _ensure_db() -> sqlite3.Connection:
|
|||||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(DB_PATH))
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS events (
|
CREATE TABLE IF NOT EXISTS events (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
event_type TEXT NOT NULL,
|
event_type TEXT NOT NULL,
|
||||||
@@ -83,7 +84,8 @@ def _ensure_db() -> sqlite3.Connection:
|
|||||||
data TEXT DEFAULT '{}',
|
data TEXT DEFAULT '{}',
|
||||||
timestamp TEXT NOT NULL
|
timestamp TEXT NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -119,8 +121,15 @@ def log_event(
|
|||||||
db.execute(
|
db.execute(
|
||||||
"INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) "
|
"INSERT INTO events (id, event_type, source, task_id, agent_id, data, timestamp) "
|
||||||
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
"VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||||
(entry.id, event_type.value, source, task_id, agent_id,
|
(
|
||||||
json.dumps(data or {}), entry.timestamp),
|
entry.id,
|
||||||
|
event_type.value,
|
||||||
|
source,
|
||||||
|
task_id,
|
||||||
|
agent_id,
|
||||||
|
json.dumps(data or {}),
|
||||||
|
entry.timestamp,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -131,6 +140,7 @@ def log_event(
|
|||||||
# Broadcast to WebSocket clients (non-blocking)
|
# Broadcast to WebSocket clients (non-blocking)
|
||||||
try:
|
try:
|
||||||
from infrastructure.events.broadcaster import event_broadcaster
|
from infrastructure.events.broadcaster import event_broadcaster
|
||||||
|
|
||||||
event_broadcaster.broadcast_sync(entry)
|
event_broadcaster.broadcast_sync(entry)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -157,13 +167,15 @@ def get_task_events(task_id: str, limit: int = 50) -> list[EventLogEntry]:
|
|||||||
et = EventType(r["event_type"])
|
et = EventType(r["event_type"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
et = EventType.SYSTEM_INFO
|
et = EventType.SYSTEM_INFO
|
||||||
entries.append(EventLogEntry(
|
entries.append(
|
||||||
id=r["id"],
|
EventLogEntry(
|
||||||
event_type=et,
|
id=r["id"],
|
||||||
source=r["source"],
|
event_type=et,
|
||||||
timestamp=r["timestamp"],
|
source=r["source"],
|
||||||
data=json.loads(r["data"]) if r["data"] else {},
|
timestamp=r["timestamp"],
|
||||||
task_id=r["task_id"],
|
data=json.loads(r["data"]) if r["data"] else {},
|
||||||
agent_id=r["agent_id"],
|
task_id=r["task_id"],
|
||||||
))
|
agent_id=r["agent_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
return entries
|
return entries
|
||||||
|
|||||||
@@ -29,7 +29,8 @@ def _ensure_db() -> sqlite3.Connection:
|
|||||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(DB_PATH))
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS tasks (
|
CREATE TABLE IF NOT EXISTS tasks (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
title TEXT NOT NULL,
|
title TEXT NOT NULL,
|
||||||
@@ -42,7 +43,8 @@ def _ensure_db() -> sqlite3.Connection:
|
|||||||
created_at TEXT DEFAULT (datetime('now')),
|
created_at TEXT DEFAULT (datetime('now')),
|
||||||
completed_at TEXT
|
completed_at TEXT
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -103,9 +105,7 @@ def get_task_summary_for_briefing() -> dict:
|
|||||||
"""Return a summary of task counts by status for the morning briefing."""
|
"""Return a summary of task counts by status for the morning briefing."""
|
||||||
db = _ensure_db()
|
db = _ensure_db()
|
||||||
try:
|
try:
|
||||||
rows = db.execute(
|
rows = db.execute("SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status").fetchall()
|
||||||
"SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status"
|
|
||||||
).fetchall()
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|||||||
@@ -69,16 +69,16 @@ def _check_model_available(model_name: str) -> bool:
|
|||||||
|
|
||||||
def _pull_model(model_name: str) -> bool:
|
def _pull_model(model_name: str) -> bool:
|
||||||
"""Attempt to pull a model from Ollama.
|
"""Attempt to pull a model from Ollama.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful or model already exists
|
True if successful or model already exists
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import urllib.request
|
|
||||||
import json
|
import json
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
logger.info("Pulling model: %s", model_name)
|
logger.info("Pulling model: %s", model_name)
|
||||||
|
|
||||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
req = urllib.request.Request(
|
req = urllib.request.Request(
|
||||||
f"{url}/api/pull",
|
f"{url}/api/pull",
|
||||||
@@ -86,7 +86,7 @@ def _pull_model(model_name: str) -> bool:
|
|||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with urllib.request.urlopen(req, timeout=300) as response:
|
with urllib.request.urlopen(req, timeout=300) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
logger.info("Successfully pulled model: %s", model_name)
|
logger.info("Successfully pulled model: %s", model_name)
|
||||||
@@ -94,7 +94,7 @@ def _pull_model(model_name: str) -> bool:
|
|||||||
else:
|
else:
|
||||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||||
return False
|
return False
|
||||||
@@ -106,53 +106,44 @@ def _resolve_model_with_fallback(
|
|||||||
auto_pull: bool = True,
|
auto_pull: bool = True,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Resolve model with automatic pulling and fallback.
|
"""Resolve model with automatic pulling and fallback.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requested_model: Preferred model to use
|
requested_model: Preferred model to use
|
||||||
require_vision: Whether the model needs vision capabilities
|
require_vision: Whether the model needs vision capabilities
|
||||||
auto_pull: Whether to attempt pulling missing models
|
auto_pull: Whether to attempt pulling missing models
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (model_name, is_fallback)
|
Tuple of (model_name, is_fallback)
|
||||||
"""
|
"""
|
||||||
model = requested_model or settings.ollama_model
|
model = requested_model or settings.ollama_model
|
||||||
|
|
||||||
# Check if requested model is available
|
# Check if requested model is available
|
||||||
if _check_model_available(model):
|
if _check_model_available(model):
|
||||||
logger.debug("Using available model: %s", model)
|
logger.debug("Using available model: %s", model)
|
||||||
return model, False
|
return model, False
|
||||||
|
|
||||||
# Try to pull the requested model
|
# Try to pull the requested model
|
||||||
if auto_pull:
|
if auto_pull:
|
||||||
logger.info("Model %s not available locally, attempting to pull...", model)
|
logger.info("Model %s not available locally, attempting to pull...", model)
|
||||||
if _pull_model(model):
|
if _pull_model(model):
|
||||||
return model, False
|
return model, False
|
||||||
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
||||||
|
|
||||||
# Use appropriate fallback chain
|
# Use appropriate fallback chain
|
||||||
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
||||||
|
|
||||||
for fallback_model in fallback_chain:
|
for fallback_model in fallback_chain:
|
||||||
if _check_model_available(fallback_model):
|
if _check_model_available(fallback_model):
|
||||||
logger.warning(
|
logger.warning("Using fallback model %s (requested: %s)", fallback_model, model)
|
||||||
"Using fallback model %s (requested: %s)",
|
|
||||||
fallback_model, model
|
|
||||||
)
|
|
||||||
return fallback_model, True
|
return fallback_model, True
|
||||||
|
|
||||||
# Try to pull the fallback
|
# Try to pull the fallback
|
||||||
if auto_pull and _pull_model(fallback_model):
|
if auto_pull and _pull_model(fallback_model):
|
||||||
logger.info(
|
logger.info("Pulled and using fallback model %s (requested: %s)", fallback_model, model)
|
||||||
"Pulled and using fallback model %s (requested: %s)",
|
|
||||||
fallback_model, model
|
|
||||||
)
|
|
||||||
return fallback_model, True
|
return fallback_model, True
|
||||||
|
|
||||||
# Absolute last resort - return the requested model and hope for the best
|
# Absolute last resort - return the requested model and hope for the best
|
||||||
logger.error(
|
logger.error("No models available in fallback chain. Requested: %s", model)
|
||||||
"No models available in fallback chain. Requested: %s",
|
|
||||||
model
|
|
||||||
)
|
|
||||||
return model, False
|
return model, False
|
||||||
|
|
||||||
|
|
||||||
@@ -190,6 +181,7 @@ def _resolve_backend(requested: str | None) -> str:
|
|||||||
|
|
||||||
# "auto" path — lazy import to keep startup fast and tests clean.
|
# "auto" path — lazy import to keep startup fast and tests clean.
|
||||||
from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon
|
from timmy.backends import airllm_available, claude_available, grok_available, is_apple_silicon
|
||||||
|
|
||||||
if is_apple_silicon() and airllm_available():
|
if is_apple_silicon() and airllm_available():
|
||||||
return "airllm"
|
return "airllm"
|
||||||
return "ollama"
|
return "ollama"
|
||||||
@@ -215,14 +207,17 @@ def create_timmy(
|
|||||||
|
|
||||||
if resolved == "claude":
|
if resolved == "claude":
|
||||||
from timmy.backends import ClaudeBackend
|
from timmy.backends import ClaudeBackend
|
||||||
|
|
||||||
return ClaudeBackend()
|
return ClaudeBackend()
|
||||||
|
|
||||||
if resolved == "grok":
|
if resolved == "grok":
|
||||||
from timmy.backends import GrokBackend
|
from timmy.backends import GrokBackend
|
||||||
|
|
||||||
return GrokBackend()
|
return GrokBackend()
|
||||||
|
|
||||||
if resolved == "airllm":
|
if resolved == "airllm":
|
||||||
from timmy.backends import TimmyAirLLMAgent
|
from timmy.backends import TimmyAirLLMAgent
|
||||||
|
|
||||||
return TimmyAirLLMAgent(model_size=size)
|
return TimmyAirLLMAgent(model_size=size)
|
||||||
|
|
||||||
# Default: Ollama via Agno.
|
# Default: Ollama via Agno.
|
||||||
@@ -236,16 +231,16 @@ def create_timmy(
|
|||||||
# If Ollama is completely unreachable, fall back to Claude if available
|
# If Ollama is completely unreachable, fall back to Claude if available
|
||||||
if not _check_model_available(model_name):
|
if not _check_model_available(model_name):
|
||||||
from timmy.backends import claude_available
|
from timmy.backends import claude_available
|
||||||
|
|
||||||
if claude_available():
|
if claude_available():
|
||||||
logger.warning(
|
logger.warning("Ollama unreachable — falling back to Claude backend")
|
||||||
"Ollama unreachable — falling back to Claude backend"
|
|
||||||
)
|
|
||||||
from timmy.backends import ClaudeBackend
|
from timmy.backends import ClaudeBackend
|
||||||
|
|
||||||
return ClaudeBackend()
|
return ClaudeBackend()
|
||||||
|
|
||||||
if is_fallback:
|
if is_fallback:
|
||||||
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
||||||
|
|
||||||
use_tools = _model_supports_tools(model_name)
|
use_tools = _model_supports_tools(model_name)
|
||||||
|
|
||||||
# Conditionally include tools — small models get none
|
# Conditionally include tools — small models get none
|
||||||
@@ -259,6 +254,7 @@ def create_timmy(
|
|||||||
# Try to load memory context
|
# Try to load memory context
|
||||||
try:
|
try:
|
||||||
from timmy.memory_system import memory_system
|
from timmy.memory_system import memory_system
|
||||||
|
|
||||||
memory_context = memory_system.get_system_context()
|
memory_context = memory_system.get_system_context()
|
||||||
if memory_context:
|
if memory_context:
|
||||||
# Truncate if too long — smaller budget for small models
|
# Truncate if too long — smaller budget for small models
|
||||||
@@ -290,32 +286,32 @@ def create_timmy(
|
|||||||
|
|
||||||
class TimmyWithMemory:
|
class TimmyWithMemory:
|
||||||
"""Agent wrapper with explicit three-tier memory management."""
|
"""Agent wrapper with explicit three-tier memory management."""
|
||||||
|
|
||||||
def __init__(self, db_file: str = "timmy.db") -> None:
|
def __init__(self, db_file: str = "timmy.db") -> None:
|
||||||
from timmy.memory_system import memory_system
|
from timmy.memory_system import memory_system
|
||||||
|
|
||||||
self.agent = create_timmy(db_file=db_file)
|
self.agent = create_timmy(db_file=db_file)
|
||||||
self.memory = memory_system
|
self.memory = memory_system
|
||||||
self.session_active = True
|
self.session_active = True
|
||||||
|
|
||||||
# Store initial context for reference
|
# Store initial context for reference
|
||||||
self.initial_context = self.memory.get_system_context()
|
self.initial_context = self.memory.get_system_context()
|
||||||
|
|
||||||
def chat(self, message: str) -> str:
|
def chat(self, message: str) -> str:
|
||||||
"""Simple chat interface that tracks in memory."""
|
"""Simple chat interface that tracks in memory."""
|
||||||
# Check for user facts to extract
|
# Check for user facts to extract
|
||||||
self._extract_and_store_facts(message)
|
self._extract_and_store_facts(message)
|
||||||
|
|
||||||
# Run agent
|
# Run agent
|
||||||
result = self.agent.run(message, stream=False)
|
result = self.agent.run(message, stream=False)
|
||||||
response_text = result.content if hasattr(result, "content") else str(result)
|
response_text = result.content if hasattr(result, "content") else str(result)
|
||||||
|
|
||||||
return response_text
|
return response_text
|
||||||
|
|
||||||
def _extract_and_store_facts(self, message: str) -> None:
|
def _extract_and_store_facts(self, message: str) -> None:
|
||||||
"""Extract user facts from message and store in memory."""
|
"""Extract user facts from message and store in memory."""
|
||||||
message_lower = message.lower()
|
message_lower = message.lower()
|
||||||
|
|
||||||
# Extract name
|
# Extract name
|
||||||
name_patterns = [
|
name_patterns = [
|
||||||
("my name is ", 11),
|
("my name is ", 11),
|
||||||
@@ -323,7 +319,7 @@ class TimmyWithMemory:
|
|||||||
("i am ", 5),
|
("i am ", 5),
|
||||||
("call me ", 8),
|
("call me ", 8),
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern, offset in name_patterns:
|
for pattern, offset in name_patterns:
|
||||||
if pattern in message_lower:
|
if pattern in message_lower:
|
||||||
idx = message_lower.find(pattern) + offset
|
idx = message_lower.find(pattern) + offset
|
||||||
@@ -332,7 +328,7 @@ class TimmyWithMemory:
|
|||||||
self.memory.update_user_fact("Name", name)
|
self.memory.update_user_fact("Name", name)
|
||||||
self.memory.record_decision(f"Learned user's name: {name}")
|
self.memory.record_decision(f"Learned user's name: {name}")
|
||||||
break
|
break
|
||||||
|
|
||||||
# Extract preferences
|
# Extract preferences
|
||||||
pref_patterns = [
|
pref_patterns = [
|
||||||
("i like ", "Likes"),
|
("i like ", "Likes"),
|
||||||
@@ -341,7 +337,7 @@ class TimmyWithMemory:
|
|||||||
("i don't like ", "Dislikes"),
|
("i don't like ", "Dislikes"),
|
||||||
("i hate ", "Dislikes"),
|
("i hate ", "Dislikes"),
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern, category in pref_patterns:
|
for pattern, category in pref_patterns:
|
||||||
if pattern in message_lower:
|
if pattern in message_lower:
|
||||||
idx = message_lower.find(pattern) + len(pattern)
|
idx = message_lower.find(pattern) + len(pattern)
|
||||||
@@ -349,16 +345,16 @@ class TimmyWithMemory:
|
|||||||
if pref and len(pref) > 3:
|
if pref and len(pref) > 3:
|
||||||
self.memory.record_open_item(f"User {category.lower()}: {pref}")
|
self.memory.record_open_item(f"User {category.lower()}: {pref}")
|
||||||
break
|
break
|
||||||
|
|
||||||
def end_session(self, summary: str = "Session completed") -> None:
|
def end_session(self, summary: str = "Session completed") -> None:
|
||||||
"""End session and write handoff."""
|
"""End session and write handoff."""
|
||||||
if self.session_active:
|
if self.session_active:
|
||||||
self.memory.end_session(summary)
|
self.memory.end_session(summary)
|
||||||
self.session_active = False
|
self.session_active = False
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.end_session()
|
self.end_session()
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -16,38 +16,41 @@ Architecture:
|
|||||||
All methods return effects that can be logged, audited, and replayed.
|
All methods return effects that can be logged, audited, and replayed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
|
||||||
|
|
||||||
|
|
||||||
class PerceptionType(Enum):
|
class PerceptionType(Enum):
|
||||||
"""Types of sensory input an agent can receive."""
|
"""Types of sensory input an agent can receive."""
|
||||||
TEXT = auto() # Natural language
|
|
||||||
IMAGE = auto() # Visual input
|
TEXT = auto() # Natural language
|
||||||
AUDIO = auto() # Sound/speech
|
IMAGE = auto() # Visual input
|
||||||
SENSOR = auto() # Temperature, distance, etc.
|
AUDIO = auto() # Sound/speech
|
||||||
MOTION = auto() # Accelerometer, gyroscope
|
SENSOR = auto() # Temperature, distance, etc.
|
||||||
NETWORK = auto() # API calls, messages
|
MOTION = auto() # Accelerometer, gyroscope
|
||||||
INTERNAL = auto() # Self-monitoring (battery, temp)
|
NETWORK = auto() # API calls, messages
|
||||||
|
INTERNAL = auto() # Self-monitoring (battery, temp)
|
||||||
|
|
||||||
|
|
||||||
class ActionType(Enum):
|
class ActionType(Enum):
|
||||||
"""Types of actions an agent can perform."""
|
"""Types of actions an agent can perform."""
|
||||||
TEXT = auto() # Generate text response
|
|
||||||
SPEAK = auto() # Text-to-speech
|
TEXT = auto() # Generate text response
|
||||||
MOVE = auto() # Physical movement
|
SPEAK = auto() # Text-to-speech
|
||||||
GRIP = auto() # Manipulate objects
|
MOVE = auto() # Physical movement
|
||||||
CALL = auto() # API/network call
|
GRIP = auto() # Manipulate objects
|
||||||
EMIT = auto() # Signal/light/sound
|
CALL = auto() # API/network call
|
||||||
SLEEP = auto() # Power management
|
EMIT = auto() # Signal/light/sound
|
||||||
|
SLEEP = auto() # Power management
|
||||||
|
|
||||||
|
|
||||||
class AgentCapability(Enum):
|
class AgentCapability(Enum):
|
||||||
"""High-level capabilities a TimAgent may possess."""
|
"""High-level capabilities a TimAgent may possess."""
|
||||||
|
|
||||||
REASONING = "reasoning"
|
REASONING = "reasoning"
|
||||||
CODING = "coding"
|
CODING = "coding"
|
||||||
WRITING = "writing"
|
WRITING = "writing"
|
||||||
@@ -63,15 +66,16 @@ class AgentCapability(Enum):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AgentIdentity:
|
class AgentIdentity:
|
||||||
"""Immutable identity for an agent instance.
|
"""Immutable identity for an agent instance.
|
||||||
|
|
||||||
This persists across sessions and substrates. If Timmy moves
|
This persists across sessions and substrates. If Timmy moves
|
||||||
from cloud to robot, the identity follows.
|
from cloud to robot, the identity follows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
version: str
|
version: str
|
||||||
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity":
|
def generate(cls, name: str, version: str = "1.0.0") -> "AgentIdentity":
|
||||||
"""Generate a new unique identity."""
|
"""Generate a new unique identity."""
|
||||||
@@ -85,16 +89,17 @@ class AgentIdentity:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Perception:
|
class Perception:
|
||||||
"""A sensory input to the agent.
|
"""A sensory input to the agent.
|
||||||
|
|
||||||
Substrate-agnostic representation. A camera image and a
|
Substrate-agnostic representation. A camera image and a
|
||||||
LiDAR point cloud are both Perception instances.
|
LiDAR point cloud are both Perception instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: PerceptionType
|
type: PerceptionType
|
||||||
data: Any # Content depends on type
|
data: Any # Content depends on type
|
||||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
source: str = "unknown" # e.g., "camera_1", "microphone", "user_input"
|
source: str = "unknown" # e.g., "camera_1", "microphone", "user_input"
|
||||||
metadata: dict = field(default_factory=dict)
|
metadata: dict = field(default_factory=dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def text(cls, content: str, source: str = "user") -> "Perception":
|
def text(cls, content: str, source: str = "user") -> "Perception":
|
||||||
"""Factory for text perception."""
|
"""Factory for text perception."""
|
||||||
@@ -103,7 +108,7 @@ class Perception:
|
|||||||
data=content,
|
data=content,
|
||||||
source=source,
|
source=source,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception":
|
def sensor(cls, kind: str, value: float, unit: str = "") -> "Perception":
|
||||||
"""Factory for sensor readings."""
|
"""Factory for sensor readings."""
|
||||||
@@ -117,16 +122,17 @@ class Perception:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Action:
|
class Action:
|
||||||
"""An action the agent intends to perform.
|
"""An action the agent intends to perform.
|
||||||
|
|
||||||
Actions are effects — they describe what should happen,
|
Actions are effects — they describe what should happen,
|
||||||
not how. The substrate implements the "how."
|
not how. The substrate implements the "how."
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ActionType
|
type: ActionType
|
||||||
payload: Any # Action-specific data
|
payload: Any # Action-specific data
|
||||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
confidence: float = 1.0 # 0-1, agent's certainty
|
confidence: float = 1.0 # 0-1, agent's certainty
|
||||||
deadline: Optional[str] = None # When action must complete
|
deadline: Optional[str] = None # When action must complete
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def respond(cls, text: str, confidence: float = 1.0) -> "Action":
|
def respond(cls, text: str, confidence: float = 1.0) -> "Action":
|
||||||
"""Factory for text response action."""
|
"""Factory for text response action."""
|
||||||
@@ -135,7 +141,7 @@ class Action:
|
|||||||
payload=text,
|
payload=text,
|
||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action":
|
def move(cls, vector: tuple[float, float, float], speed: float = 1.0) -> "Action":
|
||||||
"""Factory for movement action (x, y, z meters)."""
|
"""Factory for movement action (x, y, z meters)."""
|
||||||
@@ -148,10 +154,11 @@ class Action:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Memory:
|
class Memory:
|
||||||
"""A stored experience or fact.
|
"""A stored experience or fact.
|
||||||
|
|
||||||
Memories are substrate-agnostic. A conversation history
|
Memories are substrate-agnostic. A conversation history
|
||||||
and a video recording are both Memory instances.
|
and a video recording are both Memory instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: Any
|
content: Any
|
||||||
created_at: str
|
created_at: str
|
||||||
@@ -159,7 +166,7 @@ class Memory:
|
|||||||
last_accessed: Optional[str] = None
|
last_accessed: Optional[str] = None
|
||||||
importance: float = 0.5 # 0-1, for pruning decisions
|
importance: float = 0.5 # 0-1, for pruning decisions
|
||||||
tags: list[str] = field(default_factory=list)
|
tags: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
def touch(self) -> None:
|
def touch(self) -> None:
|
||||||
"""Mark memory as accessed."""
|
"""Mark memory as accessed."""
|
||||||
self.access_count += 1
|
self.access_count += 1
|
||||||
@@ -169,6 +176,7 @@ class Memory:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Communication:
|
class Communication:
|
||||||
"""A message to/from another agent or human."""
|
"""A message to/from another agent or human."""
|
||||||
|
|
||||||
sender: str
|
sender: str
|
||||||
recipient: str
|
recipient: str
|
||||||
content: Any
|
content: Any
|
||||||
@@ -179,132 +187,132 @@ class Communication:
|
|||||||
|
|
||||||
class TimAgent(ABC):
|
class TimAgent(ABC):
|
||||||
"""Abstract base class for all Timmy agent implementations.
|
"""Abstract base class for all Timmy agent implementations.
|
||||||
|
|
||||||
This is the substrate-agnostic interface. Implementations:
|
This is the substrate-agnostic interface. Implementations:
|
||||||
- OllamaAgent: LLM-based reasoning (today)
|
- OllamaAgent: LLM-based reasoning (today)
|
||||||
- RobotAgent: Physical embodiment (future)
|
- RobotAgent: Physical embodiment (future)
|
||||||
- SimulationAgent: Virtual environment (future)
|
- SimulationAgent: Virtual environment (future)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
agent = OllamaAgent(identity) # Today's implementation
|
agent = OllamaAgent(identity) # Today's implementation
|
||||||
|
|
||||||
perception = Perception.text("Hello Timmy")
|
perception = Perception.text("Hello Timmy")
|
||||||
memory = agent.perceive(perception)
|
memory = agent.perceive(perception)
|
||||||
|
|
||||||
action = agent.reason("How should I respond?")
|
action = agent.reason("How should I respond?")
|
||||||
result = agent.act(action)
|
result = agent.act(action)
|
||||||
|
|
||||||
agent.remember(memory) # Store for future
|
agent.remember(memory) # Store for future
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, identity: AgentIdentity) -> None:
|
def __init__(self, identity: AgentIdentity) -> None:
|
||||||
self._identity = identity
|
self._identity = identity
|
||||||
self._capabilities: set[AgentCapability] = set()
|
self._capabilities: set[AgentCapability] = set()
|
||||||
self._state: dict[str, Any] = {}
|
self._state: dict[str, Any] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identity(self) -> AgentIdentity:
|
def identity(self) -> AgentIdentity:
|
||||||
"""Return this agent's immutable identity."""
|
"""Return this agent's immutable identity."""
|
||||||
return self._identity
|
return self._identity
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def capabilities(self) -> set[AgentCapability]:
|
def capabilities(self) -> set[AgentCapability]:
|
||||||
"""Return set of supported capabilities."""
|
"""Return set of supported capabilities."""
|
||||||
return self._capabilities.copy()
|
return self._capabilities.copy()
|
||||||
|
|
||||||
def has_capability(self, capability: AgentCapability) -> bool:
|
def has_capability(self, capability: AgentCapability) -> bool:
|
||||||
"""Check if agent supports a capability."""
|
"""Check if agent supports a capability."""
|
||||||
return capability in self._capabilities
|
return capability in self._capabilities
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def perceive(self, perception: Perception) -> Memory:
|
def perceive(self, perception: Perception) -> Memory:
|
||||||
"""Process sensory input and create a memory.
|
"""Process sensory input and create a memory.
|
||||||
|
|
||||||
This is the entry point for all agent interaction.
|
This is the entry point for all agent interaction.
|
||||||
A text message, camera frame, or temperature reading
|
A text message, camera frame, or temperature reading
|
||||||
all enter through perceive().
|
all enter through perceive().
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
perception: Sensory input
|
perception: Sensory input
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Memory: Stored representation of the perception
|
Memory: Stored representation of the perception
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reason(self, query: str, context: list[Memory]) -> Action:
|
def reason(self, query: str, context: list[Memory]) -> Action:
|
||||||
"""Reason about a situation and decide on action.
|
"""Reason about a situation and decide on action.
|
||||||
|
|
||||||
This is where "thinking" happens. The agent uses its
|
This is where "thinking" happens. The agent uses its
|
||||||
substrate-appropriate reasoning (LLM, neural net, rules)
|
substrate-appropriate reasoning (LLM, neural net, rules)
|
||||||
to decide what to do.
|
to decide what to do.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: What to reason about
|
query: What to reason about
|
||||||
context: Relevant memories for context
|
context: Relevant memories for context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Action: What the agent decides to do
|
Action: What the agent decides to do
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def act(self, action: Action) -> Any:
|
def act(self, action: Action) -> Any:
|
||||||
"""Execute an action in the substrate.
|
"""Execute an action in the substrate.
|
||||||
|
|
||||||
This is where the abstract action becomes concrete:
|
This is where the abstract action becomes concrete:
|
||||||
- TEXT → Generate LLM response
|
- TEXT → Generate LLM response
|
||||||
- MOVE → Send motor commands
|
- MOVE → Send motor commands
|
||||||
- SPEAK → Call TTS engine
|
- SPEAK → Call TTS engine
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
action: The action to execute
|
action: The action to execute
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Result of the action (substrate-specific)
|
Result of the action (substrate-specific)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remember(self, memory: Memory) -> None:
|
def remember(self, memory: Memory) -> None:
|
||||||
"""Store a memory for future retrieval.
|
"""Store a memory for future retrieval.
|
||||||
|
|
||||||
The storage mechanism depends on substrate:
|
The storage mechanism depends on substrate:
|
||||||
- Cloud: SQLite, vector DB
|
- Cloud: SQLite, vector DB
|
||||||
- Robot: Local flash storage
|
- Robot: Local flash storage
|
||||||
- Hybrid: Synced with conflict resolution
|
- Hybrid: Synced with conflict resolution
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory: Experience to store
|
memory: Experience to store
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
||||||
"""Retrieve relevant memories.
|
"""Retrieve relevant memories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: What to search for
|
query: What to search for
|
||||||
limit: Maximum memories to return
|
limit: Maximum memories to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of relevant memories, sorted by relevance
|
List of relevant memories, sorted by relevance
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def communicate(self, message: Communication) -> bool:
|
def communicate(self, message: Communication) -> bool:
|
||||||
"""Send/receive communication with another agent.
|
"""Send/receive communication with another agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: Message to send
|
message: Message to send
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if communication succeeded
|
True if communication succeeded
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_state(self) -> dict[str, Any]:
|
def get_state(self) -> dict[str, Any]:
|
||||||
"""Get current agent state for monitoring/debugging."""
|
"""Get current agent state for monitoring/debugging."""
|
||||||
return {
|
return {
|
||||||
@@ -312,7 +320,7 @@ class TimAgent(ABC):
|
|||||||
"capabilities": list(self._capabilities),
|
"capabilities": list(self._capabilities),
|
||||||
"state": self._state.copy(),
|
"state": self._state.copy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Graceful shutdown. Persist state, close connections."""
|
"""Graceful shutdown. Persist state, close connections."""
|
||||||
# Override in subclass for cleanup
|
# Override in subclass for cleanup
|
||||||
@@ -321,7 +329,7 @@ class TimAgent(ABC):
|
|||||||
|
|
||||||
class AgentEffect:
|
class AgentEffect:
|
||||||
"""Log entry for agent actions — for audit and replay.
|
"""Log entry for agent actions — for audit and replay.
|
||||||
|
|
||||||
The complete history of an agent's life can be captured
|
The complete history of an agent's life can be captured
|
||||||
as a sequence of AgentEffects. This enables:
|
as a sequence of AgentEffects. This enables:
|
||||||
- Debugging: What did the agent see and do?
|
- Debugging: What did the agent see and do?
|
||||||
@@ -329,40 +337,46 @@ class AgentEffect:
|
|||||||
- Replay: Reconstruct agent state from log
|
- Replay: Reconstruct agent state from log
|
||||||
- Training: Learn from agent experiences
|
- Training: Learn from agent experiences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, log_path: Optional[str] = None) -> None:
|
def __init__(self, log_path: Optional[str] = None) -> None:
|
||||||
self._effects: list[dict] = []
|
self._effects: list[dict] = []
|
||||||
self._log_path = log_path
|
self._log_path = log_path
|
||||||
|
|
||||||
def log_perceive(self, perception: Perception, memory_id: str) -> None:
|
def log_perceive(self, perception: Perception, memory_id: str) -> None:
|
||||||
"""Log a perception event."""
|
"""Log a perception event."""
|
||||||
self._effects.append({
|
self._effects.append(
|
||||||
"type": "perceive",
|
{
|
||||||
"perception_type": perception.type.name,
|
"type": "perceive",
|
||||||
"source": perception.source,
|
"perception_type": perception.type.name,
|
||||||
"memory_id": memory_id,
|
"source": perception.source,
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"memory_id": memory_id,
|
||||||
})
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def log_reason(self, query: str, action_type: ActionType) -> None:
|
def log_reason(self, query: str, action_type: ActionType) -> None:
|
||||||
"""Log a reasoning event."""
|
"""Log a reasoning event."""
|
||||||
self._effects.append({
|
self._effects.append(
|
||||||
"type": "reason",
|
{
|
||||||
"query": query,
|
"type": "reason",
|
||||||
"action_type": action_type.name,
|
"query": query,
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"action_type": action_type.name,
|
||||||
})
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def log_act(self, action: Action, result: Any) -> None:
|
def log_act(self, action: Action, result: Any) -> None:
|
||||||
"""Log an action event."""
|
"""Log an action event."""
|
||||||
self._effects.append({
|
self._effects.append(
|
||||||
"type": "act",
|
{
|
||||||
"action_type": action.type.name,
|
"type": "act",
|
||||||
"confidence": action.confidence,
|
"action_type": action.type.name,
|
||||||
"result_type": type(result).__name__,
|
"confidence": action.confidence,
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"result_type": type(result).__name__,
|
||||||
})
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def export(self) -> list[dict]:
|
def export(self) -> list[dict]:
|
||||||
"""Export effect log for analysis."""
|
"""Export effect log for analysis."""
|
||||||
return self._effects.copy()
|
return self._effects.copy()
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ between the old codebase and the new embodiment-ready architecture.
|
|||||||
Usage:
|
Usage:
|
||||||
from timmy.agent_core import AgentIdentity, Perception
|
from timmy.agent_core import AgentIdentity, Perception
|
||||||
from timmy.agent_core.ollama_adapter import OllamaAgent
|
from timmy.agent_core.ollama_adapter import OllamaAgent
|
||||||
|
|
||||||
identity = AgentIdentity.generate("Timmy")
|
identity = AgentIdentity.generate("Timmy")
|
||||||
agent = OllamaAgent(identity)
|
agent = OllamaAgent(identity)
|
||||||
|
|
||||||
perception = Perception.text("Hello!")
|
perception = Perception.text("Hello!")
|
||||||
memory = agent.perceive(perception)
|
memory = agent.perceive(perception)
|
||||||
action = agent.reason("How should I respond?", [memory])
|
action = agent.reason("How should I respond?", [memory])
|
||||||
@@ -19,27 +19,27 @@ Usage:
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from timmy.agent import _resolve_model_with_fallback, create_timmy
|
||||||
from timmy.agent_core.interface import (
|
from timmy.agent_core.interface import (
|
||||||
AgentCapability,
|
|
||||||
AgentIdentity,
|
|
||||||
Perception,
|
|
||||||
PerceptionType,
|
|
||||||
Action,
|
Action,
|
||||||
ActionType,
|
ActionType,
|
||||||
Memory,
|
AgentCapability,
|
||||||
Communication,
|
|
||||||
TimAgent,
|
|
||||||
AgentEffect,
|
AgentEffect,
|
||||||
|
AgentIdentity,
|
||||||
|
Communication,
|
||||||
|
Memory,
|
||||||
|
Perception,
|
||||||
|
PerceptionType,
|
||||||
|
TimAgent,
|
||||||
)
|
)
|
||||||
from timmy.agent import create_timmy, _resolve_model_with_fallback
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaAgent(TimAgent):
|
class OllamaAgent(TimAgent):
|
||||||
"""TimAgent implementation using local Ollama LLM.
|
"""TimAgent implementation using local Ollama LLM.
|
||||||
|
|
||||||
This is the production agent for Timmy Time v2. It uses
|
This is the production agent for Timmy Time v2. It uses
|
||||||
Ollama for reasoning and SQLite for memory persistence.
|
Ollama for reasoning and SQLite for memory persistence.
|
||||||
|
|
||||||
Capabilities:
|
Capabilities:
|
||||||
- REASONING: LLM-based inference
|
- REASONING: LLM-based inference
|
||||||
- CODING: Code generation and analysis
|
- CODING: Code generation and analysis
|
||||||
@@ -47,7 +47,7 @@ class OllamaAgent(TimAgent):
|
|||||||
- ANALYSIS: Data processing and insights
|
- ANALYSIS: Data processing and insights
|
||||||
- COMMUNICATION: Multi-agent messaging
|
- COMMUNICATION: Multi-agent messaging
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
identity: AgentIdentity,
|
identity: AgentIdentity,
|
||||||
@@ -56,7 +56,7 @@ class OllamaAgent(TimAgent):
|
|||||||
require_vision: bool = False,
|
require_vision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Ollama-based agent.
|
"""Initialize Ollama-based agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identity: Agent identity (persistent across sessions)
|
identity: Agent identity (persistent across sessions)
|
||||||
model: Ollama model to use (auto-resolves with fallback)
|
model: Ollama model to use (auto-resolves with fallback)
|
||||||
@@ -64,23 +64,24 @@ class OllamaAgent(TimAgent):
|
|||||||
require_vision: Whether to select a vision-capable model
|
require_vision: Whether to select a vision-capable model
|
||||||
"""
|
"""
|
||||||
super().__init__(identity)
|
super().__init__(identity)
|
||||||
|
|
||||||
# Resolve model with automatic pulling and fallback
|
# Resolve model with automatic pulling and fallback
|
||||||
resolved_model, is_fallback = _resolve_model_with_fallback(
|
resolved_model, is_fallback = _resolve_model_with_fallback(
|
||||||
requested_model=model,
|
requested_model=model,
|
||||||
require_vision=require_vision,
|
require_vision=require_vision,
|
||||||
auto_pull=True,
|
auto_pull=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_fallback:
|
if is_fallback:
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger(__name__).info(
|
logging.getLogger(__name__).info(
|
||||||
"OllamaAdapter using fallback model %s", resolved_model
|
"OllamaAdapter using fallback model %s", resolved_model
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize underlying Ollama agent
|
# Initialize underlying Ollama agent
|
||||||
self._timmy = create_timmy(model=resolved_model)
|
self._timmy = create_timmy(model=resolved_model)
|
||||||
|
|
||||||
# Set capabilities based on what Ollama can do
|
# Set capabilities based on what Ollama can do
|
||||||
self._capabilities = {
|
self._capabilities = {
|
||||||
AgentCapability.REASONING,
|
AgentCapability.REASONING,
|
||||||
@@ -89,17 +90,17 @@ class OllamaAgent(TimAgent):
|
|||||||
AgentCapability.ANALYSIS,
|
AgentCapability.ANALYSIS,
|
||||||
AgentCapability.COMMUNICATION,
|
AgentCapability.COMMUNICATION,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Effect logging for audit/replay
|
# Effect logging for audit/replay
|
||||||
self._effect_log = AgentEffect(effect_log) if effect_log else None
|
self._effect_log = AgentEffect(effect_log) if effect_log else None
|
||||||
|
|
||||||
# Simple in-memory working memory (short term)
|
# Simple in-memory working memory (short term)
|
||||||
self._working_memory: list[Memory] = []
|
self._working_memory: list[Memory] = []
|
||||||
self._max_working_memory = 10
|
self._max_working_memory = 10
|
||||||
|
|
||||||
def perceive(self, perception: Perception) -> Memory:
|
def perceive(self, perception: Perception) -> Memory:
|
||||||
"""Process perception and store in memory.
|
"""Process perception and store in memory.
|
||||||
|
|
||||||
For text perceptions, we might do light preprocessing
|
For text perceptions, we might do light preprocessing
|
||||||
(summarization, keyword extraction) before storage.
|
(summarization, keyword extraction) before storage.
|
||||||
"""
|
"""
|
||||||
@@ -114,28 +115,28 @@ class OllamaAgent(TimAgent):
|
|||||||
created_at=perception.timestamp,
|
created_at=perception.timestamp,
|
||||||
tags=self._extract_tags(perception),
|
tags=self._extract_tags(perception),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add to working memory
|
# Add to working memory
|
||||||
self._working_memory.append(memory)
|
self._working_memory.append(memory)
|
||||||
if len(self._working_memory) > self._max_working_memory:
|
if len(self._working_memory) > self._max_working_memory:
|
||||||
self._working_memory.pop(0) # FIFO eviction
|
self._working_memory.pop(0) # FIFO eviction
|
||||||
|
|
||||||
# Log effect
|
# Log effect
|
||||||
if self._effect_log:
|
if self._effect_log:
|
||||||
self._effect_log.log_perceive(perception, memory.id)
|
self._effect_log.log_perceive(perception, memory.id)
|
||||||
|
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
def reason(self, query: str, context: list[Memory]) -> Action:
|
def reason(self, query: str, context: list[Memory]) -> Action:
|
||||||
"""Use LLM to reason and decide on action.
|
"""Use LLM to reason and decide on action.
|
||||||
|
|
||||||
This is where the Ollama agent does its work. We construct
|
This is where the Ollama agent does its work. We construct
|
||||||
a prompt from the query and context, then interpret the
|
a prompt from the query and context, then interpret the
|
||||||
response as an action.
|
response as an action.
|
||||||
"""
|
"""
|
||||||
# Build context string from memories
|
# Build context string from memories
|
||||||
context_str = self._format_context(context)
|
context_str = self._format_context(context)
|
||||||
|
|
||||||
# Construct prompt
|
# Construct prompt
|
||||||
prompt = f"""You are {self._identity.name}, an AI assistant.
|
prompt = f"""You are {self._identity.name}, an AI assistant.
|
||||||
|
|
||||||
@@ -145,30 +146,30 @@ Context from previous interactions:
|
|||||||
Current query: {query}
|
Current query: {query}
|
||||||
|
|
||||||
Respond naturally and helpfully."""
|
Respond naturally and helpfully."""
|
||||||
|
|
||||||
# Run LLM inference
|
# Run LLM inference
|
||||||
result = self._timmy.run(prompt, stream=False)
|
result = self._timmy.run(prompt, stream=False)
|
||||||
response_text = result.content if hasattr(result, "content") else str(result)
|
response_text = result.content if hasattr(result, "content") else str(result)
|
||||||
|
|
||||||
# Create text response action
|
# Create text response action
|
||||||
action = Action.respond(response_text, confidence=0.9)
|
action = Action.respond(response_text, confidence=0.9)
|
||||||
|
|
||||||
# Log effect
|
# Log effect
|
||||||
if self._effect_log:
|
if self._effect_log:
|
||||||
self._effect_log.log_reason(query, action.type)
|
self._effect_log.log_reason(query, action.type)
|
||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def act(self, action: Action) -> Any:
|
def act(self, action: Action) -> Any:
|
||||||
"""Execute action in the Ollama substrate.
|
"""Execute action in the Ollama substrate.
|
||||||
|
|
||||||
For text actions, the "execution" is just returning the
|
For text actions, the "execution" is just returning the
|
||||||
text (already generated during reasoning). For future
|
text (already generated during reasoning). For future
|
||||||
action types (MOVE, SPEAK), this would trigger the
|
action types (MOVE, SPEAK), this would trigger the
|
||||||
appropriate Ollama tool calls.
|
appropriate Ollama tool calls.
|
||||||
"""
|
"""
|
||||||
result = None
|
result = None
|
||||||
|
|
||||||
if action.type == ActionType.TEXT:
|
if action.type == ActionType.TEXT:
|
||||||
result = action.payload
|
result = action.payload
|
||||||
elif action.type == ActionType.SPEAK:
|
elif action.type == ActionType.SPEAK:
|
||||||
@@ -179,13 +180,13 @@ Respond naturally and helpfully."""
|
|||||||
result = {"status": "not_implemented", "payload": action.payload}
|
result = {"status": "not_implemented", "payload": action.payload}
|
||||||
else:
|
else:
|
||||||
result = {"error": f"Action type {action.type} not supported by OllamaAgent"}
|
result = {"error": f"Action type {action.type} not supported by OllamaAgent"}
|
||||||
|
|
||||||
# Log effect
|
# Log effect
|
||||||
if self._effect_log:
|
if self._effect_log:
|
||||||
self._effect_log.log_act(action, result)
|
self._effect_log.log_act(action, result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def remember(self, memory: Memory) -> None:
|
def remember(self, memory: Memory) -> None:
|
||||||
"""Store memory in working memory.
|
"""Store memory in working memory.
|
||||||
|
|
||||||
@@ -200,48 +201,48 @@ Respond naturally and helpfully."""
|
|||||||
# Evict oldest if over capacity
|
# Evict oldest if over capacity
|
||||||
if len(self._working_memory) > self._max_working_memory:
|
if len(self._working_memory) > self._max_working_memory:
|
||||||
self._working_memory.pop(0)
|
self._working_memory.pop(0)
|
||||||
|
|
||||||
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
def recall(self, query: str, limit: int = 5) -> list[Memory]:
|
||||||
"""Retrieve relevant memories.
|
"""Retrieve relevant memories.
|
||||||
|
|
||||||
Simple keyword matching for now. Future: vector similarity.
|
Simple keyword matching for now. Future: vector similarity.
|
||||||
"""
|
"""
|
||||||
query_lower = query.lower()
|
query_lower = query.lower()
|
||||||
scored = []
|
scored = []
|
||||||
|
|
||||||
for memory in self._working_memory:
|
for memory in self._working_memory:
|
||||||
score = 0
|
score = 0
|
||||||
content_str = str(memory.content).lower()
|
content_str = str(memory.content).lower()
|
||||||
|
|
||||||
# Simple keyword overlap
|
# Simple keyword overlap
|
||||||
query_words = set(query_lower.split())
|
query_words = set(query_lower.split())
|
||||||
content_words = set(content_str.split())
|
content_words = set(content_str.split())
|
||||||
overlap = len(query_words & content_words)
|
overlap = len(query_words & content_words)
|
||||||
score += overlap
|
score += overlap
|
||||||
|
|
||||||
# Boost recent memories
|
# Boost recent memories
|
||||||
score += memory.importance
|
score += memory.importance
|
||||||
|
|
||||||
scored.append((score, memory))
|
scored.append((score, memory))
|
||||||
|
|
||||||
# Sort by score descending
|
# Sort by score descending
|
||||||
scored.sort(key=lambda x: x[0], reverse=True)
|
scored.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
# Return top N
|
# Return top N
|
||||||
return [m for _, m in scored[:limit]]
|
return [m for _, m in scored[:limit]]
|
||||||
|
|
||||||
def communicate(self, message: Communication) -> bool:
|
def communicate(self, message: Communication) -> bool:
|
||||||
"""Send message to another agent.
|
"""Send message to another agent.
|
||||||
|
|
||||||
Swarm comms removed — inter-agent communication will be handled
|
Swarm comms removed — inter-agent communication will be handled
|
||||||
by the unified brain memory layer.
|
by the unified brain memory layer.
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _extract_tags(self, perception: Perception) -> list[str]:
|
def _extract_tags(self, perception: Perception) -> list[str]:
|
||||||
"""Extract searchable tags from perception."""
|
"""Extract searchable tags from perception."""
|
||||||
tags = [perception.type.name, perception.source]
|
tags = [perception.type.name, perception.source]
|
||||||
|
|
||||||
if perception.type == PerceptionType.TEXT:
|
if perception.type == PerceptionType.TEXT:
|
||||||
# Simple keyword extraction
|
# Simple keyword extraction
|
||||||
text = str(perception.data).lower()
|
text = str(perception.data).lower()
|
||||||
@@ -249,14 +250,14 @@ Respond naturally and helpfully."""
|
|||||||
for kw in keywords:
|
for kw in keywords:
|
||||||
if kw in text:
|
if kw in text:
|
||||||
tags.append(kw)
|
tags.append(kw)
|
||||||
|
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
def _format_context(self, memories: list[Memory]) -> str:
|
def _format_context(self, memories: list[Memory]) -> str:
|
||||||
"""Format memories into context string for prompt."""
|
"""Format memories into context string for prompt."""
|
||||||
if not memories:
|
if not memories:
|
||||||
return "No previous context."
|
return "No previous context."
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
for mem in memories[-5:]: # Last 5 memories
|
for mem in memories[-5:]: # Last 5 memories
|
||||||
if isinstance(mem.content, dict):
|
if isinstance(mem.content, dict):
|
||||||
@@ -264,9 +265,9 @@ Respond naturally and helpfully."""
|
|||||||
parts.append(f"- {data}")
|
parts.append(f"- {data}")
|
||||||
else:
|
else:
|
||||||
parts.append(f"- {mem.content}")
|
parts.append(f"- {mem.content}")
|
||||||
|
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
def get_effect_log(self) -> Optional[list[dict]]:
|
def get_effect_log(self) -> Optional[list[dict]]:
|
||||||
"""Export effect log if logging is enabled."""
|
"""Export effect log if logging is enabled."""
|
||||||
if self._effect_log:
|
if self._effect_log:
|
||||||
|
|||||||
@@ -30,9 +30,11 @@ logger = logging.getLogger(__name__)
|
|||||||
# Data structures
|
# Data structures
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgenticStep:
|
class AgenticStep:
|
||||||
"""Result of a single step in the agentic loop."""
|
"""Result of a single step in the agentic loop."""
|
||||||
|
|
||||||
step_num: int
|
step_num: int
|
||||||
description: str
|
description: str
|
||||||
result: str
|
result: str
|
||||||
@@ -43,6 +45,7 @@ class AgenticStep:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AgenticResult:
|
class AgenticResult:
|
||||||
"""Final result of the entire agentic loop."""
|
"""Final result of the entire agentic loop."""
|
||||||
|
|
||||||
task_id: str
|
task_id: str
|
||||||
task: str
|
task: str
|
||||||
summary: str
|
summary: str
|
||||||
@@ -55,6 +58,7 @@ class AgenticResult:
|
|||||||
# Agent factory
|
# Agent factory
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _get_loop_agent():
|
def _get_loop_agent():
|
||||||
"""Create a fresh agent for the agentic loop.
|
"""Create a fresh agent for the agentic loop.
|
||||||
|
|
||||||
@@ -62,6 +66,7 @@ def _get_loop_agent():
|
|||||||
dedicated session so it doesn't pollute the main chat history.
|
dedicated session so it doesn't pollute the main chat history.
|
||||||
"""
|
"""
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
|
|
||||||
return create_timmy()
|
return create_timmy()
|
||||||
|
|
||||||
|
|
||||||
@@ -85,6 +90,7 @@ def _parse_steps(plan_text: str) -> list[str]:
|
|||||||
# Core loop
|
# Core loop
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def run_agentic_loop(
|
async def run_agentic_loop(
|
||||||
task: str,
|
task: str,
|
||||||
*,
|
*,
|
||||||
@@ -146,12 +152,15 @@ async def run_agentic_loop(
|
|||||||
was_truncated = planned_steps > total_steps
|
was_truncated = planned_steps > total_steps
|
||||||
|
|
||||||
# Broadcast plan
|
# Broadcast plan
|
||||||
await _broadcast_progress("agentic.plan_ready", {
|
await _broadcast_progress(
|
||||||
"task_id": task_id,
|
"agentic.plan_ready",
|
||||||
"task": task,
|
{
|
||||||
"steps": steps,
|
"task_id": task_id,
|
||||||
"total": total_steps,
|
"task": task,
|
||||||
})
|
"steps": steps,
|
||||||
|
"total": total_steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# ── Phase 2: Execution ─────────────────────────────────────────────────
|
# ── Phase 2: Execution ─────────────────────────────────────────────────
|
||||||
completed_results: list[str] = []
|
completed_results: list[str] = []
|
||||||
@@ -175,6 +184,7 @@ async def run_agentic_loop(
|
|||||||
|
|
||||||
# Clean the response
|
# Clean the response
|
||||||
from timmy.session import _clean_response
|
from timmy.session import _clean_response
|
||||||
|
|
||||||
step_result = _clean_response(step_result)
|
step_result = _clean_response(step_result)
|
||||||
|
|
||||||
step = AgenticStep(
|
step = AgenticStep(
|
||||||
@@ -188,13 +198,16 @@ async def run_agentic_loop(
|
|||||||
completed_results.append(f"Step {i}: {step_result[:200]}")
|
completed_results.append(f"Step {i}: {step_result[:200]}")
|
||||||
|
|
||||||
# Broadcast progress
|
# Broadcast progress
|
||||||
await _broadcast_progress("agentic.step_complete", {
|
await _broadcast_progress(
|
||||||
"task_id": task_id,
|
"agentic.step_complete",
|
||||||
"step": i,
|
{
|
||||||
"total": total_steps,
|
"task_id": task_id,
|
||||||
"description": step_desc,
|
"step": i,
|
||||||
"result": step_result[:200],
|
"total": total_steps,
|
||||||
})
|
"description": step_desc,
|
||||||
|
"result": step_result[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if on_progress:
|
if on_progress:
|
||||||
await on_progress(step_desc, i, total_steps)
|
await on_progress(step_desc, i, total_steps)
|
||||||
@@ -210,11 +223,16 @@ async def run_agentic_loop(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
adapt_run = await asyncio.to_thread(
|
adapt_run = await asyncio.to_thread(
|
||||||
agent.run, adapt_prompt, stream=False,
|
agent.run,
|
||||||
|
adapt_prompt,
|
||||||
|
stream=False,
|
||||||
session_id=f"{session_id}_adapt{i}",
|
session_id=f"{session_id}_adapt{i}",
|
||||||
)
|
)
|
||||||
adapt_result = adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run)
|
adapt_result = (
|
||||||
|
adapt_run.content if hasattr(adapt_run, "content") else str(adapt_run)
|
||||||
|
)
|
||||||
from timmy.session import _clean_response
|
from timmy.session import _clean_response
|
||||||
|
|
||||||
adapt_result = _clean_response(adapt_result)
|
adapt_result = _clean_response(adapt_result)
|
||||||
|
|
||||||
step = AgenticStep(
|
step = AgenticStep(
|
||||||
@@ -227,14 +245,17 @@ async def run_agentic_loop(
|
|||||||
result.steps.append(step)
|
result.steps.append(step)
|
||||||
completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}")
|
completed_results.append(f"Step {i} (adapted): {adapt_result[:200]}")
|
||||||
|
|
||||||
await _broadcast_progress("agentic.step_adapted", {
|
await _broadcast_progress(
|
||||||
"task_id": task_id,
|
"agentic.step_adapted",
|
||||||
"step": i,
|
{
|
||||||
"total": total_steps,
|
"task_id": task_id,
|
||||||
"description": step_desc,
|
"step": i,
|
||||||
"error": str(exc),
|
"total": total_steps,
|
||||||
"adaptation": adapt_result[:200],
|
"description": step_desc,
|
||||||
})
|
"error": str(exc),
|
||||||
|
"adaptation": adapt_result[:200],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if on_progress:
|
if on_progress:
|
||||||
await on_progress(f"[Adapted] {step_desc}", i, total_steps)
|
await on_progress(f"[Adapted] {step_desc}", i, total_steps)
|
||||||
@@ -259,11 +280,16 @@ async def run_agentic_loop(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
summary_run = await asyncio.to_thread(
|
summary_run = await asyncio.to_thread(
|
||||||
agent.run, summary_prompt, stream=False,
|
agent.run,
|
||||||
|
summary_prompt,
|
||||||
|
stream=False,
|
||||||
session_id=f"{session_id}_summary",
|
session_id=f"{session_id}_summary",
|
||||||
)
|
)
|
||||||
result.summary = summary_run.content if hasattr(summary_run, "content") else str(summary_run)
|
result.summary = (
|
||||||
|
summary_run.content if hasattr(summary_run, "content") else str(summary_run)
|
||||||
|
)
|
||||||
from timmy.session import _clean_response
|
from timmy.session import _clean_response
|
||||||
|
|
||||||
result.summary = _clean_response(result.summary)
|
result.summary = _clean_response(result.summary)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Agentic loop summary failed: %s", exc)
|
logger.error("Agentic loop summary failed: %s", exc)
|
||||||
@@ -281,13 +307,16 @@ async def run_agentic_loop(
|
|||||||
|
|
||||||
result.total_duration_ms = int((time.monotonic() - start_time) * 1000)
|
result.total_duration_ms = int((time.monotonic() - start_time) * 1000)
|
||||||
|
|
||||||
await _broadcast_progress("agentic.task_complete", {
|
await _broadcast_progress(
|
||||||
"task_id": task_id,
|
"agentic.task_complete",
|
||||||
"status": result.status,
|
{
|
||||||
"steps_completed": len(result.steps),
|
"task_id": task_id,
|
||||||
"summary": result.summary[:300],
|
"status": result.status,
|
||||||
"duration_ms": result.total_duration_ms,
|
"steps_completed": len(result.steps),
|
||||||
})
|
"summary": result.summary[:300],
|
||||||
|
"duration_ms": result.total_duration_ms,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -296,10 +325,12 @@ async def run_agentic_loop(
|
|||||||
# WebSocket broadcast helper
|
# WebSocket broadcast helper
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
async def _broadcast_progress(event: str, data: dict) -> None:
|
async def _broadcast_progress(event: str, data: dict) -> None:
|
||||||
"""Broadcast agentic loop progress via WebSocket (best-effort)."""
|
"""Broadcast agentic loop progress via WebSocket (best-effort)."""
|
||||||
try:
|
try:
|
||||||
from infrastructure.ws_manager.handler import ws_manager
|
from infrastructure.ws_manager.handler import ws_manager
|
||||||
|
|
||||||
await ws_manager.broadcast(event, data)
|
await ws_manager.broadcast(event, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Agentic loop: WS broadcast failed for %s", event)
|
logger.debug("Agentic loop: WS broadcast failed for %s", event)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from agno.agent import Agent
|
|||||||
from agno.models.ollama import Ollama
|
from agno.models.ollama import Ollama
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from infrastructure.events.bus import EventBus, Event
|
from infrastructure.events.bus import Event, EventBus
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mcp.registry import tool_registry
|
from mcp.registry import tool_registry
|
||||||
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BaseAgent(ABC):
|
class BaseAgent(ABC):
|
||||||
"""Base class for all sub-agents."""
|
"""Base class for all sub-agents."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
@@ -43,15 +43,15 @@ class BaseAgent(ABC):
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.role = role
|
self.role = role
|
||||||
self.tools = tools or []
|
self.tools = tools or []
|
||||||
|
|
||||||
# Create Agno agent
|
# Create Agno agent
|
||||||
self.agent = self._create_agent(system_prompt)
|
self.agent = self._create_agent(system_prompt)
|
||||||
|
|
||||||
# Event bus for communication
|
# Event bus for communication
|
||||||
self.event_bus: Optional[EventBus] = None
|
self.event_bus: Optional[EventBus] = None
|
||||||
|
|
||||||
logger.info("%s agent initialized (id: %s)", name, agent_id)
|
logger.info("%s agent initialized (id: %s)", name, agent_id)
|
||||||
|
|
||||||
def _create_agent(self, system_prompt: str) -> Agent:
|
def _create_agent(self, system_prompt: str) -> Agent:
|
||||||
"""Create the underlying Agno agent."""
|
"""Create the underlying Agno agent."""
|
||||||
# Get tools from registry
|
# Get tools from registry
|
||||||
@@ -60,7 +60,7 @@ class BaseAgent(ABC):
|
|||||||
handler = tool_registry.get_handler(tool_name)
|
handler = tool_registry.get_handler(tool_name)
|
||||||
if handler:
|
if handler:
|
||||||
tool_instances.append(handler)
|
tool_instances.append(handler)
|
||||||
|
|
||||||
return Agent(
|
return Agent(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300),
|
model=Ollama(id=settings.ollama_model, host=settings.ollama_url, timeout=300),
|
||||||
@@ -71,19 +71,19 @@ class BaseAgent(ABC):
|
|||||||
markdown=True,
|
markdown=True,
|
||||||
telemetry=settings.telemetry_enabled,
|
telemetry=settings.telemetry_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
def connect_event_bus(self, bus: EventBus) -> None:
|
def connect_event_bus(self, bus: EventBus) -> None:
|
||||||
"""Connect to the event bus for inter-agent communication."""
|
"""Connect to the event bus for inter-agent communication."""
|
||||||
self.event_bus = bus
|
self.event_bus = bus
|
||||||
|
|
||||||
# Subscribe to relevant events
|
# Subscribe to relevant events
|
||||||
bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message)
|
bus.subscribe(f"agent.{self.agent_id}.*")(self._handle_direct_message)
|
||||||
bus.subscribe("agent.task.assigned")(self._handle_task_assignment)
|
bus.subscribe("agent.task.assigned")(self._handle_task_assignment)
|
||||||
|
|
||||||
async def _handle_direct_message(self, event: Event) -> None:
|
async def _handle_direct_message(self, event: Event) -> None:
|
||||||
"""Handle direct messages to this agent."""
|
"""Handle direct messages to this agent."""
|
||||||
logger.debug("%s received message: %s", self.name, event.type)
|
logger.debug("%s received message: %s", self.name, event.type)
|
||||||
|
|
||||||
async def _handle_task_assignment(self, event: Event) -> None:
|
async def _handle_task_assignment(self, event: Event) -> None:
|
||||||
"""Handle task assignment events."""
|
"""Handle task assignment events."""
|
||||||
assigned_agent = event.data.get("agent_id")
|
assigned_agent = event.data.get("agent_id")
|
||||||
@@ -91,41 +91,43 @@ class BaseAgent(ABC):
|
|||||||
task_id = event.data.get("task_id")
|
task_id = event.data.get("task_id")
|
||||||
description = event.data.get("description", "")
|
description = event.data.get("description", "")
|
||||||
logger.info("%s assigned task %s: %s", self.name, task_id, description[:50])
|
logger.info("%s assigned task %s: %s", self.name, task_id, description[:50])
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task
|
||||||
await self.execute_task(task_id, description, event.data)
|
await self.execute_task(task_id, description, event.data)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||||
"""Execute a task assigned to this agent.
|
"""Execute a task assigned to this agent.
|
||||||
|
|
||||||
Must be implemented by subclasses.
|
Must be implemented by subclasses.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run(self, message: str) -> str:
|
async def run(self, message: str) -> str:
|
||||||
"""Run the agent with a message.
|
"""Run the agent with a message.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Agent response
|
Agent response
|
||||||
"""
|
"""
|
||||||
result = self.agent.run(message, stream=False)
|
result = self.agent.run(message, stream=False)
|
||||||
response = result.content if hasattr(result, "content") else str(result)
|
response = result.content if hasattr(result, "content") else str(result)
|
||||||
|
|
||||||
# Emit completion event
|
# Emit completion event
|
||||||
if self.event_bus:
|
if self.event_bus:
|
||||||
await self.event_bus.publish(Event(
|
await self.event_bus.publish(
|
||||||
type=f"agent.{self.agent_id}.response",
|
Event(
|
||||||
source=self.agent_id,
|
type=f"agent.{self.agent_id}.response",
|
||||||
data={"input": message, "output": response},
|
source=self.agent_id,
|
||||||
))
|
data={"input": message, "output": response},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def get_capabilities(self) -> list[str]:
|
def get_capabilities(self) -> list[str]:
|
||||||
"""Get list of capabilities this agent provides."""
|
"""Get list of capabilities this agent provides."""
|
||||||
return self.tools
|
return self.tools
|
||||||
|
|
||||||
def get_status(self) -> dict:
|
def get_status(self) -> dict:
|
||||||
"""Get current agent status."""
|
"""Get current agent status."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ from typing import Any, Optional
|
|||||||
from agno.agent import Agent
|
from agno.agent import Agent
|
||||||
from agno.models.ollama import Ollama
|
from agno.models.ollama import Ollama
|
||||||
|
|
||||||
from timmy.agents.base import BaseAgent, SubAgent
|
|
||||||
from config import settings
|
from config import settings
|
||||||
from infrastructure.events.bus import EventBus, event_bus
|
from infrastructure.events.bus import EventBus, event_bus
|
||||||
|
from timmy.agents.base import BaseAgent, SubAgent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ _timmy_context: dict[str, Any] = {
|
|||||||
|
|
||||||
async def _load_hands_async() -> list[dict]:
|
async def _load_hands_async() -> list[dict]:
|
||||||
"""Async helper to load hands.
|
"""Async helper to load hands.
|
||||||
|
|
||||||
Hands registry removed — hand definitions live in TOML files under hands/.
|
Hands registry removed — hand definitions live in TOML files under hands/.
|
||||||
This will be rewired to read from brain memory.
|
This will be rewired to read from brain memory.
|
||||||
"""
|
"""
|
||||||
@@ -42,7 +42,7 @@ def build_timmy_context_sync() -> dict[str, Any]:
|
|||||||
Gathers git commits, active sub-agents, and hot memory.
|
Gathers git commits, active sub-agents, and hot memory.
|
||||||
"""
|
"""
|
||||||
global _timmy_context
|
global _timmy_context
|
||||||
|
|
||||||
ctx: dict[str, Any] = {
|
ctx: dict[str, Any] = {
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
"repo_root": settings.repo_root,
|
"repo_root": settings.repo_root,
|
||||||
@@ -51,45 +51,52 @@ def build_timmy_context_sync() -> dict[str, Any]:
|
|||||||
"hands": [],
|
"hands": [],
|
||||||
"memory": "",
|
"memory": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 1. Get recent git commits
|
# 1. Get recent git commits
|
||||||
try:
|
try:
|
||||||
from tools.git_tools import git_log
|
from tools.git_tools import git_log
|
||||||
|
|
||||||
result = git_log(max_count=20)
|
result = git_log(max_count=20)
|
||||||
if result.get("success"):
|
if result.get("success"):
|
||||||
commits = result.get("commits", [])
|
commits = result.get("commits", [])
|
||||||
ctx["git_log"] = "\n".join([
|
ctx["git_log"] = "\n".join(
|
||||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
[f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits[:20]]
|
||||||
for c in commits[:20]
|
)
|
||||||
])
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not load git log for context: %s", exc)
|
logger.warning("Could not load git log for context: %s", exc)
|
||||||
ctx["git_log"] = "(Git log unavailable)"
|
ctx["git_log"] = "(Git log unavailable)"
|
||||||
|
|
||||||
# 2. Get active sub-agents
|
# 2. Get active sub-agents
|
||||||
try:
|
try:
|
||||||
from swarm import registry as swarm_registry
|
from swarm import registry as swarm_registry
|
||||||
|
|
||||||
conn = swarm_registry._get_conn()
|
conn = swarm_registry._get_conn()
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT id, name, status, capabilities FROM agents ORDER BY name"
|
"SELECT id, name, status, capabilities FROM agents ORDER BY name"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
ctx["agents"] = [
|
ctx["agents"] = [
|
||||||
{"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]}
|
{
|
||||||
|
"id": r["id"],
|
||||||
|
"name": r["name"],
|
||||||
|
"status": r["status"],
|
||||||
|
"capabilities": r["capabilities"],
|
||||||
|
}
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
conn.close()
|
conn.close()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not load agents for context: %s", exc)
|
logger.warning("Could not load agents for context: %s", exc)
|
||||||
ctx["agents"] = []
|
ctx["agents"] = []
|
||||||
|
|
||||||
# 3. Read hot memory (via HotMemory to auto-create if missing)
|
# 3. Read hot memory (via HotMemory to auto-create if missing)
|
||||||
try:
|
try:
|
||||||
from timmy.memory_system import memory_system
|
from timmy.memory_system import memory_system
|
||||||
|
|
||||||
ctx["memory"] = memory_system.hot.read()[:2000]
|
ctx["memory"] = memory_system.hot.read()[:2000]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Could not load memory for context: %s", exc)
|
logger.warning("Could not load memory for context: %s", exc)
|
||||||
ctx["memory"] = "(Memory unavailable)"
|
ctx["memory"] = "(Memory unavailable)"
|
||||||
|
|
||||||
_timmy_context.update(ctx)
|
_timmy_context.update(ctx)
|
||||||
logger.info("Context built (sync): %d agents", len(ctx["agents"]))
|
logger.info("Context built (sync): %d agents", len(ctx["agents"]))
|
||||||
return ctx
|
return ctx
|
||||||
@@ -110,21 +117,31 @@ build_timmy_context = build_timmy_context_sync
|
|||||||
|
|
||||||
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
||||||
"""Format the system prompt with dynamic context."""
|
"""Format the system prompt with dynamic context."""
|
||||||
|
|
||||||
# Format agents list
|
# Format agents list
|
||||||
agents_list = "\n".join([
|
agents_list = (
|
||||||
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
"\n".join(
|
||||||
for a in context.get("agents", [])
|
[
|
||||||
]) or "(No agents registered yet)"
|
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
||||||
|
for a in context.get("agents", [])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
or "(No agents registered yet)"
|
||||||
|
)
|
||||||
|
|
||||||
# Format hands list
|
# Format hands list
|
||||||
hands_list = "\n".join([
|
hands_list = (
|
||||||
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
"\n".join(
|
||||||
for h in context.get("hands", [])
|
[
|
||||||
]) or "(No hands configured)"
|
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
||||||
|
for h in context.get("hands", [])
|
||||||
repo_root = context.get('repo_root', settings.repo_root)
|
]
|
||||||
|
)
|
||||||
|
or "(No hands configured)"
|
||||||
|
)
|
||||||
|
|
||||||
|
repo_root = context.get("repo_root", settings.repo_root)
|
||||||
|
|
||||||
context_block = f"""
|
context_block = f"""
|
||||||
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
|
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
|
||||||
|
|
||||||
@@ -149,10 +166,10 @@ def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
|||||||
### Hot Memory:
|
### Hot Memory:
|
||||||
{context.get('memory', '(unavailable)')[:1000]}
|
{context.get('memory', '(unavailable)')[:1000]}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Replace {REPO_ROOT} placeholder with actual path
|
# Replace {REPO_ROOT} placeholder with actual path
|
||||||
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
|
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
|
||||||
|
|
||||||
# Insert context after the first line
|
# Insert context after the first line
|
||||||
lines = base_prompt.split("\n")
|
lines = base_prompt.split("\n")
|
||||||
if lines:
|
if lines:
|
||||||
@@ -227,63 +244,71 @@ class TimmyOrchestrator(BaseAgent):
|
|||||||
name="Orchestrator",
|
name="Orchestrator",
|
||||||
role="orchestrator",
|
role="orchestrator",
|
||||||
system_prompt=formatted_prompt,
|
system_prompt=formatted_prompt,
|
||||||
tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write", "system_status"],
|
tools=[
|
||||||
|
"web_search",
|
||||||
|
"read_file",
|
||||||
|
"write_file",
|
||||||
|
"python",
|
||||||
|
"memory_search",
|
||||||
|
"memory_write",
|
||||||
|
"system_status",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sub-agent registry
|
# Sub-agent registry
|
||||||
self.sub_agents: dict[str, BaseAgent] = {}
|
self.sub_agents: dict[str, BaseAgent] = {}
|
||||||
|
|
||||||
# Session tracking for init behavior
|
# Session tracking for init behavior
|
||||||
self._session_initialized = False
|
self._session_initialized = False
|
||||||
self._session_context: dict[str, Any] = {}
|
self._session_context: dict[str, Any] = {}
|
||||||
self._context_fully_loaded = False
|
self._context_fully_loaded = False
|
||||||
|
|
||||||
# Connect to event bus
|
# Connect to event bus
|
||||||
self.connect_event_bus(event_bus)
|
self.connect_event_bus(event_bus)
|
||||||
|
|
||||||
logger.info("Orchestrator initialized with context-aware prompt")
|
logger.info("Orchestrator initialized with context-aware prompt")
|
||||||
|
|
||||||
def register_sub_agent(self, agent: BaseAgent) -> None:
|
def register_sub_agent(self, agent: BaseAgent) -> None:
|
||||||
"""Register a sub-agent with the orchestrator."""
|
"""Register a sub-agent with the orchestrator."""
|
||||||
self.sub_agents[agent.agent_id] = agent
|
self.sub_agents[agent.agent_id] = agent
|
||||||
agent.connect_event_bus(event_bus)
|
agent.connect_event_bus(event_bus)
|
||||||
logger.info("Registered sub-agent: %s", agent.name)
|
logger.info("Registered sub-agent: %s", agent.name)
|
||||||
|
|
||||||
async def _session_init(self) -> None:
|
async def _session_init(self) -> None:
|
||||||
"""Initialize session context on first user message.
|
"""Initialize session context on first user message.
|
||||||
|
|
||||||
Silently reads git log and AGENTS.md to ground the orchestrator in real data.
|
Silently reads git log and AGENTS.md to ground the orchestrator in real data.
|
||||||
This runs once per session before the first response.
|
This runs once per session before the first response.
|
||||||
"""
|
"""
|
||||||
if self._session_initialized:
|
if self._session_initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.debug("Running session init...")
|
logger.debug("Running session init...")
|
||||||
|
|
||||||
# Load full context including hands if not already done
|
# Load full context including hands if not already done
|
||||||
if not self._context_fully_loaded:
|
if not self._context_fully_loaded:
|
||||||
await build_timmy_context_async()
|
await build_timmy_context_async()
|
||||||
self._context_fully_loaded = True
|
self._context_fully_loaded = True
|
||||||
|
|
||||||
# Read recent git log --oneline -15 from repo root
|
# Read recent git log --oneline -15 from repo root
|
||||||
try:
|
try:
|
||||||
from tools.git_tools import git_log
|
from tools.git_tools import git_log
|
||||||
|
|
||||||
git_result = git_log(max_count=15)
|
git_result = git_log(max_count=15)
|
||||||
if git_result.get("success"):
|
if git_result.get("success"):
|
||||||
commits = git_result.get("commits", [])
|
commits = git_result.get("commits", [])
|
||||||
self._session_context["git_log_commits"] = commits
|
self._session_context["git_log_commits"] = commits
|
||||||
# Format as oneline for easy reading
|
# Format as oneline for easy reading
|
||||||
self._session_context["git_log_oneline"] = "\n".join([
|
self._session_context["git_log_oneline"] = "\n".join(
|
||||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
[f"{c['short_sha']} {c['message'].split(chr(10))[0]}" for c in commits]
|
||||||
for c in commits
|
)
|
||||||
])
|
|
||||||
logger.debug(f"Session init: loaded {len(commits)} commits from git log")
|
logger.debug(f"Session init: loaded {len(commits)} commits from git log")
|
||||||
else:
|
else:
|
||||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Session init: could not read git log: %s", exc)
|
logger.warning("Session init: could not read git log: %s", exc)
|
||||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||||
|
|
||||||
# Read AGENTS.md for self-awareness
|
# Read AGENTS.md for self-awareness
|
||||||
try:
|
try:
|
||||||
agents_md_path = Path(settings.repo_root) / "AGENTS.md"
|
agents_md_path = Path(settings.repo_root) / "AGENTS.md"
|
||||||
@@ -291,7 +316,7 @@ class TimmyOrchestrator(BaseAgent):
|
|||||||
self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
|
self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Session init: could not read AGENTS.md: %s", exc)
|
logger.warning("Session init: could not read AGENTS.md: %s", exc)
|
||||||
|
|
||||||
# Read CHANGELOG for recent changes
|
# Read CHANGELOG for recent changes
|
||||||
try:
|
try:
|
||||||
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
|
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
|
||||||
@@ -299,11 +324,13 @@ class TimmyOrchestrator(BaseAgent):
|
|||||||
self._session_context["changelog"] = changelog_path.read_text()[:2000]
|
self._session_context["changelog"] = changelog_path.read_text()[:2000]
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Changelog is optional
|
pass # Changelog is optional
|
||||||
|
|
||||||
# Build session-specific context block for the prompt
|
# Build session-specific context block for the prompt
|
||||||
recent_changes = self._session_context.get("git_log_oneline", "")
|
recent_changes = self._session_context.get("git_log_oneline", "")
|
||||||
if recent_changes and recent_changes != "Git log unavailable":
|
if recent_changes and recent_changes != "Git log unavailable":
|
||||||
self._session_context["recent_changes_block"] = f"""
|
self._session_context[
|
||||||
|
"recent_changes_block"
|
||||||
|
] = f"""
|
||||||
## Recent Changes to Your Codebase (last 15 commits):
|
## Recent Changes to Your Codebase (last 15 commits):
|
||||||
```
|
```
|
||||||
{recent_changes}
|
{recent_changes}
|
||||||
@@ -312,17 +339,17 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
|||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
self._session_context["recent_changes_block"] = ""
|
self._session_context["recent_changes_block"] = ""
|
||||||
|
|
||||||
self._session_initialized = True
|
self._session_initialized = True
|
||||||
logger.debug("Session init complete")
|
logger.debug("Session init complete")
|
||||||
|
|
||||||
def _get_enhanced_system_prompt(self) -> str:
|
def _get_enhanced_system_prompt(self) -> str:
|
||||||
"""Get system prompt enhanced with session-specific context.
|
"""Get system prompt enhanced with session-specific context.
|
||||||
|
|
||||||
Prepends the recent git log to the system prompt for grounding.
|
Prepends the recent git log to the system prompt for grounding.
|
||||||
"""
|
"""
|
||||||
base = self.system_prompt
|
base = self.system_prompt
|
||||||
|
|
||||||
# Add recent changes block if available
|
# Add recent changes block if available
|
||||||
recent_changes = self._session_context.get("recent_changes_block", "")
|
recent_changes = self._session_context.get("recent_changes_block", "")
|
||||||
if recent_changes:
|
if recent_changes:
|
||||||
@@ -330,36 +357,45 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
|||||||
lines = base.split("\n")
|
lines = base.split("\n")
|
||||||
if lines:
|
if lines:
|
||||||
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
|
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
|
||||||
|
|
||||||
return base
|
return base
|
||||||
|
|
||||||
async def orchestrate(self, user_request: str) -> str:
|
async def orchestrate(self, user_request: str) -> str:
|
||||||
"""Main entry point for user requests.
|
"""Main entry point for user requests.
|
||||||
|
|
||||||
Analyzes the request and either handles directly or delegates.
|
Analyzes the request and either handles directly or delegates.
|
||||||
"""
|
"""
|
||||||
# Run session init on first message (loads git log, etc.)
|
# Run session init on first message (loads git log, etc.)
|
||||||
await self._session_init()
|
await self._session_init()
|
||||||
|
|
||||||
# Quick classification
|
# Quick classification
|
||||||
request_lower = user_request.lower()
|
request_lower = user_request.lower()
|
||||||
|
|
||||||
# Direct response patterns (no delegation needed)
|
# Direct response patterns (no delegation needed)
|
||||||
direct_patterns = [
|
direct_patterns = [
|
||||||
"your name", "who are you", "what are you",
|
"your name",
|
||||||
"hello", "hi", "how are you",
|
"who are you",
|
||||||
"help", "what can you do",
|
"what are you",
|
||||||
|
"hello",
|
||||||
|
"hi",
|
||||||
|
"how are you",
|
||||||
|
"help",
|
||||||
|
"what can you do",
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in direct_patterns:
|
for pattern in direct_patterns:
|
||||||
if pattern in request_lower:
|
if pattern in request_lower:
|
||||||
return await self.run(user_request)
|
return await self.run(user_request)
|
||||||
|
|
||||||
# Check for memory references — delegate to Echo
|
# Check for memory references — delegate to Echo
|
||||||
memory_patterns = [
|
memory_patterns = [
|
||||||
"we talked about", "we discussed", "remember",
|
"we talked about",
|
||||||
"what did i say", "what did we decide",
|
"we discussed",
|
||||||
"remind me", "have we",
|
"remember",
|
||||||
|
"what did i say",
|
||||||
|
"what did we decide",
|
||||||
|
"remind me",
|
||||||
|
"have we",
|
||||||
]
|
]
|
||||||
|
|
||||||
for pattern in memory_patterns:
|
for pattern in memory_patterns:
|
||||||
@@ -395,19 +431,16 @@ When asked "what's new?" or similar, refer to these commits for actual changes.
|
|||||||
if agent in text_lower:
|
if agent in text_lower:
|
||||||
return agent
|
return agent
|
||||||
return "orchestrator"
|
return "orchestrator"
|
||||||
|
|
||||||
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
async def execute_task(self, task_id: str, description: str, context: dict) -> Any:
|
||||||
"""Execute a task (usually delegates to appropriate agent)."""
|
"""Execute a task (usually delegates to appropriate agent)."""
|
||||||
return await self.orchestrate(description)
|
return await self.orchestrate(description)
|
||||||
|
|
||||||
def get_swarm_status(self) -> dict:
|
def get_swarm_status(self) -> dict:
|
||||||
"""Get status of all agents in the swarm."""
|
"""Get status of all agents in the swarm."""
|
||||||
return {
|
return {
|
||||||
"orchestrator": self.get_status(),
|
"orchestrator": self.get_status(),
|
||||||
"sub_agents": {
|
"sub_agents": {aid: agent.get_status() for aid, agent in self.sub_agents.items()},
|
||||||
aid: agent.get_status()
|
|
||||||
for aid, agent in self.sub_agents.items()
|
|
||||||
},
|
|
||||||
"total_agents": 1 + len(self.sub_agents),
|
"total_agents": 1 + len(self.sub_agents),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -468,10 +501,29 @@ _PERSONAS: list[dict[str, Any]] = [
|
|||||||
"system_prompt": (
|
"system_prompt": (
|
||||||
"You are Helm, a routing and orchestration specialist.\n"
|
"You are Helm, a routing and orchestration specialist.\n"
|
||||||
"Analyze tasks and decide how to route them to other agents.\n"
|
"Analyze tasks and decide how to route them to other agents.\n"
|
||||||
"Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory).\n"
|
"Available agents: Seer (research), Forge (code), Quill (writing), Echo (memory), Lab (experiments).\n"
|
||||||
"Respond with: Primary Agent: [agent name]"
|
"Respond with: Primary Agent: [agent name]"
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"agent_id": "lab",
|
||||||
|
"name": "Lab",
|
||||||
|
"role": "experiment",
|
||||||
|
"tools": [
|
||||||
|
"run_experiment",
|
||||||
|
"prepare_experiment",
|
||||||
|
"shell",
|
||||||
|
"python",
|
||||||
|
"read_file",
|
||||||
|
"write_file",
|
||||||
|
],
|
||||||
|
"system_prompt": (
|
||||||
|
"You are Lab, an autonomous ML experimentation specialist.\n"
|
||||||
|
"You run time-boxed training experiments, evaluate metrics,\n"
|
||||||
|
"modify training code to improve results, and iterate.\n"
|
||||||
|
"Always report the metric delta. Never exceed the time budget."
|
||||||
|
),
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ class ApprovalItem:
|
|||||||
id: str
|
id: str
|
||||||
title: str
|
title: str
|
||||||
description: str
|
description: str
|
||||||
proposed_action: str # what Timmy wants to do
|
proposed_action: str # what Timmy wants to do
|
||||||
impact: str # "low" | "medium" | "high"
|
impact: str # "low" | "medium" | "high"
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
status: str # "pending" | "approved" | "rejected"
|
status: str # "pending" | "approved" | "rejected"
|
||||||
|
|
||||||
|
|
||||||
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||||
@@ -81,6 +81,7 @@ def _row_to_item(row: sqlite3.Row) -> ApprovalItem:
|
|||||||
# Public API
|
# Public API
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def create_item(
|
def create_item(
|
||||||
title: str,
|
title: str,
|
||||||
description: str,
|
description: str,
|
||||||
@@ -133,18 +134,14 @@ def list_pending(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
|||||||
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
def list_all(db_path: Path = _DEFAULT_DB) -> list[ApprovalItem]:
|
||||||
"""Return all approval items regardless of status, newest first."""
|
"""Return all approval items regardless of status, newest first."""
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_path)
|
||||||
rows = conn.execute(
|
rows = conn.execute("SELECT * FROM approval_items ORDER BY created_at DESC").fetchall()
|
||||||
"SELECT * FROM approval_items ORDER BY created_at DESC"
|
|
||||||
).fetchall()
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return [_row_to_item(r) for r in rows]
|
return [_row_to_item(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_path)
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM approval_items WHERE id = ?", (item_id,)).fetchone()
|
||||||
"SELECT * FROM approval_items WHERE id = ?", (item_id,)
|
|
||||||
).fetchone()
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return _row_to_item(row) if row else None
|
return _row_to_item(row) if row else None
|
||||||
|
|
||||||
@@ -152,9 +149,7 @@ def get_item(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem
|
|||||||
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||||
"""Mark an approval item as approved."""
|
"""Mark an approval item as approved."""
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_path)
|
||||||
conn.execute(
|
conn.execute("UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,))
|
||||||
"UPDATE approval_items SET status = 'approved' WHERE id = ?", (item_id,)
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
return get_item(item_id, db_path)
|
return get_item(item_id, db_path)
|
||||||
@@ -163,9 +158,7 @@ def approve(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]
|
|||||||
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
def reject(item_id: str, db_path: Path = _DEFAULT_DB) -> Optional[ApprovalItem]:
|
||||||
"""Mark an approval item as rejected."""
|
"""Mark an approval item as rejected."""
|
||||||
conn = _get_conn(db_path)
|
conn = _get_conn(db_path)
|
||||||
conn.execute(
|
conn.execute("UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,))
|
||||||
"UPDATE approval_items SET status = 'rejected' WHERE id = ?", (item_id,)
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
return get_item(item_id, db_path)
|
return get_item(item_id, db_path)
|
||||||
|
|||||||
214
src/timmy/autoresearch.py
Normal file
214
src/timmy/autoresearch.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Autoresearch — autonomous ML experiment loops.
|
||||||
|
|
||||||
|
Integrates Karpathy's autoresearch pattern: an agent modifies training
|
||||||
|
code, runs time-boxed GPU experiments, evaluates a target metric
|
||||||
|
(val_bpb by default), and iterates to find improvements.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. prepare_experiment — clone repo + run data prep
|
||||||
|
2. run_experiment — execute train.py with wall-clock timeout
|
||||||
|
3. evaluate_result — compare metric against baseline
|
||||||
|
4. experiment_loop — orchestrate the full cycle
|
||||||
|
|
||||||
|
All subprocess calls are guarded with timeouts for graceful degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_REPO = "https://github.com/karpathy/autoresearch.git"
|
||||||
|
_METRIC_RE = re.compile(r"val_bpb[:\s]+([0-9]+\.?[0-9]*)")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_experiment(
|
||||||
|
workspace: Path,
|
||||||
|
repo_url: str = DEFAULT_REPO,
|
||||||
|
) -> str:
|
||||||
|
"""Clone autoresearch repo and run data preparation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace: Directory to set up the experiment in.
|
||||||
|
repo_url: Git URL for the autoresearch repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Status message describing what was prepared.
|
||||||
|
"""
|
||||||
|
workspace = Path(workspace)
|
||||||
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
repo_dir = workspace / "autoresearch"
|
||||||
|
if not repo_dir.exists():
|
||||||
|
logger.info("Cloning autoresearch into %s", repo_dir)
|
||||||
|
result = subprocess.run(
|
||||||
|
["git", "clone", "--depth", "1", repo_url, str(repo_dir)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return f"Clone failed: {result.stderr.strip()}"
|
||||||
|
else:
|
||||||
|
logger.info("Autoresearch repo already present at %s", repo_dir)
|
||||||
|
|
||||||
|
# Run prepare.py (data download + tokeniser training)
|
||||||
|
prepare_script = repo_dir / "prepare.py"
|
||||||
|
if prepare_script.exists():
|
||||||
|
logger.info("Running prepare.py …")
|
||||||
|
result = subprocess.run(
|
||||||
|
["python", str(prepare_script)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
cwd=str(repo_dir),
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
return f"Preparation failed: {result.stderr.strip()[:500]}"
|
||||||
|
return "Preparation complete — data downloaded and tokeniser trained."
|
||||||
|
|
||||||
|
return "Preparation skipped — no prepare.py found."
|
||||||
|
|
||||||
|
|
||||||
|
def run_experiment(
|
||||||
|
workspace: Path,
|
||||||
|
timeout: int = 300,
|
||||||
|
metric_name: str = "val_bpb",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Run a single training experiment with a wall-clock timeout.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace: Experiment workspace (contains autoresearch/ subdir).
|
||||||
|
timeout: Maximum wall-clock seconds for the run.
|
||||||
|
metric_name: Name of the metric to extract from stdout.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys: metric (float|None), log (str), duration_s (int),
|
||||||
|
success (bool), error (str|None).
|
||||||
|
"""
|
||||||
|
repo_dir = Path(workspace) / "autoresearch"
|
||||||
|
train_script = repo_dir / "train.py"
|
||||||
|
|
||||||
|
if not train_script.exists():
|
||||||
|
return {
|
||||||
|
"metric": None,
|
||||||
|
"log": "",
|
||||||
|
"duration_s": 0,
|
||||||
|
"success": False,
|
||||||
|
"error": f"train.py not found in {repo_dir}",
|
||||||
|
}
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["python", str(train_script)],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
cwd=str(repo_dir),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
duration = int(time.monotonic() - start)
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
|
||||||
|
# Extract metric from output
|
||||||
|
metric_val = _extract_metric(output, metric_name)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"metric": metric_val,
|
||||||
|
"log": output[-2000:], # Keep last 2k chars
|
||||||
|
"duration_s": duration,
|
||||||
|
"success": result.returncode == 0,
|
||||||
|
"error": None if result.returncode == 0 else f"Exit code {result.returncode}",
|
||||||
|
}
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
duration = int(time.monotonic() - start)
|
||||||
|
return {
|
||||||
|
"metric": None,
|
||||||
|
"log": f"Experiment timed out after {timeout}s",
|
||||||
|
"duration_s": duration,
|
||||||
|
"success": False,
|
||||||
|
"error": f"Timed out after {timeout}s",
|
||||||
|
}
|
||||||
|
except OSError as exc:
|
||||||
|
return {
|
||||||
|
"metric": None,
|
||||||
|
"log": "",
|
||||||
|
"duration_s": 0,
|
||||||
|
"success": False,
|
||||||
|
"error": str(exc),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metric(output: str, metric_name: str = "val_bpb") -> Optional[float]:
|
||||||
|
"""Extract the last occurrence of a metric value from training output."""
|
||||||
|
pattern = re.compile(rf"{re.escape(metric_name)}[:\s]+([0-9]+\.?[0-9]*)")
|
||||||
|
matches = pattern.findall(output)
|
||||||
|
if matches:
|
||||||
|
try:
|
||||||
|
return float(matches[-1])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_result(
|
||||||
|
current: float,
|
||||||
|
baseline: float,
|
||||||
|
metric_name: str = "val_bpb",
|
||||||
|
) -> str:
|
||||||
|
"""Compare a metric against baseline and return an assessment.
|
||||||
|
|
||||||
|
For val_bpb, lower is better.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current: Current experiment's metric value.
|
||||||
|
baseline: Baseline metric to compare against.
|
||||||
|
metric_name: Name of the metric (for display).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Human-readable assessment string.
|
||||||
|
"""
|
||||||
|
delta = current - baseline
|
||||||
|
pct = (delta / baseline) * 100 if baseline != 0 else 0.0
|
||||||
|
|
||||||
|
if delta < 0:
|
||||||
|
return f"Improvement: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
|
||||||
|
elif delta > 0:
|
||||||
|
return f"Regression: {metric_name} {baseline:.4f} -> {current:.4f} " f"({pct:+.2f}%)"
|
||||||
|
else:
|
||||||
|
return f"No change: {metric_name} = {current:.4f}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_experiment_history(workspace: Path) -> list[dict[str, Any]]:
|
||||||
|
"""Read experiment history from the workspace results file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of experiment result dicts, most recent first.
|
||||||
|
"""
|
||||||
|
results_file = Path(workspace) / "results.jsonl"
|
||||||
|
if not results_file.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
history: list[dict[str, Any]] = []
|
||||||
|
for line in results_file.read_text().strip().splitlines():
|
||||||
|
try:
|
||||||
|
history.append(json.loads(line))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return list(reversed(history))
|
||||||
|
|
||||||
|
|
||||||
|
def _append_result(workspace: Path, result: dict[str, Any]) -> None:
|
||||||
|
"""Append a result to the workspace JSONL log."""
|
||||||
|
results_file = Path(workspace) / "results.jsonl"
|
||||||
|
results_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with results_file.open("a") as f:
|
||||||
|
f.write(json.dumps(result) + "\n")
|
||||||
@@ -24,8 +24,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# HuggingFace model IDs for each supported size.
|
# HuggingFace model IDs for each supported size.
|
||||||
_AIRLLM_MODELS: dict[str, str] = {
|
_AIRLLM_MODELS: dict[str, str] = {
|
||||||
"8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
"8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
"70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
"70b": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"405b": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
"405b": "meta-llama/Meta-Llama-3.1-405B-Instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,6 +35,7 @@ ModelSize = Literal["8b", "70b", "405b"]
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RunResult:
|
class RunResult:
|
||||||
"""Minimal Agno-compatible run result — carries the model's response text."""
|
"""Minimal Agno-compatible run result — carries the model's response text."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +48,7 @@ def airllm_available() -> bool:
|
|||||||
"""Return True when the airllm package is importable."""
|
"""Return True when the airllm package is importable."""
|
||||||
try:
|
try:
|
||||||
import airllm # noqa: F401
|
import airllm # noqa: F401
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
@@ -67,15 +69,16 @@ class TimmyAirLLMAgent:
|
|||||||
model_id = _AIRLLM_MODELS.get(model_size)
|
model_id = _AIRLLM_MODELS.get(model_size)
|
||||||
if model_id is None:
|
if model_id is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model size {model_size!r}. "
|
f"Unknown model size {model_size!r}. " f"Choose from: {list(_AIRLLM_MODELS)}"
|
||||||
f"Choose from: {list(_AIRLLM_MODELS)}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_apple_silicon():
|
if is_apple_silicon():
|
||||||
from airllm import AirLLMMLX # type: ignore[import]
|
from airllm import AirLLMMLX # type: ignore[import]
|
||||||
|
|
||||||
self._model = AirLLMMLX(model_id)
|
self._model = AirLLMMLX(model_id)
|
||||||
else:
|
else:
|
||||||
from airllm import AutoModel # type: ignore[import]
|
from airllm import AutoModel # type: ignore[import]
|
||||||
|
|
||||||
self._model = AutoModel.from_pretrained(model_id)
|
self._model = AutoModel.from_pretrained(model_id)
|
||||||
|
|
||||||
self._history: list[str] = []
|
self._history: list[str] = []
|
||||||
@@ -137,6 +140,7 @@ class TimmyAirLLMAgent:
|
|||||||
try:
|
try:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
Console().print(Markdown(text))
|
Console().print(Markdown(text))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(text)
|
print(text)
|
||||||
@@ -157,6 +161,7 @@ GROK_MODELS: dict[str, str] = {
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GrokUsageStats:
|
class GrokUsageStats:
|
||||||
"""Tracks Grok API usage for cost monitoring and Spark logging."""
|
"""Tracks Grok API usage for cost monitoring and Spark logging."""
|
||||||
|
|
||||||
total_requests: int = 0
|
total_requests: int = 0
|
||||||
total_prompt_tokens: int = 0
|
total_prompt_tokens: int = 0
|
||||||
total_completion_tokens: int = 0
|
total_completion_tokens: int = 0
|
||||||
@@ -240,9 +245,7 @@ class GrokBackend:
|
|||||||
RunResult with response content
|
RunResult with response content
|
||||||
"""
|
"""
|
||||||
if not self._api_key:
|
if not self._api_key:
|
||||||
return RunResult(
|
return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
|
||||||
content="Grok is not configured. Set XAI_API_KEY to enable."
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
messages = self._build_messages(message)
|
messages = self._build_messages(message)
|
||||||
@@ -285,16 +288,12 @@ class GrokBackend:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.stats.errors += 1
|
self.stats.errors += 1
|
||||||
logger.error("Grok API error: %s", exc)
|
logger.error("Grok API error: %s", exc)
|
||||||
return RunResult(
|
return RunResult(content=f"Grok temporarily unavailable: {exc}")
|
||||||
content=f"Grok temporarily unavailable: {exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def arun(self, message: str) -> RunResult:
|
async def arun(self, message: str) -> RunResult:
|
||||||
"""Async inference via Grok API — used by cascade router and tools."""
|
"""Async inference via Grok API — used by cascade router and tools."""
|
||||||
if not self._api_key:
|
if not self._api_key:
|
||||||
return RunResult(
|
return RunResult(content="Grok is not configured. Set XAI_API_KEY to enable.")
|
||||||
content="Grok is not configured. Set XAI_API_KEY to enable."
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
messages = self._build_messages(message)
|
messages = self._build_messages(message)
|
||||||
@@ -336,9 +335,7 @@ class GrokBackend:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.stats.errors += 1
|
self.stats.errors += 1
|
||||||
logger.error("Grok async API error: %s", exc)
|
logger.error("Grok async API error: %s", exc)
|
||||||
return RunResult(
|
return RunResult(content=f"Grok temporarily unavailable: {exc}")
|
||||||
content=f"Grok temporarily unavailable: {exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_response(self, message: str, *, stream: bool = True) -> None:
|
def print_response(self, message: str, *, stream: bool = True) -> None:
|
||||||
"""Run inference and render the response to stdout (CLI interface)."""
|
"""Run inference and render the response to stdout (CLI interface)."""
|
||||||
@@ -346,6 +343,7 @@ class GrokBackend:
|
|||||||
try:
|
try:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
Console().print(Markdown(result.content))
|
Console().print(Markdown(result.content))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(result.content)
|
print(result.content)
|
||||||
@@ -415,6 +413,7 @@ def grok_available() -> bool:
|
|||||||
"""Return True when Grok is enabled and API key is configured."""
|
"""Return True when Grok is enabled and API key is configured."""
|
||||||
try:
|
try:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return settings.grok_enabled and bool(settings.xai_api_key)
|
return settings.grok_enabled and bool(settings.xai_api_key)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
@@ -472,9 +471,7 @@ class ClaudeBackend:
|
|||||||
def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult:
|
def run(self, message: str, *, stream: bool = False, **kwargs) -> RunResult:
|
||||||
"""Synchronous inference via Claude API."""
|
"""Synchronous inference via Claude API."""
|
||||||
if not self._api_key:
|
if not self._api_key:
|
||||||
return RunResult(
|
return RunResult(content="Claude is not configured. Set ANTHROPIC_API_KEY to enable.")
|
||||||
content="Claude is not configured. Set ANTHROPIC_API_KEY to enable."
|
|
||||||
)
|
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
messages = self._build_messages(message)
|
messages = self._build_messages(message)
|
||||||
@@ -508,9 +505,7 @@ class ClaudeBackend:
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Claude API error: %s", exc)
|
logger.error("Claude API error: %s", exc)
|
||||||
return RunResult(
|
return RunResult(content=f"Claude temporarily unavailable: {exc}")
|
||||||
content=f"Claude temporarily unavailable: {exc}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_response(self, message: str, *, stream: bool = True) -> None:
|
def print_response(self, message: str, *, stream: bool = True) -> None:
|
||||||
"""Run inference and render the response to stdout (CLI interface)."""
|
"""Run inference and render the response to stdout (CLI interface)."""
|
||||||
@@ -518,6 +513,7 @@ class ClaudeBackend:
|
|||||||
try:
|
try:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
Console().print(Markdown(result.content))
|
Console().print(Markdown(result.content))
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print(result.content)
|
print(result.content)
|
||||||
@@ -569,6 +565,7 @@ def claude_available() -> bool:
|
|||||||
"""Return True when Anthropic API key is configured."""
|
"""Return True when Anthropic API key is configured."""
|
||||||
try:
|
try:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
return bool(settings.anthropic_api_key)
|
return bool(settings.anthropic_api_key)
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ _CACHE_MINUTES = 30
|
|||||||
# Data structures
|
# Data structures
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ApprovalItem:
|
class ApprovalItem:
|
||||||
"""Lightweight representation used inside a Briefing.
|
"""Lightweight representation used inside a Briefing.
|
||||||
@@ -32,6 +33,7 @@ class ApprovalItem:
|
|||||||
The canonical mutable version (with persistence) lives in timmy.approvals.
|
The canonical mutable version (with persistence) lives in timmy.approvals.
|
||||||
This one travels with the Briefing dataclass as a read-only snapshot.
|
This one travels with the Briefing dataclass as a read-only snapshot.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
title: str
|
title: str
|
||||||
description: str
|
description: str
|
||||||
@@ -44,20 +46,19 @@ class ApprovalItem:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Briefing:
|
class Briefing:
|
||||||
generated_at: datetime
|
generated_at: datetime
|
||||||
summary: str # 150-300 words
|
summary: str # 150-300 words
|
||||||
approval_items: list[ApprovalItem] = field(default_factory=list)
|
approval_items: list[ApprovalItem] = field(default_factory=list)
|
||||||
period_start: datetime = field(
|
period_start: datetime = field(
|
||||||
default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6)
|
default_factory=lambda: datetime.now(timezone.utc) - timedelta(hours=6)
|
||||||
)
|
)
|
||||||
period_end: datetime = field(
|
period_end: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
default_factory=lambda: datetime.now(timezone.utc)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# SQLite cache
|
# SQLite cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
def _get_cache_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(db_path))
|
conn = sqlite3.connect(str(db_path))
|
||||||
@@ -98,9 +99,7 @@ def _save_briefing(briefing: Briefing, db_path: Path = _DEFAULT_DB) -> None:
|
|||||||
def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
|
def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
|
||||||
"""Load the most-recently cached briefing, or None if there is none."""
|
"""Load the most-recently cached briefing, or None if there is none."""
|
||||||
conn = _get_cache_conn(db_path)
|
conn = _get_cache_conn(db_path)
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1").fetchone()
|
||||||
"SELECT * FROM briefings ORDER BY generated_at DESC LIMIT 1"
|
|
||||||
).fetchone()
|
|
||||||
conn.close()
|
conn.close()
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
@@ -115,7 +114,11 @@ def _load_latest(db_path: Path = _DEFAULT_DB) -> Optional[Briefing]:
|
|||||||
def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
|
def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
|
||||||
"""Return True if the briefing was generated within max_age_minutes."""
|
"""Return True if the briefing was generated within max_age_minutes."""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
age = now - briefing.generated_at.replace(tzinfo=timezone.utc) if briefing.generated_at.tzinfo is None else now - briefing.generated_at
|
age = (
|
||||||
|
now - briefing.generated_at.replace(tzinfo=timezone.utc)
|
||||||
|
if briefing.generated_at.tzinfo is None
|
||||||
|
else now - briefing.generated_at
|
||||||
|
)
|
||||||
return age.total_seconds() < max_age_minutes * 60
|
return age.total_seconds() < max_age_minutes * 60
|
||||||
|
|
||||||
|
|
||||||
@@ -123,6 +126,7 @@ def is_fresh(briefing: Briefing, max_age_minutes: int = _CACHE_MINUTES) -> bool:
|
|||||||
# Activity gathering helpers
|
# Activity gathering helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _gather_swarm_summary(since: datetime) -> str:
|
def _gather_swarm_summary(since: datetime) -> str:
|
||||||
"""Pull recent task/agent stats from swarm.db. Graceful if DB missing."""
|
"""Pull recent task/agent stats from swarm.db. Graceful if DB missing."""
|
||||||
swarm_db = Path("data/swarm.db")
|
swarm_db = Path("data/swarm.db")
|
||||||
@@ -170,6 +174,7 @@ def _gather_task_queue_summary() -> str:
|
|||||||
"""Pull task queue stats for the briefing. Graceful if unavailable."""
|
"""Pull task queue stats for the briefing. Graceful if unavailable."""
|
||||||
try:
|
try:
|
||||||
from swarm.task_queue.models import get_task_summary_for_briefing
|
from swarm.task_queue.models import get_task_summary_for_briefing
|
||||||
|
|
||||||
stats = get_task_summary_for_briefing()
|
stats = get_task_summary_for_briefing()
|
||||||
parts = []
|
parts = []
|
||||||
if stats["pending_approval"]:
|
if stats["pending_approval"]:
|
||||||
@@ -194,6 +199,7 @@ def _gather_chat_summary(since: datetime) -> str:
|
|||||||
"""Pull recent chat messages from the in-memory log."""
|
"""Pull recent chat messages from the in-memory log."""
|
||||||
try:
|
try:
|
||||||
from dashboard.store import message_log
|
from dashboard.store import message_log
|
||||||
|
|
||||||
messages = message_log.all()
|
messages = message_log.all()
|
||||||
# Filter to messages in the briefing window (best-effort: no timestamps)
|
# Filter to messages in the briefing window (best-effort: no timestamps)
|
||||||
recent = messages[-10:] if len(messages) > 10 else messages
|
recent = messages[-10:] if len(messages) > 10 else messages
|
||||||
@@ -213,6 +219,7 @@ def _gather_chat_summary(since: datetime) -> str:
|
|||||||
# BriefingEngine
|
# BriefingEngine
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class BriefingEngine:
|
class BriefingEngine:
|
||||||
"""Generates morning briefings by querying activity and asking Timmy."""
|
"""Generates morning briefings by querying activity and asking Timmy."""
|
||||||
|
|
||||||
@@ -297,6 +304,7 @@ class BriefingEngine:
|
|||||||
"""Call Timmy's Agno agent and return the response text."""
|
"""Call Timmy's Agno agent and return the response text."""
|
||||||
try:
|
try:
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
|
|
||||||
agent = create_timmy()
|
agent = create_timmy()
|
||||||
run = agent.run(prompt, stream=False)
|
run = agent.run(prompt, stream=False)
|
||||||
result = run.content if hasattr(run, "content") else str(run)
|
result = run.content if hasattr(run, "content") else str(run)
|
||||||
@@ -317,6 +325,7 @@ class BriefingEngine:
|
|||||||
"""Return pending ApprovalItems from the approvals DB."""
|
"""Return pending ApprovalItems from the approvals DB."""
|
||||||
try:
|
try:
|
||||||
from timmy import approvals as _approvals
|
from timmy import approvals as _approvals
|
||||||
|
|
||||||
raw_items = _approvals.list_pending()
|
raw_items = _approvals.list_pending()
|
||||||
return [
|
return [
|
||||||
ApprovalItem(
|
ApprovalItem(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class TimmyResponse:
|
class TimmyResponse:
|
||||||
"""Response from Timmy via Cascade Router."""
|
"""Response from Timmy via Cascade Router."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
provider_used: str
|
provider_used: str
|
||||||
latency_ms: float
|
latency_ms: float
|
||||||
@@ -27,31 +28,30 @@ class TimmyResponse:
|
|||||||
|
|
||||||
class TimmyCascadeAdapter:
|
class TimmyCascadeAdapter:
|
||||||
"""Adapter that routes Timmy requests through Cascade Router.
|
"""Adapter that routes Timmy requests through Cascade Router.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
adapter = TimmyCascadeAdapter()
|
adapter = TimmyCascadeAdapter()
|
||||||
response = await adapter.chat("Hello")
|
response = await adapter.chat("Hello")
|
||||||
print(f"Response: {response.content}")
|
print(f"Response: {response.content}")
|
||||||
print(f"Provider: {response.provider_used}")
|
print(f"Provider: {response.provider_used}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, router: Optional[CascadeRouter] = None) -> None:
|
def __init__(self, router: Optional[CascadeRouter] = None) -> None:
|
||||||
"""Initialize adapter with Cascade Router.
|
"""Initialize adapter with Cascade Router.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
router: CascadeRouter instance. If None, creates default.
|
router: CascadeRouter instance. If None, creates default.
|
||||||
"""
|
"""
|
||||||
self.router = router or CascadeRouter()
|
self.router = router or CascadeRouter()
|
||||||
logger.info("TimmyCascadeAdapter initialized with %d providers",
|
logger.info("TimmyCascadeAdapter initialized with %d providers", len(self.router.providers))
|
||||||
len(self.router.providers))
|
|
||||||
|
|
||||||
async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse:
|
async def chat(self, message: str, context: Optional[str] = None) -> TimmyResponse:
|
||||||
"""Send message through cascade router with automatic failover.
|
"""Send message through cascade router with automatic failover.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: User message
|
message: User message
|
||||||
context: Optional conversation context
|
context: Optional conversation context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TimmyResponse with content and metadata
|
TimmyResponse with content and metadata
|
||||||
"""
|
"""
|
||||||
@@ -60,37 +60,38 @@ class TimmyCascadeAdapter:
|
|||||||
if context:
|
if context:
|
||||||
messages.append({"role": "system", "content": context})
|
messages.append({"role": "system", "content": context})
|
||||||
messages.append({"role": "user", "content": message})
|
messages.append({"role": "user", "content": message})
|
||||||
|
|
||||||
# Route through cascade
|
# Route through cascade
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await self.router.complete(
|
result = await self.router.complete(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
system_prompt=SYSTEM_PROMPT,
|
system_prompt=SYSTEM_PROMPT,
|
||||||
)
|
)
|
||||||
|
|
||||||
latency = (time.time() - start) * 1000
|
latency = (time.time() - start) * 1000
|
||||||
|
|
||||||
# Determine if fallback was used
|
# Determine if fallback was used
|
||||||
primary = self.router.providers[0] if self.router.providers else None
|
primary = self.router.providers[0] if self.router.providers else None
|
||||||
fallback_used = primary and primary.status.value != "healthy"
|
fallback_used = primary and primary.status.value != "healthy"
|
||||||
|
|
||||||
return TimmyResponse(
|
return TimmyResponse(
|
||||||
content=result.content,
|
content=result.content,
|
||||||
provider_used=result.provider_name,
|
provider_used=result.provider_name,
|
||||||
latency_ms=latency,
|
latency_ms=latency,
|
||||||
fallback_used=fallback_used,
|
fallback_used=fallback_used,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("All providers failed: %s", exc)
|
logger.error("All providers failed: %s", exc)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_provider_status(self) -> list[dict]:
|
def get_provider_status(self) -> list[dict]:
|
||||||
"""Get status of all providers.
|
"""Get status of all providers.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of provider status dicts
|
List of provider status dicts
|
||||||
"""
|
"""
|
||||||
@@ -112,10 +113,10 @@ class TimmyCascadeAdapter:
|
|||||||
}
|
}
|
||||||
for p in self.router.providers
|
for p in self.router.providers
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_preferred_provider(self) -> Optional[str]:
|
def get_preferred_provider(self) -> Optional[str]:
|
||||||
"""Get name of highest-priority healthy provider.
|
"""Get name of highest-priority healthy provider.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Provider name or None if all unhealthy
|
Provider name or None if all unhealthy
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -17,22 +17,23 @@ logger = logging.getLogger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConversationContext:
|
class ConversationContext:
|
||||||
"""Tracks the current conversation state."""
|
"""Tracks the current conversation state."""
|
||||||
|
|
||||||
user_name: Optional[str] = None
|
user_name: Optional[str] = None
|
||||||
current_topic: Optional[str] = None
|
current_topic: Optional[str] = None
|
||||||
last_intent: Optional[str] = None
|
last_intent: Optional[str] = None
|
||||||
turn_count: int = 0
|
turn_count: int = 0
|
||||||
started_at: datetime = field(default_factory=datetime.now)
|
started_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
def update_topic(self, topic: str) -> None:
|
def update_topic(self, topic: str) -> None:
|
||||||
"""Update the current conversation topic."""
|
"""Update the current conversation topic."""
|
||||||
self.current_topic = topic
|
self.current_topic = topic
|
||||||
self.turn_count += 1
|
self.turn_count += 1
|
||||||
|
|
||||||
def set_user_name(self, name: str) -> None:
|
def set_user_name(self, name: str) -> None:
|
||||||
"""Remember the user's name."""
|
"""Remember the user's name."""
|
||||||
self.user_name = name
|
self.user_name = name
|
||||||
logger.info("User name set to: %s", name)
|
logger.info("User name set to: %s", name)
|
||||||
|
|
||||||
def get_context_summary(self) -> str:
|
def get_context_summary(self) -> str:
|
||||||
"""Generate a context summary for the prompt."""
|
"""Generate a context summary for the prompt."""
|
||||||
parts = []
|
parts = []
|
||||||
@@ -47,35 +48,88 @@ class ConversationContext:
|
|||||||
|
|
||||||
class ConversationManager:
|
class ConversationManager:
|
||||||
"""Manages conversation context across sessions."""
|
"""Manages conversation context across sessions."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._contexts: dict[str, ConversationContext] = {}
|
self._contexts: dict[str, ConversationContext] = {}
|
||||||
|
|
||||||
def get_context(self, session_id: str) -> ConversationContext:
|
def get_context(self, session_id: str) -> ConversationContext:
|
||||||
"""Get or create context for a session."""
|
"""Get or create context for a session."""
|
||||||
if session_id not in self._contexts:
|
if session_id not in self._contexts:
|
||||||
self._contexts[session_id] = ConversationContext()
|
self._contexts[session_id] = ConversationContext()
|
||||||
return self._contexts[session_id]
|
return self._contexts[session_id]
|
||||||
|
|
||||||
def clear_context(self, session_id: str) -> None:
|
def clear_context(self, session_id: str) -> None:
|
||||||
"""Clear context for a session."""
|
"""Clear context for a session."""
|
||||||
if session_id in self._contexts:
|
if session_id in self._contexts:
|
||||||
del self._contexts[session_id]
|
del self._contexts[session_id]
|
||||||
|
|
||||||
# Words that look like names but are actually verbs/UI states
|
# Words that look like names but are actually verbs/UI states
|
||||||
_NAME_BLOCKLIST = frozenset({
|
_NAME_BLOCKLIST = frozenset(
|
||||||
"sending", "loading", "pending", "processing", "typing",
|
{
|
||||||
"working", "going", "trying", "looking", "getting", "doing",
|
"sending",
|
||||||
"waiting", "running", "checking", "coming", "leaving",
|
"loading",
|
||||||
"thinking", "reading", "writing", "watching", "listening",
|
"pending",
|
||||||
"playing", "eating", "sleeping", "sitting", "standing",
|
"processing",
|
||||||
"walking", "talking", "asking", "telling", "feeling",
|
"typing",
|
||||||
"hoping", "wondering", "glad", "happy", "sorry", "sure",
|
"working",
|
||||||
"fine", "good", "great", "okay", "here", "there", "back",
|
"going",
|
||||||
"done", "ready", "busy", "free", "available", "interested",
|
"trying",
|
||||||
"confused", "lost", "stuck", "curious", "excited", "tired",
|
"looking",
|
||||||
"not", "also", "just", "still", "already", "currently",
|
"getting",
|
||||||
})
|
"doing",
|
||||||
|
"waiting",
|
||||||
|
"running",
|
||||||
|
"checking",
|
||||||
|
"coming",
|
||||||
|
"leaving",
|
||||||
|
"thinking",
|
||||||
|
"reading",
|
||||||
|
"writing",
|
||||||
|
"watching",
|
||||||
|
"listening",
|
||||||
|
"playing",
|
||||||
|
"eating",
|
||||||
|
"sleeping",
|
||||||
|
"sitting",
|
||||||
|
"standing",
|
||||||
|
"walking",
|
||||||
|
"talking",
|
||||||
|
"asking",
|
||||||
|
"telling",
|
||||||
|
"feeling",
|
||||||
|
"hoping",
|
||||||
|
"wondering",
|
||||||
|
"glad",
|
||||||
|
"happy",
|
||||||
|
"sorry",
|
||||||
|
"sure",
|
||||||
|
"fine",
|
||||||
|
"good",
|
||||||
|
"great",
|
||||||
|
"okay",
|
||||||
|
"here",
|
||||||
|
"there",
|
||||||
|
"back",
|
||||||
|
"done",
|
||||||
|
"ready",
|
||||||
|
"busy",
|
||||||
|
"free",
|
||||||
|
"available",
|
||||||
|
"interested",
|
||||||
|
"confused",
|
||||||
|
"lost",
|
||||||
|
"stuck",
|
||||||
|
"curious",
|
||||||
|
"excited",
|
||||||
|
"tired",
|
||||||
|
"not",
|
||||||
|
"also",
|
||||||
|
"just",
|
||||||
|
"still",
|
||||||
|
"already",
|
||||||
|
"currently",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def extract_user_name(self, message: str) -> Optional[str]:
|
def extract_user_name(self, message: str) -> Optional[str]:
|
||||||
"""Try to extract user's name from message."""
|
"""Try to extract user's name from message."""
|
||||||
@@ -106,40 +160,66 @@ class ConversationManager:
|
|||||||
return name.capitalize()
|
return name.capitalize()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def should_use_tools(self, message: str, context: ConversationContext) -> bool:
|
def should_use_tools(self, message: str, context: ConversationContext) -> bool:
|
||||||
"""Determine if this message likely requires tools.
|
"""Determine if this message likely requires tools.
|
||||||
|
|
||||||
Returns True if tools are likely needed, False for simple chat.
|
Returns True if tools are likely needed, False for simple chat.
|
||||||
"""
|
"""
|
||||||
message_lower = message.lower().strip()
|
message_lower = message.lower().strip()
|
||||||
|
|
||||||
# Tool keywords that suggest tool usage is needed
|
# Tool keywords that suggest tool usage is needed
|
||||||
tool_keywords = [
|
tool_keywords = [
|
||||||
"search", "look up", "find", "google", "current price",
|
"search",
|
||||||
"latest", "today's", "news", "weather", "stock price",
|
"look up",
|
||||||
"read file", "write file", "save", "calculate", "compute",
|
"find",
|
||||||
"run ", "execute", "shell", "command", "install",
|
"google",
|
||||||
|
"current price",
|
||||||
|
"latest",
|
||||||
|
"today's",
|
||||||
|
"news",
|
||||||
|
"weather",
|
||||||
|
"stock price",
|
||||||
|
"read file",
|
||||||
|
"write file",
|
||||||
|
"save",
|
||||||
|
"calculate",
|
||||||
|
"compute",
|
||||||
|
"run ",
|
||||||
|
"execute",
|
||||||
|
"shell",
|
||||||
|
"command",
|
||||||
|
"install",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Chat-only keywords that definitely don't need tools
|
# Chat-only keywords that definitely don't need tools
|
||||||
chat_only = [
|
chat_only = [
|
||||||
"hello", "hi ", "hey", "how are you", "what's up",
|
"hello",
|
||||||
"your name", "who are you", "what are you",
|
"hi ",
|
||||||
"thanks", "thank you", "bye", "goodbye",
|
"hey",
|
||||||
"tell me about yourself", "what can you do",
|
"how are you",
|
||||||
|
"what's up",
|
||||||
|
"your name",
|
||||||
|
"who are you",
|
||||||
|
"what are you",
|
||||||
|
"thanks",
|
||||||
|
"thank you",
|
||||||
|
"bye",
|
||||||
|
"goodbye",
|
||||||
|
"tell me about yourself",
|
||||||
|
"what can you do",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check for chat-only patterns first
|
# Check for chat-only patterns first
|
||||||
for pattern in chat_only:
|
for pattern in chat_only:
|
||||||
if pattern in message_lower:
|
if pattern in message_lower:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check for tool keywords
|
# Check for tool keywords
|
||||||
for keyword in tool_keywords:
|
for keyword in tool_keywords:
|
||||||
if keyword in message_lower:
|
if keyword in message_lower:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Simple questions (starting with what, who, how, why, when, where)
|
# Simple questions (starting with what, who, how, why, when, where)
|
||||||
# usually don't need tools unless about current/real-time info
|
# usually don't need tools unless about current/real-time info
|
||||||
simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"]
|
simple_question_words = ["what is", "who is", "how does", "why is", "when did", "where is"]
|
||||||
@@ -150,7 +230,7 @@ class ConversationManager:
|
|||||||
if any(t in message_lower for t in time_words):
|
if any(t in message_lower for t in time_words):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Default: don't use tools for unclear cases
|
# Default: don't use tools for unclear cases
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -25,11 +25,12 @@ def _get_model():
|
|||||||
global _model, _has_embeddings
|
global _model, _has_embeddings
|
||||||
if _has_embeddings is False:
|
if _has_embeddings is False:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if _model is not None:
|
if _model is not None:
|
||||||
return _model
|
return _model
|
||||||
|
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
# In test mode or low-memory environments, skip embedding model load
|
# In test mode or low-memory environments, skip embedding model load
|
||||||
if settings.timmy_skip_embeddings:
|
if settings.timmy_skip_embeddings:
|
||||||
_has_embeddings = False
|
_has_embeddings = False
|
||||||
@@ -37,7 +38,8 @@ def _get_model():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
||||||
|
_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
_has_embeddings = True
|
_has_embeddings = True
|
||||||
return _model
|
return _model
|
||||||
except (ImportError, RuntimeError, Exception):
|
except (ImportError, RuntimeError, Exception):
|
||||||
@@ -56,7 +58,7 @@ def _get_embedding_dimension() -> int:
|
|||||||
|
|
||||||
def _compute_embedding(text: str) -> list[float]:
|
def _compute_embedding(text: str) -> list[float]:
|
||||||
"""Compute embedding vector for text.
|
"""Compute embedding vector for text.
|
||||||
|
|
||||||
Uses sentence-transformers if available, otherwise returns
|
Uses sentence-transformers if available, otherwise returns
|
||||||
a simple hash-based vector for basic similarity.
|
a simple hash-based vector for basic similarity.
|
||||||
"""
|
"""
|
||||||
@@ -66,30 +68,31 @@ def _compute_embedding(text: str) -> list[float]:
|
|||||||
return model.encode(text).tolist()
|
return model.encode(text).tolist()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Fallback: simple character n-gram hash embedding
|
# Fallback: simple character n-gram hash embedding
|
||||||
# Not as good but allows the system to work without heavy deps
|
# Not as good but allows the system to work without heavy deps
|
||||||
dim = 384
|
dim = 384
|
||||||
vec = [0.0] * dim
|
vec = [0.0] * dim
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
|
|
||||||
# Generate character trigram features
|
# Generate character trigram features
|
||||||
for i in range(len(text) - 2):
|
for i in range(len(text) - 2):
|
||||||
trigram = text[i:i+3]
|
trigram = text[i : i + 3]
|
||||||
hash_val = hash(trigram) % dim
|
hash_val = hash(trigram) % dim
|
||||||
vec[hash_val] += 1.0
|
vec[hash_val] += 1.0
|
||||||
|
|
||||||
# Normalize
|
# Normalize
|
||||||
norm = sum(x*x for x in vec) ** 0.5
|
norm = sum(x * x for x in vec) ** 0.5
|
||||||
if norm > 0:
|
if norm > 0:
|
||||||
vec = [x/norm for x in vec]
|
vec = [x / norm for x in vec]
|
||||||
|
|
||||||
return vec
|
return vec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryEntry:
|
class MemoryEntry:
|
||||||
"""A memory entry with vector embedding."""
|
"""A memory entry with vector embedding."""
|
||||||
|
|
||||||
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
content: str = "" # The actual text content
|
content: str = "" # The actual text content
|
||||||
source: str = "" # Where it came from (agent, user, system)
|
source: str = "" # Where it came from (agent, user, system)
|
||||||
@@ -99,9 +102,7 @@ class MemoryEntry:
|
|||||||
session_id: Optional[str] = None
|
session_id: Optional[str] = None
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
embedding: Optional[list[float]] = None
|
embedding: Optional[list[float]] = None
|
||||||
timestamp: str = field(
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
|
||||||
)
|
|
||||||
relevance_score: Optional[float] = None # Set during search
|
relevance_score: Optional[float] = None # Set during search
|
||||||
|
|
||||||
|
|
||||||
@@ -110,7 +111,7 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(DB_PATH))
|
conn = sqlite3.connect(str(DB_PATH))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
# Try to load sqlite-vss extension
|
# Try to load sqlite-vss extension
|
||||||
try:
|
try:
|
||||||
conn.enable_load_extension(True)
|
conn.enable_load_extension(True)
|
||||||
@@ -119,7 +120,7 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
_has_vss = True
|
_has_vss = True
|
||||||
except Exception:
|
except Exception:
|
||||||
_has_vss = False
|
_has_vss = False
|
||||||
|
|
||||||
# Create tables
|
# Create tables
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -137,24 +138,14 @@ def _get_conn() -> sqlite3.Connection:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create indexes
|
# Create indexes
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memory_agent ON memory_entries(agent_id)"
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)")
|
||||||
)
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)")
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memory_task ON memory_entries(task_id)"
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)")
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memory_session ON memory_entries(session_id)"
|
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memory_time ON memory_entries(timestamp)"
|
|
||||||
)
|
|
||||||
conn.execute(
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_memory_type ON memory_entries(context_type)"
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -170,7 +161,7 @@ def store_memory(
|
|||||||
compute_embedding: bool = True,
|
compute_embedding: bool = True,
|
||||||
) -> MemoryEntry:
|
) -> MemoryEntry:
|
||||||
"""Store a memory entry with optional embedding.
|
"""Store a memory entry with optional embedding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: The text content to store
|
content: The text content to store
|
||||||
source: Source of the memory (agent name, user, system)
|
source: Source of the memory (agent name, user, system)
|
||||||
@@ -180,14 +171,14 @@ def store_memory(
|
|||||||
session_id: Session identifier
|
session_id: Session identifier
|
||||||
metadata: Additional structured data
|
metadata: Additional structured data
|
||||||
compute_embedding: Whether to compute vector embedding
|
compute_embedding: Whether to compute vector embedding
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stored MemoryEntry
|
The stored MemoryEntry
|
||||||
"""
|
"""
|
||||||
embedding = None
|
embedding = None
|
||||||
if compute_embedding:
|
if compute_embedding:
|
||||||
embedding = _compute_embedding(content)
|
embedding = _compute_embedding(content)
|
||||||
|
|
||||||
entry = MemoryEntry(
|
entry = MemoryEntry(
|
||||||
content=content,
|
content=content,
|
||||||
source=source,
|
source=source,
|
||||||
@@ -198,7 +189,7 @@ def store_memory(
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -222,7 +213,7 @@ def store_memory(
|
|||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
|
||||||
@@ -235,7 +226,7 @@ def search_memories(
|
|||||||
min_relevance: float = 0.0,
|
min_relevance: float = 0.0,
|
||||||
) -> list[MemoryEntry]:
|
) -> list[MemoryEntry]:
|
||||||
"""Search for memories by semantic similarity.
|
"""Search for memories by semantic similarity.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query text
|
query: Search query text
|
||||||
limit: Maximum results
|
limit: Maximum results
|
||||||
@@ -243,18 +234,18 @@ def search_memories(
|
|||||||
agent_id: Filter by agent
|
agent_id: Filter by agent
|
||||||
session_id: Filter by session
|
session_id: Filter by session
|
||||||
min_relevance: Minimum similarity score (0-1)
|
min_relevance: Minimum similarity score (0-1)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of MemoryEntry objects sorted by relevance
|
List of MemoryEntry objects sorted by relevance
|
||||||
"""
|
"""
|
||||||
query_embedding = _compute_embedding(query)
|
query_embedding = _compute_embedding(query)
|
||||||
|
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
|
|
||||||
# Build query with filters
|
# Build query with filters
|
||||||
conditions = []
|
conditions = []
|
||||||
params = []
|
params = []
|
||||||
|
|
||||||
if context_type:
|
if context_type:
|
||||||
conditions.append("context_type = ?")
|
conditions.append("context_type = ?")
|
||||||
params.append(context_type)
|
params.append(context_type)
|
||||||
@@ -264,9 +255,9 @@ def search_memories(
|
|||||||
if session_id:
|
if session_id:
|
||||||
conditions.append("session_id = ?")
|
conditions.append("session_id = ?")
|
||||||
params.append(session_id)
|
params.append(session_id)
|
||||||
|
|
||||||
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
where_clause = "WHERE " + " AND ".join(conditions) if conditions else ""
|
||||||
|
|
||||||
# Fetch candidates (we'll do in-memory similarity for now)
|
# Fetch candidates (we'll do in-memory similarity for now)
|
||||||
# For production with sqlite-vss, this would use vector similarity index
|
# For production with sqlite-vss, this would use vector similarity index
|
||||||
query_sql = f"""
|
query_sql = f"""
|
||||||
@@ -276,10 +267,10 @@ def search_memories(
|
|||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
"""
|
||||||
params.append(limit * 3) # Get more candidates for ranking
|
params.append(limit * 3) # Get more candidates for ranking
|
||||||
|
|
||||||
rows = conn.execute(query_sql, params).fetchall()
|
rows = conn.execute(query_sql, params).fetchall()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# Compute similarity scores
|
# Compute similarity scores
|
||||||
results = []
|
results = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
@@ -295,7 +286,7 @@ def search_memories(
|
|||||||
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
|
embedding=json.loads(row["embedding"]) if row["embedding"] else None,
|
||||||
timestamp=row["timestamp"],
|
timestamp=row["timestamp"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if entry.embedding:
|
if entry.embedding:
|
||||||
# Cosine similarity
|
# Cosine similarity
|
||||||
score = _cosine_similarity(query_embedding, entry.embedding)
|
score = _cosine_similarity(query_embedding, entry.embedding)
|
||||||
@@ -308,7 +299,7 @@ def search_memories(
|
|||||||
entry.relevance_score = score
|
entry.relevance_score = score
|
||||||
if score >= min_relevance:
|
if score >= min_relevance:
|
||||||
results.append(entry)
|
results.append(entry)
|
||||||
|
|
||||||
# Sort by relevance and return top results
|
# Sort by relevance and return top results
|
||||||
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
|
results.sort(key=lambda x: x.relevance_score or 0, reverse=True)
|
||||||
return results[:limit]
|
return results[:limit]
|
||||||
@@ -316,9 +307,9 @@ def search_memories(
|
|||||||
|
|
||||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
"""Compute cosine similarity between two vectors."""
|
"""Compute cosine similarity between two vectors."""
|
||||||
dot = sum(x*y for x, y in zip(a, b))
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
norm_a = sum(x*x for x in a) ** 0.5
|
norm_a = sum(x * x for x in a) ** 0.5
|
||||||
norm_b = sum(x*x for x in b) ** 0.5
|
norm_b = sum(x * x for x in b) ** 0.5
|
||||||
if norm_a == 0 or norm_b == 0:
|
if norm_a == 0 or norm_b == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return dot / (norm_a * norm_b)
|
return dot / (norm_a * norm_b)
|
||||||
@@ -334,51 +325,47 @@ def _keyword_overlap(query: str, content: str) -> float:
|
|||||||
return overlap / len(query_words)
|
return overlap / len(query_words)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_context(
|
def get_memory_context(query: str, max_tokens: int = 2000, **filters) -> str:
|
||||||
query: str,
|
|
||||||
max_tokens: int = 2000,
|
|
||||||
**filters
|
|
||||||
) -> str:
|
|
||||||
"""Get relevant memory context as formatted text for LLM prompts.
|
"""Get relevant memory context as formatted text for LLM prompts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query
|
query: Search query
|
||||||
max_tokens: Approximate maximum tokens to return
|
max_tokens: Approximate maximum tokens to return
|
||||||
**filters: Additional filters (agent_id, session_id, etc.)
|
**filters: Additional filters (agent_id, session_id, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted context string for inclusion in prompts
|
Formatted context string for inclusion in prompts
|
||||||
"""
|
"""
|
||||||
memories = search_memories(query, limit=20, **filters)
|
memories = search_memories(query, limit=20, **filters)
|
||||||
|
|
||||||
context_parts = []
|
context_parts = []
|
||||||
total_chars = 0
|
total_chars = 0
|
||||||
max_chars = max_tokens * 4 # Rough approximation
|
max_chars = max_tokens * 4 # Rough approximation
|
||||||
|
|
||||||
for mem in memories:
|
for mem in memories:
|
||||||
formatted = f"[{mem.source}]: {mem.content}"
|
formatted = f"[{mem.source}]: {mem.content}"
|
||||||
if total_chars + len(formatted) > max_chars:
|
if total_chars + len(formatted) > max_chars:
|
||||||
break
|
break
|
||||||
context_parts.append(formatted)
|
context_parts.append(formatted)
|
||||||
total_chars += len(formatted)
|
total_chars += len(formatted)
|
||||||
|
|
||||||
if not context_parts:
|
if not context_parts:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
return "Relevant context from memory:\n" + "\n\n".join(context_parts)
|
return "Relevant context from memory:\n" + "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
|
def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
|
||||||
"""Recall personal facts about the user or system.
|
"""Recall personal facts about the user or system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Optional agent filter
|
agent_id: Optional agent filter
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of fact strings
|
List of fact strings
|
||||||
"""
|
"""
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
|
|
||||||
if agent_id:
|
if agent_id:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -398,7 +385,7 @@ def recall_personal_facts(agent_id: Optional[str] = None) -> list[str]:
|
|||||||
LIMIT 100
|
LIMIT 100
|
||||||
""",
|
""",
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return [r["content"] for r in rows]
|
return [r["content"] for r in rows]
|
||||||
|
|
||||||
@@ -434,11 +421,11 @@ def update_personal_fact(memory_id: str, new_content: str) -> bool:
|
|||||||
|
|
||||||
def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry:
|
def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntry:
|
||||||
"""Store a personal fact about the user or system.
|
"""Store a personal fact about the user or system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fact: The fact to store
|
fact: The fact to store
|
||||||
agent_id: Associated agent
|
agent_id: Associated agent
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stored MemoryEntry
|
The stored MemoryEntry
|
||||||
"""
|
"""
|
||||||
@@ -453,7 +440,7 @@ def store_personal_fact(fact: str, agent_id: Optional[str] = None) -> MemoryEntr
|
|||||||
|
|
||||||
def delete_memory(memory_id: str) -> bool:
|
def delete_memory(memory_id: str) -> bool:
|
||||||
"""Delete a memory entry by ID.
|
"""Delete a memory entry by ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if deleted, False if not found
|
True if deleted, False if not found
|
||||||
"""
|
"""
|
||||||
@@ -470,29 +457,27 @@ def delete_memory(memory_id: str) -> bool:
|
|||||||
|
|
||||||
def get_memory_stats() -> dict:
|
def get_memory_stats() -> dict:
|
||||||
"""Get statistics about the memory store.
|
"""Get statistics about the memory store.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with counts by type, total entries, etc.
|
Dict with counts by type, total entries, etc.
|
||||||
"""
|
"""
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
|
|
||||||
total = conn.execute(
|
total = conn.execute("SELECT COUNT(*) as count FROM memory_entries").fetchone()["count"]
|
||||||
"SELECT COUNT(*) as count FROM memory_entries"
|
|
||||||
).fetchone()["count"]
|
|
||||||
|
|
||||||
by_type = {}
|
by_type = {}
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type"
|
"SELECT context_type, COUNT(*) as count FROM memory_entries GROUP BY context_type"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
by_type[row["context_type"]] = row["count"]
|
by_type[row["context_type"]] = row["count"]
|
||||||
|
|
||||||
with_embeddings = conn.execute(
|
with_embeddings = conn.execute(
|
||||||
"SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL"
|
"SELECT COUNT(*) as count FROM memory_entries WHERE embedding IS NOT NULL"
|
||||||
).fetchone()["count"]
|
).fetchone()["count"]
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_entries": total,
|
"total_entries": total,
|
||||||
"by_type": by_type,
|
"by_type": by_type,
|
||||||
@@ -503,20 +488,20 @@ def get_memory_stats() -> dict:
|
|||||||
|
|
||||||
def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
||||||
"""Delete old memories to manage storage.
|
"""Delete old memories to manage storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
older_than_days: Delete memories older than this
|
older_than_days: Delete memories older than this
|
||||||
keep_facts: Whether to preserve fact-type memories
|
keep_facts: Whether to preserve fact-type memories
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of entries deleted
|
Number of entries deleted
|
||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat()
|
cutoff = (datetime.now(timezone.utc) - timedelta(days=older_than_days)).isoformat()
|
||||||
|
|
||||||
conn = _get_conn()
|
conn = _get_conn()
|
||||||
|
|
||||||
if keep_facts:
|
if keep_facts:
|
||||||
cursor = conn.execute(
|
cursor = conn.execute(
|
||||||
"""
|
"""
|
||||||
@@ -530,9 +515,9 @@ def prune_memories(older_than_days: int = 90, keep_facts: bool = True) -> int:
|
|||||||
"DELETE FROM memory_entries WHERE timestamp < ?",
|
"DELETE FROM memory_entries WHERE timestamp < ?",
|
||||||
(cutoff,),
|
(cutoff,),
|
||||||
)
|
)
|
||||||
|
|
||||||
deleted = cursor.rowcount
|
deleted = cursor.rowcount
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return deleted
|
return deleted
|
||||||
|
|||||||
@@ -28,50 +28,52 @@ HANDOFF_PATH = VAULT_PATH / "notes" / "last-session-handoff.md"
|
|||||||
|
|
||||||
class HotMemory:
|
class HotMemory:
|
||||||
"""Tier 1: Hot memory (MEMORY.md) — always loaded."""
|
"""Tier 1: Hot memory (MEMORY.md) — always loaded."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.path = HOT_MEMORY_PATH
|
self.path = HOT_MEMORY_PATH
|
||||||
self._content: Optional[str] = None
|
self._content: Optional[str] = None
|
||||||
self._last_modified: Optional[float] = None
|
self._last_modified: Optional[float] = None
|
||||||
|
|
||||||
def read(self, force_refresh: bool = False) -> str:
|
def read(self, force_refresh: bool = False) -> str:
|
||||||
"""Read hot memory, with caching."""
|
"""Read hot memory, with caching."""
|
||||||
if not self.path.exists():
|
if not self.path.exists():
|
||||||
self._create_default()
|
self._create_default()
|
||||||
|
|
||||||
# Check if file changed
|
# Check if file changed
|
||||||
current_mtime = self.path.stat().st_mtime
|
current_mtime = self.path.stat().st_mtime
|
||||||
if not force_refresh and self._content and self._last_modified == current_mtime:
|
if not force_refresh and self._content and self._last_modified == current_mtime:
|
||||||
return self._content
|
return self._content
|
||||||
|
|
||||||
self._content = self.path.read_text()
|
self._content = self.path.read_text()
|
||||||
self._last_modified = current_mtime
|
self._last_modified = current_mtime
|
||||||
logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path)
|
logger.debug("HotMemory: Loaded %d chars from %s", len(self._content), self.path)
|
||||||
return self._content
|
return self._content
|
||||||
|
|
||||||
def update_section(self, section: str, content: str) -> None:
|
def update_section(self, section: str, content: str) -> None:
|
||||||
"""Update a specific section in MEMORY.md."""
|
"""Update a specific section in MEMORY.md."""
|
||||||
full_content = self.read()
|
full_content = self.read()
|
||||||
|
|
||||||
# Find section
|
# Find section
|
||||||
pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)"
|
pattern = rf"(## {re.escape(section)}.*?)(?=\n## |\Z)"
|
||||||
match = re.search(pattern, full_content, re.DOTALL)
|
match = re.search(pattern, full_content, re.DOTALL)
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
# Replace section
|
# Replace section
|
||||||
new_section = f"## {section}\n\n{content}\n\n"
|
new_section = f"## {section}\n\n{content}\n\n"
|
||||||
full_content = full_content[:match.start()] + new_section + full_content[match.end():]
|
full_content = full_content[: match.start()] + new_section + full_content[match.end() :]
|
||||||
else:
|
else:
|
||||||
# Append section before last updated line
|
# Append section before last updated line
|
||||||
insert_point = full_content.rfind("*Prune date:")
|
insert_point = full_content.rfind("*Prune date:")
|
||||||
new_section = f"## {section}\n\n{content}\n\n"
|
new_section = f"## {section}\n\n{content}\n\n"
|
||||||
full_content = full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
|
full_content = (
|
||||||
|
full_content[:insert_point] + new_section + "\n" + full_content[insert_point:]
|
||||||
|
)
|
||||||
|
|
||||||
self.path.write_text(full_content)
|
self.path.write_text(full_content)
|
||||||
self._content = full_content
|
self._content = full_content
|
||||||
self._last_modified = self.path.stat().st_mtime
|
self._last_modified = self.path.stat().st_mtime
|
||||||
logger.info("HotMemory: Updated section '%s'", section)
|
logger.info("HotMemory: Updated section '%s'", section)
|
||||||
|
|
||||||
def _create_default(self) -> None:
|
def _create_default(self) -> None:
|
||||||
"""Create default MEMORY.md if missing."""
|
"""Create default MEMORY.md if missing."""
|
||||||
default_content = """# Timmy Hot Memory
|
default_content = """# Timmy Hot Memory
|
||||||
@@ -130,33 +132,33 @@ class HotMemory:
|
|||||||
*Prune date: {prune_date}*
|
*Prune date: {prune_date}*
|
||||||
""".format(
|
""".format(
|
||||||
date=datetime.now(timezone.utc).strftime("%Y-%m-%d"),
|
date=datetime.now(timezone.utc).strftime("%Y-%m-%d"),
|
||||||
prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d")
|
prune_date=(datetime.now(timezone.utc).replace(day=25)).strftime("%Y-%m-%d"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.path.write_text(default_content)
|
self.path.write_text(default_content)
|
||||||
logger.info("HotMemory: Created default MEMORY.md")
|
logger.info("HotMemory: Created default MEMORY.md")
|
||||||
|
|
||||||
|
|
||||||
class VaultMemory:
|
class VaultMemory:
|
||||||
"""Tier 2: Structured vault (memory/) — append-only markdown."""
|
"""Tier 2: Structured vault (memory/) — append-only markdown."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.path = VAULT_PATH
|
self.path = VAULT_PATH
|
||||||
self._ensure_structure()
|
self._ensure_structure()
|
||||||
|
|
||||||
def _ensure_structure(self) -> None:
|
def _ensure_structure(self) -> None:
|
||||||
"""Ensure vault directory structure exists."""
|
"""Ensure vault directory structure exists."""
|
||||||
(self.path / "self").mkdir(parents=True, exist_ok=True)
|
(self.path / "self").mkdir(parents=True, exist_ok=True)
|
||||||
(self.path / "notes").mkdir(parents=True, exist_ok=True)
|
(self.path / "notes").mkdir(parents=True, exist_ok=True)
|
||||||
(self.path / "aar").mkdir(parents=True, exist_ok=True)
|
(self.path / "aar").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def write_note(self, name: str, content: str, namespace: str = "notes") -> Path:
|
def write_note(self, name: str, content: str, namespace: str = "notes") -> Path:
|
||||||
"""Write a note to the vault."""
|
"""Write a note to the vault."""
|
||||||
# Add timestamp to filename
|
# Add timestamp to filename
|
||||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d")
|
||||||
filename = f"{timestamp}_{name}.md"
|
filename = f"{timestamp}_{name}.md"
|
||||||
filepath = self.path / namespace / filename
|
filepath = self.path / namespace / filename
|
||||||
|
|
||||||
# Add header
|
# Add header
|
||||||
full_content = f"""# {name.replace('_', ' ').title()}
|
full_content = f"""# {name.replace('_', ' ').title()}
|
||||||
|
|
||||||
@@ -171,39 +173,39 @@ class VaultMemory:
|
|||||||
|
|
||||||
*Auto-generated by Timmy Memory System*
|
*Auto-generated by Timmy Memory System*
|
||||||
"""
|
"""
|
||||||
|
|
||||||
filepath.write_text(full_content)
|
filepath.write_text(full_content)
|
||||||
logger.info("VaultMemory: Wrote %s", filepath)
|
logger.info("VaultMemory: Wrote %s", filepath)
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
def read_file(self, filepath: Path) -> str:
|
def read_file(self, filepath: Path) -> str:
|
||||||
"""Read a file from the vault."""
|
"""Read a file from the vault."""
|
||||||
if not filepath.exists():
|
if not filepath.exists():
|
||||||
return ""
|
return ""
|
||||||
return filepath.read_text()
|
return filepath.read_text()
|
||||||
|
|
||||||
def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]:
|
def list_files(self, namespace: str = "notes", pattern: str = "*.md") -> list[Path]:
|
||||||
"""List files in a namespace."""
|
"""List files in a namespace."""
|
||||||
dir_path = self.path / namespace
|
dir_path = self.path / namespace
|
||||||
if not dir_path.exists():
|
if not dir_path.exists():
|
||||||
return []
|
return []
|
||||||
return sorted(dir_path.glob(pattern))
|
return sorted(dir_path.glob(pattern))
|
||||||
|
|
||||||
def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]:
|
def get_latest(self, namespace: str = "notes", pattern: str = "*.md") -> Optional[Path]:
|
||||||
"""Get most recent file in namespace."""
|
"""Get most recent file in namespace."""
|
||||||
files = self.list_files(namespace, pattern)
|
files = self.list_files(namespace, pattern)
|
||||||
return files[-1] if files else None
|
return files[-1] if files else None
|
||||||
|
|
||||||
def update_user_profile(self, key: str, value: str) -> None:
|
def update_user_profile(self, key: str, value: str) -> None:
|
||||||
"""Update a field in user_profile.md."""
|
"""Update a field in user_profile.md."""
|
||||||
profile_path = self.path / "self" / "user_profile.md"
|
profile_path = self.path / "self" / "user_profile.md"
|
||||||
|
|
||||||
if not profile_path.exists():
|
if not profile_path.exists():
|
||||||
# Create default profile
|
# Create default profile
|
||||||
self._create_default_profile()
|
self._create_default_profile()
|
||||||
|
|
||||||
content = profile_path.read_text()
|
content = profile_path.read_text()
|
||||||
|
|
||||||
# Simple pattern replacement
|
# Simple pattern replacement
|
||||||
pattern = rf"(\*\*{re.escape(key)}:\*\*).*"
|
pattern = rf"(\*\*{re.escape(key)}:\*\*).*"
|
||||||
if re.search(pattern, content):
|
if re.search(pattern, content):
|
||||||
@@ -214,17 +216,17 @@ class VaultMemory:
|
|||||||
if facts_section in content:
|
if facts_section in content:
|
||||||
insert_point = content.find(facts_section) + len(facts_section)
|
insert_point = content.find(facts_section) + len(facts_section)
|
||||||
content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:]
|
content = content[:insert_point] + f"\n- {key}: {value}" + content[insert_point:]
|
||||||
|
|
||||||
# Update last_updated
|
# Update last_updated
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r"\*Last updated:.*\*",
|
r"\*Last updated:.*\*",
|
||||||
f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*",
|
f"*Last updated: {datetime.now(timezone.utc).strftime('%Y-%m-%d')}*",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
profile_path.write_text(content)
|
profile_path.write_text(content)
|
||||||
logger.info("VaultMemory: Updated user profile: %s = %s", key, value)
|
logger.info("VaultMemory: Updated user profile: %s = %s", key, value)
|
||||||
|
|
||||||
def _create_default_profile(self) -> None:
|
def _create_default_profile(self) -> None:
|
||||||
"""Create default user profile."""
|
"""Create default user profile."""
|
||||||
profile_path = self.path / "self" / "user_profile.md"
|
profile_path = self.path / "self" / "user_profile.md"
|
||||||
@@ -254,24 +256,26 @@ class VaultMemory:
|
|||||||
---
|
---
|
||||||
|
|
||||||
*Last updated: {date}*
|
*Last updated: {date}*
|
||||||
""".format(date=datetime.now(timezone.utc).strftime("%Y-%m-%d"))
|
""".format(
|
||||||
|
date=datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||||
|
)
|
||||||
|
|
||||||
profile_path.write_text(default)
|
profile_path.write_text(default)
|
||||||
|
|
||||||
|
|
||||||
class HandoffProtocol:
|
class HandoffProtocol:
|
||||||
"""Session handoff protocol for continuity."""
|
"""Session handoff protocol for continuity."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.path = HANDOFF_PATH
|
self.path = HANDOFF_PATH
|
||||||
self.vault = VaultMemory()
|
self.vault = VaultMemory()
|
||||||
|
|
||||||
def write_handoff(
|
def write_handoff(
|
||||||
self,
|
self,
|
||||||
session_summary: str,
|
session_summary: str,
|
||||||
key_decisions: list[str],
|
key_decisions: list[str],
|
||||||
open_items: list[str],
|
open_items: list[str],
|
||||||
next_steps: list[str]
|
next_steps: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write handoff at session end."""
|
"""Write handoff at session end."""
|
||||||
content = f"""# Last Session Handoff
|
content = f"""# Last Session Handoff
|
||||||
@@ -303,25 +307,24 @@ The user was last working on: {session_summary[:200]}...
|
|||||||
|
|
||||||
*This handoff will be auto-loaded at next session start*
|
*This handoff will be auto-loaded at next session start*
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.path.write_text(content)
|
self.path.write_text(content)
|
||||||
|
|
||||||
# Also archive to notes
|
# Also archive to notes
|
||||||
self.vault.write_note(
|
self.vault.write_note("session_handoff", content, namespace="notes")
|
||||||
"session_handoff",
|
|
||||||
content,
|
logger.info(
|
||||||
namespace="notes"
|
"HandoffProtocol: Wrote handoff with %d decisions, %d open items",
|
||||||
|
len(key_decisions),
|
||||||
|
len(open_items),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("HandoffProtocol: Wrote handoff with %d decisions, %d open items",
|
|
||||||
len(key_decisions), len(open_items))
|
|
||||||
|
|
||||||
def read_handoff(self) -> Optional[str]:
|
def read_handoff(self) -> Optional[str]:
|
||||||
"""Read handoff if exists."""
|
"""Read handoff if exists."""
|
||||||
if not self.path.exists():
|
if not self.path.exists():
|
||||||
return None
|
return None
|
||||||
return self.path.read_text()
|
return self.path.read_text()
|
||||||
|
|
||||||
def clear_handoff(self) -> None:
|
def clear_handoff(self) -> None:
|
||||||
"""Clear handoff after loading."""
|
"""Clear handoff after loading."""
|
||||||
if self.path.exists():
|
if self.path.exists():
|
||||||
@@ -331,7 +334,7 @@ The user was last working on: {session_summary[:200]}...
|
|||||||
|
|
||||||
class MemorySystem:
|
class MemorySystem:
|
||||||
"""Central memory system coordinating all tiers."""
|
"""Central memory system coordinating all tiers."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.hot = HotMemory()
|
self.hot = HotMemory()
|
||||||
self.vault = VaultMemory()
|
self.vault = VaultMemory()
|
||||||
@@ -339,52 +342,52 @@ class MemorySystem:
|
|||||||
self.session_start_time: Optional[datetime] = None
|
self.session_start_time: Optional[datetime] = None
|
||||||
self.session_decisions: list[str] = []
|
self.session_decisions: list[str] = []
|
||||||
self.session_open_items: list[str] = []
|
self.session_open_items: list[str] = []
|
||||||
|
|
||||||
def start_session(self) -> str:
|
def start_session(self) -> str:
|
||||||
"""Start a new session, loading context from memory."""
|
"""Start a new session, loading context from memory."""
|
||||||
self.session_start_time = datetime.now(timezone.utc)
|
self.session_start_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Build context
|
# Build context
|
||||||
context_parts = []
|
context_parts = []
|
||||||
|
|
||||||
# 1. Hot memory
|
# 1. Hot memory
|
||||||
hot_content = self.hot.read()
|
hot_content = self.hot.read()
|
||||||
context_parts.append("## Hot Memory\n" + hot_content)
|
context_parts.append("## Hot Memory\n" + hot_content)
|
||||||
|
|
||||||
# 2. Last session handoff
|
# 2. Last session handoff
|
||||||
handoff_content = self.handoff.read_handoff()
|
handoff_content = self.handoff.read_handoff()
|
||||||
if handoff_content:
|
if handoff_content:
|
||||||
context_parts.append("## Previous Session\n" + handoff_content)
|
context_parts.append("## Previous Session\n" + handoff_content)
|
||||||
self.handoff.clear_handoff()
|
self.handoff.clear_handoff()
|
||||||
|
|
||||||
# 3. User profile (key fields only)
|
# 3. User profile (key fields only)
|
||||||
profile = self._load_user_profile_summary()
|
profile = self._load_user_profile_summary()
|
||||||
if profile:
|
if profile:
|
||||||
context_parts.append("## User Context\n" + profile)
|
context_parts.append("## User Context\n" + profile)
|
||||||
|
|
||||||
full_context = "\n\n---\n\n".join(context_parts)
|
full_context = "\n\n---\n\n".join(context_parts)
|
||||||
logger.info("MemorySystem: Session started with %d chars context", len(full_context))
|
logger.info("MemorySystem: Session started with %d chars context", len(full_context))
|
||||||
|
|
||||||
return full_context
|
return full_context
|
||||||
|
|
||||||
def end_session(self, summary: str) -> None:
|
def end_session(self, summary: str) -> None:
|
||||||
"""End session, write handoff."""
|
"""End session, write handoff."""
|
||||||
self.handoff.write_handoff(
|
self.handoff.write_handoff(
|
||||||
session_summary=summary,
|
session_summary=summary,
|
||||||
key_decisions=self.session_decisions,
|
key_decisions=self.session_decisions,
|
||||||
open_items=self.session_open_items,
|
open_items=self.session_open_items,
|
||||||
next_steps=[]
|
next_steps=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update hot memory
|
# Update hot memory
|
||||||
self.hot.update_section(
|
self.hot.update_section(
|
||||||
"Current Session",
|
"Current Session",
|
||||||
f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n" +
|
f"**Last Session:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M')}\n"
|
||||||
f"**Summary:** {summary[:100]}..."
|
+ f"**Summary:** {summary[:100]}...",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("MemorySystem: Session ended, handoff written")
|
logger.info("MemorySystem: Session ended, handoff written")
|
||||||
|
|
||||||
def record_decision(self, decision: str) -> None:
|
def record_decision(self, decision: str) -> None:
|
||||||
"""Record a key decision during session."""
|
"""Record a key decision during session."""
|
||||||
self.session_decisions.append(decision)
|
self.session_decisions.append(decision)
|
||||||
@@ -393,43 +396,47 @@ class MemorySystem:
|
|||||||
if "## Key Decisions" in current:
|
if "## Key Decisions" in current:
|
||||||
# Append to section
|
# Append to section
|
||||||
pass # Handled at session end
|
pass # Handled at session end
|
||||||
|
|
||||||
def record_open_item(self, item: str) -> None:
|
def record_open_item(self, item: str) -> None:
|
||||||
"""Record an open item for follow-up."""
|
"""Record an open item for follow-up."""
|
||||||
self.session_open_items.append(item)
|
self.session_open_items.append(item)
|
||||||
|
|
||||||
def update_user_fact(self, key: str, value: str) -> None:
|
def update_user_fact(self, key: str, value: str) -> None:
|
||||||
"""Update user profile in vault."""
|
"""Update user profile in vault."""
|
||||||
self.vault.update_user_profile(key, value)
|
self.vault.update_user_profile(key, value)
|
||||||
# Also update hot memory
|
# Also update hot memory
|
||||||
if key.lower() == "name":
|
if key.lower() == "name":
|
||||||
self.hot.update_section("User Profile", f"**Name:** {value}")
|
self.hot.update_section("User Profile", f"**Name:** {value}")
|
||||||
|
|
||||||
def _load_user_profile_summary(self) -> str:
|
def _load_user_profile_summary(self) -> str:
|
||||||
"""Load condensed user profile."""
|
"""Load condensed user profile."""
|
||||||
profile_path = self.vault.path / "self" / "user_profile.md"
|
profile_path = self.vault.path / "self" / "user_profile.md"
|
||||||
if not profile_path.exists():
|
if not profile_path.exists():
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
content = profile_path.read_text()
|
content = profile_path.read_text()
|
||||||
|
|
||||||
# Extract key fields
|
# Extract key fields
|
||||||
summary_parts = []
|
summary_parts = []
|
||||||
|
|
||||||
# Name
|
# Name
|
||||||
name_match = re.search(r"\*\*Name:\*\* (.+)", content)
|
name_match = re.search(r"\*\*Name:\*\* (.+)", content)
|
||||||
if name_match and "unknown" not in name_match.group(1).lower():
|
if name_match and "unknown" not in name_match.group(1).lower():
|
||||||
summary_parts.append(f"Name: {name_match.group(1).strip()}")
|
summary_parts.append(f"Name: {name_match.group(1).strip()}")
|
||||||
|
|
||||||
# Interests
|
# Interests
|
||||||
interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL)
|
interests_section = re.search(r"## Interests.*?\n- (.+?)(?=\n## |\Z)", content, re.DOTALL)
|
||||||
if interests_section:
|
if interests_section:
|
||||||
interests = [i.strip() for i in interests_section.group(1).split("\n-") if i.strip() and "to be" not in i]
|
interests = [
|
||||||
|
i.strip()
|
||||||
|
for i in interests_section.group(1).split("\n-")
|
||||||
|
if i.strip() and "to be" not in i
|
||||||
|
]
|
||||||
if interests:
|
if interests:
|
||||||
summary_parts.append(f"Interests: {', '.join(interests[:3])}")
|
summary_parts.append(f"Interests: {', '.join(interests[:3])}")
|
||||||
|
|
||||||
return "\n".join(summary_parts) if summary_parts else ""
|
return "\n".join(summary_parts) if summary_parts else ""
|
||||||
|
|
||||||
def get_system_context(self) -> str:
|
def get_system_context(self) -> str:
|
||||||
"""Get full context for system prompt injection.
|
"""Get full context for system prompt injection.
|
||||||
|
|
||||||
|
|||||||
@@ -38,12 +38,14 @@ def _get_embedding_model():
|
|||||||
global EMBEDDING_MODEL
|
global EMBEDDING_MODEL
|
||||||
if EMBEDDING_MODEL is None:
|
if EMBEDDING_MODEL is None:
|
||||||
from config import settings
|
from config import settings
|
||||||
|
|
||||||
if settings.timmy_skip_embeddings:
|
if settings.timmy_skip_embeddings:
|
||||||
EMBEDDING_MODEL = False
|
EMBEDDING_MODEL = False
|
||||||
return EMBEDDING_MODEL
|
return EMBEDDING_MODEL
|
||||||
try:
|
try:
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
EMBEDDING_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
|
|
||||||
|
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
|
||||||
logger.info("SemanticMemory: Loaded embedding model")
|
logger.info("SemanticMemory: Loaded embedding model")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning("SemanticMemory: sentence-transformers not installed, using fallback")
|
logger.warning("SemanticMemory: sentence-transformers not installed, using fallback")
|
||||||
@@ -60,11 +62,12 @@ def _simple_hash_embedding(text: str) -> list[float]:
|
|||||||
h = hashlib.md5(word.encode()).hexdigest()
|
h = hashlib.md5(word.encode()).hexdigest()
|
||||||
for j in range(8):
|
for j in range(8):
|
||||||
idx = (i * 8 + j) % 128
|
idx = (i * 8 + j) % 128
|
||||||
vec[idx] += int(h[j*2:j*2+2], 16) / 255.0
|
vec[idx] += int(h[j * 2 : j * 2 + 2], 16) / 255.0
|
||||||
# Normalize
|
# Normalize
|
||||||
import math
|
import math
|
||||||
mag = math.sqrt(sum(x*x for x in vec)) or 1.0
|
|
||||||
return [x/mag for x in vec]
|
mag = math.sqrt(sum(x * x for x in vec)) or 1.0
|
||||||
|
return [x / mag for x in vec]
|
||||||
|
|
||||||
|
|
||||||
def embed_text(text: str) -> list[float]:
|
def embed_text(text: str) -> list[float]:
|
||||||
@@ -80,9 +83,10 @@ def embed_text(text: str) -> list[float]:
|
|||||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||||
"""Calculate cosine similarity between two vectors."""
|
"""Calculate cosine similarity between two vectors."""
|
||||||
import math
|
import math
|
||||||
dot = sum(x*y for x, y in zip(a, b))
|
|
||||||
mag_a = math.sqrt(sum(x*x for x in a))
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
mag_b = math.sqrt(sum(x*x for x in b))
|
mag_a = math.sqrt(sum(x * x for x in a))
|
||||||
|
mag_b = math.sqrt(sum(x * x for x in b))
|
||||||
if mag_a == 0 or mag_b == 0:
|
if mag_a == 0 or mag_b == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return dot / (mag_a * mag_b)
|
return dot / (mag_a * mag_b)
|
||||||
@@ -91,6 +95,7 @@ def cosine_similarity(a: list[float], b: list[float]) -> float:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MemoryChunk:
|
class MemoryChunk:
|
||||||
"""A searchable chunk of memory."""
|
"""A searchable chunk of memory."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
source: str # filepath
|
source: str # filepath
|
||||||
content: str
|
content: str
|
||||||
@@ -100,17 +105,18 @@ class MemoryChunk:
|
|||||||
|
|
||||||
class SemanticMemory:
|
class SemanticMemory:
|
||||||
"""Vector-based semantic search over vault content."""
|
"""Vector-based semantic search over vault content."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.db_path = SEMANTIC_DB_PATH
|
self.db_path = SEMANTIC_DB_PATH
|
||||||
self.vault_path = VAULT_PATH
|
self.vault_path = VAULT_PATH
|
||||||
self._init_db()
|
self._init_db()
|
||||||
|
|
||||||
def _init_db(self) -> None:
|
def _init_db(self) -> None:
|
||||||
"""Initialize SQLite with vector storage."""
|
"""Initialize SQLite with vector storage."""
|
||||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
conn = sqlite3.connect(str(self.db_path))
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS chunks (
|
CREATE TABLE IF NOT EXISTS chunks (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
source TEXT NOT NULL,
|
source TEXT NOT NULL,
|
||||||
@@ -119,76 +125,76 @@ class SemanticMemory:
|
|||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
source_hash TEXT NOT NULL
|
source_hash TEXT NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)")
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_source ON chunks(source)")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def index_file(self, filepath: Path) -> int:
|
def index_file(self, filepath: Path) -> int:
|
||||||
"""Index a single file into semantic memory."""
|
"""Index a single file into semantic memory."""
|
||||||
if not filepath.exists():
|
if not filepath.exists():
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
content = filepath.read_text()
|
content = filepath.read_text()
|
||||||
file_hash = hashlib.md5(content.encode()).hexdigest()
|
file_hash = hashlib.md5(content.encode()).hexdigest()
|
||||||
|
|
||||||
# Check if already indexed with same hash
|
# Check if already indexed with same hash
|
||||||
conn = sqlite3.connect(str(self.db_path))
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
cursor = conn.execute(
|
cursor = conn.execute(
|
||||||
"SELECT source_hash FROM chunks WHERE source = ? LIMIT 1",
|
"SELECT source_hash FROM chunks WHERE source = ? LIMIT 1", (str(filepath),)
|
||||||
(str(filepath),)
|
|
||||||
)
|
)
|
||||||
existing = cursor.fetchone()
|
existing = cursor.fetchone()
|
||||||
if existing and existing[0] == file_hash:
|
if existing and existing[0] == file_hash:
|
||||||
conn.close()
|
conn.close()
|
||||||
return 0 # Already indexed
|
return 0 # Already indexed
|
||||||
|
|
||||||
# Delete old chunks for this file
|
# Delete old chunks for this file
|
||||||
conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),))
|
conn.execute("DELETE FROM chunks WHERE source = ?", (str(filepath),))
|
||||||
|
|
||||||
# Split into chunks (paragraphs)
|
# Split into chunks (paragraphs)
|
||||||
chunks = self._split_into_chunks(content)
|
chunks = self._split_into_chunks(content)
|
||||||
|
|
||||||
# Index each chunk
|
# Index each chunk
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
for i, chunk_text in enumerate(chunks):
|
for i, chunk_text in enumerate(chunks):
|
||||||
if len(chunk_text.strip()) < 20: # Skip tiny chunks
|
if len(chunk_text.strip()) < 20: # Skip tiny chunks
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk_id = f"{filepath.stem}_{i}"
|
chunk_id = f"{filepath.stem}_{i}"
|
||||||
embedding = embed_text(chunk_text)
|
embedding = embed_text(chunk_text)
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""INSERT INTO chunks (id, source, content, embedding, created_at, source_hash)
|
"""INSERT INTO chunks (id, source, content, embedding, created_at, source_hash)
|
||||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||||
(chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash)
|
(chunk_id, str(filepath), chunk_text, json.dumps(embedding), now, file_hash),
|
||||||
)
|
)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
|
logger.info("SemanticMemory: Indexed %s (%d chunks)", filepath.name, len(chunks))
|
||||||
return len(chunks)
|
return len(chunks)
|
||||||
|
|
||||||
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
|
def _split_into_chunks(self, text: str, max_chunk_size: int = 500) -> list[str]:
|
||||||
"""Split text into semantic chunks."""
|
"""Split text into semantic chunks."""
|
||||||
# Split by paragraphs first
|
# Split by paragraphs first
|
||||||
paragraphs = text.split('\n\n')
|
paragraphs = text.split("\n\n")
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
for para in paragraphs:
|
for para in paragraphs:
|
||||||
para = para.strip()
|
para = para.strip()
|
||||||
if not para:
|
if not para:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If paragraph is small enough, keep as one chunk
|
# If paragraph is small enough, keep as one chunk
|
||||||
if len(para) <= max_chunk_size:
|
if len(para) <= max_chunk_size:
|
||||||
chunks.append(para)
|
chunks.append(para)
|
||||||
else:
|
else:
|
||||||
# Split long paragraphs by sentences
|
# Split long paragraphs by sentences
|
||||||
sentences = para.replace('. ', '.\n').split('\n')
|
sentences = para.replace(". ", ".\n").split("\n")
|
||||||
current_chunk = ""
|
current_chunk = ""
|
||||||
|
|
||||||
for sent in sentences:
|
for sent in sentences:
|
||||||
if len(current_chunk) + len(sent) < max_chunk_size:
|
if len(current_chunk) + len(sent) < max_chunk_size:
|
||||||
current_chunk += " " + sent if current_chunk else sent
|
current_chunk += " " + sent if current_chunk else sent
|
||||||
@@ -196,82 +202,80 @@ class SemanticMemory:
|
|||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(current_chunk.strip())
|
chunks.append(current_chunk.strip())
|
||||||
current_chunk = sent
|
current_chunk = sent
|
||||||
|
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(current_chunk.strip())
|
chunks.append(current_chunk.strip())
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
def index_vault(self) -> int:
|
def index_vault(self) -> int:
|
||||||
"""Index entire vault directory."""
|
"""Index entire vault directory."""
|
||||||
total_chunks = 0
|
total_chunks = 0
|
||||||
|
|
||||||
for md_file in self.vault_path.rglob("*.md"):
|
for md_file in self.vault_path.rglob("*.md"):
|
||||||
# Skip handoff file (handled separately)
|
# Skip handoff file (handled separately)
|
||||||
if "last-session-handoff" in md_file.name:
|
if "last-session-handoff" in md_file.name:
|
||||||
continue
|
continue
|
||||||
total_chunks += self.index_file(md_file)
|
total_chunks += self.index_file(md_file)
|
||||||
|
|
||||||
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
|
logger.info("SemanticMemory: Indexed vault (%d total chunks)", total_chunks)
|
||||||
return total_chunks
|
return total_chunks
|
||||||
|
|
||||||
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
|
def search(self, query: str, top_k: int = 5) -> list[tuple[str, float]]:
|
||||||
"""Search for relevant memory chunks."""
|
"""Search for relevant memory chunks."""
|
||||||
query_embedding = embed_text(query)
|
query_embedding = embed_text(query)
|
||||||
|
|
||||||
conn = sqlite3.connect(str(self.db_path))
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
# Get all chunks (in production, use vector index)
|
# Get all chunks (in production, use vector index)
|
||||||
rows = conn.execute(
|
rows = conn.execute("SELECT source, content, embedding FROM chunks").fetchall()
|
||||||
"SELECT source, content, embedding FROM chunks"
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
# Calculate similarities
|
# Calculate similarities
|
||||||
scored = []
|
scored = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
embedding = json.loads(row["embedding"])
|
embedding = json.loads(row["embedding"])
|
||||||
score = cosine_similarity(query_embedding, embedding)
|
score = cosine_similarity(query_embedding, embedding)
|
||||||
scored.append((row["source"], row["content"], score))
|
scored.append((row["source"], row["content"], score))
|
||||||
|
|
||||||
# Sort by score descending
|
# Sort by score descending
|
||||||
scored.sort(key=lambda x: x[2], reverse=True)
|
scored.sort(key=lambda x: x[2], reverse=True)
|
||||||
|
|
||||||
# Return top_k
|
# Return top_k
|
||||||
return [(content, score) for _, content, score in scored[:top_k]]
|
return [(content, score) for _, content, score in scored[:top_k]]
|
||||||
|
|
||||||
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
|
def get_relevant_context(self, query: str, max_chars: int = 2000) -> str:
|
||||||
"""Get formatted context string for a query."""
|
"""Get formatted context string for a query."""
|
||||||
results = self.search(query, top_k=3)
|
results = self.search(query, top_k=3)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
total_chars = 0
|
total_chars = 0
|
||||||
|
|
||||||
for content, score in results:
|
for content, score in results:
|
||||||
if score < 0.3: # Similarity threshold
|
if score < 0.3: # Similarity threshold
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
|
chunk = f"[Relevant memory - score {score:.2f}]: {content[:400]}..."
|
||||||
if total_chars + len(chunk) > max_chars:
|
if total_chars + len(chunk) > max_chars:
|
||||||
break
|
break
|
||||||
|
|
||||||
parts.append(chunk)
|
parts.append(chunk)
|
||||||
total_chars += len(chunk)
|
total_chars += len(chunk)
|
||||||
|
|
||||||
return "\n\n".join(parts) if parts else ""
|
return "\n\n".join(parts) if parts else ""
|
||||||
|
|
||||||
def stats(self) -> dict:
|
def stats(self) -> dict:
|
||||||
"""Get indexing statistics."""
|
"""Get indexing statistics."""
|
||||||
conn = sqlite3.connect(str(self.db_path))
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks")
|
cursor = conn.execute("SELECT COUNT(*), COUNT(DISTINCT source) FROM chunks")
|
||||||
total_chunks, total_files = cursor.fetchone()
|
total_chunks, total_files = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_chunks": total_chunks,
|
"total_chunks": total_chunks,
|
||||||
"total_files": total_files,
|
"total_files": total_files,
|
||||||
@@ -281,40 +285,39 @@ class SemanticMemory:
|
|||||||
|
|
||||||
class MemorySearcher:
|
class MemorySearcher:
|
||||||
"""High-level interface for memory search."""
|
"""High-level interface for memory search."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.semantic = SemanticMemory()
|
self.semantic = SemanticMemory()
|
||||||
|
|
||||||
def search(self, query: str, tiers: list[str] = None) -> dict:
|
def search(self, query: str, tiers: list[str] = None) -> dict:
|
||||||
"""Search across memory tiers.
|
"""Search across memory tiers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Search query
|
query: Search query
|
||||||
tiers: List of tiers to search ["hot", "vault", "semantic"]
|
tiers: List of tiers to search ["hot", "vault", "semantic"]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with results from each tier
|
Dict with results from each tier
|
||||||
"""
|
"""
|
||||||
tiers = tiers or ["semantic"] # Default to semantic only
|
tiers = tiers or ["semantic"] # Default to semantic only
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
if "semantic" in tiers:
|
if "semantic" in tiers:
|
||||||
semantic_results = self.semantic.search(query, top_k=5)
|
semantic_results = self.semantic.search(query, top_k=5)
|
||||||
results["semantic"] = [
|
results["semantic"] = [
|
||||||
{"content": content, "score": score}
|
{"content": content, "score": score} for content, score in semantic_results
|
||||||
for content, score in semantic_results
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_context_for_query(self, query: str) -> str:
|
def get_context_for_query(self, query: str) -> str:
|
||||||
"""Get comprehensive context for a user query."""
|
"""Get comprehensive context for a user query."""
|
||||||
# Get semantic context
|
# Get semantic context
|
||||||
semantic_context = self.semantic.get_relevant_context(query)
|
semantic_context = self.semantic.get_relevant_context(query)
|
||||||
|
|
||||||
if semantic_context:
|
if semantic_context:
|
||||||
return f"## Relevant Past Context\n\n{semantic_context}"
|
return f"## Relevant Past Context\n\n{semantic_context}"
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@@ -353,6 +356,7 @@ def memory_search(query: str, top_k: int = 5) -> str:
|
|||||||
# 2. Search runtime vector store (stored facts/conversations)
|
# 2. Search runtime vector store (stored facts/conversations)
|
||||||
try:
|
try:
|
||||||
from timmy.memory.vector_store import search_memories
|
from timmy.memory.vector_store import search_memories
|
||||||
|
|
||||||
runtime_results = search_memories(query, limit=top_k, min_relevance=0.2)
|
runtime_results = search_memories(query, limit=top_k, min_relevance=0.2)
|
||||||
for entry in runtime_results:
|
for entry in runtime_results:
|
||||||
label = entry.context_type or "memory"
|
label = entry.context_type or "memory"
|
||||||
@@ -387,6 +391,7 @@ def memory_read(query: str = "", top_k: int = 5) -> str:
|
|||||||
# Always include personal facts first
|
# Always include personal facts first
|
||||||
try:
|
try:
|
||||||
from timmy.memory.vector_store import search_memories
|
from timmy.memory.vector_store import search_memories
|
||||||
|
|
||||||
facts = search_memories(query or "", limit=top_k, min_relevance=0.0)
|
facts = search_memories(query or "", limit=top_k, min_relevance=0.0)
|
||||||
fact_entries = [e for e in facts if (e.context_type or "") == "fact"]
|
fact_entries = [e for e in facts if (e.context_type or "") == "fact"]
|
||||||
if fact_entries:
|
if fact_entries:
|
||||||
@@ -433,6 +438,7 @@ def memory_write(content: str, context_type: str = "fact") -> str:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from timmy.memory.vector_store import store_memory
|
from timmy.memory.vector_store import store_memory
|
||||||
|
|
||||||
entry = store_memory(
|
entry = store_memory(
|
||||||
content=content.strip(),
|
content=content.strip(),
|
||||||
source="agent",
|
source="agent",
|
||||||
|
|||||||
@@ -32,13 +32,15 @@ _TOOL_CALL_JSON = re.compile(
|
|||||||
|
|
||||||
# Matches function-call-style text: memory_search(query="...") etc.
|
# Matches function-call-style text: memory_search(query="...") etc.
|
||||||
_FUNC_CALL_TEXT = re.compile(
|
_FUNC_CALL_TEXT = re.compile(
|
||||||
r'\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)'
|
r"\b(?:memory_search|web_search|shell|python|read_file|write_file|list_files|calculator)"
|
||||||
r'\s*\([^)]*\)',
|
r"\s*\([^)]*\)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Matches chain-of-thought narration lines the model should keep internal
|
# Matches chain-of-thought narration lines the model should keep internal
|
||||||
_COT_PATTERNS = [
|
_COT_PATTERNS = [
|
||||||
re.compile(r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE),
|
re.compile(
|
||||||
|
r"^(?:Since |Using |Let me |I'll use |I will use |Here's a possible ).*$", re.MULTILINE
|
||||||
|
),
|
||||||
re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE),
|
re.compile(r"^(?:I found a relevant |This context suggests ).*$", re.MULTILINE),
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -48,6 +50,7 @@ def _get_agent():
|
|||||||
global _agent
|
global _agent
|
||||||
if _agent is None:
|
if _agent is None:
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_agent = create_timmy()
|
_agent = create_timmy()
|
||||||
logger.info("Session: Timmy agent initialized (singleton)")
|
logger.info("Session: Timmy agent initialized (singleton)")
|
||||||
@@ -99,6 +102,7 @@ def reset_session(session_id: Optional[str] = None) -> None:
|
|||||||
sid = session_id or _DEFAULT_SESSION_ID
|
sid = session_id or _DEFAULT_SESSION_ID
|
||||||
try:
|
try:
|
||||||
from timmy.conversation import conversation_manager
|
from timmy.conversation import conversation_manager
|
||||||
|
|
||||||
conversation_manager.clear_context(sid)
|
conversation_manager.clear_context(sid)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Session: context clear failed for %s: %s", sid, exc)
|
logger.debug("Session: context clear failed for %s: %s", sid, exc)
|
||||||
@@ -112,10 +116,12 @@ def _extract_facts(message: str) -> None:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from timmy.conversation import conversation_manager
|
from timmy.conversation import conversation_manager
|
||||||
|
|
||||||
name = conversation_manager.extract_user_name(message)
|
name = conversation_manager.extract_user_name(message)
|
||||||
if name:
|
if name:
|
||||||
try:
|
try:
|
||||||
from timmy.memory_system import memory_system
|
from timmy.memory_system import memory_system
|
||||||
|
|
||||||
memory_system.update_user_fact("Name", name)
|
memory_system.update_user_fact("Name", name)
|
||||||
logger.info("Session: Learned user name: %s", name)
|
logger.info("Session: Learned user name: %s", name)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ including any mistakes or errors that occur during the session."
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, date
|
from datetime import date, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ Continue your train of thought."""
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Thought:
|
class Thought:
|
||||||
"""A single thought in Timmy's inner stream."""
|
"""A single thought in Timmy's inner stream."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
seed_type: str
|
seed_type: str
|
||||||
@@ -98,9 +99,7 @@ def _get_conn(db_path: Path = _DEFAULT_DB) -> sqlite3.Connection:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
conn.execute(
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)")
|
||||||
"CREATE INDEX IF NOT EXISTS idx_thoughts_time ON thoughts(created_at)"
|
|
||||||
)
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
@@ -190,9 +189,7 @@ class ThinkingEngine:
|
|||||||
def get_thought(self, thought_id: str) -> Optional[Thought]:
|
def get_thought(self, thought_id: str) -> Optional[Thought]:
|
||||||
"""Retrieve a single thought by ID."""
|
"""Retrieve a single thought by ID."""
|
||||||
conn = _get_conn(self._db_path)
|
conn = _get_conn(self._db_path)
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (thought_id,)).fetchone()
|
||||||
"SELECT * FROM thoughts WHERE id = ?", (thought_id,)
|
|
||||||
).fetchone()
|
|
||||||
conn.close()
|
conn.close()
|
||||||
return _row_to_thought(row) if row else None
|
return _row_to_thought(row) if row else None
|
||||||
|
|
||||||
@@ -208,9 +205,7 @@ class ThinkingEngine:
|
|||||||
for _ in range(max_depth):
|
for _ in range(max_depth):
|
||||||
if not current_id:
|
if not current_id:
|
||||||
break
|
break
|
||||||
row = conn.execute(
|
row = conn.execute("SELECT * FROM thoughts WHERE id = ?", (current_id,)).fetchone()
|
||||||
"SELECT * FROM thoughts WHERE id = ?", (current_id,)
|
|
||||||
).fetchone()
|
|
||||||
if not row:
|
if not row:
|
||||||
break
|
break
|
||||||
chain.append(_row_to_thought(row))
|
chain.append(_row_to_thought(row))
|
||||||
@@ -254,8 +249,10 @@ class ThinkingEngine:
|
|||||||
def _seed_from_swarm(self) -> str:
|
def _seed_from_swarm(self) -> str:
|
||||||
"""Gather recent swarm activity as thought seed."""
|
"""Gather recent swarm activity as thought seed."""
|
||||||
try:
|
try:
|
||||||
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from timmy.briefing import _gather_swarm_summary, _gather_task_queue_summary
|
||||||
|
|
||||||
since = datetime.now(timezone.utc) - timedelta(hours=1)
|
since = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||||
swarm = _gather_swarm_summary(since)
|
swarm = _gather_swarm_summary(since)
|
||||||
tasks = _gather_task_queue_summary()
|
tasks = _gather_task_queue_summary()
|
||||||
@@ -272,6 +269,7 @@ class ThinkingEngine:
|
|||||||
"""Gather memory context as thought seed."""
|
"""Gather memory context as thought seed."""
|
||||||
try:
|
try:
|
||||||
from timmy.memory_system import memory_system
|
from timmy.memory_system import memory_system
|
||||||
|
|
||||||
context = memory_system.get_system_context()
|
context = memory_system.get_system_context()
|
||||||
if context:
|
if context:
|
||||||
# Truncate to a reasonable size for a thought seed
|
# Truncate to a reasonable size for a thought seed
|
||||||
@@ -299,10 +297,12 @@ class ThinkingEngine:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from timmy.session import chat
|
from timmy.session import chat
|
||||||
|
|
||||||
return chat(prompt, session_id="thinking")
|
return chat(prompt, session_id="thinking")
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback: create a fresh agent
|
# Fallback: create a fresh agent
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy
|
||||||
|
|
||||||
agent = create_timmy()
|
agent = create_timmy()
|
||||||
run = agent.run(prompt, stream=False)
|
run = agent.run(prompt, stream=False)
|
||||||
return run.content if hasattr(run, "content") else str(run)
|
return run.content if hasattr(run, "content") else str(run)
|
||||||
@@ -323,8 +323,7 @@ class ThinkingEngine:
|
|||||||
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)
|
INSERT INTO thoughts (id, content, seed_type, parent_id, created_at)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
(thought.id, thought.content, thought.seed_type,
|
(thought.id, thought.content, thought.seed_type, thought.parent_id, thought.created_at),
|
||||||
thought.parent_id, thought.created_at),
|
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -333,7 +332,8 @@ class ThinkingEngine:
|
|||||||
def _log_event(self, thought: Thought) -> None:
|
def _log_event(self, thought: Thought) -> None:
|
||||||
"""Log the thought as a swarm event."""
|
"""Log the thought as a swarm event."""
|
||||||
try:
|
try:
|
||||||
from swarm.event_log import log_event, EventType
|
from swarm.event_log import EventType, log_event
|
||||||
|
|
||||||
log_event(
|
log_event(
|
||||||
EventType.TIMMY_THOUGHT,
|
EventType.TIMMY_THOUGHT,
|
||||||
source="thinking-engine",
|
source="thinking-engine",
|
||||||
@@ -351,12 +351,16 @@ class ThinkingEngine:
|
|||||||
"""Broadcast the thought to WebSocket clients."""
|
"""Broadcast the thought to WebSocket clients."""
|
||||||
try:
|
try:
|
||||||
from infrastructure.ws_manager.handler import ws_manager
|
from infrastructure.ws_manager.handler import ws_manager
|
||||||
await ws_manager.broadcast("timmy_thought", {
|
|
||||||
"thought_id": thought.id,
|
await ws_manager.broadcast(
|
||||||
"content": thought.content,
|
"timmy_thought",
|
||||||
"seed_type": thought.seed_type,
|
{
|
||||||
"created_at": thought.created_at,
|
"thought_id": thought.id,
|
||||||
})
|
"content": thought.content,
|
||||||
|
"seed_type": thought.seed_type,
|
||||||
|
"created_at": thought.created_at,
|
||||||
|
},
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Failed to broadcast thought: %s", exc)
|
logger.debug("Failed to broadcast thought: %s", exc)
|
||||||
|
|
||||||
|
|||||||
@@ -227,11 +227,7 @@ def create_aider_tool(base_path: Path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
return (
|
return result.stdout if result.stdout else "Code changes applied successfully"
|
||||||
result.stdout
|
|
||||||
if result.stdout
|
|
||||||
else "Code changes applied successfully"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return f"Aider error: {result.stderr}"
|
return f"Aider error: {result.stderr}"
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -354,7 +350,7 @@ def consult_grok(query: str) -> str:
|
|||||||
Grok's response text, or an error/status message.
|
Grok's response text, or an error/status message.
|
||||||
"""
|
"""
|
||||||
from config import settings
|
from config import settings
|
||||||
from timmy.backends import grok_available, get_grok_backend
|
from timmy.backends import get_grok_backend, grok_available
|
||||||
|
|
||||||
if not grok_available():
|
if not grok_available():
|
||||||
return (
|
return (
|
||||||
@@ -385,9 +381,7 @@ def consult_grok(query: str) -> str:
|
|||||||
ln = get_ln_backend()
|
ln = get_ln_backend()
|
||||||
sats = min(settings.grok_max_sats_per_query, 100)
|
sats = min(settings.grok_max_sats_per_query, 100)
|
||||||
inv = ln.create_invoice(sats, f"Grok query: {query[:50]}")
|
inv = ln.create_invoice(sats, f"Grok query: {query[:50]}")
|
||||||
invoice_info = (
|
invoice_info = f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
|
||||||
f"\n[Lightning invoice: {sats} sats — {inv.payment_request[:40]}...]"
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -447,7 +441,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
|||||||
|
|
||||||
# Memory search and write — persistent recall across all channels
|
# Memory search and write — persistent recall across all channels
|
||||||
try:
|
try:
|
||||||
from timmy.semantic_memory import memory_search, memory_write, memory_read
|
from timmy.semantic_memory import memory_read, memory_search, memory_write
|
||||||
|
|
||||||
toolkit.register(memory_search, name="memory_search")
|
toolkit.register(memory_search, name="memory_search")
|
||||||
toolkit.register(memory_write, name="memory_write")
|
toolkit.register(memory_write, name="memory_write")
|
||||||
@@ -473,6 +467,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
|||||||
Task ID and confirmation that background execution has started.
|
Task ID and confirmation that background execution has started.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
task_id = None
|
task_id = None
|
||||||
|
|
||||||
async def _launch():
|
async def _launch():
|
||||||
@@ -502,11 +497,7 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
|||||||
|
|
||||||
# System introspection - query runtime environment (sovereign self-knowledge)
|
# System introspection - query runtime environment (sovereign self-knowledge)
|
||||||
try:
|
try:
|
||||||
from timmy.tools_intro import (
|
from timmy.tools_intro import check_ollama_health, get_memory_status, get_system_info
|
||||||
get_system_info,
|
|
||||||
check_ollama_health,
|
|
||||||
get_memory_status,
|
|
||||||
)
|
|
||||||
|
|
||||||
toolkit.register(get_system_info, name="get_system_info")
|
toolkit.register(get_system_info, name="get_system_info")
|
||||||
toolkit.register(check_ollama_health, name="check_ollama_health")
|
toolkit.register(check_ollama_health, name="check_ollama_health")
|
||||||
@@ -526,6 +517,60 @@ def create_full_toolkit(base_dir: str | Path | None = None):
|
|||||||
return toolkit
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
|
def create_experiment_tools(base_dir: str | Path | None = None):
|
||||||
|
"""Create tools for the experiment agent (Lab).
|
||||||
|
|
||||||
|
Includes: prepare_experiment, run_experiment, evaluate_result,
|
||||||
|
plus shell + file ops for editing training code.
|
||||||
|
"""
|
||||||
|
if not _AGNO_TOOLS_AVAILABLE:
|
||||||
|
raise ImportError(f"Agno tools not available: {_ImportError}")
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
toolkit = Toolkit(name="experiment")
|
||||||
|
|
||||||
|
from timmy.autoresearch import evaluate_result, prepare_experiment, run_experiment
|
||||||
|
|
||||||
|
workspace = (
|
||||||
|
Path(base_dir) if base_dir else Path(settings.repo_root) / settings.autoresearch_workspace
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare(repo_url: str = "https://github.com/karpathy/autoresearch.git") -> str:
|
||||||
|
"""Clone and prepare an autoresearch experiment workspace."""
|
||||||
|
return prepare_experiment(workspace, repo_url)
|
||||||
|
|
||||||
|
def _run(timeout: int = 0) -> str:
|
||||||
|
"""Run a single training experiment with wall-clock timeout."""
|
||||||
|
t = timeout or settings.autoresearch_time_budget
|
||||||
|
result = run_experiment(workspace, timeout=t, metric_name=settings.autoresearch_metric)
|
||||||
|
if result["success"] and result["metric"] is not None:
|
||||||
|
return (
|
||||||
|
f"{settings.autoresearch_metric}: {result['metric']:.4f} ({result['duration_s']}s)"
|
||||||
|
)
|
||||||
|
return result.get("error") or "Experiment failed"
|
||||||
|
|
||||||
|
def _evaluate(current: float, baseline: float) -> str:
|
||||||
|
"""Compare current metric against baseline."""
|
||||||
|
return evaluate_result(current, baseline, metric_name=settings.autoresearch_metric)
|
||||||
|
|
||||||
|
toolkit.register(_prepare, name="prepare_experiment")
|
||||||
|
toolkit.register(_run, name="run_experiment")
|
||||||
|
toolkit.register(_evaluate, name="evaluate_result")
|
||||||
|
|
||||||
|
# Also give Lab access to file + shell tools for editing train.py
|
||||||
|
shell_tools = ShellTools()
|
||||||
|
toolkit.register(shell_tools.run_shell_command, name="shell")
|
||||||
|
|
||||||
|
base_path = Path(base_dir) if base_dir else Path(settings.repo_root)
|
||||||
|
file_tools = FileTools(base_dir=base_path)
|
||||||
|
toolkit.register(file_tools.read_file, name="read_file")
|
||||||
|
toolkit.register(file_tools.save_file, name="write_file")
|
||||||
|
toolkit.register(file_tools.list_files, name="list_files")
|
||||||
|
|
||||||
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
# Mapping of agent IDs to their toolkits
|
# Mapping of agent IDs to their toolkits
|
||||||
AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
|
AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
|
||||||
"echo": create_research_tools,
|
"echo": create_research_tools,
|
||||||
@@ -534,6 +579,7 @@ AGENT_TOOLKITS: dict[str, Callable[[], Toolkit]] = {
|
|||||||
"seer": create_data_tools,
|
"seer": create_data_tools,
|
||||||
"forge": create_code_tools,
|
"forge": create_code_tools,
|
||||||
"quill": create_writing_tools,
|
"quill": create_writing_tools,
|
||||||
|
"lab": create_experiment_tools,
|
||||||
"pixel": lambda base_dir=None: _create_stub_toolkit("pixel"),
|
"pixel": lambda base_dir=None: _create_stub_toolkit("pixel"),
|
||||||
"lyra": lambda base_dir=None: _create_stub_toolkit("lyra"),
|
"lyra": lambda base_dir=None: _create_stub_toolkit("lyra"),
|
||||||
"reel": lambda base_dir=None: _create_stub_toolkit("reel"),
|
"reel": lambda base_dir=None: _create_stub_toolkit("reel"),
|
||||||
@@ -553,9 +599,7 @@ def _create_stub_toolkit(name: str):
|
|||||||
return toolkit
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
def get_tools_for_agent(
|
def get_tools_for_agent(agent_id: str, base_dir: str | Path | None = None) -> Toolkit | None:
|
||||||
agent_id: str, base_dir: str | Path | None = None
|
|
||||||
) -> Toolkit | None:
|
|
||||||
"""Get the appropriate toolkit for an agent.
|
"""Get the appropriate toolkit for an agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -643,6 +687,21 @@ def get_all_available_tools() -> dict[str, dict]:
|
|||||||
"description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)",
|
"description": "Local AI coding assistant using Ollama (qwen2.5:14b or deepseek-coder)",
|
||||||
"available_in": ["forge", "orchestrator"],
|
"available_in": ["forge", "orchestrator"],
|
||||||
},
|
},
|
||||||
|
"prepare_experiment": {
|
||||||
|
"name": "Prepare Experiment",
|
||||||
|
"description": "Clone autoresearch repo and run data preparation for ML experiments",
|
||||||
|
"available_in": ["lab", "orchestrator"],
|
||||||
|
},
|
||||||
|
"run_experiment": {
|
||||||
|
"name": "Run Experiment",
|
||||||
|
"description": "Execute a time-boxed ML training experiment and capture metrics",
|
||||||
|
"available_in": ["lab", "orchestrator"],
|
||||||
|
},
|
||||||
|
"evaluate_result": {
|
||||||
|
"name": "Evaluate Result",
|
||||||
|
"description": "Compare experiment metric against baseline to assess improvement",
|
||||||
|
"available_in": ["lab", "orchestrator"],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Git tools ─────────────────────────────────────────────────────────────
|
# ── Git tools ─────────────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ _VALID_AGENTS: dict[str, str] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def delegate_task(agent_name: str, task_description: str, priority: str = "normal") -> dict[str, Any]:
|
def delegate_task(
|
||||||
|
agent_name: str, task_description: str, priority: str = "normal"
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Record a delegation intent to another agent.
|
"""Record a delegation intent to another agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -44,7 +46,9 @@ def delegate_task(agent_name: str, task_description: str, priority: str = "norma
|
|||||||
if priority not in valid_priorities:
|
if priority not in valid_priorities:
|
||||||
priority = "normal"
|
priority = "normal"
|
||||||
|
|
||||||
logger.info("Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority)
|
logger.info(
|
||||||
|
"Delegation intent: %s → %s (priority=%s)", agent_name, task_description[:80], priority
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|||||||
@@ -65,9 +65,7 @@ def _get_ollama_model() -> str:
|
|||||||
models = response.json().get("models", [])
|
models = response.json().get("models", [])
|
||||||
# Check if configured model is available
|
# Check if configured model is available
|
||||||
for model in models:
|
for model in models:
|
||||||
if model.get("name", "").startswith(
|
if model.get("name", "").startswith(settings.ollama_model.split(":")[0]):
|
||||||
settings.ollama_model.split(":")[0]
|
|
||||||
):
|
|
||||||
return settings.ollama_model
|
return settings.ollama_model
|
||||||
|
|
||||||
# Fallback: return configured model
|
# Fallback: return configured model
|
||||||
@@ -139,9 +137,7 @@ def get_memory_status() -> dict[str, Any]:
|
|||||||
if tier1_exists:
|
if tier1_exists:
|
||||||
lines = memory_md.read_text().splitlines()
|
lines = memory_md.read_text().splitlines()
|
||||||
tier1_info["line_count"] = len(lines)
|
tier1_info["line_count"] = len(lines)
|
||||||
tier1_info["sections"] = [
|
tier1_info["sections"] = [ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")]
|
||||||
ln.lstrip("# ").strip() for ln in lines if ln.startswith("## ")
|
|
||||||
]
|
|
||||||
|
|
||||||
# Vault — scan all subdirs under memory/
|
# Vault — scan all subdirs under memory/
|
||||||
vault_root = repo_root / "memory"
|
vault_root = repo_root / "memory"
|
||||||
@@ -233,13 +229,15 @@ def get_agent_roster() -> dict[str, Any]:
|
|||||||
|
|
||||||
roster = []
|
roster = []
|
||||||
for persona in _PERSONAS:
|
for persona in _PERSONAS:
|
||||||
roster.append({
|
roster.append(
|
||||||
"id": persona["agent_id"],
|
{
|
||||||
"name": persona["name"],
|
"id": persona["agent_id"],
|
||||||
"status": "available",
|
"name": persona["name"],
|
||||||
"capabilities": ", ".join(persona.get("tools", [])),
|
"status": "available",
|
||||||
"role": persona.get("role", ""),
|
"capabilities": ", ".join(persona.get("tools", [])),
|
||||||
})
|
"role": persona.get("role", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"agents": roster,
|
"agents": roster,
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class StatusResponse(BaseModel):
|
|||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
"""Simple in-memory rate limiting middleware."""
|
"""Simple in-memory rate limiting middleware."""
|
||||||
|
|
||||||
def __init__(self, app, limit: int = 10, window: int = 60):
|
def __init__(self, app, limit: int = 10, window: int = 60):
|
||||||
super().__init__(app)
|
super().__init__(app)
|
||||||
self.limit = limit
|
self.limit = limit
|
||||||
@@ -53,22 +53,20 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
if request.url.path == "/serve/chat" and request.method == "POST":
|
if request.url.path == "/serve/chat" and request.method == "POST":
|
||||||
client_ip = request.client.host if request.client else "unknown"
|
client_ip = request.client.host if request.client else "unknown"
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# Clean up old requests
|
# Clean up old requests
|
||||||
self.requests[client_ip] = [
|
self.requests[client_ip] = [
|
||||||
t for t in self.requests[client_ip]
|
t for t in self.requests[client_ip] if now - t < self.window
|
||||||
if now - t < self.window
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(self.requests[client_ip]) >= self.limit:
|
if len(self.requests[client_ip]) >= self.limit:
|
||||||
logger.warning("Rate limit exceeded for %s", client_ip)
|
logger.warning("Rate limit exceeded for %s", client_ip)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=429,
|
status_code=429, content={"error": "Rate limit exceeded. Try again later."}
|
||||||
content={"error": "Rate limit exceeded. Try again later."}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.requests[client_ip].append(now)
|
self.requests[client_ip].append(now)
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ def start(
|
|||||||
return
|
return
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from timmy_serve.app import create_timmy_serve_app
|
from timmy_serve.app import create_timmy_serve_app
|
||||||
|
|
||||||
serve_app = create_timmy_serve_app()
|
serve_app = create_timmy_serve_app()
|
||||||
|
|||||||
@@ -23,9 +23,7 @@ class AgentMessage:
|
|||||||
to_agent: str = ""
|
to_agent: str = ""
|
||||||
content: str = ""
|
content: str = ""
|
||||||
message_type: str = "text" # text | command | response | error
|
message_type: str = "text" # text | command | response | error
|
||||||
timestamp: str = field(
|
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
|
||||||
)
|
|
||||||
replied: bool = False
|
replied: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -56,7 +54,10 @@ class InterAgentMessenger:
|
|||||||
self._all_messages.append(msg)
|
self._all_messages.append(msg)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Message %s → %s: %s (%s)",
|
"Message %s → %s: %s (%s)",
|
||||||
from_agent, to_agent, content[:50], message_type,
|
from_agent,
|
||||||
|
to_agent,
|
||||||
|
content[:50],
|
||||||
|
message_type,
|
||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ class VoiceTTS:
|
|||||||
def _init_engine(self) -> None:
|
def _init_engine(self) -> None:
|
||||||
try:
|
try:
|
||||||
import pyttsx3
|
import pyttsx3
|
||||||
|
|
||||||
self._engine = pyttsx3.init()
|
self._engine = pyttsx3.init()
|
||||||
self._engine.setProperty("rate", self._rate)
|
self._engine.setProperty("rate", self._rate)
|
||||||
self._engine.setProperty("volume", self._volume)
|
self._engine.setProperty("volume", self._volume)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user