- Notifications
You must be signed in to change notification settings - Fork712
A PyTorch native platform for training generative AI models
License
pytorch/torchtitan
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
torchtitan is under extensive development. To use the latest features oftorchtitan, we recommend using the most recent PyTorch nightly.
- [2025/11] AMD released anoptimized fork of
torchtitanfor AMD GPUs. - [2025/10] We released
torchtitanv0.2.0. - [2025/10] SkyPilot now supports
torchtitan! See the tutorialhere. - [2025/07] We publishedinstructions on how to add a model to
torchtitan. - [2025/04] Our paper was accepted byICLR 2025.
- [2024/12] GPU MODElecture on torchtitan.
- [2024/07]Presentation at PyTorch Conference 2024.
torchtitan is a PyTorch native platform designed forrapid experimentation and large-scale training of generative AI models. As a minimal clean-room implementation of PyTorch native scaling techniques,torchtitan provides a flexible foundation for developers to build upon. Withtorchtitanextension points, one can easily create custom extensions tailored to specific needs.
Our mission is to accelerate innovation in the field of generative AI by empowering researchers and developers to explore new modeling architectures and infrastructure techniques.
The Guiding Principles when buildingtorchtitan
- Designed to be easy to understand, use and extend for different training purposes.
- Minimal changes to the model code when applying multi-dimensional parallelism.
- Bias towards a clean, minimal codebase while providing basic reusable / swappable components.
torchtitan has been showcasing PyTorch's latest distributed training features, via support for pretraining Llama 3.1 LLMs of various sizes.
We look forward to your contributions!
- To accelerate contributions to and innovations around torchtitan, we host an
experimentsfolder. New ideas should start there. To contribute, follow theexperiments guidelines. - For fixes and contributions to core, follow these
guidelines.
- Multi-dimensional composable parallelisms
- FSDP2 with per-parameter sharding
- Tensor Parallel (includingasync TP)
- Pipeline Parallel
- Context Parallel
- Meta device initialization
- Selective (layer or operator) and full activation checkpointing
- Distributed checkpointing (including async checkpointing)
- Interoperable checkpoints which can be loaded directly into
torchtunefor fine-tuning
- Interoperable checkpoints which can be loaded directly into
torch.compilesupport- Float8 support (how-to)
- MXFP8 training for dense and MoE models on Blackwell GPUs.
- DDP and HSDP
- TorchFT integration
- Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support forcustom datasets
- Gradient accumulation, enabled by giving an additional
--training.global_batch_sizeargument in configuration - Flexible learning rate scheduler (warmup-stable-decay)
- Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged viaTensorboard or Weights & Biases
- Debugging tools including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
- All options easily configured viatoml files
- Helper scripts to
- download tokenizers from Hugging Face
- convert original Llama 3 checkpoints into the expected DCP format
- estimate FSDP/HSDP memory usage without materializing the model
- run distributed inference with Tensor Parallel
We reportperformance on up to 512 GPUs, and verifyloss converging correctness of various techniques.
You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
- torchtitan/train.py - the main training loop and high-level setup code
- torchtitan/models/llama3/model/model.py - the Llama 3.1 model definition
- torchtitan/models/llama3/infra/parallelize.py - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and
torch.compileto the model - torchtitan/models/llama3/infra/pipeline.py - helpers for applying Pipeline Parallel to the model
- torchtitan/components/checkpoint.py - utils for saving/loading distributed checkpoints
- torchtitan/components/quantization/float8.py - utils for applying Float8 techniques
One can directly run the source code, or installtorchtitan from a nightly build, or a stable release.
This method requires the nightly build of PyTorch, or the latest PyTorch builtfrom source.
git clone https://github.com/pytorch/torchtitancd torchtitanpip install -r requirements.txtThis method requires the nightly build of PyTorch. You can replacecu126 with another version of cuda (e.g.cu128) or an AMD GPU (e.g.rocm6.3).
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstallpip install --pre torchtitan --index-url https://download.pytorch.org/whl/nightly/cu126
One can install the lateststable release oftorchtitan viapip orconda.
pip install torchtitan
conda install conda-forge::torchtitan
Note that each stable release pins the nightly versions oftorch andtorchao. Please seerelease.md for more details.
torchtitan currently supports training Llama 3.1 (8B, 70B, 405B) out of the box. To get started training these models, we need to download the tokenizer. Follow the instructions on the officialmeta-llama repository to ensure you have access to the Llama model weights.
Once you have confirmed access, you can run the following command to download the Llama 3.1 tokenizer to your local machine.
# Get your HF token from https://huggingface.co/settings/tokens# Llama 3.1 tokenizerpython scripts/download_hf_assets.py --repo_id meta-llama/Llama-3.1-8B --assets tokenizer --hf_token=...
Llama 3 8B model locally on 8 GPUs
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.shFor training on ParallelCluster/Slurm type configurations, you can use themultinode_trainer.slurm file to submit your sbatch job.
To get started adjust the number of nodes and GPUs
#SBATCH --ntasks=2#SBATCH --nodes=2Then start a run wherennodes is your total node count, matching the sbatch node count above.
srun torchrun --nnodes 2If your gpu count per node is not 8, adjust--nproc_per_node in the torchrun command and#SBATCH --gpus-per-task in the SBATCH command section.
We provide a detailed look into the parallelisms and optimizations available intorchtitan, along with summary advice on when to use various techniques.
TorchTitan: One-stop PyTorch native solution for production ready LLM pre-training
@inproceedings{ liang2025torchtitan, title={TorchTitan: One-stop PyTorch native solution for production ready {LLM} pretraining}, author={Wanchao Liang and Tianyu Liu and Less Wright and Will Constable and Andrew Gu and Chien-Chin Huang and Iris Zhang and Wei Feng and Howard Huang and Junjie Wang and Sanket Purandare and Gokul Nadathur and Stratos Idreos}, booktitle={The Thirteenth International Conference on Learning Representations}, year={2025}, url={https://openreview.net/forum?id=SFN6Wm7YBI}}Source code is made available under aBSD 3 license, however you may have other legal obligations that govern your use of other content linked in this repository, such as the license or terms of service for third-party data and models.
About
A PyTorch native platform for training generative AI models
Resources
License
Code of conduct
Contributing
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.