Fix session saving to DB with full conversation history (not just user/assistant messages without tool calls)

This commit is contained in:
teknium1
2026-02-22 17:10:24 -08:00
parent e1604b2b4a
commit 6037b6a5ab

View File

@@ -503,6 +503,42 @@ class AIAgent:
if self.verbose_logging:
logging.warning(f"Failed to cleanup browser for task {task_id}: {e}")
def _persist_session(self, messages: List[Dict], conversation_history: List[Dict] = None):
"""Save session state to both JSON log and SQLite on any exit path.
Ensures conversations are never lost, even on errors or early returns.
"""
self._session_messages = messages
self._save_session_log(messages)
self._flush_messages_to_session_db(messages, conversation_history)
def _log_msg_to_db(self, msg: Dict):
"""Log a single message to SQLite immediately. Called after each messages.append()."""
if not self._session_db:
return
try:
role = msg.get("role", "unknown")
content = msg.get("content")
tool_calls_data = None
if hasattr(msg, "tool_calls") and msg.tool_calls:
tool_calls_data = [
{"name": tc.function.name, "arguments": tc.function.arguments}
for tc in msg.tool_calls
]
elif isinstance(msg.get("tool_calls"), list):
tool_calls_data = msg["tool_calls"]
self._session_db.append_message(
session_id=self.session_id,
role=role,
content=content,
tool_name=msg.get("tool_name"),
tool_calls=tool_calls_data,
tool_call_id=msg.get("tool_call_id"),
finish_reason=msg.get("finish_reason"),
)
except Exception as e:
logger.debug("Session DB log_msg failed: %s", e)
def _flush_messages_to_session_db(self, messages: List[Dict], conversation_history: List[Dict] = None):
"""Persist any un-logged messages to the SQLite session store.
@@ -1490,11 +1526,13 @@ class AIAgent:
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
logging.debug(f"Tool result preview: {result_preview}...")
messages.append({
tool_msg = {
"role": "tool",
"content": function_result,
"tool_call_id": tool_call.id
})
}
messages.append(tool_msg)
self._log_msg_to_db(tool_msg)
if not self.quiet_mode:
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
@@ -1504,11 +1542,13 @@ class AIAgent:
remaining = len(assistant_message.tool_calls) - i
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
for skipped_tc in assistant_message.tool_calls[i:]:
messages.append({
skip_msg = {
"role": "tool",
"content": "[Tool execution skipped - user sent a new message]",
"tool_call_id": skipped_tc.id
})
}
messages.append(skip_msg)
self._log_msg_to_db(skip_msg)
break
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
@@ -1622,10 +1662,9 @@ class AIAgent:
)
# Add user message
messages.append({
"role": "user",
"content": user_message
})
user_msg = {"role": "user", "content": user_message}
messages.append(user_msg)
self._log_msg_to_db(user_msg)
if not self.quiet_mode:
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
@@ -1645,13 +1684,6 @@ class AIAgent:
active_system_prompt = self._cached_system_prompt
# Log user message to SQLite
if self._session_db:
try:
self._session_db.append_message(self.session_id, "user", user_message)
except Exception as e:
logger.debug("Session DB append_message failed: %s", e)
# Main conversation loop
api_call_count = 0
final_response = None
@@ -1822,7 +1854,7 @@ class AIAgent:
if retry_count > max_retries:
print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.")
logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
@@ -1841,7 +1873,7 @@ class AIAgent:
while time.time() < sleep_end:
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": "Operation interrupted.",
"messages": messages,
@@ -1865,7 +1897,7 @@ class AIAgent:
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
self._cleanup_task_resources(effective_task_id)
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
@@ -1878,7 +1910,7 @@ class AIAgent:
else:
# First message was truncated - mark as failed
print(f"{self.log_prefix}❌ First response truncated - cannot recover")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
@@ -1917,7 +1949,7 @@ class AIAgent:
thinking_spinner.stop("")
thinking_spinner = None
print(f"{self.log_prefix}⚡ Interrupted during API call.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
interrupted = True
final_response = "Operation interrupted."
break
@@ -1943,7 +1975,7 @@ class AIAgent:
# Check for interrupt before deciding to retry
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": "Operation interrupted.",
"messages": messages,
@@ -1972,7 +2004,7 @@ class AIAgent:
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
@@ -2004,7 +2036,7 @@ class AIAgent:
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"messages": messages,
"completed": False,
@@ -2030,7 +2062,7 @@ class AIAgent:
while time.time() < sleep_end:
if self._interrupt_requested:
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": "Operation interrupted.",
"messages": messages,
@@ -2071,7 +2103,7 @@ class AIAgent:
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
self._cleanup_task_resources(effective_task_id)
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
@@ -2119,7 +2151,7 @@ class AIAgent:
print(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.")
# Return partial result - don't include the bad tool call in messages
self._invalid_tool_retries = 0
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
@@ -2170,8 +2202,9 @@ class AIAgent:
f"For tools with no required parameters, use an empty object: {{}}. "
f"Please either retry the tool call with valid JSON, or respond without using that tool."
)
messages.append({"role": "user", "content": recovery_msg})
# Continue the loop - model will see this message and can recover
recovery_dict = {"role": "user", "content": recovery_msg}
messages.append(recovery_dict)
self._log_msg_to_db(recovery_dict)
continue
# Reset retry counter on successful JSON validation
@@ -2180,6 +2213,7 @@ class AIAgent:
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
messages.append(assistant_msg)
self._log_msg_to_db(assistant_msg)
self._execute_tool_calls(assistant_message, messages, effective_task_id)
@@ -2236,9 +2270,10 @@ class AIAgent:
"finish_reason": finish_reason,
}
messages.append(empty_msg)
self._log_msg_to_db(empty_msg)
self._cleanup_task_resources(effective_task_id)
self._flush_messages_to_session_db(messages, conversation_history)
self._persist_session(messages, conversation_history)
return {
"final_response": final_response or None,
@@ -2256,6 +2291,7 @@ class AIAgent:
final_msg = self._build_assistant_message(assistant_message, finish_reason)
messages.append(final_msg)
self._log_msg_to_db(final_msg)
if not self.quiet_mode:
print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
@@ -2286,11 +2322,13 @@ class AIAgent:
}
for tc in msg["tool_calls"]:
if tc["id"] not in answered_ids:
messages.append({
err_msg = {
"role": "tool",
"tool_call_id": tc["id"],
"content": f"Error executing tool: {error_msg}",
})
}
messages.append(err_msg)
self._log_msg_to_db(err_msg)
pending_handled = True
break
@@ -2298,10 +2336,12 @@ class AIAgent:
# Error happened before tool processing (e.g. response parsing).
# Use a user-role message so the model can see what went wrong
# without confusing the API with a fabricated assistant turn.
messages.append({
sys_err_msg = {
"role": "user",
"content": f"[System error during processing: {error_msg}]",
})
}
messages.append(sys_err_msg)
self._log_msg_to_db(sys_err_msg)
# If we're near the limit, break to avoid infinite loops
if api_call_count >= self.max_iterations - 1:
@@ -2320,11 +2360,8 @@ class AIAgent:
# Clean up VM and browser for this task after conversation completes
self._cleanup_task_resources(effective_task_id)
# Update session messages and save session log
self._session_messages = messages
self._save_session_log(messages)
self._flush_messages_to_session_db(messages, conversation_history)
# Persist session to both JSON log and SQLite
self._persist_session(messages, conversation_history)
# Build result with interrupt info if applicable
result = {