How to avoid bias against underrepresented target classes while training a machine learning model

In this article, I am going to describe how to use stratification while splitting data into the training and the test datasets. Let’s start with an explanation of what “K-fold split” is. Later, I am going to describe when we need stratification. In the end, I will show the difference in the results produced by normal KFold and the StratifiedKFold.

K-fold

K-fold is a method of splitting the available data into the training and test datasets. In this method, the input dataset is split multiple times. Every time a different split is produced, so during cross-validation, the model is going to be evaluated using a different test set in every step of cross-validation (in the case of cross-validation, it is called the validation set or development set). That way of creating the validation sets helps us reduce the chance of overfitting.

Note that, before using K-fold we must split the original dataset into training and test set. The training set becomes the input of the K-fold split, which produces multiple training and validation sets. The test set will not be used for K-fold validation because we want to use it to estimate the performance of the model in a real-life scenario.


Subscribe to the newsletter and join the free email course.

Stratification

A method of selecting examples that preserves the percentage of target classes in the available data.

For example, imagine that I have a dataset of cats and dogs. In my dataset, there are only 4 cats and 16 dogs. I want to preserve the ratio in every K-fold split. If I don’t do it, the model may ignore the unrepresented target class because it occurs only in few of the splits.

Example

First, I am going to create an imbalances dataset. Imagine that the target class 1 denotes a cat, 2 means a dog.

1
2
3
4
5
6
7
8
9
import numpy as np

X = np.array([
    [1, 2], [3, 4], [1, 2], [3, 4], [5, 6],
    [6, 4], [2, 9], [3, 7], [2, 4], [6, 1],
    [5, 8], [4, 0], [1, 1], [2, 9], [3, 8],
    [0, 0], [6, 6], [6, 4], [5, 4], [8, 4]
])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

Now, I am going to use the standard KFold to generate 4 different splits of the input dataset. It would be better to shuffle the elements before the split (add the shuffle=True parameter), but the difference will be easier to notice if I don’t shuffle the data.

1
2
3
4
5
6
from sklearn.model_selection import KFold

kfold = KFold(n_splits=4)
for train_index, test_index in kfold.split(X):
    print("TRAIN (index):", train_index, "TEST (index):", test_index)
    print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])

The output:

1
2
3
4
5
6
7
8
TRAIN (index): [ 5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 1 2 3 4]
TRAIN (target class value): [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 1 1 1 0]
TRAIN (index): [ 0  1  2  3  4 10 11 12 13 14 15 16 17 18 19] TEST (index): [5 6 7 8 9]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 15 16 17 18 19] TEST (index): [10 11 12 13 14]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] TEST (index): [15 16 17 18 19]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]

We see that the ratio of cats in every split (target class = 1) is different than in the input dataset. The first split gets all of the cats, the others don’t contain them at all. Again, if I shuffled the data it may be possible that I would get a more fair distribution, but there is no way to be sure.

Unless I use the StratifiedKFold, which keeps the underrepresented target class in every of the generated splits and ensures that the ratio of examples is the same as in the input.

1
2
3
4
5
6
from sklearn.model_selection import StratifiedKFold
stratified = StratifiedKFold(n_splits = 4)
//note the difference in the next line, I had to pass the "y" parameter.
for train_index, test_index in stratified.split(X, y):
    print("TRAIN (index):", train_index, "TEST (index):", test_index)
    print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])

In the result, we see that in the training and validation datasets the ratio of cats is around 20%, so the same as in the input dataset.

1
2
3
4
5
6
7
8
TRAIN (index): [ 1  2  3  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 4 5 6 7]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  2  3  4  5  6  7 12 13 14 15 16 17 18 19] TEST (index): [ 1  8  9 10 11]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  3  4  5  6  7  8  9 10 11 16 17 18 19] TEST (index): [ 2 12 13 14 15]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15] TEST (index): [ 3 16 17 18 19]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]

Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you want to contact me, send me a message on LinkedIn or Twitter.

Would you like to have a call and talk? Please schedule a meeting using this link.


Bartosz Mikulski
Bartosz Mikulski * data/machine learning engineer * conference speaker * co-founder of Software Craft Poznan & Poznan Scala User Group

Subscribe to the newsletter and get access to my free email course on building trustworthy data pipelines.

Do you want to work with me at riskmethods?

REMOTE position (available in Poland or Germany)