forked from Rockachopa/Timmy-time-dashboard
256 lines
9.0 KiB
Python
256 lines
9.0 KiB
Python
"""Trajectory exporter — reads session JSONL logs and extracts conversation trajectories.
|
||
|
||
A trajectory is a coherent sequence of messages + tool calls that form
|
||
a single task attempt. Each trajectory becomes one training example.
|
||
|
||
Refs: #1105
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from datetime import UTC, datetime, timedelta
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_LOGS_DIR_DEFAULT = "logs"
|
||
_SESSION_GLOB = "session_*.jsonl"
|
||
|
||
|
||
@dataclass
|
||
class Trajectory:
|
||
"""A single conversation trajectory extracted from session logs."""
|
||
|
||
session_date: str
|
||
started_at: str
|
||
ended_at: str
|
||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||
errors: list[dict[str, Any]] = field(default_factory=list)
|
||
decisions: list[dict[str, Any]] = field(default_factory=list)
|
||
|
||
@property
|
||
def message_count(self) -> int:
|
||
return len(self.messages)
|
||
|
||
@property
|
||
def tool_call_count(self) -> int:
|
||
return len(self.tool_calls)
|
||
|
||
@property
|
||
def error_count(self) -> int:
|
||
return len(self.errors)
|
||
|
||
@property
|
||
def has_successful_tool_call(self) -> bool:
|
||
"""True if any tool call succeeded (no error entry follows it)."""
|
||
return self.tool_call_count > 0 and self.error_count == 0
|
||
|
||
@property
|
||
def is_multi_step(self) -> bool:
|
||
"""True if this trajectory involved multiple turns with tool use."""
|
||
return self.message_count >= 2 and self.tool_call_count >= 1
|
||
|
||
def to_chat_format(self) -> list[dict[str, str]]:
|
||
"""Convert trajectory to chat-format messages for training.
|
||
|
||
Interleaves messages and tool-call results as assistant/tool turns.
|
||
"""
|
||
chat: list[dict[str, str]] = []
|
||
# Merge all entries by timestamp and emit in order
|
||
all_entries = sorted(
|
||
self.messages + self.tool_calls + self.decisions,
|
||
key=lambda e: e.get("timestamp", ""),
|
||
)
|
||
for entry in all_entries:
|
||
etype = entry.get("type")
|
||
if etype == "message":
|
||
role = "user" if entry.get("role") == "user" else "assistant"
|
||
content = entry.get("content", "")
|
||
if content:
|
||
chat.append({"role": role, "content": content})
|
||
elif etype == "tool_call":
|
||
tool = entry.get("tool", "unknown")
|
||
result = entry.get("result", "")
|
||
chat.append(
|
||
{
|
||
"role": "assistant",
|
||
"content": f"[tool:{tool}] {result}",
|
||
}
|
||
)
|
||
elif etype == "decision":
|
||
decision = entry.get("decision", "")
|
||
if decision:
|
||
chat.append({"role": "assistant", "content": f"[decided] {decision}"})
|
||
return chat
|
||
|
||
|
||
class TrajectoryExporter:
|
||
"""Reads session JSONL logs and yields Trajectory objects for a date range."""
|
||
|
||
def __init__(self, logs_dir: str | Path | None = None, repo_root: str | Path | None = None):
|
||
if repo_root is None:
|
||
repo_root = Path(__file__).resolve().parent.parent.parent
|
||
self._repo_root = Path(repo_root)
|
||
|
||
if logs_dir is None:
|
||
self._logs_dir = self._repo_root / _LOGS_DIR_DEFAULT
|
||
else:
|
||
self._logs_dir = Path(logs_dir)
|
||
|
||
def export_week(self, weeks_ago: int = 0) -> list[Trajectory]:
|
||
"""Export all trajectories from the specified week.
|
||
|
||
Args:
|
||
weeks_ago: 0 = current week, 1 = last week, etc.
|
||
|
||
Returns:
|
||
List of Trajectory objects extracted from session logs.
|
||
"""
|
||
now = datetime.now(tz=UTC)
|
||
# Week boundaries: Mon–Sun
|
||
days_since_monday = now.weekday()
|
||
week_start = (now - timedelta(days=days_since_monday + 7 * weeks_ago)).replace(
|
||
hour=0, minute=0, second=0, microsecond=0
|
||
)
|
||
week_end = week_start + timedelta(days=7)
|
||
|
||
logger.info(
|
||
"Exporting trajectories for week %s–%s",
|
||
week_start.date().isoformat(),
|
||
week_end.date().isoformat(),
|
||
)
|
||
|
||
trajectories: list[Trajectory] = []
|
||
log_files = sorted(self._logs_dir.glob(_SESSION_GLOB))
|
||
|
||
for log_file in log_files:
|
||
# Parse date from filename: session_YYYY-MM-DD.jsonl
|
||
try:
|
||
date_str = log_file.stem.removeprefix("session_")
|
||
file_date = datetime.strptime(date_str, "%Y-%m-%d").replace(tzinfo=UTC)
|
||
except ValueError:
|
||
logger.debug("Skipping non-date session file: %s", log_file.name)
|
||
continue
|
||
|
||
if not (week_start <= file_date < week_end):
|
||
continue
|
||
|
||
file_trajectories = self._extract_from_file(log_file)
|
||
trajectories.extend(file_trajectories)
|
||
logger.info(
|
||
"Extracted %d trajectories from %s", len(file_trajectories), log_file.name
|
||
)
|
||
|
||
logger.info("Total trajectories exported: %d", len(trajectories))
|
||
return trajectories
|
||
|
||
def _extract_from_file(self, log_file: Path) -> list[Trajectory]:
|
||
"""Parse a single session JSONL file into trajectories.
|
||
|
||
Groups entries into trajectories by finding natural conversation
|
||
boundaries (gaps of inactivity or topic shifts in the message stream).
|
||
"""
|
||
entries: list[dict[str, Any]] = []
|
||
try:
|
||
with open(log_file) as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
entries.append(json.loads(line))
|
||
except json.JSONDecodeError:
|
||
logger.debug("Skipping malformed JSON line in %s", log_file.name)
|
||
except OSError as exc:
|
||
logger.warning("Could not read %s: %s", log_file, exc)
|
||
return []
|
||
|
||
if not entries:
|
||
return []
|
||
|
||
date_str = log_file.stem.removeprefix("session_")
|
||
return self._segment_trajectories(entries, date_str)
|
||
|
||
def _segment_trajectories(
|
||
self, entries: list[dict[str, Any]], session_date: str
|
||
) -> list[Trajectory]:
|
||
"""Split a flat list of session entries into discrete trajectories.
|
||
|
||
Segmentation rule: start a new trajectory when:
|
||
- A user message follows a Timmy message (new conversation turn)
|
||
- More than 5 minutes have elapsed between entries
|
||
|
||
This produces training examples that are coherent task attempts.
|
||
"""
|
||
if not entries:
|
||
return []
|
||
|
||
trajectories: list[Trajectory] = []
|
||
current_entries: list[dict[str, Any]] = []
|
||
prev_ts: datetime | None = None
|
||
_SEGMENT_GAP_MINUTES = 5
|
||
|
||
def _flush() -> None:
|
||
if current_entries:
|
||
traj = _build_trajectory(current_entries, session_date)
|
||
if traj.message_count > 0:
|
||
trajectories.append(traj)
|
||
|
||
for entry in entries:
|
||
ts_raw = entry.get("timestamp", "")
|
||
try:
|
||
ts = datetime.fromisoformat(ts_raw.replace("Z", "+00:00"))
|
||
except (ValueError, AttributeError):
|
||
ts = None
|
||
|
||
# Time-gap segmentation
|
||
if ts and prev_ts and (ts - prev_ts).total_seconds() > _SEGMENT_GAP_MINUTES * 60:
|
||
_flush()
|
||
current_entries = []
|
||
|
||
# New-turn segmentation: user message after assistant turn
|
||
etype = entry.get("type")
|
||
erole = entry.get("role")
|
||
if etype == "message" and erole == "user" and current_entries:
|
||
# Check if previous non-error entry was a Timmy message
|
||
for prev in reversed(current_entries):
|
||
if prev.get("type") == "message":
|
||
if prev.get("role") == "timmy":
|
||
_flush()
|
||
current_entries = []
|
||
break
|
||
|
||
current_entries.append(entry)
|
||
if ts:
|
||
prev_ts = ts
|
||
|
||
_flush()
|
||
return trajectories
|
||
|
||
|
||
def _build_trajectory(entries: list[dict[str, Any]], session_date: str) -> Trajectory:
|
||
"""Build a Trajectory from a flat list of entries."""
|
||
messages = [e for e in entries if e.get("type") == "message"]
|
||
tool_calls = [e for e in entries if e.get("type") == "tool_call"]
|
||
errors = [e for e in entries if e.get("type") == "error"]
|
||
decisions = [e for e in entries if e.get("type") == "decision"]
|
||
|
||
timestamps = [e.get("timestamp", "") for e in entries if e.get("timestamp")]
|
||
started_at = min(timestamps) if timestamps else ""
|
||
ended_at = max(timestamps) if timestamps else ""
|
||
|
||
return Trajectory(
|
||
session_date=session_date,
|
||
started_at=started_at,
|
||
ended_at=ended_at,
|
||
messages=messages,
|
||
tool_calls=tool_calls,
|
||
errors=errors,
|
||
decisions=decisions,
|
||
)
|