forked from Rockachopa/Timmy-time-dashboard
Merge pull request #27 from AlexanderWhitestone/claude/analyze-test-coverage-KBlkN
This commit is contained in:
@@ -46,3 +46,10 @@
|
||||
# Alternatively, configure via the /telegram/setup dashboard endpoint at runtime.
|
||||
# Requires: pip install ".[telegram]"
|
||||
# TELEGRAM_TOKEN=
|
||||
|
||||
# ── Discord bot ──────────────────────────────────────────────────────────────
|
||||
# Bot token from https://discord.com/developers/applications
|
||||
# Alternatively, configure via the /discord/setup dashboard endpoint at runtime.
|
||||
# Requires: pip install ".[discord]"
|
||||
# Optional: pip install pyzbar Pillow (for QR code invite detection from screenshots)
|
||||
# DISCORD_TOKEN=
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -21,8 +21,9 @@ env/
|
||||
# SQLite memory — never commit agent memory
|
||||
*.db
|
||||
|
||||
# Telegram bot state (contains bot token)
|
||||
# Chat platform state files (contain bot tokens)
|
||||
telegram_state.json
|
||||
discord_state.json
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
|
||||
70
docker-compose.test.yml
Normal file
70
docker-compose.test.yml
Normal file
@@ -0,0 +1,70 @@
|
||||
# ── Timmy Time — test stack ──────────────────────────────────────────────────
|
||||
#
|
||||
# Lightweight compose for functional tests. Runs the dashboard on port 18000
|
||||
# and optional agent workers on the swarm-test-net network.
|
||||
#
|
||||
# Usage:
|
||||
# FUNCTIONAL_DOCKER=1 pytest tests/functional/test_docker_swarm.py -v
|
||||
#
|
||||
# Or manually:
|
||||
# docker compose -f docker-compose.test.yml -p timmy-test up -d --build --wait
|
||||
# curl http://localhost:18000/health
|
||||
# docker compose -f docker-compose.test.yml -p timmy-test down -v
|
||||
|
||||
services:
|
||||
|
||||
dashboard:
|
||||
build: .
|
||||
image: timmy-time:test
|
||||
container_name: timmy-test-dashboard
|
||||
ports:
|
||||
- "18000:8000"
|
||||
volumes:
|
||||
- test-data:/app/data
|
||||
- ./src:/app/src
|
||||
- ./static:/app/static
|
||||
environment:
|
||||
DEBUG: "true"
|
||||
TIMMY_TEST_MODE: "1"
|
||||
OLLAMA_URL: "http://host.docker.internal:11434"
|
||||
LIGHTNING_BACKEND: "mock"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
networks:
|
||||
- swarm-test-net
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 10
|
||||
start_period: 10s
|
||||
|
||||
agent:
|
||||
build: .
|
||||
image: timmy-time:test
|
||||
profiles:
|
||||
- agents
|
||||
volumes:
|
||||
- test-data:/app/data
|
||||
- ./src:/app/src
|
||||
environment:
|
||||
COORDINATOR_URL: "http://dashboard:8000"
|
||||
OLLAMA_URL: "http://host.docker.internal:11434"
|
||||
AGENT_NAME: "${AGENT_NAME:-TestWorker}"
|
||||
AGENT_CAPABILITIES: "${AGENT_CAPABILITIES:-general}"
|
||||
TIMMY_TEST_MODE: "1"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
command: ["sh", "-c", "python -m swarm.agent_runner --agent-id agent-$(hostname) --name $${AGENT_NAME:-TestWorker}"]
|
||||
networks:
|
||||
- swarm-test-net
|
||||
depends_on:
|
||||
dashboard:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
test-data:
|
||||
|
||||
networks:
|
||||
swarm-test-net:
|
||||
driver: bridge
|
||||
@@ -54,6 +54,12 @@ voice = [
|
||||
telegram = [
|
||||
"python-telegram-bot>=21.0",
|
||||
]
|
||||
# Discord: bridge Discord messages to Timmy with native thread support.
|
||||
# pip install ".[discord]"
|
||||
# Optional: pip install pyzbar Pillow (for QR code invite detection)
|
||||
discord = [
|
||||
"discord.py>=2.3.0",
|
||||
]
|
||||
# Creative: GPU-accelerated image, music, and video generation.
|
||||
# pip install ".[creative]"
|
||||
creative = [
|
||||
@@ -84,6 +90,7 @@ include = [
|
||||
"src/notifications",
|
||||
"src/shortcuts",
|
||||
"src/telegram_bot",
|
||||
"src/chat_bridge",
|
||||
"src/spark",
|
||||
"src/tools",
|
||||
"src/creative",
|
||||
|
||||
10
src/chat_bridge/__init__.py
Normal file
10
src/chat_bridge/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Chat Bridge — vendor-agnostic chat platform abstraction.
|
||||
|
||||
Provides a clean interface for integrating any chat platform
|
||||
(Discord, Telegram, Slack, etc.) with Timmy's agent core.
|
||||
|
||||
Usage:
|
||||
from chat_bridge.base import ChatPlatform
|
||||
from chat_bridge.registry import platform_registry
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
"""
|
||||
147
src/chat_bridge/base.py
Normal file
147
src/chat_bridge/base.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""ChatPlatform — abstract base class for all chat vendor integrations.
|
||||
|
||||
Each vendor (Discord, Telegram, Slack, etc.) implements this interface.
|
||||
The dashboard and agent code interact only with this contract, never
|
||||
with vendor-specific APIs directly.
|
||||
|
||||
Architecture:
|
||||
ChatPlatform (ABC)
|
||||
|
|
||||
+-- DiscordVendor (discord.py)
|
||||
+-- TelegramVendor (future migration)
|
||||
+-- SlackVendor (future)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class PlatformState(Enum):
|
||||
"""Lifecycle state of a chat platform connection."""
|
||||
DISCONNECTED = auto()
|
||||
CONNECTING = auto()
|
||||
CONNECTED = auto()
|
||||
ERROR = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""Vendor-agnostic representation of a chat message."""
|
||||
content: str
|
||||
author: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
timestamp: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
message_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None
|
||||
attachments: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatThread:
|
||||
"""Vendor-agnostic representation of a conversation thread."""
|
||||
thread_id: str
|
||||
title: str
|
||||
channel_id: str
|
||||
platform: str
|
||||
created_at: str = field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
archived: bool = False
|
||||
message_count: int = 0
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InviteInfo:
|
||||
"""Parsed invite extracted from an image or text."""
|
||||
url: str
|
||||
code: str
|
||||
platform: str
|
||||
guild_name: Optional[str] = None
|
||||
source: str = "unknown" # "qr", "vision", "text"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformStatus:
|
||||
"""Current status of a chat platform connection."""
|
||||
platform: str
|
||||
state: PlatformState
|
||||
token_set: bool
|
||||
guild_count: int = 0
|
||||
thread_count: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform,
|
||||
"state": self.state.name.lower(),
|
||||
"connected": self.state == PlatformState.CONNECTED,
|
||||
"token_set": self.token_set,
|
||||
"guild_count": self.guild_count,
|
||||
"thread_count": self.thread_count,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class ChatPlatform(ABC):
|
||||
"""Abstract base class for chat platform integrations.
|
||||
|
||||
Lifecycle:
|
||||
configure(token) -> start() -> [send/receive messages] -> stop()
|
||||
|
||||
All vendors implement this interface. The dashboard routes and
|
||||
agent code work with ChatPlatform, never with vendor-specific APIs.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Platform identifier (e.g., 'discord', 'telegram')."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def state(self) -> PlatformState:
|
||||
"""Current connection state."""
|
||||
|
||||
@abstractmethod
|
||||
async def start(self, token: Optional[str] = None) -> bool:
|
||||
"""Start the platform connection. Returns True on success."""
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully disconnect."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self, channel_id: str, content: str, thread_id: Optional[str] = None
|
||||
) -> Optional[ChatMessage]:
|
||||
"""Send a message. Optionally within a thread."""
|
||||
|
||||
@abstractmethod
|
||||
async def create_thread(
|
||||
self, channel_id: str, title: str, initial_message: Optional[str] = None
|
||||
) -> Optional[ChatThread]:
|
||||
"""Create a new thread in a channel."""
|
||||
|
||||
@abstractmethod
|
||||
async def join_from_invite(self, invite_code: str) -> bool:
|
||||
"""Join a server/workspace using an invite code."""
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> PlatformStatus:
|
||||
"""Return current platform status."""
|
||||
|
||||
@abstractmethod
|
||||
def save_token(self, token: str) -> None:
|
||||
"""Persist token for restarts."""
|
||||
|
||||
@abstractmethod
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Load persisted token."""
|
||||
166
src/chat_bridge/invite_parser.py
Normal file
166
src/chat_bridge/invite_parser.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""InviteParser — extract chat platform invite links from images.
|
||||
|
||||
Strategy chain:
|
||||
1. QR code detection (pyzbar — fast, no GPU)
|
||||
2. Ollama vision OCR (local LLM — handles screenshots with visible URLs)
|
||||
3. Regex fallback on raw text input
|
||||
|
||||
Supports Discord invite patterns:
|
||||
- discord.gg/<code>
|
||||
- discord.com/invite/<code>
|
||||
- discordapp.com/invite/<code>
|
||||
|
||||
Usage:
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
# From image bytes (screenshot or QR photo)
|
||||
result = await invite_parser.parse_image(image_bytes)
|
||||
|
||||
# From plain text
|
||||
result = invite_parser.parse_text("Join us at discord.gg/abc123")
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import InviteInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Patterns for Discord invite URLs
|
||||
_DISCORD_PATTERNS = [
|
||||
re.compile(r"(?:https?://)?discord\.gg/([A-Za-z0-9\-_]+)"),
|
||||
re.compile(r"(?:https?://)?(?:www\.)?discord(?:app)?\.com/invite/([A-Za-z0-9\-_]+)"),
|
||||
]
|
||||
|
||||
|
||||
def _extract_discord_code(text: str) -> Optional[str]:
|
||||
"""Extract a Discord invite code from text."""
|
||||
for pattern in _DISCORD_PATTERNS:
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
|
||||
class InviteParser:
|
||||
"""Multi-strategy invite parser.
|
||||
|
||||
Tries QR detection first (fast), then Ollama vision (local AI),
|
||||
then regex on raw text. All local, no cloud.
|
||||
"""
|
||||
|
||||
async def parse_image(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Extract an invite from image bytes (screenshot or QR photo).
|
||||
|
||||
Tries strategies in order:
|
||||
1. QR code decode (pyzbar)
|
||||
2. Ollama vision model (local OCR)
|
||||
"""
|
||||
result = self._try_qr_decode(image_data)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = await self._try_ollama_vision(image_data)
|
||||
if result:
|
||||
return result
|
||||
|
||||
logger.info("No invite found in image via any strategy.")
|
||||
return None
|
||||
|
||||
def parse_text(self, text: str) -> Optional[InviteInfo]:
|
||||
"""Extract an invite from plain text."""
|
||||
code = _extract_discord_code(text)
|
||||
if code:
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="text",
|
||||
)
|
||||
return None
|
||||
|
||||
def _try_qr_decode(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Strategy 1: Decode QR codes from image using pyzbar."""
|
||||
try:
|
||||
from PIL import Image
|
||||
from pyzbar.pyzbar import decode as qr_decode
|
||||
except ImportError:
|
||||
logger.debug("pyzbar/Pillow not installed, skipping QR strategy.")
|
||||
return None
|
||||
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
decoded = qr_decode(image)
|
||||
|
||||
for obj in decoded:
|
||||
text = obj.data.decode("utf-8", errors="ignore")
|
||||
code = _extract_discord_code(text)
|
||||
if code:
|
||||
logger.info("QR decode found Discord invite: %s", code)
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="qr",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("QR decode failed: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
async def _try_ollama_vision(self, image_data: bytes) -> Optional[InviteInfo]:
|
||||
"""Strategy 2: Use Ollama vision model for local OCR."""
|
||||
try:
|
||||
import base64
|
||||
import httpx
|
||||
from config import settings
|
||||
except ImportError:
|
||||
logger.debug("httpx not available for Ollama vision.")
|
||||
return None
|
||||
|
||||
try:
|
||||
b64_image = base64.b64encode(image_data).decode("ascii")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{settings.ollama_url}/api/generate",
|
||||
json={
|
||||
"model": settings.ollama_model,
|
||||
"prompt": (
|
||||
"Extract any Discord invite link from this image. "
|
||||
"Look for URLs like discord.gg/CODE or "
|
||||
"discord.com/invite/CODE. "
|
||||
"Reply with ONLY the invite URL, nothing else. "
|
||||
"If no invite link is found, reply with: NONE"
|
||||
),
|
||||
"images": [b64_image],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.debug("Ollama vision returned %d", resp.status_code)
|
||||
return None
|
||||
|
||||
answer = resp.json().get("response", "").strip()
|
||||
if answer and answer.upper() != "NONE":
|
||||
code = _extract_discord_code(answer)
|
||||
if code:
|
||||
logger.info("Ollama vision found Discord invite: %s", code)
|
||||
return InviteInfo(
|
||||
url=f"https://discord.gg/{code}",
|
||||
code=code,
|
||||
platform="discord",
|
||||
source="vision",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Ollama vision strategy failed: %s", exc)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
invite_parser = InviteParser()
|
||||
74
src/chat_bridge/registry.py
Normal file
74
src/chat_bridge/registry.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""PlatformRegistry — singleton registry for chat platform vendors.
|
||||
|
||||
Provides a central point for registering, discovering, and managing
|
||||
all chat platform integrations. Dashboard routes and the agent core
|
||||
interact with platforms through this registry.
|
||||
|
||||
Usage:
|
||||
from chat_bridge.registry import platform_registry
|
||||
|
||||
platform_registry.register(discord_vendor)
|
||||
discord = platform_registry.get("discord")
|
||||
all_platforms = platform_registry.list_platforms()
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import ChatPlatform, PlatformStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PlatformRegistry:
|
||||
"""Thread-safe registry of ChatPlatform vendors."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._platforms: dict[str, ChatPlatform] = {}
|
||||
|
||||
def register(self, platform: ChatPlatform) -> None:
|
||||
"""Register a chat platform vendor."""
|
||||
name = platform.name
|
||||
if name in self._platforms:
|
||||
logger.warning("Platform '%s' already registered, replacing.", name)
|
||||
self._platforms[name] = platform
|
||||
logger.info("Registered chat platform: %s", name)
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Remove a platform from the registry. Returns True if it existed."""
|
||||
if name in self._platforms:
|
||||
del self._platforms[name]
|
||||
logger.info("Unregistered chat platform: %s", name)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name: str) -> Optional[ChatPlatform]:
|
||||
"""Get a platform by name."""
|
||||
return self._platforms.get(name)
|
||||
|
||||
def list_platforms(self) -> list[PlatformStatus]:
|
||||
"""Return status of all registered platforms."""
|
||||
return [p.status() for p in self._platforms.values()]
|
||||
|
||||
async def start_all(self) -> dict[str, bool]:
|
||||
"""Start all registered platforms. Returns name -> success mapping."""
|
||||
results = {}
|
||||
for name, platform in self._platforms.items():
|
||||
try:
|
||||
results[name] = await platform.start()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to start platform '%s': %s", name, exc)
|
||||
results[name] = False
|
||||
return results
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all registered platforms."""
|
||||
for name, platform in self._platforms.items():
|
||||
try:
|
||||
await platform.stop()
|
||||
except Exception as exc:
|
||||
logger.error("Error stopping platform '%s': %s", name, exc)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
platform_registry = PlatformRegistry()
|
||||
0
src/chat_bridge/vendors/__init__.py
vendored
Normal file
0
src/chat_bridge/vendors/__init__.py
vendored
Normal file
400
src/chat_bridge/vendors/discord.py
vendored
Normal file
400
src/chat_bridge/vendors/discord.py
vendored
Normal file
@@ -0,0 +1,400 @@
|
||||
"""DiscordVendor — Discord integration via discord.py.
|
||||
|
||||
Implements ChatPlatform with native thread support. Each conversation
|
||||
with Timmy gets its own Discord thread, keeping channels clean.
|
||||
|
||||
Optional dependency — install with:
|
||||
pip install ".[discord]"
|
||||
|
||||
Architecture:
|
||||
DiscordVendor
|
||||
├── _client (discord.Client) — handles gateway events
|
||||
├── _thread_map — channel_id -> active thread
|
||||
└── _message_handler — bridges to Timmy agent
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from chat_bridge.base import (
|
||||
ChatMessage,
|
||||
ChatPlatform,
|
||||
ChatThread,
|
||||
InviteInfo,
|
||||
PlatformState,
|
||||
PlatformStatus,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STATE_FILE = Path(__file__).parent.parent.parent.parent / "discord_state.json"
|
||||
|
||||
|
||||
class DiscordVendor(ChatPlatform):
|
||||
"""Discord integration with native thread conversations.
|
||||
|
||||
Every user interaction creates or continues a Discord thread,
|
||||
keeping channel history clean and conversations organized.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._client = None
|
||||
self._token: Optional[str] = None
|
||||
self._state: PlatformState = PlatformState.DISCONNECTED
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._guild_count: int = 0
|
||||
self._active_threads: dict[str, str] = {} # channel_id -> thread_id
|
||||
|
||||
# ── ChatPlatform interface ─────────────────────────────────────────────
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "discord"
|
||||
|
||||
@property
|
||||
def state(self) -> PlatformState:
|
||||
return self._state
|
||||
|
||||
async def start(self, token: Optional[str] = None) -> bool:
|
||||
"""Start the Discord bot. Returns True on success."""
|
||||
if self._state == PlatformState.CONNECTED:
|
||||
return True
|
||||
|
||||
tok = token or self.load_token()
|
||||
if not tok:
|
||||
logger.warning("Discord bot: no token configured, skipping start.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import discord
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"discord.py is not installed. "
|
||||
'Run: pip install ".[discord]"'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
self._state = PlatformState.CONNECTING
|
||||
self._token = tok
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
|
||||
self._client = discord.Client(intents=intents)
|
||||
self._register_handlers()
|
||||
|
||||
# Run the client in a background task so we don't block
|
||||
self._task = asyncio.create_task(self._run_client(tok))
|
||||
|
||||
# Wait briefly for connection
|
||||
for _ in range(30):
|
||||
await asyncio.sleep(0.5)
|
||||
if self._state == PlatformState.CONNECTED:
|
||||
logger.info("Discord bot connected (%d guilds).", self._guild_count)
|
||||
return True
|
||||
if self._state == PlatformState.ERROR:
|
||||
return False
|
||||
|
||||
logger.warning("Discord bot: connection timed out.")
|
||||
self._state = PlatformState.ERROR
|
||||
return False
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Discord bot failed to start: %s", exc)
|
||||
self._state = PlatformState.ERROR
|
||||
self._token = None
|
||||
self._client = None
|
||||
return False
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully disconnect the Discord bot."""
|
||||
if self._client and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
logger.info("Discord bot disconnected.")
|
||||
except Exception as exc:
|
||||
logger.error("Error stopping Discord bot: %s", exc)
|
||||
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
self._client = None
|
||||
self._task = None
|
||||
|
||||
async def send_message(
|
||||
self, channel_id: str, content: str, thread_id: Optional[str] = None
|
||||
) -> Optional[ChatMessage]:
|
||||
"""Send a message to a Discord channel or thread."""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
return None
|
||||
|
||||
try:
|
||||
import discord
|
||||
|
||||
target_id = int(thread_id) if thread_id else int(channel_id)
|
||||
channel = self._client.get_channel(target_id)
|
||||
|
||||
if channel is None:
|
||||
channel = await self._client.fetch_channel(target_id)
|
||||
|
||||
msg = await channel.send(content)
|
||||
|
||||
return ChatMessage(
|
||||
content=content,
|
||||
author=str(self._client.user),
|
||||
channel_id=str(msg.channel.id),
|
||||
platform="discord",
|
||||
message_id=str(msg.id),
|
||||
thread_id=thread_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to send Discord message: %s", exc)
|
||||
return None
|
||||
|
||||
async def create_thread(
|
||||
self, channel_id: str, title: str, initial_message: Optional[str] = None
|
||||
) -> Optional[ChatThread]:
|
||||
"""Create a new thread in a Discord channel."""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
return None
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(channel_id))
|
||||
if channel is None:
|
||||
channel = await self._client.fetch_channel(int(channel_id))
|
||||
|
||||
thread = await channel.create_thread(
|
||||
name=title[:100], # Discord limits thread names to 100 chars
|
||||
auto_archive_duration=1440, # 24 hours
|
||||
)
|
||||
|
||||
if initial_message:
|
||||
await thread.send(initial_message)
|
||||
|
||||
self._active_threads[channel_id] = str(thread.id)
|
||||
|
||||
return ChatThread(
|
||||
thread_id=str(thread.id),
|
||||
title=title[:100],
|
||||
channel_id=channel_id,
|
||||
platform="discord",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to create Discord thread: %s", exc)
|
||||
return None
|
||||
|
||||
async def join_from_invite(self, invite_code: str) -> bool:
|
||||
"""Join a Discord server using an invite code.
|
||||
|
||||
Note: Bot accounts cannot use invite links directly.
|
||||
This generates an OAuth2 URL for adding the bot to a server.
|
||||
The invite_code is validated but the actual join requires
|
||||
the server admin to use the bot's OAuth2 authorization URL.
|
||||
"""
|
||||
if not self._client or self._state != PlatformState.CONNECTED:
|
||||
logger.warning("Discord bot not connected, cannot process invite.")
|
||||
return False
|
||||
|
||||
try:
|
||||
import discord
|
||||
|
||||
invite = await self._client.fetch_invite(invite_code)
|
||||
logger.info(
|
||||
"Validated invite for server '%s' (code: %s)",
|
||||
invite.guild.name if invite.guild else "unknown",
|
||||
invite_code,
|
||||
)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("Invalid Discord invite '%s': %s", invite_code, exc)
|
||||
return False
|
||||
|
||||
def status(self) -> PlatformStatus:
|
||||
return PlatformStatus(
|
||||
platform="discord",
|
||||
state=self._state,
|
||||
token_set=bool(self._token),
|
||||
guild_count=self._guild_count,
|
||||
thread_count=len(self._active_threads),
|
||||
)
|
||||
|
||||
def save_token(self, token: str) -> None:
|
||||
"""Persist token to state file."""
|
||||
try:
|
||||
_STATE_FILE.write_text(json.dumps({"token": token}))
|
||||
except Exception as exc:
|
||||
logger.error("Failed to save Discord token: %s", exc)
|
||||
|
||||
def load_token(self) -> Optional[str]:
|
||||
"""Load token from state file or config."""
|
||||
try:
|
||||
if _STATE_FILE.exists():
|
||||
data = json.loads(_STATE_FILE.read_text())
|
||||
token = data.get("token")
|
||||
if token:
|
||||
return token
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read discord state file: %s", exc)
|
||||
|
||||
try:
|
||||
from config import settings
|
||||
return settings.discord_token or None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ── OAuth2 URL generation ──────────────────────────────────────────────
|
||||
|
||||
def get_oauth2_url(self) -> Optional[str]:
|
||||
"""Generate the OAuth2 URL for adding this bot to a server.
|
||||
|
||||
Requires the bot to be connected to read its application ID.
|
||||
"""
|
||||
if not self._client or not self._client.user:
|
||||
return None
|
||||
|
||||
app_id = self._client.user.id
|
||||
# Permissions: Send Messages, Create Public Threads, Manage Threads,
|
||||
# Read Message History, Embed Links, Attach Files
|
||||
permissions = 397284550656
|
||||
return (
|
||||
f"https://discord.com/oauth2/authorize"
|
||||
f"?client_id={app_id}&scope=bot"
|
||||
f"&permissions={permissions}"
|
||||
)
|
||||
|
||||
# ── Internal ───────────────────────────────────────────────────────────
|
||||
|
||||
async def _run_client(self, token: str) -> None:
|
||||
"""Run the discord.py client (blocking call in a task)."""
|
||||
try:
|
||||
await self._client.start(token)
|
||||
except Exception as exc:
|
||||
logger.error("Discord client error: %s", exc)
|
||||
self._state = PlatformState.ERROR
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""Register Discord event handlers on the client."""
|
||||
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
self._guild_count = len(self._client.guilds)
|
||||
self._state = PlatformState.CONNECTED
|
||||
logger.info(
|
||||
"Discord ready: %s in %d guild(s)",
|
||||
self._client.user,
|
||||
self._guild_count,
|
||||
)
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message):
|
||||
# Ignore our own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Only respond to mentions or DMs
|
||||
is_dm = not hasattr(message.channel, "guild") or message.channel.guild is None
|
||||
is_mention = self._client.user in message.mentions
|
||||
|
||||
if not is_dm and not is_mention:
|
||||
return
|
||||
|
||||
await self._handle_message(message)
|
||||
|
||||
@self._client.event
|
||||
async def on_disconnect():
|
||||
if self._state != PlatformState.DISCONNECTED:
|
||||
self._state = PlatformState.CONNECTING
|
||||
logger.warning("Discord disconnected, will auto-reconnect.")
|
||||
|
||||
async def _handle_message(self, message) -> None:
|
||||
"""Process an incoming message and respond via a thread."""
|
||||
# Strip the bot mention from the message content
|
||||
content = message.content
|
||||
if self._client.user:
|
||||
content = content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Create or reuse a thread for this conversation
|
||||
thread = await self._get_or_create_thread(message)
|
||||
target = thread or message.channel
|
||||
|
||||
# Run Timmy agent
|
||||
try:
|
||||
from timmy.agent import create_timmy
|
||||
|
||||
agent = create_timmy()
|
||||
run = await asyncio.to_thread(agent.run, content, stream=False)
|
||||
response = run.content if hasattr(run, "content") else str(run)
|
||||
except Exception as exc:
|
||||
logger.error("Timmy error in Discord handler: %s", exc)
|
||||
response = f"Timmy is offline: {exc}"
|
||||
|
||||
# Discord has a 2000 character limit
|
||||
for chunk in _chunk_message(response, 2000):
|
||||
await target.send(chunk)
|
||||
|
||||
async def _get_or_create_thread(self, message):
|
||||
"""Get the active thread for a channel, or create one.
|
||||
|
||||
If the message is already in a thread, use that thread.
|
||||
Otherwise, create a new thread from the message.
|
||||
"""
|
||||
try:
|
||||
import discord
|
||||
|
||||
# Already in a thread — just use it
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
return message.channel
|
||||
|
||||
# DM channels don't support threads
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return None
|
||||
|
||||
# Create a thread from this message
|
||||
thread_name = f"Timmy | {message.author.display_name}"
|
||||
thread = await message.create_thread(
|
||||
name=thread_name[:100],
|
||||
auto_archive_duration=1440,
|
||||
)
|
||||
channel_id = str(message.channel.id)
|
||||
self._active_threads[channel_id] = str(thread.id)
|
||||
return thread
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug("Could not create thread: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def _chunk_message(text: str, max_len: int = 2000) -> list[str]:
|
||||
"""Split a message into chunks that fit Discord's character limit."""
|
||||
if len(text) <= max_len:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
while text:
|
||||
if len(text) <= max_len:
|
||||
chunks.append(text)
|
||||
break
|
||||
# Try to split at a newline
|
||||
split_at = text.rfind("\n", 0, max_len)
|
||||
if split_at == -1:
|
||||
split_at = max_len
|
||||
chunks.append(text[:split_at])
|
||||
text = text[split_at:].lstrip("\n")
|
||||
return chunks
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
discord_bot = DiscordVendor()
|
||||
@@ -16,6 +16,9 @@ class Settings(BaseSettings):
|
||||
# Telegram bot token — set via TELEGRAM_TOKEN env var or the /telegram/setup endpoint
|
||||
telegram_token: str = ""
|
||||
|
||||
# Discord bot token — set via DISCORD_TOKEN env var or the /discord/setup endpoint
|
||||
discord_token: str = ""
|
||||
|
||||
# ── AirLLM / backend selection ───────────────────────────────────────────
|
||||
# "ollama" — always use Ollama (default, safe everywhere)
|
||||
# "airllm" — always use AirLLM (requires pip install ".[bigbrain]")
|
||||
|
||||
@@ -25,6 +25,7 @@ from dashboard.routes.swarm_internal import router as swarm_internal_router
|
||||
from dashboard.routes.tools import router as tools_router
|
||||
from dashboard.routes.spark import router as spark_router
|
||||
from dashboard.routes.creative import router as creative_router
|
||||
from dashboard.routes.discord import router as discord_router
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -108,8 +109,15 @@ async def lifespan(app: FastAPI):
|
||||
from telegram_bot.bot import telegram_bot
|
||||
await telegram_bot.start()
|
||||
|
||||
# Auto-start Discord bot and register in platform registry
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
from chat_bridge.registry import platform_registry
|
||||
platform_registry.register(discord_bot)
|
||||
await discord_bot.start()
|
||||
|
||||
yield
|
||||
|
||||
await discord_bot.stop()
|
||||
await telegram_bot.stop()
|
||||
task.cancel()
|
||||
try:
|
||||
@@ -145,6 +153,7 @@ app.include_router(swarm_internal_router)
|
||||
app.include_router(tools_router)
|
||||
app.include_router(spark_router)
|
||||
app.include_router(creative_router)
|
||||
app.include_router(discord_router)
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
|
||||
140
src/dashboard/routes/discord.py
Normal file
140
src/dashboard/routes/discord.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Dashboard routes for Discord bot setup, status, and invite-from-image.
|
||||
|
||||
Endpoints:
|
||||
POST /discord/setup — configure bot token
|
||||
GET /discord/status — connection state + guild count
|
||||
POST /discord/join — paste screenshot → extract invite → join
|
||||
GET /discord/oauth-url — get the bot's OAuth2 authorization URL
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter(prefix="/discord", tags=["discord"])
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/setup")
|
||||
async def setup_discord(payload: TokenPayload):
|
||||
"""Configure the Discord bot token and (re)start the bot.
|
||||
|
||||
Send POST with JSON body: {"token": "<your-bot-token>"}
|
||||
Get the token from https://discord.com/developers/applications
|
||||
"""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
token = payload.token.strip()
|
||||
if not token:
|
||||
return {"ok": False, "error": "Token cannot be empty."}
|
||||
|
||||
discord_bot.save_token(token)
|
||||
|
||||
if discord_bot.state.name == "CONNECTED":
|
||||
await discord_bot.stop()
|
||||
|
||||
success = await discord_bot.start(token=token)
|
||||
if success:
|
||||
return {"ok": True, "message": "Discord bot connected successfully."}
|
||||
return {
|
||||
"ok": False,
|
||||
"error": (
|
||||
"Failed to start bot. Check that the token is correct and "
|
||||
'discord.py is installed: pip install ".[discord]"'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def discord_status():
|
||||
"""Return current Discord bot status."""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
return discord_bot.status().to_dict()
|
||||
|
||||
|
||||
@router.post("/join")
|
||||
async def join_from_image(
|
||||
image: Optional[UploadFile] = File(None),
|
||||
invite_url: Optional[str] = Form(None),
|
||||
):
|
||||
"""Extract a Discord invite from a screenshot or text and validate it.
|
||||
|
||||
Accepts either:
|
||||
- An uploaded image (screenshot of invite or QR code)
|
||||
- A plain text invite URL
|
||||
|
||||
The bot validates the invite and returns the OAuth2 URL for the
|
||||
server admin to authorize the bot.
|
||||
"""
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
invite_info = None
|
||||
|
||||
# Try image first
|
||||
if image and image.filename:
|
||||
image_data = await image.read()
|
||||
if image_data:
|
||||
invite_info = await invite_parser.parse_image(image_data)
|
||||
|
||||
# Fall back to text
|
||||
if not invite_info and invite_url:
|
||||
invite_info = invite_parser.parse_text(invite_url)
|
||||
|
||||
if not invite_info:
|
||||
return {
|
||||
"ok": False,
|
||||
"error": (
|
||||
"No Discord invite found. "
|
||||
"Paste a screenshot with a visible invite link or QR code, "
|
||||
"or enter the invite URL directly."
|
||||
),
|
||||
}
|
||||
|
||||
# Validate the invite
|
||||
valid = await discord_bot.join_from_invite(invite_info.code)
|
||||
|
||||
result = {
|
||||
"ok": True,
|
||||
"invite": {
|
||||
"code": invite_info.code,
|
||||
"url": invite_info.url,
|
||||
"source": invite_info.source,
|
||||
"platform": invite_info.platform,
|
||||
},
|
||||
"validated": valid,
|
||||
}
|
||||
|
||||
# Include OAuth2 URL if bot is connected
|
||||
oauth_url = discord_bot.get_oauth2_url()
|
||||
if oauth_url:
|
||||
result["oauth2_url"] = oauth_url
|
||||
result["message"] = (
|
||||
"Invite validated. Share this OAuth2 URL with the server admin "
|
||||
"to add Timmy to the server."
|
||||
)
|
||||
else:
|
||||
result["message"] = (
|
||||
"Invite found but bot is not connected. "
|
||||
"Configure a bot token first via /discord/setup."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/oauth-url")
|
||||
async def discord_oauth_url():
|
||||
"""Get the bot's OAuth2 authorization URL for adding to servers."""
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
url = discord_bot.get_oauth2_url()
|
||||
if url:
|
||||
return {"ok": True, "url": url}
|
||||
return {
|
||||
"ok": False,
|
||||
"error": "Bot is not connected. Configure a token first.",
|
||||
}
|
||||
@@ -25,6 +25,14 @@ for _mod in [
|
||||
# without the package installed.
|
||||
"telegram",
|
||||
"telegram.ext",
|
||||
# discord.py is optional (discord extra) — stub so tests run
|
||||
# without the package installed.
|
||||
"discord",
|
||||
"discord.ext",
|
||||
"discord.ext.commands",
|
||||
# pyzbar is optional (for QR code invite detection)
|
||||
"pyzbar",
|
||||
"pyzbar.pyzbar",
|
||||
]:
|
||||
sys.modules.setdefault(_mod, MagicMock())
|
||||
|
||||
|
||||
0
tests/functional/__init__.py
Normal file
0
tests/functional/__init__.py
Normal file
185
tests/functional/conftest.py
Normal file
185
tests/functional/conftest.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Functional test fixtures — real services, no mocking.
|
||||
|
||||
These fixtures provide:
|
||||
- TestClient hitting the real FastAPI app (singletons, SQLite, etc.)
|
||||
- Typer CliRunner for CLI commands
|
||||
- Real temporary SQLite for swarm state
|
||||
- Real payment handler with mock lightning backend (LIGHTNING_BACKEND=mock)
|
||||
- Docker compose lifecycle for container-level tests
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# ── Stub heavy optional deps (same as root conftest) ─────────────────────────
|
||||
# These aren't mocks — they're import compatibility shims for packages
|
||||
# not installed in the test environment. The code under test handles
|
||||
# their absence via try/except ImportError.
|
||||
for _mod in [
|
||||
"agno", "agno.agent", "agno.models", "agno.models.ollama",
|
||||
"agno.db", "agno.db.sqlite",
|
||||
"airllm",
|
||||
"telegram", "telegram.ext",
|
||||
]:
|
||||
sys.modules.setdefault(_mod, MagicMock())
|
||||
|
||||
os.environ["TIMMY_TEST_MODE"] = "1"
|
||||
|
||||
|
||||
# ── Isolation: fresh coordinator state per test ───────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_state():
|
||||
"""Reset all singleton state between tests so they can't leak."""
|
||||
from dashboard.store import message_log
|
||||
message_log.clear()
|
||||
yield
|
||||
message_log.clear()
|
||||
from swarm.coordinator import coordinator
|
||||
coordinator.auctions._auctions.clear()
|
||||
coordinator.comms._listeners.clear()
|
||||
coordinator._in_process_nodes.clear()
|
||||
coordinator.manager.stop_all()
|
||||
try:
|
||||
from swarm import routing
|
||||
routing.routing_engine._manifests.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ── TestClient with real app, no patches ──────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def app_client(tmp_path):
|
||||
"""TestClient wrapping the real dashboard app.
|
||||
|
||||
Uses a tmp_path for swarm SQLite so tests don't pollute each other.
|
||||
No mocking — Ollama is offline (graceful degradation), singletons are real.
|
||||
"""
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
import swarm.tasks as tasks_mod
|
||||
import swarm.registry as registry_mod
|
||||
original_tasks_db = tasks_mod.DB_PATH
|
||||
original_reg_db = registry_mod.DB_PATH
|
||||
|
||||
tasks_mod.DB_PATH = data_dir / "swarm.db"
|
||||
registry_mod.DB_PATH = data_dir / "swarm.db"
|
||||
|
||||
from dashboard.app import app
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
tasks_mod.DB_PATH = original_tasks_db
|
||||
registry_mod.DB_PATH = original_reg_db
|
||||
|
||||
|
||||
# ── Timmy-serve TestClient ────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def serve_client():
|
||||
"""TestClient wrapping the timmy-serve L402 app.
|
||||
|
||||
Uses real mock-lightning backend (LIGHTNING_BACKEND=mock).
|
||||
"""
|
||||
from timmy_serve.app import create_timmy_serve_app
|
||||
|
||||
app = create_timmy_serve_app(price_sats=100)
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
# ── CLI runners ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def timmy_runner():
|
||||
"""Typer CliRunner + app for the `timmy` CLI."""
|
||||
from typer.testing import CliRunner
|
||||
from timmy.cli import app
|
||||
return CliRunner(), app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serve_runner():
|
||||
"""Typer CliRunner + app for the `timmy-serve` CLI."""
|
||||
from typer.testing import CliRunner
|
||||
from timmy_serve.cli import app
|
||||
return CliRunner(), app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tdd_runner():
|
||||
"""Typer CliRunner + app for the `self-tdd` CLI."""
|
||||
from typer.testing import CliRunner
|
||||
from self_tdd.watchdog import app
|
||||
return CliRunner(), app
|
||||
|
||||
|
||||
# ── Docker compose lifecycle ──────────────────────────────────────────────────
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
COMPOSE_TEST = PROJECT_ROOT / "docker-compose.test.yml"
|
||||
|
||||
|
||||
def _compose(*args, timeout=60):
|
||||
"""Run a docker compose command against the test compose file."""
|
||||
cmd = ["docker", "compose", "-f", str(COMPOSE_TEST), "-p", "timmy-test", *args]
|
||||
return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, cwd=str(PROJECT_ROOT))
|
||||
|
||||
|
||||
def _wait_for_healthy(url: str, retries=30, interval=2):
|
||||
"""Poll a URL until it returns 200 or we run out of retries."""
|
||||
import httpx
|
||||
for i in range(retries):
|
||||
try:
|
||||
r = httpx.get(url, timeout=5)
|
||||
if r.status_code == 200:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
time.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def docker_stack():
|
||||
"""Spin up the test compose stack once per session.
|
||||
|
||||
Yields a base URL (http://localhost:18000) to hit the dashboard.
|
||||
Tears down after all tests complete.
|
||||
|
||||
Skipped unless FUNCTIONAL_DOCKER=1 is set.
|
||||
"""
|
||||
if not COMPOSE_TEST.exists():
|
||||
pytest.skip("docker-compose.test.yml not found")
|
||||
if os.environ.get("FUNCTIONAL_DOCKER") != "1":
|
||||
pytest.skip("Set FUNCTIONAL_DOCKER=1 to run Docker tests")
|
||||
|
||||
# Verify Docker daemon is reachable before attempting build
|
||||
docker_check = subprocess.run(
|
||||
["docker", "info"], capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if docker_check.returncode != 0:
|
||||
pytest.skip(f"Docker daemon not available: {docker_check.stderr.strip()}")
|
||||
|
||||
result = _compose("up", "-d", "--build", "--wait", timeout=300)
|
||||
if result.returncode != 0:
|
||||
pytest.fail(f"docker compose up failed:\n{result.stderr}")
|
||||
|
||||
base_url = "http://localhost:18000"
|
||||
if not _wait_for_healthy(f"{base_url}/health"):
|
||||
logs = _compose("logs")
|
||||
_compose("down", "-v")
|
||||
pytest.fail(f"Dashboard never became healthy:\n{logs.stdout}")
|
||||
|
||||
yield base_url
|
||||
|
||||
_compose("down", "-v", timeout=60)
|
||||
124
tests/functional/test_cli.py
Normal file
124
tests/functional/test_cli.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Functional tests for CLI entry points via Typer's CliRunner.
|
||||
|
||||
Each test invokes the real CLI command. Ollama is not running, so
|
||||
commands that need inference will fail gracefully — and that's a valid
|
||||
user scenario we want to verify.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── timmy CLI ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTimmyCLI:
|
||||
"""Tests the `timmy` command (chat, think, status)."""
|
||||
|
||||
def test_status_runs(self, timmy_runner):
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["status"])
|
||||
# Ollama is offline, so this should either:
|
||||
# - Print an error about Ollama being unreachable, OR
|
||||
# - Exit non-zero
|
||||
# Either way, the CLI itself shouldn't crash with an unhandled exception.
|
||||
# The exit code tells us if the command ran at all.
|
||||
assert result.exit_code is not None
|
||||
|
||||
def test_chat_requires_message(self, timmy_runner):
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["chat"])
|
||||
# Missing required argument
|
||||
assert result.exit_code != 0
|
||||
assert "Missing argument" in result.output or "Usage" in result.output
|
||||
|
||||
def test_think_requires_topic(self, timmy_runner):
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["think"])
|
||||
assert result.exit_code != 0
|
||||
assert "Missing argument" in result.output or "Usage" in result.output
|
||||
|
||||
def test_chat_with_message_runs(self, timmy_runner):
|
||||
"""Chat with a real message — Ollama offline means graceful failure."""
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["chat", "hello"])
|
||||
# Will fail because Ollama isn't running, but the CLI should handle it
|
||||
assert result.exit_code is not None
|
||||
|
||||
def test_backend_flag_accepted(self, timmy_runner):
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["status", "--backend", "ollama"])
|
||||
assert result.exit_code is not None
|
||||
|
||||
def test_help_text(self, timmy_runner):
|
||||
runner, app = timmy_runner
|
||||
result = runner.invoke(app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Timmy" in result.output or "sovereign" in result.output.lower()
|
||||
|
||||
|
||||
# ── timmy-serve CLI ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTimmyServeCLI:
|
||||
"""Tests the `timmy-serve` command (start, invoice, status)."""
|
||||
|
||||
def test_start_dry_run(self, serve_runner):
|
||||
"""--dry-run should print config and exit cleanly."""
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["start", "--dry-run"])
|
||||
assert result.exit_code == 0
|
||||
assert "Starting Timmy Serve" in result.output
|
||||
assert "Dry run" in result.output or "dry run" in result.output
|
||||
|
||||
def test_start_dry_run_custom_port(self, serve_runner):
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["start", "--dry-run", "--port", "9999"])
|
||||
assert result.exit_code == 0
|
||||
assert "9999" in result.output
|
||||
|
||||
def test_start_dry_run_custom_price(self, serve_runner):
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["start", "--dry-run", "--price", "500"])
|
||||
assert result.exit_code == 0
|
||||
assert "500" in result.output
|
||||
|
||||
def test_invoice_creates_real_invoice(self, serve_runner):
|
||||
"""Create a real Lightning invoice via the mock backend."""
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["invoice", "--amount", "200", "--memo", "test invoice"])
|
||||
assert result.exit_code == 0
|
||||
assert "Invoice created" in result.output
|
||||
assert "200" in result.output
|
||||
assert "Payment hash" in result.output or "payment_hash" in result.output.lower()
|
||||
|
||||
def test_status_shows_earnings(self, serve_runner):
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["status"])
|
||||
assert result.exit_code == 0
|
||||
assert "Total invoices" in result.output or "invoices" in result.output.lower()
|
||||
assert "sats" in result.output.lower()
|
||||
|
||||
def test_help_text(self, serve_runner):
|
||||
runner, app = serve_runner
|
||||
result = runner.invoke(app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Serve" in result.output or "Lightning" in result.output
|
||||
|
||||
|
||||
# ── self-tdd CLI ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSelfTddCLI:
|
||||
"""Tests the `self-tdd` command (watch)."""
|
||||
|
||||
def test_help_text(self, tdd_runner):
|
||||
runner, app = tdd_runner
|
||||
result = runner.invoke(app, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "watchdog" in result.output.lower() or "test" in result.output.lower()
|
||||
|
||||
def test_watch_help(self, tdd_runner):
|
||||
runner, app = tdd_runner
|
||||
result = runner.invoke(app, ["watch", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "interval" in result.output.lower()
|
||||
199
tests/functional/test_dashboard.py
Normal file
199
tests/functional/test_dashboard.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Functional tests for the dashboard — real HTTP requests, no mocking.
|
||||
|
||||
The dashboard runs with Ollama offline (graceful degradation).
|
||||
These tests verify what a real user sees when they open the browser.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDashboardLoads:
|
||||
"""Verify the dashboard serves real HTML pages."""
|
||||
|
||||
def test_index_page(self, app_client):
|
||||
response = app_client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
# The real rendered page should have the base HTML structure
|
||||
assert "<html" in response.text
|
||||
assert "Timmy" in response.text
|
||||
|
||||
def test_health_endpoint(self, app_client):
|
||||
response = app_client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data or "ollama" in data
|
||||
|
||||
def test_agents_json(self, app_client):
|
||||
response = app_client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, (dict, list))
|
||||
|
||||
def test_swarm_live_page(self, app_client):
|
||||
response = app_client.get("/swarm/live")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
assert "WebSocket" in response.text or "swarm" in response.text.lower()
|
||||
|
||||
def test_mobile_endpoint(self, app_client):
|
||||
response = app_client.get("/mobile/status")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestChatFlowOffline:
|
||||
"""Test the chat flow when Ollama is not running.
|
||||
|
||||
This is a real user scenario — they start the dashboard before Ollama.
|
||||
The app should degrade gracefully, not crash.
|
||||
"""
|
||||
|
||||
def test_chat_with_ollama_offline(self, app_client):
|
||||
"""POST to chat endpoint — should return HTML with an error message,
|
||||
not a 500 server error."""
|
||||
response = app_client.post(
|
||||
"/agents/timmy/chat",
|
||||
data={"message": "hello timmy"},
|
||||
)
|
||||
# The route catches exceptions and returns them in the template
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
# Should contain either the error message or the response
|
||||
assert "hello timmy" in response.text or "offline" in response.text.lower() or "error" in response.text.lower()
|
||||
|
||||
def test_chat_requires_message_field(self, app_client):
|
||||
"""POST without the message field should fail."""
|
||||
response = app_client.post("/agents/timmy/chat", data={})
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_history_starts_empty(self, app_client):
|
||||
response = app_client.get("/agents/timmy/history")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_chat_then_history(self, app_client):
|
||||
"""After chatting, history should contain the message."""
|
||||
app_client.post("/agents/timmy/chat", data={"message": "test message"})
|
||||
response = app_client.get("/agents/timmy/history")
|
||||
assert response.status_code == 200
|
||||
assert "test message" in response.text
|
||||
|
||||
def test_clear_history(self, app_client):
|
||||
app_client.post("/agents/timmy/chat", data={"message": "ephemeral"})
|
||||
response = app_client.delete("/agents/timmy/history")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestSwarmLifecycle:
|
||||
"""Full swarm lifecycle: spawn → post task → bid → assign → complete.
|
||||
|
||||
No mocking. Real coordinator, real SQLite, real in-process agents.
|
||||
"""
|
||||
|
||||
def test_spawn_agent_and_list(self, app_client):
|
||||
spawn = app_client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
assert spawn.status_code == 200
|
||||
spawn_data = spawn.json()
|
||||
agent_id = spawn_data.get("id") or spawn_data.get("agent_id")
|
||||
assert agent_id
|
||||
|
||||
agents = app_client.get("/swarm/agents")
|
||||
assert agents.status_code == 200
|
||||
agent_names = [a["name"] for a in agents.json()["agents"]]
|
||||
assert "Echo" in agent_names
|
||||
|
||||
def test_post_task_opens_auction(self, app_client):
|
||||
resp = app_client.post("/swarm/tasks", data={"description": "Summarize README"})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["description"] == "Summarize README"
|
||||
assert data["status"] == "bidding"
|
||||
|
||||
def test_task_persists_in_list(self, app_client):
|
||||
app_client.post("/swarm/tasks", data={"description": "Task Alpha"})
|
||||
app_client.post("/swarm/tasks", data={"description": "Task Beta"})
|
||||
resp = app_client.get("/swarm/tasks")
|
||||
descriptions = [t["description"] for t in resp.json()["tasks"]]
|
||||
assert "Task Alpha" in descriptions
|
||||
assert "Task Beta" in descriptions
|
||||
|
||||
def test_complete_task(self, app_client):
|
||||
post = app_client.post("/swarm/tasks", data={"description": "Quick job"})
|
||||
task_id = post.json()["task_id"]
|
||||
resp = app_client.post(
|
||||
f"/swarm/tasks/{task_id}/complete",
|
||||
data={"result": "Done."},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "completed"
|
||||
|
||||
# Verify the result persisted
|
||||
task = app_client.get(f"/swarm/tasks/{task_id}")
|
||||
assert task.json()["result"] == "Done."
|
||||
|
||||
def test_fail_task_feeds_learner(self, app_client):
|
||||
post = app_client.post("/swarm/tasks", data={"description": "Doomed job"})
|
||||
task_id = post.json()["task_id"]
|
||||
resp = app_client.post(
|
||||
f"/swarm/tasks/{task_id}/fail",
|
||||
data={"reason": "OOM"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "failed"
|
||||
|
||||
def test_stop_agent(self, app_client):
|
||||
spawn = app_client.post("/swarm/spawn", data={"name": "Disposable"})
|
||||
agent_id = spawn.json().get("id") or spawn.json().get("agent_id")
|
||||
resp = app_client.delete(f"/swarm/agents/{agent_id}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["stopped"] is True
|
||||
|
||||
def test_insights_endpoint(self, app_client):
|
||||
resp = app_client.get("/swarm/insights")
|
||||
assert resp.status_code == 200
|
||||
assert "agents" in resp.json()
|
||||
|
||||
def test_websocket_connects(self, app_client):
|
||||
"""Real WebSocket connection to /swarm/live."""
|
||||
with app_client.websocket_connect("/swarm/live") as ws:
|
||||
ws.send_text("ping")
|
||||
# Connection holds — the endpoint just logs, doesn't echo back.
|
||||
# The point is it doesn't crash.
|
||||
|
||||
|
||||
class TestSwarmUIPartials:
|
||||
"""HTMX partial endpoints — verify they return real rendered HTML."""
|
||||
|
||||
def test_agents_sidebar_html(self, app_client):
|
||||
app_client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
resp = app_client.get("/swarm/agents/sidebar")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
assert "echo" in resp.text.lower()
|
||||
|
||||
def test_agent_panel_html(self, app_client):
|
||||
spawn = app_client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
agent_id = spawn.json().get("id") or spawn.json().get("agent_id")
|
||||
resp = app_client.get(f"/swarm/agents/{agent_id}/panel")
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
assert "echo" in resp.text.lower()
|
||||
|
||||
def test_message_agent_creates_task(self, app_client):
|
||||
spawn = app_client.post("/swarm/spawn", data={"name": "Worker"})
|
||||
agent_id = spawn.json().get("id") or spawn.json().get("agent_id")
|
||||
resp = app_client.post(
|
||||
f"/swarm/agents/{agent_id}/message",
|
||||
data={"message": "Summarise the codebase"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
|
||||
def test_direct_assign_to_agent(self, app_client):
|
||||
spawn = app_client.post("/swarm/spawn", data={"name": "Worker"})
|
||||
agent_id = spawn.json().get("id") or spawn.json().get("agent_id")
|
||||
resp = app_client.post(
|
||||
"/swarm/tasks/direct",
|
||||
data={"description": "Direct job", "agent_id": agent_id},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
150
tests/functional/test_docker_swarm.py
Normal file
150
tests/functional/test_docker_swarm.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Container-level swarm integration tests.
|
||||
|
||||
These tests require Docker and run against real containers:
|
||||
- dashboard on port 18000
|
||||
- agent workers scaled via docker compose
|
||||
|
||||
Run with:
|
||||
FUNCTIONAL_DOCKER=1 pytest tests/functional/test_docker_swarm.py -v
|
||||
|
||||
Skipped automatically if FUNCTIONAL_DOCKER != "1".
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Try to import httpx for real HTTP calls to containers
|
||||
httpx = pytest.importorskip("httpx")
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
||||
COMPOSE_TEST = PROJECT_ROOT / "docker-compose.test.yml"
|
||||
|
||||
|
||||
def _compose(*args, timeout=60):
|
||||
cmd = ["docker", "compose", "-f", str(COMPOSE_TEST), "-p", "timmy-test", *args]
|
||||
return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, cwd=str(PROJECT_ROOT))
|
||||
|
||||
|
||||
class TestDockerDashboard:
|
||||
"""Tests hitting the real dashboard container over HTTP."""
|
||||
|
||||
def test_health(self, docker_stack):
|
||||
resp = httpx.get(f"{docker_stack}/health", timeout=10)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "status" in data or "ollama" in data
|
||||
|
||||
def test_index_page(self, docker_stack):
|
||||
resp = httpx.get(docker_stack, timeout=10)
|
||||
assert resp.status_code == 200
|
||||
assert "text/html" in resp.headers["content-type"]
|
||||
assert "Timmy" in resp.text
|
||||
|
||||
def test_swarm_status(self, docker_stack):
|
||||
resp = httpx.get(f"{docker_stack}/swarm", timeout=10)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_spawn_agent_via_api(self, docker_stack):
|
||||
resp = httpx.post(
|
||||
f"{docker_stack}/swarm/spawn",
|
||||
data={"name": "RemoteEcho"},
|
||||
timeout=10,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data.get("name") == "RemoteEcho" or "id" in data
|
||||
|
||||
def test_post_task_via_api(self, docker_stack):
|
||||
resp = httpx.post(
|
||||
f"{docker_stack}/swarm/tasks",
|
||||
data={"description": "Docker test task"},
|
||||
timeout=10,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["description"] == "Docker test task"
|
||||
assert "task_id" in data
|
||||
|
||||
|
||||
class TestDockerAgentSwarm:
|
||||
"""Tests with real agent containers communicating over the network.
|
||||
|
||||
These tests scale up agent workers and verify they register,
|
||||
bid on tasks, and get assigned work — all over real HTTP.
|
||||
"""
|
||||
|
||||
def test_agent_registers_via_http(self, docker_stack):
|
||||
"""Scale up one agent worker and verify it appears in the registry."""
|
||||
# Start one agent
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=1",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0, f"Failed to start agent:\n{result.stderr}"
|
||||
|
||||
# Give the agent time to register via HTTP
|
||||
time.sleep(8)
|
||||
|
||||
resp = httpx.get(f"{docker_stack}/swarm/agents", timeout=10)
|
||||
assert resp.status_code == 200
|
||||
agents = resp.json()["agents"]
|
||||
agent_names = [a["name"] for a in agents]
|
||||
assert "TestWorker" in agent_names or any("Worker" in n for n in agent_names)
|
||||
|
||||
# Clean up the agent
|
||||
_compose("--profile", "agents", "down", timeout=30)
|
||||
|
||||
def test_agent_bids_on_task(self, docker_stack):
|
||||
"""Start an agent, post a task, verify the agent bids on it."""
|
||||
# Start agent
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=1",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
# Wait for agent to register
|
||||
time.sleep(8)
|
||||
|
||||
# Post a task — this triggers an auction
|
||||
task_resp = httpx.post(
|
||||
f"{docker_stack}/swarm/tasks",
|
||||
data={"description": "Test bidding flow"},
|
||||
timeout=10,
|
||||
)
|
||||
assert task_resp.status_code == 200
|
||||
task_id = task_resp.json()["task_id"]
|
||||
|
||||
# Give the agent time to poll and bid
|
||||
time.sleep(12)
|
||||
|
||||
# Check task status — may have been assigned
|
||||
task = httpx.get(f"{docker_stack}/swarm/tasks/{task_id}", timeout=10)
|
||||
assert task.status_code == 200
|
||||
task_data = task.json()
|
||||
# The task should still exist regardless of bid outcome
|
||||
assert task_data["description"] == "Test bidding flow"
|
||||
|
||||
_compose("--profile", "agents", "down", timeout=30)
|
||||
|
||||
def test_multiple_agents(self, docker_stack):
|
||||
"""Scale to 3 agents and verify all register."""
|
||||
result = _compose(
|
||||
"--profile", "agents", "up", "-d", "--scale", "agent=3",
|
||||
timeout=120,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
|
||||
# Wait for registration
|
||||
time.sleep(12)
|
||||
|
||||
resp = httpx.get(f"{docker_stack}/swarm/agents", timeout=10)
|
||||
agents = resp.json()["agents"]
|
||||
# Should have at least the 3 agents we started (plus possibly Timmy and auto-spawned ones)
|
||||
worker_count = sum(1 for a in agents if "Worker" in a["name"] or "TestWorker" in a["name"])
|
||||
assert worker_count >= 1 # At least some registered
|
||||
|
||||
_compose("--profile", "agents", "down", timeout=30)
|
||||
106
tests/functional/test_l402_flow.py
Normal file
106
tests/functional/test_l402_flow.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Functional test for the full L402 payment flow.
|
||||
|
||||
Uses the real mock-lightning backend (LIGHTNING_BACKEND=mock) — no patching.
|
||||
This exercises the entire payment lifecycle a real client would go through:
|
||||
|
||||
1. Hit protected endpoint → get 402 + invoice + macaroon
|
||||
2. "Pay" the invoice (settle via mock backend)
|
||||
3. Present macaroon:preimage → get access
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestL402PaymentFlow:
|
||||
"""End-to-end L402 payment lifecycle."""
|
||||
|
||||
def test_unprotected_endpoints_work(self, serve_client):
|
||||
"""Status and health don't require payment."""
|
||||
resp = serve_client.get("/serve/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["status"] == "active"
|
||||
assert data["price_sats"] == 100
|
||||
|
||||
health = serve_client.get("/health")
|
||||
assert health.status_code == 200
|
||||
|
||||
def test_chat_without_payment_returns_402(self, serve_client):
|
||||
"""Hitting /serve/chat without an L402 token gives 402."""
|
||||
resp = serve_client.post(
|
||||
"/serve/chat",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert resp.status_code == 402
|
||||
data = resp.json()
|
||||
assert data["error"] == "Payment Required"
|
||||
assert data["code"] == "L402"
|
||||
assert "macaroon" in data
|
||||
assert "invoice" in data
|
||||
assert "payment_hash" in data
|
||||
assert data["amount_sats"] == 100
|
||||
|
||||
# WWW-Authenticate header should be present
|
||||
assert "WWW-Authenticate" in resp.headers
|
||||
assert "L402" in resp.headers["WWW-Authenticate"]
|
||||
|
||||
def test_chat_with_garbage_token_returns_402(self, serve_client):
|
||||
resp = serve_client.post(
|
||||
"/serve/chat",
|
||||
json={"message": "hello"},
|
||||
headers={"Authorization": "L402 garbage:token"},
|
||||
)
|
||||
assert resp.status_code == 402
|
||||
|
||||
def test_full_payment_lifecycle(self, serve_client):
|
||||
"""Complete flow: get challenge → pay → access."""
|
||||
from timmy_serve.payment_handler import payment_handler
|
||||
|
||||
# Step 1: Hit protected endpoint, get 402 challenge
|
||||
challenge_resp = serve_client.post(
|
||||
"/serve/chat",
|
||||
json={"message": "hello"},
|
||||
)
|
||||
assert challenge_resp.status_code == 402
|
||||
challenge = challenge_resp.json()
|
||||
macaroon = challenge["macaroon"]
|
||||
payment_hash = challenge["payment_hash"]
|
||||
|
||||
# Step 2: "Pay" the invoice via the mock backend's auto-settle
|
||||
# The mock backend settles invoices when you provide the correct preimage.
|
||||
# Get the preimage from the mock backend's internal state.
|
||||
invoice = payment_handler.get_invoice(payment_hash)
|
||||
assert invoice is not None
|
||||
preimage = invoice.preimage # mock backend exposes this
|
||||
|
||||
# Step 3: Present macaroon:preimage to access the endpoint
|
||||
resp = serve_client.post(
|
||||
"/serve/chat",
|
||||
json={"message": "hello after paying"},
|
||||
headers={"Authorization": f"L402 {macaroon}:{preimage}"},
|
||||
)
|
||||
# The chat will fail because Ollama isn't running, but the
|
||||
# L402 middleware should let us through (status != 402).
|
||||
# We accept 200 (success) or 500 (Ollama offline) — NOT 402.
|
||||
assert resp.status_code != 402
|
||||
|
||||
def test_create_invoice_via_api(self, serve_client):
|
||||
"""POST /serve/invoice creates a real invoice."""
|
||||
resp = serve_client.post(
|
||||
"/serve/invoice",
|
||||
json={"amount_sats": 500, "memo": "premium access"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["amount_sats"] == 500
|
||||
assert data["payment_hash"]
|
||||
assert data["payment_request"]
|
||||
|
||||
def test_status_reflects_invoices(self, serve_client):
|
||||
"""Creating invoices should be reflected in /serve/status."""
|
||||
serve_client.post("/serve/invoice", json={"amount_sats": 100, "memo": "test"})
|
||||
serve_client.post("/serve/invoice", json={"amount_sats": 200, "memo": "test2"})
|
||||
|
||||
resp = serve_client.get("/serve/status")
|
||||
data = resp.json()
|
||||
assert data["total_invoices"] >= 2
|
||||
456
tests/test_agent_core.py
Normal file
456
tests/test_agent_core.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""Functional tests for agent_core — interface and ollama_adapter.
|
||||
|
||||
Covers the substrate-agnostic agent contract (data classes, enums,
|
||||
factory methods, abstract enforcement) and the OllamaAgent adapter
|
||||
(perceive → reason → act → remember → recall → communicate workflow).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_core.interface import (
|
||||
ActionType,
|
||||
AgentCapability,
|
||||
AgentEffect,
|
||||
AgentIdentity,
|
||||
Action,
|
||||
Communication,
|
||||
Memory,
|
||||
Perception,
|
||||
PerceptionType,
|
||||
TimAgent,
|
||||
)
|
||||
|
||||
|
||||
# ── AgentIdentity ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentIdentity:
|
||||
def test_generate_creates_uuid(self):
|
||||
identity = AgentIdentity.generate("Timmy")
|
||||
assert identity.name == "Timmy"
|
||||
uuid.UUID(identity.id) # raises on invalid
|
||||
|
||||
def test_generate_default_version(self):
|
||||
identity = AgentIdentity.generate("Timmy")
|
||||
assert identity.version == "1.0.0"
|
||||
|
||||
def test_generate_custom_version(self):
|
||||
identity = AgentIdentity.generate("Timmy", version="2.0.0")
|
||||
assert identity.version == "2.0.0"
|
||||
|
||||
def test_frozen_identity(self):
|
||||
identity = AgentIdentity.generate("Timmy")
|
||||
with pytest.raises(AttributeError):
|
||||
identity.name = "Other"
|
||||
|
||||
def test_created_at_populated(self):
|
||||
identity = AgentIdentity.generate("Timmy")
|
||||
assert identity.created_at # not empty
|
||||
assert "T" in identity.created_at # ISO format
|
||||
|
||||
def test_two_identities_differ(self):
|
||||
a = AgentIdentity.generate("A")
|
||||
b = AgentIdentity.generate("B")
|
||||
assert a.id != b.id
|
||||
|
||||
|
||||
# ── Perception ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPerception:
|
||||
def test_text_factory(self):
|
||||
p = Perception.text("hello")
|
||||
assert p.type == PerceptionType.TEXT
|
||||
assert p.data == "hello"
|
||||
assert p.source == "user"
|
||||
|
||||
def test_text_factory_custom_source(self):
|
||||
p = Perception.text("hello", source="api")
|
||||
assert p.source == "api"
|
||||
|
||||
def test_sensor_factory(self):
|
||||
p = Perception.sensor("temperature", 22.5, "°C")
|
||||
assert p.type == PerceptionType.SENSOR
|
||||
assert p.data["kind"] == "temperature"
|
||||
assert p.data["value"] == 22.5
|
||||
assert p.data["unit"] == "°C"
|
||||
assert p.source == "sensor_temperature"
|
||||
|
||||
def test_timestamp_auto_populated(self):
|
||||
p = Perception.text("hi")
|
||||
assert p.timestamp
|
||||
assert "T" in p.timestamp
|
||||
|
||||
def test_metadata_defaults_empty(self):
|
||||
p = Perception.text("hi")
|
||||
assert p.metadata == {}
|
||||
|
||||
|
||||
# ── Action ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAction:
|
||||
def test_respond_factory(self):
|
||||
a = Action.respond("Hello!")
|
||||
assert a.type == ActionType.TEXT
|
||||
assert a.payload == "Hello!"
|
||||
assert a.confidence == 1.0
|
||||
|
||||
def test_respond_with_confidence(self):
|
||||
a = Action.respond("Maybe", confidence=0.5)
|
||||
assert a.confidence == 0.5
|
||||
|
||||
def test_move_factory(self):
|
||||
a = Action.move((1.0, 2.0, 3.0), speed=0.5)
|
||||
assert a.type == ActionType.MOVE
|
||||
assert a.payload["vector"] == (1.0, 2.0, 3.0)
|
||||
assert a.payload["speed"] == 0.5
|
||||
|
||||
def test_move_default_speed(self):
|
||||
a = Action.move((0, 0, 0))
|
||||
assert a.payload["speed"] == 1.0
|
||||
|
||||
def test_deadline_defaults_none(self):
|
||||
a = Action.respond("test")
|
||||
assert a.deadline is None
|
||||
|
||||
|
||||
# ── Memory ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMemory:
|
||||
def test_touch_increments(self):
|
||||
m = Memory(id="m1", content="hello", created_at="2025-01-01T00:00:00Z")
|
||||
assert m.access_count == 0
|
||||
m.touch()
|
||||
assert m.access_count == 1
|
||||
m.touch()
|
||||
assert m.access_count == 2
|
||||
|
||||
def test_touch_sets_last_accessed(self):
|
||||
m = Memory(id="m1", content="hello", created_at="2025-01-01T00:00:00Z")
|
||||
assert m.last_accessed is None
|
||||
m.touch()
|
||||
assert m.last_accessed is not None
|
||||
|
||||
def test_default_importance(self):
|
||||
m = Memory(id="m1", content="x", created_at="now")
|
||||
assert m.importance == 0.5
|
||||
|
||||
def test_tags_default_empty(self):
|
||||
m = Memory(id="m1", content="x", created_at="now")
|
||||
assert m.tags == []
|
||||
|
||||
|
||||
# ── Communication ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCommunication:
|
||||
def test_defaults(self):
|
||||
c = Communication(sender="A", recipient="B", content="hi")
|
||||
assert c.protocol == "direct"
|
||||
assert c.encrypted is False
|
||||
assert c.timestamp # auto-populated
|
||||
|
||||
|
||||
# ── TimAgent abstract enforcement ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTimAgentABC:
|
||||
def test_cannot_instantiate_abstract(self):
|
||||
with pytest.raises(TypeError):
|
||||
TimAgent(AgentIdentity.generate("X"))
|
||||
|
||||
def test_concrete_subclass_works(self):
|
||||
class Dummy(TimAgent):
|
||||
def perceive(self, p): return Memory(id="1", content=p.data, created_at="")
|
||||
def reason(self, q, c): return Action.respond(q)
|
||||
def act(self, a): return a.payload
|
||||
def remember(self, m): pass
|
||||
def recall(self, q, limit=5): return []
|
||||
def communicate(self, m): return True
|
||||
|
||||
d = Dummy(AgentIdentity.generate("Dummy"))
|
||||
assert d.identity.name == "Dummy"
|
||||
assert d.capabilities == set()
|
||||
|
||||
def test_has_capability(self):
|
||||
class Dummy(TimAgent):
|
||||
def perceive(self, p): pass
|
||||
def reason(self, q, c): pass
|
||||
def act(self, a): pass
|
||||
def remember(self, m): pass
|
||||
def recall(self, q, limit=5): return []
|
||||
def communicate(self, m): return True
|
||||
|
||||
d = Dummy(AgentIdentity.generate("D"))
|
||||
d._capabilities.add(AgentCapability.REASONING)
|
||||
assert d.has_capability(AgentCapability.REASONING)
|
||||
assert not d.has_capability(AgentCapability.VISION)
|
||||
|
||||
def test_capabilities_returns_copy(self):
|
||||
class Dummy(TimAgent):
|
||||
def perceive(self, p): pass
|
||||
def reason(self, q, c): pass
|
||||
def act(self, a): pass
|
||||
def remember(self, m): pass
|
||||
def recall(self, q, limit=5): return []
|
||||
def communicate(self, m): return True
|
||||
|
||||
d = Dummy(AgentIdentity.generate("D"))
|
||||
caps = d.capabilities
|
||||
caps.add(AgentCapability.VISION)
|
||||
assert AgentCapability.VISION not in d.capabilities
|
||||
|
||||
def test_get_state(self):
|
||||
class Dummy(TimAgent):
|
||||
def perceive(self, p): pass
|
||||
def reason(self, q, c): pass
|
||||
def act(self, a): pass
|
||||
def remember(self, m): pass
|
||||
def recall(self, q, limit=5): return []
|
||||
def communicate(self, m): return True
|
||||
|
||||
d = Dummy(AgentIdentity.generate("D"))
|
||||
state = d.get_state()
|
||||
assert "identity" in state
|
||||
assert "capabilities" in state
|
||||
assert "state" in state
|
||||
|
||||
def test_shutdown_does_not_raise(self):
|
||||
class Dummy(TimAgent):
|
||||
def perceive(self, p): pass
|
||||
def reason(self, q, c): pass
|
||||
def act(self, a): pass
|
||||
def remember(self, m): pass
|
||||
def recall(self, q, limit=5): return []
|
||||
def communicate(self, m): return True
|
||||
|
||||
d = Dummy(AgentIdentity.generate("D"))
|
||||
d.shutdown() # should not raise
|
||||
|
||||
|
||||
# ── AgentEffect ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentEffect:
|
||||
def test_empty_export(self):
|
||||
effect = AgentEffect()
|
||||
assert effect.export() == []
|
||||
|
||||
def test_log_perceive(self):
|
||||
effect = AgentEffect()
|
||||
p = Perception.text("test input")
|
||||
effect.log_perceive(p, "mem_0")
|
||||
log = effect.export()
|
||||
assert len(log) == 1
|
||||
assert log[0]["type"] == "perceive"
|
||||
assert log[0]["perception_type"] == "TEXT"
|
||||
assert log[0]["memory_id"] == "mem_0"
|
||||
assert "timestamp" in log[0]
|
||||
|
||||
def test_log_reason(self):
|
||||
effect = AgentEffect()
|
||||
effect.log_reason("How to help?", ActionType.TEXT)
|
||||
log = effect.export()
|
||||
assert len(log) == 1
|
||||
assert log[0]["type"] == "reason"
|
||||
assert log[0]["query"] == "How to help?"
|
||||
assert log[0]["action_type"] == "TEXT"
|
||||
|
||||
def test_log_act(self):
|
||||
effect = AgentEffect()
|
||||
action = Action.respond("Hello!")
|
||||
effect.log_act(action, "Hello!")
|
||||
log = effect.export()
|
||||
assert len(log) == 1
|
||||
assert log[0]["type"] == "act"
|
||||
assert log[0]["confidence"] == 1.0
|
||||
assert log[0]["result_type"] == "str"
|
||||
|
||||
def test_export_returns_copy(self):
|
||||
effect = AgentEffect()
|
||||
effect.log_reason("q", ActionType.TEXT)
|
||||
exported = effect.export()
|
||||
exported.clear()
|
||||
assert len(effect.export()) == 1
|
||||
|
||||
def test_full_audit_trail(self):
|
||||
effect = AgentEffect()
|
||||
p = Perception.text("input")
|
||||
effect.log_perceive(p, "m0")
|
||||
effect.log_reason("what now?", ActionType.TEXT)
|
||||
action = Action.respond("response")
|
||||
effect.log_act(action, "response")
|
||||
log = effect.export()
|
||||
assert len(log) == 3
|
||||
types = [e["type"] for e in log]
|
||||
assert types == ["perceive", "reason", "act"]
|
||||
|
||||
|
||||
# ── OllamaAgent functional tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOllamaAgent:
|
||||
"""Functional tests for the OllamaAgent adapter.
|
||||
|
||||
Uses mocked Ollama (create_timmy returns a mock) to exercise
|
||||
the full perceive → reason → act → remember → recall pipeline.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def agent(self):
|
||||
with patch("agent_core.ollama_adapter.create_timmy") as mock_ct:
|
||||
mock_timmy = MagicMock()
|
||||
mock_run = MagicMock()
|
||||
mock_run.content = "Mocked LLM response"
|
||||
mock_timmy.run.return_value = mock_run
|
||||
mock_ct.return_value = mock_timmy
|
||||
|
||||
from agent_core.ollama_adapter import OllamaAgent
|
||||
identity = AgentIdentity.generate("TestTimmy")
|
||||
return OllamaAgent(identity, effect_log="/tmp/test_effects")
|
||||
|
||||
def test_capabilities_set(self, agent):
|
||||
caps = agent.capabilities
|
||||
assert AgentCapability.REASONING in caps
|
||||
assert AgentCapability.CODING in caps
|
||||
assert AgentCapability.WRITING in caps
|
||||
assert AgentCapability.ANALYSIS in caps
|
||||
assert AgentCapability.COMMUNICATION in caps
|
||||
|
||||
def test_perceive_creates_memory(self, agent):
|
||||
p = Perception.text("Hello Timmy")
|
||||
mem = agent.perceive(p)
|
||||
assert mem.id == "mem_0"
|
||||
assert mem.content["data"] == "Hello Timmy"
|
||||
assert mem.content["type"] == "TEXT"
|
||||
|
||||
def test_perceive_extracts_tags(self, agent):
|
||||
p = Perception.text("I need help with a bug in my code")
|
||||
mem = agent.perceive(p)
|
||||
assert "TEXT" in mem.tags
|
||||
assert "user" in mem.tags
|
||||
assert "help" in mem.tags
|
||||
assert "bug" in mem.tags
|
||||
assert "code" in mem.tags
|
||||
|
||||
def test_perceive_fifo_eviction(self, agent):
|
||||
for i in range(12):
|
||||
agent.perceive(Perception.text(f"msg {i}"))
|
||||
assert len(agent._working_memory) == 10
|
||||
# oldest two evicted
|
||||
assert agent._working_memory[0].content["data"] == "msg 2"
|
||||
|
||||
def test_reason_returns_action(self, agent):
|
||||
mem = agent.perceive(Perception.text("context"))
|
||||
action = agent.reason("What should I do?", [mem])
|
||||
assert action.type == ActionType.TEXT
|
||||
assert action.payload == "Mocked LLM response"
|
||||
assert action.confidence == 0.9
|
||||
|
||||
def test_act_text(self, agent):
|
||||
action = Action.respond("Hello!")
|
||||
result = agent.act(action)
|
||||
assert result == "Hello!"
|
||||
|
||||
def test_act_speak(self, agent):
|
||||
action = Action(type=ActionType.SPEAK, payload="Speak this")
|
||||
result = agent.act(action)
|
||||
assert result["spoken"] == "Speak this"
|
||||
assert result["tts_engine"] == "pyttsx3"
|
||||
|
||||
def test_act_call(self, agent):
|
||||
action = Action(type=ActionType.CALL, payload={"url": "http://example.com"})
|
||||
result = agent.act(action)
|
||||
assert result["status"] == "not_implemented"
|
||||
|
||||
def test_act_unsupported(self, agent):
|
||||
action = Action(type=ActionType.MOVE, payload=(0, 0, 0))
|
||||
result = agent.act(action)
|
||||
assert "error" in result
|
||||
|
||||
def test_remember_stores_and_deduplicates(self, agent):
|
||||
mem = agent.perceive(Perception.text("original"))
|
||||
assert len(agent._working_memory) == 1
|
||||
agent.remember(mem)
|
||||
assert len(agent._working_memory) == 1 # deduplicated
|
||||
assert mem.access_count == 1
|
||||
|
||||
def test_remember_evicts_on_overflow(self, agent):
|
||||
for i in range(10):
|
||||
agent.perceive(Perception.text(f"fill {i}"))
|
||||
extra = Memory(id="extra", content="overflow", created_at="now")
|
||||
agent.remember(extra)
|
||||
assert len(agent._working_memory) == 10
|
||||
# first memory evicted
|
||||
assert agent._working_memory[-1].id == "extra"
|
||||
|
||||
def test_recall_keyword_matching(self, agent):
|
||||
agent.perceive(Perception.text("python code review"))
|
||||
agent.perceive(Perception.text("weather forecast"))
|
||||
agent.perceive(Perception.text("python bug fix"))
|
||||
results = agent.recall("python", limit=5)
|
||||
# All memories returned (recall returns up to limit)
|
||||
assert len(results) == 3
|
||||
# Memories containing "python" should score higher and appear first
|
||||
first_content = str(results[0].content)
|
||||
assert "python" in first_content.lower()
|
||||
|
||||
def test_recall_respects_limit(self, agent):
|
||||
for i in range(10):
|
||||
agent.perceive(Perception.text(f"memory {i}"))
|
||||
results = agent.recall("memory", limit=3)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_communicate_success(self, agent):
|
||||
with patch("swarm.comms.SwarmComms") as MockComms:
|
||||
mock_comms = MagicMock()
|
||||
MockComms.return_value = mock_comms
|
||||
msg = Communication(sender="Timmy", recipient="Echo", content="hi")
|
||||
result = agent.communicate(msg)
|
||||
# communicate returns True on success, False on exception
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_communicate_failure(self, agent):
|
||||
# Force an import error inside communicate() to trigger except branch
|
||||
with patch.dict("sys.modules", {"swarm.comms": None}):
|
||||
msg = Communication(sender="Timmy", recipient="Echo", content="hi")
|
||||
assert agent.communicate(msg) is False
|
||||
|
||||
def test_effect_logging_full_workflow(self, agent):
|
||||
p = Perception.text("test input")
|
||||
mem = agent.perceive(p)
|
||||
action = agent.reason("respond", [mem])
|
||||
agent.act(action)
|
||||
log = agent.get_effect_log()
|
||||
assert len(log) == 3
|
||||
assert log[0]["type"] == "perceive"
|
||||
assert log[1]["type"] == "reason"
|
||||
assert log[2]["type"] == "act"
|
||||
|
||||
def test_no_effect_log_when_disabled(self):
|
||||
with patch("agent_core.ollama_adapter.create_timmy") as mock_ct:
|
||||
mock_timmy = MagicMock()
|
||||
mock_ct.return_value = mock_timmy
|
||||
from agent_core.ollama_adapter import OllamaAgent
|
||||
identity = AgentIdentity.generate("NoLog")
|
||||
agent = OllamaAgent(identity) # no effect_log
|
||||
assert agent.get_effect_log() is None
|
||||
|
||||
def test_format_context_empty(self, agent):
|
||||
result = agent._format_context([])
|
||||
assert result == "No previous context."
|
||||
|
||||
def test_format_context_with_dict_content(self, agent):
|
||||
mem = Memory(id="m", content={"data": "hello"}, created_at="now")
|
||||
result = agent._format_context([mem])
|
||||
assert "hello" in result
|
||||
|
||||
def test_format_context_with_string_content(self, agent):
|
||||
mem = Memory(id="m", content="plain string", created_at="now")
|
||||
result = agent._format_context([mem])
|
||||
assert "plain string" in result
|
||||
268
tests/test_chat_bridge.py
Normal file
268
tests/test_chat_bridge.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Tests for the chat_bridge base classes, registry, and invite parser."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from chat_bridge.base import (
|
||||
ChatMessage,
|
||||
ChatPlatform,
|
||||
ChatThread,
|
||||
InviteInfo,
|
||||
PlatformState,
|
||||
PlatformStatus,
|
||||
)
|
||||
from chat_bridge.registry import PlatformRegistry
|
||||
|
||||
|
||||
# ── Base dataclass tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChatMessage:
|
||||
def test_create_message(self):
|
||||
msg = ChatMessage(
|
||||
content="Hello",
|
||||
author="user1",
|
||||
channel_id="123",
|
||||
platform="test",
|
||||
)
|
||||
assert msg.content == "Hello"
|
||||
assert msg.author == "user1"
|
||||
assert msg.platform == "test"
|
||||
assert msg.thread_id is None
|
||||
assert msg.attachments == []
|
||||
|
||||
def test_message_with_thread(self):
|
||||
msg = ChatMessage(
|
||||
content="Reply",
|
||||
author="bot",
|
||||
channel_id="123",
|
||||
platform="discord",
|
||||
thread_id="456",
|
||||
)
|
||||
assert msg.thread_id == "456"
|
||||
|
||||
|
||||
class TestChatThread:
|
||||
def test_create_thread(self):
|
||||
thread = ChatThread(
|
||||
thread_id="t1",
|
||||
title="Timmy | user1",
|
||||
channel_id="c1",
|
||||
platform="discord",
|
||||
)
|
||||
assert thread.thread_id == "t1"
|
||||
assert thread.archived is False
|
||||
assert thread.message_count == 0
|
||||
|
||||
|
||||
class TestInviteInfo:
|
||||
def test_create_invite(self):
|
||||
invite = InviteInfo(
|
||||
url="https://discord.gg/abc123",
|
||||
code="abc123",
|
||||
platform="discord",
|
||||
source="qr",
|
||||
)
|
||||
assert invite.code == "abc123"
|
||||
assert invite.source == "qr"
|
||||
|
||||
|
||||
class TestPlatformStatus:
|
||||
def test_to_dict(self):
|
||||
status = PlatformStatus(
|
||||
platform="discord",
|
||||
state=PlatformState.CONNECTED,
|
||||
token_set=True,
|
||||
guild_count=3,
|
||||
)
|
||||
d = status.to_dict()
|
||||
assert d["connected"] is True
|
||||
assert d["platform"] == "discord"
|
||||
assert d["guild_count"] == 3
|
||||
assert d["state"] == "connected"
|
||||
|
||||
def test_disconnected_status(self):
|
||||
status = PlatformStatus(
|
||||
platform="test",
|
||||
state=PlatformState.DISCONNECTED,
|
||||
token_set=False,
|
||||
)
|
||||
d = status.to_dict()
|
||||
assert d["connected"] is False
|
||||
|
||||
|
||||
# ── PlatformRegistry tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakePlatform(ChatPlatform):
|
||||
"""Minimal ChatPlatform for testing the registry."""
|
||||
|
||||
def __init__(self, platform_name: str = "fake"):
|
||||
self._name = platform_name
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def state(self) -> PlatformState:
|
||||
return self._state
|
||||
|
||||
async def start(self, token=None) -> bool:
|
||||
self._state = PlatformState.CONNECTED
|
||||
return True
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._state = PlatformState.DISCONNECTED
|
||||
|
||||
async def send_message(self, channel_id, content, thread_id=None):
|
||||
return ChatMessage(
|
||||
content=content, author="bot", channel_id=channel_id, platform=self._name
|
||||
)
|
||||
|
||||
async def create_thread(self, channel_id, title, initial_message=None):
|
||||
return ChatThread(
|
||||
thread_id="t1", title=title, channel_id=channel_id, platform=self._name
|
||||
)
|
||||
|
||||
async def join_from_invite(self, invite_code) -> bool:
|
||||
return True
|
||||
|
||||
def status(self):
|
||||
return PlatformStatus(
|
||||
platform=self._name,
|
||||
state=self._state,
|
||||
token_set=False,
|
||||
)
|
||||
|
||||
def save_token(self, token):
|
||||
pass
|
||||
|
||||
def load_token(self):
|
||||
return None
|
||||
|
||||
|
||||
class TestPlatformRegistry:
|
||||
def test_register_and_get(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("test1")
|
||||
reg.register(p)
|
||||
assert reg.get("test1") is p
|
||||
|
||||
def test_get_missing(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.get("nonexistent") is None
|
||||
|
||||
def test_unregister(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("test1")
|
||||
reg.register(p)
|
||||
assert reg.unregister("test1") is True
|
||||
assert reg.get("test1") is None
|
||||
|
||||
def test_unregister_missing(self):
|
||||
reg = PlatformRegistry()
|
||||
assert reg.unregister("nope") is False
|
||||
|
||||
def test_list_platforms(self):
|
||||
reg = PlatformRegistry()
|
||||
reg.register(_FakePlatform("a"))
|
||||
reg.register(_FakePlatform("b"))
|
||||
statuses = reg.list_platforms()
|
||||
assert len(statuses) == 2
|
||||
names = {s.platform for s in statuses}
|
||||
assert names == {"a", "b"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_all(self):
|
||||
reg = PlatformRegistry()
|
||||
reg.register(_FakePlatform("x"))
|
||||
reg.register(_FakePlatform("y"))
|
||||
results = await reg.start_all()
|
||||
assert results == {"x": True, "y": True}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_all(self):
|
||||
reg = PlatformRegistry()
|
||||
p = _FakePlatform("z")
|
||||
reg.register(p)
|
||||
await reg.start_all()
|
||||
assert p.state == PlatformState.CONNECTED
|
||||
await reg.stop_all()
|
||||
assert p.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_replace_existing(self):
|
||||
reg = PlatformRegistry()
|
||||
p1 = _FakePlatform("dup")
|
||||
p2 = _FakePlatform("dup")
|
||||
reg.register(p1)
|
||||
reg.register(p2)
|
||||
assert reg.get("dup") is p2
|
||||
|
||||
|
||||
# ── InviteParser tests ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestInviteParser:
|
||||
def test_parse_text_discord_gg(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("Join us at https://discord.gg/abc123!")
|
||||
assert result is not None
|
||||
assert result.code == "abc123"
|
||||
assert result.platform == "discord"
|
||||
assert result.source == "text"
|
||||
|
||||
def test_parse_text_discord_com_invite(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text(
|
||||
"Link: https://discord.com/invite/myServer2024"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "myServer2024"
|
||||
|
||||
def test_parse_text_discordapp(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text(
|
||||
"https://discordapp.com/invite/test-code"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "test-code"
|
||||
|
||||
def test_parse_text_no_invite(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("Hello world, no links here")
|
||||
assert result is None
|
||||
|
||||
def test_parse_text_bare_discord_gg(self):
|
||||
from chat_bridge.invite_parser import invite_parser
|
||||
|
||||
result = invite_parser.parse_text("discord.gg/xyz789")
|
||||
assert result is not None
|
||||
assert result.code == "xyz789"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_image_no_deps(self):
|
||||
"""parse_image returns None when pyzbar/Pillow are not installed."""
|
||||
from chat_bridge.invite_parser import InviteParser
|
||||
|
||||
parser = InviteParser()
|
||||
# With mocked pyzbar, this should gracefully return None
|
||||
result = await parser.parse_image(b"fake-image-bytes")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestExtractDiscordCode:
|
||||
def test_various_formats(self):
|
||||
from chat_bridge.invite_parser import _extract_discord_code
|
||||
|
||||
assert _extract_discord_code("discord.gg/abc") == "abc"
|
||||
assert _extract_discord_code("https://discord.gg/test") == "test"
|
||||
assert _extract_discord_code("http://discord.gg/http") == "http"
|
||||
assert _extract_discord_code("discord.com/invite/xyz") == "xyz"
|
||||
assert _extract_discord_code("no link here") is None
|
||||
assert _extract_discord_code("") is None
|
||||
225
tests/test_discord_vendor.py
Normal file
225
tests/test_discord_vendor.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for the Discord vendor and dashboard routes."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from chat_bridge.base import PlatformState
|
||||
|
||||
|
||||
# ── DiscordVendor unit tests ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiscordVendor:
|
||||
def test_name(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.name == "discord"
|
||||
|
||||
def test_initial_state(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_status_disconnected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
status = vendor.status()
|
||||
assert status.platform == "discord"
|
||||
assert status.state == PlatformState.DISCONNECTED
|
||||
assert status.token_set is False
|
||||
assert status.guild_count == 0
|
||||
|
||||
def test_save_and_load_token(self, tmp_path, monkeypatch):
|
||||
from chat_bridge.vendors import discord as discord_mod
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
state_file = tmp_path / "discord_state.json"
|
||||
monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file)
|
||||
|
||||
vendor = DiscordVendor()
|
||||
vendor.save_token("test-token-abc")
|
||||
|
||||
assert state_file.exists()
|
||||
data = json.loads(state_file.read_text())
|
||||
assert data["token"] == "test-token-abc"
|
||||
|
||||
loaded = vendor.load_token()
|
||||
assert loaded == "test-token-abc"
|
||||
|
||||
def test_load_token_missing_file(self, tmp_path, monkeypatch):
|
||||
from chat_bridge.vendors import discord as discord_mod
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
state_file = tmp_path / "nonexistent.json"
|
||||
monkeypatch.setattr(discord_mod, "_STATE_FILE", state_file)
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Falls back to config.settings.discord_token
|
||||
token = vendor.load_token()
|
||||
# Default discord_token is "" which becomes None
|
||||
assert token is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_token(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.start(token=None)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_import_error(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Simulate discord.py not installed by making import fail
|
||||
with patch.dict("sys.modules", {"discord": None}):
|
||||
result = await vendor.start(token="fake-token")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_when_disconnected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
# Should not raise
|
||||
await vendor.stop()
|
||||
assert vendor.state == PlatformState.DISCONNECTED
|
||||
|
||||
def test_get_oauth2_url_no_client(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
assert vendor.get_oauth2_url() is None
|
||||
|
||||
def test_get_oauth2_url_with_client(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
mock_client = MagicMock()
|
||||
mock_client.user.id = 123456789
|
||||
vendor._client = mock_client
|
||||
url = vendor.get_oauth2_url()
|
||||
assert "123456789" in url
|
||||
assert "oauth2/authorize" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.send_message("123", "hello")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_thread_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.create_thread("123", "Test Thread")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_from_invite_not_connected(self):
|
||||
from chat_bridge.vendors.discord import DiscordVendor
|
||||
|
||||
vendor = DiscordVendor()
|
||||
result = await vendor.join_from_invite("abc123")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestChunkMessage:
|
||||
def test_short_message(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
chunks = _chunk_message("Hello!", 2000)
|
||||
assert chunks == ["Hello!"]
|
||||
|
||||
def test_long_message(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
text = "a" * 5000
|
||||
chunks = _chunk_message(text, 2000)
|
||||
assert len(chunks) == 3
|
||||
assert all(len(c) <= 2000 for c in chunks)
|
||||
assert "".join(chunks) == text
|
||||
|
||||
def test_split_at_newline(self):
|
||||
from chat_bridge.vendors.discord import _chunk_message
|
||||
|
||||
text = "Line1\n" + "x" * 1990 + "\nLine3"
|
||||
chunks = _chunk_message(text, 2000)
|
||||
assert len(chunks) >= 2
|
||||
assert chunks[0].startswith("Line1")
|
||||
|
||||
|
||||
# ── Discord route tests ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiscordRoutes:
|
||||
def test_status_endpoint(self, client):
|
||||
resp = client.get("/discord/status")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["platform"] == "discord"
|
||||
assert "connected" in data
|
||||
|
||||
def test_setup_empty_token(self, client):
|
||||
resp = client.post("/discord/setup", json={"token": ""})
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert "empty" in data["error"].lower()
|
||||
|
||||
def test_setup_with_token(self, client):
|
||||
"""Setup with a token — bot won't actually connect but route works."""
|
||||
with patch(
|
||||
"chat_bridge.vendors.discord.DiscordVendor.start",
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
):
|
||||
resp = client.post(
|
||||
"/discord/setup", json={"token": "fake-token-123"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# Will fail because discord.py is mocked, but route handles it
|
||||
assert "ok" in data
|
||||
|
||||
def test_join_no_input(self, client):
|
||||
resp = client.post("/discord/join")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert "no discord invite" in data["error"].lower()
|
||||
|
||||
def test_join_with_text_invite(self, client):
|
||||
with patch(
|
||||
"chat_bridge.vendors.discord.DiscordVendor.join_from_invite",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
resp = client.post(
|
||||
"/discord/join",
|
||||
data={"invite_url": "https://discord.gg/testcode"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["invite"]["code"] == "testcode"
|
||||
assert data["invite"]["source"] == "text"
|
||||
|
||||
def test_oauth_url_not_connected(self, client):
|
||||
from chat_bridge.vendors.discord import discord_bot
|
||||
|
||||
# Reset singleton so it has no client
|
||||
discord_bot._client = None
|
||||
resp = client.get("/discord/oauth-url")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
170
tests/test_docker_runner.py
Normal file
170
tests/test_docker_runner.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Functional tests for swarm.docker_runner — Docker container lifecycle.
|
||||
|
||||
All subprocess calls are mocked so Docker is not required.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
from swarm.docker_runner import DockerAgentRunner, ManagedContainer
|
||||
|
||||
|
||||
class TestDockerAgentRunner:
|
||||
"""Test container spawn/stop/list lifecycle."""
|
||||
|
||||
def test_init_defaults(self):
|
||||
runner = DockerAgentRunner()
|
||||
assert runner.image == "timmy-time:latest"
|
||||
assert runner.coordinator_url == "http://dashboard:8000"
|
||||
assert runner.extra_env == {}
|
||||
assert runner._containers == {}
|
||||
|
||||
def test_init_custom(self):
|
||||
runner = DockerAgentRunner(
|
||||
image="custom:v2",
|
||||
coordinator_url="http://host:9000",
|
||||
extra_env={"FOO": "bar"},
|
||||
)
|
||||
assert runner.image == "custom:v2"
|
||||
assert runner.coordinator_url == "http://host:9000"
|
||||
assert runner.extra_env == {"FOO": "bar"}
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_spawn_success(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0, stdout="abc123container\n", stderr=""
|
||||
)
|
||||
runner = DockerAgentRunner()
|
||||
info = runner.spawn("Echo", agent_id="test-id-1234", capabilities="summarise")
|
||||
|
||||
assert info["container_id"] == "abc123container"
|
||||
assert info["agent_id"] == "test-id-1234"
|
||||
assert info["name"] == "Echo"
|
||||
assert info["capabilities"] == "summarise"
|
||||
assert "abc123container" in runner._containers
|
||||
|
||||
# Verify docker command structure
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert cmd[0] == "docker"
|
||||
assert cmd[1] == "run"
|
||||
assert "--detach" in cmd
|
||||
assert "--name" in cmd
|
||||
assert "timmy-time:latest" in cmd
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_spawn_generates_uuid_when_no_agent_id(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="cid\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
info = runner.spawn("Echo")
|
||||
# agent_id should be a valid UUID-like string
|
||||
assert len(info["agent_id"]) == 36 # UUID format
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_spawn_custom_image(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="cid\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
info = runner.spawn("Echo", image="custom:latest")
|
||||
assert info["image"] == "custom:latest"
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_spawn_docker_error(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=1, stdout="", stderr="no such image"
|
||||
)
|
||||
runner = DockerAgentRunner()
|
||||
with pytest.raises(RuntimeError, match="no such image"):
|
||||
runner.spawn("Echo")
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run", side_effect=FileNotFoundError)
|
||||
def test_spawn_docker_not_installed(self, mock_run):
|
||||
runner = DockerAgentRunner()
|
||||
with pytest.raises(RuntimeError, match="Docker CLI not found"):
|
||||
runner.spawn("Echo")
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_stop_success(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="cid\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
# Spawn first
|
||||
runner.spawn("Echo", agent_id="a1")
|
||||
cid = list(runner._containers.keys())[0]
|
||||
|
||||
mock_run.reset_mock()
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
|
||||
assert runner.stop(cid) is True
|
||||
assert cid not in runner._containers
|
||||
# Verify docker rm -f was called
|
||||
rm_cmd = mock_run.call_args[0][0]
|
||||
assert rm_cmd[0] == "docker"
|
||||
assert rm_cmd[1] == "rm"
|
||||
assert "-f" in rm_cmd
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run", side_effect=Exception("fail"))
|
||||
def test_stop_failure(self, mock_run):
|
||||
runner = DockerAgentRunner()
|
||||
runner._containers["fake"] = ManagedContainer(
|
||||
container_id="fake", agent_id="a", name="X", image="img"
|
||||
)
|
||||
assert runner.stop("fake") is False
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_stop_all(self, mock_run):
|
||||
# Return different container IDs so they don't overwrite each other
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0, stdout="cid_a\n", stderr=""),
|
||||
MagicMock(returncode=0, stdout="cid_b\n", stderr=""),
|
||||
]
|
||||
runner = DockerAgentRunner()
|
||||
runner.spawn("A", agent_id="a1")
|
||||
runner.spawn("B", agent_id="a2")
|
||||
assert len(runner._containers) == 2
|
||||
|
||||
mock_run.side_effect = None
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
stopped = runner.stop_all()
|
||||
assert stopped == 2
|
||||
assert len(runner._containers) == 0
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_list_containers(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="cid\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
runner.spawn("Echo", agent_id="e1")
|
||||
containers = runner.list_containers()
|
||||
assert len(containers) == 1
|
||||
assert containers[0].name == "Echo"
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_is_running_true(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="true\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
assert runner.is_running("somecid") is True
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_is_running_false(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="false\n", stderr="")
|
||||
runner = DockerAgentRunner()
|
||||
assert runner.is_running("somecid") is False
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run", side_effect=Exception("timeout"))
|
||||
def test_is_running_exception(self, mock_run):
|
||||
runner = DockerAgentRunner()
|
||||
assert runner.is_running("somecid") is False
|
||||
|
||||
@patch("swarm.docker_runner.subprocess.run")
|
||||
def test_build_env_flags(self, mock_run):
|
||||
runner = DockerAgentRunner(extra_env={"CUSTOM": "val"})
|
||||
flags = runner._build_env_flags("agent-1", "Echo", "summarise")
|
||||
# Should contain pairs of --env KEY=VALUE
|
||||
env_dict = {}
|
||||
for i, f in enumerate(flags):
|
||||
if f == "--env" and i + 1 < len(flags):
|
||||
k, v = flags[i + 1].split("=", 1)
|
||||
env_dict[k] = v
|
||||
assert env_dict["COORDINATOR_URL"] == "http://dashboard:8000"
|
||||
assert env_dict["AGENT_NAME"] == "Echo"
|
||||
assert env_dict["AGENT_ID"] == "agent-1"
|
||||
assert env_dict["AGENT_CAPABILITIES"] == "summarise"
|
||||
assert env_dict["CUSTOM"] == "val"
|
||||
129
tests/test_lnd_backend.py
Normal file
129
tests/test_lnd_backend.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Functional tests for lightning.lnd_backend — LND gRPC backend.
|
||||
|
||||
gRPC is stubbed via sys.modules; tests verify initialization, error
|
||||
handling, and the placeholder method behavior.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning.base import (
|
||||
BackendNotAvailableError,
|
||||
Invoice,
|
||||
LightningError,
|
||||
)
|
||||
|
||||
|
||||
def _make_grpc_mock():
|
||||
"""Create a mock grpc module with required attributes."""
|
||||
mock_grpc = MagicMock()
|
||||
mock_grpc.StatusCode.NOT_FOUND = "NOT_FOUND"
|
||||
mock_grpc.RpcError = type("RpcError", (Exception,), {
|
||||
"code": lambda self: "NOT_FOUND",
|
||||
"details": lambda self: "mocked error",
|
||||
})
|
||||
return mock_grpc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lnd_module():
|
||||
"""Reload lnd_backend with grpc stubbed so GRPC_AVAILABLE=True."""
|
||||
grpc_mock = _make_grpc_mock()
|
||||
old = sys.modules.get("grpc")
|
||||
sys.modules["grpc"] = grpc_mock
|
||||
try:
|
||||
import lightning.lnd_backend as mod
|
||||
importlib.reload(mod)
|
||||
yield mod
|
||||
finally:
|
||||
if old is not None:
|
||||
sys.modules["grpc"] = old
|
||||
else:
|
||||
sys.modules.pop("grpc", None)
|
||||
# Reload to restore original state
|
||||
import lightning.lnd_backend as mod2
|
||||
importlib.reload(mod2)
|
||||
|
||||
|
||||
class TestLndBackendInit:
|
||||
def test_init_with_explicit_params(self, lnd_module):
|
||||
backend = lnd_module.LndBackend(
|
||||
host="localhost:10009",
|
||||
tls_cert_path="/fake/tls.cert",
|
||||
macaroon_path="/fake/admin.macaroon",
|
||||
verify_ssl=True,
|
||||
)
|
||||
assert backend._host == "localhost:10009"
|
||||
assert backend._tls_cert_path == "/fake/tls.cert"
|
||||
assert backend._macaroon_path == "/fake/admin.macaroon"
|
||||
assert backend._verify_ssl is True
|
||||
|
||||
def test_init_from_env_vars(self, lnd_module):
|
||||
env = {
|
||||
"LND_GRPC_HOST": "remote:9999",
|
||||
"LND_TLS_CERT_PATH": "/env/tls.cert",
|
||||
"LND_MACAROON_PATH": "/env/macaroon",
|
||||
"LND_VERIFY_SSL": "false",
|
||||
}
|
||||
with patch.dict(os.environ, env):
|
||||
backend = lnd_module.LndBackend()
|
||||
assert backend._host == "remote:9999"
|
||||
assert backend._verify_ssl is False
|
||||
|
||||
def test_init_raises_without_grpc(self):
|
||||
from lightning.lnd_backend import LndBackend
|
||||
with pytest.raises(LightningError, match="grpcio not installed"):
|
||||
LndBackend()
|
||||
|
||||
def test_name_is_lnd(self, lnd_module):
|
||||
assert lnd_module.LndBackend.name == "lnd"
|
||||
|
||||
def test_grpc_available_true_after_reload(self, lnd_module):
|
||||
assert lnd_module.GRPC_AVAILABLE is True
|
||||
|
||||
|
||||
class TestLndBackendMethods:
|
||||
@pytest.fixture
|
||||
def backend(self, lnd_module):
|
||||
return lnd_module.LndBackend(
|
||||
host="localhost:10009",
|
||||
macaroon_path="/fake/path",
|
||||
)
|
||||
|
||||
def test_check_stub_raises_not_available(self, backend):
|
||||
"""_check_stub should raise BackendNotAvailableError when stub is None."""
|
||||
with pytest.raises(BackendNotAvailableError, match="not fully implemented"):
|
||||
backend._check_stub()
|
||||
|
||||
def test_create_invoice_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.create_invoice(1000, memo="test")
|
||||
|
||||
def test_check_payment_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.check_payment("abc123")
|
||||
|
||||
def test_get_invoice_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.get_invoice("abc123")
|
||||
|
||||
def test_settle_invoice_returns_false(self, backend):
|
||||
"""LND auto-settles, so manual settle always returns False."""
|
||||
result = backend.settle_invoice("hash", "preimage")
|
||||
assert result is False
|
||||
|
||||
def test_list_invoices_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.list_invoices()
|
||||
|
||||
def test_get_balance_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.get_balance_sats()
|
||||
|
||||
def test_health_check_raises_not_available(self, backend):
|
||||
with pytest.raises(BackendNotAvailableError):
|
||||
backend.health_check()
|
||||
70
tests/test_routes_tools.py
Normal file
70
tests/test_routes_tools.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Functional tests for dashboard routes: /tools and /swarm/live WebSocket.
|
||||
|
||||
Tests the tools dashboard page, API stats endpoint, and the swarm
|
||||
WebSocket live endpoint.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
# ── /tools route ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestToolsPage:
|
||||
def test_tools_page_returns_200(self, client):
|
||||
response = client.get("/tools")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_tools_page_html_content(self, client):
|
||||
response = client.get("/tools")
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_tools_api_stats_returns_json(self, client):
|
||||
response = client.get("/tools/api/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "all_stats" in data
|
||||
assert "available_tools" in data
|
||||
assert isinstance(data["available_tools"], list)
|
||||
assert len(data["available_tools"]) > 0
|
||||
|
||||
def test_tools_api_stats_includes_base_tools(self, client):
|
||||
response = client.get("/tools/api/stats")
|
||||
data = response.json()
|
||||
base_tools = {"web_search", "shell", "python", "read_file", "write_file", "list_files"}
|
||||
for tool in base_tools:
|
||||
assert tool in data["available_tools"], f"Missing: {tool}"
|
||||
|
||||
def test_tools_page_with_agents(self, client):
|
||||
"""Spawn an agent and verify tools page includes it."""
|
||||
client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
response = client.get("/tools")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ── /swarm/live WebSocket ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSwarmWebSocket:
|
||||
def test_websocket_connect_disconnect(self, client):
|
||||
with client.websocket_connect("/swarm/live") as ws:
|
||||
# Connection succeeds
|
||||
pass
|
||||
# Disconnect on context manager exit
|
||||
|
||||
def test_websocket_send_receive(self, client):
|
||||
"""The WebSocket endpoint should accept messages (it logs them)."""
|
||||
with client.websocket_connect("/swarm/live") as ws:
|
||||
ws.send_text("ping")
|
||||
# The endpoint only echoes via logging, not back to client.
|
||||
# The key test is that it doesn't crash on receiving a message.
|
||||
|
||||
def test_websocket_multiple_connections(self, client):
|
||||
"""Multiple clients can connect simultaneously."""
|
||||
with client.websocket_connect("/swarm/live") as ws1:
|
||||
with client.websocket_connect("/swarm/live") as ws2:
|
||||
ws1.send_text("hello from 1")
|
||||
ws2.send_text("hello from 2")
|
||||
242
tests/test_swarm_routes_functional.py
Normal file
242
tests/test_swarm_routes_functional.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Functional tests for swarm routes — /swarm/* endpoints.
|
||||
|
||||
Tests the full request/response cycle for swarm management endpoints,
|
||||
including error paths and HTMX partial rendering.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
class TestSwarmStatusRoutes:
|
||||
def test_swarm_status(self, client):
|
||||
response = client.get("/swarm")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "agents" in data or "status" in data or isinstance(data, dict)
|
||||
|
||||
def test_list_agents_empty(self, client):
|
||||
response = client.get("/swarm/agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "agents" in data
|
||||
assert isinstance(data["agents"], list)
|
||||
|
||||
|
||||
class TestSwarmAgentLifecycle:
|
||||
def test_spawn_agent(self, client):
|
||||
response = client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data or "agent_id" in data or "name" in data
|
||||
|
||||
def test_spawn_and_list(self, client):
|
||||
client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
response = client.get("/swarm/agents")
|
||||
data = response.json()
|
||||
assert len(data["agents"]) >= 1
|
||||
names = [a["name"] for a in data["agents"]]
|
||||
assert "Echo" in names
|
||||
|
||||
def test_stop_agent(self, client):
|
||||
spawn_resp = client.post("/swarm/spawn", data={"name": "TestAgent"})
|
||||
spawn_data = spawn_resp.json()
|
||||
agent_id = spawn_data.get("id") or spawn_data.get("agent_id")
|
||||
response = client.delete(f"/swarm/agents/{agent_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["stopped"] is True
|
||||
|
||||
def test_stop_nonexistent_agent(self, client):
|
||||
response = client.delete("/swarm/agents/nonexistent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["stopped"] is False
|
||||
|
||||
|
||||
class TestSwarmTaskLifecycle:
|
||||
def test_post_task(self, client):
|
||||
response = client.post("/swarm/tasks", data={"description": "Summarise readme"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["description"] == "Summarise readme"
|
||||
assert data["status"] == "bidding" # coordinator auto-opens auction
|
||||
assert "task_id" in data
|
||||
|
||||
def test_list_tasks(self, client):
|
||||
client.post("/swarm/tasks", data={"description": "Task A"})
|
||||
client.post("/swarm/tasks", data={"description": "Task B"})
|
||||
response = client.get("/swarm/tasks")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["tasks"]) >= 2
|
||||
|
||||
def test_list_tasks_filter_by_status(self, client):
|
||||
client.post("/swarm/tasks", data={"description": "Bidding task"})
|
||||
response = client.get("/swarm/tasks?status=bidding")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
for task in data["tasks"]:
|
||||
assert task["status"] == "bidding"
|
||||
|
||||
def test_list_tasks_invalid_status(self, client):
|
||||
"""Invalid TaskStatus enum value causes server error (unhandled ValueError)."""
|
||||
with pytest.raises(ValueError, match="is not a valid TaskStatus"):
|
||||
client.get("/swarm/tasks?status=invalid_status")
|
||||
|
||||
def test_get_task_by_id(self, client):
|
||||
post_resp = client.post("/swarm/tasks", data={"description": "Find me"})
|
||||
task_id = post_resp.json()["task_id"]
|
||||
response = client.get(f"/swarm/tasks/{task_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["description"] == "Find me"
|
||||
|
||||
def test_get_nonexistent_task(self, client):
|
||||
response = client.get("/swarm/tasks/nonexistent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
|
||||
def test_complete_task(self, client):
|
||||
# Create and assign a task first
|
||||
client.post("/swarm/spawn", data={"name": "Worker"})
|
||||
post_resp = client.post("/swarm/tasks", data={"description": "Do work"})
|
||||
task_id = post_resp.json()["task_id"]
|
||||
response = client.post(
|
||||
f"/swarm/tasks/{task_id}/complete",
|
||||
data={"result": "Work done"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "completed"
|
||||
|
||||
def test_complete_nonexistent_task(self, client):
|
||||
response = client.post(
|
||||
"/swarm/tasks/fake-id/complete",
|
||||
data={"result": "done"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_fail_task(self, client):
|
||||
post_resp = client.post("/swarm/tasks", data={"description": "Will fail"})
|
||||
task_id = post_resp.json()["task_id"]
|
||||
response = client.post(
|
||||
f"/swarm/tasks/{task_id}/fail",
|
||||
data={"reason": "out of memory"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "failed"
|
||||
|
||||
def test_fail_nonexistent_task(self, client):
|
||||
response = client.post(
|
||||
"/swarm/tasks/fake-id/fail",
|
||||
data={"reason": "no reason"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestSwarmAuction:
|
||||
def test_post_task_and_auction_no_agents(self, client):
|
||||
"""Auction with no bidders should still return a response."""
|
||||
with patch(
|
||||
"swarm.coordinator.AUCTION_DURATION_SECONDS", 0
|
||||
):
|
||||
response = client.post(
|
||||
"/swarm/tasks/auction",
|
||||
data={"description": "Quick task"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "task_id" in data
|
||||
|
||||
|
||||
class TestSwarmInsights:
|
||||
def test_insights_empty(self, client):
|
||||
response = client.get("/swarm/insights")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "agents" in data
|
||||
|
||||
def test_agent_insights(self, client):
|
||||
response = client.get("/swarm/insights/some-agent-id")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["agent_id"] == "some-agent-id"
|
||||
assert "total_bids" in data
|
||||
assert "win_rate" in data
|
||||
|
||||
|
||||
class TestSwarmUIPartials:
|
||||
def test_live_page(self, client):
|
||||
response = client.get("/swarm/live")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_agents_sidebar(self, client):
|
||||
response = client.get("/swarm/agents/sidebar")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_agent_panel_not_found(self, client):
|
||||
response = client.get("/swarm/agents/nonexistent/panel")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_agent_panel_found(self, client):
|
||||
spawn_resp = client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
agent_id = spawn_resp.json().get("id") or spawn_resp.json().get("agent_id")
|
||||
response = client.get(f"/swarm/agents/{agent_id}/panel")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_task_panel_route_shadowed(self, client):
|
||||
"""The /swarm/tasks/panel route is shadowed by /swarm/tasks/{task_id}.
|
||||
|
||||
FastAPI matches the dynamic {task_id} route first, so "panel" is
|
||||
treated as a task_id lookup, returning JSON with an error.
|
||||
This documents the current behavior (a routing order issue).
|
||||
"""
|
||||
response = client.get("/swarm/tasks/panel")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
|
||||
def test_direct_assign_with_agent(self, client):
|
||||
spawn_resp = client.post("/swarm/spawn", data={"name": "Worker"})
|
||||
agent_id = spawn_resp.json().get("id") or spawn_resp.json().get("agent_id")
|
||||
response = client.post(
|
||||
"/swarm/tasks/direct",
|
||||
data={"description": "Direct task", "agent_id": agent_id},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_direct_assign_without_agent(self, client):
|
||||
"""No agent → runs auction (with no bidders)."""
|
||||
with patch("swarm.coordinator.AUCTION_DURATION_SECONDS", 0):
|
||||
response = client.post(
|
||||
"/swarm/tasks/direct",
|
||||
data={"description": "Open task"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_message_agent_creates_task(self, client):
|
||||
"""Messaging a non-Timmy agent creates and assigns a task."""
|
||||
spawn_resp = client.post("/swarm/spawn", data={"name": "Echo"})
|
||||
agent_id = spawn_resp.json().get("id") or spawn_resp.json().get("agent_id")
|
||||
response = client.post(
|
||||
f"/swarm/agents/{agent_id}/message",
|
||||
data={"message": "Summarise the readme"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_message_nonexistent_agent(self, client):
|
||||
response = client.post(
|
||||
"/swarm/agents/fake-id/message",
|
||||
data={"message": "hello"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
169
tests/test_timmy_tools.py
Normal file
169
tests/test_timmy_tools.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Functional tests for timmy.tools — tool tracking, persona toolkits, catalog.
|
||||
|
||||
Covers tool usage statistics, persona-to-toolkit mapping, catalog generation,
|
||||
and graceful degradation when Agno is unavailable.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from timmy.tools import (
|
||||
_TOOL_USAGE,
|
||||
_track_tool_usage,
|
||||
get_tool_stats,
|
||||
get_tools_for_persona,
|
||||
get_all_available_tools,
|
||||
PERSONA_TOOLKITS,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_usage():
|
||||
"""Clear tool usage tracking between tests."""
|
||||
_TOOL_USAGE.clear()
|
||||
yield
|
||||
_TOOL_USAGE.clear()
|
||||
|
||||
|
||||
# ── Tool usage tracking ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestToolTracking:
|
||||
def test_track_creates_agent_entry(self):
|
||||
_track_tool_usage("agent-1", "web_search", success=True)
|
||||
assert "agent-1" in _TOOL_USAGE
|
||||
assert len(_TOOL_USAGE["agent-1"]) == 1
|
||||
|
||||
def test_track_records_metadata(self):
|
||||
_track_tool_usage("agent-1", "shell", success=False)
|
||||
entry = _TOOL_USAGE["agent-1"][0]
|
||||
assert entry["tool"] == "shell"
|
||||
assert entry["success"] is False
|
||||
assert "timestamp" in entry
|
||||
|
||||
def test_track_multiple_calls(self):
|
||||
_track_tool_usage("a1", "search")
|
||||
_track_tool_usage("a1", "read")
|
||||
_track_tool_usage("a1", "search")
|
||||
assert len(_TOOL_USAGE["a1"]) == 3
|
||||
|
||||
def test_track_multiple_agents(self):
|
||||
_track_tool_usage("a1", "search")
|
||||
_track_tool_usage("a2", "shell")
|
||||
assert len(_TOOL_USAGE) == 2
|
||||
|
||||
|
||||
class TestGetToolStats:
|
||||
def test_stats_for_specific_agent(self):
|
||||
_track_tool_usage("a1", "search")
|
||||
_track_tool_usage("a1", "read")
|
||||
_track_tool_usage("a1", "search")
|
||||
stats = get_tool_stats("a1")
|
||||
assert stats["agent_id"] == "a1"
|
||||
assert stats["total_calls"] == 3
|
||||
assert set(stats["tools_used"]) == {"search", "read"}
|
||||
assert len(stats["recent_calls"]) == 3
|
||||
|
||||
def test_stats_for_unknown_agent(self):
|
||||
stats = get_tool_stats("nonexistent")
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["tools_used"] == []
|
||||
assert stats["recent_calls"] == []
|
||||
|
||||
def test_stats_recent_capped_at_10(self):
|
||||
for i in range(15):
|
||||
_track_tool_usage("a1", f"tool_{i}")
|
||||
stats = get_tool_stats("a1")
|
||||
assert len(stats["recent_calls"]) == 10
|
||||
|
||||
def test_stats_all_agents(self):
|
||||
_track_tool_usage("a1", "search")
|
||||
_track_tool_usage("a2", "shell")
|
||||
_track_tool_usage("a2", "read")
|
||||
stats = get_tool_stats()
|
||||
assert "a1" in stats
|
||||
assert "a2" in stats
|
||||
assert stats["a1"]["total_calls"] == 1
|
||||
assert stats["a2"]["total_calls"] == 2
|
||||
|
||||
def test_stats_empty(self):
|
||||
stats = get_tool_stats()
|
||||
assert stats == {}
|
||||
|
||||
|
||||
# ── Persona toolkit mapping ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPersonaToolkits:
|
||||
def test_all_expected_personas_present(self):
|
||||
expected = {"echo", "mace", "helm", "seer", "forge", "quill", "pixel", "lyra", "reel"}
|
||||
assert set(PERSONA_TOOLKITS.keys()) == expected
|
||||
|
||||
def test_get_tools_for_known_persona_raises_without_agno(self):
|
||||
"""Agno is mocked but not a real package, so create_*_tools raises ImportError."""
|
||||
with pytest.raises(ImportError, match="Agno tools not available"):
|
||||
get_tools_for_persona("echo")
|
||||
|
||||
def test_get_tools_for_unknown_persona(self):
|
||||
result = get_tools_for_persona("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_creative_personas_return_none(self):
|
||||
"""Creative personas (pixel, lyra, reel) use stub toolkits that
|
||||
return None when Agno is unavailable."""
|
||||
for persona_id in ("pixel", "lyra", "reel"):
|
||||
result = get_tools_for_persona(persona_id)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── Tool catalog ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestToolCatalog:
|
||||
def test_catalog_contains_base_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
base_tools = {"web_search", "shell", "python", "read_file", "write_file", "list_files"}
|
||||
for tool_id in base_tools:
|
||||
assert tool_id in catalog, f"Missing base tool: {tool_id}"
|
||||
|
||||
def test_catalog_tool_structure(self):
|
||||
catalog = get_all_available_tools()
|
||||
for tool_id, info in catalog.items():
|
||||
assert "name" in info, f"{tool_id} missing 'name'"
|
||||
assert "description" in info, f"{tool_id} missing 'description'"
|
||||
assert "available_in" in info, f"{tool_id} missing 'available_in'"
|
||||
assert isinstance(info["available_in"], list)
|
||||
|
||||
def test_catalog_timmy_has_all_base_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
base_tools = {"web_search", "shell", "python", "read_file", "write_file", "list_files"}
|
||||
for tool_id in base_tools:
|
||||
assert "timmy" in catalog[tool_id]["available_in"], (
|
||||
f"Timmy missing tool: {tool_id}"
|
||||
)
|
||||
|
||||
def test_catalog_echo_research_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
assert "echo" in catalog["web_search"]["available_in"]
|
||||
assert "echo" in catalog["read_file"]["available_in"]
|
||||
# Echo should NOT have shell
|
||||
assert "echo" not in catalog["shell"]["available_in"]
|
||||
|
||||
def test_catalog_forge_code_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
assert "forge" in catalog["shell"]["available_in"]
|
||||
assert "forge" in catalog["python"]["available_in"]
|
||||
assert "forge" in catalog["write_file"]["available_in"]
|
||||
|
||||
def test_catalog_includes_git_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
git_tools = [k for k in catalog if "git" in k.lower()]
|
||||
# Should have some git tools from tools.git_tools
|
||||
assert len(git_tools) > 0
|
||||
|
||||
def test_catalog_includes_creative_tools(self):
|
||||
catalog = get_all_available_tools()
|
||||
# Should pick up image, music, video catalogs
|
||||
all_keys = list(catalog.keys())
|
||||
assert len(all_keys) > 6 # more than just base tools
|
||||
155
tests/test_voice_tts_functional.py
Normal file
155
tests/test_voice_tts_functional.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Functional tests for timmy_serve.voice_tts — TTS engine lifecycle.
|
||||
|
||||
pyttsx3 is not available in CI, so all tests mock the engine.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestVoiceTTS:
|
||||
"""Test TTS engine initialization, speak, and configuration."""
|
||||
|
||||
def test_init_success(self):
|
||||
mock_pyttsx3 = MagicMock()
|
||||
mock_engine = MagicMock()
|
||||
mock_pyttsx3.init.return_value = mock_engine
|
||||
|
||||
with patch.dict("sys.modules", {"pyttsx3": mock_pyttsx3}):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS(rate=200, volume=0.8)
|
||||
assert tts.available is True
|
||||
mock_engine.setProperty.assert_any_call("rate", 200)
|
||||
mock_engine.setProperty.assert_any_call("volume", 0.8)
|
||||
|
||||
def test_init_failure_graceful(self):
|
||||
"""When pyttsx3 import fails, VoiceTTS degrades gracefully."""
|
||||
with patch.dict("sys.modules", {"pyttsx3": None}):
|
||||
from importlib import reload
|
||||
import timmy_serve.voice_tts as mod
|
||||
tts = mod.VoiceTTS.__new__(mod.VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._rate = 175
|
||||
tts._volume = 0.9
|
||||
tts._available = False
|
||||
tts._lock = threading.Lock()
|
||||
assert tts.available is False
|
||||
|
||||
def test_speak_skips_when_unavailable(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._available = False
|
||||
tts._lock = threading.Lock()
|
||||
# Should not raise
|
||||
tts.speak("hello")
|
||||
|
||||
def test_speak_sync_skips_when_unavailable(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._available = False
|
||||
tts._lock = threading.Lock()
|
||||
tts.speak_sync("hello")
|
||||
|
||||
def test_speak_calls_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._available = True
|
||||
tts._lock = threading.Lock()
|
||||
|
||||
tts.speak("test speech")
|
||||
# Give the background thread time to execute
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
tts._engine.say.assert_called_with("test speech")
|
||||
|
||||
def test_speak_sync_calls_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._available = True
|
||||
tts._lock = threading.Lock()
|
||||
|
||||
tts.speak_sync("sync test")
|
||||
tts._engine.say.assert_called_with("sync test")
|
||||
tts._engine.runAndWait.assert_called_once()
|
||||
|
||||
def test_set_rate(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._rate = 175
|
||||
|
||||
tts.set_rate(220)
|
||||
assert tts._rate == 220
|
||||
tts._engine.setProperty.assert_called_with("rate", 220)
|
||||
|
||||
def test_set_rate_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts._rate = 175
|
||||
tts.set_rate(220)
|
||||
assert tts._rate == 220
|
||||
|
||||
def test_set_volume_clamped(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._volume = 0.9
|
||||
|
||||
tts.set_volume(1.5)
|
||||
assert tts._volume == 1.0
|
||||
|
||||
tts.set_volume(-0.5)
|
||||
assert tts._volume == 0.0
|
||||
|
||||
tts.set_volume(0.7)
|
||||
assert tts._volume == 0.7
|
||||
|
||||
def test_get_voices_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
assert tts.get_voices() == []
|
||||
|
||||
def test_get_voices_with_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
mock_voice = MagicMock()
|
||||
mock_voice.id = "voice1"
|
||||
mock_voice.name = "Default"
|
||||
mock_voice.languages = ["en"]
|
||||
|
||||
tts._engine = MagicMock()
|
||||
tts._engine.getProperty.return_value = [mock_voice]
|
||||
|
||||
voices = tts.get_voices()
|
||||
assert len(voices) == 1
|
||||
assert voices[0]["id"] == "voice1"
|
||||
assert voices[0]["name"] == "Default"
|
||||
assert voices[0]["languages"] == ["en"]
|
||||
|
||||
def test_get_voices_exception(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts._engine.getProperty.side_effect = RuntimeError("no voices")
|
||||
assert tts.get_voices() == []
|
||||
|
||||
def test_set_voice(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = MagicMock()
|
||||
tts.set_voice("voice_id_1")
|
||||
tts._engine.setProperty.assert_called_with("voice", "voice_id_1")
|
||||
|
||||
def test_set_voice_no_engine(self):
|
||||
from timmy_serve.voice_tts import VoiceTTS
|
||||
tts = VoiceTTS.__new__(VoiceTTS)
|
||||
tts._engine = None
|
||||
tts.set_voice("voice_id_1") # should not raise
|
||||
100
tests/test_watchdog_functional.py
Normal file
100
tests/test_watchdog_functional.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Functional tests for self_tdd.watchdog — continuous test runner.
|
||||
|
||||
All subprocess calls are mocked to avoid running real pytest.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from self_tdd.watchdog import _run_tests, watch
|
||||
|
||||
|
||||
class TestRunTests:
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_passing(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=0,
|
||||
stdout="5 passed\n",
|
||||
stderr="",
|
||||
)
|
||||
passed, output = _run_tests()
|
||||
assert passed is True
|
||||
assert "5 passed" in output
|
||||
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_failing(self, mock_run):
|
||||
mock_run.return_value = MagicMock(
|
||||
returncode=1,
|
||||
stdout="2 failed, 3 passed\n",
|
||||
stderr="ERRORS",
|
||||
)
|
||||
passed, output = _run_tests()
|
||||
assert passed is False
|
||||
assert "2 failed" in output
|
||||
assert "ERRORS" in output
|
||||
|
||||
@patch("self_tdd.watchdog.subprocess.run")
|
||||
def test_run_tests_command_format(self, mock_run):
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="")
|
||||
_run_tests()
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert "pytest" in " ".join(cmd)
|
||||
assert "tests/" in cmd
|
||||
assert "-q" in cmd
|
||||
assert "--tb=short" in cmd
|
||||
assert mock_run.call_args[1]["capture_output"] is True
|
||||
assert mock_run.call_args[1]["text"] is True
|
||||
|
||||
|
||||
class TestWatch:
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_first_pass(self, mock_typer, mock_tests, mock_sleep):
|
||||
"""First iteration: None→passing → should print green message."""
|
||||
call_count = 0
|
||||
|
||||
def side_effect():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise KeyboardInterrupt
|
||||
return (True, "all good")
|
||||
|
||||
mock_tests.side_effect = side_effect
|
||||
watch(interval=10)
|
||||
# Should have printed green "All tests passing" message
|
||||
mock_typer.secho.assert_called()
|
||||
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_regression(self, mock_typer, mock_tests, mock_sleep):
|
||||
"""Regression: passing→failing → should print red message + output."""
|
||||
results = [(True, "ok"), (False, "FAILED: test_foo"), KeyboardInterrupt]
|
||||
idx = 0
|
||||
|
||||
def side_effect():
|
||||
nonlocal idx
|
||||
if idx >= len(results):
|
||||
raise KeyboardInterrupt
|
||||
r = results[idx]
|
||||
idx += 1
|
||||
if isinstance(r, type) and issubclass(r, BaseException):
|
||||
raise r()
|
||||
return r
|
||||
|
||||
mock_tests.side_effect = side_effect
|
||||
watch(interval=5)
|
||||
# Should have printed red "Regression detected" at some point
|
||||
secho_calls = [str(c) for c in mock_typer.secho.call_args_list]
|
||||
assert any("Regression" in c for c in secho_calls) or any("RED" in c for c in secho_calls)
|
||||
|
||||
@patch("self_tdd.watchdog.time.sleep")
|
||||
@patch("self_tdd.watchdog._run_tests")
|
||||
@patch("self_tdd.watchdog.typer")
|
||||
def test_watch_keyboard_interrupt(self, mock_typer, mock_tests, mock_sleep):
|
||||
mock_tests.side_effect = KeyboardInterrupt
|
||||
watch(interval=60)
|
||||
mock_typer.echo.assert_called() # "Watchdog stopped"
|
||||
Reference in New Issue
Block a user