Implement enhanced response handling and tool call validation in run_agent

- Added methods to check for meaningful content after <think> blocks and to retrieve messages up to the last complete assistant turn.
- Introduced retry logic for handling truncated responses and invalid JSON arguments in tool calls, with a maximum retry limit.
- Improved logging for invalid JSON and empty responses, ensuring better error tracking and handling.
- Updated the batch data generation script to adjust dataset file, batch size, and ephemeral system prompt for improved context management.
This commit is contained in:
teknium
2026-01-10 13:04:43 +00:00
parent 4071ba29da
commit 66daebe88f
2 changed files with 194 additions and 7 deletions

View File

@@ -208,6 +208,60 @@ class AIAgent:
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)")
def _has_content_after_think_block(self, content: str) -> bool:
"""
Check if content has actual text after any <think></think> blocks.
This detects cases where the model only outputs reasoning but no actual
response, which indicates an incomplete generation that should be retried.
Args:
content: The assistant message content to check
Returns:
True if there's meaningful content after think blocks, False otherwise
"""
if not content:
return False
import re
# Remove all <think>...</think> blocks (including nested ones, non-greedy)
cleaned = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL)
# Check if there's any non-whitespace content remaining
return bool(cleaned.strip())
def _get_messages_up_to_last_assistant(self, messages: List[Dict]) -> List[Dict]:
"""
Get messages up to (but not including) the last assistant turn.
This is used when we need to "roll back" to the last successful point
in the conversation, typically when the final assistant message is
incomplete or malformed.
Args:
messages: Full message list
Returns:
Messages up to the last complete assistant turn (ending with user/tool message)
"""
if not messages:
return []
# Find the index of the last assistant message
last_assistant_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i].get("role") == "assistant":
last_assistant_idx = i
break
if last_assistant_idx is None:
# No assistant message found, return all messages
return messages.copy()
# Return everything up to (not including) the last assistant message
return messages[:last_assistant_idx]
def _format_tools_for_system_message(self) -> str:
"""
Format tool definitions for the system message in the trajectory format.
@@ -292,9 +346,19 @@ class AIAgent:
# Add tool calls wrapped in XML tags
for tool_call in msg["tool_calls"]:
# Parse arguments - should always succeed since we validate during conversation
# but keep try-except as safety net
try:
arguments = json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
except json.JSONDecodeError:
# This shouldn't happen since we validate and retry during conversation,
# but if it does, log warning and use empty dict
logging.warning(f"Unexpected invalid JSON in trajectory conversion: {tool_call['function']['arguments'][:100]}")
arguments = {}
tool_call_json = {
"name": tool_call["function"]["name"],
"arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
"arguments": arguments
}
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
@@ -417,6 +481,12 @@ class AIAgent:
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
import uuid
effective_task_id = task_id or str(uuid.uuid4())
# Reset retry counters at the start of each conversation to prevent state leakage
self._invalid_tool_retries = 0
self._invalid_json_retries = 0
self._empty_content_retries = 0
# Initialize conversation
messages = conversation_history or []
@@ -540,6 +610,45 @@ class AIAgent:
time.sleep(wait_time)
continue # Retry the API call
# Check finish_reason before proceeding
finish_reason = response.choices[0].finish_reason
# Handle "length" finish_reason - response was truncated
if finish_reason == "length":
print(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens")
# If we have prior messages, roll back to last complete state
if len(messages) > 1:
print(f"{self.log_prefix} ⏪ Rolling back to last complete assistant turn")
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
# Clean up VM
try:
cleanup_vm(effective_task_id)
except Exception as e:
if self.verbose_logging:
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
return {
"final_response": None,
"messages": rolled_back_messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Response truncated due to output length limit"
}
else:
# First message was truncated - mark as failed
print(f"{self.log_prefix}❌ First response truncated - cannot recover")
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"failed": True,
"error": "First response truncated due to output length limit"
}
break # Success, exit retry loop
except Exception as api_error:
@@ -638,6 +747,40 @@ class AIAgent:
if hasattr(self, '_invalid_tool_retries'):
self._invalid_tool_retries = 0
# Validate tool call arguments are valid JSON
invalid_json_args = []
for tc in assistant_message.tool_calls:
try:
json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
invalid_json_args.append((tc.function.name, str(e)))
if invalid_json_args:
# Track retries for invalid JSON arguments
self._invalid_json_retries += 1
tool_name, error_msg = invalid_json_args[0]
print(f"{self.log_prefix}⚠️ Invalid JSON in tool call arguments for '{tool_name}': {error_msg}")
if self._invalid_json_retries < 3:
print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_json_retries}/3)...")
# Don't add anything to messages, just retry the API call
continue
else:
print(f"{self.log_prefix}❌ Max retries (3) for invalid JSON arguments exceeded. Stopping as partial.")
self._invalid_json_retries = 0 # Reset for next conversation
return {
"final_response": None,
"messages": messages, # Messages up to last valid point
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": f"Model generated invalid JSON arguments for tool '{tool_name}': {error_msg}"
}
# Reset retry counter on successful JSON validation
self._invalid_json_retries = 0
# Extract reasoning from response if available (for reasoning models like minimax, kimi, etc.)
reasoning_content = None
if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning:
@@ -667,10 +810,12 @@ class AIAgent:
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
function_name = tool_call.function.name
# Parse arguments - should always succeed since we validated above
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
print(f"❌ Invalid JSON in tool call arguments: {e}")
# This shouldn't happen since we validate and retry above
logging.warning(f"Unexpected JSON error after validation: {e}")
function_args = {}
# Preview tool call arguments
@@ -712,6 +857,48 @@ class AIAgent:
# No tool calls - this is the final response
final_response = assistant_message.content or ""
# Check if response only has think block with no actual content after it
if not self._has_content_after_think_block(final_response):
# Track retries for empty-after-think responses
if not hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
self._empty_content_retries += 1
content_preview = final_response[:80] + "..." if len(final_response) > 80 else final_response
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
print(f"{self.log_prefix} Content: '{content_preview}'")
if self._empty_content_retries < 3:
print(f"{self.log_prefix}🔄 Retrying API call ({self._empty_content_retries}/3)...")
# Don't add the incomplete message, just retry
continue
else:
# Max retries exceeded - roll back to last complete assistant turn
print(f"{self.log_prefix}❌ Max retries (3) for empty content exceeded. Rolling back to last complete turn.")
self._empty_content_retries = 0 # Reset for next conversation
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
# Clean up VM
try:
cleanup_vm(effective_task_id)
except Exception as e:
if self.verbose_logging:
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
return {
"final_response": None,
"messages": rolled_back_messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": "Model generated only think blocks with no actual response after 3 retries"
}
# Reset retry counter on successful content
if hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
# Extract reasoning from response if available
reasoning_content = None
if hasattr(assistant_message, 'reasoning') and assistant_message.reasoning:

View File

@@ -9,16 +9,16 @@ LOG_FILE="logs/glm4.7-thinking-sft1_$(date +%Y%m%d_%H%M%S).log"
echo "📝 Logging output to: $LOG_FILE"
python batch_runner.py \
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_1.jsonl" \
--batch_size=25 \
--run_name="megascience_glm4.7-thinking-sft1" \
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_2.jsonl" \
--batch_size=20 \
--run_name="megascience_glm4.7-thinking-sft2" \
--distribution="science" \
--model="z-ai/glm-4.7" \
--base_url="https://openrouter.ai/api/v1" \
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
--num_workers=10 \
--num_workers=15 \
--max_turns=60 \
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12." \
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain focused context." \
2>&1 | tee "$LOG_FILE"
echo "✅ Log saved to: $LOG_FILE"