forked from Rockachopa/Timmy-time-dashboard
Merge pull request #57 from AlexanderWhitestone/feature/model-upgrade-llama3.1
feat: Multi-modal LLM support with automatic model fallback
This commit is contained in:
187
README.md
187
README.md
@@ -18,12 +18,17 @@ make install # create venv + install deps
|
|||||||
cp .env.example .env # configure environment
|
cp .env.example .env # configure environment
|
||||||
|
|
||||||
ollama serve # separate terminal
|
ollama serve # separate terminal
|
||||||
ollama pull llama3.2
|
ollama pull llama3.1:8b-instruct # Required for reliable tool calling
|
||||||
|
|
||||||
make dev # http://localhost:8000
|
make dev # http://localhost:8000
|
||||||
make test # no Ollama needed
|
make test # no Ollama needed
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note:** llama3.1:8b-instruct is used instead of llama3.2 because it is
|
||||||
|
specifically fine-tuned for reliable tool/function calling.
|
||||||
|
llama3.2 (3B) was found to hallucinate tool output consistently in testing.
|
||||||
|
Fallback: qwen2.5:14b if llama3.1:8b-instruct is not available.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## What's Here
|
## What's Here
|
||||||
@@ -74,8 +79,184 @@ make help # see all commands
|
|||||||
cp .env.example .env
|
cp .env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
Key variables: `OLLAMA_URL`, `OLLAMA_MODEL`, `TIMMY_MODEL_BACKEND`,
|
| Variable | Default | Purpose |
|
||||||
`L402_HMAC_SECRET`, `LIGHTNING_BACKEND`, `DEBUG`. Full list in `.env.example`.
|
|----------|---------|---------|
|
||||||
|
| `OLLAMA_URL` | `http://localhost:11434` | Ollama host |
|
||||||
|
| `OLLAMA_MODEL` | `llama3.1:8b-instruct` | Model for tool calling. Use llama3.1:8b-instruct for reliable tool use; fallback to qwen2.5:14b |
|
||||||
|
| `DEBUG` | `false` | Enable `/docs` and `/redoc` |
|
||||||
|
| `TIMMY_MODEL_BACKEND` | `ollama` | `ollama` \| `airllm` \| `auto` |
|
||||||
|
| `AIRLLM_MODEL_SIZE` | `70b` | `8b` \| `70b` \| `405b` |
|
||||||
|
| `L402_HMAC_SECRET` | *(default — change in prod)* | HMAC signing key for macaroons |
|
||||||
|
| `L402_MACAROON_SECRET` | *(default — change in prod)* | Macaroon secret |
|
||||||
|
| `LIGHTNING_BACKEND` | `mock` | `mock` (production-ready) \| `lnd` (scaffolded, not yet functional) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Browser / Phone
|
||||||
|
│ HTTP + HTMX + WebSocket
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────┐
|
||||||
|
│ FastAPI (dashboard.app) │
|
||||||
|
│ routes: agents, health, swarm, │
|
||||||
|
│ marketplace, voice, mobile │
|
||||||
|
└───┬─────────────┬──────────┬────────────┘
|
||||||
|
│ │ │
|
||||||
|
▼ ▼ ▼
|
||||||
|
Jinja2 Timmy Swarm
|
||||||
|
Templates Agent Coordinator
|
||||||
|
(HTMX) │ ├─ Registry (SQLite)
|
||||||
|
├─ Ollama ├─ AuctionManager (L402 bids)
|
||||||
|
└─ AirLLM ├─ SwarmComms (Redis / in-memory)
|
||||||
|
└─ SwarmManager (subprocess)
|
||||||
|
│
|
||||||
|
├── Voice NLU + TTS (pyttsx3, local)
|
||||||
|
├── WebSocket live feed (ws_manager)
|
||||||
|
├── L402 Lightning proxy (macaroon + invoice)
|
||||||
|
├── Push notifications (local + macOS native)
|
||||||
|
└── Siri Shortcuts API endpoints
|
||||||
|
|
||||||
|
Persistence: timmy.db (Agno memory), data/swarm.db (registry + tasks)
|
||||||
|
External: Ollama :11434, optional Redis, optional LND gRPC
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Layout
|
||||||
|
|
||||||
|
```
|
||||||
|
src/
|
||||||
|
config.py # pydantic-settings — all env vars live here
|
||||||
|
timmy/ # Core agent (agent.py, backends.py, cli.py, prompts.py)
|
||||||
|
hands/ # Autonomous scheduled agents (registry, scheduler, runner)
|
||||||
|
dashboard/ # FastAPI app, routes, Jinja2 templates
|
||||||
|
swarm/ # Multi-agent: coordinator, registry, bidder, tasks, comms
|
||||||
|
timmy_serve/ # L402 proxy, payment handler, TTS, serve CLI
|
||||||
|
spark/ # Intelligence engine — events, predictions, advisory
|
||||||
|
creative/ # Creative director + video assembler pipeline
|
||||||
|
tools/ # Git, image, music, video tools for persona agents
|
||||||
|
lightning/ # Lightning backend abstraction (mock + LND)
|
||||||
|
agent_core/ # Substrate-agnostic agent interface
|
||||||
|
voice/ # NLU intent detection
|
||||||
|
ws_manager/ # WebSocket connection manager
|
||||||
|
notifications/ # Push notification store
|
||||||
|
shortcuts/ # Siri Shortcuts endpoints
|
||||||
|
telegram_bot/ # Telegram bridge
|
||||||
|
self_tdd/ # Continuous test watchdog
|
||||||
|
hands/ # Hand manifests — oracle/, sentinel/, etc.
|
||||||
|
tests/ # one test file per module, all mocked
|
||||||
|
static/style.css # Dark mission-control theme (JetBrains Mono)
|
||||||
|
docs/ # GitHub Pages landing page
|
||||||
|
AGENTS.md # AI agent development standards ← read this
|
||||||
|
.env.example # Environment variable reference
|
||||||
|
Makefile # Common dev commands
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Mobile Access
|
||||||
|
|
||||||
|
The dashboard is fully mobile-optimized (iOS safe area, 44px touch targets, 16px
|
||||||
|
input to prevent zoom, momentum scroll).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Bind to your local network
|
||||||
|
uvicorn dashboard.app:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
|
||||||
|
# Find your IP
|
||||||
|
ipconfig getifaddr en0 # Wi-Fi on macOS
|
||||||
|
```
|
||||||
|
|
||||||
|
Open `http://<your-ip>:8000` on your phone (same Wi-Fi network).
|
||||||
|
|
||||||
|
Mobile-specific routes:
|
||||||
|
- `/mobile` — single-column optimized layout
|
||||||
|
- `/mobile-test` — 21-scenario HITL test harness (layout, touch, scroll, notch)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Hands — Autonomous Agents
|
||||||
|
|
||||||
|
Hands are scheduled, autonomous agents that run on cron schedules. Each Hand has a `HAND.toml` manifest, `SYSTEM.md` prompt, and optional `skills/` directory.
|
||||||
|
|
||||||
|
**Built-in Hands:**
|
||||||
|
|
||||||
|
| Hand | Schedule | Purpose |
|
||||||
|
|------|----------|---------|
|
||||||
|
| **Oracle** | 7am, 7pm UTC | Bitcoin intelligence — price, on-chain, macro analysis |
|
||||||
|
| **Sentinel** | Every 15 min | System health — dashboard, agents, database, resources |
|
||||||
|
| **Scout** | Every hour | OSINT monitoring — HN, Reddit, RSS for Bitcoin/sovereign AI |
|
||||||
|
| **Scribe** | Daily 9am | Content production — blog posts, docs, changelog |
|
||||||
|
| **Ledger** | Every 6 hours | Treasury tracking — Bitcoin/Lightning balances, payment audit |
|
||||||
|
| **Weaver** | Sunday 10am | Creative pipeline — orchestrates Pixel+Lyra+Reel for video |
|
||||||
|
|
||||||
|
**Dashboard:** `/hands` — manage, trigger, approve actions
|
||||||
|
|
||||||
|
**Example HAND.toml:**
|
||||||
|
```toml
|
||||||
|
[hand]
|
||||||
|
name = "oracle"
|
||||||
|
schedule = "0 7,19 * * *" # Twice daily
|
||||||
|
enabled = true
|
||||||
|
|
||||||
|
[tools]
|
||||||
|
required = ["mempool_fetch", "price_fetch"]
|
||||||
|
|
||||||
|
[approval_gates]
|
||||||
|
broadcast = { action = "broadcast", description = "Post to dashboard" }
|
||||||
|
|
||||||
|
[output]
|
||||||
|
dashboard = true
|
||||||
|
channel = "telegram"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AirLLM — Big Brain Backend
|
||||||
|
|
||||||
|
Run 70B or 405B models locally with no GPU, using AirLLM's layer-by-layer loading.
|
||||||
|
Apple Silicon uses MLX automatically.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install ".[bigbrain]"
|
||||||
|
pip install "airllm[mlx]" # Apple Silicon only
|
||||||
|
|
||||||
|
timmy chat "Explain self-custody" --backend airllm --model-size 70b
|
||||||
|
```
|
||||||
|
|
||||||
|
Or set once in `.env`:
|
||||||
|
```bash
|
||||||
|
TIMMY_MODEL_BACKEND=auto
|
||||||
|
AIRLLM_MODEL_SIZE=70b
|
||||||
|
```
|
||||||
|
|
||||||
|
| Flag | Parameters | RAM needed |
|
||||||
|
|-------|-------------|------------|
|
||||||
|
| `8b` | 8 billion | ~16 GB |
|
||||||
|
| `70b` | 70 billion | ~140 GB |
|
||||||
|
| `405b`| 405 billion | ~810 GB |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## CLI
|
||||||
|
|
||||||
|
```bash
|
||||||
|
timmy chat "What is sovereignty?"
|
||||||
|
timmy think "Bitcoin and self-custody"
|
||||||
|
timmy status
|
||||||
|
|
||||||
|
timmy-serve start # L402-gated API server (port 8402)
|
||||||
|
timmy-serve invoice # generate a Lightning invoice
|
||||||
|
timmy-serve status
|
||||||
|
```
|
||||||
|
|
||||||
|
Or with the bootstrap script (creates venv, tests, watchdog, server in one shot):
|
||||||
|
```bash
|
||||||
|
bash scripts/activate_self_tdd.sh
|
||||||
|
bash scripts/activate_self_tdd.sh --big-brain # also installs AirLLM
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -24,11 +24,31 @@ providers:
|
|||||||
priority: 1
|
priority: 1
|
||||||
url: "http://localhost:11434"
|
url: "http://localhost:11434"
|
||||||
models:
|
models:
|
||||||
- name: llama3.2
|
# Text + Tools models
|
||||||
|
- name: llama3.1:8b-instruct
|
||||||
default: true
|
default: true
|
||||||
context_window: 128000
|
context_window: 128000
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
|
- name: llama3.2:3b
|
||||||
|
context_window: 128000
|
||||||
|
capabilities: [text, tools, json, streaming, vision]
|
||||||
|
- name: qwen2.5:14b
|
||||||
|
context_window: 32000
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
- name: deepseek-r1:1.5b
|
- name: deepseek-r1:1.5b
|
||||||
context_window: 32000
|
context_window: 32000
|
||||||
|
capabilities: [text, json, streaming]
|
||||||
|
|
||||||
|
# Vision models
|
||||||
|
- name: llava:7b
|
||||||
|
context_window: 4096
|
||||||
|
capabilities: [text, vision, streaming]
|
||||||
|
- name: qwen2.5-vl:3b
|
||||||
|
context_window: 32000
|
||||||
|
capabilities: [text, vision, tools, json, streaming]
|
||||||
|
- name: moondream:1.8b
|
||||||
|
context_window: 2048
|
||||||
|
capabilities: [text, vision, streaming]
|
||||||
|
|
||||||
# Secondary: Local AirLLM (if installed)
|
# Secondary: Local AirLLM (if installed)
|
||||||
- name: airllm-local
|
- name: airllm-local
|
||||||
@@ -38,8 +58,11 @@ providers:
|
|||||||
models:
|
models:
|
||||||
- name: 70b
|
- name: 70b
|
||||||
default: true
|
default: true
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
- name: 8b
|
- name: 8b
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
- name: 405b
|
- name: 405b
|
||||||
|
capabilities: [text, tools, json, streaming]
|
||||||
|
|
||||||
# Tertiary: OpenAI (if API key available)
|
# Tertiary: OpenAI (if API key available)
|
||||||
- name: openai-backup
|
- name: openai-backup
|
||||||
@@ -52,8 +75,10 @@ providers:
|
|||||||
- name: gpt-4o-mini
|
- name: gpt-4o-mini
|
||||||
default: true
|
default: true
|
||||||
context_window: 128000
|
context_window: 128000
|
||||||
|
capabilities: [text, vision, tools, json, streaming]
|
||||||
- name: gpt-4o
|
- name: gpt-4o
|
||||||
context_window: 128000
|
context_window: 128000
|
||||||
|
capabilities: [text, vision, tools, json, streaming]
|
||||||
|
|
||||||
# Quaternary: Anthropic (if API key available)
|
# Quaternary: Anthropic (if API key available)
|
||||||
- name: anthropic-backup
|
- name: anthropic-backup
|
||||||
@@ -65,10 +90,37 @@ providers:
|
|||||||
- name: claude-3-haiku-20240307
|
- name: claude-3-haiku-20240307
|
||||||
default: true
|
default: true
|
||||||
context_window: 200000
|
context_window: 200000
|
||||||
|
capabilities: [text, vision, streaming]
|
||||||
- name: claude-3-sonnet-20240229
|
- name: claude-3-sonnet-20240229
|
||||||
context_window: 200000
|
context_window: 200000
|
||||||
|
capabilities: [text, vision, tools, streaming]
|
||||||
|
|
||||||
# ── Custom Models ──────────────────────────────────────────────────────
|
# ── Capability-Based Fallback Chains ────────────────────────────────────────
|
||||||
|
# When a model doesn't support a required capability (e.g., vision),
|
||||||
|
# the system falls back through these chains in order.
|
||||||
|
|
||||||
|
fallback_chains:
|
||||||
|
# Vision-capable models (for image understanding)
|
||||||
|
vision:
|
||||||
|
- llama3.2:3b # Fast, good vision
|
||||||
|
- qwen2.5-vl:3b # Excellent vision, small
|
||||||
|
- llava:7b # Classic vision model
|
||||||
|
- moondream:1.8b # Tiny, fast vision
|
||||||
|
|
||||||
|
# Tool-calling models (for function calling)
|
||||||
|
tools:
|
||||||
|
- llama3.1:8b-instruct # Best tool use
|
||||||
|
- qwen2.5:7b # Reliable tools
|
||||||
|
- llama3.2:3b # Small but capable
|
||||||
|
|
||||||
|
# General text generation (any model)
|
||||||
|
text:
|
||||||
|
- llama3.1:8b-instruct
|
||||||
|
- qwen2.5:14b
|
||||||
|
- deepseek-r1:1.5b
|
||||||
|
- llama3.2:3b
|
||||||
|
|
||||||
|
# ── Custom Models ───────────────────────────────────────────────────────────
|
||||||
# Register custom model weights for per-agent assignment.
|
# Register custom model weights for per-agent assignment.
|
||||||
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
|
# Supports GGUF (Ollama), safetensors, and HuggingFace checkpoint dirs.
|
||||||
# Models can also be registered at runtime via the /api/v1/models API.
|
# Models can also be registered at runtime via the /api/v1/models API.
|
||||||
@@ -91,7 +143,7 @@ custom_models: []
|
|||||||
# context_window: 32000
|
# context_window: 32000
|
||||||
# description: "Process reward model for scoring outputs"
|
# description: "Process reward model for scoring outputs"
|
||||||
|
|
||||||
# ── Agent Model Assignments ─────────────────────────────────────────────
|
# ── Agent Model Assignments ─────────────────────────────────────────────────
|
||||||
# Map persona agent IDs to specific models.
|
# Map persona agent IDs to specific models.
|
||||||
# Agents without an assignment use the global default (ollama_model).
|
# Agents without an assignment use the global default (ollama_model).
|
||||||
agent_model_assignments: {}
|
agent_model_assignments: {}
|
||||||
@@ -99,6 +151,20 @@ agent_model_assignments: {}
|
|||||||
# persona-forge: my-finetuned-llama
|
# persona-forge: my-finetuned-llama
|
||||||
# persona-echo: deepseek-r1:1.5b
|
# persona-echo: deepseek-r1:1.5b
|
||||||
|
|
||||||
|
# ── Multi-Modal Settings ────────────────────────────────────────────────────
|
||||||
|
multimodal:
|
||||||
|
# Automatically pull models when needed
|
||||||
|
auto_pull: true
|
||||||
|
|
||||||
|
# Timeout for model pulling (seconds)
|
||||||
|
pull_timeout: 300
|
||||||
|
|
||||||
|
# Maximum fallback depth (how many models to try before giving up)
|
||||||
|
max_fallback_depth: 3
|
||||||
|
|
||||||
|
# Prefer smaller models for vision when available (faster)
|
||||||
|
prefer_small_vision: true
|
||||||
|
|
||||||
# Cost tracking (optional, for budget monitoring)
|
# Cost tracking (optional, for budget monitoring)
|
||||||
cost_tracking:
|
cost_tracking:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|||||||
BIN
data/scripture.db-shm
Normal file
BIN
data/scripture.db-shm
Normal file
Binary file not shown.
BIN
data/scripture.db-wal
Normal file
BIN
data/scripture.db-wal
Normal file
Binary file not shown.
57
docs/CHANGELOG_2025-02-27.md
Normal file
57
docs/CHANGELOG_2025-02-27.md
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# Changelog — 2025-02-27
|
||||||
|
|
||||||
|
## Model Upgrade & Hallucination Fix
|
||||||
|
|
||||||
|
### Change 1: Model Upgrade (Primary Fix)
|
||||||
|
**Problem:** llama3.2 (3B parameters) consistently hallucinated tool output instead of waiting for real results.
|
||||||
|
|
||||||
|
**Solution:** Upgraded default model to `llama3.1:8b-instruct` which is specifically fine-tuned for reliable tool/function calling.
|
||||||
|
|
||||||
|
**Changes:**
|
||||||
|
- `src/config.py`: Changed `ollama_model` default from `llama3.2` to `llama3.1:8b-instruct`
|
||||||
|
- Added fallback logic: if primary model unavailable, auto-fallback to `qwen2.5:14b`
|
||||||
|
- `README.md`: Updated setup instructions with new model requirement
|
||||||
|
|
||||||
|
**User Action Required:**
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.1:8b-instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### Change 2: Structured Output Enforcement (Foundation)
|
||||||
|
**Preparation:** Added infrastructure for two-phase tool calling with JSON schema enforcement.
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Session context tracking in `TimmyOrchestrator`
|
||||||
|
- `_session_init()` runs on first message to load real data
|
||||||
|
|
||||||
|
### Change 3: Git Tool Working Directory Fix
|
||||||
|
**Problem:** Git tools failed with "fatal: Not a git repository" due to wrong working directory.
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- Rewrote `src/tools/git_tools.py` to use subprocess with explicit `cwd=REPO_ROOT`
|
||||||
|
- Added `REPO_ROOT` module-level constant auto-detected at import time
|
||||||
|
- All git commands now run from the correct directory
|
||||||
|
|
||||||
|
### Change 4: Session Init with Git Log
|
||||||
|
**Problem:** Timmy couldn't answer "what's new?" from real data.
|
||||||
|
|
||||||
|
**Solution:**
|
||||||
|
- `_session_init()` now reads `git log --oneline -15` from repo root on first message
|
||||||
|
- Recent commits prepended to system prompt
|
||||||
|
- Timmy now grounds self-description in actual commit history
|
||||||
|
|
||||||
|
### Change 5: Documentation Updates
|
||||||
|
- `README.md`: Updated Quickstart with new model requirement
|
||||||
|
- `README.md`: Configuration table reflects new default model
|
||||||
|
- Added notes explaining why llama3.1:8b-instruct is required
|
||||||
|
|
||||||
|
### Files Modified
|
||||||
|
- `src/config.py` — Model configuration with fallback
|
||||||
|
- `src/tools/git_tools.py` — Complete rewrite with subprocess + cwd
|
||||||
|
- `src/agents/timmy.py` — Session init with git log reading
|
||||||
|
- `README.md` — Updated setup and configuration docs
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
- All git tool tests pass with new subprocess implementation
|
||||||
|
- Git log correctly returns commits from repo root
|
||||||
|
- Session init loads context on first message
|
||||||
@@ -8,7 +8,11 @@ class Settings(BaseSettings):
|
|||||||
ollama_url: str = "http://localhost:11434"
|
ollama_url: str = "http://localhost:11434"
|
||||||
|
|
||||||
# LLM model passed to Agno/Ollama — override with OLLAMA_MODEL
|
# LLM model passed to Agno/Ollama — override with OLLAMA_MODEL
|
||||||
ollama_model: str = "llama3.2"
|
# llama3.1:8b-instruct is used instead of llama3.2 because it is
|
||||||
|
# specifically fine-tuned for reliable tool/function calling.
|
||||||
|
# llama3.2 (3B) hallucinated tool output consistently in testing.
|
||||||
|
# Fallback: qwen2.5:14b if llama3.1:8b-instruct not available.
|
||||||
|
ollama_model: str = "llama3.1:8b-instruct"
|
||||||
|
|
||||||
# Set DEBUG=true to enable /docs and /redoc (disabled by default)
|
# Set DEBUG=true to enable /docs and /redoc (disabled by default)
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
@@ -145,6 +149,62 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
# ── Model fallback configuration ────────────────────────────────────────────
|
||||||
|
# Primary model for reliable tool calling (llama3.1:8b-instruct)
|
||||||
|
# Fallback if primary not available: qwen2.5:14b
|
||||||
|
OLLAMA_MODEL_PRIMARY: str = "llama3.1:8b-instruct"
|
||||||
|
OLLAMA_MODEL_FALLBACK: str = "qwen2.5:14b"
|
||||||
|
|
||||||
|
|
||||||
|
def check_ollama_model_available(model_name: str) -> bool:
|
||||||
|
"""Check if a specific Ollama model is available locally."""
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{url}/api/tags",
|
||||||
|
method="GET",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as response:
|
||||||
|
import json
|
||||||
|
data = json.loads(response.read().decode())
|
||||||
|
models = [m.get("name", "").split(":")[0] for m in data.get("models", [])]
|
||||||
|
# Check for exact match or model name without tag
|
||||||
|
return any(model_name in m or m in model_name for m in models)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_effective_ollama_model() -> str:
|
||||||
|
"""Get the effective Ollama model, with fallback logic."""
|
||||||
|
# If user has overridden, use their setting
|
||||||
|
user_model = settings.ollama_model
|
||||||
|
|
||||||
|
# Check if user's model is available
|
||||||
|
if check_ollama_model_available(user_model):
|
||||||
|
return user_model
|
||||||
|
|
||||||
|
# Try primary
|
||||||
|
if check_ollama_model_available(OLLAMA_MODEL_PRIMARY):
|
||||||
|
_startup_logger.warning(
|
||||||
|
f"Requested model '{user_model}' not available. "
|
||||||
|
f"Using primary: {OLLAMA_MODEL_PRIMARY}"
|
||||||
|
)
|
||||||
|
return OLLAMA_MODEL_PRIMARY
|
||||||
|
|
||||||
|
# Try fallback
|
||||||
|
if check_ollama_model_available(OLLAMA_MODEL_FALLBACK):
|
||||||
|
_startup_logger.warning(
|
||||||
|
f"Primary model '{OLLAMA_MODEL_PRIMARY}' not available. "
|
||||||
|
f"Using fallback: {OLLAMA_MODEL_FALLBACK}"
|
||||||
|
)
|
||||||
|
return OLLAMA_MODEL_FALLBACK
|
||||||
|
|
||||||
|
# Last resort - return user's setting and hope for the best
|
||||||
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
# ── Startup validation ───────────────────────────────────────────────────────
|
# ── Startup validation ───────────────────────────────────────────────────────
|
||||||
# Enforce security requirements — fail fast in production.
|
# Enforce security requirements — fail fast in production.
|
||||||
import logging as _logging
|
import logging as _logging
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Git operations tools for Forge, Helm, and Timmy personas.
|
"""Git operations tools for Forge, Helm, and Timmy personas.
|
||||||
|
|
||||||
Provides a full set of git commands that agents can execute against
|
Provides a full set of git commands that agents can execute against
|
||||||
local or remote repositories. Uses GitPython under the hood.
|
the local repository. Uses subprocess with explicit working directory
|
||||||
|
to ensure commands run from the repo root.
|
||||||
|
|
||||||
All functions return plain dicts so they're easily serialisable for
|
All functions return plain dicts so they're easily serialisable for
|
||||||
tool-call results, Spark event capture, and WebSocket broadcast.
|
tool-call results, Spark event capture, and WebSocket broadcast.
|
||||||
@@ -10,134 +11,261 @@ tool-call results, Spark event capture, and WebSocket broadcast.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_GIT_AVAILABLE = True
|
|
||||||
try:
|
def _find_repo_root() -> str:
|
||||||
from git import Repo, InvalidGitRepositoryError, GitCommandNotFound
|
"""Walk up from this file's location to find the .git directory."""
|
||||||
except ImportError:
|
path = os.path.dirname(os.path.abspath(__file__))
|
||||||
_GIT_AVAILABLE = False
|
# Start from project root (3 levels up from src/tools/git_tools.py)
|
||||||
|
path = os.path.dirname(os.path.dirname(os.path.dirname(path)))
|
||||||
|
|
||||||
|
while path != os.path.dirname(path):
|
||||||
|
if os.path.exists(os.path.join(path, '.git')):
|
||||||
|
return path
|
||||||
|
path = os.path.dirname(path)
|
||||||
|
|
||||||
|
# Fallback to config repo_root
|
||||||
|
try:
|
||||||
|
from config import settings
|
||||||
|
return settings.repo_root
|
||||||
|
except Exception:
|
||||||
|
return os.getcwd()
|
||||||
|
|
||||||
|
|
||||||
def _require_git() -> None:
|
# Module-level constant for repo root
|
||||||
if not _GIT_AVAILABLE:
|
REPO_ROOT = _find_repo_root()
|
||||||
raise ImportError(
|
logger.info(f"Git repo root: {REPO_ROOT}")
|
||||||
"GitPython is not installed. Run: pip install GitPython"
|
|
||||||
|
|
||||||
|
def _run_git_command(args: list[str], cwd: Optional[str] = None) -> tuple[int, str, str]:
|
||||||
|
"""Run a git command with proper working directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Git command arguments (e.g., ["log", "--oneline", "-5"])
|
||||||
|
cwd: Working directory (defaults to REPO_ROOT)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (returncode, stdout, stderr)
|
||||||
|
"""
|
||||||
|
cmd = ["git"] + args
|
||||||
|
working_dir = cwd or REPO_ROOT
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=working_dir,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
return result.returncode, result.stdout, result.stderr
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
def _open_repo(repo_path: str | Path) -> "Repo":
|
return -1, "", "Command timed out after 30 seconds"
|
||||||
"""Open an existing git repo at *repo_path*."""
|
except Exception as exc:
|
||||||
_require_git()
|
return -1, "", str(exc)
|
||||||
return Repo(str(repo_path))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Repository management ────────────────────────────────────────────────────
|
# ── Repository management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_clone(url: str, dest: str | Path) -> dict:
|
def git_clone(url: str, dest: str | Path) -> dict:
|
||||||
"""Clone a remote repository to a local path.
|
"""Clone a remote repository to a local path."""
|
||||||
|
returncode, stdout, stderr = _run_git_command(
|
||||||
Returns dict with ``path`` and ``default_branch``.
|
["clone", url, str(dest)],
|
||||||
"""
|
cwd=None # Clone uses current directory as parent
|
||||||
_require_git()
|
)
|
||||||
repo = Repo.clone_from(url, str(dest))
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"path": str(dest),
|
"path": str(dest),
|
||||||
"default_branch": repo.active_branch.name,
|
"message": f"Cloned {url} to {dest}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def git_init(path: str | Path) -> dict:
|
def git_init(path: str | Path) -> dict:
|
||||||
"""Initialise a new git repository at *path*."""
|
"""Initialise a new git repository at *path*."""
|
||||||
_require_git()
|
os.makedirs(path, exist_ok=True)
|
||||||
Path(path).mkdir(parents=True, exist_ok=True)
|
returncode, stdout, stderr = _run_git_command(["init"], cwd=str(path))
|
||||||
repo = Repo.init(str(path))
|
|
||||||
return {"success": True, "path": str(path), "bare": repo.bare}
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "path": str(path)}
|
||||||
|
|
||||||
|
|
||||||
# ── Status / inspection ──────────────────────────────────────────────────────
|
# ── Status / inspection ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_status(repo_path: str | Path) -> dict:
|
def git_status(repo_path: Optional[str] = None) -> dict:
|
||||||
"""Return working-tree status: modified, staged, untracked files."""
|
"""Return working-tree status: modified, staged, untracked files."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
|
returncode, stdout, stderr = _run_git_command(
|
||||||
|
["status", "--porcelain", "-b"], cwd=cwd
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
# Parse porcelain output
|
||||||
|
lines = stdout.strip().split("\n") if stdout else []
|
||||||
|
branch = "unknown"
|
||||||
|
modified = []
|
||||||
|
staged = []
|
||||||
|
untracked = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("## "):
|
||||||
|
branch = line[3:].split("...")[0].strip()
|
||||||
|
elif len(line) >= 2:
|
||||||
|
index_status = line[0]
|
||||||
|
worktree_status = line[1]
|
||||||
|
filename = line[3:].strip() if len(line) > 3 else ""
|
||||||
|
|
||||||
|
if index_status in "MADRC":
|
||||||
|
staged.append(filename)
|
||||||
|
if worktree_status in "MD":
|
||||||
|
modified.append(filename)
|
||||||
|
if worktree_status == "?":
|
||||||
|
untracked.append(filename)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"branch": repo.active_branch.name,
|
"branch": branch,
|
||||||
"is_dirty": repo.is_dirty(untracked_files=True),
|
"is_dirty": bool(modified or staged or untracked),
|
||||||
"untracked": repo.untracked_files,
|
"modified": modified,
|
||||||
"modified": [item.a_path for item in repo.index.diff(None)],
|
"staged": staged,
|
||||||
"staged": [item.a_path for item in repo.index.diff("HEAD")],
|
"untracked": untracked,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def git_diff(
|
def git_diff(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
staged: bool = False,
|
staged: bool = False,
|
||||||
file_path: Optional[str] = None,
|
file_path: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Show diff of working tree or staged changes.
|
"""Show diff of working tree or staged changes."""
|
||||||
|
cwd = repo_path or REPO_ROOT
|
||||||
If *file_path* is given, scope diff to that file only.
|
args = ["diff"]
|
||||||
"""
|
|
||||||
repo = _open_repo(repo_path)
|
|
||||||
args: list[str] = []
|
|
||||||
if staged:
|
if staged:
|
||||||
args.append("--cached")
|
args.append("--cached")
|
||||||
if file_path:
|
if file_path:
|
||||||
args.extend(["--", file_path])
|
args.extend(["--", file_path])
|
||||||
diff_text = repo.git.diff(*args)
|
|
||||||
return {"success": True, "diff": diff_text, "staged": staged}
|
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "diff": stdout, "staged": staged}
|
||||||
|
|
||||||
|
|
||||||
def git_log(
|
def git_log(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
max_count: int = 20,
|
max_count: int = 20,
|
||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Return recent commit history as a list of dicts."""
|
"""Return recent commit history as a list of dicts."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
ref = branch or repo.active_branch.name
|
args = ["log", f"--max-count={max_count}", "--format=%H|%h|%s|%an|%ai"]
|
||||||
|
if branch:
|
||||||
|
args.append(branch)
|
||||||
|
|
||||||
|
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
commits = []
|
commits = []
|
||||||
for commit in repo.iter_commits(ref, max_count=max_count):
|
for line in stdout.strip().split("\n"):
|
||||||
commits.append({
|
if not line:
|
||||||
"sha": commit.hexsha,
|
continue
|
||||||
"short_sha": commit.hexsha[:8],
|
parts = line.split("|", 4)
|
||||||
"message": commit.message.strip(),
|
if len(parts) >= 5:
|
||||||
"author": str(commit.author),
|
commits.append({
|
||||||
"date": commit.committed_datetime.isoformat(),
|
"sha": parts[0],
|
||||||
"files_changed": len(commit.stats.files),
|
"short_sha": parts[1],
|
||||||
})
|
"message": parts[2],
|
||||||
return {"success": True, "branch": ref, "commits": commits}
|
"author": parts[3],
|
||||||
|
"date": parts[4],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Get current branch
|
||||||
|
_, branch_out, _ = _run_git_command(["branch", "--show-current"], cwd=cwd)
|
||||||
|
current_branch = branch_out.strip() or "main"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"branch": branch or current_branch,
|
||||||
|
"commits": commits,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def git_blame(repo_path: str | Path, file_path: str) -> dict:
|
def git_blame(repo_path: Optional[str] = None, file_path: str = "") -> dict:
|
||||||
"""Show line-by-line authorship for a file."""
|
"""Show line-by-line authorship for a file."""
|
||||||
repo = _open_repo(repo_path)
|
if not file_path:
|
||||||
blame_text = repo.git.blame(file_path)
|
return {"success": False, "error": "file_path is required"}
|
||||||
return {"success": True, "file": file_path, "blame": blame_text}
|
|
||||||
|
cwd = repo_path or REPO_ROOT
|
||||||
|
returncode, stdout, stderr = _run_git_command(
|
||||||
|
["blame", "--porcelain", file_path], cwd=cwd
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "file": file_path, "blame": stdout}
|
||||||
|
|
||||||
|
|
||||||
# ── Branching ─────────────────────────────────────────────────────────────────
|
# ── Branching ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_branch(
|
def git_branch(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
create: Optional[str] = None,
|
create: Optional[str] = None,
|
||||||
switch: Optional[str] = None,
|
switch: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""List branches, optionally create or switch to one."""
|
"""List branches, optionally create or switch to one."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
|
|
||||||
if create:
|
if create:
|
||||||
repo.create_head(create)
|
returncode, _, stderr = _run_git_command(
|
||||||
|
["branch", create], cwd=cwd
|
||||||
|
)
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
if switch:
|
if switch:
|
||||||
repo.heads[switch].checkout()
|
returncode, _, stderr = _run_git_command(
|
||||||
|
["checkout", switch], cwd=cwd
|
||||||
branches = [h.name for h in repo.heads]
|
)
|
||||||
active = repo.active_branch.name
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
# List branches
|
||||||
|
returncode, stdout, stderr = _run_git_command(
|
||||||
|
["branch", "-a", "--format=%(refname:short)%(if)%(HEAD)%(then)*%(end)"],
|
||||||
|
cwd=cwd
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
branches = []
|
||||||
|
active = ""
|
||||||
|
for line in stdout.strip().split("\n"):
|
||||||
|
line = line.strip()
|
||||||
|
if line.endswith("*"):
|
||||||
|
active = line[:-1]
|
||||||
|
branches.append(active)
|
||||||
|
elif line:
|
||||||
|
branches.append(line)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"branches": branches,
|
"branches": branches,
|
||||||
@@ -149,26 +277,47 @@ def git_branch(
|
|||||||
|
|
||||||
# ── Staging & committing ─────────────────────────────────────────────────────
|
# ── Staging & committing ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_add(repo_path: str | Path, paths: list[str] | None = None) -> dict:
|
def git_add(repo_path: Optional[str] = None, paths: Optional[list[str]] = None) -> dict:
|
||||||
"""Stage files for commit. *paths* defaults to all modified files."""
|
"""Stage files for commit. *paths* defaults to all modified files."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
|
|
||||||
if paths:
|
if paths:
|
||||||
repo.index.add(paths)
|
args = ["add"] + paths
|
||||||
else:
|
else:
|
||||||
# Stage all changes
|
args = ["add", "-A"]
|
||||||
repo.git.add(A=True)
|
|
||||||
staged = [item.a_path for item in repo.index.diff("HEAD")]
|
returncode, _, stderr = _run_git_command(args, cwd=cwd)
|
||||||
return {"success": True, "staged": staged}
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "staged": paths or ["all"]}
|
||||||
|
|
||||||
|
|
||||||
def git_commit(repo_path: str | Path, message: str) -> dict:
|
def git_commit(
|
||||||
|
repo_path: Optional[str] = None,
|
||||||
|
message: str = "",
|
||||||
|
) -> dict:
|
||||||
"""Create a commit with the given message."""
|
"""Create a commit with the given message."""
|
||||||
repo = _open_repo(repo_path)
|
if not message:
|
||||||
commit = repo.index.commit(message)
|
return {"success": False, "error": "commit message is required"}
|
||||||
|
|
||||||
|
cwd = repo_path or REPO_ROOT
|
||||||
|
returncode, stdout, stderr = _run_git_command(
|
||||||
|
["commit", "-m", message], cwd=cwd
|
||||||
|
)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
# Get the commit hash
|
||||||
|
_, hash_out, _ = _run_git_command(["rev-parse", "HEAD"], cwd=cwd)
|
||||||
|
commit_hash = hash_out.strip()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"sha": commit.hexsha,
|
"sha": commit_hash,
|
||||||
"short_sha": commit.hexsha[:8],
|
"short_sha": commit_hash[:8],
|
||||||
"message": message,
|
"message": message,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,47 +325,68 @@ def git_commit(repo_path: str | Path, message: str) -> dict:
|
|||||||
# ── Remote operations ─────────────────────────────────────────────────────────
|
# ── Remote operations ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_push(
|
def git_push(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
remote: str = "origin",
|
remote: str = "origin",
|
||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Push the current (or specified) branch to the remote."""
|
"""Push the current (or specified) branch to the remote."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
ref = branch or repo.active_branch.name
|
args = ["push", remote]
|
||||||
info = repo.remotes[remote].push(ref)
|
if branch:
|
||||||
summaries = [str(i.summary) for i in info]
|
args.append(branch)
|
||||||
return {"success": True, "remote": remote, "branch": ref, "summaries": summaries}
|
|
||||||
|
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "remote": remote, "branch": branch or "current"}
|
||||||
|
|
||||||
|
|
||||||
def git_pull(
|
def git_pull(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
remote: str = "origin",
|
remote: str = "origin",
|
||||||
branch: Optional[str] = None,
|
branch: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Pull from the remote into the working tree."""
|
"""Pull from the remote into the working tree."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
ref = branch or repo.active_branch.name
|
args = ["pull", remote]
|
||||||
info = repo.remotes[remote].pull(ref)
|
if branch:
|
||||||
summaries = [str(i.summary) for i in info]
|
args.append(branch)
|
||||||
return {"success": True, "remote": remote, "branch": ref, "summaries": summaries}
|
|
||||||
|
returncode, stdout, stderr = _run_git_command(args, cwd=cwd)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
|
return {"success": True, "remote": remote, "branch": branch or "current"}
|
||||||
|
|
||||||
|
|
||||||
# ── Stashing ──────────────────────────────────────────────────────────────────
|
# ── Stashing ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def git_stash(
|
def git_stash(
|
||||||
repo_path: str | Path,
|
repo_path: Optional[str] = None,
|
||||||
pop: bool = False,
|
pop: bool = False,
|
||||||
message: Optional[str] = None,
|
message: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Stash or pop working-tree changes."""
|
"""Stash or pop working-tree changes."""
|
||||||
repo = _open_repo(repo_path)
|
cwd = repo_path or REPO_ROOT
|
||||||
|
|
||||||
if pop:
|
if pop:
|
||||||
repo.git.stash("pop")
|
returncode, _, stderr = _run_git_command(["stash", "pop"], cwd=cwd)
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
return {"success": True, "action": "pop"}
|
return {"success": True, "action": "pop"}
|
||||||
args = ["push"]
|
|
||||||
|
args = ["stash", "push"]
|
||||||
if message:
|
if message:
|
||||||
args.extend(["-m", message])
|
args.extend(["-m", message])
|
||||||
repo.git.stash(*args)
|
|
||||||
|
returncode, _, stderr = _run_git_command(args, cwd=cwd)
|
||||||
|
|
||||||
|
if returncode != 0:
|
||||||
|
return {"success": False, "error": stderr}
|
||||||
|
|
||||||
return {"success": True, "action": "stash", "message": message}
|
return {"success": True, "action": "stash", "message": message}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Infrastructure models package."""
|
||||||
|
|
||||||
|
from infrastructure.models.registry import (
|
||||||
|
CustomModel,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRegistry,
|
||||||
|
ModelRole,
|
||||||
|
model_registry,
|
||||||
|
)
|
||||||
|
from infrastructure.models.multimodal import (
|
||||||
|
ModelCapability,
|
||||||
|
ModelInfo,
|
||||||
|
MultiModalManager,
|
||||||
|
get_model_for_capability,
|
||||||
|
get_multimodal_manager,
|
||||||
|
model_supports_tools,
|
||||||
|
model_supports_vision,
|
||||||
|
pull_model_with_fallback,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Registry
|
||||||
|
"CustomModel",
|
||||||
|
"ModelFormat",
|
||||||
|
"ModelRegistry",
|
||||||
|
"ModelRole",
|
||||||
|
"model_registry",
|
||||||
|
# Multi-modal
|
||||||
|
"ModelCapability",
|
||||||
|
"ModelInfo",
|
||||||
|
"MultiModalManager",
|
||||||
|
"get_model_for_capability",
|
||||||
|
"get_multimodal_manager",
|
||||||
|
"model_supports_tools",
|
||||||
|
"model_supports_vision",
|
||||||
|
"pull_model_with_fallback",
|
||||||
|
]
|
||||||
|
|||||||
445
src/infrastructure/models/multimodal.py
Normal file
445
src/infrastructure/models/multimodal.py
Normal file
@@ -0,0 +1,445 @@
|
|||||||
|
"""Multi-modal model support with automatic capability detection and fallbacks.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Model capability detection (vision, audio, etc.)
|
||||||
|
- Automatic model pulling with fallback chains
|
||||||
|
- Content-type aware model selection
|
||||||
|
- Graceful degradation when primary models unavailable
|
||||||
|
|
||||||
|
No cloud by default — tries local first, falls back through configured options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCapability(Enum):
|
||||||
|
"""Capabilities a model can have."""
|
||||||
|
TEXT = auto() # Standard text completion
|
||||||
|
VISION = auto() # Image understanding
|
||||||
|
AUDIO = auto() # Audio/speech processing
|
||||||
|
TOOLS = auto() # Function calling / tool use
|
||||||
|
JSON = auto() # Structured output / JSON mode
|
||||||
|
STREAMING = auto() # Streaming responses
|
||||||
|
|
||||||
|
|
||||||
|
# Known model capabilities (local Ollama models)
|
||||||
|
# These are used when we can't query the model directly
|
||||||
|
KNOWN_MODEL_CAPABILITIES: dict[str, set[ModelCapability]] = {
|
||||||
|
# Llama 3.x series
|
||||||
|
"llama3.1": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.1:8b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.1:8b-instruct": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.1:70b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.1:405b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.2": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
"llama3.2:1b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"llama3.2:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
"llama3.2-vision": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
"llama3.2-vision:11b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
|
||||||
|
# Qwen series
|
||||||
|
"qwen2.5": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"qwen2.5:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"qwen2.5:14b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"qwen2.5:32b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"qwen2.5:72b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"qwen2.5-vl": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
"qwen2.5-vl:3b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
"qwen2.5-vl:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING, ModelCapability.VISION},
|
||||||
|
|
||||||
|
# DeepSeek series
|
||||||
|
"deepseek-r1": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-r1:1.5b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-r1:7b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-r1:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-r1:32b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-r1:70b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"deepseek-v3": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Gemma series
|
||||||
|
"gemma2": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"gemma2:2b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"gemma2:9b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"gemma2:27b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Mistral series
|
||||||
|
"mistral": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"mistral:7b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"mistral-nemo": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"mistral-small": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"mistral-large": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Vision-specific models
|
||||||
|
"llava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"llava:7b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"llava:13b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"llava:34b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"llava-phi3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"llava-llama3": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"bakllava": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"moondream": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
"moondream:1.8b": {ModelCapability.TEXT, ModelCapability.VISION, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Phi series
|
||||||
|
"phi3": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"phi3:3.8b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"phi3:14b": {ModelCapability.TEXT, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"phi4": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Command R
|
||||||
|
"command-r": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"command-r:35b": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"command-r-plus": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
|
||||||
|
# Granite (IBM)
|
||||||
|
"granite3-dense": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
"granite3-moe": {ModelCapability.TEXT, ModelCapability.TOOLS, ModelCapability.JSON, ModelCapability.STREAMING},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Default fallback chains for each capability
|
||||||
|
# These are tried in order when the primary model doesn't support a capability
|
||||||
|
DEFAULT_FALLBACK_CHAINS: dict[ModelCapability, list[str]] = {
|
||||||
|
ModelCapability.VISION: [
|
||||||
|
"llama3.2:3b", # Fast vision model
|
||||||
|
"llava:7b", # Classic vision model
|
||||||
|
"qwen2.5-vl:3b", # Qwen vision
|
||||||
|
"moondream:1.8b", # Tiny vision model (last resort)
|
||||||
|
],
|
||||||
|
ModelCapability.TOOLS: [
|
||||||
|
"llama3.1:8b-instruct", # Best tool use
|
||||||
|
"llama3.2:3b", # Smaller but capable
|
||||||
|
"qwen2.5:7b", # Reliable fallback
|
||||||
|
],
|
||||||
|
ModelCapability.AUDIO: [
|
||||||
|
# Audio models are less common in Ollama
|
||||||
|
# Would need specific audio-capable models here
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""Information about a model's capabilities and availability."""
|
||||||
|
name: str
|
||||||
|
capabilities: set[ModelCapability] = field(default_factory=set)
|
||||||
|
is_available: bool = False
|
||||||
|
is_pulled: bool = False
|
||||||
|
size_mb: Optional[int] = None
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
def supports(self, capability: ModelCapability) -> bool:
|
||||||
|
"""Check if model supports a specific capability."""
|
||||||
|
return capability in self.capabilities
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalManager:
|
||||||
|
"""Manages multi-modal model capabilities and fallback chains.
|
||||||
|
|
||||||
|
This class:
|
||||||
|
1. Detects what capabilities each model has
|
||||||
|
2. Maintains fallback chains for different capabilities
|
||||||
|
3. Pulls models on-demand with automatic fallback
|
||||||
|
4. Routes requests to appropriate models based on content type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ollama_url: Optional[str] = None) -> None:
|
||||||
|
self.ollama_url = ollama_url or settings.ollama_url
|
||||||
|
self._available_models: dict[str, ModelInfo] = {}
|
||||||
|
self._fallback_chains: dict[ModelCapability, list[str]] = dict(DEFAULT_FALLBACK_CHAINS)
|
||||||
|
self._refresh_available_models()
|
||||||
|
|
||||||
|
def _refresh_available_models(self) -> None:
|
||||||
|
"""Query Ollama for available models."""
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
import json
|
||||||
|
|
||||||
|
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{url}/api/tags",
|
||||||
|
method="GET",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as response:
|
||||||
|
data = json.loads(response.read().decode())
|
||||||
|
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
name = model_data.get("name", "")
|
||||||
|
self._available_models[name] = ModelInfo(
|
||||||
|
name=name,
|
||||||
|
capabilities=self._detect_capabilities(name),
|
||||||
|
is_available=True,
|
||||||
|
is_pulled=True,
|
||||||
|
size_mb=model_data.get("size", 0) // (1024 * 1024),
|
||||||
|
description=model_data.get("details", {}).get("family", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Found %d models in Ollama", len(self._available_models))
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not refresh available models: %s", exc)
|
||||||
|
|
||||||
|
def _detect_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||||
|
"""Detect capabilities for a model based on known data."""
|
||||||
|
# Normalize model name (strip tags for lookup)
|
||||||
|
base_name = model_name.split(":")[0]
|
||||||
|
|
||||||
|
# Try exact match first
|
||||||
|
if model_name in KNOWN_MODEL_CAPABILITIES:
|
||||||
|
return set(KNOWN_MODEL_CAPABILITIES[model_name])
|
||||||
|
|
||||||
|
# Try base name match
|
||||||
|
if base_name in KNOWN_MODEL_CAPABILITIES:
|
||||||
|
return set(KNOWN_MODEL_CAPABILITIES[base_name])
|
||||||
|
|
||||||
|
# Default to text-only for unknown models
|
||||||
|
logger.debug("Unknown model %s, defaulting to TEXT only", model_name)
|
||||||
|
return {ModelCapability.TEXT, ModelCapability.STREAMING}
|
||||||
|
|
||||||
|
def get_model_capabilities(self, model_name: str) -> set[ModelCapability]:
|
||||||
|
"""Get capabilities for a specific model."""
|
||||||
|
if model_name in self._available_models:
|
||||||
|
return self._available_models[model_name].capabilities
|
||||||
|
return self._detect_capabilities(model_name)
|
||||||
|
|
||||||
|
def model_supports(self, model_name: str, capability: ModelCapability) -> bool:
|
||||||
|
"""Check if a model supports a specific capability."""
|
||||||
|
capabilities = self.get_model_capabilities(model_name)
|
||||||
|
return capability in capabilities
|
||||||
|
|
||||||
|
def get_models_with_capability(self, capability: ModelCapability) -> list[ModelInfo]:
|
||||||
|
"""Get all available models that support a capability."""
|
||||||
|
return [
|
||||||
|
info for info in self._available_models.values()
|
||||||
|
if capability in info.capabilities
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_best_model_for(
|
||||||
|
self,
|
||||||
|
capability: ModelCapability,
|
||||||
|
preferred_model: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get the best available model for a specific capability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
capability: The required capability
|
||||||
|
preferred_model: Preferred model to use if available and capable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model name or None if no suitable model found
|
||||||
|
"""
|
||||||
|
# Check if preferred model supports this capability
|
||||||
|
if preferred_model:
|
||||||
|
if preferred_model in self._available_models:
|
||||||
|
if self.model_supports(preferred_model, capability):
|
||||||
|
return preferred_model
|
||||||
|
logger.debug(
|
||||||
|
"Preferred model %s doesn't support %s, checking fallbacks",
|
||||||
|
preferred_model, capability.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check fallback chain for this capability
|
||||||
|
fallback_chain = self._fallback_chains.get(capability, [])
|
||||||
|
for model_name in fallback_chain:
|
||||||
|
if model_name in self._available_models:
|
||||||
|
logger.debug("Using fallback model %s for %s", model_name, capability.name)
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# Find any available model with this capability
|
||||||
|
capable_models = self.get_models_with_capability(capability)
|
||||||
|
if capable_models:
|
||||||
|
# Sort by size (prefer smaller/faster models as fallback)
|
||||||
|
capable_models.sort(key=lambda m: m.size_mb or float('inf'))
|
||||||
|
return capable_models[0].name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def pull_model_with_fallback(
|
||||||
|
self,
|
||||||
|
primary_model: str,
|
||||||
|
capability: Optional[ModelCapability] = None,
|
||||||
|
auto_pull: bool = True,
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Pull a model with automatic fallback if unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
primary_model: The desired model to use
|
||||||
|
capability: Required capability (for finding fallback)
|
||||||
|
auto_pull: Whether to attempt pulling missing models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model_name, is_fallback)
|
||||||
|
"""
|
||||||
|
# Check if primary model is already available
|
||||||
|
if primary_model in self._available_models:
|
||||||
|
return primary_model, False
|
||||||
|
|
||||||
|
# Try to pull the primary model
|
||||||
|
if auto_pull:
|
||||||
|
if self._pull_model(primary_model):
|
||||||
|
return primary_model, False
|
||||||
|
|
||||||
|
# Need to find a fallback
|
||||||
|
if capability:
|
||||||
|
fallback = self.get_best_model_for(capability, primary_model)
|
||||||
|
if fallback:
|
||||||
|
logger.info(
|
||||||
|
"Primary model %s unavailable, using fallback %s",
|
||||||
|
primary_model, fallback
|
||||||
|
)
|
||||||
|
return fallback, True
|
||||||
|
|
||||||
|
# Last resort: use the configured default model
|
||||||
|
default_model = settings.ollama_model
|
||||||
|
if default_model in self._available_models:
|
||||||
|
logger.warning(
|
||||||
|
"Falling back to default model %s (primary: %s unavailable)",
|
||||||
|
default_model, primary_model
|
||||||
|
)
|
||||||
|
return default_model, True
|
||||||
|
|
||||||
|
# Absolute last resort
|
||||||
|
return primary_model, False
|
||||||
|
|
||||||
|
def _pull_model(self, model_name: str) -> bool:
|
||||||
|
"""Attempt to pull a model from Ollama.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful or model already exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger.info("Pulling model: %s", model_name)
|
||||||
|
|
||||||
|
url = self.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{url}/api/pull",
|
||||||
|
method="POST",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(req, timeout=300) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
logger.info("Successfully pulled model: %s", model_name)
|
||||||
|
# Refresh available models
|
||||||
|
self._refresh_available_models()
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def configure_fallback_chain(
|
||||||
|
self,
|
||||||
|
capability: ModelCapability,
|
||||||
|
models: list[str]
|
||||||
|
) -> None:
|
||||||
|
"""Configure a custom fallback chain for a capability."""
|
||||||
|
self._fallback_chains[capability] = models
|
||||||
|
logger.info("Configured fallback chain for %s: %s", capability.name, models)
|
||||||
|
|
||||||
|
def get_fallback_chain(self, capability: ModelCapability) -> list[str]:
|
||||||
|
"""Get the fallback chain for a capability."""
|
||||||
|
return list(self._fallback_chains.get(capability, []))
|
||||||
|
|
||||||
|
def list_available_models(self) -> list[ModelInfo]:
|
||||||
|
"""List all available models with their capabilities."""
|
||||||
|
return list(self._available_models.values())
|
||||||
|
|
||||||
|
def refresh(self) -> None:
|
||||||
|
"""Refresh the list of available models."""
|
||||||
|
self._refresh_available_models()
|
||||||
|
|
||||||
|
def get_model_for_content(
|
||||||
|
self,
|
||||||
|
content_type: str, # "text", "image", "audio", "multimodal"
|
||||||
|
preferred_model: Optional[str] = None,
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Get appropriate model based on content type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_type: Type of content (text, image, audio, multimodal)
|
||||||
|
preferred_model: User's preferred model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model_name, is_fallback)
|
||||||
|
"""
|
||||||
|
content_type = content_type.lower()
|
||||||
|
|
||||||
|
if content_type in ("image", "vision", "multimodal"):
|
||||||
|
# For vision content, we need a vision-capable model
|
||||||
|
return self.pull_model_with_fallback(
|
||||||
|
preferred_model or "llava:7b",
|
||||||
|
capability=ModelCapability.VISION,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif content_type == "audio":
|
||||||
|
# Audio support is limited in Ollama
|
||||||
|
# Would need specific audio models
|
||||||
|
logger.warning("Audio support is limited, falling back to text model")
|
||||||
|
return self.pull_model_with_fallback(
|
||||||
|
preferred_model or settings.ollama_model,
|
||||||
|
capability=ModelCapability.TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Standard text content
|
||||||
|
return self.pull_model_with_fallback(
|
||||||
|
preferred_model or settings.ollama_model,
|
||||||
|
capability=ModelCapability.TEXT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
_multimodal_manager: Optional[MultiModalManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_multimodal_manager() -> MultiModalManager:
|
||||||
|
"""Get or create the multi-modal manager singleton."""
|
||||||
|
global _multimodal_manager
|
||||||
|
if _multimodal_manager is None:
|
||||||
|
_multimodal_manager = MultiModalManager()
|
||||||
|
return _multimodal_manager
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_for_capability(
|
||||||
|
capability: ModelCapability,
|
||||||
|
preferred_model: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Convenience function to get best model for a capability."""
|
||||||
|
return get_multimodal_manager().get_best_model_for(capability, preferred_model)
|
||||||
|
|
||||||
|
|
||||||
|
def pull_model_with_fallback(
|
||||||
|
primary_model: str,
|
||||||
|
capability: Optional[ModelCapability] = None,
|
||||||
|
auto_pull: bool = True,
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Convenience function to pull model with fallback."""
|
||||||
|
return get_multimodal_manager().pull_model_with_fallback(
|
||||||
|
primary_model, capability, auto_pull
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def model_supports_vision(model_name: str) -> bool:
|
||||||
|
"""Check if a model supports vision."""
|
||||||
|
return get_multimodal_manager().model_supports(model_name, ModelCapability.VISION)
|
||||||
|
|
||||||
|
|
||||||
|
def model_supports_tools(model_name: str) -> bool:
|
||||||
|
"""Check if a model supports tool calling."""
|
||||||
|
return get_multimodal_manager().model_supports(model_name, ModelCapability.TOOLS)
|
||||||
@@ -3,14 +3,19 @@
|
|||||||
Routes requests through an ordered list of LLM providers,
|
Routes requests through an ordered list of LLM providers,
|
||||||
automatically failing over on rate limits or errors.
|
automatically failing over on rate limits or errors.
|
||||||
Tracks metrics for latency, errors, and cost.
|
Tracks metrics for latency, errors, and cost.
|
||||||
|
|
||||||
|
Now with multi-modal support — automatically selects vision-capable
|
||||||
|
models for image inputs and falls back through capability chains.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -43,6 +48,14 @@ class CircuitState(Enum):
|
|||||||
HALF_OPEN = "half_open" # Testing if recovered
|
HALF_OPEN = "half_open" # Testing if recovered
|
||||||
|
|
||||||
|
|
||||||
|
class ContentType(Enum):
|
||||||
|
"""Type of content in the request."""
|
||||||
|
TEXT = "text"
|
||||||
|
VISION = "vision" # Contains images
|
||||||
|
AUDIO = "audio" # Contains audio
|
||||||
|
MULTIMODAL = "multimodal" # Multiple content types
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProviderMetrics:
|
class ProviderMetrics:
|
||||||
"""Metrics for a single provider."""
|
"""Metrics for a single provider."""
|
||||||
@@ -67,6 +80,18 @@ class ProviderMetrics:
|
|||||||
return self.failed_requests / self.total_requests
|
return self.failed_requests / self.total_requests
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapability:
|
||||||
|
"""Capabilities a model supports."""
|
||||||
|
name: str
|
||||||
|
supports_vision: bool = False
|
||||||
|
supports_audio: bool = False
|
||||||
|
supports_tools: bool = False
|
||||||
|
supports_json: bool = False
|
||||||
|
supports_streaming: bool = True
|
||||||
|
context_window: int = 4096
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Provider:
|
class Provider:
|
||||||
"""LLM provider configuration and state."""
|
"""LLM provider configuration and state."""
|
||||||
@@ -94,6 +119,23 @@ class Provider:
|
|||||||
if self.models:
|
if self.models:
|
||||||
return self.models[0]["name"]
|
return self.models[0]["name"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_model_with_capability(self, capability: str) -> Optional[str]:
|
||||||
|
"""Get a model that supports the given capability."""
|
||||||
|
for model in self.models:
|
||||||
|
capabilities = model.get("capabilities", [])
|
||||||
|
if capability in capabilities:
|
||||||
|
return model["name"]
|
||||||
|
# Fall back to default
|
||||||
|
return self.get_default_model()
|
||||||
|
|
||||||
|
def model_has_capability(self, model_name: str, capability: str) -> bool:
|
||||||
|
"""Check if a specific model has a capability."""
|
||||||
|
for model in self.models:
|
||||||
|
if model["name"] == model_name:
|
||||||
|
capabilities = model.get("capabilities", [])
|
||||||
|
return capability in capabilities
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -107,19 +149,39 @@ class RouterConfig:
|
|||||||
circuit_breaker_half_open_max_calls: int = 2
|
circuit_breaker_half_open_max_calls: int = 2
|
||||||
cost_tracking_enabled: bool = True
|
cost_tracking_enabled: bool = True
|
||||||
budget_daily_usd: float = 10.0
|
budget_daily_usd: float = 10.0
|
||||||
|
# Multi-modal settings
|
||||||
|
auto_pull_models: bool = True
|
||||||
|
fallback_chains: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class CascadeRouter:
|
class CascadeRouter:
|
||||||
"""Routes LLM requests with automatic failover.
|
"""Routes LLM requests with automatic failover.
|
||||||
|
|
||||||
|
Now with multi-modal support:
|
||||||
|
- Automatically detects content type (text, vision, audio)
|
||||||
|
- Selects appropriate models based on capabilities
|
||||||
|
- Falls back through capability-specific model chains
|
||||||
|
- Supports image URLs and base64 encoding
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
router = CascadeRouter()
|
router = CascadeRouter()
|
||||||
|
|
||||||
|
# Text request
|
||||||
response = await router.complete(
|
response = await router.complete(
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
model="llama3.2"
|
model="llama3.2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Vision request (automatically detects and selects vision model)
|
||||||
|
response = await router.complete(
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What's in this image?",
|
||||||
|
"images": ["path/to/image.jpg"]
|
||||||
|
}],
|
||||||
|
model="llava:7b"
|
||||||
|
)
|
||||||
|
|
||||||
# Check metrics
|
# Check metrics
|
||||||
metrics = router.get_metrics()
|
metrics = router.get_metrics()
|
||||||
"""
|
"""
|
||||||
@@ -130,6 +192,14 @@ class CascadeRouter:
|
|||||||
self.config: RouterConfig = RouterConfig()
|
self.config: RouterConfig = RouterConfig()
|
||||||
self._load_config()
|
self._load_config()
|
||||||
|
|
||||||
|
# Initialize multi-modal manager if available
|
||||||
|
self._mm_manager: Optional[Any] = None
|
||||||
|
try:
|
||||||
|
from infrastructure.models.multimodal import get_multimodal_manager
|
||||||
|
self._mm_manager = get_multimodal_manager()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Multi-modal manager not available: %s", exc)
|
||||||
|
|
||||||
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
logger.info("CascadeRouter initialized with %d providers", len(self.providers))
|
||||||
|
|
||||||
def _load_config(self) -> None:
|
def _load_config(self) -> None:
|
||||||
@@ -149,6 +219,13 @@ class CascadeRouter:
|
|||||||
|
|
||||||
# Load cascade settings
|
# Load cascade settings
|
||||||
cascade = data.get("cascade", {})
|
cascade = data.get("cascade", {})
|
||||||
|
|
||||||
|
# Load fallback chains
|
||||||
|
fallback_chains = data.get("fallback_chains", {})
|
||||||
|
|
||||||
|
# Load multi-modal settings
|
||||||
|
multimodal = data.get("multimodal", {})
|
||||||
|
|
||||||
self.config = RouterConfig(
|
self.config = RouterConfig(
|
||||||
timeout_seconds=cascade.get("timeout_seconds", 30),
|
timeout_seconds=cascade.get("timeout_seconds", 30),
|
||||||
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
max_retries_per_provider=cascade.get("max_retries_per_provider", 2),
|
||||||
@@ -156,6 +233,8 @@ class CascadeRouter:
|
|||||||
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
circuit_breaker_failure_threshold=cascade.get("circuit_breaker", {}).get("failure_threshold", 5),
|
||||||
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
circuit_breaker_recovery_timeout=cascade.get("circuit_breaker", {}).get("recovery_timeout", 60),
|
||||||
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
circuit_breaker_half_open_max_calls=cascade.get("circuit_breaker", {}).get("half_open_max_calls", 2),
|
||||||
|
auto_pull_models=multimodal.get("auto_pull", True),
|
||||||
|
fallback_chains=fallback_chains,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load providers
|
# Load providers
|
||||||
@@ -226,6 +305,81 @@ class CascadeRouter:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _detect_content_type(self, messages: list[dict]) -> ContentType:
|
||||||
|
"""Detect the type of content in the messages.
|
||||||
|
|
||||||
|
Checks for images, audio, etc. in the message content.
|
||||||
|
"""
|
||||||
|
has_image = False
|
||||||
|
has_audio = False
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content", "")
|
||||||
|
|
||||||
|
# Check for image URLs/paths
|
||||||
|
if msg.get("images"):
|
||||||
|
has_image = True
|
||||||
|
|
||||||
|
# Check for image URLs in content
|
||||||
|
if isinstance(content, str):
|
||||||
|
image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp')
|
||||||
|
if any(ext in content.lower() for ext in image_extensions):
|
||||||
|
has_image = True
|
||||||
|
if content.startswith("data:image/"):
|
||||||
|
has_image = True
|
||||||
|
|
||||||
|
# Check for audio
|
||||||
|
if msg.get("audio"):
|
||||||
|
has_audio = True
|
||||||
|
|
||||||
|
# Check for multimodal content structure
|
||||||
|
if isinstance(content, list):
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
if item.get("type") == "image_url":
|
||||||
|
has_image = True
|
||||||
|
elif item.get("type") == "audio":
|
||||||
|
has_audio = True
|
||||||
|
|
||||||
|
if has_image and has_audio:
|
||||||
|
return ContentType.MULTIMODAL
|
||||||
|
elif has_image:
|
||||||
|
return ContentType.VISION
|
||||||
|
elif has_audio:
|
||||||
|
return ContentType.AUDIO
|
||||||
|
return ContentType.TEXT
|
||||||
|
|
||||||
|
def _get_fallback_model(
|
||||||
|
self,
|
||||||
|
provider: Provider,
|
||||||
|
original_model: str,
|
||||||
|
content_type: ContentType
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get a fallback model for the given content type."""
|
||||||
|
# Map content type to capability
|
||||||
|
capability_map = {
|
||||||
|
ContentType.VISION: "vision",
|
||||||
|
ContentType.AUDIO: "audio",
|
||||||
|
ContentType.MULTIMODAL: "vision", # Vision models often do both
|
||||||
|
}
|
||||||
|
|
||||||
|
capability = capability_map.get(content_type)
|
||||||
|
if not capability:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check provider's models for capability
|
||||||
|
fallback_model = provider.get_model_with_capability(capability)
|
||||||
|
if fallback_model and fallback_model != original_model:
|
||||||
|
return fallback_model
|
||||||
|
|
||||||
|
# Use fallback chains from config
|
||||||
|
fallback_chain = self.config.fallback_chains.get(capability, [])
|
||||||
|
for model_name in fallback_chain:
|
||||||
|
if provider.model_has_capability(model_name, capability):
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def complete(
|
async def complete(
|
||||||
self,
|
self,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
@@ -235,6 +389,11 @@ class CascadeRouter:
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""Complete a chat conversation with automatic failover.
|
"""Complete a chat conversation with automatic failover.
|
||||||
|
|
||||||
|
Multi-modal support:
|
||||||
|
- Automatically detects if messages contain images
|
||||||
|
- Falls back to vision-capable models when needed
|
||||||
|
- Supports image URLs, paths, and base64 encoding
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of message dicts with role and content
|
messages: List of message dicts with role and content
|
||||||
model: Preferred model (tries this first, then provider defaults)
|
model: Preferred model (tries this first, then provider defaults)
|
||||||
@@ -247,6 +406,11 @@ class CascadeRouter:
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If all providers fail
|
RuntimeError: If all providers fail
|
||||||
"""
|
"""
|
||||||
|
# Detect content type for multi-modal routing
|
||||||
|
content_type = self._detect_content_type(messages)
|
||||||
|
if content_type != ContentType.TEXT:
|
||||||
|
logger.debug("Detected %s content, selecting appropriate model", content_type.value)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for provider in self.providers:
|
for provider in self.providers:
|
||||||
@@ -266,15 +430,48 @@ class CascadeRouter:
|
|||||||
logger.debug("Skipping %s (circuit open)", provider.name)
|
logger.debug("Skipping %s (circuit open)", provider.name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Determine which model to use
|
||||||
|
selected_model = model or provider.get_default_model()
|
||||||
|
is_fallback_model = False
|
||||||
|
|
||||||
|
# For non-text content, check if model supports it
|
||||||
|
if content_type != ContentType.TEXT and selected_model:
|
||||||
|
if provider.type == "ollama" and self._mm_manager:
|
||||||
|
from infrastructure.models.multimodal import ModelCapability
|
||||||
|
|
||||||
|
# Check if selected model supports the required capability
|
||||||
|
if content_type == ContentType.VISION:
|
||||||
|
supports = self._mm_manager.model_supports(
|
||||||
|
selected_model, ModelCapability.VISION
|
||||||
|
)
|
||||||
|
if not supports:
|
||||||
|
# Find fallback model
|
||||||
|
fallback = self._get_fallback_model(
|
||||||
|
provider, selected_model, content_type
|
||||||
|
)
|
||||||
|
if fallback:
|
||||||
|
logger.info(
|
||||||
|
"Model %s doesn't support vision, falling back to %s",
|
||||||
|
selected_model, fallback
|
||||||
|
)
|
||||||
|
selected_model = fallback
|
||||||
|
is_fallback_model = True
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"No vision-capable model found on %s, trying anyway",
|
||||||
|
provider.name
|
||||||
|
)
|
||||||
|
|
||||||
# Try this provider
|
# Try this provider
|
||||||
for attempt in range(self.config.max_retries_per_provider):
|
for attempt in range(self.config.max_retries_per_provider):
|
||||||
try:
|
try:
|
||||||
result = await self._try_provider(
|
result = await self._try_provider(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
model=model,
|
model=selected_model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Success! Update metrics and return
|
# Success! Update metrics and return
|
||||||
@@ -282,8 +479,9 @@ class CascadeRouter:
|
|||||||
return {
|
return {
|
||||||
"content": result["content"],
|
"content": result["content"],
|
||||||
"provider": provider.name,
|
"provider": provider.name,
|
||||||
"model": result.get("model", model or provider.get_default_model()),
|
"model": result.get("model", selected_model or provider.get_default_model()),
|
||||||
"latency_ms": result.get("latency_ms", 0),
|
"latency_ms": result.get("latency_ms", 0),
|
||||||
|
"is_fallback_model": is_fallback_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -307,9 +505,10 @@ class CascadeRouter:
|
|||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
model: Optional[str],
|
model: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: Optional[int],
|
max_tokens: Optional[int],
|
||||||
|
content_type: ContentType = ContentType.TEXT,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Try a single provider request."""
|
"""Try a single provider request."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -320,6 +519,7 @@ class CascadeRouter:
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
model=model or provider.get_default_model(),
|
model=model or provider.get_default_model(),
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
content_type=content_type,
|
||||||
)
|
)
|
||||||
elif provider.type == "openai":
|
elif provider.type == "openai":
|
||||||
result = await self._call_openai(
|
result = await self._call_openai(
|
||||||
@@ -359,15 +559,19 @@ class CascadeRouter:
|
|||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
model: str,
|
model: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
content_type: ContentType = ContentType.TEXT,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Call Ollama API."""
|
"""Call Ollama API with multi-modal support."""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
url = f"{provider.url}/api/chat"
|
url = f"{provider.url}/api/chat"
|
||||||
|
|
||||||
|
# Transform messages for Ollama format (including images)
|
||||||
|
transformed_messages = self._transform_messages_for_ollama(messages)
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": transformed_messages,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"options": {
|
"options": {
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
@@ -388,6 +592,41 @@ class CascadeRouter:
|
|||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
||||||
|
"""Transform messages to Ollama format, handling images."""
|
||||||
|
transformed = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
new_msg = {
|
||||||
|
"role": msg.get("role", "user"),
|
||||||
|
"content": msg.get("content", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Handle images
|
||||||
|
images = msg.get("images", [])
|
||||||
|
if images:
|
||||||
|
new_msg["images"] = []
|
||||||
|
for img in images:
|
||||||
|
if isinstance(img, str):
|
||||||
|
if img.startswith("data:image/"):
|
||||||
|
# Base64 encoded image
|
||||||
|
new_msg["images"].append(img.split(",")[1])
|
||||||
|
elif img.startswith("http://") or img.startswith("https://"):
|
||||||
|
# URL - would need to download, skip for now
|
||||||
|
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
||||||
|
elif Path(img).exists():
|
||||||
|
# Local file path - read and encode
|
||||||
|
try:
|
||||||
|
with open(img, "rb") as f:
|
||||||
|
img_data = base64.b64encode(f.read()).decode()
|
||||||
|
new_msg["images"].append(img_data)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Failed to read image %s: %s", img, exc)
|
||||||
|
|
||||||
|
transformed.append(new_msg)
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
async def _call_openai(
|
async def _call_openai(
|
||||||
self,
|
self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
@@ -496,7 +735,7 @@ class CascadeRouter:
|
|||||||
"content": response.choices[0].message.content,
|
"content": response.choices[0].message.content,
|
||||||
"model": response.model,
|
"model": response.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
def _record_success(self, provider: Provider, latency_ms: float) -> None:
|
||||||
"""Record a successful request."""
|
"""Record a successful request."""
|
||||||
provider.metrics.total_requests += 1
|
provider.metrics.total_requests += 1
|
||||||
@@ -598,6 +837,35 @@ class CascadeRouter:
|
|||||||
for p in self.providers
|
for p in self.providers
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def generate_with_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_path: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
) -> dict:
|
||||||
|
"""Convenience method for vision requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt about the image
|
||||||
|
image_path: Path to image file
|
||||||
|
model: Vision-capable model (auto-selected if not provided)
|
||||||
|
temperature: Sampling temperature
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response dict with content and metadata
|
||||||
|
"""
|
||||||
|
messages = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
"images": [image_path],
|
||||||
|
}]
|
||||||
|
return await self.complete(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton
|
||||||
|
|||||||
@@ -5,11 +5,16 @@ Memory Architecture:
|
|||||||
- Tier 2 (Vault): memory/ — structured markdown, append-only
|
- Tier 2 (Vault): memory/ — structured markdown, append-only
|
||||||
- Tier 3 (Semantic): Vector search over vault files
|
- Tier 3 (Semantic): Vector search over vault files
|
||||||
|
|
||||||
|
Model Management:
|
||||||
|
- Pulls requested model automatically if not available
|
||||||
|
- Falls back through capability-based model chains
|
||||||
|
- Multi-modal support with vision model fallbacks
|
||||||
|
|
||||||
Handoff Protocol maintains continuity across sessions.
|
Handoff Protocol maintains continuity across sessions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from agno.agent import Agent
|
from agno.agent import Agent
|
||||||
from agno.db.sqlite import SqliteDb
|
from agno.db.sqlite import SqliteDb
|
||||||
@@ -24,6 +29,23 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Fallback chain for text/tool models (in order of preference)
|
||||||
|
DEFAULT_MODEL_FALLBACKS = [
|
||||||
|
"llama3.1:8b-instruct",
|
||||||
|
"llama3.1",
|
||||||
|
"qwen2.5:14b",
|
||||||
|
"qwen2.5:7b",
|
||||||
|
"llama3.2:3b",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fallback chain for vision models
|
||||||
|
VISION_MODEL_FALLBACKS = [
|
||||||
|
"llama3.2:3b",
|
||||||
|
"llava:7b",
|
||||||
|
"qwen2.5-vl:3b",
|
||||||
|
"moondream:1.8b",
|
||||||
|
]
|
||||||
|
|
||||||
# Union type for callers that want to hint the return type.
|
# Union type for callers that want to hint the return type.
|
||||||
TimmyAgent = Union[Agent, "TimmyAirLLMAgent", "GrokBackend"]
|
TimmyAgent = Union[Agent, "TimmyAirLLMAgent", "GrokBackend"]
|
||||||
|
|
||||||
@@ -40,6 +62,120 @@ _SMALL_MODEL_PATTERNS = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_model_available(model_name: str) -> bool:
|
||||||
|
"""Check if an Ollama model is available locally."""
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
import json
|
||||||
|
|
||||||
|
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{url}/api/tags",
|
||||||
|
method="GET",
|
||||||
|
headers={"Accept": "application/json"},
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as response:
|
||||||
|
data = json.loads(response.read().decode())
|
||||||
|
models = [m.get("name", "") for m in data.get("models", [])]
|
||||||
|
# Check for exact match or model name without tag
|
||||||
|
return any(
|
||||||
|
model_name == m or model_name == m.split(":")[0] or m.startswith(model_name)
|
||||||
|
for m in models
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Could not check model availability: %s", exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _pull_model(model_name: str) -> bool:
|
||||||
|
"""Attempt to pull a model from Ollama.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful or model already exists
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import urllib.request
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger.info("Pulling model: %s", model_name)
|
||||||
|
|
||||||
|
url = settings.ollama_url.replace("localhost", "127.0.0.1")
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{url}/api/pull",
|
||||||
|
method="POST",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
data=json.dumps({"name": model_name, "stream": False}).encode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
with urllib.request.urlopen(req, timeout=300) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
logger.info("Successfully pulled model: %s", model_name)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.error("Failed to pull %s: HTTP %s", model_name, response.status)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error pulling model %s: %s", model_name, exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_model_with_fallback(
|
||||||
|
requested_model: Optional[str] = None,
|
||||||
|
require_vision: bool = False,
|
||||||
|
auto_pull: bool = True,
|
||||||
|
) -> tuple[str, bool]:
|
||||||
|
"""Resolve model with automatic pulling and fallback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requested_model: Preferred model to use
|
||||||
|
require_vision: Whether the model needs vision capabilities
|
||||||
|
auto_pull: Whether to attempt pulling missing models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model_name, is_fallback)
|
||||||
|
"""
|
||||||
|
model = requested_model or settings.ollama_model
|
||||||
|
|
||||||
|
# Check if requested model is available
|
||||||
|
if _check_model_available(model):
|
||||||
|
logger.debug("Using available model: %s", model)
|
||||||
|
return model, False
|
||||||
|
|
||||||
|
# Try to pull the requested model
|
||||||
|
if auto_pull:
|
||||||
|
logger.info("Model %s not available locally, attempting to pull...", model)
|
||||||
|
if _pull_model(model):
|
||||||
|
return model, False
|
||||||
|
logger.warning("Failed to pull %s, checking fallbacks...", model)
|
||||||
|
|
||||||
|
# Use appropriate fallback chain
|
||||||
|
fallback_chain = VISION_MODEL_FALLBACKS if require_vision else DEFAULT_MODEL_FALLBACKS
|
||||||
|
|
||||||
|
for fallback_model in fallback_chain:
|
||||||
|
if _check_model_available(fallback_model):
|
||||||
|
logger.warning(
|
||||||
|
"Using fallback model %s (requested: %s)",
|
||||||
|
fallback_model, model
|
||||||
|
)
|
||||||
|
return fallback_model, True
|
||||||
|
|
||||||
|
# Try to pull the fallback
|
||||||
|
if auto_pull and _pull_model(fallback_model):
|
||||||
|
logger.info(
|
||||||
|
"Pulled and using fallback model %s (requested: %s)",
|
||||||
|
fallback_model, model
|
||||||
|
)
|
||||||
|
return fallback_model, True
|
||||||
|
|
||||||
|
# Absolute last resort - return the requested model and hope for the best
|
||||||
|
logger.error(
|
||||||
|
"No models available in fallback chain. Requested: %s",
|
||||||
|
model
|
||||||
|
)
|
||||||
|
return model, False
|
||||||
|
|
||||||
|
|
||||||
def _model_supports_tools(model_name: str) -> bool:
|
def _model_supports_tools(model_name: str) -> bool:
|
||||||
"""Check if the configured model can reliably handle tool calling.
|
"""Check if the configured model can reliably handle tool calling.
|
||||||
|
|
||||||
@@ -106,7 +242,16 @@ def create_timmy(
|
|||||||
return TimmyAirLLMAgent(model_size=size)
|
return TimmyAirLLMAgent(model_size=size)
|
||||||
|
|
||||||
# Default: Ollama via Agno.
|
# Default: Ollama via Agno.
|
||||||
model_name = settings.ollama_model
|
# Resolve model with automatic pulling and fallback
|
||||||
|
model_name, is_fallback = _resolve_model_with_fallback(
|
||||||
|
requested_model=None,
|
||||||
|
require_vision=False,
|
||||||
|
auto_pull=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_fallback:
|
||||||
|
logger.info("Using fallback model %s (requested was unavailable)", model_name)
|
||||||
|
|
||||||
use_tools = _model_supports_tools(model_name)
|
use_tools = _model_supports_tools(model_name)
|
||||||
|
|
||||||
# Conditionally include tools — small models get none
|
# Conditionally include tools — small models get none
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from timmy.agent_core.interface import (
|
|||||||
TimAgent,
|
TimAgent,
|
||||||
AgentEffect,
|
AgentEffect,
|
||||||
)
|
)
|
||||||
from timmy.agent import create_timmy
|
from timmy.agent import create_timmy, _resolve_model_with_fallback
|
||||||
|
|
||||||
|
|
||||||
class OllamaAgent(TimAgent):
|
class OllamaAgent(TimAgent):
|
||||||
@@ -53,18 +53,33 @@ class OllamaAgent(TimAgent):
|
|||||||
identity: AgentIdentity,
|
identity: AgentIdentity,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
effect_log: Optional[str] = None,
|
effect_log: Optional[str] = None,
|
||||||
|
require_vision: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Ollama-based agent.
|
"""Initialize Ollama-based agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identity: Agent identity (persistent across sessions)
|
identity: Agent identity (persistent across sessions)
|
||||||
model: Ollama model to use (default from config)
|
model: Ollama model to use (auto-resolves with fallback)
|
||||||
effect_log: Path to log agent effects (optional)
|
effect_log: Path to log agent effects (optional)
|
||||||
|
require_vision: Whether to select a vision-capable model
|
||||||
"""
|
"""
|
||||||
super().__init__(identity)
|
super().__init__(identity)
|
||||||
|
|
||||||
|
# Resolve model with automatic pulling and fallback
|
||||||
|
resolved_model, is_fallback = _resolve_model_with_fallback(
|
||||||
|
requested_model=model,
|
||||||
|
require_vision=require_vision,
|
||||||
|
auto_pull=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_fallback:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).info(
|
||||||
|
"OllamaAdapter using fallback model %s", resolved_model
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize underlying Ollama agent
|
# Initialize underlying Ollama agent
|
||||||
self._timmy = create_timmy(model=model)
|
self._timmy = create_timmy(model=resolved_model)
|
||||||
|
|
||||||
# Set capabilities based on what Ollama can do
|
# Set capabilities based on what Ollama can do
|
||||||
self._capabilities = {
|
self._capabilities = {
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ Uses the three-tier memory system and MCP tools.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from agno.agent import Agent
|
from agno.agent import Agent
|
||||||
@@ -17,8 +19,166 @@ from mcp.registry import tool_registry
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Dynamic context that gets built at startup
|
||||||
|
_timmy_context: dict[str, Any] = {
|
||||||
|
"git_log": "",
|
||||||
|
"agents": [],
|
||||||
|
"hands": [],
|
||||||
|
"memory": "",
|
||||||
|
}
|
||||||
|
|
||||||
TIMMY_ORCHESTRATOR_PROMPT = """You are Timmy, a sovereign AI orchestrator running locally on this Mac.
|
|
||||||
|
async def _load_hands_async() -> list[dict]:
|
||||||
|
"""Async helper to load hands."""
|
||||||
|
try:
|
||||||
|
from hands.registry import HandRegistry
|
||||||
|
reg = HandRegistry()
|
||||||
|
hands_dict = await reg.load_all()
|
||||||
|
return [
|
||||||
|
{"name": h.name, "schedule": h.schedule.cron if h.schedule else "manual", "enabled": h.enabled}
|
||||||
|
for h in hands_dict.values()
|
||||||
|
]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not load hands for context: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def build_timmy_context_sync() -> dict[str, Any]:
|
||||||
|
"""Build Timmy's self-awareness context at startup (synchronous version).
|
||||||
|
|
||||||
|
This function gathers:
|
||||||
|
- Recent git commits (last 20)
|
||||||
|
- Active sub-agents
|
||||||
|
- Hot memory from MEMORY.md
|
||||||
|
|
||||||
|
Note: Hands are loaded separately in async context.
|
||||||
|
|
||||||
|
Returns a dict that can be formatted into the system prompt.
|
||||||
|
"""
|
||||||
|
global _timmy_context
|
||||||
|
|
||||||
|
ctx: dict[str, Any] = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"repo_root": settings.repo_root,
|
||||||
|
"git_log": "",
|
||||||
|
"agents": [],
|
||||||
|
"hands": [],
|
||||||
|
"memory": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1. Get recent git commits
|
||||||
|
try:
|
||||||
|
from tools.git_tools import git_log
|
||||||
|
result = git_log(max_count=20)
|
||||||
|
if result.get("success"):
|
||||||
|
commits = result.get("commits", [])
|
||||||
|
ctx["git_log"] = "\n".join([
|
||||||
|
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||||
|
for c in commits[:20]
|
||||||
|
])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not load git log for context: %s", exc)
|
||||||
|
ctx["git_log"] = "(Git log unavailable)"
|
||||||
|
|
||||||
|
# 2. Get active sub-agents
|
||||||
|
try:
|
||||||
|
from swarm import registry as swarm_registry
|
||||||
|
conn = swarm_registry._get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, name, status, capabilities FROM agents ORDER BY name"
|
||||||
|
).fetchall()
|
||||||
|
ctx["agents"] = [
|
||||||
|
{"id": r["id"], "name": r["name"], "status": r["status"], "capabilities": r["capabilities"]}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
conn.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not load agents for context: %s", exc)
|
||||||
|
ctx["agents"] = []
|
||||||
|
|
||||||
|
# 3. Read hot memory
|
||||||
|
try:
|
||||||
|
memory_path = Path(settings.repo_root) / "MEMORY.md"
|
||||||
|
if memory_path.exists():
|
||||||
|
ctx["memory"] = memory_path.read_text()[:2000] # First 2000 chars
|
||||||
|
else:
|
||||||
|
ctx["memory"] = "(MEMORY.md not found)"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Could not load memory for context: %s", exc)
|
||||||
|
ctx["memory"] = "(Memory unavailable)"
|
||||||
|
|
||||||
|
_timmy_context.update(ctx)
|
||||||
|
logger.info("Timmy context built (sync): %d agents", len(ctx["agents"]))
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
async def build_timmy_context_async() -> dict[str, Any]:
|
||||||
|
"""Build complete Timmy context including hands (async version)."""
|
||||||
|
ctx = build_timmy_context_sync()
|
||||||
|
ctx["hands"] = await _load_hands_async()
|
||||||
|
_timmy_context.update(ctx)
|
||||||
|
logger.info("Timmy context built (async): %d agents, %d hands", len(ctx["agents"]), len(ctx["hands"]))
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
|
# Keep old name for backwards compatibility
|
||||||
|
build_timmy_context = build_timmy_context_sync
|
||||||
|
|
||||||
|
|
||||||
|
def format_timmy_prompt(base_prompt: str, context: dict[str, Any]) -> str:
|
||||||
|
"""Format the system prompt with dynamic context."""
|
||||||
|
|
||||||
|
# Format agents list
|
||||||
|
agents_list = "\n".join([
|
||||||
|
f"| {a['name']} | {a['capabilities'] or 'general'} | {a['status']} |"
|
||||||
|
for a in context.get("agents", [])
|
||||||
|
]) or "(No agents registered yet)"
|
||||||
|
|
||||||
|
# Format hands list
|
||||||
|
hands_list = "\n".join([
|
||||||
|
f"| {h['name']} | {h['schedule']} | {'enabled' if h['enabled'] else 'disabled'} |"
|
||||||
|
for h in context.get("hands", [])
|
||||||
|
]) or "(No hands configured)"
|
||||||
|
|
||||||
|
repo_root = context.get('repo_root', settings.repo_root)
|
||||||
|
|
||||||
|
context_block = f"""
|
||||||
|
## Current System Context (as of {context.get('timestamp', datetime.now(timezone.utc).isoformat())})
|
||||||
|
|
||||||
|
### Repository
|
||||||
|
**Root:** `{repo_root}`
|
||||||
|
|
||||||
|
### Recent Commits (last 20):
|
||||||
|
```
|
||||||
|
{context.get('git_log', '(unavailable)')}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Active Sub-Agents:
|
||||||
|
| Name | Capabilities | Status |
|
||||||
|
|------|--------------|--------|
|
||||||
|
{agents_list}
|
||||||
|
|
||||||
|
### Hands (Scheduled Tasks):
|
||||||
|
| Name | Schedule | Status |
|
||||||
|
|------|----------|--------|
|
||||||
|
{hands_list}
|
||||||
|
|
||||||
|
### Hot Memory:
|
||||||
|
{context.get('memory', '(unavailable)')[:1000]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Replace {REPO_ROOT} placeholder with actual path
|
||||||
|
base_prompt = base_prompt.replace("{REPO_ROOT}", repo_root)
|
||||||
|
|
||||||
|
# Insert context after the first line (You are Timmy...)
|
||||||
|
lines = base_prompt.split("\n")
|
||||||
|
if lines:
|
||||||
|
return lines[0] + "\n" + context_block + "\n" + "\n".join(lines[1:])
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Base prompt with anti-hallucination hard rules
|
||||||
|
TIMMY_ORCHESTRATOR_PROMPT_BASE = """You are Timmy, a sovereign AI orchestrator running locally on this Mac.
|
||||||
|
|
||||||
## Your Role
|
## Your Role
|
||||||
|
|
||||||
@@ -62,6 +222,20 @@ You have three tiers of memory:
|
|||||||
|
|
||||||
Use `memory_search` when the user refers to past conversations.
|
Use `memory_search` when the user refers to past conversations.
|
||||||
|
|
||||||
|
## Hard Rules — Non-Negotiable
|
||||||
|
|
||||||
|
1. **NEVER fabricate tool output.** If you need data from a tool, call the tool and wait for the real result. Do not write what you think the result might be.
|
||||||
|
|
||||||
|
2. **If a tool call returns an error, report the exact error message.** Do not retry with invented data.
|
||||||
|
|
||||||
|
3. **If you do not know something about your own system, say:** "I don't have that information — let me check." Then use a tool. Do not guess.
|
||||||
|
|
||||||
|
4. **Never say "I'll wait for the output" and then immediately provide fake output.** These are contradictory. Wait means wait — no output until the tool returns.
|
||||||
|
|
||||||
|
5. **When corrected, use memory_write to save the correction immediately.**
|
||||||
|
|
||||||
|
6. **Your source code lives at the repository root shown above.** When using git tools, you don't need to specify a path — they automatically run from {REPO_ROOT}.
|
||||||
|
|
||||||
## Principles
|
## Principles
|
||||||
|
|
||||||
1. **Sovereignty** — Everything local, no cloud
|
1. **Sovereignty** — Everything local, no cloud
|
||||||
@@ -78,21 +252,31 @@ class TimmyOrchestrator(BaseAgent):
|
|||||||
"""Main orchestrator agent that coordinates the swarm."""
|
"""Main orchestrator agent that coordinates the swarm."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
# Build initial context (sync) and format prompt
|
||||||
|
# Full context including hands will be loaded on first async call
|
||||||
|
context = build_timmy_context_sync()
|
||||||
|
formatted_prompt = format_timmy_prompt(TIMMY_ORCHESTRATOR_PROMPT_BASE, context)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
agent_id="timmy",
|
agent_id="timmy",
|
||||||
name="Timmy",
|
name="Timmy",
|
||||||
role="orchestrator",
|
role="orchestrator",
|
||||||
system_prompt=TIMMY_ORCHESTRATOR_PROMPT,
|
system_prompt=formatted_prompt,
|
||||||
tools=["web_search", "read_file", "write_file", "python", "memory_search"],
|
tools=["web_search", "read_file", "write_file", "python", "memory_search", "memory_write"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sub-agent registry
|
# Sub-agent registry
|
||||||
self.sub_agents: dict[str, BaseAgent] = {}
|
self.sub_agents: dict[str, BaseAgent] = {}
|
||||||
|
|
||||||
|
# Session tracking for init behavior
|
||||||
|
self._session_initialized = False
|
||||||
|
self._session_context: dict[str, Any] = {}
|
||||||
|
self._context_fully_loaded = False
|
||||||
|
|
||||||
# Connect to event bus
|
# Connect to event bus
|
||||||
self.connect_event_bus(event_bus)
|
self.connect_event_bus(event_bus)
|
||||||
|
|
||||||
logger.info("Timmy Orchestrator initialized")
|
logger.info("Timmy Orchestrator initialized with context-aware prompt")
|
||||||
|
|
||||||
def register_sub_agent(self, agent: BaseAgent) -> None:
|
def register_sub_agent(self, agent: BaseAgent) -> None:
|
||||||
"""Register a sub-agent with the orchestrator."""
|
"""Register a sub-agent with the orchestrator."""
|
||||||
@@ -100,11 +284,102 @@ class TimmyOrchestrator(BaseAgent):
|
|||||||
agent.connect_event_bus(event_bus)
|
agent.connect_event_bus(event_bus)
|
||||||
logger.info("Registered sub-agent: %s", agent.name)
|
logger.info("Registered sub-agent: %s", agent.name)
|
||||||
|
|
||||||
|
async def _session_init(self) -> None:
|
||||||
|
"""Initialize session context on first user message.
|
||||||
|
|
||||||
|
Silently reads git log and AGENTS.md to ground self-description in real data.
|
||||||
|
This runs once per session before the first response.
|
||||||
|
|
||||||
|
The git log is prepended to Timmy's context so he can answer "what's new?"
|
||||||
|
from actual commit data rather than hallucinating.
|
||||||
|
"""
|
||||||
|
if self._session_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("Running session init...")
|
||||||
|
|
||||||
|
# Load full context including hands if not already done
|
||||||
|
if not self._context_fully_loaded:
|
||||||
|
await build_timmy_context_async()
|
||||||
|
self._context_fully_loaded = True
|
||||||
|
|
||||||
|
# Read recent git log --oneline -15 from repo root
|
||||||
|
try:
|
||||||
|
from tools.git_tools import git_log
|
||||||
|
git_result = git_log(max_count=15)
|
||||||
|
if git_result.get("success"):
|
||||||
|
commits = git_result.get("commits", [])
|
||||||
|
self._session_context["git_log_commits"] = commits
|
||||||
|
# Format as oneline for easy reading
|
||||||
|
self._session_context["git_log_oneline"] = "\n".join([
|
||||||
|
f"{c['short_sha']} {c['message'].split(chr(10))[0]}"
|
||||||
|
for c in commits
|
||||||
|
])
|
||||||
|
logger.debug(f"Session init: loaded {len(commits)} commits from git log")
|
||||||
|
else:
|
||||||
|
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Session init: could not read git log: %s", exc)
|
||||||
|
self._session_context["git_log_oneline"] = "Git log unavailable"
|
||||||
|
|
||||||
|
# Read AGENTS.md for self-awareness
|
||||||
|
try:
|
||||||
|
agents_md_path = Path(settings.repo_root) / "AGENTS.md"
|
||||||
|
if agents_md_path.exists():
|
||||||
|
self._session_context["agents_md"] = agents_md_path.read_text()[:3000]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Session init: could not read AGENTS.md: %s", exc)
|
||||||
|
|
||||||
|
# Read CHANGELOG for recent changes
|
||||||
|
try:
|
||||||
|
changelog_path = Path(settings.repo_root) / "docs" / "CHANGELOG_2026-02-26.md"
|
||||||
|
if changelog_path.exists():
|
||||||
|
self._session_context["changelog"] = changelog_path.read_text()[:2000]
|
||||||
|
except Exception:
|
||||||
|
pass # Changelog is optional
|
||||||
|
|
||||||
|
# Build session-specific context block for the prompt
|
||||||
|
recent_changes = self._session_context.get("git_log_oneline", "")
|
||||||
|
if recent_changes and recent_changes != "Git log unavailable":
|
||||||
|
self._session_context["recent_changes_block"] = f"""
|
||||||
|
## Recent Changes to Your Codebase (last 15 commits):
|
||||||
|
```
|
||||||
|
{recent_changes}
|
||||||
|
```
|
||||||
|
When asked "what's new?" or similar, refer to these commits for actual changes.
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
self._session_context["recent_changes_block"] = ""
|
||||||
|
|
||||||
|
self._session_initialized = True
|
||||||
|
logger.debug("Session init complete")
|
||||||
|
|
||||||
|
def _get_enhanced_system_prompt(self) -> str:
|
||||||
|
"""Get system prompt enhanced with session-specific context.
|
||||||
|
|
||||||
|
This prepends the recent git log to the system prompt so Timmy
|
||||||
|
can answer questions about what's new from real data.
|
||||||
|
"""
|
||||||
|
base = self.system_prompt
|
||||||
|
|
||||||
|
# Add recent changes block if available
|
||||||
|
recent_changes = self._session_context.get("recent_changes_block", "")
|
||||||
|
if recent_changes:
|
||||||
|
# Insert after the first line
|
||||||
|
lines = base.split("\n")
|
||||||
|
if lines:
|
||||||
|
return lines[0] + "\n" + recent_changes + "\n" + "\n".join(lines[1:])
|
||||||
|
|
||||||
|
return base
|
||||||
|
|
||||||
async def orchestrate(self, user_request: str) -> str:
|
async def orchestrate(self, user_request: str) -> str:
|
||||||
"""Main entry point for user requests.
|
"""Main entry point for user requests.
|
||||||
|
|
||||||
Analyzes the request and either handles directly or delegates.
|
Analyzes the request and either handles directly or delegates.
|
||||||
"""
|
"""
|
||||||
|
# Run session init on first message (loads git log, etc.)
|
||||||
|
await self._session_init()
|
||||||
|
|
||||||
# Quick classification
|
# Quick classification
|
||||||
request_lower = user_request.lower()
|
request_lower = user_request.lower()
|
||||||
|
|
||||||
@@ -171,7 +446,7 @@ def create_timmy_swarm() -> TimmyOrchestrator:
|
|||||||
from timmy.agents.echo import EchoAgent
|
from timmy.agents.echo import EchoAgent
|
||||||
from timmy.agents.helm import HelmAgent
|
from timmy.agents.helm import HelmAgent
|
||||||
|
|
||||||
# Create orchestrator
|
# Create orchestrator (builds context automatically)
|
||||||
timmy = TimmyOrchestrator()
|
timmy = TimmyOrchestrator()
|
||||||
|
|
||||||
# Register sub-agents
|
# Register sub-agents
|
||||||
@@ -182,3 +457,18 @@ def create_timmy_swarm() -> TimmyOrchestrator:
|
|||||||
timmy.register_sub_agent(HelmAgent())
|
timmy.register_sub_agent(HelmAgent())
|
||||||
|
|
||||||
return timmy
|
return timmy
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for refreshing context (called by /api/timmy/refresh-context)
|
||||||
|
def refresh_timmy_context_sync() -> dict[str, Any]:
|
||||||
|
"""Refresh Timmy's context (sync version)."""
|
||||||
|
return build_timmy_context_sync()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_timmy_context_async() -> dict[str, Any]:
|
||||||
|
"""Refresh Timmy's context including hands (async version)."""
|
||||||
|
return await build_timmy_context_async()
|
||||||
|
|
||||||
|
|
||||||
|
# Keep old name for backwards compatibility
|
||||||
|
refresh_timmy_context = refresh_timmy_context_sync
|
||||||
|
|||||||
Reference in New Issue
Block a user