- Notifications
You must be signed in to change notification settings - Fork13
A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
License
mberkay0/pretrained-backbones-unet
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
APyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
This is a simple package for semantic segmentation withUNet and pretrained backbones. This package utilizes thetimm models for the pre-trained encoders.
When dealing with relatively limited datasets, initializing a model using pre-trained weights from a large dataset can be an excellent choice for ensuring successful network training. By utilizing state-of-the-art models, such as ConvNeXt, as an encoder, you can effortlessly solve the problem at hand while achieving optimal performance in this context.
The primary characteristics of this library are as follows:
430 pre-trained backbone networks are available for the UNet semantic segmentation model.
Supports backbone networks such as ConvNext, ResNet, EfficientNet, DenseNet, RegNet, and VGG... which are popular and SOTA performers, for the UNet model.
It is possible to adjust which layers of the backbone of the model are trainable parametrically.
It includes a DataSet class for binary and multi-class semantic segmentation.
And it comes with a pre-built rapid custom training class.
pip install pretrained-backbones-unet
pip install git+https://github.com/mberkay0/pretrained-backbones-unet
frombackbones_unet.model.unetimportUnetfrombackbones_unet.utils.datasetimportSemanticSegmentationDatasetfrombackbones_unet.model.lossesimportDiceLossfrombackbones_unet.utils.trainerimportTrainer# create a torch.utils.data.Dataset/DataLoadertrain_img_path='example_data/train/images'train_mask_path='example_data/train/masks'val_img_path='example_data/val/images'val_mask_path='example_data/val/masks'train_dataset=SemanticSegmentationDataset(train_img_path,train_mask_path)val_dataset=SemanticSegmentationDataset(val_img_path,val_mask_path)train_loader=DataLoader(train_dataset,batch_size=2)val_loader=DataLoader(val_dataset,batch_size=2)model=Unet(backbone='convnext_base',# backbone network namein_channels=3,# input channels (1 for gray-scale images, 3 for RGB, etc.)num_classes=1,# output channels (number of classes in your dataset))params= [pforpinmodel.parameters()ifp.requires_grad]optimizer=torch.optim.AdamW(params,1e-4)trainer=Trainer(model,# UNet model with pretrained backbonecriterion=DiceLoss(),# loss function for model convergenceoptimizer=optimizer,# optimizer for regularizationepochs=10# number of epochs for model training)trainer.fit(train_loader,val_loader)
importbackbones_unetprint(backbones_unet.__available_models__)
About
A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.