# How to avoid bias against underrepresented target classes while training a machine learning model

In this article, I am going to describe how to use stratification while splitting data into the training and the test datasets. Let’s start with an explanation of what “K-fold split” is. Later, I am going to describe when we need stratification. In the end, I will show the difference in the results produced by normal KFold and the StratifiedKFold.

# K-fold

K-fold is a method of splitting the available data into the training and test datasets. In this method, the input dataset is split multiple times. Every time a different split is produced, so during cross-validation, the model is going to be evaluated using a different test set in every step of cross-validation (in the case of cross-validation, it is called the validation set or development set). That way of creating the validation sets helps us reduce the chance of overfitting.

Note that, before using K-fold we must split the original dataset into training and test set. The training set becomes the input of the K-fold split, which produces multiple training and validation sets. The test set will not be used for K-fold validation because we want to use it to estimate the performance of the model in a real-life scenario.

# Stratification

A method of selecting examples that preserves the percentage of target classes in the available data.

For example, imagine that I have a dataset of cats and dogs. In my dataset, there are only 4 cats and 16 dogs. I want to preserve the ratio in every K-fold split. If I don’t do it, the model may ignore the unrepresented target class because it occurs only in few of the splits.

## Example

First, I am going to create an imbalances dataset. Imagine that the target class 1 denotes a cat, 2 means a dog.

1
2
3
4
5
6
7
8
9
import numpy as np

X = np.array([
[1, 2], [3, 4], [1, 2], [3, 4], [5, 6],
[6, 4], [2, 9], [3, 7], [2, 4], [6, 1],
[5, 8], [4, 0], [1, 1], [2, 9], [3, 8],
[0, 0], [6, 6], [6, 4], [5, 4], [8, 4]
])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])


Now, I am going to use the standard KFold to generate 4 different splits of the input dataset. It would be better to shuffle the elements before the split (add the shuffle=True parameter), but the difference will be easier to notice if I don’t shuffle the data.

1
2
3
4
5
6
from sklearn.model_selection import KFold

kfold = KFold(n_splits=4)
for train_index, test_index in kfold.split(X):
print("TRAIN (index):", train_index, "TEST (index):", test_index)
print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])


The output:

1
2
3
4
5
6
7
8
TRAIN (index): [ 5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 1 2 3 4]
TRAIN (target class value): [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 1 1 1 0]
TRAIN (index): [ 0  1  2  3  4 10 11 12 13 14 15 16 17 18 19] TEST (index): [5 6 7 8 9]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 15 16 17 18 19] TEST (index): [10 11 12 13 14]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]
TRAIN (index): [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14] TEST (index): [15 16 17 18 19]
TRAIN (target class value): [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [0 0 0 0 0]


We see that the ratio of cats in every split (target class = 1) is different than in the input dataset. The first split gets all of the cats, the others don’t contain them at all. Again, if I shuffled the data it may be possible that I would get a more fair distribution, but there is no way to be sure.

Unless I use the StratifiedKFold, which keeps the underrepresented target class in every of the generated splits and ensures that the ratio of examples is the same as in the input.

1
2
3
4
5
6
from sklearn.model_selection import StratifiedKFold
stratified = StratifiedKFold(n_splits = 4)
//note the difference in the next line, I had to pass the "y" parameter.
for train_index, test_index in stratified.split(X, y):
print("TRAIN (index):", train_index, "TEST (index):", test_index)
print("TRAIN (target class value):", y[train_index], "TEST (target class value):", y[test_index])


In the result, we see that in the training and validation datasets the ratio of cats is around 20%, so the same as in the input dataset.

1
2
3
4
5
6
7
8
TRAIN (index): [ 1  2  3  8  9 10 11 12 13 14 15 16 17 18 19] TEST (index): [0 4 5 6 7]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  2  3  4  5  6  7 12 13 14 15 16 17 18 19] TEST (index): [ 1  8  9 10 11]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  3  4  5  6  7  8  9 10 11 16 17 18 19] TEST (index): [ 2 12 13 14 15]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]
TRAIN (index): [ 0  1  2  4  5  6  7  8  9 10 11 12 13 14 15] TEST (index): [ 3 16 17 18 19]
TRAIN (target class value): [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] TEST (target class value): [1 0 0 0 0]


Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you watch programming live streams, check out my YouTube channel.
You can also follow me on Twitter: @mikulskibartosz

If you want to hire me, send me a message on LinkedIn or Twitter.

Bartosz Mikulski * data/machine learning engineer * conference speaker * co-founder of Software Craftsmanship Poznan & Poznan Scala User Group