"""LoRA trainer — triggers fine-tune job and loads the resulting adapter. Supports two backends: 1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI 2. Ollama create (adapter packaging into a new Ollama model) Graceful degradation: if neither backend is available, logs a warning and returns a skipped result — the rest of the loop continues. Refs: #1105 """ from __future__ import annotations import json import logging import os import shutil import subprocess from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path logger = logging.getLogger(__name__) _DEFAULT_BASE_MODEL = "hermes4-14b" _DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters" _MLX_LM_BIN = "mlx_lm.lora" _OLLAMA_BIN = "ollama" @dataclass class TrainResult: """Result of a LoRA fine-tune run.""" status: str # "completed" | "skipped" | "failed" adapter_path: str | None model_name: str | None iteration: int duration_seconds: float message: str train_loss: float | None = None class LoRATrainer: """Orchestrates LoRA fine-tuning and adapter loading. Workflow: 1. Run mlx_lm.lora fine-tune on the training dataset 2. Save the resulting adapter to .loop/retrain/adapters// 3. Create (or update) an Ollama model that uses the new adapter """ def __init__( self, base_model: str = _DEFAULT_BASE_MODEL, adapter_dir: str | Path | None = None, repo_root: str | Path | None = None, dry_run: bool = False, ): if repo_root is None: repo_root = Path(__file__).resolve().parent.parent.parent self._repo_root = Path(repo_root) self._base_model = base_model self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR) self._adapter_dir.mkdir(parents=True, exist_ok=True) self._dry_run = dry_run def train(self, dataset_path: Path, iteration: int) -> TrainResult: """Run LoRA fine-tuning on the dataset. Args: dataset_path: Path to the JSONL training dataset. iteration: Current fine-tune iteration number (used for naming). Returns: TrainResult with status, adapter path, and metrics. """ started = datetime.now(tz=UTC) if not dataset_path.exists() or dataset_path.stat().st_size == 0: return TrainResult( status="skipped", adapter_path=None, model_name=None, iteration=iteration, duration_seconds=0.0, message="Training dataset is empty — skipping fine-tune", ) if self._dry_run: logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path) adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz" return TrainResult( status="skipped", adapter_path=str(adapter_path), model_name=f"{self._base_model}-ft-{iteration:04d}", iteration=iteration, duration_seconds=0.0, message="dry-run mode — no training performed", ) # Determine which backend is available if shutil.which(_MLX_LM_BIN): return self._train_mlx(dataset_path, iteration, started) else: logger.warning( "%s not found — skipping LoRA fine-tune (install mlx-lm to enable)", _MLX_LM_BIN, ) return TrainResult( status="skipped", adapter_path=None, model_name=None, iteration=iteration, duration_seconds=0.0, message=( f"{_MLX_LM_BIN} not available. " "Install mlx-lm on Apple Silicon to enable LoRA fine-tuning." ), ) def _train_mlx( self, dataset_path: Path, iteration: int, started: datetime ) -> TrainResult: """Run mlx_lm.lora fine-tune.""" adapter_out = self._adapter_dir / f"iter_{iteration:04d}" adapter_out.mkdir(parents=True, exist_ok=True) cmd = [ _MLX_LM_BIN, "--model", self._base_model, "--data", str(dataset_path), "--adapter-path", str(adapter_out), "--train", "--iters", "100", "--batch-size", "1", "--learning-rate", "1e-5", ] logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration) logger.info("Command: %s", " ".join(cmd)) try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=3600, # 1 hour max env={**os.environ, "PYTHONUNBUFFERED": "1"}, ) except subprocess.TimeoutExpired: duration = (datetime.now(tz=UTC) - started).total_seconds() return TrainResult( status="failed", adapter_path=None, model_name=None, iteration=iteration, duration_seconds=duration, message="Fine-tune timed out after 1 hour", ) except Exception as exc: duration = (datetime.now(tz=UTC) - started).total_seconds() return TrainResult( status="failed", adapter_path=None, model_name=None, iteration=iteration, duration_seconds=duration, message=f"Fine-tune subprocess error: {exc}", ) duration = (datetime.now(tz=UTC) - started).total_seconds() if result.returncode != 0: logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500]) return TrainResult( status="failed", adapter_path=None, model_name=None, iteration=iteration, duration_seconds=duration, message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}", ) # Parse final train loss from stdout if available train_loss = _parse_train_loss(result.stdout) adapter_file = adapter_out / "adapters.npz" model_name = f"{self._base_model}-ft-{iteration:04d}" # Attempt to register with Ollama ollama_ok = self._register_ollama_adapter(adapter_out, model_name) if not ollama_ok: logger.warning("Ollama adapter registration failed — adapter saved locally") logger.info( "Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s", iteration, train_loss or 0.0, duration, adapter_file, ) return TrainResult( status="completed", adapter_path=str(adapter_file), model_name=model_name, iteration=iteration, duration_seconds=duration, message=f"LoRA fine-tune completed successfully in {duration:.0f}s", train_loss=train_loss, ) def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool: """Create an Ollama model entry for the new adapter. Writes a minimal Modelfile and runs `ollama create`. """ if not shutil.which(_OLLAMA_BIN): logger.debug("Ollama not found — skipping adapter registration") return False modelfile_content = ( f"FROM {self._base_model}\n" f"ADAPTER {adapter_dir}\n" ) modelfile_path = adapter_dir / "Modelfile" try: modelfile_path.write_text(modelfile_content) result = subprocess.run( [_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)], capture_output=True, text=True, timeout=300, ) if result.returncode == 0: logger.info("Ollama model registered: %s", model_name) return True else: logger.warning("ollama create failed: %s", result.stderr[:200]) return False except Exception as exc: logger.warning("Ollama adapter registration error: %s", exc) return False def _parse_train_loss(stdout: str) -> float | None: """Extract the final training loss from mlx-lm stdout.""" loss: float | None = None for line in stdout.splitlines(): line_lower = line.lower() if "train loss" in line_lower or "loss:" in line_lower: parts = line.split() for i, part in enumerate(parts): if "loss" in part.lower() and i + 1 < len(parts): try: loss = float(parts[i + 1].strip(",:")) except ValueError: pass return loss