forked from Rockachopa/Timmy-time-dashboard
Merge pull request #57 from AlexanderWhitestone/feature/model-upgrade-llama3.1
feat: Multi-modal LLM support with automatic model fallback
This commit is contained in:
187
README.md
187
README.md
@@ -18,12 +18,17 @@ make install # create venv + install deps
|
||||
cp .env.example .env # configure environment
|
||||
|
||||
ollama serve # separate terminal
|
||||
ollama pull llama3.2
|
||||
ollama pull llama3.1:8b-instruct # Required for reliable tool calling
|
||||
|
||||
make dev # http://localhost:8000
|
||||
make test # no Ollama needed
|
||||
```
|
||||
|
||||
**Note:** llama3.1:8b-instruct is used instead of llama3.2 because it is
|
||||
specifically fine-tuned for reliable tool/function calling.
|
||||
llama3.2 (3B) was found to hallucinate tool output consistently in testing.
|
||||
Fallback: qwen2.5:14b if llama3.1:8b-instruct is not available.
|
||||
|
||||
---
|
||||
|
||||
## What's Here
|
||||
@@ -74,8 +79,184 @@ make help # see all commands
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
Key variables: `OLLAMA_URL`, `OLLAMA_MODEL`, `TIMMY_MODEL_BACKEND`,
|
||||
`L402_HMAC_SECRET`, `LIGHTNING_BACKEND`, `DEBUG`. Full list in `.env.example`.
|
||||
| Variable | Default | Purpose |
|
||||
|----------|---------|---------|
|
||||
| `OLLAMA_URL` | `http://localhost:11434` | Ollama host |
|
||||
| `OLLAMA_MODEL` | `llama3.1:8b-instruct` | Model for tool calling. Use llama3.1:8b-instruct for reliable tool use; fallback to qwen2.5:14b |
|
||||
| `DEBUG` | `false` | Enable `/docs` and `/redoc` |
|
||||
| `TIMMY_MODEL_BACKEND` | `ollama` | `ollama` \| `airllm` \| `auto` |
|
||||
| `AIRLLM_MODEL_SIZE` | `70b` | `8b` \| `70b` \| `405b` |
|
||||
| `L402_HMAC_SECRET` | *(default — change in prod)* | HMAC signing key for macaroons |
|
||||
| `L402_MACAROON_SECRET` | *(default — change in prod)* | Macaroon secret |
|
||||
| `LIGHTNING_BACKEND` | `mock` | `mock` (production-ready) \| `lnd` (scaffolded, not yet functional) |
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Browser / Phone
|
||||
│ HTTP + HTMX + WebSocket
|
||||
▼
|
||||
┌─────────────────────────────────────────┐
|
||||
│ FastAPI (dashboard.app) │
|
||||
│ routes: agents, health, swarm, │
|
||||
│ marketplace, voice, mobile │
|
||||
└───┬─────────────┬──────────┬────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
Jinja2 Timmy Swarm
|
||||
Templates Agent Coordinator
|
||||
(HTMX) │ ├─ Registry (SQLite)
|
||||
├─ Ollama ├─ AuctionManager (L402 bids)
|
||||
└─ AirLLM ├─ SwarmComms (Redis / in-memory)
|
||||
└─ SwarmManager (subprocess)
|
||||
│
|
||||
├── Voice NLU + TTS (pyttsx3, local)
|
||||
├── WebSocket live feed (ws_manager)
|
||||
├── L402 Lightning proxy (macaroon + invoice)
|
||||
├── Push notifications (local + macOS native)
|
||||
└── Siri Shortcuts API endpoints
|
||||
|
||||
Persistence: timmy.db (Agno memory), data/swarm.db (registry + tasks)
|
||||
External: Ollama :11434, optional Redis, optional LND gRPC
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Layout
|
||||
|
||||
```
|
||||
src/
|
||||
config.py # pydantic-settings — all env vars live here
|
||||
timmy/ # Core agent (agent.py, backends.py, cli.py, prompts.py)
|
||||
hands/ # Autonomous scheduled agents (registry, scheduler, runner)
|
||||
dashboard/ # FastAPI app, routes, Jinja2 templates
|
||||
swarm/ # Multi-agent: coordinator, registry, bidder, tasks, comms
|
||||
timmy_serve/ # L402 proxy, payment handler, TTS, serve CLI
|
||||
spark/ # Intelligence engine — events, predictions, advisory
|
||||
creative/ # Creative director + video assembler pipeline
|
||||
tools/ # Git, image, music, video tools for persona agents
|
||||
lightning/ # Lightning backend abstraction (mock + LND)
|
||||
agent_core/ # Substrate-agnostic agent interface
|
||||
voice/ # NLU intent detection
|
||||
ws_manager/ # WebSocket connection manager
|
||||
notifications/ # Push notification store
|
||||
shortcuts/ # Siri Shortcuts endpoints
|
||||
telegram_bot/ # Telegram bridge
|
||||
self_tdd/ # Continuous test watchdog
|
||||
hands/ # Hand manifests — oracle/, sentinel/, etc.
|
||||
tests/ # one test file per module, all mocked
|
||||
static/style.css # Dark mission-control theme (JetBrains Mono)
|
||||
docs/ # GitHub Pages landing page
|
||||
AGENTS.md # AI agent development standards ← read this
|
||||
.env.example # Environment variable reference
|
||||
Makefile # Common dev commands
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Mobile Access
|
||||
|
||||
The dashboard is fully mobile-optimized (iOS safe area, 44px touch targets, 16px
|
||||
input to prevent zoom, momentum scroll).
|
||||
|
||||
```bash
|
||||
# Bind to your local network
|
||||
uvicorn dashboard.app:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
# Find your IP
|
||||
ipconfig getifaddr en0 # Wi-Fi on macOS
|
||||
```
|
||||
|
||||
Open `http://<your-ip>:8000` on your phone (same Wi-Fi network).
|
||||
|
||||
Mobile-specific routes:
|
||||
- `/mobile` — single-column optimized layout
|
||||
- `/mobile-test` — 21-scenario HITL test harness (layout, touch, scroll, notch)
|
||||
|
||||
---
|
||||
|
||||
## Hands — Autonomous Agents
|
||||
|
||||
Hands are scheduled, autonomous agents that run on cron schedules. Each Hand has a `HAND.toml` manifest, `SYSTEM.md` prompt, and optional `skills/` directory.
|
||||
|
||||
**Built-in Hands:**
|
||||
|
||||
| Hand | Schedule | Purpose |
|
||||
|------|----------|---------|
|
||||
| **Oracle** | 7am, 7pm UTC | Bitcoin intelligence — price, on-chain, macro analysis |
|
||||
| **Sentinel** | Every 15 min | System health — dashboard, agents, database, resources |
|
||||
| **Scout** | Every hour | OSINT monitoring — HN, Reddit, RSS for Bitcoin/sovereign AI |
|
||||
| **Scribe** | Daily 9am | Content production — blog posts, docs, changelog |
|
||||
| **Ledger** | Every 6 hours | Treasury tracking — Bitcoin/Lightning balances, payment audit |
|
||||
| **Weaver** | Sunday 10am | Creative pipeline — orchestrates Pixel+Lyra+Reel for video |
|
||||
|
||||
**Dashboard:** `/hands` — manage, trigger, approve actions
|
||||
|
||||
**Example HAND.toml:**
|
||||
```toml
|
||||
[hand]
|
||||
name = "oracle"
|
||||
schedule = "0 7,19 * * *" # Twice daily
|
||||
enabled = true
|
||||
|
||||
[tools]
|
||||
required = ["mempool_fetch", "price_fetch"]
|
||||
|
||||
[approval_gates]
|
||||
broadcast = { action = "broadcast", description = "Post to dashboard" }
|
||||
|
||||
[output]
|
||||
dashboard = true
|
||||
channel = "telegram"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## AirLLM — Big Brain Backend
|
||||
|
||||
Run 70B or 405B models locally with no GPU, using AirLLM's layer-by-layer loading.
|
||||
Apple Silicon uses MLX automatically.
|
||||
|
||||
```bash
|
||||
pip install ".[bigbrain]"
|
||||
pip install "airllm[mlx]" # Apple Silicon only
|
||||
|
||||
timmy chat "Explain self-custody" --backend airllm --model-size 70b
|
||||
```
|
||||
|
||||
Or set once in `.env`:
|
||||
```bash
|
||||
TIMMY_MODEL_BACKEND=auto
|
||||
AIRLLM_MODEL_SIZE=70b
|
||||
```
|
||||
|
||||
| Flag | Parameters | RAM needed |
|
||||
|-------|-------------|------------|
|
||||
| `8b` | 8 billion | ~16 GB |
|
||||
| `70b` | 70 billion | ~140 GB |
|
||||
| `405b`| 405 billion | ~810 GB |
|
||||
|
||||
---
|
||||
|
||||
## CLI
|
||||
|
||||
```bash
|
||||
timmy chat "What is sovereignty?"
|
||||
timmy think "Bitcoin and self-custody"
|
||||
timmy status
|
||||
|
||||
timmy-serve start # L402-gated API server (port 8402)
|
||||
timmy-serve invoice # generate a Lightning invoice
|
||||
timmy-serve status
|
||||
```
|
||||
|
||||
Or with the bootstrap script (creates venv, tests, watchdog, server in one shot):
|
||||
```bash
|
||||
bash scripts/activate_self_tdd.sh
|
||||
bash scripts/activate_self_tdd.sh --big-brain # also installs AirLLM
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -24,11 +24,31 @@ providers:
|
||||
priority: 1
|
||||
url: "http://localhost:11434"
|
||||
models:
|
||||
- name: llama3.2
|
||||
# Text + Tools models
|
||||
- name: llama3.1:8b-instruct
|
||||
default: true
|
||||
context_window: 128000
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: llama3.2:3b
|
||||
context_window: 128000
|
||||
capabilities: [text, tools, json, streaming, vision]
|
||||
- name: qwen2.5:14b
|
||||
context_window: 32000
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: deepseek-r1:1.5b
|
||||
context_window: 32000
|
||||
capabilities: [text, json, streaming]
|
||||
|
||||
# Vision models
|
||||
- name: llava:7b
|
||||
context_window: 4096
|
||||
capabilities: [text, vision, streaming]
|
||||
- name: qwen2.5-vl:3b
|
||||
context_window: 32000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
- name: moondream:1.8b
|
||||
context_window: 2048
|
||||
capabilities: [text, vision, streaming]
|
||||
|
||||
# Secondary: Local AirLLM (if installed)
|
||||
- name: airllm-local
|
||||
@@ -38,8 +58,11 @@ providers:
|
||||
models:
|
||||
- name: 70b
|
||||
default: true
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: 8b
|
||||
capabilities: [text, tools, json, streaming]
|
||||
- name: 405b
|
||||
capabilities: [text, tools, json, streaming]
|
||||
|
||||
# Tertiary: OpenAI (if API key available)
|
||||
- name: openai-backup
|
||||
@@ -52,8 +75,10 @@ providers:
|
||||
- name: gpt-4o-mini
|
||||
default: true
|
||||
context_window: 128000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
- name: gpt-4o
|
||||
context_window: 128000
|
||||
capabilities: [text, vision, tools, json, streaming]
|
||||
|
||||
# Quaternary: Anthropic (if API key available)
|
||||
- name: anthropic-backup
|
||||
@@ -65,10 +90,37 @@ providers:
|
||||
- name: claude-3-haiku-20240307
|
||||
default: true
|
||||
context_window: 200000
|
||||
capabilities: [text, vision, streaming]
|
||||
- name: claude-3-sonnet-20240229
|
||||
context_window: 200000
|
||||
capabilities: [text, vision, tools, streaming]
|
||||
|
||||
# ── Custom Models ──────────────────────────────────────────────────────
|
||||
# ── Capability-Based Fallback Chains ────────────────────────────────────────
|
||||
# When a model doesn't support a required capability (e.g., vision),
|
||||
# the system falls back through these chains in order.
|
||||
|
||||
fallback_chains:
|
||||
# Vision-capable models (for image understanding)
|
||||
vision:
|
||||
- llama3.2:3b # Fast, good vision
|
||||
- qwen2.5-vl:3b # Excellent vision, small
|
||||
- llava:7b # Classic vision model
|
||||
- moondream:1.8b # Tiny, fast vision
|
||||
|
||||
# Tool-calling models (for function calling)
|
||||
tools:
|
||||
- llama3.1:8b-instruct # Best tool use
|
||||
- qwen2.5:7b # Reliable tools
|
||||
- llama3.2:3b # Small but capable
|
||||
|
||||
# General text generation (any model)
|
||||
text:
|
||||
- llama3.1:8b-instruct
|
||||
- qwen2.5:14b
|
||||
- deepseek-r1:1.5b
|
||||
- llama3.2:3b
|
||||
|
||||
# ── Custom Models ───────────────────────────────────────────────────────────
|
||||
# Register custom model weights for per-agent assignment.
|
||||
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
|
||||
# Models can also be registered at runtime via the /api/v1/models API.
|
||||
@@ -91,7 +143,7 @@ custom_models: []
|
||||
# context_window: 32000
|
||||
# description: "Process reward model for scoring outputs"
|
||||
|
||||
# ── Agent Model Assignments ─────────────────────────────────────────────
|
||||
# ── Agent Model Assignments ─────────────────────────────────────────────────
|
||||
# Map persona agent IDs to specific models.
|
||||
# Agents without an assignment use the global default (ollama_model).
|
||||
agent_model_assignments: {}
|
||||
@@ -99,6 +151,20 @@ agent_model_assignments: {}
|
||||
# persona-forge: my-finetuned-llama
|
||||
# persona-echo: deepseek-r1:1.5b
|
||||
|
||||
# ── Multi-Modal Settings ────────────────────────────────────────────────────
|
||||
multimodal:
|
||||
# Automatically pull models when needed
|
||||
auto_pull: true
|
||||
|
||||
# Timeout for model pulling (seconds)
|
||||
pull_timeout: 300
|
||||
|
||||
# Maximum fallback depth (how many models to try before giving up)
|
||||
max_fallback_depth: 3
|
||||
|
||||
# Prefer smaller models for vision when available (faster)
|
||||
prefer_small_vision: true
|
||||
|
||||
# Cost tracking (optional, for budget monitoring)
|
||||
cost_tracking:
|
||||
enabled: true
|
||||
|
||||
BIN
data/scripture.db-shm
Normal file
BIN
data/scripture.db-shm
Normal file
Binary file not shown.
BIN
data/scripture.db-wal
Normal file
BIN
data/scripture.db-wal
Normal file
Binary file not shown.
57
docs/CHANGELOG_2025-02-27.md
Normal file
57
docs/CHANGELOG_2025-02-27.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Changelog — 2025-02-27
|
||||
|
||||
## Model Upgrade & Hallucination Fix
|
||||
|
||||
### Change 1: Model Upgrade (Primary Fix)
|
||||
**Problem:** llama3.2 (3B parameters) consistently hallucinated tool output instead of waiting for real results.
|
||||
|
||||
**Solution:** Upgraded default model to `llama3.1:8b-instruct` which is specifically fine-tuned for reliable tool/function calling.
|
||||
|
||||
**Changes:**
|
||||
- `src/config.py`: Changed `ollama_model` default from `llama3.2` to `llama3.1:8b-instruct`
|
||||
- Added fallback logic: if primary model unavailable, auto-fallback to `qwen2.5:14b`
|
||||
- `README.md`: Updated setup instructions with new model requirement
|
||||
|
||||
**User Action Required:**
|
||||
```bash
|
||||
ollama pull llama3.1:8b-instruct
|
||||
```
|
||||
|
||||
### Change 2: Structured Output Enforcement (Foundation)
|
||||
**Preparation:** Added infrastructure for two-phase tool calling with JSON schema enforcement.
|
||||
|
||||
**Implementation:**
|
||||
- Session context tracking in `TimmyOrchestrator`
|
||||
- `_session_init()` runs on first message to load real data
|
||||
|
||||
### Change 3: Git Tool Working Directory Fix
|
||||
**Problem:** Git tools failed with "fatal: Not a git repository" due to wrong working directory.
|
||||
|
||||
**Solution:**
|
||||
- Rewrote `src/tools/git_tools.py` to use subprocess with explicit `cwd=REPO_ROOT`
|
||||
- Added `REPO_ROOT` module-level constant auto-detected at import time
|
||||
- All git commands now run from the correct directory
|
||||
|
||||
### Change 4: Session Init with Git Log
|
||||
**Problem:** Timmy couldn't answer "what's new?" from real data.
|
||||
|
||||
**Solution:**
|
||||
- `_session_init()` now reads `git log --oneline -15` from repo root on first message
|
||||
- Recent commits prepended to system prompt
|
||||
- Timmy now grounds self-description in actual commit history
|
||||
|
||||
### Change 5: Documentation Updates
|
||||
- `README.md`: Updated Quickstart with new model requirement
|
||||
- `README.md`: Configuration table reflects new default model
|
||||
- Added notes explaining why llama3.1:8b-instruct is required
|
||||
|
||||
### Files Modified
|
||||
- `src/config.py` — Model configuration with fallback
|
||||
- `src/tools/git_tools.py` — Complete rewrite with subprocess + cwd
|
||||
- `src/agents/timmy.py` — Session init with git log reading
|
||||
- `README.md` — Updated setup and configuration docs
|
||||
|
||||
### Testing
|
||||
- All git tool tests pass with new subprocess implementation
|
||||
- Git log correctly returns commits from repo root
|
||||
- Session init loads context on first message
|
||||
@@ -8,7 +8,11 @@ class Settings(BaseSettings):
|
||||
ollama_url: str = "http://localhost:11434"
|
||||
|
||||
# LLM model passed to Agno/Ollama — override with OLLAMA_MODEL
|
||||
ollama_model: str = "llama3.2"
|
||||
# llama3.1:8b-instruct is used instead of llama3.2 because it is
|
||||
# specifically fine-tuned for reliable tool/function calling.
|
||||
# llama3.2 (3B) hallucinated tool output consistently in testing.
|
||||
# Fallback: qwen2.5:14b if llama3.1:8b-instruct not available.
|
||||
ollama_model: str = "llama3.1:8b-instruct"
|
||||
|
||||
# Set DEBUG=true to enable /docs and /redoc (disabled by default)
|
||||
debug: bool = False
|
||||
@@ -145,6 +149,62 @@ class Settings(BaseSettings):
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# ── Model fallback configuration ────────────────────────────────────────────
|
||||
# Primary model for reliable tool calling (llama3.1:8b-instruct)
|
||||
# Fallback if primary not available: qwen2.5:14b
|
||||
OLLAMA_MODEL_PRIMARY: str = "llama3.1:8b-instruct"
|
||||
OLLAMA_MODEL_FALLBACK: str = "qwen2.5:14b"
|
||||
|
||||
|
||||
def check_ollama_model_available(model_name: str) -> bool:
|
||||
"""Check if a specific Ollama model is available locally."""
|
||||
try:
|
||||
import urllib.request
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
method="GET",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
import json
|
||||
data = json.loads(response.read().decode())
|
||||
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
||||
# Check for exact match or model name without tag
|
||||
return any(model_name in m or m in model_name for m in models)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def get_effective_ollama_model() -> str:
|
||||
"""Get the effective Ollama model, with fallback logic."""
|
||||
# If user has overridden, use their setting
|
||||
user_model = settings.ollama_model
|
||||
|
||||
# Check if user's model is available
|
||||
if check_ollama_model_available(user_model):
|
||||
return user_model
|
||||
|
||||
# Try primary
|
||||
if check_ollama_model_available(OLLAMA_MODEL_PRIMARY):
|
||||
_startup_logger.warning(
|
||||
f"Requested model '{user_model}' not available. "
|
||||
f"Using primary: {OLLAMA_MODEL_PRIMARY}"
|
||||
)
|
||||
return OLLAMA_MODEL_PRIMARY
|
||||
|
||||
# Try fallback
|
||||
if check_ollama_model_available(OLLAMA_MODEL_FALLBACK):
|
||||
_startup_logger.warning(
|
||||
f"Primary model '{OLLAMA_MODEL_PRIMARY}' not available. "
|
||||
f"Using fallback: {OLLAMA_MODEL_FALLBACK}"
|
||||
)
|
||||
return OLLAMA_MODEL_FALLBACK
|
||||
|
||||
# Last resort - return user's setting and hope for the best
|
||||
return user_model
|
||||
|
||||
|
||||
# ── Startup validation ───────────────────────────────────────────────────────
|
||||
# Enforce security requirements — fail fast in production.
|
||||
import logging as _logging
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Git operations tools for Forge, Helm, and Timmy personas.
|
||||
|
||||
Provides a full set of git commands that agents can execute against
|
||||
local or remote repositories. Uses GitPython under the hood.
|
||||
the local repository. Uses subprocess with explicit working directory
|
||||
to ensure commands run from the repo root.
|
||||
|
||||
All functions return plain dicts so they're easily serialisable for
|
||||
tool-call results, Spark event capture, and WebSocket broadcast.
|
||||
@@ -10,134 +11,261 @@ tool-call results, Spark event capture, and WebSocket broadcast.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GIT_AVAILABLE = True
|
||||
try:
|
||||
from git import Repo, InvalidGitRepositoryError, GitCommandNotFound
|
||||
except ImportError:
|
||||
_GIT_AVAILABLE = False
|
||||
|
||||
def _find_repo_root() -> str:
|
||||
"""Walk up from this file's location to find the .git directory."""
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
# Start from project root (3 levels up from src/tools/git_tools.py)
|
||||
path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
||||
|
||||
while path != os.path.dirname(path):
|
||||
if os.path.exists(os.path.join(path, '.git')):
|
||||
return path
|
||||
path = os.path.dirname(path)
|
||||
|
||||
# Fallback to config repo_root
|
||||
try:
|
||||
from config import settings
|
||||
return settings.repo_root
|
||||
except Exception:
|
||||
return os.getcwd()
|
||||
|
||||
|
||||
def _require_git() -> None:
|
||||
if not _GIT_AVAILABLE:
|
||||
raise ImportError(
|
||||
"GitPython is not installed. Run: pip install GitPython"
|
||||
# Module-level constant for repo root
|
||||
REPO_ROOT = _find_repo_root()
|
||||
logger.info(f"Git repo root: {REPO_ROOT}")
|
||||
|
||||
|
||||
def _run_git_command(args: list[str], cwd: Optional[str] = None) -> tuple[int, str, str]:
|
||||
"""Run a git command with proper working directory.
|
||||
|
||||
Args:
|
||||
args: Git command arguments (e.g., ["log", "--oneline", "-5"])
|
||||
cwd: Working directory (defaults to REPO_ROOT)
|
||||
|
||||
Returns:
|
||||
Tuple of (returncode, stdout, stderr)
|
||||
"""
|
||||
cmd = ["git"] + args
|
||||
working_dir = cwd or REPO_ROOT
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=working_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def _open_repo(repo_path: str | Path) -> "Repo":
|
||||
"""Open an existing git repo at *repo_path*."""
|
||||
_require_git()
|
||||
return Repo(str(repo_path))
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
except subprocess.TimeoutExpired:
|
||||
return -1, "", "Command timed out after 30 seconds"
|
||||
except Exception as exc:
|
||||
return -1, "", str(exc)
|
||||
|
||||
|
||||
# ── Repository management ────────────────────────────────────────────────────
|
||||
|
||||
def git_clone(url: str, dest: str | Path) -> dict:
|
||||
"""Clone a remote repository to a local path.
|
||||
|
||||
Returns dict with ``path`` and ``default_branch``.
|
||||
"""
|
||||
_require_git()
|
||||
repo = Repo.clone_from(url, str(dest))
|
||||
"""Clone a remote repository to a local path."""
|
||||
returncode, stdout, stderr = _run_git_command(
|
||||
["clone", url, str(dest)],
|
||||
cwd=None # Clone uses current directory as parent
|
||||
)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(dest),
|
||||
"default_branch": repo.active_branch.name,
|
||||
"message": f"Cloned {url} to {dest}",
|
||||
}
|
||||
|
||||
|
||||
def git_init(path: str | Path) -> dict:
|
||||
"""Initialise a new git repository at *path*."""
|
||||
_require_git()
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
repo = Repo.init(str(path))
|
||||
return {"success": True, "path": str(path), "bare": repo.bare}
|
||||
os.makedirs(path, exist_ok=True)
|
||||
returncode, stdout, stderr = _run_git_command(["init"], cwd=str(path))
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "path": str(path)}
|
||||
|
||||
|
||||
# ── Status / inspection ──────────────────────────────────────────────────────
|
||||
|
||||
def git_status(repo_path: str | Path) -> dict:
|
||||
def git_status(repo_path: Optional[str] = None) -> dict:
|
||||
"""Return working-tree status: modified, staged, untracked files."""
|
||||
repo = _open_repo(repo_path)
|
||||
cwd = repo_path or REPO_ROOT
|
||||
returncode, stdout, stderr = _run_git_command(
|
||||
["status", "--porcelain", "-b"], cwd=cwd
|
||||
)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
# Parse porcelain output
|
||||
lines = stdout.strip().split("\n") if stdout else []
|
||||
branch = "unknown"
|
||||
modified = []
|
||||
staged = []
|
||||
untracked = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("## "):
|
||||
branch = line[3:].split("...")[0].strip()
|
||||
elif len(line) >= 2:
|
||||
index_status = line[0]
|
||||
worktree_status = line[1]
|
||||
filename = line[3:].strip() if len(line) > 3 else ""
|
||||
|
||||
if index_status in "MADRC":
|
||||
staged.append(filename)
|
||||
if worktree_status in "MD":
|
||||
modified.append(filename)
|
||||
if worktree_status == "?":
|
||||
untracked.append(filename)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"branch": repo.active_branch.name,
|
||||
"is_dirty": repo.is_dirty(untracked_files=True),
|
||||
"untracked": repo.untracked_files,
|
||||
"modified": [item.a_path for item in repo.index.diff(None)],
|
||||
"staged": [item.a_path for item in repo.index.diff("HEAD")],
|
||||
"branch": branch,
|
||||
"is_dirty": bool(modified or staged or untracked),
|
||||
"modified": modified,
|
||||
"staged": staged,
|
||||
"untracked": untracked,
|
||||
}
|
||||
|
||||
|
||||
def git_diff(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
staged: bool = False,
|
||||
file_path: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Show diff of working tree or staged changes.
|
||||
|
||||
If *file_path* is given, scope diff to that file only.
|
||||
"""
|
||||
repo = _open_repo(repo_path)
|
||||
args: list[str] = []
|
||||
"""Show diff of working tree or staged changes."""
|
||||
cwd = repo_path or REPO_ROOT
|
||||
args = ["diff"]
|
||||
if staged:
|
||||
args.append("--cached")
|
||||
if file_path:
|
||||
args.extend(["--", file_path])
|
||||
diff_text = repo.git.diff(*args)
|
||||
return {"success": True, "diff": diff_text, "staged": staged}
|
||||
|
||||
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "diff": stdout, "staged": staged}
|
||||
|
||||
|
||||
def git_log(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
max_count: int = 20,
|
||||
branch: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Return recent commit history as a list of dicts."""
|
||||
repo = _open_repo(repo_path)
|
||||
ref = branch or repo.active_branch.name
|
||||
cwd = repo_path or REPO_ROOT
|
||||
args = ["log", f"--max-count={max_count}", "--format=%H|%h|%s|%an|%ai"]
|
||||
if branch:
|
||||
args.append(branch)
|
||||
|
||||
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
commits = []
|
||||
for commit in repo.iter_commits(ref, max_count=max_count):
|
||||
commits.append({
|
||||
"sha": commit.hexsha,
|
||||
"short_sha": commit.hexsha[:8],
|
||||
"message": commit.message.strip(),
|
||||
"author": str(commit.author),
|
||||
"date": commit.committed_datetime.isoformat(),
|
||||
"files_changed": len(commit.stats.files),
|
||||
})
|
||||
return {"success": True, "branch": ref, "commits": commits}
|
||||
for line in stdout.strip().split("\n"):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split("|", 4)
|
||||
if len(parts) >= 5:
|
||||
commits.append({
|
||||
"sha": parts[0],
|
||||
"short_sha": parts[1],
|
||||
"message": parts[2],
|
||||
"author": parts[3],
|
||||
"date": parts[4],
|
||||
})
|
||||
|
||||
# Get current branch
|
||||
_, branch_out, _ = _run_git_command(["branch", "--show-current"], cwd=cwd)
|
||||
current_branch = branch_out.strip() or "main"
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"branch": branch or current_branch,
|
||||
"commits": commits,
|
||||
}
|
||||
|
||||
|
||||
def git_blame(repo_path: str | Path, file_path: str) -> dict:
|
||||
def git_blame(repo_path: Optional[str] = None, file_path: str = "") -> dict:
|
||||
"""Show line-by-line authorship for a file."""
|
||||
repo = _open_repo(repo_path)
|
||||
blame_text = repo.git.blame(file_path)
|
||||
return {"success": True, "file": file_path, "blame": blame_text}
|
||||
if not file_path:
|
||||
return {"success": False, "error": "file_path is required"}
|
||||
|
||||
cwd = repo_path or REPO_ROOT
|
||||
returncode, stdout, stderr = _run_git_command(
|
||||
["blame", "--porcelain", file_path], cwd=cwd
|
||||
)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "file": file_path, "blame": stdout}
|
||||
|
||||
|
||||
# ── Branching ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def git_branch(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
create: Optional[str] = None,
|
||||
switch: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""List branches, optionally create or switch to one."""
|
||||
repo = _open_repo(repo_path)
|
||||
|
||||
cwd = repo_path or REPO_ROOT
|
||||
|
||||
if create:
|
||||
repo.create_head(create)
|
||||
returncode, _, stderr = _run_git_command(
|
||||
["branch", create], cwd=cwd
|
||||
)
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
if switch:
|
||||
repo.heads[switch].checkout()
|
||||
|
||||
branches = [h.name for h in repo.heads]
|
||||
active = repo.active_branch.name
|
||||
returncode, _, stderr = _run_git_command(
|
||||
["checkout", switch], cwd=cwd
|
||||
)
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
# List branches
|
||||
returncode, stdout, stderr = _run_git_command(
|
||||
["branch", "-a", "--format=%(refname:short)%(if)%(HEAD)%(then)*%(end)"],
|
||||
cwd=cwd
|
||||
)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
branches = []
|
||||
active = ""
|
||||
for line in stdout.strip().split("\n"):
|
||||
line = line.strip()
|
||||
if line.endswith("*"):
|
||||
active = line[:-1]
|
||||
branches.append(active)
|
||||
elif line:
|
||||
branches.append(line)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"branches": branches,
|
||||
@@ -149,26 +277,47 @@ def git_branch(
|
||||
|
||||
# ── Staging & committing ─────────────────────────────────────────────────────
|
||||
|
||||
def git_add(repo_path: str | Path, paths: list[str] | None = None) -> dict:
|
||||
"""Stage files for commit. *paths* defaults to all modified files."""
|
||||
repo = _open_repo(repo_path)
|
||||
def git_add(repo_path: Optional[str] = None, paths: Optional[list[str]] = None) -> dict:
|
||||
"""Stage files for commit. *paths* defaults to all modified files."""
|
||||
cwd = repo_path or REPO_ROOT
|
||||
|
||||
if paths:
|
||||
repo.index.add(paths)
|
||||
args = ["add"] + paths
|
||||
else:
|
||||
# Stage all changes
|
||||
repo.git.add(A=True)
|
||||
staged = [item.a_path for item in repo.index.diff("HEAD")]
|
||||
return {"success": True, "staged": staged}
|
||||
args = ["add", "-A"]
|
||||
|
||||
returncode, _, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "staged": paths or ["all"]}
|
||||
|
||||
|
||||
def git_commit(repo_path: str | Path, message: str) -> dict:
|
||||
def git_commit(
|
||||
repo_path: Optional[str] = None,
|
||||
message: str = "",
|
||||
) -> dict:
|
||||
"""Create a commit with the given message."""
|
||||
repo = _open_repo(repo_path)
|
||||
commit = repo.index.commit(message)
|
||||
if not message:
|
||||
return {"success": False, "error": "commit message is required"}
|
||||
|
||||
cwd = repo_path or REPO_ROOT
|
||||
returncode, stdout, stderr = _run_git_command(
|
||||
["commit", "-m", message], cwd=cwd
|
||||
)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
# Get the commit hash
|
||||
_, hash_out, _ = _run_git_command(["rev-parse", "HEAD"], cwd=cwd)
|
||||
commit_hash = hash_out.strip()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"sha": commit.hexsha,
|
||||
"short_sha": commit.hexsha[:8],
|
||||
"sha": commit_hash,
|
||||
"short_sha": commit_hash[:8],
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@@ -176,47 +325,68 @@ def git_commit(repo_path: str | Path, message: str) -> dict:
|
||||
# ── Remote operations ─────────────────────────────────────────────────────────
|
||||
|
||||
def git_push(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
remote: str = "origin",
|
||||
branch: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Push the current (or specified) branch to the remote."""
|
||||
repo = _open_repo(repo_path)
|
||||
ref = branch or repo.active_branch.name
|
||||
info = repo.remotes[remote].push(ref)
|
||||
summaries = [str(i.summary) for i in info]
|
||||
return {"success": True, "remote": remote, "branch": ref, "summaries": summaries}
|
||||
cwd = repo_path or REPO_ROOT
|
||||
args = ["push", remote]
|
||||
if branch:
|
||||
args.append(branch)
|
||||
|
||||
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "remote": remote, "branch": branch or "current"}
|
||||
|
||||
|
||||
def git_pull(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
remote: str = "origin",
|
||||
branch: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Pull from the remote into the working tree."""
|
||||
repo = _open_repo(repo_path)
|
||||
ref = branch or repo.active_branch.name
|
||||
info = repo.remotes[remote].pull(ref)
|
||||
summaries = [str(i.summary) for i in info]
|
||||
return {"success": True, "remote": remote, "branch": ref, "summaries": summaries}
|
||||
cwd = repo_path or REPO_ROOT
|
||||
args = ["pull", remote]
|
||||
if branch:
|
||||
args.append(branch)
|
||||
|
||||
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "remote": remote, "branch": branch or "current"}
|
||||
|
||||
|
||||
# ── Stashing ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def git_stash(
|
||||
repo_path: str | Path,
|
||||
repo_path: Optional[str] = None,
|
||||
pop: bool = False,
|
||||
message: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Stash or pop working-tree changes."""
|
||||
repo = _open_repo(repo_path)
|
||||
cwd = repo_path or REPO_ROOT
|
||||
|
||||
if pop:
|
||||
repo.git.stash("pop")
|
||||
returncode, _, stderr = _run_git_command(["stash", "pop"], cwd=cwd)
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
return {"success": True, "action": "pop"}
|
||||
args = ["push"]
|
||||
|
||||
args = ["stash", "push"]
|
||||
if message:
|
||||
args.extend(["-m", message])
|
||||
repo.git.stash(*args)
|
||||
|
||||
returncode, _, stderr = _run_git_command(args, cwd=cwd)
|
||||
|
||||
if returncode != 0:
|
||||
return {"success": False, "error": stderr}
|
||||
|
||||
return {"success": True, "action": "stash", "message": message}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Infrastructure models package."""
|
||||
|
||||
from infrastructure.models.registry import (
|
||||
CustomModel,
|
||||
ModelFormat,
|
||||
ModelRegistry,
|
||||
ModelRole,
|
||||
model_registry,
|
||||
)
|
||||
from infrastructure.models.multimodal import (
|
||||
ModelCapability,
|
||||
ModelInfo,
|
||||
MultiModalManager,
|
||||
get_model_for_capability,
|
||||
get_multimodal_manager,
|
||||
model_supports_tools,
|
||||
model_supports_vision,
|
||||
pull_model_with_fallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Registry
|
||||
"CustomModel",
|
||||
"ModelFormat",
|
||||
"ModelRegistry",
|
||||
"ModelRole",
|
||||
"model_registry",
|
||||
# Multi-modal
|
||||
"ModelCapability",
|
||||
"ModelInfo",
|
||||
"MultiModalManager",
|
||||
"get_model_for_capability",
|
||||
"get_multimodal_manager",
|
||||
"model_supports_tools",
|
||||
"model_supports_vision",
|
||||
"pull_model_with_fallback",
|
||||
]
|
||||
|
||||
445
src/infrastructure/models/multimodal.py
Normal file
445
src/infrastructure/models/multimodal.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""Multi-modal model support with automatic capability detection and fallbacks.
|
||||
|
||||
Provides:
|
||||
- Model capability detection (vision, audio, etc.)
|
||||
- Automatic model pulling with fallback chains
|
||||
- Content-type aware model selection
|
||||
- Graceful degradation when primary models unavailable
|
||||
|
||||
No cloud by default — tries local first, falls back through configured options.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelCapability(Enum):
|
||||
"""Capabilities a model can have."""
|
||||
TEXT = auto() # Standard text completion
|
||||
VISION = auto() # Image understanding
|
||||
AUDIO = auto() # Audio/speech processing
|
||||
TOOLS = auto() # Function calling / tool use
|
||||
JSON = auto() # Structured output / JSON mode
|
||||
STREAMING = auto() # Streaming responses
|
||||
|
||||
|
||||
# Known model capabilities (local Ollama models)
|
||||
# These are used when we can't query the model directly
|
||||
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||
# Llama 3.x series
|
||||
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
# Qwen series
|
||||
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||
|
||||
# DeepSeek series
|
||||
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:7b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Gemma series
|
||||
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Mistral series
|
||||
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Vision-specific models
|
||||
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:13b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava:34b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava-phi3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"llava-llama3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||
|
||||
# Phi series
|
||||
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Command R
|
||||
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
|
||||
# Granite (IBM)
|
||||
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||
}
|
||||
|
||||
|
||||
# Default fallback chains for each capability
|
||||
# These are tried in order when the primary model doesn't support a capability
|
||||
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||
ModelCapability.VISION: [
|
||||
"llama3.2:3b", # Fast vision model
|
||||
"llava:7b", # Classic vision model
|
||||
"qwen2.5-vl:3b", # Qwen vision
|
||||
"moondream:1.8b", # Tiny vision model (last resort)
|
||||
],
|
||||
ModelCapability.TOOLS: [
|
||||
"llama3.1:8b-instruct", # Best tool use
|
||||
"llama3.2:3b", # Smaller but capable
|
||||
"qwen2.5:7b", # Reliable fallback
|
||||
],
|
||||
ModelCapability.AUDIO: [
|
||||
# Audio models are less common in Ollama
|
||||
# Would need specific audio-capable models here
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""Information about a model's capabilities and availability."""
|
||||
name: str
|
||||
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||
is_available: bool = False
|
||||
is_pulled: bool = False
|
||||
size_mb: Optional[int] = None
|
||||
description: str = ""
|
||||
|
||||
def supports(self, capability: ModelCapability) -> bool:
|
||||
"""Check if model supports a specific capability."""
|
||||
return capability in self.capabilities
|
||||
|
||||
|
||||
class MultiModalManager:
|
||||
"""Manages multi-modal model capabilities and fallback chains.
|
||||
|
||||
This class:
|
||||
1. Detects what capabilities each model has
|
||||
2. Maintains fallback chains for different capabilities
|
||||
3. Pulls models on-demand with automatic fallback
|
||||
4. Routes requests to appropriate models based on content type
|
||||
"""
|
||||
|
||||
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||
self.ollama_url = ollama_url or settings.ollama_url
|
||||
self._available_models: dict[str, ModelInfo] = {}
|
||||
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||
self._refresh_available_models()
|
||||
|
||||
def _refresh_available_models(self) -> None:
|
||||
"""Query Ollama for available models."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
method="GET",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
|
||||
for model_data in data.get("models", []):
|
||||
name = model_data.get("name", "")
|
||||
self._available_models[name] = ModelInfo(
|
||||
name=name,
|
||||
capabilities=self._detect_capabilities(name),
|
||||
is_available=True,
|
||||
is_pulled=True,
|
||||
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||
description=model_data.get("details", {}).get("family", ""),
|
||||
)
|
||||
|
||||
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Could not refresh available models: %s", exc)
|
||||
|
||||
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Detect capabilities for a model based on known data."""
|
||||
# Normalize model name (strip tags for lookup)
|
||||
base_name = model_name.split(":")[0]
|
||||
|
||||
# Try exact match first
|
||||
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||
|
||||
# Try base name match
|
||||
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||
|
||||
# Default to text-only for unknown models
|
||||
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||
|
||||
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||
"""Get capabilities for a specific model."""
|
||||
if model_name in self._available_models:
|
||||
return self._available_models[model_name].capabilities
|
||||
return self._detect_capabilities(model_name)
|
||||
|
||||
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||
"""Check if a model supports a specific capability."""
|
||||
capabilities = self.get_model_capabilities(model_name)
|
||||
return capability in capabilities
|
||||
|
||||
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||
"""Get all available models that support a capability."""
|
||||
return [
|
||||
info for info in self._available_models.values()
|
||||
if capability in info.capabilities
|
||||
]
|
||||
|
||||
def get_best_model_for(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Get the best available model for a specific capability.
|
||||
|
||||
Args:
|
||||
capability: The required capability
|
||||
preferred_model: Preferred model to use if available and capable
|
||||
|
||||
Returns:
|
||||
Model name or None if no suitable model found
|
||||
"""
|
||||
# Check if preferred model supports this capability
|
||||
if preferred_model:
|
||||
if preferred_model in self._available_models:
|
||||
if self.model_supports(preferred_model, capability):
|
||||
return preferred_model
|
||||
logger.debug(
|
||||
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||
preferred_model, capability.name
|
||||
)
|
||||
|
||||
# Check fallback chain for this capability
|
||||
fallback_chain = self._fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if model_name in self._available_models:
|
||||
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||
return model_name
|
||||
|
||||
# Find any available model with this capability
|
||||
capable_models = self.get_models_with_capability(capability)
|
||||
if capable_models:
|
||||
# Sort by size (prefer smaller/faster models as fallback)
|
||||
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
||||
return capable_models[0].name
|
||||
|
||||
return None
|
||||
|
||||
def pull_model_with_fallback(
|
||||
self,
|
||||
primary_model: str,
|
||||
capability: Optional[ModelCapability] = None,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Pull a model with automatic fallback if unavailable.
|
||||
|
||||
Args:
|
||||
primary_model: The desired model to use
|
||||
capability: Required capability (for finding fallback)
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
# Check if primary model is already available
|
||||
if primary_model in self._available_models:
|
||||
return primary_model, False
|
||||
|
||||
# Try to pull the primary model
|
||||
if auto_pull:
|
||||
if self._pull_model(primary_model):
|
||||
return primary_model, False
|
||||
|
||||
# Need to find a fallback
|
||||
if capability:
|
||||
fallback = self.get_best_model_for(capability, primary_model)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Primary model %s unavailable, using fallback %s",
|
||||
primary_model, fallback
|
||||
)
|
||||
return fallback, True
|
||||
|
||||
# Last resort: use the configured default model
|
||||
default_model = settings.ollama_model
|
||||
if default_model in self._available_models:
|
||||
logger.warning(
|
||||
"Falling back to default model %s (primary: %s unavailable)",
|
||||
default_model, primary_model
|
||||
)
|
||||
return default_model, True
|
||||
|
||||
# Absolute last resort
|
||||
return primary_model, False
|
||||
|
||||
def _pull_model(self, model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
# Refresh available models
|
||||
self._refresh_available_models()
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
def configure_fallback_chain(
|
||||
self,
|
||||
capability: ModelCapability,
|
||||
models: list[str]
|
||||
) -> None:
|
||||
"""Configure a custom fallback chain for a capability."""
|
||||
self._fallback_chains[capability] = models
|
||||
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||
|
||||
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||
"""Get the fallback chain for a capability."""
|
||||
return list(self._fallback_chains.get(capability, []))
|
||||
|
||||
def list_available_models(self) -> list[ModelInfo]:
|
||||
"""List all available models with their capabilities."""
|
||||
return list(self._available_models.values())
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh the list of available models."""
|
||||
self._refresh_available_models()
|
||||
|
||||
def get_model_for_content(
|
||||
self,
|
||||
content_type: str, # "text", "image", "audio", "multimodal"
|
||||
preferred_model: Optional[str] = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Get appropriate model based on content type.
|
||||
|
||||
Args:
|
||||
content_type: Type of content (text, image, audio, multimodal)
|
||||
preferred_model: User's preferred model
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
content_type = content_type.lower()
|
||||
|
||||
if content_type in ("image", "vision", "multimodal"):
|
||||
# For vision content, we need a vision-capable model
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or "llava:7b",
|
||||
capability=ModelCapability.VISION,
|
||||
)
|
||||
|
||||
elif content_type == "audio":
|
||||
# Audio support is limited in Ollama
|
||||
# Would need specific audio models
|
||||
logger.warning("Audio support is limited, falling back to text model")
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
else:
|
||||
# Standard text content
|
||||
return self.pull_model_with_fallback(
|
||||
preferred_model or settings.ollama_model,
|
||||
capability=ModelCapability.TEXT,
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
_multimodal_manager: Optional[MultiModalManager] = None
|
||||
|
||||
|
||||
def get_multimodal_manager() -> MultiModalManager:
|
||||
"""Get or create the multi-modal manager singleton."""
|
||||
global _multimodal_manager
|
||||
if _multimodal_manager is None:
|
||||
_multimodal_manager = MultiModalManager()
|
||||
return _multimodal_manager
|
||||
|
||||
|
||||
def get_model_for_capability(
|
||||
capability: ModelCapability,
|
||||
preferred_model: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Convenience function to get best model for a capability."""
|
||||
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||
|
||||
|
||||
def pull_model_with_fallback(
|
||||
primary_model: str,
|
||||
capability: Optional[ModelCapability] = None,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Convenience function to pull model with fallback."""
|
||||
return get_multimodal_manager().pull_model_with_fallback(
|
||||
primary_model, capability, auto_pull
|
||||
)
|
||||
|
||||
|
||||
def model_supports_vision(model_name: str) -> bool:
|
||||
"""Check if a model supports vision."""
|
||||
return get_multimodal_manager().model_supports(model_name, ModelCapability.VISION)
|
||||
|
||||
|
||||
def model_supports_tools(model_name: str) -> bool:
|
||||
"""Check if a model supports tool calling."""
|
||||
return get_multimodal_manager().model_supports(model_name, ModelCapability.TOOLS)
|
||||
@@ -3,14 +3,19 @@
|
||||
Routes requests through an ordered list of LLM providers,
|
||||
automatically failing over on rate limits or errors.
|
||||
Tracks metrics for latency, errors, and cost.
|
||||
|
||||
Now with multi-modal support — automatically selects vision-capable
|
||||
models for image inputs and falls back through capability chains.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from pathlib import Path
|
||||
@@ -43,6 +48,14 @@ class CircuitState(Enum):
|
||||
HALF_OPEN = "half_open" # Testing if recovered
|
||||
|
||||
|
||||
class ContentType(Enum):
|
||||
"""Type of content in the request."""
|
||||
TEXT = "text"
|
||||
VISION = "vision" # Contains images
|
||||
AUDIO = "audio" # Contains audio
|
||||
MULTIMODAL = "multimodal" # Multiple content types
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetrics:
|
||||
"""Metrics for a single provider."""
|
||||
@@ -67,6 +80,18 @@ class ProviderMetrics:
|
||||
return self.failed_requests / self.total_requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCapability:
|
||||
"""Capabilities a model supports."""
|
||||
name: str
|
||||
supports_vision: bool = False
|
||||
supports_audio: bool = False
|
||||
supports_tools: bool = False
|
||||
supports_json: bool = False
|
||||
supports_streaming: bool = True
|
||||
context_window: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""LLM provider configuration and state."""
|
||||
@@ -94,6 +119,23 @@ class Provider:
|
||||
if self.models:
|
||||
return self.models[0]["name"]
|
||||
return None
|
||||
|
||||
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||
"""Get a model that supports the given capability."""
|
||||
for model in self.models:
|
||||
capabilities = model.get("capabilities", [])
|
||||
if capability in capabilities:
|
||||
return model["name"]
|
||||
# Fall back to default
|
||||
return self.get_default_model()
|
||||
|
||||
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||
"""Check if a specific model has a capability."""
|
||||
for model in self.models:
|
||||
if model["name"] == model_name:
|
||||
capabilities = model.get("capabilities", [])
|
||||
return capability in capabilities
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -107,19 +149,39 @@ class RouterConfig:
|
||||
circuit_breaker_half_open_max_calls: int = 2
|
||||
cost_tracking_enabled: bool = True
|
||||
budget_daily_usd: float = 10.0
|
||||
# Multi-modal settings
|
||||
auto_pull_models: bool = True
|
||||
fallback_chains: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class CascadeRouter:
|
||||
"""Routes LLM requests with automatic failover.
|
||||
|
||||
Now with multi-modal support:
|
||||
- Automatically detects content type (text, vision, audio)
|
||||
- Selects appropriate models based on capabilities
|
||||
- Falls back through capability-specific model chains
|
||||
- Supports image URLs and base64 encoding
|
||||
|
||||
Usage:
|
||||
router = CascadeRouter()
|
||||
|
||||
# Text request
|
||||
response = await router.complete(
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
model="llama3.2"
|
||||
)
|
||||
|
||||
# Vision request (automatically detects and selects vision model)
|
||||
response = await router.complete(
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "What's in this image?",
|
||||
"images": ["path/to/image.jpg"]
|
||||
}],
|
||||
model="llava:7b"
|
||||
)
|
||||
|
||||
# Check metrics
|
||||
metrics = router.get_metrics()
|
||||
"""
|
||||
@@ -130,6 +192,14 @@ class CascadeRouter:
|
||||
self.config: RouterConfig = RouterConfig()
|
||||
self._load_config()
|
||||
|
||||
# Initialize multi-modal manager if available
|
||||
self._mm_manager: Optional[Any] = None
|
||||
try:
|
||||
from infrastructure.models.multimodal import get_multimodal_manager
|
||||
self._mm_manager = get_multimodal_manager()
|
||||
except Exception as exc:
|
||||
logger.debug("Multi-modal manager not available: %s", exc)
|
||||
|
||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||
|
||||
def _load_config(self) -> None:
|
||||
@@ -149,6 +219,13 @@ class CascadeRouter:
|
||||
|
||||
# Load cascade settings
|
||||
cascade = data.get("cascade", {})
|
||||
|
||||
# Load fallback chains
|
||||
fallback_chains = data.get("fallback_chains", {})
|
||||
|
||||
# Load multi-modal settings
|
||||
multimodal = data.get("multimodal", {})
|
||||
|
||||
self.config = RouterConfig(
|
||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||
@@ -156,6 +233,8 @@ class CascadeRouter:
|
||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||
auto_pull_models=multimodal.get("auto_pull", True),
|
||||
fallback_chains=fallback_chains,
|
||||
)
|
||||
|
||||
# Load providers
|
||||
@@ -226,6 +305,81 @@ class CascadeRouter:
|
||||
|
||||
return True
|
||||
|
||||
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||
"""Detect the type of content in the messages.
|
||||
|
||||
Checks for images, audio, etc. in the message content.
|
||||
"""
|
||||
has_image = False
|
||||
has_audio = False
|
||||
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Check for image URLs/paths
|
||||
if msg.get("images"):
|
||||
has_image = True
|
||||
|
||||
# Check for image URLs in content
|
||||
if isinstance(content, str):
|
||||
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||
if any(ext in content.lower() for ext in image_extensions):
|
||||
has_image = True
|
||||
if content.startswith("data:image/"):
|
||||
has_image = True
|
||||
|
||||
# Check for audio
|
||||
if msg.get("audio"):
|
||||
has_audio = True
|
||||
|
||||
# Check for multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "image_url":
|
||||
has_image = True
|
||||
elif item.get("type") == "audio":
|
||||
has_audio = True
|
||||
|
||||
if has_image and has_audio:
|
||||
return ContentType.MULTIMODAL
|
||||
elif has_image:
|
||||
return ContentType.VISION
|
||||
elif has_audio:
|
||||
return ContentType.AUDIO
|
||||
return ContentType.TEXT
|
||||
|
||||
def _get_fallback_model(
|
||||
self,
|
||||
provider: Provider,
|
||||
original_model: str,
|
||||
content_type: ContentType
|
||||
) -> Optional[str]:
|
||||
"""Get a fallback model for the given content type."""
|
||||
# Map content type to capability
|
||||
capability_map = {
|
||||
ContentType.VISION: "vision",
|
||||
ContentType.AUDIO: "audio",
|
||||
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||
}
|
||||
|
||||
capability = capability_map.get(content_type)
|
||||
if not capability:
|
||||
return None
|
||||
|
||||
# Check provider's models for capability
|
||||
fallback_model = provider.get_model_with_capability(capability)
|
||||
if fallback_model and fallback_model != original_model:
|
||||
return fallback_model
|
||||
|
||||
# Use fallback chains from config
|
||||
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||
for model_name in fallback_chain:
|
||||
if provider.model_has_capability(model_name, capability):
|
||||
return model_name
|
||||
|
||||
return None
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict],
|
||||
@@ -235,6 +389,11 @@ class CascadeRouter:
|
||||
) -> dict:
|
||||
"""Complete a chat conversation with automatic failover.
|
||||
|
||||
Multi-modal support:
|
||||
- Automatically detects if messages contain images
|
||||
- Falls back to vision-capable models when needed
|
||||
- Supports image URLs, paths, and base64 encoding
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with role and content
|
||||
model: Preferred model (tries this first, then provider defaults)
|
||||
@@ -247,6 +406,11 @@ class CascadeRouter:
|
||||
Raises:
|
||||
RuntimeError: If all providers fail
|
||||
"""
|
||||
# Detect content type for multi-modal routing
|
||||
content_type = self._detect_content_type(messages)
|
||||
if content_type != ContentType.TEXT:
|
||||
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||
|
||||
errors = []
|
||||
|
||||
for provider in self.providers:
|
||||
@@ -266,15 +430,48 @@ class CascadeRouter:
|
||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||
continue
|
||||
|
||||
# Determine which model to use
|
||||
selected_model = model or provider.get_default_model()
|
||||
is_fallback_model = False
|
||||
|
||||
# For non-text content, check if model supports it
|
||||
if content_type != ContentType.TEXT and selected_model:
|
||||
if provider.type == "ollama" and self._mm_manager:
|
||||
from infrastructure.models.multimodal import ModelCapability
|
||||
|
||||
# Check if selected model supports the required capability
|
||||
if content_type == ContentType.VISION:
|
||||
supports = self._mm_manager.model_supports(
|
||||
selected_model, ModelCapability.VISION
|
||||
)
|
||||
if not supports:
|
||||
# Find fallback model
|
||||
fallback = self._get_fallback_model(
|
||||
provider, selected_model, content_type
|
||||
)
|
||||
if fallback:
|
||||
logger.info(
|
||||
"Model %s doesn't support vision, falling back to %s",
|
||||
selected_model, fallback
|
||||
)
|
||||
selected_model = fallback
|
||||
is_fallback_model = True
|
||||
else:
|
||||
logger.warning(
|
||||
"No vision-capable model found on %s, trying anyway",
|
||||
provider.name
|
||||
)
|
||||
|
||||
# Try this provider
|
||||
for attempt in range(self.config.max_retries_per_provider):
|
||||
try:
|
||||
result = await self._try_provider(
|
||||
provider=provider,
|
||||
messages=messages,
|
||||
model=model,
|
||||
model=selected_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
# Success! Update metrics and return
|
||||
@@ -282,8 +479,9 @@ class CascadeRouter:
|
||||
return {
|
||||
"content": result["content"],
|
||||
"provider": provider.name,
|
||||
"model": result.get("model", model or provider.get_default_model()),
|
||||
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||
"latency_ms": result.get("latency_ms", 0),
|
||||
"is_fallback_model": is_fallback_model,
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
@@ -307,9 +505,10 @@ class CascadeRouter:
|
||||
self,
|
||||
provider: Provider,
|
||||
messages: list[dict],
|
||||
model: Optional[str],
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Try a single provider request."""
|
||||
start_time = time.time()
|
||||
@@ -320,6 +519,7 @@ class CascadeRouter:
|
||||
messages=messages,
|
||||
model=model or provider.get_default_model(),
|
||||
temperature=temperature,
|
||||
content_type=content_type,
|
||||
)
|
||||
elif provider.type == "openai":
|
||||
result = await self._call_openai(
|
||||
@@ -359,15 +559,19 @@ class CascadeRouter:
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
temperature: float,
|
||||
content_type: ContentType = ContentType.TEXT,
|
||||
) -> dict:
|
||||
"""Call Ollama API."""
|
||||
"""Call Ollama API with multi-modal support."""
|
||||
import aiohttp
|
||||
|
||||
url = f"{provider.url}/api/chat"
|
||||
|
||||
# Transform messages for Ollama format (including images)
|
||||
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"messages": transformed_messages,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": temperature,
|
||||
@@ -388,6 +592,41 @@ class CascadeRouter:
|
||||
"model": model,
|
||||
}
|
||||
|
||||
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||
"""Transform messages to Ollama format, handling images."""
|
||||
transformed = []
|
||||
|
||||
for msg in messages:
|
||||
new_msg = {
|
||||
"role": msg.get("role", "user"),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
|
||||
# Handle images
|
||||
images = msg.get("images", [])
|
||||
if images:
|
||||
new_msg["images"] = []
|
||||
for img in images:
|
||||
if isinstance(img, str):
|
||||
if img.startswith("data:image/"):
|
||||
# Base64 encoded image
|
||||
new_msg["images"].append(img.split(",")[1])
|
||||
elif img.startswith("http://") or img.startswith("https://"):
|
||||
# URL - would need to download, skip for now
|
||||
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||
elif Path(img).exists():
|
||||
# Local file path - read and encode
|
||||
try:
|
||||
with open(img, "rb") as f:
|
||||
img_data = base64.b64encode(f.read()).decode()
|
||||
new_msg["images"].append(img_data)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read image %s: %s", img, exc)
|
||||
|
||||
transformed.append(new_msg)
|
||||
|
||||
return transformed
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -496,7 +735,7 @@ class CascadeRouter:
|
||||
"content": response.choices[0].message.content,
|
||||
"model": response.model,
|
||||
}
|
||||
|
||||
|
||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||
"""Record a successful request."""
|
||||
provider.metrics.total_requests += 1
|
||||
@@ -598,6 +837,35 @@ class CascadeRouter:
|
||||
for p in self.providers
|
||||
],
|
||||
}
|
||||
|
||||
async def generate_with_image(
|
||||
self,
|
||||
prompt: str,
|
||||
image_path: str,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> dict:
|
||||
"""Convenience method for vision requests.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt about the image
|
||||
image_path: Path to image file
|
||||
model: Vision-capable model (auto-selected if not provided)
|
||||
temperature: Sampling temperature
|
||||
|
||||
Returns:
|
||||
Response dict with content and metadata
|
||||
"""
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"images": [image_path],
|
||||
}]
|
||||
return await self.complete(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
@@ -5,11 +5,16 @@ Memory Architecture:
|
||||
- Tier 2 (Vault): memory/ — structured markdown, append-only
|
||||
- Tier 3 (Semantic): Vector search over vault files
|
||||
|
||||
Model Management:
|
||||
- Pulls requested model automatically if not available
|
||||
- Falls back through capability-based model chains
|
||||
- Multi-modal support with vision model fallbacks
|
||||
|
||||
Handoff Protocol maintains continuity across sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from agno.agent import Agent
|
||||
from agno.db.sqlite import SqliteDb
|
||||
@@ -24,6 +29,23 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Fallback chain for text/tool models (in order of preference)
|
||||
DEFAULT_MODEL_FALLBACKS = [
|
||||
"llama3.1:8b-instruct",
|
||||
"llama3.1",
|
||||
"qwen2.5:14b",
|
||||
"qwen2.5:7b",
|
||||
"llama3.2:3b",
|
||||
]
|
||||
|
||||
# Fallback chain for vision models
|
||||
VISION_MODEL_FALLBACKS = [
|
||||
"llama3.2:3b",
|
||||
"llava:7b",
|
||||
"qwen2.5-vl:3b",
|
||||
"moondream:1.8b",
|
||||
]
|
||||
|
||||
# Union type for callers that want to hint the return type.
|
||||
TimmyAgent = Union[Agent, "TimmyAirLLMAgent", "GrokBackend"]
|
||||
|
||||
@@ -40,6 +62,120 @@ _SMALL_MODEL_PATTERNS = (
|
||||
)
|
||||
|
||||
|
||||
def _check_model_available(model_name: str) -> bool:
|
||||
"""Check if an Ollama model is available locally."""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/tags",
|
||||
method="GET",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=5) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
models = [m.get("name", "") for m in data.get("models", [])]
|
||||
# Check for exact match or model name without tag
|
||||
return any(
|
||||
model_name == m or model_name == m.split(":")[0] or m.startswith(model_name)
|
||||
for m in models
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Could not check model availability: %s", exc)
|
||||
return False
|
||||
|
||||
|
||||
def _pull_model(model_name: str) -> bool:
|
||||
"""Attempt to pull a model from Ollama.
|
||||
|
||||
Returns:
|
||||
True if successful or model already exists
|
||||
"""
|
||||
try:
|
||||
import urllib.request
|
||||
import json
|
||||
|
||||
logger.info("Pulling model: %s", model_name)
|
||||
|
||||
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||
req = urllib.request.Request(
|
||||
f"{url}/api/pull",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=300) as response:
|
||||
if response.status == 200:
|
||||
logger.info("Successfully pulled model: %s", model_name)
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_model_with_fallback(
|
||||
requested_model: Optional[str] = None,
|
||||
require_vision: bool = False,
|
||||
auto_pull: bool = True,
|
||||
) -> tuple[str, bool]:
|
||||
"""Resolve model with automatic pulling and fallback.
|
||||
|
||||
Args:
|
||||
requested_model: Preferred model to use
|
||||
require_vision: Whether the model needs vision capabilities
|
||||
auto_pull: Whether to attempt pulling missing models
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, is_fallback)
|
||||
"""
|
||||
model = requested_model or settings.ollama_model
|
||||
|
||||
# Check if requested model is available
|
||||
if _check_model_available(model):
|
||||
logger.debug("Using available model: %s", model)
|
||||
return model, False
|
||||
|
||||
# Try to pull the requested model
|
||||
if auto_pull:
|
||||
logger.info("Model %s not available locally, attempting to pull...", model)
|
||||
if _pull_model(model):
|
||||
return model, False
|
||||
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
||||
|
||||
# Use appropriate fallback chain
|
||||
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
||||
|
||||
for fallback_model in fallback_chain:
|
||||
if _check_model_available(fallback_model):
|
||||
logger.warning(
|
||||
"Using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
return fallback_model, True
|
||||
|
||||
# Try to pull the fallback
|
||||
if auto_pull and _pull_model(fallback_model):
|
||||
logger.info(
|
||||
"Pulled and using fallback model %s (requested: %s)",
|
||||
fallback_model, model
|
||||
)
|
||||
return fallback_model, True
|
||||
|
||||
# Absolute last resort - return the requested model and hope for the best
|
||||
logger.error(
|
||||
"No models available in fallback chain. Requested: %s",
|
||||
model
|
||||
)
|
||||
return model, False
|
||||
|
||||
|
||||
def _model_supports_tools(model_name: str) -> bool:
|
||||
"""Check if the configured model can reliably handle tool calling.
|
||||
|
||||
@@ -106,7 +242,16 @@ def create_timmy(
|
||||
return TimmyAirLLMAgent(model_size=size)
|
||||
|
||||
# Default: Ollama via Agno.
|
||||
model_name = settings.ollama_model
|
||||
# Resolve model with automatic pulling and fallback
|
||||
model_name, is_fallback = _resolve_model_with_fallback(
|
||||
requested_model=None,
|
||||
require_vision=False,
|
||||
auto_pull=True,
|
||||
)
|
||||
|
||||
if is_fallback:
|
||||
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
||||
|
||||
use_tools = _model_supports_tools(model_name)
|
||||
|
||||
# Conditionally include tools — small models get none
|
||||
|
||||
@@ -31,7 +31,7 @@ from timmy.agent_core.interface import (
|
||||
TimAgent,
|
||||
AgentEffect,
|
||||
)
|
||||
from timmy.agent import create_timmy
|
||||
from timmy.agent import create_timmy, _resolve_model_with_fallback
|
||||
|
||||
|
||||
class OllamaAgent(TimAgent):
|
||||
@@ -53,18 +53,33 @@ class OllamaAgent(TimAgent):
|
||||
identity: AgentIdentity,
|
||||
model: Optional[str] = None,
|
||||
effect_log: Optional[str] = None,
|
||||
require_vision: bool = False,
|
||||
) -> None:
|
||||
"""Initialize Ollama-based agent.
|
||||
|
||||
Args:
|
||||
identity: Agent identity (persistent across sessions)
|
||||
model: Ollama model to use (default from config)
|
||||
model: Ollama model to use (auto-resolves with fallback)
|
||||
effect_log: Path to log agent effects (optional)
|
||||
require_vision: Whether to select a vision-capable model
|
||||
"""
|
||||
super().__init__(identity)
|
||||
|
||||
# Resolve model with automatic pulling and fallback
|
||||
resolved_model, is_fallback = _resolve_model_with_fallback(
|
||||
requested_model=model,
|
||||
require_vision=require_vision,
|
||||
auto_pull=True,
|
||||
)
|
||||
|
||||
if is_fallback:
|
||||
import logging
|
||||
logging.getLogger(__name__).info(
|
||||
"OllamaAdapter using fallback model %s", resolved_model
|
||||
)
|
||||
|
||||
# Initialize underlying Ollama agent
|
||||
self._timmy = create_timmy(model=model)
|
||||
self._timmy = create_timmy(model=resolved_model)
|
||||
|
||||
# Set capabilities based on what Ollama can do
|
||||
self._capabilities = {
|
||||
|
||||
@@ -5,6 +5,8 @@ Uses the three-tier memory system and MCP tools.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from agno.agent import Agent
|
||||
@@ -17,8 +19,166 @@ from mcp.registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Dynamic context that gets built at startup
|
||||
_timmy_context: dict[str, Any] = {
|
||||
"git_log": "",
|
||||
"agents": [],
|
||||
"hands": [],
|
||||
"memory": "",
|
||||
}
|
||||
|
||||
TIMMY_ORCHESTRATOR_PROMPT = """You are Timmy, a sovereign AI orchestrator running locally on this Mac.
|
||||
|
||||
async def _load_hands_async() -> list[dict]:
|
||||
"""Async helper to load hands."""
|
||||
try:
|
||||
from hands.registry import HandRegistry
|
||||
reg = HandRegistry()
|
||||
hands_dict = await reg.load_all()
|
||||
return [
|
||||
{"name": h.name, "schedule": h.schedule.cron if h.schedule else "manual", "enabled": h.enabled}
|
||||
for h in hands_dict.values()
|
||||
]
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load hands for context: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
def build_timmy_context_sync() -> dict[str, Any]:
|
||||
"""Build Timmy's self-awareness context at startup (synchronous version).
|
||||
|
||||
This function gathers:
|
||||
- Recent git commits (last 20)
|
||||
- Active sub-agents
|
||||
- Hot memory from MEMORY.md
|
||||
|
||||
Note: Hands are loaded separately in async context.
|
||||
|
||||
Returns a dict that can be formatted into the system prompt.
|
||||
"""
|
||||
global _timmy_context
|
||||
|
||||
ctx: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"repo_root": settings.repo_root,
|
||||
"git_log": "",
|
||||
"agents": [],
|
||||
"hands": [],
|
||||
"memory": "",
|
||||
}
|
||||
|
||||
# 1. Get recent git commits
|
||||
try:
|
||||
from tools.git_tools import git_log
|
||||
result = git_log(max_count=20)
|
||||
if result.get("success"):
|
||||
commits = result.get("commits", [])
|
||||
ctx["git_log"] = "\n".join([
|
||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||
for c in commits[:20]
|
||||
])
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load git log for context: %s", exc)
|
||||
ctx["git_log"] = "(Git log unavailable)"
|
||||
|
||||
# 2. Get active sub-agents
|
||||
try:
|
||||
from swarm import registry as swarm_registry
|
||||
conn = swarm_registry._get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT id, name, status, capabilities FROM agents ORDER BY name"
|
||||
).fetchall()
|
||||
ctx["agents"] = [
|
||||
{"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]}
|
||||
for r in rows
|
||||
]
|
||||
conn.close()
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load agents for context: %s", exc)
|
||||
ctx["agents"] = []
|
||||
|
||||
# 3. Read hot memory
|
||||
try:
|
||||
memory_path = Path(settings.repo_root) / "MEMORY.md"
|
||||
if memory_path.exists():
|
||||
ctx["memory"] = memory_path.read_text()[:2000] # First 2000 chars
|
||||
else:
|
||||
ctx["memory"] = "(MEMORY.md not found)"
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load memory for context: %s", exc)
|
||||
ctx["memory"] = "(Memory unavailable)"
|
||||
|
||||
_timmy_context.update(ctx)
|
||||
logger.info("Timmy context built (sync): %d agents", len(ctx["agents"]))
|
||||
return ctx
|
||||
|
||||
|
||||
async def build_timmy_context_async() -> dict[str, Any]:
|
||||
"""Build complete Timmy context including hands (async version)."""
|
||||
ctx = build_timmy_context_sync()
|
||||
ctx["hands"] = await _load_hands_async()
|
||||
_timmy_context.update(ctx)
|
||||
logger.info("Timmy context built (async): %d agents, %d hands", len(ctx["agents"]), len(ctx["hands"]))
|
||||
return ctx
|
||||
|
||||
|
||||
# Keep old name for backwards compatibility
|
||||
build_timmy_context = build_timmy_context_sync
|
||||
|
||||
|
||||
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
||||
"""Format the system prompt with dynamic context."""
|
||||
|
||||
# Format agents list
|
||||
agents_list = "\n".join([
|
||||
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
||||
for a in context.get("agents", [])
|
||||
]) or "(No agents registered yet)"
|
||||
|
||||
# Format hands list
|
||||
hands_list = "\n".join([
|
||||
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
||||
for h in context.get("hands", [])
|
||||
]) or "(No hands configured)"
|
||||
|
||||
repo_root = context.get('repo_root', settings.repo_root)
|
||||
|
||||
context_block = f"""
|
||||
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
|
||||
|
||||
### Repository
|
||||
**Root:** `{repo_root}`
|
||||
|
||||
### Recent Commits (last 20):
|
||||
```
|
||||
{context.get('git_log', '(unavailable)')}
|
||||
```
|
||||
|
||||
### Active Sub-Agents:
|
||||
| Name | Capabilities | Status |
|
||||
|------|--------------|--------|
|
||||
{agents_list}
|
||||
|
||||
### Hands (Scheduled Tasks):
|
||||
| Name | Schedule | Status |
|
||||
|------|----------|--------|
|
||||
{hands_list}
|
||||
|
||||
### Hot Memory:
|
||||
{context.get('memory', '(unavailable)')[:1000]}
|
||||
"""
|
||||
|
||||
# Replace {REPO_ROOT} placeholder with actual path
|
||||
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
|
||||
|
||||
# Insert context after the first line (You are Timmy...)
|
||||
lines = base_prompt.split("\n")
|
||||
if lines:
|
||||
return lines[0] + "\n" + context_block + "\n" + "\n".join(lines[1:])
|
||||
return base_prompt
|
||||
|
||||
|
||||
# Base prompt with anti-hallucination hard rules
|
||||
TIMMY_ORCHESTRATOR_PROMPT_BASE = """You are Timmy, a sovereign AI orchestrator running locally on this Mac.
|
||||
|
||||
## Your Role
|
||||
|
||||
@@ -62,6 +222,20 @@ You have three tiers of memory:
|
||||
|
||||
Use `memory_search` when the user refers to past conversations.
|
||||
|
||||
## Hard Rules — Non-Negotiable
|
||||
|
||||
1. **NEVER fabricate tool output.** If you need data from a tool, call the tool and wait for the real result. Do not write what you think the result might be.
|
||||
|
||||
2. **If a tool call returns an error, report the exact error message.** Do not retry with invented data.
|
||||
|
||||
3. **If you do not know something about your own system, say:** "I don't have that information — let me check." Then use a tool. Do not guess.
|
||||
|
||||
4. **Never say "I'll wait for the output" and then immediately provide fake output.** These are contradictory. Wait means wait — no output until the tool returns.
|
||||
|
||||
5. **When corrected, use memory_write to save the correction immediately.**
|
||||
|
||||
6. **Your source code lives at the repository root shown above.** When using git tools, you don't need to specify a path — they automatically run from {REPO_ROOT}.
|
||||
|
||||
## Principles
|
||||
|
||||
1. **Sovereignty** — Everything local, no cloud
|
||||
@@ -78,21 +252,31 @@ class TimmyOrchestrator(BaseAgent):
|
||||
"""Main orchestrator agent that coordinates the swarm."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Build initial context (sync) and format prompt
|
||||
# Full context including hands will be loaded on first async call
|
||||
context = build_timmy_context_sync()
|
||||
formatted_prompt = format_timmy_prompt(TIMMY_ORCHESTRATOR_PROMPT_BASE, context)
|
||||
|
||||
super().__init__(
|
||||
agent_id="timmy",
|
||||
name="Timmy",
|
||||
role="orchestrator",
|
||||
system_prompt=TIMMY_ORCHESTRATOR_PROMPT,
|
||||
tools=["web_search", "read_file", "write_file", "python", "memory_search"],
|
||||
system_prompt=formatted_prompt,
|
||||
tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write"],
|
||||
)
|
||||
|
||||
# Sub-agent registry
|
||||
self.sub_agents: dict[str, BaseAgent] = {}
|
||||
|
||||
# Session tracking for init behavior
|
||||
self._session_initialized = False
|
||||
self._session_context: dict[str, Any] = {}
|
||||
self._context_fully_loaded = False
|
||||
|
||||
# Connect to event bus
|
||||
self.connect_event_bus(event_bus)
|
||||
|
||||
logger.info("Timmy Orchestrator initialized")
|
||||
logger.info("Timmy Orchestrator initialized with context-aware prompt")
|
||||
|
||||
def register_sub_agent(self, agent: BaseAgent) -> None:
|
||||
"""Register a sub-agent with the orchestrator."""
|
||||
@@ -100,11 +284,102 @@ class TimmyOrchestrator(BaseAgent):
|
||||
agent.connect_event_bus(event_bus)
|
||||
logger.info("Registered sub-agent: %s", agent.name)
|
||||
|
||||
async def _session_init(self) -> None:
|
||||
"""Initialize session context on first user message.
|
||||
|
||||
Silently reads git log and AGENTS.md to ground self-description in real data.
|
||||
This runs once per session before the first response.
|
||||
|
||||
The git log is prepended to Timmy's context so he can answer "what's new?"
|
||||
from actual commit data rather than hallucinating.
|
||||
"""
|
||||
if self._session_initialized:
|
||||
return
|
||||
|
||||
logger.debug("Running session init...")
|
||||
|
||||
# Load full context including hands if not already done
|
||||
if not self._context_fully_loaded:
|
||||
await build_timmy_context_async()
|
||||
self._context_fully_loaded = True
|
||||
|
||||
# Read recent git log --oneline -15 from repo root
|
||||
try:
|
||||
from tools.git_tools import git_log
|
||||
git_result = git_log(max_count=15)
|
||||
if git_result.get("success"):
|
||||
commits = git_result.get("commits", [])
|
||||
self._session_context["git_log_commits"] = commits
|
||||
# Format as oneline for easy reading
|
||||
self._session_context["git_log_oneline"] = "\n".join([
|
||||
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||
for c in commits
|
||||
])
|
||||
logger.debug(f"Session init: loaded {len(commits)} commits from git log")
|
||||
else:
|
||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||
except Exception as exc:
|
||||
logger.warning("Session init: could not read git log: %s", exc)
|
||||
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||
|
||||
# Read AGENTS.md for self-awareness
|
||||
try:
|
||||
agents_md_path = Path(settings.repo_root) / "AGENTS.md"
|
||||
if agents_md_path.exists():
|
||||
self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
|
||||
except Exception as exc:
|
||||
logger.warning("Session init: could not read AGENTS.md: %s", exc)
|
||||
|
||||
# Read CHANGELOG for recent changes
|
||||
try:
|
||||
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
|
||||
if changelog_path.exists():
|
||||
self._session_context["changelog"] = changelog_path.read_text()[:2000]
|
||||
except Exception:
|
||||
pass # Changelog is optional
|
||||
|
||||
# Build session-specific context block for the prompt
|
||||
recent_changes = self._session_context.get("git_log_oneline", "")
|
||||
if recent_changes and recent_changes != "Git log unavailable":
|
||||
self._session_context["recent_changes_block"] = f"""
|
||||
## Recent Changes to Your Codebase (last 15 commits):
|
||||
```
|
||||
{recent_changes}
|
||||
```
|
||||
When asked "what's new?" or similar, refer to these commits for actual changes.
|
||||
"""
|
||||
else:
|
||||
self._session_context["recent_changes_block"] = ""
|
||||
|
||||
self._session_initialized = True
|
||||
logger.debug("Session init complete")
|
||||
|
||||
def _get_enhanced_system_prompt(self) -> str:
|
||||
"""Get system prompt enhanced with session-specific context.
|
||||
|
||||
This prepends the recent git log to the system prompt so Timmy
|
||||
can answer questions about what's new from real data.
|
||||
"""
|
||||
base = self.system_prompt
|
||||
|
||||
# Add recent changes block if available
|
||||
recent_changes = self._session_context.get("recent_changes_block", "")
|
||||
if recent_changes:
|
||||
# Insert after the first line
|
||||
lines = base.split("\n")
|
||||
if lines:
|
||||
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
|
||||
|
||||
return base
|
||||
|
||||
async def orchestrate(self, user_request: str) -> str:
|
||||
"""Main entry point for user requests.
|
||||
|
||||
Analyzes the request and either handles directly or delegates.
|
||||
"""
|
||||
# Run session init on first message (loads git log, etc.)
|
||||
await self._session_init()
|
||||
|
||||
# Quick classification
|
||||
request_lower = user_request.lower()
|
||||
|
||||
@@ -171,7 +446,7 @@ def create_timmy_swarm() -> TimmyOrchestrator:
|
||||
from timmy.agents.echo import EchoAgent
|
||||
from timmy.agents.helm import HelmAgent
|
||||
|
||||
# Create orchestrator
|
||||
# Create orchestrator (builds context automatically)
|
||||
timmy = TimmyOrchestrator()
|
||||
|
||||
# Register sub-agents
|
||||
@@ -182,3 +457,18 @@ def create_timmy_swarm() -> TimmyOrchestrator:
|
||||
timmy.register_sub_agent(HelmAgent())
|
||||
|
||||
return timmy
|
||||
|
||||
|
||||
# Convenience functions for refreshing context (called by /api/timmy/refresh-context)
|
||||
def refresh_timmy_context_sync() -> dict[str, Any]:
|
||||
"""Refresh Timmy's context (sync version)."""
|
||||
return build_timmy_context_sync()
|
||||
|
||||
|
||||
async def refresh_timmy_context_async() -> dict[str, Any]:
|
||||
"""Refresh Timmy's context including hands (async version)."""
|
||||
return await build_timmy_context_async()
|
||||
|
||||
|
||||
# Keep old name for backwards compatibility
|
||||
refresh_timmy_context = refresh_timmy_context_sync
|
||||
|
||||
Reference in New Issue
Block a user