Files
hermes-config/skills/mlops/pytorch-lightning/references/callbacks.md
Alexander Whitestone 11cc14d707 init: Hermes config, skills, memories, cron
Sovereign backup of all Hermes Agent configuration and data.
Excludes: secrets, auth tokens, sessions, caches, code (separate repo).

Tracked:
- config.yaml (model, fallback chain, toolsets, display prefs)
- SOUL.md (Timmy personality charter)
- memories/ (persistent MEMORY.md + USER.md)
- skills/ (371 files — full skill library)
- cron/jobs.json (scheduled tasks)
- channel_directory.json (platform channels)
- hooks/ (custom hooks)
2026-03-14 14:42:33 -04:00

12 KiB

PyTorch Lightning Callbacks

Overview

Callbacks add functionality to training without modifying the LightningModule. They capture non-essential logic like checkpointing, early stopping, and logging.

Built-In Callbacks

1. ModelCheckpoint

Saves best models during training:

from lightning.pytorch.callbacks import ModelCheckpoint

# Save top 3 models based on validation loss
checkpoint = ModelCheckpoint(
    dirpath='checkpoints/',
    filename='model-{epoch:02d}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    save_last=True,  # Also save last epoch
    verbose=True
)

trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader)

Configuration options:

checkpoint = ModelCheckpoint(
    monitor='val_acc',        # Metric to monitor
    mode='max',               # 'max' for accuracy, 'min' for loss
    save_top_k=5,             # Keep best 5 models
    save_last=True,           # Save last epoch separately
    every_n_epochs=1,         # Save every N epochs
    save_on_train_epoch_end=False,  # Save on validation end instead
    filename='best-{epoch}-{val_acc:.3f}',  # Naming pattern
    auto_insert_metric_name=False  # Don't auto-add metric to filename
)

Load checkpoint:

# Load best model
best_model_path = checkpoint.best_model_path
model = LitModel.load_from_checkpoint(best_model_path)

# Resume training
trainer = L.Trainer(callbacks=[checkpoint])
trainer.fit(model, train_loader, val_loader, ckpt_path='checkpoints/last.ckpt')

2. EarlyStopping

Stops training when metric stops improving:

from lightning.pytorch.callbacks import EarlyStopping

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,               # Wait 5 epochs
    mode='min',
    min_delta=0.001,          # Minimum change to qualify as improvement
    verbose=True,
    strict=True,              # Crash if monitored metric not found
    check_on_train_epoch_end=False  # Check on validation end
)

trainer = L.Trainer(callbacks=[early_stop])
trainer.fit(model, train_loader, val_loader)
# Stops automatically if no improvement for 5 epochs

Advanced usage:

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=10,
    min_delta=0.0,
    verbose=True,
    mode='min',
    stopping_threshold=0.1,   # Stop if val_loss < 0.1
    divergence_threshold=5.0, # Stop if val_loss > 5.0
    check_finite=True         # Stop on NaN/Inf
)

3. LearningRateMonitor

Logs learning rate:

from lightning.pytorch.callbacks import LearningRateMonitor

lr_monitor = LearningRateMonitor(
    logging_interval='epoch',  # Or 'step'
    log_momentum=True          # Also log momentum
)

trainer = L.Trainer(callbacks=[lr_monitor])
# Learning rate automatically logged to TensorBoard/WandB

4. TQDMProgressBar

Customizes progress bar:

from lightning.pytorch.callbacks import TQDMProgressBar

progress_bar = TQDMProgressBar(
    refresh_rate=10,  # Update every 10 batches
    process_position=0
)

trainer = L.Trainer(callbacks=[progress_bar])

5. GradientAccumulationScheduler

Dynamic gradient accumulation:

from lightning.pytorch.callbacks import GradientAccumulationScheduler

# Accumulate more gradients as training progresses
accumulator = GradientAccumulationScheduler(
    scheduling={
        0: 8,   # Epochs 0-4: accumulate 8 batches
        5: 4,   # Epochs 5-9: accumulate 4 batches
        10: 2   # Epochs 10+: accumulate 2 batches
    }
)

trainer = L.Trainer(callbacks=[accumulator])

6. StochasticWeightAveraging (SWA)

Averages weights for better generalization:

from lightning.pytorch.callbacks import StochasticWeightAveraging

swa = StochasticWeightAveraging(
    swa_lrs=1e-2,  # SWA learning rate
    swa_epoch_start=0.8,  # Start at 80% of training
    annealing_epochs=10,  # Annealing period
    annealing_strategy='cos'  # 'cos' or 'linear'
)

trainer = L.Trainer(callbacks=[swa])

Custom Callbacks

Basic Custom Callback

from lightning.pytorch.callbacks import Callback

class PrintingCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        print("Training is starting!")

    def on_train_end(self, trainer, pl_module):
        print("Training is done!")

    def on_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch} ended")

# Use it
trainer = L.Trainer(callbacks=[PrintingCallback()])

Advanced Custom Callback

class MetricsCallback(Callback):
    """Logs custom metrics every N batches."""

    def __init__(self, log_every_n_batches=100):
        self.log_every_n_batches = log_every_n_batches
        self.metrics = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx % self.log_every_n_batches == 0:
            # Compute custom metric
            metric = self.compute_metric(outputs)
            self.metrics.append(metric)

            # Log to Lightning
            pl_module.log('custom_metric', metric)

    def compute_metric(self, outputs):
        # Your custom logic
        return outputs['loss'].item()

    def state_dict(self):
        """Save callback state in checkpoint."""
        return {'metrics': self.metrics}

    def load_state_dict(self, state_dict):
        """Restore callback state from checkpoint."""
        self.metrics = state_dict['metrics']

Gradient Monitoring Callback

class GradientMonitorCallback(Callback):
    """Monitor gradient norms."""

    def on_after_backward(self, trainer, pl_module):
        # Compute gradient norm
        total_norm = 0.0
        for p in pl_module.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        # Log
        pl_module.log('grad_norm', total_norm)

        # Warn if exploding
        if total_norm > 100:
            print(f"Warning: Large gradient norm: {total_norm:.2f}")

Model Inspection Callback

class ModelInspectionCallback(Callback):
    """Inspect model activations during training."""

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        if batch_idx == 0:  # First batch of epoch
            # Register hooks
            self.activations = {}

            def get_activation(name):
                def hook(model, input, output):
                    self.activations[name] = output.detach()
                return hook

            # Attach to specific layers
            pl_module.model.layer1.register_forward_hook(get_activation('layer1'))
            pl_module.model.layer2.register_forward_hook(get_activation('layer2'))

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if batch_idx == 0:
            # Log activation statistics
            for name, activation in self.activations.items():
                mean = activation.mean().item()
                std = activation.std().item()
                pl_module.log(f'{name}_mean', mean)
                pl_module.log(f'{name}_std', std)

Callback Hooks

All available hooks:

class MyCallback(Callback):
    # Setup/Teardown
    def setup(self, trainer, pl_module, stage):
        """Called at beginning of fit/test/predict."""
        pass

    def teardown(self, trainer, pl_module, stage):
        """Called at end of fit/test/predict."""
        pass

    # Training
    def on_train_start(self, trainer, pl_module):
        pass

    def on_train_epoch_start(self, trainer, pl_module):
        pass

    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        pass

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        pass

    def on_train_epoch_end(self, trainer, pl_module):
        pass

    def on_train_end(self, trainer, pl_module):
        pass

    # Validation
    def on_validation_start(self, trainer, pl_module):
        pass

    def on_validation_epoch_start(self, trainer, pl_module):
        pass

    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
        pass

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        pass

    def on_validation_epoch_end(self, trainer, pl_module):
        pass

    def on_validation_end(self, trainer, pl_module):
        pass

    # Test (same structure as validation)
    def on_test_start(self, trainer, pl_module):
        pass
    # ... (test_epoch_start, test_batch_start, etc.)

    # Predict
    def on_predict_start(self, trainer, pl_module):
        pass
    # ... (predict_epoch_start, predict_batch_start, etc.)

    # Backward
    def on_before_backward(self, trainer, pl_module, loss):
        pass

    def on_after_backward(self, trainer, pl_module):
        pass

    # Optimizer
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        pass

    # Checkpointing
    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        """Add data to checkpoint."""
        pass

    def on_load_checkpoint(self, trainer, pl_module, checkpoint):
        """Restore data from checkpoint."""
        pass

Combining Multiple Callbacks

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Create all callbacks
checkpoint = ModelCheckpoint(monitor='val_loss', mode='min', save_top_k=3)
early_stop = EarlyStopping(monitor='val_loss', patience=5)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
custom_callback = MyCustomCallback()

# Add all to Trainer
trainer = L.Trainer(
    callbacks=[checkpoint, early_stop, lr_monitor, custom_callback]
)

trainer.fit(model, train_loader, val_loader)

Execution order: Callbacks execute in the order they're added

Best Practices

1. Keep Callbacks Independent

Bad (dependent on other callback):

class BadCallback(Callback):
    def on_train_end(self, trainer, pl_module):
        # Assumes ModelCheckpoint is present
        best_path = trainer.checkpoint_callback.best_model_path  # Fragile!

Good (self-contained):

class GoodCallback(Callback):
    def on_train_end(self, trainer, pl_module):
        # Find checkpoint callback if present
        for callback in trainer.callbacks:
            if isinstance(callback, ModelCheckpoint):
                best_path = callback.best_model_path
                break

2. Use State Dict for Persistence

class StatefulCallback(Callback):
    def __init__(self):
        self.counter = 0
        self.history = []

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        self.counter += 1
        self.history.append(outputs['loss'].item())

    def state_dict(self):
        """Save state."""
        return {
            'counter': self.counter,
            'history': self.history
        }

    def load_state_dict(self, state_dict):
        """Restore state."""
        self.counter = state_dict['counter']
        self.history = state_dict['history']

3. Handle Distributed Training

class DistributedCallback(Callback):
    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # Only run on main process
        if trainer.is_global_zero:
            print("This only prints once in distributed training")

        # Run on all processes
        loss = outputs['loss']
        # ... do something with loss on each GPU

Resources