Add/Update wolf/models.py by Wolf

This commit is contained in:
2026-04-05 17:59:24 +00:00
parent a6a2330d4a
commit f9e4f40f05

120
wolf/models.py Normal file
View File

@@ -0,0 +1,120 @@
import requests
import logging
import json
class ModelClient:
"""
Base model client for Wolf.
"""
def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
def generate(self, prompt, model_name, system_prompt=None):
raise NotImplementedError
class OpenRouterClient(ModelClient):
def __init__(self, api_key):
super().__init__(api_key, "https://openrouter.ai/api/v1")
def generate(self, prompt, model_name, system_prompt=None):
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
data = {
"model": model_name,
"messages": messages
}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
class GroqClient(ModelClient):
def __init__(self, api_key):
super().__init__(api_key, "https://api.groq.com/openai/v1")
def generate(self, prompt, model_name, system_prompt=None):
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
data = {
"model": model_name,
"messages": messages
}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
class OllamaClient(ModelClient):
def __init__(self, base_url="http://localhost:11434"):
super().__init__(base_url=base_url)
def generate(self, prompt, model_name, system_prompt=None):
url = f"{self.base_url}/api/generate"
data = {
"model": model_name,
"prompt": prompt,
"system": system_prompt if system_prompt else "",
"stream": False
}
response = requests.post(url, json=data)
response.raise_for_status()
return response.json()["response"]
class ModelFactory:
@staticmethod
def get_client(provider, api_key=None, base_url=None):
if provider == "openrouter":
return OpenRouterClient(api_key)
elif provider == "groq":
return GroqClient(api_key)
elif provider == "ollama":
return OllamaClient(base_url or "http://localhost:11434")
elif provider == "openai":
# Reuse Groq client as it's OpenAI-compatible
client = GroqClient(api_key)
client.base_url = "https://api.openai.com/v1"
return client
elif provider == "anthropic":
# Need a specific Anthropic client
return AnthropicClient(api_key)
else:
raise ValueError(f"Unknown provider: {provider}")
class AnthropicClient(ModelClient):
def __init__(self, api_key):
super().__init__(api_key, "https://api.anthropic.com/v1")
def generate(self, prompt, model_name, system_prompt=None):
url = f"{self.base_url}/messages"
headers = {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json"
}
messages = [{"role": "user", "content": prompt}]
data = {
"model": model_name,
"messages": messages,
"max_tokens": 4096
}
if system_prompt:
data["system"] = system_prompt
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()["content"][0]["text"]