- Notifications
You must be signed in to change notification settings - Fork3.6k
Pretrain, finetune ANY AI model of ANY size on 1 or 10,000+ GPUs with zero code changes.
License
Lightning-AI/pytorch-lightning
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
The deep learning framework to pretrain and finetune AI models.
Deploying models? Check outLitServe, the PyTorch Lightning for inference engines
Quick start •Examples •PyTorch Lightning •Fabric •Lightning Cloud •Community •Docs
Over 340,000 developers useLightning Cloud - purpose-built for PyTorch and PyTorch Lightning.
- GPUs from $0.19.
- Clusters: frontier-grade training/inference clusters.
- AI Studio (vibe train): workspaces where AI helps you debug, tune and vibe train.
- AI Studio (vibe deploy): workspaces where AI helps you optimize, and deploy models.
- Notebooks: Persistent GPU workspaces where AI helps you code and analyze.
- Inference: Deploy models as inference APIs.
Training models in plain PyTorch is tedious and error-prone - you have to manually handle things like backprop, mixed precision, multi-GPU, and distributed training, often rewriting code for every new project. PyTorch Lightning organizes PyTorch code to automate those complexities so you can focus on your model and data, while keeping full control and scaling from CPU to multi-node without changing your core code. But if you want control of those things, you can still opt intoexpert-level control.
Fun analogy: If PyTorch is Javascript, PyTorch Lightning is ReactJS or NextJS.
PyTorch Lightning: Train and deploy PyTorch at scale.
Lightning Fabric: Expert control.
Lightning gives you granular control over how much abstraction you want to add over PyTorch.
Install Lightning:
pip install lightning
Advanced install options
pip install lightning['extra']
conda install lightning -c conda-forge
Install future release from the source
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
Install nightly from the source (no guarantees)
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
or from testing PyPI
pip install -iU https://test.pypi.org/simple/ pytorch-lightning
Define the training workflow. Here's a toy example (explore real examples):
# main.py# ! pip install torchvisionimporttorch,torch.nnasnn,torch.utils.dataasdata,torchvisionastv,torch.nn.functionalasFimportlightningasL# --------------------------------# Step 1: Define a LightningModule# --------------------------------# A LightningModule (nn.Module subclass) defines a full *system*# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).classLitAutoEncoder(L.LightningModule):def__init__(self):super().__init__()self.encoder=nn.Sequential(nn.Linear(28*28,128),nn.ReLU(),nn.Linear(128,3))self.decoder=nn.Sequential(nn.Linear(3,128),nn.ReLU(),nn.Linear(128,28*28))defforward(self,x):# in lightning, forward defines the prediction/inference actionsembedding=self.encoder(x)returnembeddingdeftraining_step(self,batch,batch_idx):# training_step defines the train loop. It is independent of forwardx,_=batchx=x.view(x.size(0),-1)z=self.encoder(x)x_hat=self.decoder(z)loss=F.mse_loss(x_hat,x)self.log("train_loss",loss)returnlossdefconfigure_optimizers(self):optimizer=torch.optim.Adam(self.parameters(),lr=1e-3)returnoptimizer# -------------------# Step 2: Define data# -------------------dataset=tv.datasets.MNIST(".",download=True,transform=tv.transforms.ToTensor())train,val=data.random_split(dataset, [55000,5000])# -------------------# Step 3: Train# -------------------autoencoder=LitAutoEncoder()trainer=L.Trainer()trainer.fit(autoencoder,data.DataLoader(train),data.DataLoader(val))
Run the model on your terminal
pip install torchvisionpython main.py
PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering.
Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more:
Task | Description | Run |
---|---|---|
Hello world | Pretrain - Hello world example | |
Image classification | Finetune - ResNet-34 model to classify images of cars | |
Image segmentation | Finetune - ResNet-50 model to segment images | |
Object detection | Finetune - Faster R-CNN model to detect objects | |
Text classification | Finetune - text classifier (BERT model) | |
Text summarization | Finetune - text summarization (Hugging Face transformer model) | |
Audio generation | Finetune - audio generator (transformer model) | |
LLM finetuning | Finetune - LLM (Meta Llama 3.1 8B) | |
Image generation | Pretrain - Image generator (diffusion model) | |
Recommendation system | Train - recommendation system (factorization and embedding) | |
Time-series forecasting | Train - Time-series forecasting with LSTM |
Lightning has over40+ advanced featuresdesigned for professional AI research at scale.
Here are some examples:
Train on 1000s of GPUs without code changes
# 8 GPUs# no code changes neededtrainer=Trainer(accelerator="gpu",devices=8)# 256 GPUstrainer=Trainer(accelerator="gpu",devices=8,num_nodes=32)
Train on other accelerators like TPUs without code changes
# no code changes neededtrainer=Trainer(accelerator="tpu",devices=8)
16-bit precision
# no code changes neededtrainer=Trainer(precision=16)
Experiment managers
fromlightningimportloggers# tensorboardtrainer=Trainer(logger=TensorBoardLogger("logs/"))# weights and biasestrainer=Trainer(logger=loggers.WandbLogger())# comettrainer=Trainer(logger=loggers.CometLogger())# mlflowtrainer=Trainer(logger=loggers.MLFlowLogger())# neptunetrainer=Trainer(logger=loggers.NeptuneLogger())# ... and dozens more
Early Stopping
es=EarlyStopping(monitor="val_loss")trainer=Trainer(callbacks=[es])
Checkpointing
checkpointing=ModelCheckpoint(monitor="val_loss")trainer=Trainer(callbacks=[checkpointing])
Export to torchscript (JIT) (production use)
# torchscriptautoencoder=LitAutoEncoder()torch.jit.save(autoencoder.to_torchscript(),"model.pt")
Export to ONNX (production use)
# onnxwithtempfile.NamedTemporaryFile(suffix=".onnx",delete=False)astmpfile:autoencoder=LitAutoEncoder()input_sample=torch.randn((1,64))autoencoder.to_onnx(tmpfile.name,input_sample,export_params=True)os.path.isfile(tmpfile.name)
- Models become hardware agnostic
- Code is clear to read because engineering code is abstracted away
- Easier to reproduce
- Make fewer mistakes because lightning handles the tricky engineering
- Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate
- Lightning has dozens of integrations with popular machine learning tools.
- Tested rigorously with every new PR. We test every combination of PyTorch and Python supported versions, every OS, multi GPUs and even TPUs.
- Minimal running speed overhead (about 300 ms per epoch compared with pure PyTorch).
Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.
Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.
What to change | Resulting Fabric Code (copy me!) |
---|---|
+ import lightning as L import torch; import torchvision as tv dataset = tv.datasets.CIFAR10("data", download=True, train=True, transform=tv.transforms.ToTensor())+ fabric = L.Fabric()+ fabric.launch() model = tv.models.resnet18() optimizer = torch.optim.SGD(model.parameters(), lr=0.001)- device = "cuda" if torch.cuda.is_available() else "cpu"- model.to(device)+ model, optimizer = fabric.setup(model, optimizer) dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)+ dataloader = fabric.setup_dataloaders(dataloader) model.train() num_epochs = 10 for epoch in range(num_epochs): for batch in dataloader: inputs, labels = batch- inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = torch.nn.functional.cross_entropy(outputs, labels)- loss.backward()+ fabric.backward(loss) optimizer.step() print(loss.data) | importlightningasLimporttorch;importtorchvisionastvdataset=tv.datasets.CIFAR10("data",download=True,train=True,transform=tv.transforms.ToTensor())fabric=L.Fabric()fabric.launch()model=tv.models.resnet18()optimizer=torch.optim.SGD(model.parameters(),lr=0.001)model,optimizer=fabric.setup(model,optimizer)dataloader=torch.utils.data.DataLoader(dataset,batch_size=8)dataloader=fabric.setup_dataloaders(dataloader)model.train()num_epochs=10forepochinrange(num_epochs):forbatchindataloader:inputs,labels=batchoptimizer.zero_grad()outputs=model(inputs)loss=torch.nn.functional.cross_entropy(outputs,labels)fabric.backward(loss)optimizer.step()print(loss.data) |
Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
# Use your available hardware# no code changes neededfabric=Fabric()# Run on GPUs (CUDA or MPS)fabric=Fabric(accelerator="gpu")# 8 GPUsfabric=Fabric(accelerator="gpu",devices=8)# 256 GPUs, multi-nodefabric=Fabric(accelerator="gpu",devices=8,num_nodes=32)# Run on TPUsfabric=Fabric(accelerator="tpu")
Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
# Use state-of-the-art distributed training techniquesfabric=Fabric(strategy="ddp")fabric=Fabric(strategy="deepspeed")fabric=Fabric(strategy="fsdp")# Switch the precisionfabric=Fabric(precision="16-mixed")fabric=Fabric(precision="64")
All the device logic boilerplate is handled for you
# no more of this!- model.to(device)- batch.to(device)
Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
importlightningasLclassMyCustomTrainer:def__init__(self,accelerator="auto",strategy="auto",devices="auto",precision="32-true"):self.fabric=L.Fabric(accelerator=accelerator,strategy=strategy,devices=devices,precision=precision)deffit(self,model,optimizer,dataloader,max_epochs):self.fabric.launch()model,optimizer=self.fabric.setup(model,optimizer)dataloader=self.fabric.setup_dataloaders(dataloader)model.train()forepochinrange(max_epochs):forbatchindataloader:input,target=batchoptimizer.zero_grad()output=model(input)loss=loss_fn(output,target)self.fabric.backward(loss)optimizer.step()
You can find a more extensive example in ourexamples
Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against major Python and PyTorch versions.
Current build statuses
System / PyTorch ver. | 1.13 | 2.0 | 2.1 |
---|---|---|---|
Linux py3.9 [GPUs] | |||
Linux (multiple Python versions) | |||
OSX (multiple Python versions) | |||
Windows (multiple Python versions) |
The lightning community is maintained by
- 10+ core contributors who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs.
- 800+ community contributors.
Want to help us build Lightning and reduce boilerplate for thousands of researchers?Learn how to make your first contribution here
Lightning is also part of thePyTorch ecosystem which requires projects to have solid testing, documentation and support.
If you have any questions please:
About
Pretrain, finetune ANY AI model of ANY size on 1 or 10,000+ GPUs with zero code changes.
Topics
Resources
License
Code of conduct
Contributing
Security policy
Uh oh!
There was an error while loading.Please reload this page.