Using a surrogate model to interpret a machine learning model

In my opinion, training a surrogate model is the easiest method of interpreting the behavior of an existing machine learning model.

To apply this method, we are going to need:

  • an existing machine learning model
  • input data that can be processed by the existing model (for example the test dataset used for training the model or a sample of real-world data from the production environment)

We don’t need to know anything about the existing model. It is just a black box. It has an input, and when we pass the data, we get an output. That is all we need.

In the first step, I am going to pass the data into the black box model and get the prediction.

          =======================
          =                     =
data =>   =     black box       =  => prediction
          =       model         =
          =======================

Now, I have to decide what kind of model I want to train as the surrogate model. It should be a model that I know how to interpret and explain to people who have no machine learning knowledge, for example, linear regression or decision trees.

I am going to train the surrogate model, using the independent variables from input data and the prediction from the black box as the dependent variable.

independent variables       prediction
from input dataset          from black box model
              ||              ||
              ||              ||
              \/              \/
          =======================
          =                     =
          =     surrogate       =
          =       model         =
          =======================
                    ||
                    \/
                 surrogate's
                 prediction

After that, I can calculate the prediction error of the surrogate model and compare it with the predictions of the black box. The smaller the error I get, the better the surrogate model explains the black box.

When I get a surrogate model which has an acceptable prediction error, I can look at its parameters to understand which features are important and how the black box model works.

Newer post

Wilson score in Python - example

How to calculate page popularity using the Wilson Score