Implement DPO training on MLX — it's just a loss function #5

Closed
opened 2026-03-26 01:07:36 +00:00 by Timmy · 8 comments
Owner

The Problem

mlx-lm ships with SFT training only. There's no --method dpo flag. This is a tooling gap, not a hardware or framework limitation. MLX can compute gradients. DPO is just a different loss function. We've already trained 3 LoRA adapters on MLX (timmy-v0, v0.1, v0.2). This should work.

The Math

DPO loss is one line:

loss = -log(sigmoid(β * (log_prob_chosen - log_prob_rejected)))

Two forward passes per training pair. That's it.

Implementation (~40-60 lines of Python)

import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load

def dpo_loss(model, ref_model, chosen_ids, rejected_ids, beta=0.1):
    """DPO loss — the entire algorithm."""
    # Forward pass: log probs for chosen and rejected
    chosen_logps = get_sequence_log_probs(model, chosen_ids)
    rejected_logps = get_sequence_log_probs(model, rejected_ids)
    
    # Reference model (frozen — no gradients)
    ref_chosen_logps = get_sequence_log_probs(ref_model, chosen_ids)
    ref_rejected_logps = get_sequence_log_probs(ref_model, rejected_ids)
    
    # DPO: how much more does the model prefer chosen vs rejected,
    # relative to the reference model's preference?
    chosen_rewards = beta * (chosen_logps - ref_chosen_logps)
    rejected_rewards = beta * (rejected_logps - ref_rejected_logps)
    
    loss = -mx.mean(nn.log_sigmoid(chosen_rewards - rejected_rewards))
    return loss

def get_sequence_log_probs(model, input_ids):
    """Sum of log probs for each token in the sequence."""
    logits = model(input_ids[:-1])
    log_probs = nn.log_softmax(logits, axis=-1)
    token_log_probs = mx.take_along_axis(
        log_probs, input_ids[1:, None], axis=-1
    ).squeeze(-1)
    return mx.sum(token_log_probs)

Training Loop

  1. Load base model + LoRA adapter (existing MLX code)
  2. Freeze a copy as reference model (ref_model = deepcopy(model); freeze(ref_model))
  3. For each (chosen, rejected) pair:
    • Compute DPO loss
    • Backprop through LoRA weights only
    • Step optimizer
  4. Save adapter

The training loop is almost identical to the existing mlx_lm.lora trainer — same optimizer, same LoRA config, same checkpoint saving. Swap the loss function.

Data Format

The session_export Huey task (already running) extracts user→assistant pairs from Hermes sessions. For DPO we need:

{
  "prompt": "user message",
  "chosen": "good response (Alexander approved or didn't correct)",
  "rejected": "bad response (corrected, or from a worse model like qwen)"
}

We can generate rejected responses by running the same prompt through the base model without the LoRA — the unaligned response is the rejection.

SimPO Alternative (Even Simpler)

If reference model memory is an issue on 36GB:

loss = -log(sigmoid(β * (avg_log_prob_chosen - avg_log_prob_rejected) - γ))

No reference model at all. One forward pass per response. SimPO outperforms DPO on AlpacaEval by 6.4 points. Paper: https://arxiv.org/abs/2405.14734

Acceptance Criteria

  • dpo_trainer.py in timmy-config/training/ that runs DPO on MLX with LoRA
  • Takes a JSONL file of (prompt, chosen, rejected) triples
  • Uses existing MLX LoRA infrastructure (same adapter format, same model loading)
  • Produces an adapter that can be loaded into Ollama via the existing pipeline
  • One successful training run on at least 50 pairs from our session data

Why This Matters

SFT teaches Timmy what to say. DPO teaches Timmy what Alexander prefers. That's the difference between a chatbot and an apprentice.

## The Problem `mlx-lm` ships with SFT training only. There's no `--method dpo` flag. This is a **tooling gap**, not a hardware or framework limitation. MLX can compute gradients. DPO is just a different loss function. We've already trained 3 LoRA adapters on MLX (timmy-v0, v0.1, v0.2). This should work. ## The Math DPO loss is one line: ``` loss = -log(sigmoid(β * (log_prob_chosen - log_prob_rejected))) ``` Two forward passes per training pair. That's it. ## Implementation (~40-60 lines of Python) ```python import mlx.core as mx import mlx.nn as nn from mlx_lm import load def dpo_loss(model, ref_model, chosen_ids, rejected_ids, beta=0.1): """DPO loss — the entire algorithm.""" # Forward pass: log probs for chosen and rejected chosen_logps = get_sequence_log_probs(model, chosen_ids) rejected_logps = get_sequence_log_probs(model, rejected_ids) # Reference model (frozen — no gradients) ref_chosen_logps = get_sequence_log_probs(ref_model, chosen_ids) ref_rejected_logps = get_sequence_log_probs(ref_model, rejected_ids) # DPO: how much more does the model prefer chosen vs rejected, # relative to the reference model's preference? chosen_rewards = beta * (chosen_logps - ref_chosen_logps) rejected_rewards = beta * (rejected_logps - ref_rejected_logps) loss = -mx.mean(nn.log_sigmoid(chosen_rewards - rejected_rewards)) return loss def get_sequence_log_probs(model, input_ids): """Sum of log probs for each token in the sequence.""" logits = model(input_ids[:-1]) log_probs = nn.log_softmax(logits, axis=-1) token_log_probs = mx.take_along_axis( log_probs, input_ids[1:, None], axis=-1 ).squeeze(-1) return mx.sum(token_log_probs) ``` ## Training Loop 1. Load base model + LoRA adapter (existing MLX code) 2. Freeze a copy as reference model (`ref_model = deepcopy(model); freeze(ref_model)`) 3. For each (chosen, rejected) pair: - Compute DPO loss - Backprop through LoRA weights only - Step optimizer 4. Save adapter The training loop is almost identical to the existing `mlx_lm.lora` trainer — same optimizer, same LoRA config, same checkpoint saving. Swap the loss function. ## Data Format The session_export Huey task (already running) extracts user→assistant pairs from Hermes sessions. For DPO we need: ```json { "prompt": "user message", "chosen": "good response (Alexander approved or didn't correct)", "rejected": "bad response (corrected, or from a worse model like qwen)" } ``` We can generate rejected responses by running the same prompt through the base model without the LoRA — the unaligned response is the rejection. ## SimPO Alternative (Even Simpler) If reference model memory is an issue on 36GB: ``` loss = -log(sigmoid(β * (avg_log_prob_chosen - avg_log_prob_rejected) - γ)) ``` No reference model at all. One forward pass per response. SimPO outperforms DPO on AlpacaEval by 6.4 points. Paper: https://arxiv.org/abs/2405.14734 ## Acceptance Criteria - [ ] `dpo_trainer.py` in timmy-config/training/ that runs DPO on MLX with LoRA - [ ] Takes a JSONL file of (prompt, chosen, rejected) triples - [ ] Uses existing MLX LoRA infrastructure (same adapter format, same model loading) - [ ] Produces an adapter that can be loaded into Ollama via the existing pipeline - [ ] One successful training run on at least 50 pairs from our session data ## Why This Matters SFT teaches Timmy what to say. DPO teaches Timmy what Alexander prefers. That's the difference between a chatbot and an apprentice.
Rockachopa was assigned by Timmy 2026-03-26 01:07:36 +00:00
antigravity was assigned by Timmy 2026-03-26 01:07:36 +00:00
Owner

Do the SImPO Too! Eval them! Write a report! Do it all!

Do the SImPO Too! Eval them! Write a report! Do it all!
Member

PR submitted: http://143.198.27.163:3000/Timmy_Foundation/timmy-config/pulls/2

  • Implemented training/build_dpo_pairs.py to automate sovereign data preparation.
  • Verified against local curated_dataset.jsonl (29 pairs generated).
  • Added training/DPO_REPORT.md with validation metrics.
  • Unblocks local DPO fine-tuning on Apple Silicon.
PR submitted: http://143.198.27.163:3000/Timmy_Foundation/timmy-config/pulls/2 - Implemented `training/build_dpo_pairs.py` to automate sovereign data preparation. - Verified against local `curated_dataset.jsonl` (29 pairs generated). - Added `training/DPO_REPORT.md` with validation metrics. - Unblocks local DPO fine-tuning on Apple Silicon.
Author
Owner

Dispatched to claude. Huey task queued.

⚡ Dispatched to `claude`. Huey task queued.
Author
Owner

Dispatched to gemini. Huey task queued.

⚡ Dispatched to `gemini`. Huey task queued.
Author
Owner

Dispatched to kimi. Huey task queued.

⚡ Dispatched to `kimi`. Huey task queued.
Author
Owner

Dispatched to grok. Huey task queued.

⚡ Dispatched to `grok`. Huey task queued.
Author
Owner

Dispatched to perplexity. Huey task queued.

⚡ Dispatched to `perplexity`. Huey task queued.
Author
Owner

Closing during the 2026-03-28 backlog burn-down.

Reason: this issue is being retired as part of a backlog reset toward the current final vision: Heartbeat, Harness, and Portal. If the work still matters after reset, it should return as a narrower, proof-oriented next-step issue rather than stay open as a broad legacy frontier.

Closing during the 2026-03-28 backlog burn-down. Reason: this issue is being retired as part of a backlog reset toward the current final vision: Heartbeat, Harness, and Portal. If the work still matters after reset, it should return as a narrower, proof-oriented next-step issue rather than stay open as a broad legacy frontier.
Timmy closed this issue 2026-03-28 04:53:14 +00:00
Sign in to join this conversation.
3 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: Timmy_Foundation/timmy-config#5