diff --git a/scripts/export_trajectories.py b/scripts/export_trajectories.py new file mode 100644 index 00000000..a37630d4 --- /dev/null +++ b/scripts/export_trajectories.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +"""Export Claude conversation trajectories to ShareGPT JSONL format for LoRA fine-tuning. + +Reads from two sources (in priority order): + 1. logs/session_*.jsonl — rich logs with tool calls (preferred) + 2. data/chat.db — SQLite chat history (fallback) + +Output is a ShareGPT-compatible JSONL file where each line is one conversation: + {"conversations": [ + {"from": "human", "value": "..."}, + {"from": "gpt", "value": "...", "tool_calls": [...]}, + {"from": "tool", "value": "..."}, + {"from": "gpt", "value": "..."} + ]} + +Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 3 of 7) +Refs: #1102 +""" + +from __future__ import annotations + +import argparse +import json +import sqlite3 +import sys +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +# ── Constants ──────────────────────────────────────────────────────────────── + +REPO_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_LOGS_DIR = REPO_ROOT / "logs" +DEFAULT_DB_PATH = REPO_ROOT / "data" / "chat.db" +DEFAULT_OUTPUT = Path.home() / "timmy-training-data.jsonl" + +# Time gap that signals a new conversation boundary +CONVERSATION_GAP_MINUTES = 30 + +# Role mappings → ShareGPT "from" values +ROLE_MAP = { + "user": "human", + "timmy": "gpt", + "agent": "gpt", + "assistant": "gpt", + "system": "system", +} + + +# ── Session log reader ─────────────────────────────────────────────────────── + +def _parse_ts(ts: str) -> datetime | None: + """Parse an ISO timestamp string, returning None on failure.""" + try: + return datetime.fromisoformat(ts) + except (ValueError, TypeError): + return None + + +def _group_into_conversations( + entries: list[dict], + gap_minutes: int = CONVERSATION_GAP_MINUTES, +) -> list[list[dict]]: + """Split a flat list of session entries into conversation windows. + + A new conversation starts whenever there is a gap ≥ *gap_minutes* between + consecutive entries, or when the type sequence restarts with a user message + after an agent reply. + """ + if not entries: + return [] + + conversations: list[list[dict]] = [] + current: list[dict] = [] + last_ts: datetime | None = None + + for entry in entries: + ts = _parse_ts(entry.get("timestamp", "")) + + if last_ts is not None and ts is not None: + gap = ts - last_ts + if gap >= timedelta(minutes=gap_minutes): + if current: + conversations.append(current) + current = [] + + current.append(entry) + if ts is not None: + last_ts = ts + + if current: + conversations.append(current) + + return conversations + + +def _conversation_to_sharegpt(entries: list[dict]) -> dict[str, Any] | None: + """Convert a list of session entries into a ShareGPT conversation dict. + + Returns None if the conversation has fewer than 2 turns (not useful for + training). + """ + turns: list[dict[str, Any]] = [] + pending_tool_calls: list[dict] = [] + + for entry in entries: + etype = entry.get("type") + + if etype == "message": + role_raw = entry.get("role", "") + from_role = ROLE_MAP.get(role_raw, "gpt") + content = entry.get("content", "") + + if not content: + continue + + turn: dict[str, Any] = {"from": from_role, "value": content} + + # Attach any accumulated tool calls to this gpt turn + if pending_tool_calls and from_role == "gpt": + turn["tool_calls"] = pending_tool_calls + pending_tool_calls = [] + + turns.append(turn) + + elif etype == "tool_call": + tool_name = entry.get("tool", "unknown") + args = entry.get("args", {}) + result = entry.get("result", "") + + # Record call for the next gpt turn + pending_tool_calls.append({ + "name": tool_name, + "arguments": args, + }) + + # Also emit a tool-result turn immediately after + turns.append({"from": "tool", "value": str(result), "tool": tool_name}) + + # Discard conversations with < 2 meaningful turns + meaningful = [t for t in turns if t["from"] in ("human", "gpt")] + if len(meaningful) < 2: + return None + + return {"conversations": turns} + + +def load_from_session_logs(logs_dir: Path) -> list[dict[str, Any]]: + """Load all session JSONL logs and return ShareGPT-formatted conversations.""" + log_files = sorted(logs_dir.glob("session_*.jsonl")) + if not log_files: + return [] + + all_entries: list[dict] = [] + for log_file in log_files: + try: + with open(log_file) as f: + for line in f: + line = line.strip() + if line: + try: + all_entries.append(json.loads(line)) + except json.JSONDecodeError: + continue + except OSError: + continue + + # Sort by timestamp for correct ordering across files + all_entries.sort(key=lambda e: e.get("timestamp", "")) + + conversation_groups = _group_into_conversations(all_entries) + results: list[dict[str, Any]] = [] + for group in conversation_groups: + conv = _conversation_to_sharegpt(group) + if conv is not None: + results.append(conv) + + return results + + +# ── SQLite fallback reader ─────────────────────────────────────────────────── + +def load_from_sqlite(db_path: Path) -> list[dict[str, Any]]: + """Read chat.db and return ShareGPT-formatted conversations.""" + if not db_path.exists(): + return [] + + try: + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT role, content, timestamp FROM chat_messages ORDER BY id" + ).fetchall() + conn.close() + except sqlite3.Error: + return [] + + entries = [ + { + "type": "message", + "role": row["role"], + "content": row["content"], + "timestamp": row["timestamp"], + } + for row in rows + ] + + conversation_groups = _group_into_conversations(entries) + results: list[dict[str, Any]] = [] + for group in conversation_groups: + conv = _conversation_to_sharegpt(group) + if conv is not None: + results.append(conv) + + return results + + +# ── Validation ─────────────────────────────────────────────────────────────── + +def validate_output(output_path: Path) -> dict[str, Any]: + """Validate the exported JSONL and return stats.""" + if not output_path.exists(): + return {"error": "Output file not found"} + + total = 0 + with_tools = 0 + turn_counts: list[int] = [] + + with open(output_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except json.JSONDecodeError: + continue + + total += 1 + turns = obj.get("conversations", []) + turn_counts.append(len(turns)) + + has_tool = any( + t.get("from") == "tool" or t.get("tool_calls") + for t in turns + ) + if has_tool: + with_tools += 1 + + avg_turns = sum(turn_counts) / len(turn_counts) if turn_counts else 0 + + return { + "total_conversations": total, + "with_tool_calls": with_tools, + "avg_turns_per_conversation": round(avg_turns, 1), + "output_path": str(output_path), + } + + +# ── Main ───────────────────────────────────────────────────────────────────── + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="Export Timmy conversation trajectories to ShareGPT JSONL", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument( + "--logs-dir", + type=Path, + default=DEFAULT_LOGS_DIR, + help="Directory containing session_*.jsonl files", + ) + p.add_argument( + "--db", + type=Path, + default=DEFAULT_DB_PATH, + help="Path to chat.db (used if no session logs found)", + ) + p.add_argument( + "--output", + type=Path, + default=DEFAULT_OUTPUT, + help="Output JSONL file path", + ) + p.add_argument( + "--gap-minutes", + type=int, + default=CONVERSATION_GAP_MINUTES, + help="Time gap (minutes) between entries that marks a new conversation", + ) + p.add_argument( + "--validate-only", + action="store_true", + help="Skip export; just validate an existing output file", + ) + p.add_argument( + "--min-examples", + type=int, + default=0, + help="Exit non-zero if fewer than this many examples are exported", + ) + return p + + +def main(argv: list[str] | None = None) -> int: + args = build_parser().parse_args(argv) + + if args.validate_only: + stats = validate_output(args.output) + print(json.dumps(stats, indent=2)) + return 0 + + # ── Load conversations ─────────────────────────────────────────────────── + print(f"[1/3] Loading from session logs: {args.logs_dir}") + conversations = load_from_session_logs(args.logs_dir) + + if not conversations: + print(f"[1/3] No session logs found — falling back to SQLite: {args.db}") + conversations = load_from_sqlite(args.db) + + if not conversations: + print( + "WARNING: No conversation data found.\n" + " • Run the dashboard and have some conversations first.\n" + " • Session logs are written to logs/session_YYYY-MM-DD.jsonl\n" + " • Chat history is stored in data/chat.db", + file=sys.stderr, + ) + # Still write empty file so downstream steps don't error on missing file + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text("") + return 0 + + # ── Write output ───────────────────────────────────────────────────────── + print(f"[2/3] Writing {len(conversations)} conversations → {args.output}") + args.output.parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + for conv in conversations: + f.write(json.dumps(conv) + "\n") + + # ── Validate ───────────────────────────────────────────────────────────── + print("[3/3] Validating output…") + stats = validate_output(args.output) + print(json.dumps(stats, indent=2)) + + if args.min_examples and stats.get("total_conversations", 0) < args.min_examples: + print( + f"ERROR: Only {stats['total_conversations']} examples exported " + f"(need ≥ {args.min_examples})", + file=sys.stderr, + ) + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/scripts/test_export_trajectories.py b/tests/scripts/test_export_trajectories.py new file mode 100644 index 00000000..476ecfd2 --- /dev/null +++ b/tests/scripts/test_export_trajectories.py @@ -0,0 +1,306 @@ +"""Unit tests for scripts/export_trajectories.py.""" + +from __future__ import annotations + +import json +import sqlite3 +from datetime import datetime, timedelta +from pathlib import Path + +import pytest + +import scripts.export_trajectories as et + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _ts(base: datetime, offset_minutes: int = 0) -> str: + return (base + timedelta(minutes=offset_minutes)).isoformat() + + +BASE = datetime(2026, 3, 1, 10, 0, 0) + + +def _make_session_entries(base: datetime = BASE) -> list[dict]: + """Minimal session log entries: user → tool_call → timmy reply.""" + return [ + {"type": "message", "role": "user", "content": "list my files", "timestamp": _ts(base, 0)}, + {"type": "tool_call", "tool": "shell", "args": {"cmd": "ls"}, "result": "a.py\nb.py", "timestamp": _ts(base, 1)}, + {"type": "message", "role": "timmy", "content": "You have two files.", "timestamp": _ts(base, 2)}, + ] + + +# ── _group_into_conversations ───────────────────────────────────────────────── + +class TestGroupIntoConversations: + def test_empty(self): + assert et._group_into_conversations([]) == [] + + def test_single_group_no_gap(self): + entries = _make_session_entries() + groups = et._group_into_conversations(entries, gap_minutes=30) + assert len(groups) == 1 + assert groups[0] == entries + + def test_split_on_large_gap(self): + entries_a = _make_session_entries(BASE) + # Second set starts 60 minutes later + entries_b = _make_session_entries(BASE + timedelta(hours=1)) + groups = et._group_into_conversations(entries_a + entries_b, gap_minutes=30) + assert len(groups) == 2 + assert len(groups[0]) == 3 + assert len(groups[1]) == 3 + + def test_no_split_within_gap(self): + entries = _make_session_entries() + groups = et._group_into_conversations(entries, gap_minutes=60) + assert len(groups) == 1 + + def test_entries_without_timestamp(self): + entries = [ + {"type": "message", "role": "user", "content": "hello"}, + {"type": "message", "role": "timmy", "content": "hi"}, + ] + groups = et._group_into_conversations(entries, gap_minutes=30) + assert len(groups) == 1 + + +# ── _conversation_to_sharegpt ───────────────────────────────────────────────── + +class TestConversationToSharegpt: + def test_basic_exchange(self): + entries = _make_session_entries() + result = et._conversation_to_sharegpt(entries) + assert result is not None + turns = result["conversations"] + + human_turns = [t for t in turns if t["from"] == "human"] + gpt_turns = [t for t in turns if t["from"] == "gpt"] + tool_turns = [t for t in turns if t["from"] == "tool"] + + assert len(human_turns) == 1 + assert len(gpt_turns) == 1 + assert len(tool_turns) == 1 + + def test_tool_calls_attached_to_gpt_turn(self): + entries = [ + {"type": "message", "role": "user", "content": "run ls", "timestamp": _ts(BASE, 0)}, + {"type": "tool_call", "tool": "shell", "args": {}, "result": "ok", "timestamp": _ts(BASE, 1)}, + {"type": "message", "role": "timmy", "content": "done", "timestamp": _ts(BASE, 2)}, + ] + result = et._conversation_to_sharegpt(entries) + assert result is not None + gpt_turns = [t for t in result["conversations"] if t["from"] == "gpt"] + assert len(gpt_turns) == 1 + assert "tool_calls" in gpt_turns[0] + assert gpt_turns[0]["tool_calls"][0]["name"] == "shell" + + def test_too_short_returns_none(self): + # Only one meaningful turn → not useful for training + entries = [{"type": "message", "role": "user", "content": "hi", "timestamp": _ts(BASE)}] + assert et._conversation_to_sharegpt(entries) is None + + def test_empty_content_skipped(self): + entries = [ + {"type": "message", "role": "user", "content": "", "timestamp": _ts(BASE, 0)}, + {"type": "message", "role": "timmy", "content": "pong", "timestamp": _ts(BASE, 1)}, + ] + # Only one non-empty turn → should return None + assert et._conversation_to_sharegpt(entries) is None + + def test_role_mapping(self): + entries = [ + {"type": "message", "role": "user", "content": "q", "timestamp": _ts(BASE, 0)}, + {"type": "message", "role": "assistant", "content": "a", "timestamp": _ts(BASE, 1)}, + ] + result = et._conversation_to_sharegpt(entries) + assert result is not None + roles = [t["from"] for t in result["conversations"]] + assert "human" in roles + assert "gpt" in roles + + def test_decision_entries_ignored(self): + """Non-message, non-tool entries (decisions, errors) should be skipped.""" + entries = _make_session_entries() + [ + {"type": "decision", "decision": "do something", "timestamp": _ts(BASE, 10)}, + ] + result = et._conversation_to_sharegpt(entries) + assert result is not None + assert all(t["from"] != "decision" for t in result["conversations"]) + + +# ── load_from_session_logs ──────────────────────────────────────────────────── + +class TestLoadFromSessionLogs: + def test_empty_directory(self, tmp_path): + assert et.load_from_session_logs(tmp_path) == [] + + def test_missing_directory(self, tmp_path): + assert et.load_from_session_logs(tmp_path / "nonexistent") == [] + + def test_reads_single_log(self, tmp_path): + entries = _make_session_entries() + log = tmp_path / "session_2026-03-01.jsonl" + log.write_text("\n".join(json.dumps(e) for e in entries) + "\n") + + result = et.load_from_session_logs(tmp_path) + assert len(result) == 1 + assert result[0]["conversations"][0]["from"] == "human" + + def test_reads_multiple_logs(self, tmp_path): + for day in range(3): + entries = _make_session_entries(BASE + timedelta(days=day, hours=2 * day)) + log = tmp_path / f"session_2026-03-0{day + 1}.jsonl" + log.write_text("\n".join(json.dumps(e) for e in entries) + "\n") + + result = et.load_from_session_logs(tmp_path) + # 3 log files, each a separate conversation (days apart) + assert len(result) == 3 + + def test_skips_malformed_lines(self, tmp_path): + log = tmp_path / "session_2026-03-01.jsonl" + entries = _make_session_entries() + lines = [json.dumps(e) for e in entries] + lines.insert(1, "not valid json{{{") + log.write_text("\n".join(lines) + "\n") + + # Should still parse valid entries + result = et.load_from_session_logs(tmp_path) + assert len(result) == 1 + + +# ── load_from_sqlite ────────────────────────────────────────────────────────── + +class TestLoadFromSqlite: + def _make_db(self, tmp_path: Path, rows: list[tuple]) -> Path: + db = tmp_path / "chat.db" + conn = sqlite3.connect(str(db)) + conn.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + role TEXT, content TEXT, timestamp TEXT, source TEXT + ) + """) + conn.executemany( + "INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?,?,?,?)", + rows, + ) + conn.commit() + conn.close() + return db + + def test_missing_db(self, tmp_path): + assert et.load_from_sqlite(tmp_path / "missing.db") == [] + + def test_reads_conversation(self, tmp_path): + rows = [ + ("user", "hello", _ts(BASE, 0), "browser"), + ("agent", "hi there", _ts(BASE, 5), "browser"), + ] + db = self._make_db(tmp_path, rows) + result = et.load_from_sqlite(db) + assert len(result) == 1 + turns = result[0]["conversations"] + assert turns[0]["from"] == "human" + assert turns[1]["from"] == "gpt" + + def test_splits_on_gap(self, tmp_path): + rows = [ + ("user", "a", _ts(BASE, 0), "browser"), + ("agent", "b", _ts(BASE, 5), "browser"), + ("user", "c", _ts(BASE, 120), "browser"), # 2h gap + ("agent", "d", _ts(BASE, 125), "browser"), + ] + db = self._make_db(tmp_path, rows) + result = et.load_from_sqlite(db) + assert len(result) == 2 + + +# ── validate_output ─────────────────────────────────────────────────────────── + +class TestValidateOutput: + def test_missing_file(self, tmp_path): + stats = et.validate_output(tmp_path / "missing.jsonl") + assert "error" in stats + + def test_counts_conversations(self, tmp_path): + out = tmp_path / "out.jsonl" + convs = [ + {"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "ho"}]}, + {"conversations": [{"from": "human", "value": "a"}, {"from": "gpt", "value": "b"}]}, + ] + out.write_text("\n".join(json.dumps(c) for c in convs) + "\n") + stats = et.validate_output(out) + assert stats["total_conversations"] == 2 + assert stats["with_tool_calls"] == 0 + + def test_counts_tool_calls(self, tmp_path): + out = tmp_path / "out.jsonl" + conv = {"conversations": [ + {"from": "human", "value": "run"}, + {"from": "gpt", "value": "ok", "tool_calls": [{"name": "shell", "arguments": {}}]}, + {"from": "tool", "value": "done", "tool": "shell"}, + ]} + out.write_text(json.dumps(conv) + "\n") + stats = et.validate_output(out) + assert stats["with_tool_calls"] == 1 + + +# ── CLI (main) ──────────────────────────────────────────────────────────────── + +class TestMain: + def test_no_data_exits_0(self, tmp_path): + out = tmp_path / "out.jsonl" + code = et.main([ + "--logs-dir", str(tmp_path / "logs"), + "--db", str(tmp_path / "missing.db"), + "--output", str(out), + ]) + assert code == 0 + assert out.exists() + + def test_exports_from_logs(self, tmp_path): + logs = tmp_path / "logs" + logs.mkdir() + entries = _make_session_entries() + (logs / "session_2026-03-01.jsonl").write_text( + "\n".join(json.dumps(e) for e in entries) + "\n" + ) + out = tmp_path / "out.jsonl" + code = et.main([ + "--logs-dir", str(logs), + "--db", str(tmp_path / "missing.db"), + "--output", str(out), + ]) + assert code == 0 + lines = [l for l in out.read_text().splitlines() if l.strip()] + assert len(lines) == 1 + + def test_validate_only(self, tmp_path, capsys): + out = tmp_path / "out.jsonl" + conv = {"conversations": [ + {"from": "human", "value": "x"}, + {"from": "gpt", "value": "y"}, + ]} + out.write_text(json.dumps(conv) + "\n") + code = et.main(["--validate-only", "--output", str(out)]) + assert code == 0 + captured = capsys.readouterr() + stats = json.loads(captured.out) + assert stats["total_conversations"] == 1 + + def test_min_examples_fails(self, tmp_path): + logs = tmp_path / "logs" + logs.mkdir() + entries = _make_session_entries() + (logs / "session_2026-03-01.jsonl").write_text( + "\n".join(json.dumps(e) for e in entries) + "\n" + ) + out = tmp_path / "out.jsonl" + code = et.main([ + "--logs-dir", str(logs), + "--db", str(tmp_path / "missing.db"), + "--output", str(out), + "--min-examples", "100", + ]) + assert code == 1