Build and train a segmentation model with a few lines of code
Neural network models have proven to be highly effective in solving segmentation problems, achieving state-of-the-art accuracy. They have led to significant improvements in various applications, including medical image analysis, autonomous driving, robotics, satellite imagery, video surveillance, and much more. However, building these models usually takes a long time, but after reading this guide you will be able to build one with just a few lines of code.
Table of content
- Building blocks
- Build a model
- Train the model
Segmentation is the task of dividing an image into multiple segments or regions based on certain characteristics or properties. A segmentation model takes an image as input and returns a segmentation mask:
Segmentation neural network models consist of two parts:
- An encoder: takes an input image and extracts features. Examples of encoders are ResNet, EfficentNet, and ViT.
- A decoder: takes the extracted features and generates a segmentation mask. The decoder varies on the architecture. Examples of architectures are U-Net, FPN, and DeepLab.
Thus, when building a segmentation model for a specific application, you need to choose an architecture and an encoder. However, it is difficult to choose the best combination without testing several. This usually takes a long time because changing the model requires writing a lot of boilerplate code. The Segmentation Models library solves this problem. It allows you to create a model in a single line by specifying the architecture and the encoder. Then you only need to modify that line to change either of them.
To install the latest version of Segmentation Models from PyPI use:
pip install segmentation-models-pytorch
The library provides a class for most segmentation architectures and each of them can be used with any of the available encoders. In the next section, you will see that to build a model you need to instantiate the class of the chosen architecture and pass the string of the chosen encoder as a parameter. The figure below shows the class name of each architecture provided by the library:
The figure below shows the names of the most common encoders provided by the library:
There are over 400 encoders, thus it’s not possible to show them all, but you can find a comprehensive list here.
Once the architecture and the encoder have been chosen from the figures above, building the model is very simple:
encoder_nameis the name of the chosen encoder (e.g. resnet50, efficentnet-b7, mit_b5).
encoder_weightsis the dataset of the pre-trained. If
encoder_weightsis equal to
"imagenet"the encoder weights are initialized by using the ImageNet pre-trained. All the encoders have at least one pre-trained and a comprehensive list is available here.
in_channelsis the channel count of the input image (3 if RGB).
in_channelsis not 3 an ImageNet pre-trained can be used: the first layer will be initialized by reusing the weights from the pre-trained first convolutional layer (the procedure is described here).
out_classesis the number of classes in the dataset.
activationis the activation function for the output layer. The possible choices are
Note: when using a loss function that expects logits as input, the activation function must be None. For example, when using the
This section shows all the code required to perform training. However, this library doesn’t change the usual pipeline for training and validating a model. To simplify the process, the library provides the implementation of many loss functions such as Jaccard Loss, Dice Loss, Dice Cross-Entropy Loss, Focal Loss, and metrics such as Accuracy, Precision, Recall, F1Score, and IOUScore. For a complete list of them and their parameters, check their documentation in the Losses and Metrics sections.
The proposed training example is a binary segmentation using the Oxford-IIIT Pet Dataset (it will be downloaded by code). These are two samples from the dataset:
Finally, these are all steps to perform this type of segmentation task:
- Build the model.
Set the activation function of the last layer depending on the loss function you are going to use.
2. Define the parameters.
Remember that when using a pre-trained, the input should be normalized by using the mean and standard deviation of the data used to train the pre-trained.
3. Define the train function.
Nothing changes here from the train function you would have written to train a model without using the library.
4. Define the validation function.
True positives, false positives, false negatives and true negatives from batches are all summed together to calculate metrics only at the end of batches. Note that logits must be converted to classes before metrics can be calculated. Call the train function to start training.
5. Use the model.
These are some segmentations:
This library has everything you need to experiment with segmentation. It’s very easy to build a model and apply changes, and most loss functions and metrics are provided. In addition, using this library doesn’t change the pipeline we’re used to. See the official documentation for more information. I have also included some of the most common encoders and architectures in the references.
The Oxford-IIIT Pet Dataset is available to download for commercial/research purposes under a Creative Commons Attribution-ShareAlike 4.0 International License. The copyright remains with the original owners of the images.
All images, unless otherwise noted, are by the Author. Thanks for reading, I hope you have found this useful.
 O. Ronneberger, P. Fischer and T. Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)
 Z. Zhou, Md. M. R. Siddiquee, N. Tajbakhsh and J. Liang, UNet++: A Nested U-Net Architecture for Medical Image Segmentation (2018)
 L. Chen, G. Papandreou, F. Schroff, H. Adam, Rethinking Atrous Convolution for Semantic Image Segmentation (2017)
 L. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation (2018)
 R. Li, S. Zheng, C. Duan, C. Zhang, J. Su, P.M. Atkinson, Multi-Attention-Network for Semantic Segmentation of Fine Resolution Remote Sensing Images (2020)
 A. Chaurasia, E. Culurciello, LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation (2017)
 T. Lin, P. Dollár, R. Girshick, K. He, B. Hariharan, S. Belongie, Feature Pyramid Networks for Object Detection (2017)
 H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid Scene Parsing Network (2016)
 H. Li, P. Xiong, J. An, L. Wang, Pyramid Attention Network for Semantic Segmentation (2018)
 K. Simonyan, A. Zisserman, Very Deep Convolutional Networks for Large-Scale Image Recognition (2014)
 Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, Deep Residual Learning for Image Recognition (2015)
 S. Xie, R. Girshick, P. Dollár, Z. Tu, K. He, Aggregated Residual Transformations for Deep Neural Networks (2016)
 J. Hu, L. Shen, S. Albanie, G. Sun, E. Wu, Squeeze-and-Excitation Networks (2017)
 G. Huang, Z. Liu, L. van der Maaten, K. Q. Weinberger, Densely Connected Convolutional Networks (2016)
 M. Tan, Q. V. Le, EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (2019)
 E. Xie, W. Wang, Z. Yu, A. Anandkumar, J. M. Alvarez, P. Luo, SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (2021)