What is the difference between training, validation, and test sets in machine learning
When we train a machine learning model or a neural network, we split the available data into three categories: training data set, validation data set, and test data set. In this article, I describe different methods of splitting data and explain why do we do it at all.
Three kinds of datasets
Training of a machine learning model or a neural network is performed iteratively. We train the model, check the result, tweak the hyperparameters, and train the model again. All of that is repeated until we get satisfiable results. However, how do we know that we got good results?
In every iteration of training, we use the training dataset as the examples for the model. The model “learns” by extracting patterns from the training data and generalizing the training examples.
After training, we use the validation data set to check the performance of the model during training. If we are not satisfied with the result, we modify the hyperparameters and continue training using the training dataset again.
At the end, when the training is finished, we use the test dataset to check the final performance of the model. Note that the test dataset must be used only once!
We cannot continue training the model after checking the test set result because that leads to overfitting and makes the test set useless.
Why do we even bother with splitting the data?
Imagine that you are a math teacher, and you want your students to prepare for an algebra exam. You gave them a book with 50 exercises. The book not only gives them a task to do but also shows the steps to calculate the solution and allows them to check whether their solution is right or not.
How would you check whether they learned anything or not? Would you pick randomly 10 exercises from the book and use them as the exam assignments?
The book you gave them is not very long. The students could have learned the exercises by the hearth. It is an algebra exam, not a poetry class. You want your students to understand the rules of algebra, not memorize the book.
Because of that, you take a completely different book and pick the exam assignments from that book, not from the one used by your students.
We do precisely the same thing in machine learning. The training data set is the book used by the students; the test data set is the one used by the teacher to prepare the exam.
What about the validation data set?
How do smart students learn? They read one or two exercises in the book, follow the steps carefully, try to understand the rules, take another exercise, and try to solve it on their own. If they fail, they look at the solution and try to find their mistake. After that, they write down another exercise and try to solve it. They don’t look at the solution before trying to solve the problem!
That is the validation data set. It is used by the students to check whether they have learned the rules of the algebra.
The students who aren’t smart look at the exercises, follow the steps but never try to solve the problem without looking at the book. They feel very confident. After all, they have “done” all exercises in the book correctly. In reality, they have just followed the instructions without learning anything.
By the way, last year, I wrote an article about efficient learning. Check it out if you want to speed up your learning.
Which student will get a better score during the exam? Obviously, the one who used the validation data set because that student recognized the patterns and generalized knowledge. The second one has just memorized the book.
There are two ways of splitting data into training and validation dataset.
We can upfront decide which part of observations is going to be the validation data set. In every iteration of training, the same observations are used for validation. That method is called hold-out validation.
The second method is called cross-validation. In that method, we must decide how many observations we want in the validation data set, but we don’t need to decide which observations will be used. It is possible because in every iteration of the training, we split the data set into training and validation. In every iteration, different examples end up in both data sets.
If we do 5-fold cross-validation, we split the observations into 5 parts. In the first iteration, the first part is used for validation, and the other four parts are the training data. In the second iteration, the second part is the validation data set. In the sixth training iteration, we are using the first part as the validation data again.
We do all of that to defer overfitting. Whatever we do, the model is eventually going to overfit. We cannot prevent it because we have a finite number of training examples. At some point, the model starts “memorizing” them. It happens when we see an almost perfect evaluation score for the training set and diminishing score for the validation set.
If we train a model for too long, it will overfit, but cross-validation allows us to train longer without overfitting.
You may also like
- How to plot the decision trees from XGBoost classifier
- A few useful things to know about machine learning
- How to save a machine learning model into a file
- How to avoid bias against underrepresented target classes while training a machine learning model
- A comprehensive guide to putting a machine learning model in production using Flask, Docker, and Kubernetes