Compare commits
2 Commits
gemini/iss
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2f8989c39 | ||
| 852fec3681 |
726
poetry.lock
generated
726
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -68,7 +68,7 @@ voice = ["pyttsx3", "openai-whisper", "piper-tts", "sounddevice"]
|
||||
celery = ["celery"]
|
||||
embeddings = ["sentence-transformers", "numpy"]
|
||||
git = ["GitPython"]
|
||||
research = ["requests", "trafilatura"]
|
||||
research = ["requests", "trafilatura", "google-search-results"]
|
||||
dev = ["pytest", "pytest-asyncio", "pytest-cov", "pytest-timeout", "pytest-randomly", "pytest-xdist", "selenium"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
358
scripts/export_trajectories.py
Normal file
358
scripts/export_trajectories.py
Normal file
@@ -0,0 +1,358 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Export Claude conversation trajectories to ShareGPT JSONL format for LoRA fine-tuning.
|
||||
|
||||
Reads from two sources (in priority order):
|
||||
1. logs/session_*.jsonl — rich logs with tool calls (preferred)
|
||||
2. data/chat.db — SQLite chat history (fallback)
|
||||
|
||||
Output is a ShareGPT-compatible JSONL file where each line is one conversation:
|
||||
{"conversations": [
|
||||
{"from": "human", "value": "..."},
|
||||
{"from": "gpt", "value": "...", "tool_calls": [...]},
|
||||
{"from": "tool", "value": "..."},
|
||||
{"from": "gpt", "value": "..."}
|
||||
]}
|
||||
|
||||
Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 3 of 7)
|
||||
Refs: #1102
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sqlite3
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
# ── Constants ────────────────────────────────────────────────────────────────
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_LOGS_DIR = REPO_ROOT / "logs"
|
||||
DEFAULT_DB_PATH = REPO_ROOT / "data" / "chat.db"
|
||||
DEFAULT_OUTPUT = Path.home() / "timmy-training-data.jsonl"
|
||||
|
||||
# Time gap that signals a new conversation boundary
|
||||
CONVERSATION_GAP_MINUTES = 30
|
||||
|
||||
# Role mappings → ShareGPT "from" values
|
||||
ROLE_MAP = {
|
||||
"user": "human",
|
||||
"timmy": "gpt",
|
||||
"agent": "gpt",
|
||||
"assistant": "gpt",
|
||||
"system": "system",
|
||||
}
|
||||
|
||||
|
||||
# ── Session log reader ───────────────────────────────────────────────────────
|
||||
|
||||
def _parse_ts(ts: str) -> datetime | None:
|
||||
"""Parse an ISO timestamp string, returning None on failure."""
|
||||
try:
|
||||
return datetime.fromisoformat(ts)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def _group_into_conversations(
|
||||
entries: list[dict],
|
||||
gap_minutes: int = CONVERSATION_GAP_MINUTES,
|
||||
) -> list[list[dict]]:
|
||||
"""Split a flat list of session entries into conversation windows.
|
||||
|
||||
A new conversation starts whenever there is a gap ≥ *gap_minutes* between
|
||||
consecutive entries, or when the type sequence restarts with a user message
|
||||
after an agent reply.
|
||||
"""
|
||||
if not entries:
|
||||
return []
|
||||
|
||||
conversations: list[list[dict]] = []
|
||||
current: list[dict] = []
|
||||
last_ts: datetime | None = None
|
||||
|
||||
for entry in entries:
|
||||
ts = _parse_ts(entry.get("timestamp", ""))
|
||||
|
||||
if last_ts is not None and ts is not None:
|
||||
gap = ts - last_ts
|
||||
if gap >= timedelta(minutes=gap_minutes):
|
||||
if current:
|
||||
conversations.append(current)
|
||||
current = []
|
||||
|
||||
current.append(entry)
|
||||
if ts is not None:
|
||||
last_ts = ts
|
||||
|
||||
if current:
|
||||
conversations.append(current)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def _conversation_to_sharegpt(entries: list[dict]) -> dict[str, Any] | None:
|
||||
"""Convert a list of session entries into a ShareGPT conversation dict.
|
||||
|
||||
Returns None if the conversation has fewer than 2 turns (not useful for
|
||||
training).
|
||||
"""
|
||||
turns: list[dict[str, Any]] = []
|
||||
pending_tool_calls: list[dict] = []
|
||||
|
||||
for entry in entries:
|
||||
etype = entry.get("type")
|
||||
|
||||
if etype == "message":
|
||||
role_raw = entry.get("role", "")
|
||||
from_role = ROLE_MAP.get(role_raw, "gpt")
|
||||
content = entry.get("content", "")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
turn: dict[str, Any] = {"from": from_role, "value": content}
|
||||
|
||||
# Attach any accumulated tool calls to this gpt turn
|
||||
if pending_tool_calls and from_role == "gpt":
|
||||
turn["tool_calls"] = pending_tool_calls
|
||||
pending_tool_calls = []
|
||||
|
||||
turns.append(turn)
|
||||
|
||||
elif etype == "tool_call":
|
||||
tool_name = entry.get("tool", "unknown")
|
||||
args = entry.get("args", {})
|
||||
result = entry.get("result", "")
|
||||
|
||||
# Record call for the next gpt turn
|
||||
pending_tool_calls.append({
|
||||
"name": tool_name,
|
||||
"arguments": args,
|
||||
})
|
||||
|
||||
# Also emit a tool-result turn immediately after
|
||||
turns.append({"from": "tool", "value": str(result), "tool": tool_name})
|
||||
|
||||
# Discard conversations with < 2 meaningful turns
|
||||
meaningful = [t for t in turns if t["from"] in ("human", "gpt")]
|
||||
if len(meaningful) < 2:
|
||||
return None
|
||||
|
||||
return {"conversations": turns}
|
||||
|
||||
|
||||
def load_from_session_logs(logs_dir: Path) -> list[dict[str, Any]]:
|
||||
"""Load all session JSONL logs and return ShareGPT-formatted conversations."""
|
||||
log_files = sorted(logs_dir.glob("session_*.jsonl"))
|
||||
if not log_files:
|
||||
return []
|
||||
|
||||
all_entries: list[dict] = []
|
||||
for log_file in log_files:
|
||||
try:
|
||||
with open(log_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
all_entries.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# Sort by timestamp for correct ordering across files
|
||||
all_entries.sort(key=lambda e: e.get("timestamp", ""))
|
||||
|
||||
conversation_groups = _group_into_conversations(all_entries)
|
||||
results: list[dict[str, Any]] = []
|
||||
for group in conversation_groups:
|
||||
conv = _conversation_to_sharegpt(group)
|
||||
if conv is not None:
|
||||
results.append(conv)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── SQLite fallback reader ───────────────────────────────────────────────────
|
||||
|
||||
def load_from_sqlite(db_path: Path) -> list[dict[str, Any]]:
|
||||
"""Read chat.db and return ShareGPT-formatted conversations."""
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT role, content, timestamp FROM chat_messages ORDER BY id"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
except sqlite3.Error:
|
||||
return []
|
||||
|
||||
entries = [
|
||||
{
|
||||
"type": "message",
|
||||
"role": row["role"],
|
||||
"content": row["content"],
|
||||
"timestamp": row["timestamp"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
conversation_groups = _group_into_conversations(entries)
|
||||
results: list[dict[str, Any]] = []
|
||||
for group in conversation_groups:
|
||||
conv = _conversation_to_sharegpt(group)
|
||||
if conv is not None:
|
||||
results.append(conv)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ── Validation ───────────────────────────────────────────────────────────────
|
||||
|
||||
def validate_output(output_path: Path) -> dict[str, Any]:
|
||||
"""Validate the exported JSONL and return stats."""
|
||||
if not output_path.exists():
|
||||
return {"error": "Output file not found"}
|
||||
|
||||
total = 0
|
||||
with_tools = 0
|
||||
turn_counts: list[int] = []
|
||||
|
||||
with open(output_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
total += 1
|
||||
turns = obj.get("conversations", [])
|
||||
turn_counts.append(len(turns))
|
||||
|
||||
has_tool = any(
|
||||
t.get("from") == "tool" or t.get("tool_calls")
|
||||
for t in turns
|
||||
)
|
||||
if has_tool:
|
||||
with_tools += 1
|
||||
|
||||
avg_turns = sum(turn_counts) / len(turn_counts) if turn_counts else 0
|
||||
|
||||
return {
|
||||
"total_conversations": total,
|
||||
"with_tool_calls": with_tools,
|
||||
"avg_turns_per_conversation": round(avg_turns, 1),
|
||||
"output_path": str(output_path),
|
||||
}
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Export Timmy conversation trajectories to ShareGPT JSONL",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
p.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
default=DEFAULT_LOGS_DIR,
|
||||
help="Directory containing session_*.jsonl files",
|
||||
)
|
||||
p.add_argument(
|
||||
"--db",
|
||||
type=Path,
|
||||
default=DEFAULT_DB_PATH,
|
||||
help="Path to chat.db (used if no session logs found)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=DEFAULT_OUTPUT,
|
||||
help="Output JSONL file path",
|
||||
)
|
||||
p.add_argument(
|
||||
"--gap-minutes",
|
||||
type=int,
|
||||
default=CONVERSATION_GAP_MINUTES,
|
||||
help="Time gap (minutes) between entries that marks a new conversation",
|
||||
)
|
||||
p.add_argument(
|
||||
"--validate-only",
|
||||
action="store_true",
|
||||
help="Skip export; just validate an existing output file",
|
||||
)
|
||||
p.add_argument(
|
||||
"--min-examples",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Exit non-zero if fewer than this many examples are exported",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = build_parser().parse_args(argv)
|
||||
|
||||
if args.validate_only:
|
||||
stats = validate_output(args.output)
|
||||
print(json.dumps(stats, indent=2))
|
||||
return 0
|
||||
|
||||
# ── Load conversations ───────────────────────────────────────────────────
|
||||
print(f"[1/3] Loading from session logs: {args.logs_dir}")
|
||||
conversations = load_from_session_logs(args.logs_dir)
|
||||
|
||||
if not conversations:
|
||||
print(f"[1/3] No session logs found — falling back to SQLite: {args.db}")
|
||||
conversations = load_from_sqlite(args.db)
|
||||
|
||||
if not conversations:
|
||||
print(
|
||||
"WARNING: No conversation data found.\n"
|
||||
" • Run the dashboard and have some conversations first.\n"
|
||||
" • Session logs are written to logs/session_YYYY-MM-DD.jsonl\n"
|
||||
" • Chat history is stored in data/chat.db",
|
||||
file=sys.stderr,
|
||||
)
|
||||
# Still write empty file so downstream steps don't error on missing file
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
args.output.write_text("")
|
||||
return 0
|
||||
|
||||
# ── Write output ─────────────────────────────────────────────────────────
|
||||
print(f"[2/3] Writing {len(conversations)} conversations → {args.output}")
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(args.output, "w") as f:
|
||||
for conv in conversations:
|
||||
f.write(json.dumps(conv) + "\n")
|
||||
|
||||
# ── Validate ─────────────────────────────────────────────────────────────
|
||||
print("[3/3] Validating output…")
|
||||
stats = validate_output(args.output)
|
||||
print(json.dumps(stats, indent=2))
|
||||
|
||||
if args.min_examples and stats.get("total_conversations", 0) < args.min_examples:
|
||||
print(
|
||||
f"ERROR: Only {stats['total_conversations']} examples exported "
|
||||
f"(need ≥ {args.min_examples})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -375,13 +375,21 @@ def _startup_init() -> None:
|
||||
|
||||
def _startup_background_tasks() -> list[asyncio.Task]:
|
||||
"""Spawn all recurring background tasks (non-blocking)."""
|
||||
return [
|
||||
bg_tasks = [
|
||||
asyncio.create_task(_briefing_scheduler()),
|
||||
asyncio.create_task(_thinking_scheduler()),
|
||||
asyncio.create_task(_loop_qa_scheduler()),
|
||||
asyncio.create_task(_presence_watcher()),
|
||||
asyncio.create_task(_start_chat_integrations_background()),
|
||||
]
|
||||
try:
|
||||
from timmy.paperclip import start_paperclip_poller
|
||||
bg_tasks.append(asyncio.create_task(start_paperclip_poller()))
|
||||
logger.info("Paperclip poller started")
|
||||
except ImportError:
|
||||
logger.debug("Paperclip module not found, skipping poller")
|
||||
|
||||
return bg_tasks
|
||||
|
||||
|
||||
def _try_prune(label: str, prune_fn, days: int) -> None:
|
||||
|
||||
175
src/timmy/paperclip.py
Normal file
175
src/timmy/paperclip.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Paperclip integration for Timmy.
|
||||
|
||||
This module provides a client for the Paperclip API, and a poller for
|
||||
running research tasks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
from config import settings
|
||||
from timmy.research_triage import triage_research_report
|
||||
from timmy.research_tools import google_web_search, get_llm_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaperclipTask:
|
||||
"""A task from the Paperclip API."""
|
||||
|
||||
id: str
|
||||
kind: str
|
||||
context: dict
|
||||
|
||||
|
||||
class PaperclipClient:
|
||||
"""A client for the Paperclip API."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.base_url = settings.paperclip_url
|
||||
self.api_key = settings.paperclip_api_key
|
||||
self.agent_id = settings.paperclip_agent_id
|
||||
self.company_id = settings.paperclip_company_id
|
||||
self.timeout = settings.paperclip_timeout
|
||||
|
||||
async def get_tasks(self) -> list[PaperclipTask]:
|
||||
"""Get a list of tasks from the Paperclip API."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
resp = await client.get(
|
||||
f"{self.base_url}/api/tasks",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
params={
|
||||
"agent_id": self.agent_id,
|
||||
"company_id": self.company_id,
|
||||
"status": "queued",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
tasks = resp.json()
|
||||
return [
|
||||
PaperclipTask(id=t["id"], kind=t["kind"], context=t["context"])
|
||||
for t in tasks
|
||||
]
|
||||
|
||||
async def update_task_status(
|
||||
self, task_id: str, status: str, result: str | None = None
|
||||
) -> None:
|
||||
"""Update the status of a task."""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
await client.patch(
|
||||
f"{self.base_url}/api/tasks/{task_id}",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
json={"status": status, "result": result},
|
||||
)
|
||||
|
||||
|
||||
class ResearchOrchestrator:
|
||||
"""Orchestrates research tasks."""
|
||||
|
||||
async def get_gitea_issue(self, issue_number: int) -> dict:
|
||||
"""Get a Gitea issue by its number."""
|
||||
owner, repo = settings.gitea_repo.split("/", 1)
|
||||
api_url = f"{settings.gitea_url}/api/v1/repos/{owner}/{repo}/issues/{issue_number}"
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
resp = await client.get(
|
||||
api_url,
|
||||
headers={"Authorization": f"token {settings.gitea_token}"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def post_gitea_comment(self, issue_number: int, comment: str) -> None:
|
||||
"""Post a comment to a Gitea issue."""
|
||||
owner, repo = settings.gitea_repo.split("/", 1)
|
||||
api_url = f"{settings.gitea_url}/api/v1/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
async with httpx.AsyncClient(timeout=15) as client:
|
||||
await client.post(
|
||||
api_url,
|
||||
headers={"Authorization": f"token {settings.gitea_token}"},
|
||||
json={"body": comment},
|
||||
)
|
||||
|
||||
async def run_research_pipeline(self, issue_title: str) -> str:
|
||||
"""Run the research pipeline."""
|
||||
search_results = await google_web_search(issue_title)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
response = await llm_client.completion(
|
||||
f"Summarize the following search results and generate a research report:\\n\\n{search_results}",
|
||||
max_tokens=2048,
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def run(self, context: dict) -> str:
|
||||
"""Run a research task."""
|
||||
issue_number = context.get("issue_number")
|
||||
if not issue_number:
|
||||
return "Missing issue_number in task context"
|
||||
|
||||
issue = await self.get_gitea_issue(issue_number)
|
||||
|
||||
report = await self.run_research_pipeline(issue["title"])
|
||||
|
||||
triage_results = await triage_research_report(report, source_issue=issue_number)
|
||||
|
||||
comment = f"Research complete for issue #{issue_number}.\\n\\n"
|
||||
if triage_results:
|
||||
comment += "Created the following issues:\\n"
|
||||
for result in triage_results:
|
||||
if result["gitea_issue"]:
|
||||
comment += f"- #{result['gitea_issue']['number']}: {result['action_item'].title}\\n"
|
||||
else:
|
||||
comment += "No new issues were created.\\n"
|
||||
|
||||
await self.post_gitea_comment(issue_number, comment)
|
||||
|
||||
return f"Research complete for issue #{issue_number}"
|
||||
|
||||
|
||||
class PaperclipPoller:
|
||||
"""Polls the Paperclip API for new tasks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client = PaperclipClient()
|
||||
self.orchestrator = ResearchOrchestrator()
|
||||
self.poll_interval = settings.paperclip_poll_interval
|
||||
|
||||
async def poll(self) -> None:
|
||||
"""Poll the Paperclip API for new tasks."""
|
||||
if self.poll_interval == 0:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
tasks = await self.client.get_tasks()
|
||||
for task in tasks:
|
||||
if task.kind == "research":
|
||||
await self.run_research_task(task)
|
||||
except httpx.HTTPError as exc:
|
||||
logger.warning("Error polling Paperclip: %s", exc)
|
||||
|
||||
await asyncio.sleep(self.poll_interval)
|
||||
|
||||
async def run_research_task(self, task: PaperclipTask) -> None:
|
||||
"""Run a research task."""
|
||||
await self.client.update_task_status(task.id, "running")
|
||||
try:
|
||||
result = await self.orchestrator.run(task.context)
|
||||
await self.client.update_task_status(task.id, "completed", result)
|
||||
except Exception as exc:
|
||||
logger.error("Error running research task: %s", exc, exc_info=True)
|
||||
await self.client.update_task_status(task.id, "failed", str(exc))
|
||||
|
||||
|
||||
async def start_paperclip_poller() -> None:
|
||||
"""Start the Paperclip poller."""
|
||||
if settings.paperclip_enabled:
|
||||
poller = PaperclipPoller()
|
||||
asyncio.create_task(poller.poll())
|
||||
|
||||
42
src/timmy/research_tools.py
Normal file
42
src/timmy/research_tools.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Tools for the research pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from config import settings
|
||||
from serpapi import GoogleSearch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def google_web_search(query: str) -> str:
|
||||
"""Perform a Google search and return the results."""
|
||||
if "SERPAPI_API_KEY" not in os.environ:
|
||||
logger.warning("SERPAPI_API_KEY not set, skipping web search")
|
||||
return ""
|
||||
params = {
|
||||
"q": query,
|
||||
"api_key": os.environ["SERPAPI_API_KEY"],
|
||||
}
|
||||
search = GoogleSearch(params)
|
||||
results = search.get_dict()
|
||||
return str(results)
|
||||
|
||||
|
||||
def get_llm_client() -> Any:
|
||||
"""Get an LLM client."""
|
||||
# This is a placeholder. In a real application, this would return
|
||||
# a client for an LLM service like OpenAI, Anthropic, or a local
|
||||
# model.
|
||||
class MockLLMClient:
|
||||
async def completion(self, prompt: str, max_tokens: int) -> Any:
|
||||
class MockCompletion:
|
||||
def __init__(self, text: str) -> None:
|
||||
self.text = text
|
||||
|
||||
return MockCompletion(f"This is a summary of the search results for '{prompt}'.")
|
||||
|
||||
return MockLLMClient()
|
||||
306
tests/scripts/test_export_trajectories.py
Normal file
306
tests/scripts/test_export_trajectories.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""Unit tests for scripts/export_trajectories.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
import scripts.export_trajectories as et
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _ts(base: datetime, offset_minutes: int = 0) -> str:
|
||||
return (base + timedelta(minutes=offset_minutes)).isoformat()
|
||||
|
||||
|
||||
BASE = datetime(2026, 3, 1, 10, 0, 0)
|
||||
|
||||
|
||||
def _make_session_entries(base: datetime = BASE) -> list[dict]:
|
||||
"""Minimal session log entries: user → tool_call → timmy reply."""
|
||||
return [
|
||||
{"type": "message", "role": "user", "content": "list my files", "timestamp": _ts(base, 0)},
|
||||
{"type": "tool_call", "tool": "shell", "args": {"cmd": "ls"}, "result": "a.py\nb.py", "timestamp": _ts(base, 1)},
|
||||
{"type": "message", "role": "timmy", "content": "You have two files.", "timestamp": _ts(base, 2)},
|
||||
]
|
||||
|
||||
|
||||
# ── _group_into_conversations ─────────────────────────────────────────────────
|
||||
|
||||
class TestGroupIntoConversations:
|
||||
def test_empty(self):
|
||||
assert et._group_into_conversations([]) == []
|
||||
|
||||
def test_single_group_no_gap(self):
|
||||
entries = _make_session_entries()
|
||||
groups = et._group_into_conversations(entries, gap_minutes=30)
|
||||
assert len(groups) == 1
|
||||
assert groups[0] == entries
|
||||
|
||||
def test_split_on_large_gap(self):
|
||||
entries_a = _make_session_entries(BASE)
|
||||
# Second set starts 60 minutes later
|
||||
entries_b = _make_session_entries(BASE + timedelta(hours=1))
|
||||
groups = et._group_into_conversations(entries_a + entries_b, gap_minutes=30)
|
||||
assert len(groups) == 2
|
||||
assert len(groups[0]) == 3
|
||||
assert len(groups[1]) == 3
|
||||
|
||||
def test_no_split_within_gap(self):
|
||||
entries = _make_session_entries()
|
||||
groups = et._group_into_conversations(entries, gap_minutes=60)
|
||||
assert len(groups) == 1
|
||||
|
||||
def test_entries_without_timestamp(self):
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "hello"},
|
||||
{"type": "message", "role": "timmy", "content": "hi"},
|
||||
]
|
||||
groups = et._group_into_conversations(entries, gap_minutes=30)
|
||||
assert len(groups) == 1
|
||||
|
||||
|
||||
# ── _conversation_to_sharegpt ─────────────────────────────────────────────────
|
||||
|
||||
class TestConversationToSharegpt:
|
||||
def test_basic_exchange(self):
|
||||
entries = _make_session_entries()
|
||||
result = et._conversation_to_sharegpt(entries)
|
||||
assert result is not None
|
||||
turns = result["conversations"]
|
||||
|
||||
human_turns = [t for t in turns if t["from"] == "human"]
|
||||
gpt_turns = [t for t in turns if t["from"] == "gpt"]
|
||||
tool_turns = [t for t in turns if t["from"] == "tool"]
|
||||
|
||||
assert len(human_turns) == 1
|
||||
assert len(gpt_turns) == 1
|
||||
assert len(tool_turns) == 1
|
||||
|
||||
def test_tool_calls_attached_to_gpt_turn(self):
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "run ls", "timestamp": _ts(BASE, 0)},
|
||||
{"type": "tool_call", "tool": "shell", "args": {}, "result": "ok", "timestamp": _ts(BASE, 1)},
|
||||
{"type": "message", "role": "timmy", "content": "done", "timestamp": _ts(BASE, 2)},
|
||||
]
|
||||
result = et._conversation_to_sharegpt(entries)
|
||||
assert result is not None
|
||||
gpt_turns = [t for t in result["conversations"] if t["from"] == "gpt"]
|
||||
assert len(gpt_turns) == 1
|
||||
assert "tool_calls" in gpt_turns[0]
|
||||
assert gpt_turns[0]["tool_calls"][0]["name"] == "shell"
|
||||
|
||||
def test_too_short_returns_none(self):
|
||||
# Only one meaningful turn → not useful for training
|
||||
entries = [{"type": "message", "role": "user", "content": "hi", "timestamp": _ts(BASE)}]
|
||||
assert et._conversation_to_sharegpt(entries) is None
|
||||
|
||||
def test_empty_content_skipped(self):
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "", "timestamp": _ts(BASE, 0)},
|
||||
{"type": "message", "role": "timmy", "content": "pong", "timestamp": _ts(BASE, 1)},
|
||||
]
|
||||
# Only one non-empty turn → should return None
|
||||
assert et._conversation_to_sharegpt(entries) is None
|
||||
|
||||
def test_role_mapping(self):
|
||||
entries = [
|
||||
{"type": "message", "role": "user", "content": "q", "timestamp": _ts(BASE, 0)},
|
||||
{"type": "message", "role": "assistant", "content": "a", "timestamp": _ts(BASE, 1)},
|
||||
]
|
||||
result = et._conversation_to_sharegpt(entries)
|
||||
assert result is not None
|
||||
roles = [t["from"] for t in result["conversations"]]
|
||||
assert "human" in roles
|
||||
assert "gpt" in roles
|
||||
|
||||
def test_decision_entries_ignored(self):
|
||||
"""Non-message, non-tool entries (decisions, errors) should be skipped."""
|
||||
entries = _make_session_entries() + [
|
||||
{"type": "decision", "decision": "do something", "timestamp": _ts(BASE, 10)},
|
||||
]
|
||||
result = et._conversation_to_sharegpt(entries)
|
||||
assert result is not None
|
||||
assert all(t["from"] != "decision" for t in result["conversations"])
|
||||
|
||||
|
||||
# ── load_from_session_logs ────────────────────────────────────────────────────
|
||||
|
||||
class TestLoadFromSessionLogs:
|
||||
def test_empty_directory(self, tmp_path):
|
||||
assert et.load_from_session_logs(tmp_path) == []
|
||||
|
||||
def test_missing_directory(self, tmp_path):
|
||||
assert et.load_from_session_logs(tmp_path / "nonexistent") == []
|
||||
|
||||
def test_reads_single_log(self, tmp_path):
|
||||
entries = _make_session_entries()
|
||||
log = tmp_path / "session_2026-03-01.jsonl"
|
||||
log.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||
|
||||
result = et.load_from_session_logs(tmp_path)
|
||||
assert len(result) == 1
|
||||
assert result[0]["conversations"][0]["from"] == "human"
|
||||
|
||||
def test_reads_multiple_logs(self, tmp_path):
|
||||
for day in range(3):
|
||||
entries = _make_session_entries(BASE + timedelta(days=day, hours=2 * day))
|
||||
log = tmp_path / f"session_2026-03-0{day + 1}.jsonl"
|
||||
log.write_text("\n".join(json.dumps(e) for e in entries) + "\n")
|
||||
|
||||
result = et.load_from_session_logs(tmp_path)
|
||||
# 3 log files, each a separate conversation (days apart)
|
||||
assert len(result) == 3
|
||||
|
||||
def test_skips_malformed_lines(self, tmp_path):
|
||||
log = tmp_path / "session_2026-03-01.jsonl"
|
||||
entries = _make_session_entries()
|
||||
lines = [json.dumps(e) for e in entries]
|
||||
lines.insert(1, "not valid json{{{")
|
||||
log.write_text("\n".join(lines) + "\n")
|
||||
|
||||
# Should still parse valid entries
|
||||
result = et.load_from_session_logs(tmp_path)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# ── load_from_sqlite ──────────────────────────────────────────────────────────
|
||||
|
||||
class TestLoadFromSqlite:
|
||||
def _make_db(self, tmp_path: Path, rows: list[tuple]) -> Path:
|
||||
db = tmp_path / "chat.db"
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
role TEXT, content TEXT, timestamp TEXT, source TEXT
|
||||
)
|
||||
""")
|
||||
conn.executemany(
|
||||
"INSERT INTO chat_messages (role, content, timestamp, source) VALUES (?,?,?,?)",
|
||||
rows,
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db
|
||||
|
||||
def test_missing_db(self, tmp_path):
|
||||
assert et.load_from_sqlite(tmp_path / "missing.db") == []
|
||||
|
||||
def test_reads_conversation(self, tmp_path):
|
||||
rows = [
|
||||
("user", "hello", _ts(BASE, 0), "browser"),
|
||||
("agent", "hi there", _ts(BASE, 5), "browser"),
|
||||
]
|
||||
db = self._make_db(tmp_path, rows)
|
||||
result = et.load_from_sqlite(db)
|
||||
assert len(result) == 1
|
||||
turns = result[0]["conversations"]
|
||||
assert turns[0]["from"] == "human"
|
||||
assert turns[1]["from"] == "gpt"
|
||||
|
||||
def test_splits_on_gap(self, tmp_path):
|
||||
rows = [
|
||||
("user", "a", _ts(BASE, 0), "browser"),
|
||||
("agent", "b", _ts(BASE, 5), "browser"),
|
||||
("user", "c", _ts(BASE, 120), "browser"), # 2h gap
|
||||
("agent", "d", _ts(BASE, 125), "browser"),
|
||||
]
|
||||
db = self._make_db(tmp_path, rows)
|
||||
result = et.load_from_sqlite(db)
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
# ── validate_output ───────────────────────────────────────────────────────────
|
||||
|
||||
class TestValidateOutput:
|
||||
def test_missing_file(self, tmp_path):
|
||||
stats = et.validate_output(tmp_path / "missing.jsonl")
|
||||
assert "error" in stats
|
||||
|
||||
def test_counts_conversations(self, tmp_path):
|
||||
out = tmp_path / "out.jsonl"
|
||||
convs = [
|
||||
{"conversations": [{"from": "human", "value": "hi"}, {"from": "gpt", "value": "ho"}]},
|
||||
{"conversations": [{"from": "human", "value": "a"}, {"from": "gpt", "value": "b"}]},
|
||||
]
|
||||
out.write_text("\n".join(json.dumps(c) for c in convs) + "\n")
|
||||
stats = et.validate_output(out)
|
||||
assert stats["total_conversations"] == 2
|
||||
assert stats["with_tool_calls"] == 0
|
||||
|
||||
def test_counts_tool_calls(self, tmp_path):
|
||||
out = tmp_path / "out.jsonl"
|
||||
conv = {"conversations": [
|
||||
{"from": "human", "value": "run"},
|
||||
{"from": "gpt", "value": "ok", "tool_calls": [{"name": "shell", "arguments": {}}]},
|
||||
{"from": "tool", "value": "done", "tool": "shell"},
|
||||
]}
|
||||
out.write_text(json.dumps(conv) + "\n")
|
||||
stats = et.validate_output(out)
|
||||
assert stats["with_tool_calls"] == 1
|
||||
|
||||
|
||||
# ── CLI (main) ────────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMain:
|
||||
def test_no_data_exits_0(self, tmp_path):
|
||||
out = tmp_path / "out.jsonl"
|
||||
code = et.main([
|
||||
"--logs-dir", str(tmp_path / "logs"),
|
||||
"--db", str(tmp_path / "missing.db"),
|
||||
"--output", str(out),
|
||||
])
|
||||
assert code == 0
|
||||
assert out.exists()
|
||||
|
||||
def test_exports_from_logs(self, tmp_path):
|
||||
logs = tmp_path / "logs"
|
||||
logs.mkdir()
|
||||
entries = _make_session_entries()
|
||||
(logs / "session_2026-03-01.jsonl").write_text(
|
||||
"\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
)
|
||||
out = tmp_path / "out.jsonl"
|
||||
code = et.main([
|
||||
"--logs-dir", str(logs),
|
||||
"--db", str(tmp_path / "missing.db"),
|
||||
"--output", str(out),
|
||||
])
|
||||
assert code == 0
|
||||
lines = [l for l in out.read_text().splitlines() if l.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
def test_validate_only(self, tmp_path, capsys):
|
||||
out = tmp_path / "out.jsonl"
|
||||
conv = {"conversations": [
|
||||
{"from": "human", "value": "x"},
|
||||
{"from": "gpt", "value": "y"},
|
||||
]}
|
||||
out.write_text(json.dumps(conv) + "\n")
|
||||
code = et.main(["--validate-only", "--output", str(out)])
|
||||
assert code == 0
|
||||
captured = capsys.readouterr()
|
||||
stats = json.loads(captured.out)
|
||||
assert stats["total_conversations"] == 1
|
||||
|
||||
def test_min_examples_fails(self, tmp_path):
|
||||
logs = tmp_path / "logs"
|
||||
logs.mkdir()
|
||||
entries = _make_session_entries()
|
||||
(logs / "session_2026-03-01.jsonl").write_text(
|
||||
"\n".join(json.dumps(e) for e in entries) + "\n"
|
||||
)
|
||||
out = tmp_path / "out.jsonl"
|
||||
code = et.main([
|
||||
"--logs-dir", str(logs),
|
||||
"--db", str(tmp_path / "missing.db"),
|
||||
"--output", str(out),
|
||||
"--min-examples", "100",
|
||||
])
|
||||
assert code == 1
|
||||
Reference in New Issue
Block a user