Using a surrogate model to interpret a machine learning model
In my opinion, training a surrogate model is the easiest method of interpreting the behavior of an existing machine learning model.
To apply this method, we are going to need:
- an existing machine learning model
- input data that can be processed by the existing model (for example the test dataset used for training the model or a sample of real-world data from the production environment)
We don’t need to know anything about the existing model. It is just a black box. It has an input, and when we pass the data, we get an output. That is all we need.
In the first step, I am going to pass the data into the black box model and get the prediction.
1 2 3 4 5 6 ======================= = = data => = black box = => prediction = model = =======================
Now, I have to decide what kind of model I want to train as the surrogate model. It should be a model that I know how to interpret and explain to people who have no machine learning knowledge, for example, linear regression or decision trees.
I am going to train the surrogate model, using the independent variables from input data and the prediction from the black box as the dependent variable.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 independent variables prediction from input dataset from black box model || || || || \/ \/ ======================= = = = surrogate = = model = ======================= || \/ surrogate's prediction
After that, I can calculate the prediction error of the surrogate model and compare it with the predictions of the black box. The smaller the error I get, the better the surrogate model explains the black box.
When I get a surrogate model which has an acceptable prediction error, I can look at its parameters to understand which features are important and how the black box model works.
Did you enjoy reading this article?
Would you like to learn more about software craft in data engineering and MLOps?
Subscribe to the newsletter or add this blog to your RSS reader (does anyone still use them?) to get a notification when I publish a new essay!
You may also like
- Preprocessing the input Pandas DataFrame using ColumnTransformer in Scikit-learn
- How to avoid bias against underrepresented target classes while training a machine learning model
- Prophet plot explained
- What is the difference between training, validation, and test sets in machine learning
- Generalized Linear Models — Using linear regression when the dependent variable does not follow Gaussian distribution
- Data/MLOps engineer by day
- DevRel/copywriter by night
- Python and data engineering trainer
- Conference speaker
- Contributed a chapter to the book "97 Things Every Data Engineer Should Know"
- Twitter: @mikulskibartosz