263 lines
8.8 KiB
Python
263 lines
8.8 KiB
Python
"""LoRA trainer — triggers fine-tune job and loads the resulting adapter.
|
|
|
|
Supports two backends:
|
|
1. mlx-lm (default, Apple Silicon) — `mlx_lm.lora` CLI
|
|
2. Ollama create (adapter packaging into a new Ollama model)
|
|
|
|
Graceful degradation: if neither backend is available, logs a warning
|
|
and returns a skipped result — the rest of the loop continues.
|
|
|
|
Refs: #1105
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_BASE_MODEL = "hermes4-14b"
|
|
_DEFAULT_ADAPTER_DIR = ".loop/retrain/adapters"
|
|
_MLX_LM_BIN = "mlx_lm.lora"
|
|
_OLLAMA_BIN = "ollama"
|
|
|
|
|
|
@dataclass
|
|
class TrainResult:
|
|
"""Result of a LoRA fine-tune run."""
|
|
|
|
status: str # "completed" | "skipped" | "failed"
|
|
adapter_path: str | None
|
|
model_name: str | None
|
|
iteration: int
|
|
duration_seconds: float
|
|
message: str
|
|
train_loss: float | None = None
|
|
|
|
|
|
class LoRATrainer:
|
|
"""Orchestrates LoRA fine-tuning and adapter loading.
|
|
|
|
Workflow:
|
|
1. Run mlx_lm.lora fine-tune on the training dataset
|
|
2. Save the resulting adapter to .loop/retrain/adapters/<iteration>/
|
|
3. Create (or update) an Ollama model that uses the new adapter
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_model: str = _DEFAULT_BASE_MODEL,
|
|
adapter_dir: str | Path | None = None,
|
|
repo_root: str | Path | None = None,
|
|
dry_run: bool = False,
|
|
):
|
|
if repo_root is None:
|
|
repo_root = Path(__file__).resolve().parent.parent.parent
|
|
self._repo_root = Path(repo_root)
|
|
|
|
self._base_model = base_model
|
|
self._adapter_dir = self._repo_root / (adapter_dir or _DEFAULT_ADAPTER_DIR)
|
|
self._adapter_dir.mkdir(parents=True, exist_ok=True)
|
|
self._dry_run = dry_run
|
|
|
|
def train(self, dataset_path: Path, iteration: int) -> TrainResult:
|
|
"""Run LoRA fine-tuning on the dataset.
|
|
|
|
Args:
|
|
dataset_path: Path to the JSONL training dataset.
|
|
iteration: Current fine-tune iteration number (used for naming).
|
|
|
|
Returns:
|
|
TrainResult with status, adapter path, and metrics.
|
|
"""
|
|
started = datetime.now(tz=UTC)
|
|
|
|
if not dataset_path.exists() or dataset_path.stat().st_size == 0:
|
|
return TrainResult(
|
|
status="skipped",
|
|
adapter_path=None,
|
|
model_name=None,
|
|
iteration=iteration,
|
|
duration_seconds=0.0,
|
|
message="Training dataset is empty — skipping fine-tune",
|
|
)
|
|
|
|
if self._dry_run:
|
|
logger.info("[dry-run] Would fine-tune %s on %s", self._base_model, dataset_path)
|
|
adapter_path = self._adapter_dir / f"iter_{iteration:04d}" / "adapters.npz"
|
|
return TrainResult(
|
|
status="skipped",
|
|
adapter_path=str(adapter_path),
|
|
model_name=f"{self._base_model}-ft-{iteration:04d}",
|
|
iteration=iteration,
|
|
duration_seconds=0.0,
|
|
message="dry-run mode — no training performed",
|
|
)
|
|
|
|
# Determine which backend is available
|
|
if shutil.which(_MLX_LM_BIN):
|
|
return self._train_mlx(dataset_path, iteration, started)
|
|
else:
|
|
logger.warning(
|
|
"%s not found — skipping LoRA fine-tune (install mlx-lm to enable)",
|
|
_MLX_LM_BIN,
|
|
)
|
|
return TrainResult(
|
|
status="skipped",
|
|
adapter_path=None,
|
|
model_name=None,
|
|
iteration=iteration,
|
|
duration_seconds=0.0,
|
|
message=(
|
|
f"{_MLX_LM_BIN} not available. "
|
|
"Install mlx-lm on Apple Silicon to enable LoRA fine-tuning."
|
|
),
|
|
)
|
|
|
|
def _train_mlx(
|
|
self, dataset_path: Path, iteration: int, started: datetime
|
|
) -> TrainResult:
|
|
"""Run mlx_lm.lora fine-tune."""
|
|
adapter_out = self._adapter_dir / f"iter_{iteration:04d}"
|
|
adapter_out.mkdir(parents=True, exist_ok=True)
|
|
|
|
cmd = [
|
|
_MLX_LM_BIN,
|
|
"--model", self._base_model,
|
|
"--data", str(dataset_path),
|
|
"--adapter-path", str(adapter_out),
|
|
"--train",
|
|
"--iters", "100",
|
|
"--batch-size", "1",
|
|
"--learning-rate", "1e-5",
|
|
]
|
|
|
|
logger.info("Starting mlx-lm LoRA fine-tune: iteration %d", iteration)
|
|
logger.info("Command: %s", " ".join(cmd))
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=3600, # 1 hour max
|
|
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
|
)
|
|
except subprocess.TimeoutExpired:
|
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
|
return TrainResult(
|
|
status="failed",
|
|
adapter_path=None,
|
|
model_name=None,
|
|
iteration=iteration,
|
|
duration_seconds=duration,
|
|
message="Fine-tune timed out after 1 hour",
|
|
)
|
|
except Exception as exc:
|
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
|
return TrainResult(
|
|
status="failed",
|
|
adapter_path=None,
|
|
model_name=None,
|
|
iteration=iteration,
|
|
duration_seconds=duration,
|
|
message=f"Fine-tune subprocess error: {exc}",
|
|
)
|
|
|
|
duration = (datetime.now(tz=UTC) - started).total_seconds()
|
|
|
|
if result.returncode != 0:
|
|
logger.error("mlx-lm fine-tune failed: %s", result.stderr[:500])
|
|
return TrainResult(
|
|
status="failed",
|
|
adapter_path=None,
|
|
model_name=None,
|
|
iteration=iteration,
|
|
duration_seconds=duration,
|
|
message=f"mlx_lm.lora exited {result.returncode}: {result.stderr[:300]}",
|
|
)
|
|
|
|
# Parse final train loss from stdout if available
|
|
train_loss = _parse_train_loss(result.stdout)
|
|
|
|
adapter_file = adapter_out / "adapters.npz"
|
|
model_name = f"{self._base_model}-ft-{iteration:04d}"
|
|
|
|
# Attempt to register with Ollama
|
|
ollama_ok = self._register_ollama_adapter(adapter_out, model_name)
|
|
if not ollama_ok:
|
|
logger.warning("Ollama adapter registration failed — adapter saved locally")
|
|
|
|
logger.info(
|
|
"Fine-tune complete: iteration=%d loss=%.4f duration=%.1fs adapter=%s",
|
|
iteration,
|
|
train_loss or 0.0,
|
|
duration,
|
|
adapter_file,
|
|
)
|
|
|
|
return TrainResult(
|
|
status="completed",
|
|
adapter_path=str(adapter_file),
|
|
model_name=model_name,
|
|
iteration=iteration,
|
|
duration_seconds=duration,
|
|
message=f"LoRA fine-tune completed successfully in {duration:.0f}s",
|
|
train_loss=train_loss,
|
|
)
|
|
|
|
def _register_ollama_adapter(self, adapter_dir: Path, model_name: str) -> bool:
|
|
"""Create an Ollama model entry for the new adapter.
|
|
|
|
Writes a minimal Modelfile and runs `ollama create`.
|
|
"""
|
|
if not shutil.which(_OLLAMA_BIN):
|
|
logger.debug("Ollama not found — skipping adapter registration")
|
|
return False
|
|
|
|
modelfile_content = (
|
|
f"FROM {self._base_model}\n"
|
|
f"ADAPTER {adapter_dir}\n"
|
|
)
|
|
modelfile_path = adapter_dir / "Modelfile"
|
|
try:
|
|
modelfile_path.write_text(modelfile_content)
|
|
result = subprocess.run(
|
|
[_OLLAMA_BIN, "create", model_name, "-f", str(modelfile_path)],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=300,
|
|
)
|
|
if result.returncode == 0:
|
|
logger.info("Ollama model registered: %s", model_name)
|
|
return True
|
|
else:
|
|
logger.warning("ollama create failed: %s", result.stderr[:200])
|
|
return False
|
|
except Exception as exc:
|
|
logger.warning("Ollama adapter registration error: %s", exc)
|
|
return False
|
|
|
|
|
|
def _parse_train_loss(stdout: str) -> float | None:
|
|
"""Extract the final training loss from mlx-lm stdout."""
|
|
loss: float | None = None
|
|
for line in stdout.splitlines():
|
|
line_lower = line.lower()
|
|
if "train loss" in line_lower or "loss:" in line_lower:
|
|
parts = line.split()
|
|
for i, part in enumerate(parts):
|
|
if "loss" in part.lower() and i + 1 < len(parts):
|
|
try:
|
|
loss = float(parts[i + 1].strip(",:"))
|
|
except ValueError:
|
|
pass
|
|
return loss
|