Improve error handling and type hints in session_search_tool
This commit is contained in:
@@ -20,7 +20,7 @@ import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
@@ -46,8 +46,16 @@ MAX_SESSION_CHARS = 100_000
|
||||
MAX_SUMMARY_TOKENS = 2000
|
||||
|
||||
|
||||
def _format_timestamp(ts) -> str:
|
||||
"""Convert a Unix timestamp (float/int) or ISO string to a human-readable date."""
|
||||
def _format_timestamp(ts: Optional[Any]) -> str:
|
||||
"""
|
||||
Convert a Unix timestamp (float/int) or ISO string to a human-readable date.
|
||||
|
||||
Args:
|
||||
ts: Unix timestamp (int/float), ISO string, or None
|
||||
|
||||
Returns:
|
||||
Human-readable date string or "unknown" if conversion fails
|
||||
"""
|
||||
if ts is None:
|
||||
return "unknown"
|
||||
try:
|
||||
@@ -61,8 +69,11 @@ def _format_timestamp(ts) -> str:
|
||||
dt = datetime.fromtimestamp(float(ts))
|
||||
return dt.strftime("%B %d, %Y at %I:%M %p")
|
||||
return ts
|
||||
except Exception:
|
||||
pass
|
||||
except (ValueError, OSError, OverflowError) as e:
|
||||
# Log specific errors for debugging while gracefully handling edge cases
|
||||
logging.debug("Failed to format timestamp %s: %s", ts, e)
|
||||
except Exception as e:
|
||||
logging.debug("Unexpected error formatting timestamp %s: %s", ts, e)
|
||||
return str(ts)
|
||||
|
||||
|
||||
@@ -236,18 +247,31 @@ def session_search(
|
||||
|
||||
# Resolve child sessions to their parent — delegation stores detailed
|
||||
# content in child sessions, but the user's conversation is the parent.
|
||||
def _resolve_to_parent(session_id):
|
||||
def _resolve_to_parent(session_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a session ID to its parent session ID, handling delegation chains.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to resolve
|
||||
|
||||
Returns:
|
||||
Parent session ID or None if resolution fails
|
||||
"""
|
||||
visited = set()
|
||||
sid = session_id
|
||||
while sid and sid not in visited:
|
||||
visited.add(sid)
|
||||
session = db.get_session(sid)
|
||||
if not session:
|
||||
break
|
||||
parent = session.get("parent_session_id")
|
||||
if parent:
|
||||
sid = parent
|
||||
else:
|
||||
try:
|
||||
session = db.get_session(sid)
|
||||
if not session:
|
||||
break
|
||||
parent = session.get("parent_session_id")
|
||||
if parent:
|
||||
sid = parent
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logging.debug("Error resolving parent for session %s: %s", sid, e)
|
||||
break
|
||||
return sid
|
||||
|
||||
@@ -278,7 +302,8 @@ def session_search(
|
||||
logging.warning(f"Failed to prepare session {session_id}: {e}")
|
||||
|
||||
# Summarize all sessions in parallel
|
||||
async def _summarize_all():
|
||||
async def _summarize_all() -> List[Union[str, Exception]]:
|
||||
"""Summarize all sessions in parallel."""
|
||||
coros = [
|
||||
_summarize_session(text, query, meta)
|
||||
for _, _, text, meta in tasks
|
||||
@@ -290,7 +315,14 @@ def session_search(
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
results = pool.submit(lambda: asyncio.run(_summarize_all())).result(timeout=60)
|
||||
except RuntimeError:
|
||||
# No event loop running, create a new one
|
||||
results = asyncio.run(_summarize_all())
|
||||
except concurrent.futures.TimeoutError:
|
||||
logging.warning("Session summarization timed out after 60 seconds")
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"error": "Session summarization timed out. Try a more specific query or reduce the limit.",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
summaries = []
|
||||
for (session_id, match_info, _, _), result in zip(tasks, results):
|
||||
|
||||
Reference in New Issue
Block a user