diff --git a/wolf/models.py b/wolf/models.py new file mode 100644 index 0000000..599c283 --- /dev/null +++ b/wolf/models.py @@ -0,0 +1,120 @@ +import requests +import logging +import json + +class ModelClient: + """ + Base model client for Wolf. + """ + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + + def generate(self, prompt, model_name, system_prompt=None): + raise NotImplementedError + +class OpenRouterClient(ModelClient): + def __init__(self, api_key): + super().__init__(api_key, "https://openrouter.ai/api/v1") + + def generate(self, prompt, model_name, system_prompt=None): + url = f"{self.base_url}/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + data = { + "model": model_name, + "messages": messages + } + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + +class GroqClient(ModelClient): + def __init__(self, api_key): + super().__init__(api_key, "https://api.groq.com/openai/v1") + + def generate(self, prompt, model_name, system_prompt=None): + url = f"{self.base_url}/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + data = { + "model": model_name, + "messages": messages + } + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + +class OllamaClient(ModelClient): + def __init__(self, base_url="http://localhost:11434"): + super().__init__(base_url=base_url) + + def generate(self, prompt, model_name, system_prompt=None): + url = f"{self.base_url}/api/generate" + data = { + "model": model_name, + "prompt": prompt, + "system": system_prompt if system_prompt else "", + "stream": False + } + response = requests.post(url, json=data) + response.raise_for_status() + return response.json()["response"] + +class ModelFactory: + @staticmethod + def get_client(provider, api_key=None, base_url=None): + if provider == "openrouter": + return OpenRouterClient(api_key) + elif provider == "groq": + return GroqClient(api_key) + elif provider == "ollama": + return OllamaClient(base_url or "http://localhost:11434") + elif provider == "openai": + # Reuse Groq client as it's OpenAI-compatible + client = GroqClient(api_key) + client.base_url = "https://api.openai.com/v1" + return client + elif provider == "anthropic": + # Need a specific Anthropic client + return AnthropicClient(api_key) + else: + raise ValueError(f"Unknown provider: {provider}") + +class AnthropicClient(ModelClient): + def __init__(self, api_key): + super().__init__(api_key, "https://api.anthropic.com/v1") + + def generate(self, prompt, model_name, system_prompt=None): + url = f"{self.base_url}/messages" + headers = { + "x-api-key": self.api_key, + "anthropic-version": "2023-06-01", + "Content-Type": "application/json" + } + messages = [{"role": "user", "content": prompt}] + data = { + "model": model_name, + "messages": messages, + "max_tokens": 4096 + } + if system_prompt: + data["system"] = system_prompt + + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + return response.json()["content"][0]["text"]