Have you ever wondered how to train a deep neural network to do many things? Such a model is referred to as a Multi-Task Architecture and can have benefits over a traditional approach that uses individual models for each task. A Multi-Task Architecture is a subset of Multi-Task Learning which is a general approach to training a model or set of models to perform multiple tasks simultaneously.
In this post we will learn how to train a single model to perform both classification and regression tasks simultaneously. Code for this post can be found on GitHub. Here’s an overview:
Why would we want to use a light weight model? Won’t that decrease performance? If we are not deploying to the edge shouldn’t we use as big of a model as possible?
Edge applications need light weight models to perform real-time inference with low power consumption. Other applications can benefit from them as well, but how? An overlooked benefit of lightweight models is their lower compute requirement. In general this can lower server usage and therefore decrease power consumption. This has the overall effect of reducing costs and lowering carbon emissions, the later of which could become a major issue in the future of AI.
Lightweight models can help reduce costs and lower carbon emissions via less power consumption
With all this being said, a multi-task architecture is a just a tool in the toolbox, and all project requirements should be considered before deciding which tools to use. Now let’s dive into an example of how train one of these!
To build our Multi-Task Architecture, we will loosely cover the approach from this paper, where a single model was trained for simultaneous segmentation and depth estimation. The underlying goal was to perform these tasks in a fast and efficient manner with a trade-off being an acceptable loss of performance. In Multi-Task Learning, we typically group similar tasks together. During training, we can also add an auxiliary task that may assist our model’s learning, but we may decide not to use it during inference [1, 2]. For simplicity, we will not use any auxiliary tasks during training.
Depth and Segmentation are both dense prediction tasks, and have similarities. For example, the depth of a single object will likely be consistent across all areas of the object, forming a very narrow distribution. The main idea is that each individual object should a have it’s own depth value, and we should be able to recognize individual objects just by looking at a depth map. In the same manner, we should be able to recognize the same individual objects by looking at a segmentation map. While there are likely to be some outliers, we will assume that this relationship holds.
Dataset
We will use the City Scapes dataset to provide (left camera) input images segmentation masks, and depth maps. For the segmentation maps, we choose to use the standard training labels, with 19 classes + 1 unlabeled category.
Depth Map Preparation — default disparity
Disparity maps created with SteroSGBM are readily available from the CityScapes website. The Disparity describes the pixel difference between objects as viewed from each stereo camera’s perspective, and it is inversely proportional to the depth which can be computed with:
However, the default disparity maps are noisy with many holes corresponding to infinite depth and a portion where the ego vehicle is always shown. A common approach to cleaning these disparity maps involves:
- Crop out the bottom 20% along with parts of the left and top edges
- Resize to original scale
- Apply a smoothing filter
- Perform inpainting
Once we clean the disparity we can compute the depth, which results in:
The fine details of this approach are outside the scope of this post, but if your interested here’s a video explanation on YouTube.
The crop and resize step means that the disparity (as well as the depth) map won’t exactly align with the input image. Even though we could do the same crop and resize with the input image to correct for this, we opted to explore a new approach.
Depth Map Preparation — CreStereo disparity
We explored using CreStereo to produce high quality disparity maps from both the left and right images. CreStereo is an advanced model that is able to predict smooth disparity maps from stereo image pairs. This approach introduces a paradigm known as knowledge distillation, where CreStereo is a teacher network and our model will be the student network (at least for the depth estimation). This details of this approach are outside the scope of this post, but here’s a YouTube link if you’re interested.
In general, the CreStereo depth maps have minimal noise so there’s no need to crop and resize. However, the ego vehicle present in the segmentation masks could cause issues with generalization so the bottom 20% was removed on all training images. A training sample is shown below:
Now that we have our data, let’s see the architecture.
Following [1], the architecture will consist of a MobileNet backbone/encoder, a LightWeight RefineNet Decoder, and Heads for each individual task. The overall architecture is shown below in figure 3.
For the encoder/backbone, we will use a MobileNetV3 and pass skip connections at 1/4, 1/8, 1/16, and 1/32 resolutions to the Light Weight Refine Net. Finally, the output is passed to each head that is responsible for a different task. Notice how we can even add more tasks to this architecture if we wanted to.
To implement the encoder, we use a pre-trained MobileNetV3 encoder, where we will pass in the MobileNetV3 encoder to a custom PyTorch Module. The output of it’s forward function is a ParameterDict of skip connections for input to the LightWeight Refine Net. The code snippet below shows how to do this.
class MobileNetV3Backbone(nn.Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbonedef forward(self, x):
""" Passes input theough MobileNetV3 backbone feature extraction layers
layers to add connections to
- 1: 1/4 res
- 3: 1/8 res
- 7, 8: 1/16 res
- 10, 11: 1/32 res
"""
skips = nn.ParameterDict()
for i in range(len(self.backbone) - 1):
x = self.backbone[i](x)
# add skip connection outputs
if i in [1, 3, 7, 8, 10, 11]:
skips.update({f"l{i}_out" : x})
return skips
The LightWeight RefineNet Decoder is very similar to the one implemented in [1], except with a few modifications to make it compatible with MobileNetV3 as opposed to MobileNetV2. We also note that the decoder portion consists of the Segmentation and Depth heads. The full code for the model is available on GitHub. We can piece together the model as follows:
from torchvision.models import mobilenet_v3_smallmobilenet = mobilenet_v3_small(weights='IMAGENET1K_V1')
encoder = MobileNetV3Backbone(mobilenet.features)
decoder = LightWeightRefineNet(num_seg_classes)
model = MultiTaskNetwork(encoder, freeze_encoder=False).to(device)
We divide training into three phases, the first at 1/4 resolution, the second at 1/2 resolution, and the final at full resolution. All of the weights were updated, since freezing the encoder weights didn’t seem to produce good results.
Transformations
During each phase, we perform random crop resize, color jitter, random flips, and normalization. The left input image is normalized with standard image net mean and standard deviation.
Depth Transformation
In general the depth maps contains mostly smaller values, since most of the information contained in a depth map consists of objects and surfaces close to the camera. Since the depth map has most of it’s depth concentrated around lower values (see left of figure 4 below), it will need to be transformed to be effectively learned by a neural network. The depth map is clipped between 0 and 250, this is because stereo disparity/depth data at large distances is typically unreliable and in this case we want a way to discard it. Then we take the natural log and divide it by 5 to condense the distribution around a smaller range of numbers. See this notebook for more details.
I’ll be honest, I wasn’t sure of the best way to transform the depth data. If there’s a better way or if you would do it differently I would interested to learn more in the comments :).
Loss Functions
We keep the loss functions simple, Cross Entropy Loss for segmentation and Mean Squared Error for Depth Estimation. We add them together with no weighting and optimize them jointly.
Learning Rate
We use a One Cycle Cosine Annealed Learning Rate with a max at 5e-4 and train for 150 epochs at 1/4 resolution. The notebook used for training is located here.
We fine then tune at 1/2 resolution for 25 epochs and again at full resolution for another 25 epochs both with a learning rate of 5e-6. Note that we needed to reduce the batch size each time we fine tuned at an increased resolution.
For inference we normalized the input image and ran a forward pass through the model. Figure 6 shows training results from both validation and test data
In general it seems like the model is able to segment and estimate depth when there are larger objects in an image. When more finely detailed objects such as pedestrians are present, the model tends to struggle to segment them entirely. The model is able to estimate their depth to some degree of accuracy.
An Interesting Failure
The bottom of figure 6 shows an interesting failure case to fully segment the light pole in the left side of the image. The segmentation only covers the bottom half of the light pole, while the depth shows that the bottom half of the light pole is much closer than the top half. The depth failure, could be due the bias of bottom pixels generally corresponding to closer depth; notice the horizon line around pixel 500, there is a clear divide between closer pixels and further way pixels. It seems like this bias could have leaked into the model’s segmentation task. This type of task leakage should be considered when training multi-task models.
In Multi-task Learning, training data from one task can impact performance on another task
Depth Distributions
Let’s check how the predicted depth is distributed compared to the truth. For simplicity, we will just use a sample of 94 true/predicted full resolution depth map pairs.
It seems like the model has learned two distributions, a distribution with a peak around 4 and distribution with a peak around 30. Notice that the clipping artifact did not seem to make a difference. The overall distribution contains a long tail which is characteristic of the fact that only a small portion of an image will contain far away depth data.
The predicted depth distribution is much more smooth than the ground truth. The roughness of the ground truth distribution could come from the fact that each object contains similar depth values. It may be possible to use this information to apply some sort of regularization to force the model to follow this paradigm, but that will be for another time.
Bonus: Inference Speed
Since this is a lightweight model intended for speed, let’s see how fast it will inference on GPU. The code below has been modified from this article. In this test, the input image has been scaled down to 400×1024.
# find optimal backend for performing convolutions
torch.backends.cudnn.benchmark = True # rescale to half size
rescaled_sample = Rescale(400, 1024)(sample)
rescaled_left = rescaled_sample['left'].to(DEVICE)
# INIT LOGGERS
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))
#GPU-WARM-UP
for _ in range(10):
_, _ = model(rescaled_left.unsqueeze(0))
# MEASURE PERFORMANCE
with torch.no_grad():
for rep in range(repetitions):
starter.record()
_, _ = model(rescaled_left.unsqueeze(0))
ender.record()
# WAIT FOR GPU SYNC
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)
print(mean_syn, std_syn)
The inference test shows that this model can run at 18.69+/-0.44 ms or about 55Hz. It’s important to note that this is just a Python prototype ran on a laptop with a NVIDIA RTX 3060 GPU, different hardware will change inference speed. We should also note, an SDK like Torch-TensorRt could provide significant speed up if deployed on a NVIDIA GPU.
In this post we learned how Multi-Task Learning can save costs and reduce carbon emissions. We learned how to build a lightweight multi-task architecture capable of performing classification and regression simultaneously on the CityScapes data set. We also leveraged CreStereo and Knowledge Distillation to help our model learn to predict better depth maps.
This lightweight model presents a trade-off where we sacrifice some performance for speed and efficiency. Even with this tradeoff, the trained model was able to predict reasonable depth and segmentation results on test data. Furthermore, it was able learn to predict a similar depth distribution to the ground truth depth maps.
[1] Nekrasov, Vladimir, et al. ‘Real-Time Joint Semantic Segmentation and Depth Estimation Using Asymmetric Annotations’. CoRR, vol. abs/1809.04766, 2018, http://arxiv.org/abs/1809.04766
[2] Standley, Trevor, et al. ‘Which Tasks Should Be Learned Together in Multi-Task Learning?’ CoRR, vol. abs/1905.07553, 2019, http://arxiv.org/abs/1905.07553
[3] Cordts, M., Omran, M., Ramos, S., Rehfeld, T., Enzweiler, M., Benenson, R., Franke, U., Roth, S., & Schiele, B. (2016). The cityscapes dataset for Semantic Urban Scene understanding. 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). https://doi.org/10.1109/cvpr.2016.350