Using a surrogate model to interpret a machine learning model
In my opinion, training a surrogate model is the easiest method of interpreting the behavior of an existing machine learning model.
To apply this method, we are going to need:
- an existing machine learning model
- input data that can be processed by the existing model (for example the test dataset used for training the model or a sample of real-world data from the production environment)
We don’t need to know anything about the existing model. It is just a black box. It has an input, and when we pass the data, we get an output. That is all we need.
In the first step, I am going to pass the data into the black box model and get the prediction.
1 2 3 4 5 6 ======================= = = data => = black box = => prediction = model = =======================
Now, I have to decide what kind of model I want to train as the surrogate model. It should be a model that I know how to interpret and explain to people who have no machine learning knowledge, for example, linear regression or decision trees.
I am going to train the surrogate model, using the independent variables from input data and the prediction from the black box as the dependent variable.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 independent variables prediction from input dataset from black box model || || || || \/ \/ ======================= = = = surrogate = = model = ======================= || \/ surrogate's prediction
After that, I can calculate the prediction error of the surrogate model and compare it with the predictions of the black box. The smaller the error I get, the better the surrogate model explains the black box.
When I get a surrogate model which has an acceptable prediction error, I can look at its parameters to understand which features are important and how the black box model works.
Did you enjoy reading this article?
Would you like to learn more about software craft in data engineering and MLOps?
Subscribe to the newsletter or add this blog to your RSS reader (does anyone still use them?) to get a notification when I publish a new essay!
You may also like
- The problem of large categorical variables in machine learning
- Prophet plot explained
- Machine learning cheat sheets
- Generalized Linear Models — Using linear regression when the dependent variable does not follow Gaussian distribution
- Nested cross-validation in time series forecasting using Scikit-learn and Statsmodels