refactor: streamline scratchpad handling in AIAgent

- Removed static methods for converting and checking <REASONING_SCRATCHPAD> tags, simplifying the codebase.
- Replaced calls to the removed methods with direct function calls for better clarity and maintainability.
- Updated trajectory saving logic to utilize a dedicated function for improved organization and readability.
This commit is contained in:
teknium1
2026-02-23 09:55:09 -08:00
parent 8fedbf87d9
commit d18c753b3c

View File

@@ -637,43 +637,6 @@ class AIAgent:
return json.dumps(formatted_tools, ensure_ascii=False)
@staticmethod
def _convert_scratchpad_to_think(content: str) -> str:
"""
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
When native thinking/reasoning is disabled and the model is prompted to
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
to the standard <think> format used in our trajectory storage.
Args:
content: Assistant message content that may contain scratchpad tags
Returns:
Content with scratchpad tags replaced by think tags
"""
if not content or "<REASONING_SCRATCHPAD>" not in content:
return content
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
@staticmethod
def _has_incomplete_scratchpad(content: str) -> bool:
"""
Check if content has an opening <REASONING_SCRATCHPAD> without a closing tag.
This indicates the model ran out of output tokens mid-reasoning, producing
a broken turn that shouldn't be saved. The caller should retry or discard.
Args:
content: Assistant message content to check
Returns:
True if there's an unclosed scratchpad tag
"""
if not content:
return False
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
"""
Convert internal message format to trajectory format for saving.
@@ -738,7 +701,7 @@ class AIAgent:
if msg.get("content") and msg["content"].strip():
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
# (used when native thinking is disabled and model reasons via XML)
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
content += convert_scratchpad_to_think(msg["content"]) + "\n"
# Add tool calls wrapped in XML tags
for tool_call in msg["tool_calls"]:
@@ -813,7 +776,7 @@ class AIAgent:
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
# (used when native thinking is disabled and model reasons via XML)
raw_content = msg["content"] or ""
content += self._convert_scratchpad_to_think(raw_content)
content += convert_scratchpad_to_think(raw_content)
# Ensure every gpt turn has a <think> block (empty if no reasoning)
if "<think>" not in content:
@@ -846,27 +809,8 @@ class AIAgent:
if not self.save_trajectories:
return
# Convert messages to trajectory format
trajectory = self._convert_to_trajectory_format(messages, user_query, completed)
# Determine which file to save to
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
# Create trajectory entry
entry = {
"conversations": trajectory,
"timestamp": datetime.now().isoformat(),
"model": self.model,
"completed": completed
}
# Append to JSONL file
try:
with open(filename, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
logger.info("Trajectory saved to %s", filename)
except Exception as e:
logger.warning("Failed to save trajectory: %s", e)
_save_trajectory_to_file(trajectory, self.model, completed)
def _mask_api_key_for_logs(self, key: Optional[str]) -> Optional[str]:
if not key:
@@ -2134,7 +2078,7 @@ class AIAgent:
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
if self._has_incomplete_scratchpad(assistant_message.content or ""):
if has_incomplete_scratchpad(assistant_message.content or ""):
if not hasattr(self, '_incomplete_scratchpad_retries'):
self._incomplete_scratchpad_retries = 0
self._incomplete_scratchpad_retries += 1