Add/Update wolf/models.py by Wolf
This commit is contained in:
120
wolf/models.py
Normal file
120
wolf/models.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
|
||||
class ModelClient:
|
||||
"""
|
||||
Base model client for Wolf.
|
||||
"""
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def generate(self, prompt, model_name, system_prompt=None):
|
||||
raise NotImplementedError
|
||||
|
||||
class OpenRouterClient(ModelClient):
|
||||
def __init__(self, api_key):
|
||||
super().__init__(api_key, "https://openrouter.ai/api/v1")
|
||||
|
||||
def generate(self, prompt, model_name, system_prompt=None):
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": messages
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
class GroqClient(ModelClient):
|
||||
def __init__(self, api_key):
|
||||
super().__init__(api_key, "https://api.groq.com/openai/v1")
|
||||
|
||||
def generate(self, prompt, model_name, system_prompt=None):
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": messages
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
class OllamaClient(ModelClient):
|
||||
def __init__(self, base_url="http://localhost:11434"):
|
||||
super().__init__(base_url=base_url)
|
||||
|
||||
def generate(self, prompt, model_name, system_prompt=None):
|
||||
url = f"{self.base_url}/api/generate"
|
||||
data = {
|
||||
"model": model_name,
|
||||
"prompt": prompt,
|
||||
"system": system_prompt if system_prompt else "",
|
||||
"stream": False
|
||||
}
|
||||
response = requests.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()["response"]
|
||||
|
||||
class ModelFactory:
|
||||
@staticmethod
|
||||
def get_client(provider, api_key=None, base_url=None):
|
||||
if provider == "openrouter":
|
||||
return OpenRouterClient(api_key)
|
||||
elif provider == "groq":
|
||||
return GroqClient(api_key)
|
||||
elif provider == "ollama":
|
||||
return OllamaClient(base_url or "http://localhost:11434")
|
||||
elif provider == "openai":
|
||||
# Reuse Groq client as it's OpenAI-compatible
|
||||
client = GroqClient(api_key)
|
||||
client.base_url = "https://api.openai.com/v1"
|
||||
return client
|
||||
elif provider == "anthropic":
|
||||
# Need a specific Anthropic client
|
||||
return AnthropicClient(api_key)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
class AnthropicClient(ModelClient):
|
||||
def __init__(self, api_key):
|
||||
super().__init__(api_key, "https://api.anthropic.com/v1")
|
||||
|
||||
def generate(self, prompt, model_name, system_prompt=None):
|
||||
url = f"{self.base_url}/messages"
|
||||
headers = {
|
||||
"x-api-key": self.api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
data = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
if system_prompt:
|
||||
data["system"] = system_prompt
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()["content"][0]["text"]
|
||||
Reference in New Issue
Block a user