97 lines
2.2 KiB
Python
97 lines
2.2 KiB
Python
"""
|
|
Provider Trait Spike - Python PoC
|
|
Based on Claw Code's Provider trait pattern
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Any, AsyncIterator, Optional
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
import json
|
|
|
|
|
|
class MessageRole(Enum):
|
|
SYSTEM = "system"
|
|
USER = "user"
|
|
ASSISTANT = "assistant"
|
|
TOOL = "tool"
|
|
|
|
|
|
@dataclass
|
|
class Message:
|
|
role: MessageRole
|
|
content: str
|
|
tool_calls: Optional[List[Dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class ToolCall:
|
|
id: str
|
|
name: str
|
|
arguments: Dict[str, Any]
|
|
|
|
|
|
@dataclass
|
|
class ProviderResponse:
|
|
content: str
|
|
tool_calls: List[ToolCall]
|
|
usage: Dict[str, int] # prompt_tokens, completion_tokens
|
|
|
|
|
|
class Provider(ABC):
|
|
"""
|
|
Abstract base for LLM providers.
|
|
Mirrors Claw Code's Provider trait:
|
|
- send_message: Main interaction point
|
|
- supports_tools: Capability check
|
|
- max_context: Context window size
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def send_message(
|
|
self,
|
|
messages: List[Message],
|
|
tools: Optional[List[Dict]] = None,
|
|
temperature: float = 0.7
|
|
) -> ProviderResponse:
|
|
"""Send messages to LLM, return response with optional tool calls."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""Provider identifier."""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def max_context(self) -> int:
|
|
"""Maximum context window in tokens."""
|
|
pass
|
|
|
|
@property
|
|
def supports_tools(self) -> bool:
|
|
"""Whether this provider supports function calling."""
|
|
return True
|
|
|
|
|
|
class ProviderFactory:
|
|
"""Factory for creating provider instances."""
|
|
|
|
_registry: Dict[str, type] = {}
|
|
|
|
@classmethod
|
|
def register(cls, name: str, provider_class: type):
|
|
cls._registry[name] = provider_class
|
|
|
|
@classmethod
|
|
def create(cls, name: str, config: Dict[str, Any]) -> Provider:
|
|
if name not in cls._registry:
|
|
raise ValueError(f"Unknown provider: {name}. Registered: {list(cls._registry.keys())}")
|
|
return cls._registry[name](**config)
|
|
|
|
@classmethod
|
|
def list_providers(cls) -> List[str]:
|
|
return list(cls._registry.keys())
|