Enhance tool normalization and API integration across modules
- Introduced normalization functions for tool statistics and error counts to ensure consistent schema across all trajectory entries, facilitating compatibility with HuggingFace datasets. - Updated batch processing to utilize normalized tool stats and error counts, improving data integrity. - Refactored vision tools and mixture of agents tool to integrate with OpenRouter API, replacing Nous Research API references and updating model configurations. - Enabled reasoning capabilities in API calls for enhanced response quality across various tools. - Improved error handling and API key validation for OpenRouter integration.
This commit is contained in:
@@ -44,6 +44,70 @@ from toolset_distributions import (
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# All possible tools - used to ensure consistent schema across all trajectory entries
|
||||
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
|
||||
ALL_POSSIBLE_TOOLS = {
|
||||
'terminal', 'web_search', 'web_extract', 'web_crawl',
|
||||
'vision_analyze', 'image_generate', 'mixture_of_agents'
|
||||
}
|
||||
|
||||
# Default stats for tools that weren't used
|
||||
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
|
||||
|
||||
|
||||
def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
Normalize tool_stats to include all possible tools with consistent schema.
|
||||
|
||||
This ensures HuggingFace datasets can load the JSONL without schema mismatch errors.
|
||||
Tools that weren't used get zero counts.
|
||||
|
||||
Args:
|
||||
tool_stats (Dict): Raw tool statistics from extraction
|
||||
|
||||
Returns:
|
||||
Dict: Normalized tool statistics with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
if tool in tool_stats:
|
||||
normalized[tool] = tool_stats[tool].copy()
|
||||
else:
|
||||
normalized[tool] = DEFAULT_TOOL_STATS.copy()
|
||||
|
||||
# Also include any unexpected tools (in case new tools are added)
|
||||
for tool, stats in tool_stats.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = stats.copy()
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]:
|
||||
"""
|
||||
Normalize tool_error_counts to include all possible tools.
|
||||
|
||||
Args:
|
||||
tool_error_counts (Dict): Raw error counts mapping
|
||||
|
||||
Returns:
|
||||
Dict: Normalized error counts with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with zero defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
normalized[tool] = tool_error_counts.get(tool, 0)
|
||||
|
||||
# Also include any unexpected tools
|
||||
for tool, count in tool_error_counts.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = count
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
@@ -273,12 +337,16 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
# Create tool_error_counts mapping tool names to their failure counts
|
||||
tool_stats = result.get("tool_stats", {})
|
||||
tool_error_counts = {
|
||||
# Get and normalize tool stats for consistent schema across all entries
|
||||
raw_tool_stats = result.get("tool_stats", {})
|
||||
tool_stats = _normalize_tool_stats(raw_tool_stats)
|
||||
|
||||
# Create normalized tool_error_counts mapping tool names to their failure counts
|
||||
raw_error_counts = {
|
||||
tool_name: stats.get("failure", 0)
|
||||
for tool_name, stats in tool_stats.items()
|
||||
for tool_name, stats in raw_tool_stats.items()
|
||||
}
|
||||
tool_error_counts = _normalize_tool_error_counts(raw_error_counts)
|
||||
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
@@ -288,8 +356,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
"partial": result.get("partial", False), # True if stopped due to invalid tool calls
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"],
|
||||
"tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}}
|
||||
"tool_error_counts": tool_error_counts # Simple: {tool: failure_count}
|
||||
"tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized
|
||||
"tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized
|
||||
}
|
||||
|
||||
# Append to batch output file
|
||||
|
||||
Reference in New Issue
Block a user