How to avoid bias against underrepresented target classes while training a machine learning model
In this article, I am going to describe how to use stratification while splitting data into the training and the test datasets. Let’s start with an explanation of what “K-fold split” is. Later, I am going to describe when we need stratification. In the end, I will show the difference in the results produced by normal KFold and the StratifiedKFold.
K-fold
K-fold is a method of splitting the available data into the training and test datasets. In this method, the input dataset is split multiple times. Every time a different split is produced, so during cross-validation, the model is going to be evaluated using a different test set in every step of cross-validation (in the case of cross-validation, it is called the validation set or development set). That way of creating the validation sets helps us reduce the chance of overfitting.
Note that, before using K-fold we must split the original dataset into training and test set. The training set becomes the input of the K-fold split, which produces multiple training and validation sets. The test set will not be used for K-fold validation because we want to use it to estimate the performance of the model in a real-life scenario.

Stratification
A method of selecting examples that preserves the percentage of target classes in the available data.
For example, imagine that I have a dataset of cats and dogs. In my dataset, there are only 4 cats and 16 dogs. I want to preserve the ratio in every K-fold split. If I don’t do it, the model may ignore the unrepresented target class because it occurs only in few of the splits.
Example
First, I am going to create an imbalances dataset. Imagine that the target class 1 denotes a cat, 2 means a dog.
1
2
3
4
5
6
7
8
9
import numpy as np
X = np.array([
[1, 2], [3, 4], [1, 2], [3, 4], [5, 6],
[6, 4], [2, 9], [3, 7], [2, 4], [6, 1],
[5, 8], [4, 0], [1, 1], [2, 9], [3, 8],
[0, 0], [6, 6], [6, 4], [5, 4], [8, 4]
])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Now, I am going to use the standard KFold to generate 4 different splits of the input dataset. It would be better to shuffle the elements before the split (add the shuffle=True parameter), but the difference will be easier to notice if I don’t shuffle the data.
1
2
3
4
5
6
from sklearn.model_selection import KFold
kfold = KFold(n_splits=4)
for train_index, test_index in kfold.split(X):
print("TRAIN (index):", train_index, "TEST (index):", test_index)
print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])
The output:
1
2
3
4
5
6
7
8
TRAIN (index): [ 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 1 2 3 4]
TRAIN (target class value): [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 1 1 1 0]
TRAIN (index): [ 0 1 2 3 4 10 11 12 13 14 15 16 17 18 19] TEST (index): [5 6 7 8 9]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0 1 2 3 4 5 6 7 8 9 15 16 17 18 19] TEST (index): [10 11 12 13 14]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14] TEST (index): [15 16 17 18 19]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
We see that the ratio of cats in every split (target class = 1) is different than in the input dataset. The first split gets all of the cats, the others don’t contain them at all. Again, if I shuffled the data it may be possible that I would get a more fair distribution, but there is no way to be sure.
Unless I use the StratifiedKFold, which keeps the underrepresented target class in every of the generated splits and ensures that the ratio of examples is the same as in the input.
1
2
3
4
5
6
from sklearn.model_selection import StratifiedKFold
stratified = StratifiedKFold(n_splits = 4)
//note the difference in the next line, I had to pass the "y" parameter.
for train_index, test_index in stratified.split(X, y):
print("TRAIN (index):", train_index, "TEST (index):", test_index)
print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])
In the result, we see that in the training and validation datasets the ratio of cats is around 20%, so the same as in the input dataset.
1
2
3
4
5
6
7
8
TRAIN (index): [ 1 2 3 8 9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 4 5 6 7]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0 2 3 4 5 6 7 12 13 14 15 16 17 18 19] TEST (index): [ 1 8 9 10 11]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0 1 3 4 5 6 7 8 9 10 11 16 17 18 19] TEST (index): [ 2 12 13 14 15]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0 1 2 4 5 6 7 8 9 10 11 12 13 14 15] TEST (index): [ 3 16 17 18 19]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
You may also like
- What is the difference between training, validation, and test sets in machine learning
- How to install scikit-automl in a Kaggle notebook
- A comprehensive guide to putting a machine learning model in production using Flask, Docker, and Kubernetes
- How to save a machine learning model into a file
- The problem of large categorical variables in machine learning
Remember to share on social media! If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.
If you want to contact me, send me a message on LinkedIn or Twitter.
Would you like to have a call and talk? Please schedule a meeting using this link.