From 66daebe88f003bd635aa5b8b15d8d20cba049cbb Mon Sep 17 00:00:00 2001 From: teknium Date: Sat, 10 Jan 2026 13:04:43 +0000 Subject: [PATCH] Implement enhanced response handling and tool call validation in run_agent - Added methods to check for meaningful content after blocks and to retrieve messages up to the last complete assistant turn. - Introduced retry logic for handling truncated responses and invalid JSON arguments in tool calls, with a maximum retry limit. - Improved logging for invalid JSON and empty responses, ensuring better error tracking and handling. - Updated the batch data generation script to adjust dataset file, batch size, and ephemeral system prompt for improved context management. --- run_agent.py | 191 +++++++++++++++++++++++++++++++++++++++++- run_datagen_glm4.7.sh | 10 +-- 2 files changed, 194 insertions(+), 7 deletions(-) diff --git a/run_agent.py b/run_agent.py index 339547204..e4f23b6e7 100644 --- a/run_agent.py +++ b/run_agent.py @@ -208,6 +208,60 @@ class AIAgent: prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)") + def _has_content_after_think_block(self, content: str) -> bool: + """ + Check if content has actual text after any blocks. + + This detects cases where the model only outputs reasoning but no actual + response, which indicates an incomplete generation that should be retried. + + Args: + content: The assistant message content to check + + Returns: + True if there's meaningful content after think blocks, False otherwise + """ + if not content: + return False + + import re + # Remove all ... blocks (including nested ones, non-greedy) + cleaned = re.sub(r'.*?', '', content, flags=re.DOTALL) + + # Check if there's any non-whitespace content remaining + return bool(cleaned.strip()) + + def _get_messages_up_to_last_assistant(self, messages: List[Dict]) -> List[Dict]: + """ + Get messages up to (but not including) the last assistant turn. + + This is used when we need to "roll back" to the last successful point + in the conversation, typically when the final assistant message is + incomplete or malformed. + + Args: + messages: Full message list + + Returns: + Messages up to the last complete assistant turn (ending with user/tool message) + """ + if not messages: + return [] + + # Find the index of the last assistant message + last_assistant_idx = None + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == "assistant": + last_assistant_idx = i + break + + if last_assistant_idx is None: + # No assistant message found, return all messages + return messages.copy() + + # Return everything up to (not including) the last assistant message + return messages[:last_assistant_idx] + def _format_tools_for_system_message(self) -> str: """ Format tool definitions for the system message in the trajectory format. @@ -292,9 +346,19 @@ class AIAgent: # Add tool calls wrapped in XML tags for tool_call in msg["tool_calls"]: + # Parse arguments - should always succeed since we validate during conversation + # but keep try-except as safety net + try: + arguments = json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] + except json.JSONDecodeError: + # This shouldn't happen since we validate and retry during conversation, + # but if it does, log warning and use empty dict + logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}") + arguments = {} + tool_call_json = { "name": tool_call["function"]["name"], - "arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"] + "arguments": arguments } content += f"\n{json.dumps(tool_call_json, ensure_ascii=False)}\n\n" @@ -417,6 +481,12 @@ class AIAgent: # Generate unique task_id if not provided to isolate VMs between concurrent tasks import uuid effective_task_id = task_id or str(uuid.uuid4()) + + # Reset retry counters at the start of each conversation to prevent state leakage + self._invalid_tool_retries = 0 + self._invalid_json_retries = 0 + self._empty_content_retries = 0 + # Initialize conversation messages = conversation_history or [] @@ -540,6 +610,45 @@ class AIAgent: time.sleep(wait_time) continue # Retry the API call + # Check finish_reason before proceeding + finish_reason = response.choices[0].finish_reason + + # Handle "length" finish_reason - response was truncated + if finish_reason == "length": + print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens") + + # If we have prior messages, roll back to last complete state + if len(messages) > 1: + print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn") + rolled_back_messages = self._get_messages_up_to_last_assistant(messages) + + # Clean up VM + try: + cleanup_vm(effective_task_id) + except Exception as e: + if self.verbose_logging: + logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}") + + return { + "final_response": None, + "messages": rolled_back_messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": "Response truncated due to output length limit" + } + else: + # First message was truncated - mark as failed + print(f"{self.log_prefix}❌ First response truncated - cannot recover") + return { + "final_response": None, + "messages": messages, + "api_calls": api_call_count, + "completed": False, + "failed": True, + "error": "First response truncated due to output length limit" + } + break # Success, exit retry loop except Exception as api_error: @@ -638,6 +747,40 @@ class AIAgent: if hasattr(self, '_invalid_tool_retries'): self._invalid_tool_retries = 0 + # Validate tool call arguments are valid JSON + invalid_json_args = [] + for tc in assistant_message.tool_calls: + try: + json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + invalid_json_args.append((tc.function.name, str(e))) + + if invalid_json_args: + # Track retries for invalid JSON arguments + self._invalid_json_retries += 1 + + tool_name, error_msg = invalid_json_args[0] + print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}") + + if self._invalid_json_retries < 3: + print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...") + # Don't add anything to messages, just retry the API call + continue + else: + print(f"{self.log_prefix}❌ Max retries (3) for invalid JSON arguments exceeded. Stopping as partial.") + self._invalid_json_retries = 0 # Reset for next conversation + return { + "final_response": None, + "messages": messages, # Messages up to last valid point + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": f"Model generated invalid JSON arguments for tool '{tool_name}': {error_msg}" + } + + # Reset retry counter on successful JSON validation + self._invalid_json_retries = 0 + # Extract reasoning from response if available (for reasoning models like minimax, kimi, etc.) reasoning_content = None if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning: @@ -667,10 +810,12 @@ class AIAgent: for i, tool_call in enumerate(assistant_message.tool_calls, 1): function_name = tool_call.function.name + # Parse arguments - should always succeed since we validated above try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: - print(f"❌ Invalid JSON in tool call arguments: {e}") + # This shouldn't happen since we validate and retry above + logging.warning(f"Unexpected JSON error after validation: {e}") function_args = {} # Preview tool call arguments @@ -712,6 +857,48 @@ class AIAgent: # No tool calls - this is the final response final_response = assistant_message.content or "" + # Check if response only has think block with no actual content after it + if not self._has_content_after_think_block(final_response): + # Track retries for empty-after-think responses + if not hasattr(self, '_empty_content_retries'): + self._empty_content_retries = 0 + self._empty_content_retries += 1 + + content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response + print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it") + print(f"{self.log_prefix} Content: '{content_preview}'") + + if self._empty_content_retries < 3: + print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...") + # Don't add the incomplete message, just retry + continue + else: + # Max retries exceeded - roll back to last complete assistant turn + print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded. Rolling back to last complete turn.") + self._empty_content_retries = 0 # Reset for next conversation + + rolled_back_messages = self._get_messages_up_to_last_assistant(messages) + + # Clean up VM + try: + cleanup_vm(effective_task_id) + except Exception as e: + if self.verbose_logging: + logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}") + + return { + "final_response": None, + "messages": rolled_back_messages, + "api_calls": api_call_count, + "completed": False, + "partial": True, + "error": "Model generated only think blocks with no actual response after 3 retries" + } + + # Reset retry counter on successful content + if hasattr(self, '_empty_content_retries'): + self._empty_content_retries = 0 + # Extract reasoning from response if available reasoning_content = None if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning: diff --git a/run_datagen_glm4.7.sh b/run_datagen_glm4.7.sh index 168cc422d..6224c481e 100755 --- a/run_datagen_glm4.7.sh +++ b/run_datagen_glm4.7.sh @@ -9,16 +9,16 @@ LOG_FILE="logs/glm4.7-thinking-sft1_$(date +%Y%m%d_%H%M%S).log" echo "📝 Logging output to: $LOG_FILE" python batch_runner.py \ - --dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_1.jsonl" \ - --batch_size=25 \ - --run_name="megascience_glm4.7-thinking-sft1" \ + --dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_2.jsonl" \ + --batch_size=20 \ + --run_name="megascience_glm4.7-thinking-sft2" \ --distribution="science" \ --model="z-ai/glm-4.7" \ --base_url="https://openrouter.ai/api/v1" \ --providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \ - --num_workers=10 \ + --num_workers=15 \ --max_turns=60 \ - --ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12." \ + --ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain focused context." \ 2>&1 | tee "$LOG_FILE" echo "✅ Log saved to: $LOG_FILE"