How to avoid bias against underrepresented target classes while training a machine learning model

In this article, I am going to describe how to use stratification while splitting data into the training and the test datasets. Let’s start with an explanation of what “K-fold split” is. Later, I am going to describe when we need stratification. In the end, I will show the difference in the results produced by normal KFold and the StratifiedKFold.

K-fold

K-fold is a method of splitting the available data into the training and test datasets. In this method, the input dataset is split multiple times. Every time a different split is produced, so during cross-validation, the model is going to be evaluated using a different test set in every step of cross-validation (in the case of cross-validation, it is called the validation set or development set). That way of creating the validation sets helps us reduce the chance of overfitting.

Note that, before using K-fold we must split the original dataset into training and test set. The training set becomes the input of the K-fold split, which produces multiple training and validation sets. The test set will not be used for K-fold validation because we want to use it to estimate the performance of the model in a real-life scenario.

Stratification

A method of selecting examples that preserves the percentage of target classes in the available data.

For example, imagine that I have a dataset of cats and dogs. In my dataset, there are only 4 cats and 16 dogs. I want to preserve the ratio in every K-fold split. If I don’t do it, the model may ignore the unrepresented target class because it occurs only in few of the splits.

Example

First, I am going to create an imbalances dataset. Imagine that the target class 1 denotes a cat, 2 means a dog.

1
2
3
4
5
6
7
8
9
import numpy as np

X = np.array([
    [1, 2], [3, 4], [1, 2], [3, 4], [5, 6],
    [6, 4], [2, 9], [3, 7], [2, 4], [6, 1],
    [5, 8], [4, 0], [1, 1], [2, 9], [3, 8],
    [0, 0], [6, 6], [6, 4], [5, 4], [8, 4]
])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

Now, I am going to use the standard KFold to generate 4 different splits of the input dataset. It would be better to shuffle the elements before the split (add the shuffle=True parameter), but the difference will be easier to notice if I don’t shuffle the data.

1
2
3
4
5
6
from sklearn.model_selection import KFold

kfold = KFold(n_splits=4)
for train_index, test_index in kfold.split(X):
    print("TRAIN (index):", train_index, "TEST (index):", test_index)
    print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])

The output:

1
2
3
4
5
6
7
8
TRAIN (index): [ 5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 1 2 3 4]
TRAIN (target class value): [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 1 1 1 0]
TRAIN (index): [ 0  1  2  3  4 10 11 12 13 14 15 16 17 18 19] TEST (index): [5 6 7 8 9]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 15 16 17 18 19] TEST (index): [10 11 12 13 14]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] TEST (index): [15 16 17 18 19]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]

We see that the ratio of cats in every split (target class = 1) is different than in the input dataset. The first split gets all of the cats, the others don’t contain them at all. Again, if I shuffled the data it may be possible that I would get a more fair distribution, but there is no way to be sure.

Unless I use the StratifiedKFold, which keeps the underrepresented target class in every of the generated splits and ensures that the ratio of examples is the same as in the input.

1
2
3
4
5
6
from sklearn.model_selection import StratifiedKFold
stratified = StratifiedKFold(n_splits = 4)
//note the difference in the next line, I had to pass the "y" parameter.
for train_index, test_index in stratified.split(X, y):
    print("TRAIN (index):", train_index, "TEST (index):", test_index)
    print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])

In the result, we see that in the training and validation datasets the ratio of cats is around 20%, so the same as in the input dataset.

1
2
3
4
5
6
7
8
TRAIN (index): [ 1  2  3  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 4 5 6 7]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  2  3  4  5  6  7 12 13 14 15 16 17 18 19] TEST (index): [ 1  8  9 10 11]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  3  4  5  6  7  8  9 10 11 16 17 18 19] TEST (index): [ 2 12 13 14 15]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15] TEST (index): [ 3 16 17 18 19]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]

Did you enjoy reading this article?
Would you like to learn more about software craft in data engineering and MLOps?

Subscribe to the newsletter or add this blog to your RSS reader (does anyone still use them?) to get a notification when I publish a new essay!

Newsletter

Do you enjoy reading my articles?
Subscribe to the newsletter if you don't want to miss the new content, business offers, and free training materials.

Bartosz Mikulski

Bartosz Mikulski

  • Data/MLOps engineer by day
  • DevRel/copywriter by night
  • Python and data engineering trainer
  • Conference speaker
  • Contributed a chapter to the book "97 Things Every Data Engineer Should Know"
  • Twitter: @mikulskibartosz
Newsletter

Do you enjoy reading my articles?
Subscribe to the newsletter if you don't want to miss the new content, business offers, and free training materials.