Files
hermes-agent/skills/mlops/torchtitan/references/custom-models.md
teknium f172f7d4aa Add skills tools and enhance model integration
- Introduced new skills tools: `skills_categories`, `skills_list`, and `skill_view` in `model_tools.py`, allowing for better organization and access to skill-related functionalities.
- Updated `toolsets.py` to include a new `skills` toolset, providing a dedicated space for skill tools.
- Enhanced `batch_runner.py` to recognize and validate skills tools during batch processing.
- Added comprehensive tool definitions for skills tools, ensuring compatibility with OpenAI's expected format.
- Created new shell script `test_skills_kimi.sh` for testing skills tool functionality with Kimi K2.5.
- Added example skill files demonstrating the structure and usage of skills within the Hermes-Agent framework, including `SKILL.md` for example and audiocraft skills.
- Improved documentation for skills tools and their integration into the existing tool framework, ensuring clarity for future development and usage.
2026-01-30 07:39:55 +00:00

7.2 KiB

Adding Custom Models to TorchTitan

This guide explains how to add a new model to TorchTitan following the established patterns.

Directory Structure

torchtitan/models/your_model/
├── model/
│   ├── __init__.py
│   ├── args.py          # Model arguments
│   ├── model.py         # Model definition
│   └── state_dict_adapter.py  # HF conversion (optional)
├── infra/
│   ├── __init__.py
│   ├── parallelize.py   # TP, FSDP, compile application
│   └── pipeline.py      # PP application (optional)
├── train_configs/
│   ├── debug_model.toml
│   └── your_model_XB.toml
├── __init__.py          # TrainSpec registration
└── README.md

Step 1: Define Model Arguments

Inherit from BaseModelArgs:

# model/args.py
from torchtitan.protocols.model import BaseModelArgs
from dataclasses import dataclass

@dataclass
class YourModelArgs(BaseModelArgs):
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    vocab_size: int = 128256

    def get_nparams_and_flops(self, seq_len: int) -> tuple[int, int]:
        """Return (num_params, flops_per_token) for throughput calculation."""
        nparams = self.vocab_size * self.dim + ...  # Calculate params
        flops = 6 * nparams  # Approximate: 6 * params for forward+backward
        return nparams, flops

    def update_from_config(self, job_config) -> "YourModelArgs":
        """Update args from training config."""
        # Override specific args from job_config if needed
        return self

Step 2: Define Model

Inherit from ModelProtocol:

# model/model.py
import torch.nn as nn
from torchtitan.protocols.model import ModelProtocol
from .args import YourModelArgs

class YourModel(ModelProtocol):
    def __init__(self, args: YourModelArgs):
        super().__init__()
        self.args = args
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.layers = nn.ModuleDict({
            str(i): TransformerBlock(args) for i in range(args.n_layers)
        })
        self.norm = RMSNorm(args.dim)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        h = self.tok_embeddings(tokens)
        for layer in self.layers.values():
            h = layer(h)
        h = self.norm(h)
        return self.output(h)

    def init_weights(self):
        """Initialize weights recursively."""
        for module in self.modules():
            if hasattr(module, 'init_weights') and module is not self:
                module.init_weights()
            elif isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.02)

Important guidelines:

  • Write single-device model code (parallelism applied externally)
  • Use nn.ModuleDict for layers (preserves FQNs when deleting for PP)
  • Make input/output layers optional for PP compatibility
  • Define init_weights() recursively

Step 3: Parallelize Function

# infra/parallelize.py
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.tensor.parallel import parallelize_module

def parallelize_your_model(
    model: YourModel,
    world_mesh: DeviceMesh,
    parallel_dims: ParallelDims,
    job_config: JobConfig,
):
    # Apply in this order: TP -> AC -> compile -> FSDP

    # 1. Tensor Parallelism
    if parallel_dims.tp_enabled:
        apply_tp(model, world_mesh["tp"], job_config)

    # 2. Activation Checkpointing
    if job_config.activation_checkpoint.mode == "full":
        apply_ac(model, job_config)

    # 3. torch.compile
    if job_config.compile.enable:
        model = torch.compile(model)

    # 4. FSDP
    if parallel_dims.dp_enabled:
        apply_fsdp(model, world_mesh["dp"], job_config)

    return model

Step 4: Create TrainSpec

# __init__.py
from torchtitan.protocols.train_spec import TrainSpec, register_train_spec
from .model.model import YourModel
from .model.args import YourModelArgs
from .infra.parallelize import parallelize_your_model

MODEL_CONFIGS = {
    "8B": YourModelArgs(dim=4096, n_layers=32, n_heads=32),
    "70B": YourModelArgs(dim=8192, n_layers=80, n_heads=64),
}

def get_train_spec(flavor: str) -> TrainSpec:
    return TrainSpec(
        model_cls=YourModel,
        model_args=MODEL_CONFIGS[flavor],
        parallelize_fn=parallelize_your_model,
        pipeline_fn=None,  # Or your_pipeline_fn for PP
        build_optimizer_fn=build_optimizer,  # Reuse existing
        build_lr_scheduler_fn=build_lr_scheduler,  # Reuse existing
        build_dataloader_fn=build_dataloader,  # Reuse existing
        build_tokenizer_fn=build_tokenizer,  # Reuse existing
        build_loss_fn=build_loss,  # Reuse existing
        state_dict_adapter=None,  # Or YourStateDictAdapter
    )

# Register so train.py can find it
register_train_spec("your_model", get_train_spec)

Step 5: State Dict Adapter (Optional)

For HuggingFace checkpoint conversion:

# model/state_dict_adapter.py
from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter

class YourStateDictAdapter(BaseStateDictAdapter):
    def to_hf(self, state_dict: dict) -> dict:
        """Convert torchtitan state dict to HF format."""
        hf_state_dict = {}
        for key, value in state_dict.items():
            hf_key = self._convert_key_to_hf(key)
            hf_state_dict[hf_key] = value
        return hf_state_dict

    def from_hf(self, state_dict: dict) -> dict:
        """Convert HF state dict to torchtitan format."""
        tt_state_dict = {}
        for key, value in state_dict.items():
            tt_key = self._convert_key_from_hf(key)
            tt_state_dict[tt_key] = value
        return tt_state_dict

Step 6: Training Config

# train_configs/your_model_8b.toml
[job]
dump_folder = "./outputs"
description = "Your Model 8B training"

[model]
name = "your_model"
flavor = "8B"

[optimizer]
name = "AdamW"
lr = 3e-4

[training]
local_batch_size = 2
seq_len = 8192
steps = 1000
dataset = "c4"

[parallelism]
data_parallel_shard_degree = -1
tensor_parallel_degree = 1

Step 7: Register Model

Add to torchtitan/models/__init__.py:

from .your_model import get_train_spec as get_your_model_train_spec

MODEL_REGISTRY["your_model"] = get_your_model_train_spec

Testing

Numerics Test

Compare output with HuggingFace implementation:

def test_numerics():
    # Load same checkpoint into both implementations
    tt_model = YourModel(args).load_checkpoint(...)
    hf_model = HFYourModel.from_pretrained(...)

    # Compare outputs
    input_ids = torch.randint(0, vocab_size, (1, 128))
    tt_output = tt_model(input_ids)
    hf_output = hf_model(input_ids).logits

    torch.testing.assert_close(tt_output, hf_output, atol=1e-4, rtol=1e-4)

Loss Convergence

Compare loss curves with verified baseline (see docs/converging.md).

Performance Benchmark

Add benchmark config to benchmarks/ folder.

Guiding Principles

  1. Readability over flexibility: Don't over-abstract
  2. Minimal model changes: Parallelism applied externally
  3. Clean, minimal codebase: Reuse existing components where possible
  4. Single-device semantics: Model code should work on single GPU