[Machine Learning Fundamentals] Training, validation and test sets

Hello everyone! Today I would like to talk about the use of the training, validation and test sets in machine learning. This is a common but very important topic that is often explained at the beginning of any machine learning course. However mistakes are very common. In this article, I hope to provide a clear explanation on how to employ a dataset in order to train and evaluate a machine learning model.

For any task, a dataset is usually split into three non-overlapping subsets :

  • The training set : this is the data that is employed to train the machine learning model. As machine learning models require a huge amount of data to be trained, the training set should be the largest of the three subsets.

  • The validation set (also called the development set) : this is the subset of the data that is employed to find the best hyperparameters (e.g., learning rate, batch size, weight decay, number of hidden layers…). It can also be used for early stopping, i.e. to stop the training of the model when the highest accuracy has been reached (usually before the model starts overfitting on the training set).

  • The test set : this is the subset of the data that is employed to test how good the model is and how well it will generalize to real-world data. The error on this set gives an idea of the generalization error of the model.

Very importantly, the test set should only be employed at the end when the hyperparameters that give the best accuracy on the validation set have been found.

In addition, the distribution of the test set should be as close as possible to the distribution of the real-world data that the model will encounter after deployment in a real application. The validation and test sets should ideally have a similar distribution as well so that the accuracy on each of them are similar and correlated (i.e., an improvement on the validation set leads to an improvement on the test set).

As a summary, the machine learning pipeline to train a model works as follows :

  • 1 - Train several models (with different hyperparameters) on the training set and compute the accuracy of each model on the validation set

  • 2 - Take the the model with the best accuracy on the validation set and retrain it on the whole training and validation data. This retraining step is sometimes skipped but can be useful as it avoids leaving out part of the data for validation.

  • 3 - Test this model (this one only) on the test set. The accuracy on the test set is the final accuracy that should be reported in a research paper for instance.

Not testing every model on the test set but only the model with the best accuracy on the validation set reduces the risk of a mistake when reporting the results. Indeed, it should be avoided to train several models on the training set and test each of them on the test set in order to pick the one with the best accuracy. This is a common mistake as it leads to overfitting on the test set and therefore to inflated results on the test set. In this case, the results on the test set are not a correct estimation of the generalization error of the model.

This is a brief overview of the good practices regarding the use of training, validation and test sets. Although it may sound simple, it is a fundamental topic for any machine learning learning task and having good practices is important. For anyone interested in learning more, I would recommend that you watch the very interesting tutorial Andrew Ng gave on the topic.