Enhance tool normalization and API integration across modules

- Introduced normalization functions for tool statistics and error counts to ensure consistent schema across all trajectory entries, facilitating compatibility with HuggingFace datasets.
- Updated batch processing to utilize normalized tool stats and error counts, improving data integrity.
- Refactored vision tools and mixture of agents tool to integrate with OpenRouter API, replacing Nous Research API references and updating model configurations.
- Enabled reasoning capabilities in API calls for enhanced response quality across various tools.
- Improved error handling and API key validation for OpenRouter integration.
This commit is contained in:
teknium
2026-01-14 13:40:10 +00:00
parent 66daebe88f
commit 13d360030f
6 changed files with 172 additions and 61 deletions

View File

@@ -44,6 +44,70 @@ from toolset_distributions import (
# Global configuration for worker processes # Global configuration for worker processes
_WORKER_CONFIG = {} _WORKER_CONFIG = {}
# All possible tools - used to ensure consistent schema across all trajectory entries
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
ALL_POSSIBLE_TOOLS = {
'terminal', 'web_search', 'web_extract', 'web_crawl',
'vision_analyze', 'image_generate', 'mixture_of_agents'
}
# Default stats for tools that weren't used
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]:
"""
Normalize tool_stats to include all possible tools with consistent schema.
This ensures HuggingFace datasets can load the JSONL without schema mismatch errors.
Tools that weren't used get zero counts.
Args:
tool_stats (Dict): Raw tool statistics from extraction
Returns:
Dict: Normalized tool statistics with all tools present
"""
normalized = {}
# Add all possible tools with defaults
for tool in ALL_POSSIBLE_TOOLS:
if tool in tool_stats:
normalized[tool] = tool_stats[tool].copy()
else:
normalized[tool] = DEFAULT_TOOL_STATS.copy()
# Also include any unexpected tools (in case new tools are added)
for tool, stats in tool_stats.items():
if tool not in normalized:
normalized[tool] = stats.copy()
return normalized
def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]:
"""
Normalize tool_error_counts to include all possible tools.
Args:
tool_error_counts (Dict): Raw error counts mapping
Returns:
Dict: Normalized error counts with all tools present
"""
normalized = {}
# Add all possible tools with zero defaults
for tool in ALL_POSSIBLE_TOOLS:
normalized[tool] = tool_error_counts.get(tool, 0)
# Also include any unexpected tools
for tool, count in tool_error_counts.items():
if tool not in normalized:
normalized[tool] = count
return normalized
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
""" """
@@ -273,12 +337,16 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
# Save trajectory if successful # Save trajectory if successful
if result["success"] and result["trajectory"]: if result["success"] and result["trajectory"]:
# Create tool_error_counts mapping tool names to their failure counts # Get and normalize tool stats for consistent schema across all entries
tool_stats = result.get("tool_stats", {}) raw_tool_stats = result.get("tool_stats", {})
tool_error_counts = { tool_stats = _normalize_tool_stats(raw_tool_stats)
# Create normalized tool_error_counts mapping tool names to their failure counts
raw_error_counts = {
tool_name: stats.get("failure", 0) tool_name: stats.get("failure", 0)
for tool_name, stats in tool_stats.items() for tool_name, stats in raw_tool_stats.items()
} }
tool_error_counts = _normalize_tool_error_counts(raw_error_counts)
trajectory_entry = { trajectory_entry = {
"prompt_index": prompt_index, "prompt_index": prompt_index,
@@ -288,8 +356,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
"partial": result.get("partial", False), # True if stopped due to invalid tool calls "partial": result.get("partial", False), # True if stopped due to invalid tool calls
"api_calls": result["api_calls"], "api_calls": result["api_calls"],
"toolsets_used": result["toolsets_used"], "toolsets_used": result["toolsets_used"],
"tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} "tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized
"tool_error_counts": tool_error_counts # Simple: {tool: failure_count} "tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized
} }
# Append to batch output file # Append to batch output file

View File

@@ -513,7 +513,7 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}" full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
# Run async function in event loop # Run async function in event loop
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")) return asyncio.run(vision_analyze_tool(image_url, full_prompt, "google/gemini-3-flash-preview"))
else: else:
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False) return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)

View File

@@ -555,9 +555,22 @@ class AIAgent:
"timeout": 600.0 # 10 minute timeout for very long responses "timeout": 600.0 # 10 minute timeout for very long responses
} }
# Add provider preferences for OpenRouter via extra_body # Add extra_body for OpenRouter (provider preferences + reasoning)
extra_body = {}
# Add provider preferences if specified
if provider_preferences: if provider_preferences:
api_kwargs["extra_body"] = {"provider": provider_preferences} extra_body["provider"] = provider_preferences
# Enable reasoning with xhigh effort for OpenRouter
if "openrouter" in self.base_url.lower():
extra_body["reasoning"] = {
"enabled": True,
"effort": "xhigh"
}
if extra_body:
api_kwargs["extra_body"] = extra_body
response = self.client.chat.completions.create(**api_kwargs) response = self.client.chat.completions.create(**api_kwargs)

View File

@@ -24,9 +24,9 @@ Architecture:
2. Aggregator model synthesizes responses into a high-quality output 2. Aggregator model synthesizes responses into a high-quality output
3. Multiple layers can be used for iterative refinement (future enhancement) 3. Multiple layers can be used for iterative refinement (future enhancement)
Models Used: Models Used (via OpenRouter):
- Reference Models: claude-opus-4-20250514, gemini-2.5-pro, o4-mini, deepseek-r1 - Reference Models: claude-opus-4, gemini-2.5-pro, gpt-4.1, deepseek-r1
- Aggregator Model: claude-opus-4-20250514 (highest capability for synthesis) - Aggregator Model: claude-opus-4 (highest capability for synthesis)
Configuration: Configuration:
To customize the MoA setup, modify the configuration constants at the top of this file: To customize the MoA setup, modify the configuration constants at the top of this file:
@@ -54,23 +54,23 @@ from pathlib import Path
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from openai import AsyncOpenAI from openai import AsyncOpenAI
# Initialize Nous Research API client for MoA processing # Initialize OpenRouter API client for MoA processing
nous_client = AsyncOpenAI( openrouter_client = AsyncOpenAI(
api_key=os.getenv("NOUS_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://inference-api.nousresearch.com/v1" base_url="https://openrouter.ai/api/v1"
) )
# Configuration for MoA processing # Configuration for MoA processing
# Reference models - these generate diverse initial responses in parallel # Reference models - these generate diverse initial responses in parallel (OpenRouter slugs)
REFERENCE_MODELS = [ REFERENCE_MODELS = [
"claude-opus-4-20250514", "anthropic/claude-opus-4.5",
"gemini-2.5-pro", "google/gemini-3-pro-preview",
"gpt-5", "openai/gpt-5.2-pro",
"deepseek-r1" "deepseek/deepseek-v3.2"
] ]
# Aggregator model - synthesizes reference responses into final output # Aggregator model - synthesizes reference responses into final output
AGGREGATOR_MODEL = "claude-opus-4-20250514" # Use highest capability model for aggregation AGGREGATOR_MODEL = "anthropic/claude-opus-4.5" # Use highest capability model for aggregation
# Temperature settings optimized for MoA performance # Temperature settings optimized for MoA performance
REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives
@@ -187,7 +187,13 @@ async def _run_reference_model_safe(
# Build parameters for the API call # Build parameters for the API call
api_params = { api_params = {
"model": model, "model": model,
"messages": [{"role": "user", "content": user_prompt}] "messages": [{"role": "user", "content": user_prompt}],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
} }
# GPT models (especially gpt-4o-mini) don't support custom temperature values # GPT models (especially gpt-4o-mini) don't support custom temperature values
@@ -195,7 +201,7 @@ async def _run_reference_model_safe(
if not model.lower().startswith('gpt-'): if not model.lower().startswith('gpt-'):
api_params["temperature"] = temperature api_params["temperature"] = temperature
response = await nous_client.chat.completions.create(**api_params) response = await openrouter_client.chat.completions.create(**api_params)
content = response.choices[0].message.content.strip() content = response.choices[0].message.content.strip()
print(f"{model} responded ({len(content)} characters)") print(f"{model} responded ({len(content)} characters)")
@@ -248,7 +254,13 @@ async def _run_aggregator_model(
"messages": [ "messages": [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt} {"role": "user", "content": user_prompt}
] ],
"extra_body": {
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
} }
# GPT models (especially gpt-4o-mini) don't support custom temperature values # GPT models (especially gpt-4o-mini) don't support custom temperature values
@@ -256,7 +268,7 @@ async def _run_aggregator_model(
if not AGGREGATOR_MODEL.lower().startswith('gpt-'): if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
api_params["temperature"] = temperature api_params["temperature"] = temperature
response = await nous_client.chat.completions.create(**api_params) response = await openrouter_client.chat.completions.create(**api_params)
content = response.choices[0].message.content.strip() content = response.choices[0].message.content.strip()
print(f"✅ Aggregation complete ({len(content)} characters)") print(f"✅ Aggregation complete ({len(content)} characters)")
@@ -330,8 +342,8 @@ async def mixture_of_agents_tool(
print(f"📝 Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}") print(f"📝 Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
# Validate API key availability # Validate API key availability
if not os.getenv("NOUS_API_KEY"): if not os.getenv("OPENROUTER_API_KEY"):
raise ValueError("NOUS_API_KEY environment variable not set") raise ValueError("OPENROUTER_API_KEY environment variable not set")
# Use provided models or defaults # Use provided models or defaults
ref_models = reference_models or REFERENCE_MODELS ref_models = reference_models or REFERENCE_MODELS
@@ -439,14 +451,14 @@ async def mixture_of_agents_tool(
return json.dumps(result, indent=2, ensure_ascii=False) return json.dumps(result, indent=2, ensure_ascii=False)
def check_nous_api_key() -> bool: def check_openrouter_api_key() -> bool:
""" """
Check if the Nous Research API key is available in environment variables. Check if the OpenRouter API key is available in environment variables.
Returns: Returns:
bool: True if API key is set, False otherwise bool: True if API key is set, False otherwise
""" """
return bool(os.getenv("NOUS_API_KEY")) return bool(os.getenv("OPENROUTER_API_KEY"))
def check_moa_requirements() -> bool: def check_moa_requirements() -> bool:
@@ -456,7 +468,7 @@ def check_moa_requirements() -> bool:
Returns: Returns:
bool: True if requirements are met, False otherwise bool: True if requirements are met, False otherwise
""" """
return check_nous_api_key() return check_openrouter_api_key()
def get_debug_session_info() -> Dict[str, Any]: def get_debug_session_info() -> Dict[str, Any]:
@@ -522,15 +534,15 @@ if __name__ == "__main__":
print("=" * 50) print("=" * 50)
# Check if API key is available # Check if API key is available
api_available = check_nous_api_key() api_available = check_openrouter_api_key()
if not api_available: if not api_available:
print("NOUS_API_KEY environment variable not set") print("OPENROUTER_API_KEY environment variable not set")
print("Please set your API key: export NOUS_API_KEY='your-key-here'") print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
print("Get API key at: https://inference-api.nousresearch.com/") print("Get API key at: https://openrouter.ai/")
exit(1) exit(1)
else: else:
print("Nous Research API key found") print("OpenRouter API key found")
print("🛠️ MoA tools ready for use!") print("🛠️ MoA tools ready for use!")

View File

@@ -3,7 +3,7 @@
Vision Tools Module Vision Tools Module
This module provides vision analysis tools that work with image URLs. This module provides vision analysis tools that work with image URLs.
Uses Gemini Flash via Nous Research API for intelligent image understanding. Uses Gemini 3 Flash Preview via OpenRouter API for intelligent image understanding.
Available tools: Available tools:
- vision_analyze_tool: Analyze images from URLs with custom prompts - vision_analyze_tool: Analyze images from URLs with custom prompts
@@ -38,14 +38,14 @@ from typing import Dict, Any, Optional
from openai import AsyncOpenAI from openai import AsyncOpenAI
import httpx # Use httpx for async HTTP requests import httpx # Use httpx for async HTTP requests
# Initialize Nous Research API client for vision processing # Initialize OpenRouter API client for vision processing
nous_client = AsyncOpenAI( openrouter_client = AsyncOpenAI(
api_key=os.getenv("NOUS_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://inference-api.nousresearch.com/v1" base_url="https://openrouter.ai/api/v1"
) )
# Configuration for vision processing # Configuration for vision processing
DEFAULT_VISION_MODEL = "gemini-2.5-flash" DEFAULT_VISION_MODEL = "google/gemini-3-flash-preview"
# Debug mode configuration # Debug mode configuration
DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true" DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true"
@@ -220,7 +220,7 @@ async def vision_analyze_tool(
Analyze an image from a URL using vision AI. Analyze an image from a URL using vision AI.
This tool downloads images from URLs, converts them to base64, and processes This tool downloads images from URLs, converts them to base64, and processes
them using Gemini Flash via Nous Research API. The image is downloaded to a them using Gemini 3 Flash Preview via OpenRouter API. The image is downloaded to a
temporary location and automatically cleaned up after processing. temporary location and automatically cleaned up after processing.
The user_prompt parameter is expected to be pre-formatted by the calling The user_prompt parameter is expected to be pre-formatted by the calling
@@ -230,7 +230,7 @@ async def vision_analyze_tool(
Args: Args:
image_url (str): The URL of the image to analyze (must be http:// or https://) image_url (str): The URL of the image to analyze (must be http:// or https://)
user_prompt (str): The pre-formatted prompt for the vision model user_prompt (str): The pre-formatted prompt for the vision model
model (str): The vision model to use (default: gemini-2.5-flash) model (str): The vision model to use (default: google/gemini-3-flash-preview)
Returns: Returns:
str: JSON string containing the analysis results with the following structure: str: JSON string containing the analysis results with the following structure:
@@ -271,8 +271,8 @@ async def vision_analyze_tool(
raise ValueError("Invalid image URL format. Must start with http:// or https://") raise ValueError("Invalid image URL format. Must start with http:// or https://")
# Check API key availability # Check API key availability
if not os.getenv("NOUS_API_KEY"): if not os.getenv("OPENROUTER_API_KEY"):
raise ValueError("NOUS_API_KEY environment variable not set") raise ValueError("OPENROUTER_API_KEY environment variable not set")
# Download the image to a temporary location # Download the image to a temporary location
print(f"⬇️ Downloading image from URL...", flush=True) print(f"⬇️ Downloading image from URL...", flush=True)
@@ -319,12 +319,18 @@ async def vision_analyze_tool(
print(f"🧠 Processing image with {model}...", flush=True) print(f"🧠 Processing image with {model}...", flush=True)
# Call the vision API # Call the vision API with reasoning enabled
response = await nous_client.chat.completions.create( response = await openrouter_client.chat.completions.create(
model=model, model=model,
messages=messages, messages=messages,
temperature=0.1, # Low temperature for consistent analysis temperature=0.1, # Low temperature for consistent analysis
max_tokens=2000 # Generous limit for detailed analysis max_tokens=2000, # Generous limit for detailed analysis
extra_body={
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
) )
# Extract the analysis # Extract the analysis
@@ -374,14 +380,14 @@ async def vision_analyze_tool(
print(f"⚠️ Warning: Could not delete temporary file: {cleanup_error}", flush=True) print(f"⚠️ Warning: Could not delete temporary file: {cleanup_error}", flush=True)
def check_nous_api_key() -> bool: def check_openrouter_api_key() -> bool:
""" """
Check if the Nous Research API key is available in environment variables. Check if the OpenRouter API key is available in environment variables.
Returns: Returns:
bool: True if API key is set, False otherwise bool: True if API key is set, False otherwise
""" """
return bool(os.getenv("NOUS_API_KEY")) return bool(os.getenv("OPENROUTER_API_KEY"))
def check_vision_requirements() -> bool: def check_vision_requirements() -> bool:
@@ -391,7 +397,7 @@ def check_vision_requirements() -> bool:
Returns: Returns:
bool: True if requirements are met, False otherwise bool: True if requirements are met, False otherwise
""" """
return check_nous_api_key() return check_openrouter_api_key()
def get_debug_session_info() -> Dict[str, Any]: def get_debug_session_info() -> Dict[str, Any]:
@@ -425,15 +431,15 @@ if __name__ == "__main__":
print("=" * 40) print("=" * 40)
# Check if API key is available # Check if API key is available
api_available = check_nous_api_key() api_available = check_openrouter_api_key()
if not api_available: if not api_available:
print("NOUS_API_KEY environment variable not set") print("OPENROUTER_API_KEY environment variable not set")
print("Please set your API key: export NOUS_API_KEY='your-key-here'") print("Please set your API key: export OPENROUTER_API_KEY='your-key-here'")
print("Get API key at: https://inference-api.nousresearch.com/") print("Get API key at: https://openrouter.ai/")
exit(1) exit(1)
else: else:
print("Nous Research API key found") print("OpenRouter API key found")
print("🛠️ Vision tools ready for use!") print("🛠️ Vision tools ready for use!")
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}") print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")

View File

@@ -285,7 +285,13 @@ Create a markdown summary that captures all key information in a well-organized,
{"role": "user", "content": user_prompt} {"role": "user", "content": user_prompt}
], ],
temperature=0.1, temperature=0.1,
max_tokens=max_tokens max_tokens=max_tokens,
extra_body={
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
) )
return response.choices[0].message.content.strip() return response.choices[0].message.content.strip()
except Exception as api_error: except Exception as api_error:
@@ -398,7 +404,13 @@ Create a single, unified markdown summary."""
{"role": "user", "content": synthesis_prompt} {"role": "user", "content": synthesis_prompt}
], ],
temperature=0.1, temperature=0.1,
max_tokens=4000 max_tokens=4000,
extra_body={
"reasoning": {
"enabled": True,
"effort": "xhigh"
}
}
) )
final_summary = response.choices[0].message.content.strip() final_summary = response.choices[0].message.content.strip()