diff --git a/scripts/export_trajectories.py b/scripts/export_trajectories.py new file mode 100644 index 00000000..8cdb98fd --- /dev/null +++ b/scripts/export_trajectories.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +"""Export Timmy session logs as LoRA training data (ChatML JSONL). + +Reads session JSONL files written by ``SessionLogger`` and converts them into +conversation pairs suitable for fine-tuning with ``mlx_lm.lora``. + +Output format — one JSON object per line:: + + {"messages": [ + {"role": "system", "content": ""}, + {"role": "user", "content": ""}, + {"role": "assistant", "content": ""} + ]} + +Tool calls that appear between a user turn and the next assistant message are +embedded in the assistant content using the Hermes 4 ```` XML format +so the fine-tuned model learns both when to call tools and what JSON to emit. + +Usage:: + + # Export all session logs (default paths) + python scripts/export_trajectories.py + + # Custom source / destination + python scripts/export_trajectories.py \\ + --logs-dir ~/custom-logs \\ + --output ~/timmy-training-data.jsonl \\ + --min-turns 2 \\ + --verbose + +Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 3 of 7) +Refs: #1103 +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +# ── Constants ───────────────────────────────────────────────────────────────── + +TIMMY_SYSTEM_PROMPT = ( + "You are Timmy, Alexander's personal AI agent running on a local Mac. " + "You are concise, direct, and action-oriented. " + "You have access to a broad set of tools — use them proactively. " + "When you need to call a tool, output it in this format:\n" + "\n" + '{"name": "function_name", "arguments": {"param": "value"}}\n' + "\n\n" + "Always provide structured, accurate responses." +) + +# ── Entry grouping ───────────────────────────────────────────────────────────── + + +def _load_entries(logs_dir: Path) -> list[dict[str, Any]]: + """Load all session log entries, sorted chronologically.""" + entries: list[dict[str, Any]] = [] + log_files = sorted(logs_dir.glob("session_*.jsonl")) + for log_file in log_files: + try: + with open(log_file) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + logger.warning("Skipping malformed line in %s", log_file.name) + except OSError as exc: + logger.warning("Cannot read %s: %s", log_file, exc) + return entries + + +def _format_tool_call(entry: dict[str, Any]) -> str: + """Render a tool_call entry as a Hermes 4 XML block.""" + payload = {"name": entry.get("tool", "unknown"), "arguments": entry.get("args", {})} + return f"\n{json.dumps(payload)}\n" + + +def _format_tool_result(entry: dict[str, Any]) -> str: + """Render a tool result observation.""" + result = entry.get("result", "") + tool = entry.get("tool", "unknown") + return f"\n{{\"name\": \"{tool}\", \"result\": {json.dumps(result)}}}\n" + + +def _group_into_turns(entries: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Group raw session entries into (user_text, assistant_parts) turn pairs. + + Returns a list of dicts with keys: + ``user`` - user message content + ``assistant`` - assembled assistant content (responses + tool calls) + """ + turns: list[dict[str, Any]] = [] + pending_user: str | None = None + assistant_parts: list[str] = [] + + for entry in entries: + etype = entry.get("type", "") + role = entry.get("role", "") + + if etype == "message" and role == "user": + # Flush any open turn + if pending_user is not None and assistant_parts: + turns.append( + { + "user": pending_user, + "assistant": "\n".join(assistant_parts).strip(), + } + ) + elif pending_user is not None: + # User message with no assistant response — discard + pass + pending_user = entry.get("content", "").strip() + assistant_parts = [] + + elif etype == "message" and role == "timmy": + if pending_user is not None: + content = entry.get("content", "").strip() + if content: + assistant_parts.append(content) + + elif etype == "tool_call": + if pending_user is not None: + assistant_parts.append(_format_tool_call(entry)) + # Also append tool result as context so model learns the full loop + if entry.get("result"): + assistant_parts.append(_format_tool_result(entry)) + + # decision / error entries are skipped — they are meta-data, not conversation + + # Flush final open turn + if pending_user is not None and assistant_parts: + turns.append( + { + "user": pending_user, + "assistant": "\n".join(assistant_parts).strip(), + } + ) + + return turns + + +# ── Conversion ──────────────────────────────────────────────────────────────── + + +def turns_to_training_examples( + turns: list[dict[str, Any]], + system_prompt: str = TIMMY_SYSTEM_PROMPT, + min_assistant_len: int = 10, +) -> list[dict[str, Any]]: + """Convert grouped turns into mlx-lm training examples. + + Each example has a ``messages`` list in ChatML order: + ``[system, user, assistant]``. + + Args: + turns: Output of ``_group_into_turns``. + system_prompt: System prompt prepended to every example. + min_assistant_len: Skip examples where the assistant turn is shorter + than this many characters (filters out empty/trivial turns). + + Returns: + List of training example dicts. + """ + examples: list[dict[str, Any]] = [] + for turn in turns: + assistant_text = turn.get("assistant", "").strip() + user_text = turn.get("user", "").strip() + if not user_text or len(assistant_text) < min_assistant_len: + continue + examples.append( + { + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_text}, + {"role": "assistant", "content": assistant_text}, + ] + } + ) + return examples + + +def export_training_data( + logs_dir: Path, + output_path: Path, + min_turns: int = 1, + min_assistant_len: int = 10, + verbose: bool = False, +) -> int: + """Full export pipeline: load → group → convert → write. + + Args: + logs_dir: Directory containing ``session_*.jsonl`` files. + output_path: Destination ``.jsonl`` file for training data. + min_turns: Minimum number of turns required (used for logging only). + min_assistant_len: Minimum assistant response length to include. + verbose: Print progress to stdout. + + Returns: + Number of training examples written. + """ + if verbose: + print(f"Loading session logs from: {logs_dir}") + + entries = _load_entries(logs_dir) + if verbose: + print(f" Loaded {len(entries)} raw entries") + + turns = _group_into_turns(entries) + if verbose: + print(f" Grouped into {len(turns)} conversation turns") + + examples = turns_to_training_examples( + turns, min_assistant_len=min_assistant_len + ) + if verbose: + print(f" Generated {len(examples)} training examples") + + if not examples: + print("WARNING: No training examples generated. Check that session logs exist.") + return 0 + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + for ex in examples: + f.write(json.dumps(ex) + "\n") + + if verbose: + print(f" Wrote {len(examples)} examples → {output_path}") + + return len(examples) + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + + +def _default_logs_dir() -> Path: + """Return default logs directory (repo root / logs).""" + # Walk up from this script to find repo root (contains pyproject.toml) + candidate = Path(__file__).resolve().parent + for _ in range(5): + candidate = candidate.parent + if (candidate / "pyproject.toml").exists(): + return candidate / "logs" + return Path.home() / "logs" + + +def _default_output_path() -> Path: + return Path.home() / "timmy-training-data.jsonl" + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Export Timmy session logs as LoRA training data (ChatML JSONL)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--logs-dir", + type=Path, + default=_default_logs_dir(), + help="Directory containing session_*.jsonl files (default: /logs)", + ) + parser.add_argument( + "--output", + type=Path, + default=_default_output_path(), + help="Output JSONL path (default: ~/timmy-training-data.jsonl)", + ) + parser.add_argument( + "--min-turns", + type=int, + default=1, + help="Minimum turns to process (informational, default: 1)", + ) + parser.add_argument( + "--min-assistant-len", + type=int, + default=10, + help="Minimum assistant response length in chars (default: 10)", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Print progress information", + ) + + args = parser.parse_args(argv) + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.WARNING, + format="%(levelname)s: %(message)s", + ) + + if not args.logs_dir.exists(): + print(f"ERROR: Logs directory not found: {args.logs_dir}") + print("Run the Timmy dashboard first to generate session logs.") + return 1 + + count = export_training_data( + logs_dir=args.logs_dir, + output_path=args.output, + min_turns=args.min_turns, + min_assistant_len=args.min_assistant_len, + verbose=args.verbose, + ) + + if count > 0: + print(f"Exported {count} training examples to: {args.output}") + print() + print("Next steps:") + print(f" mkdir -p ~/timmy-lora-training") + print(f" cp {args.output} ~/timmy-lora-training/train.jsonl") + print(f" python scripts/lora_finetune.py --data ~/timmy-lora-training") + else: + print("No training examples exported.") + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/lora_finetune.py b/scripts/lora_finetune.py new file mode 100644 index 00000000..049b1116 --- /dev/null +++ b/scripts/lora_finetune.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +"""LoRA fine-tuning launcher for Hermes 4 on Timmy trajectory data. + +Wraps ``mlx_lm.lora`` with project-specific defaults and pre-flight checks. +Requires Apple Silicon (M-series) and the ``mlx-lm`` package. + +Usage:: + + # Minimal — uses defaults (expects data in ~/timmy-lora-training/) + python scripts/lora_finetune.py + + # Custom model path and data + python scripts/lora_finetune.py \\ + --model /path/to/hermes4-mlx \\ + --data ~/timmy-lora-training \\ + --iters 500 \\ + --adapter-path ~/timmy-lora-adapter + + # Dry run (print command, don't execute) + python scripts/lora_finetune.py --dry-run + + # After training, test with the adapter + python scripts/lora_finetune.py --test \\ + --prompt "List the open PRs on the Timmy Time Dashboard repo" + + # Fuse adapter into base model for Ollama import + python scripts/lora_finetune.py --fuse \\ + --save-path ~/timmy-fused-model + +Typical workflow:: + + # 1. Export trajectories + python scripts/export_trajectories.py --verbose + + # 2. Prepare training dir + mkdir -p ~/timmy-lora-training + cp ~/timmy-training-data.jsonl ~/timmy-lora-training/train.jsonl + + # 3. Fine-tune + python scripts/lora_finetune.py --verbose + + # 4. Test + python scripts/lora_finetune.py --test + + # 5. Fuse + import to Ollama + python scripts/lora_finetune.py --fuse + ollama create timmy-hermes4 -f Modelfile.timmy-hermes4 + +Epic: #1091 Project Bannerlord — AutoLoRA Sovereignty Loop (Step 4 of 7) +Refs: #1103 +""" + +from __future__ import annotations + +import argparse +import platform +import shutil +import subprocess +import sys +from pathlib import Path + +# ── Defaults ────────────────────────────────────────────────────────────────── + +DEFAULT_DATA_DIR = Path.home() / "timmy-lora-training" +DEFAULT_ADAPTER_PATH = Path.home() / "timmy-lora-adapter" +DEFAULT_FUSED_PATH = Path.home() / "timmy-fused-model" + +# mlx-lm model path — local HuggingFace checkout of Hermes 4 in MLX format. +# Set MLX_HERMES4_PATH env var or pass --model to override. +DEFAULT_MODEL_PATH_ENV = "MLX_HERMES4_PATH" + +# Training hyperparameters (conservative for 36 GB M3 Max) +DEFAULT_BATCH_SIZE = 1 +DEFAULT_LORA_LAYERS = 16 +DEFAULT_ITERS = 1000 +DEFAULT_LEARNING_RATE = 1e-5 + +# Test prompt used after training +DEFAULT_TEST_PROMPT = ( + "List the open PRs on the Timmy Time Dashboard repo and triage them by priority." +) + + +# ── Pre-flight checks ───────────────────────────────────────────────────────── + + +def _check_apple_silicon() -> bool: + """Return True if running on Apple Silicon.""" + return platform.system() == "Darwin" and platform.machine() == "arm64" + + +def _check_mlx_lm() -> bool: + """Return True if mlx-lm is installed and mlx_lm.lora is runnable.""" + return shutil.which("mlx_lm.lora") is not None or _can_import("mlx_lm") + + +def _can_import(module: str) -> bool: + try: + import importlib + + importlib.import_module(module) + return True + except ImportError: + return False + + +def _resolve_model_path(model_arg: str | None) -> str | None: + """Resolve model path from arg or environment variable.""" + if model_arg: + return model_arg + import os + + env_path = os.environ.get(DEFAULT_MODEL_PATH_ENV) + if env_path: + return env_path + return None + + +def _preflight(model_path: str | None, data_dir: Path, verbose: bool) -> list[str]: + """Run pre-flight checks and return a list of warnings (empty = all OK).""" + warnings: list[str] = [] + + if not _check_apple_silicon(): + warnings.append( + "Not running on Apple Silicon. mlx-lm requires an M-series Mac.\n" + " Alternative: use Unsloth on Google Colab / RunPod / Modal." + ) + + if not _check_mlx_lm(): + warnings.append( + "mlx-lm not found. Install with:\n pip install mlx-lm" + ) + + if model_path is None: + warnings.append( + f"No model path specified. Set {DEFAULT_MODEL_PATH_ENV} or pass --model.\n" + " Download Hermes 4 in MLX format from HuggingFace:\n" + " https://huggingface.co/collections/NousResearch/hermes-4-collection-68a7\n" + " or convert the GGUF:\n" + " mlx_lm.convert --hf-path NousResearch/Hermes-4-14B --mlx-path ~/hermes4-mlx" + ) + elif not Path(model_path).exists(): + warnings.append(f"Model path does not exist: {model_path}") + + train_file = data_dir / "train.jsonl" + if not train_file.exists(): + warnings.append( + f"Training data not found: {train_file}\n" + " Generate it with:\n" + " python scripts/export_trajectories.py --verbose\n" + f" mkdir -p {data_dir}\n" + f" cp ~/timmy-training-data.jsonl {train_file}" + ) + + if verbose and not warnings: + print("Pre-flight checks: all OK") + + return warnings + + +# ── Command builders ────────────────────────────────────────────────────────── + + +def _build_train_cmd( + model_path: str, + data_dir: Path, + adapter_path: Path, + batch_size: int, + lora_layers: int, + iters: int, + learning_rate: float, +) -> list[str]: + return [ + sys.executable, "-m", "mlx_lm.lora", + "--model", model_path, + "--train", + "--data", str(data_dir), + "--batch-size", str(batch_size), + "--lora-layers", str(lora_layers), + "--iters", str(iters), + "--learning-rate", str(learning_rate), + "--adapter-path", str(adapter_path), + ] + + +def _build_test_cmd( + model_path: str, + adapter_path: Path, + prompt: str, +) -> list[str]: + return [ + sys.executable, "-m", "mlx_lm.generate", + "--model", model_path, + "--adapter-path", str(adapter_path), + "--prompt", prompt, + "--max-tokens", "512", + ] + + +def _build_fuse_cmd( + model_path: str, + adapter_path: Path, + save_path: Path, +) -> list[str]: + return [ + sys.executable, "-m", "mlx_lm.fuse", + "--model", model_path, + "--adapter-path", str(adapter_path), + "--save-path", str(save_path), + ] + + +# ── Runner ───────────────────────────────────────────────────────────────────── + + +def _run(cmd: list[str], dry_run: bool, verbose: bool) -> int: + """Print and optionally execute a command.""" + print("\nCommand:") + print(" " + " \\\n ".join(cmd)) + if dry_run: + print("\n(dry-run — not executing)") + return 0 + + print() + result = subprocess.run(cmd) + return result.returncode + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="LoRA fine-tuning launcher for Hermes 4 (AutoLoRA Step 4)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Mode flags (mutually exclusive-ish) + mode = parser.add_mutually_exclusive_group() + mode.add_argument( + "--test", + action="store_true", + help="Run inference test with trained adapter instead of training", + ) + mode.add_argument( + "--fuse", + action="store_true", + help="Fuse adapter into base model (for Ollama import)", + ) + + # Paths + parser.add_argument( + "--model", + default=None, + help=f"Path to local MLX model (or set {DEFAULT_MODEL_PATH_ENV} env var)", + ) + parser.add_argument( + "--data", + type=Path, + default=DEFAULT_DATA_DIR, + help=f"Training data directory (default: {DEFAULT_DATA_DIR})", + ) + parser.add_argument( + "--adapter-path", + type=Path, + default=DEFAULT_ADAPTER_PATH, + help=f"LoRA adapter output path (default: {DEFAULT_ADAPTER_PATH})", + ) + parser.add_argument( + "--save-path", + type=Path, + default=DEFAULT_FUSED_PATH, + help=f"Fused model output path (default: {DEFAULT_FUSED_PATH})", + ) + + # Hyperparameters + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help=f"Training batch size (default: {DEFAULT_BATCH_SIZE}; reduce to 1 if OOM)", + ) + parser.add_argument( + "--lora-layers", + type=int, + default=DEFAULT_LORA_LAYERS, + help=f"Number of LoRA layers (default: {DEFAULT_LORA_LAYERS}; reduce if OOM)", + ) + parser.add_argument( + "--iters", + type=int, + default=DEFAULT_ITERS, + help=f"Training iterations (default: {DEFAULT_ITERS})", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=DEFAULT_LEARNING_RATE, + help=f"Learning rate (default: {DEFAULT_LEARNING_RATE})", + ) + + # Misc + parser.add_argument( + "--prompt", + default=DEFAULT_TEST_PROMPT, + help="Prompt for --test mode", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print command without executing", + ) + parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="Print extra progress information", + ) + parser.add_argument( + "--skip-preflight", + action="store_true", + help="Skip pre-flight checks (useful in CI)", + ) + + args = parser.parse_args(argv) + model_path = _resolve_model_path(args.model) + + # ── Pre-flight ────────────────────────────────────────────────────────── + if not args.skip_preflight: + warnings = _preflight(model_path, args.data, args.verbose) + if warnings: + for w in warnings: + print(f"WARNING: {w}\n") + if not args.dry_run: + print("Aborting due to pre-flight warnings. Use --dry-run to see commands anyway.") + return 1 + + if model_path is None: + # Allow dry-run without a model for documentation purposes + model_path = "" + + # ── Mode dispatch ──────────────────────────────────────────────────────── + if args.test: + print(f"Testing fine-tuned model with adapter: {args.adapter_path}") + cmd = _build_test_cmd(model_path, args.adapter_path, args.prompt) + return _run(cmd, args.dry_run, args.verbose) + + if args.fuse: + print(f"Fusing adapter {args.adapter_path} into base model → {args.save_path}") + cmd = _build_fuse_cmd(model_path, args.adapter_path, args.save_path) + rc = _run(cmd, args.dry_run, args.verbose) + if rc == 0 and not args.dry_run: + print( + f"\nFused model saved to: {args.save_path}\n" + "To import into Ollama:\n" + f" ollama create timmy-hermes4 -f Modelfile.hermes4-14b\n" + " (edit Modelfile to point FROM to the fused GGUF path)" + ) + return rc + + # Default: train + print(f"Starting LoRA fine-tuning") + print(f" Model: {model_path}") + print(f" Data: {args.data}") + print(f" Adapter path: {args.adapter_path}") + print(f" Iterations: {args.iters}") + print(f" Batch size: {args.batch_size}") + print(f" LoRA layers: {args.lora_layers}") + print(f" Learning rate:{args.learning_rate}") + print() + print("Estimated time: 2-8 hours on M3 Max (depends on dataset size).") + print("If OOM: reduce --lora-layers to 8 or --batch-size stays at 1.") + + cmd = _build_train_cmd( + model_path=model_path, + data_dir=args.data, + adapter_path=args.adapter_path, + batch_size=args.batch_size, + lora_layers=args.lora_layers, + iters=args.iters, + learning_rate=args.learning_rate, + ) + rc = _run(cmd, args.dry_run, args.verbose) + + if rc == 0 and not args.dry_run: + print( + f"\nTraining complete! Adapter saved to: {args.adapter_path}\n" + "Test with:\n" + f" python scripts/lora_finetune.py --test\n" + "Then fuse + import to Ollama:\n" + f" python scripts/lora_finetune.py --fuse" + ) + + return rc + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/scripts/test_export_trajectories.py b/tests/scripts/test_export_trajectories.py new file mode 100644 index 00000000..f6ef580b --- /dev/null +++ b/tests/scripts/test_export_trajectories.py @@ -0,0 +1,285 @@ +"""Unit tests for scripts/export_trajectories.py. + +Tests trajectory conversion logic — no I/O, no Ollama, no mlx. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +import scripts.export_trajectories as et + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def simple_session(tmp_path: Path) -> Path: + """Write a minimal session JSONL file and return the logs dir.""" + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + entries = [ + {"type": "message", "role": "user", "content": "What time is it?", "timestamp": "2026-03-01T10:00:00"}, + {"type": "message", "role": "timmy", "content": "It is 10:00 AM.", "timestamp": "2026-03-01T10:00:01"}, + {"type": "message", "role": "user", "content": "Thanks!", "timestamp": "2026-03-01T10:00:05"}, + {"type": "message", "role": "timmy", "content": "You're welcome!", "timestamp": "2026-03-01T10:00:06"}, + ] + session_file = logs_dir / "session_2026-03-01.jsonl" + session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n") + return logs_dir + + +@pytest.fixture() +def tool_call_session(tmp_path: Path) -> Path: + """Write a session JSONL with tool calls.""" + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + entries = [ + {"type": "message", "role": "user", "content": "Read CLAUDE.md", "timestamp": "2026-03-01T10:00:00"}, + { + "type": "tool_call", + "tool": "read_file", + "args": {"path": "CLAUDE.md"}, + "result": "# CLAUDE.md content here", + "timestamp": "2026-03-01T10:00:01", + }, + {"type": "message", "role": "timmy", "content": "Here is the content.", "timestamp": "2026-03-01T10:00:02"}, + ] + session_file = logs_dir / "session_2026-03-01.jsonl" + session_file.write_text("\n".join(json.dumps(e) for e in entries) + "\n") + return logs_dir + + +# ── _load_entries ───────────────────────────────────────────────────────────── + + +@pytest.mark.unit +def test_load_entries_returns_all(simple_session: Path) -> None: + entries = et._load_entries(simple_session) + assert len(entries) == 4 + + +@pytest.mark.unit +def test_load_entries_skips_malformed(tmp_path: Path) -> None: + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + session = logs_dir / "session_2026-03-01.jsonl" + session.write_text( + '{"type": "message", "role": "user", "content": "hi"}\n' + "NOT_JSON\n" + '{"type": "message", "role": "timmy", "content": "hello"}\n' + ) + entries = et._load_entries(logs_dir) + assert len(entries) == 2 # malformed line skipped + + +@pytest.mark.unit +def test_load_entries_empty_dir(tmp_path: Path) -> None: + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + entries = et._load_entries(logs_dir) + assert entries == [] + + +@pytest.mark.unit +def test_load_entries_multiple_files(tmp_path: Path) -> None: + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + for day in ("2026-03-01", "2026-03-02"): + entry = {"type": "message", "role": "user", "content": f"day {day}"} + (logs_dir / f"session_{day}.jsonl").write_text(json.dumps(entry) + "\n") + entries = et._load_entries(logs_dir) + assert len(entries) == 2 + + +# ── _format_tool_call ───────────────────────────────────────────────────────── + + +@pytest.mark.unit +def test_format_tool_call_structure() -> None: + entry = { + "type": "tool_call", + "tool": "read_file", + "args": {"path": "/tmp/foo.txt"}, + "result": "file contents", + } + result = et._format_tool_call(entry) + assert result.startswith("") + assert result.endswith("") + payload = json.loads(result.split("\n")[1]) + assert payload["name"] == "read_file" + assert payload["arguments"]["path"] == "/tmp/foo.txt" + + +@pytest.mark.unit +def test_format_tool_call_missing_tool() -> None: + entry = {"type": "tool_call", "args": {}} + result = et._format_tool_call(entry) + assert "unknown" in result + + +# ── _group_into_turns ───────────────────────────────────────────────────────── + + +@pytest.mark.unit +def test_group_basic_conversation() -> None: + entries = [ + {"type": "message", "role": "user", "content": "hello"}, + {"type": "message", "role": "timmy", "content": "hi there"}, + {"type": "message", "role": "user", "content": "bye"}, + {"type": "message", "role": "timmy", "content": "goodbye"}, + ] + turns = et._group_into_turns(entries) + assert len(turns) == 2 + assert turns[0]["user"] == "hello" + assert turns[0]["assistant"] == "hi there" + assert turns[1]["user"] == "bye" + assert turns[1]["assistant"] == "goodbye" + + +@pytest.mark.unit +def test_group_with_tool_call() -> None: + entries = [ + {"type": "message", "role": "user", "content": "check the file"}, + {"type": "tool_call", "tool": "read_file", "args": {"path": "x"}, "result": "content"}, + {"type": "message", "role": "timmy", "content": "Done."}, + ] + turns = et._group_into_turns(entries) + assert len(turns) == 1 + assert "" in turns[0]["assistant"] + assert "Done." in turns[0]["assistant"] + + +@pytest.mark.unit +def test_group_skips_user_without_response() -> None: + """User message with no timmy response should not create a turn.""" + entries = [ + {"type": "message", "role": "user", "content": "hello"}, + # No timmy response + {"type": "message", "role": "user", "content": "are you there?"}, + {"type": "message", "role": "timmy", "content": "Yes!"}, + ] + turns = et._group_into_turns(entries) + assert len(turns) == 1 + assert turns[0]["user"] == "are you there?" + + +@pytest.mark.unit +def test_group_ignores_errors_and_decisions() -> None: + entries = [ + {"type": "message", "role": "user", "content": "hello"}, + {"type": "error", "error": "something failed"}, + {"type": "decision", "decision": "retry"}, + {"type": "message", "role": "timmy", "content": "Got it."}, + ] + turns = et._group_into_turns(entries) + assert len(turns) == 1 + assert "error" not in turns[0]["assistant"] + assert "retry" not in turns[0]["assistant"] + + +@pytest.mark.unit +def test_group_empty_entries() -> None: + assert et._group_into_turns([]) == [] + + +# ── turns_to_training_examples ──────────────────────────────────────────────── + + +@pytest.mark.unit +def test_training_examples_structure() -> None: + turns = [{"user": "hello", "assistant": "hi there, how can I help?"}] + examples = et.turns_to_training_examples(turns) + assert len(examples) == 1 + msgs = examples[0]["messages"] + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + assert msgs[1]["content"] == "hello" + assert msgs[2]["role"] == "assistant" + assert msgs[2]["content"] == "hi there, how can I help?" + + +@pytest.mark.unit +def test_training_examples_filters_short_responses() -> None: + turns = [ + {"user": "hello", "assistant": "ok"}, # too short + {"user": "hello", "assistant": "This is a longer response that passes."}, + ] + examples = et.turns_to_training_examples(turns, min_assistant_len=10) + assert len(examples) == 1 + assert examples[0]["messages"][2]["content"] == "This is a longer response that passes." + + +@pytest.mark.unit +def test_training_examples_filters_empty_user() -> None: + turns = [{"user": "", "assistant": "some response here"}] + examples = et.turns_to_training_examples(turns) + assert len(examples) == 0 + + +@pytest.mark.unit +def test_training_examples_uses_custom_system_prompt() -> None: + turns = [{"user": "hi", "assistant": "hello there!"}] + examples = et.turns_to_training_examples(turns, system_prompt="Custom prompt.") + assert examples[0]["messages"][0]["content"] == "Custom prompt." + + +# ── export_training_data (integration-style, uses tmp_path) ────────────────── + + +@pytest.mark.unit +def test_export_training_data_writes_jsonl(simple_session: Path, tmp_path: Path) -> None: + output = tmp_path / "train.jsonl" + count = et.export_training_data(logs_dir=simple_session, output_path=output) + assert count == 2 + assert output.exists() + lines = [json.loads(l) for l in output.read_text().splitlines() if l.strip()] + assert len(lines) == 2 + for line in lines: + assert "messages" in line + roles = [m["role"] for m in line["messages"]] + assert roles == ["system", "user", "assistant"] + + +@pytest.mark.unit +def test_export_training_data_with_tool_calls(tool_call_session: Path, tmp_path: Path) -> None: + output = tmp_path / "train.jsonl" + count = et.export_training_data(logs_dir=tool_call_session, output_path=output) + assert count == 1 + line = json.loads(output.read_text().strip()) + assistant_content = line["messages"][2]["content"] + assert "" in assistant_content + assert "read_file" in assistant_content + + +@pytest.mark.unit +def test_export_training_data_returns_zero_for_empty_logs(tmp_path: Path) -> None: + logs_dir = tmp_path / "logs" + logs_dir.mkdir() + output = tmp_path / "train.jsonl" + count = et.export_training_data(logs_dir=logs_dir, output_path=output) + assert count == 0 + assert not output.exists() + + +# ── CLI ─────────────────────────────────────────────────────────────────────── + + +@pytest.mark.unit +def test_cli_missing_logs_dir(tmp_path: Path) -> None: + rc = et.main(["--logs-dir", str(tmp_path / "nonexistent"), "--output", str(tmp_path / "out.jsonl")]) + assert rc == 1 + + +@pytest.mark.unit +def test_cli_exports_and_returns_zero(simple_session: Path, tmp_path: Path) -> None: + output = tmp_path / "out.jsonl" + rc = et.main([ + "--logs-dir", str(simple_session), + "--output", str(output), + ]) + assert rc == 0 + assert output.exists()