Learn what cross-validation is — a fundamental technique for building generalizable models
The concept of cross-validation extends directly from the one of overfitting, covered in my previous article.
Cross-validation is one of the most effective techniques to avoid overfitting and to understand the performance of a predictive model well.
When I wrote about overfitting, I divided my data into training and test sets. The training set was used to train the model, the test set to evaluate its performance. But this method should be generally avoided and not applied in real world scenarios.
This is because we can induce overfitting in the test set if we train our model for a long time until we find the correct configuration. This concept is called data leakage and is one of the most common and impactful problems in the field. In fact, if we trained our model to perform well on the test set, it would then be valid only for that test set.
What do I mean by configuration? Each model is characterized by a series of hyperparameters. Here is a definition
A model hyperparameter is an external configuration external whose value cannot be estimated from the data. Changing a hyperparameter changes the behavior of the model on our data accordingly and can improve or worsen our performance.
For example, Sklearn’s DecisionTreeClassifier tree, has max_depth as a hyperparameter that manages the depth of the tree. Changing this hyperparameter changes the performance of the model, for good or for worse. We cannot know the ideal value of max_depth beforehand except through experimenting. In addition to max_depth, the decision tree has many other hyperparameters.
When we select the model to use on our dataset, we then need to understand which are the best hyperparameter configurations. This activity is called hyperparameter tuning.
Once we find the best configuration, you take the best model, with the best configuration, to the “real” world — that is, the test set which is made up of data the model has never seen before.
In order to test the configuration without testing directly to the test set, we introduce a third set of data, called validation set.
The general flow is this:
- We train the model on the training set
- We test the performance of the current configuration on the validation set
- If and only if we are satisfied with the performance on the validation set, then we test on the test set.
But why complicate our lives by adding an additional set to evaluate performance? Why not use the classic training-test set split?
The reason is simple but it is extremely important.
Machine learning is an iterative process.
By iterative I mean that a model can and must be evaluated several times with different configurations in order for us to understand which is the most performing condition. The validation set allows us to test different configurations and select the best one for our scenario, without the risk of overfitting.
But there’s a problem. By dividing our dataset into three parts we are also reducing the number of examples available to our training model. Also, since the usual split is 50–30–20, the model results may randomly depend on how the data was distributed across the various sets.
Cross-validation solves this problem by removing the validation set from the equation and preserving the number of examples available to the model for learning.
Cross-validation is one of the most important concepts in machine learning. This is because it allows us to create models capable of generalization — that is, capable of creating consistent predictions even on data not belonging to the training set.
A model that can generalize is a useful, powerful model.
Cross-validation means dividing our training data into different portions and testing our model on a subset of these portions. The test set continues to be used for the final evaluation, while the model performances are evaluated on the portions generated by the cross-validation. This method is called K-Fold cross-validation, which we will see in more detail shortly.
Below is an image that summarizes this saying so far.
Cross-validation can be done in different ways and each method is suitable for a different scenario. In this article we will look at K-Fold cross-validation, which is by far the most popular cross-validation technique. Other popular variants are stratified cross-validation and group-based cross-validation.
The training set is divided into K-folds (we read “portions”) and the model is trained on k-1 portions. The remaining portion is used to evaluate the model.
This all takes place in the so-called cross-validation loop. Here’s an image taken from Scikit-learn.org that clearly shows this concept
After interacting through each split, we will have as a final result the average of the performances. This increases the validity of the performance, as a “new” model is trained on each portion of the training dataset. We will then have a final score which summarizes the performance of the model in many validation steps — a very reliable method compared to looking at the performance of a single iteration!
Let’s break down the process:
- Randomize each row of the dataset
- Divide the dataset into k portions
- For each group
1. Create a test portion
2. Allocate the remainder to training
3. Train the model and evaluate it on the mentioned sets
4. Save the performance
- Evaluate overall performance by taking the average of the scores at the end of the process
The value of k is typically 5 or 10, but Sturges’ rule can be used to establish a more precise number of splits
number_of_splits = 1 + log2(N)
where N is the total number of samples.
I mentioned the cross-validation loop just now. Let’s go deeper into this concept which is fundamental but often overlooked by young analysts.
Doing cross-validation, in itself, is already very useful. But in some cases it is necessary to go further and test new ideas and hypotheses to further improve your model.
All this must be done within the cross-validation loop, which is point 3 of the flow mentioned above.
Each experiment must be performed within the cross-validation loop.
Since cross-validation allows us to train and test the model several times and to gather the overall performance with an average at the end, we need to insert all the logics that change the behavior of the model within the cross-validation loop. Failure to do so makes it impossible to measure the impact of our assumptions.
Let’s now look at some examples.
Here’s a template for applying cross-validation in Python. We will use Sklearn to generate a dummy dataset for a classification task and use the accuracy and ROC-AUC score to evaluate our model.
Cross-validation is the first, essential step to consider when doing machine learning.
Always remember: if we want to do feature engineering, add logic or test other hypotheses — always split the data first with KFold and apply those logic in the cross-validation loop.
If we have a good cross-validation framework with validation data representative of reality and training data, then we can create good, highly generalizable machine learning models.