Learn how to build a model capable of performing multiple image classifications concurrently with Multiple-Task Learning
Multi-task learning (MLT) is a subfield of Machine Learning in which multiple tasks are simultaneously learned by a shared model. This type of learning helps to improve data efficiency and training speed, because the shared model will learn several tasks from the same data set, and will be able to learn faster thanks to the auxiliary information of the different tasks. In addition, it also reduces overfitting, since it will be more difficult for the model to fit perfectly with the training data, taking into account that the labels of the training data are different for each task .
This article will explain at an introductory level MTL, and will show how to implement and train it on real data with the Keras module from tensorflow. The complete code, along with a Jupyter Notebook in which you can experiment what you have learned, can be found in my GitHub repository:
In order to make the explanation understandable and simple, the CIFAR-10  dataset will be used, which is made available under the MIT license. This dataset consists of 60000 RGB images of 32×32 pixels, which are classified in 10 different classes. It is divided into 50000 training samples and 10000 testing samples and is perfectly balanced, which means that the dataset contains 6000 images per class. The dataset can be easily loaded by executing:
The dataset contains the following classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck. The two tasks to be learned by the multi-task model will be classifications on these labels, see:
- Task 1: multi-class classification on the modified CIFAR10 dataset (airplane, automobile, bird, cat, dog, frog, ship and truck labels, modifications explained below).
- Task 2: binary classification (labels are animal and vehicle).
A much more efficient option to achieve the two classification tasks mentioned above would be to train a model to learn only the first task, the outputs of which would then be used to predict the binary class animal or vehicle. An example of this would be to pass as input to the model the image of a frog, for which the model would obtain as output the frog class. Since a frog is an animal, then the image would be classified to the animal class (a schema of this solution can be seen below).
In spite of this, this article will solve the problem by applying Multi-Task Learning, since, although it is not the most efficient example, it perfectly demonstrates the usefulness and application of MTL in this type of problems, and it is an excellent base from which to develop knowledge.
With this in mind, and in order to have a balanced dataset for training, the instances of the classes belonging to deer and horse will be erased. This is done because initially the dataset contains 30000 samples belonging to animals (5000 samples x 6 classes) and only 20000 belonging to vehicles (5000 samples x 4 classes), which would unbalance the dataset with regards to the binary classification task. Instances of horses and deers are removed because they have very similar traits to cats and dogs, and therefore could add complexity to the training, as it would be more difficult to distinguish between instances of these classes.
Note that after this cleaning the labels have the following classes: [0, 1, 2, 3, 5, 6, 8, 9] (4 and 7 are missing, corresponding to deer and horse, respectively). Therefore, it will be necessary to update the labels so that they are numbered from 0 to 7, step that can be found in the Jupyter notebook in my repository. However, try code it yourself for better learning!
It is essential for Multi-Task Learning that training labels are task-specific. Therefore, in an n-task training, n arrays of different labels will be defined. In this case, the first task requires the labels to be integers from 0 to 7 (one number for each class), and the second task requires the labels to be 0 and 1 (since it is a binary classification). Data was previously preprocessed so that labels are digits from 0 to 7 and, as expected, the labels for the binary classification will be constructed based on the initial 0-to-7 labels, such that if the instance corresponds to an animal it will be a 0, and it will be a 1 when it corresponds to a vehicle.
Create the Model
As both tasks use images as training data, a convolutional network (CNN) will be used as a shared model, so that it will learn to initially extract the most significant features from the images. The output of the shared model will be flattened, and introduced to the branches corresponding to each task. Both branches will be composed of dense layers (since we have flattened their inputs) with different numbers of neurons each, whose output layers will be composed of 2 and 8 neurons for the binary and multi-class classification tasks, respectively. All layers use ReLu as activation function except the output layers, which use softmax and sigmoid for multiclass and binary classification, respectively. A low-level outline of the described model together with the code which defines it can be seen below.
In the code, it can be seen that when defining each layer of the neural network, the previous layer is received by the object, i.e. the layer whose output is the input of the layer being defined. In this way the Keras model can be separated for each task, making both branches/submodels start from the same shared model.
Usually models update their weights seeking to optimize their loss function (see my previous article Digit Classification with Single-Layer Perceptron), but in MTL each branch of the general model is learning a different task, so it is necessary to specify a loss function for each task. However, tensorflow only uses the result of a single loss function in the Backpropagation process, so a joint loss function must also be defined, which in its simplest form is a weighted sum of the values of the different loss functions. In this case the joint loss function is defined as follows:
Note that when ɣ equals 0 the model will only receive the loss obtained from Task 2, whereas if ɣ equals 1 the model will receive the loss of Task 1. This allows to train the model only for Task 1 (if ɣ = 1), only for Task 2 (if ɣ = 0) or for both tasks (if 0 < ɣ < 1), which allows the model to be used for multiple scenarios, depending on what is needed from the model.
The joint loss function can be defined in the previously created model when compiling by specifying the weights for the loss function of each branch of the model in the loss_weights parameter. See the function below.
Train the Model
Once the data has been preprocessed and the model is defined for both tasks, it is time to train it. The .fit() method is used for training, which, unlike when defining a normal model, receives as many arrays as there are branches/outputs in the model as y parameter. On the other hand, as with normal models, the batch size and the number of epochs must be specified. In this case a batch size of 128 will be used, and it will be trained for 15 epochs.
The model is trained with 3 different values for gamma: 0, 0.5 and 1, and the code measures and prints the run time of each training.
Evaluate the Models’ Performance
Finally, we will observe what are the results of the MTL models for the different gamma values by plotting the accuracy over the 15 epochs for the two tasks. This can be done thanks to the history list returned by the fit_batch_multitask_models() function defined above.
The graphs obtained from the different values of gamma clearly show the above mentioned: extreme values of gamma imply that the model will learn to perform only one of the tasks, and an intermediate value of gamma implies that the model will manage to learn to do both tasks. In addition, the model has obtained an accuracy of over 90% for both tasks, being able to perform the binary classification better. It should be mentioned that binary classification is a simpler task than multiclass classification, so the results obtained are to be expected.
Also, in the Jupyter Notebook random images of the test dataset are taken and predictions are extracted from the model to check that it works, feel free to modify it and try new things!
Multi-Task Learning has managed to achieve very good results in both tasks when the gamma value balances the weight of the loss functions of the two tasks, and it also fulfills the task of a single-task model perfectly when extreme gamma values are taken, so is a very interesting option for situations in which the model needs to fulfill certain tasks sporadically. It must be taken into account that having a bifurcation in the model implies more computational cost, and therefore greater execution time. In addition, despite establishing extreme gamma values, the execution time will continue to be high, since the model carries out the Forward and Backpropagation processes in both branches either way. This is important if the model only has to perform one task, since training a single-task model (without branches) will be much more efficient in terms of computation cost.
Finally, the enormous possibilities of this type of architecture should be mentioned. The fact of applying a first filter (the shared model) to the input to then predict based on a custom branch makes it possible to greatly optimize the resources needed to train models which are required to perform tasks with similarities in certain aspects. Known and powerful models can be used as a base or shared model for this type of architecture, achieving models capable of achieving very good results in many different tasks, always with a part of their nature in common.
 CRAWSHAW, Michael. Multi-task learning with deep neural networks: A survey. arXiv preprint arXiv:2009.09796, 2020