Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork1k
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
License
qubvel/segmentation_models
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Python library with Neural Networks for Image Segmentation based onKeras andTensorFlow.
The main features of this library are:
- High level API (just two lines of code to create model for segmentation)
- 4 models architectures for binary and multi-class image segmentation(including legendaryUnet)
- 25 available backbones for each architecture
- All backbones havepre-trained weights for faster and betterconvergence
- Helpful segmentation losses (Jaccard, Dice, Focal) and metrics (IoU, F-score)
Important note
Some models of version
1.*
are not compatible with previously trained models,if you have such models and want to load them - roll back with:$ pip install -U segmentation-models==0.2.1
Library is build to work together with Keras and TensorFlow Keras frameworks
importsegmentation_modelsassm# Segmentation Models: using `keras` framework.
By default it tries to importkeras
, if it is not installed, it will try to start withtensorflow.keras
framework.There are several ways to choose framework:
- Provide environment variable
SM_FRAMEWORK=keras
/SM_FRAMEWORK=tf.keras
before importsegmentation_models
- Change framework
sm.set_framework('keras')
/sm.set_framework('tf.keras')
You can also specify what kind ofimage_data_format
to use, segmentation-models works with both:channels_last
andchannels_first
.This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.
importkeras# or from tensorflow import keraskeras.backend.set_image_data_format('channels_last')# or keras.backend.set_image_data_format('channels_first')
Created segmentation model is just an instance of Keras Model, which can be build as easy as:
model=sm.Unet()
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
model=sm.Unet('resnet34',encoder_weights='imagenet')
Change number of output classes in the model (choose your case):
# binary segmentation (this parameters are default when you call Unet('resnet34')model=sm.Unet('resnet34',classes=1,activation='sigmoid')
# multiclass segmentation with non overlapping class masks (your classes + background)model=sm.Unet('resnet34',classes=3,activation='softmax')
# multiclass segmentation with independent overlapping/non-overlapping class masksmodel=sm.Unet('resnet34',classes=3,activation='sigmoid')
Change input shape of the model:
# if you set input channels not equal to 3, you have to set encoder_weights=None# how to handle such case with encoder_weights='imagenet' described in docsmodel=Unet('resnet34',input_shape=(None,None,6),encoder_weights=None)
importsegmentation_modelsassmBACKBONE='resnet34'preprocess_input=sm.get_preprocessing(BACKBONE)# load your datax_train,y_train,x_val,y_val=load_data(...)# preprocess inputx_train=preprocess_input(x_train)x_val=preprocess_input(x_val)# define modelmodel=sm.Unet(BACKBONE,encoder_weights='imagenet')model.compile('Adam',loss=sm.losses.bce_jaccard_loss,metrics=[sm.metrics.iou_score],)# fit model# if you use data generator use model.fit_generator(...) instead of model.fit(...)# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generatormodel.fit(x=x_train,y=y_train,batch_size=16,epochs=100,validation_data=(x_val,y_val),)
Same manipulations can be done withLinknet
,PSPNet
andFPN
. For more detailed information about models API and use casesRead the Docs.
- Models training examples:
Models
Unet | Linknet |
---|---|
![]() | ![]() |
PSPNet | FPN |
---|---|
![]() | ![]() |
Backbones
Type | Names |
---|---|
VGG | 'vgg16' 'vgg19' |
ResNet | 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' |
SE-ResNet | 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152' |
ResNeXt | 'resnext50' 'resnext101' |
SE-ResNeXt | 'seresnext50' 'seresnext101' |
SENet154 | 'senet154' |
DenseNet | 'densenet121' 'densenet169' 'densenet201' |
Inception | 'inceptionv3' 'inceptionresnetv2' |
MobileNet | 'mobilenet' 'mobilenetv2' |
EfficientNet | 'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7' |
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet'
).
Requirements
- python 3
- keras >= 2.2.0 or tensorflow >= 1.13
- keras-applications >= 1.0.7, <=1.0.8
- image-classifiers == 1.0.*
- efficientnet == 1.0.*
PyPI stable package
$ pip install -U segmentation-models
PyPI latest package
$ pip install -U --pre segmentation-models
Source latest version
$ pip install git+https://github.com/qubvel/segmentation_models
Latestdocumentation is avaliable onRead theDocs
To see important changes between versions look atCHANGELOG.md
@misc{Yakubovskiy:2019, Author = {Pavel Iakubovskii}, Title = {Segmentation Models}, Year = {2019}, Publisher = {GitHub}, Journal = {GitHub repository}, Howpublished = {\url{https://github.com/qubvel/segmentation_models}}}
Project is distributed underMIT Licence.
About
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Sponsor this project
Uh oh!
There was an error while loading.Please reload this page.
Packages0
Uh oh!
There was an error while loading.Please reload this page.