Using a surrogate model to interpret a machine learning model
In my opinion, training a surrogate model is the easiest method of interpreting the behavior of an existing machine learning model.
To apply this method, we are going to need:
- an existing machine learning model
- input data that can be processed by the existing model (for example the test dataset used for training the model or a sample of real-world data from the production environment)
We don’t need to know anything about the existing model. It is just a black box. It has an input, and when we pass the data, we get an output. That is all we need.
In the first step, I am going to pass the data into the black box model and get the prediction.
1 2 3 4 5 6 ======================= = = data => = black box = => prediction = model = =======================
Now, I have to decide what kind of model I want to train as the surrogate model. It should be a model that I know how to interpret and explain to people who have no machine learning knowledge, for example, linear regression or decision trees.
I am going to train the surrogate model, using the independent variables from input data and the prediction from the black box as the dependent variable.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 independent variables prediction from input dataset from black box model || || || || \/ \/ ======================= = = = surrogate = = model = ======================= || \/ surrogate's prediction
After that, I can calculate the prediction error of the surrogate model and compare it with the predictions of the black box. The smaller the error I get, the better the surrogate model explains the black box.
When I get a surrogate model which has an acceptable prediction error, I can look at its parameters to understand which features are important and how the black box model works.
You may also like
- What is the difference between training, validation, and test sets in machine learning
- How to save a machine learning model into a file
- Understanding uncertainty intervals generated by Prophet
- Nested cross-validation in time series forecasting using Scikit-learn and Statsmodels
- Machine learning cheat sheets