How to increase accuracy of a deep learning model
In this article, I am going to describe techniques of debugging deep learning models and increasing their accuracy. I have learned those techniques using Andrew Ng’s Coursera course which I highly recommend. It is awesome! ;)
Note that those hints work best when we are building a model which will be used in production. Most of them make no sense if we are participating in a Kaggle competition or build models just for the sake of building them with no intent of using them in real life.
First, we must start by deciding what metric we want to optimize. It is crucial to choose only one metric because otherwise, we will not be able to compare the performance of models.
If we need not only high accuracy but also short response time, we should decide which metric is going to be the optimizing metric. This is the value we want to compare to choose a better model.
Additionally, we should set some “satisficing” metrics which are our “usefulness thresholds.” If those metrics are not satisfied, we reject the model, even if it the best model according to the optimizing metric.
Test data distribution
In the course, Andrew Ng uses an example of a model that detects cats. His training data consists of 100 000 good quality images downloaded from websites and 10000 images from mobile devices. The mobile dataset is of low quality, but the model is going to be deployed as a mobile application, so this is the quality we should expect in production.
Andrew Ng suggests using all of the “website images” as the training examples and adding half of the “mobile images” to the training set. The development and testing datasets are going to contain only images from mobile devices.
That suggestion sounds counter-intuitive because usually we are told to have the same data distribution in training and testing datasets. Nevertheless, it makes sense, because we want the development and testing datasets to have the same distribution as the data which is going to be used in production when we deploy the model.
After all, we don’t want to artificially increase the accuracy during training and end up with a model that cannot perform well in real life.
Types of errors (and how to fix them)
First of all, we must define the expected error of the model. In most tasks, it is sufficient to set it at the human error level (it is possible to train a model that is better than humans, but in this case, we need a new baseline for the error).
The difference between the baseline error (the human error) and the training set error is called avoidable bias. A high error is usually caused by underfitting of the neural network.
To improve the result, we should consider: training a bigger model (adding more neurons in layers or adding new layers), increasing the training time, using a more sophisticated optimizer (RMSprop or ADAM), adding momentum to gradient descent, or changing the neural network architecture to something that better fits the problem.
The difference between training dataset error and the development dataset error is called variance. This is the measure of overfitting. To prevent it (and to get a smaller development error) we should get more data, use regularization (L2 or dropout), use data augmentation to get more data or try a different neural network architecture.
What to do when the training set error is low, but the development set error is significantly larger?
That may be a symptom of data mismatch between training and development set. To verify it, we should extract a train-dev set from the training set and use it as a development dataset.
If the error stays the same (or “almost the same”), there is no data mismatch. In this case, we should continue using techniques of lowering the variance.
If the train-dev set error is significantly smaller than the dev set error, we know that the distribution of training dataset differs from the distribution of dev dataset. To fix the problem, we should make the datasets more similar (but remember to keep the dev distribution the same as the production data). It is recommended to perform error analysis to understand what is the difference between those datasets.
Dev set error vs. test set error
First of all, we should not overuse the test set. Evaluating performance using the test set should be the last step, and we should not continue training the model, because we don’t want to overfit to the testing dataset. This dataset exists to tell us what is the expected performance of the model in production.
Nevertheless, if there is a significant difference between the development set and the test set, we see that the model overfitted to the development data. The only solution to this problem seems to be getting a bigger development dataset.
An excellent test set performance and poor results in the production
That is the most annoying kind of error. To fix it, Andrew Ng suggests getting a larger development dataset (that is similar to the production data).
If that does not help, we probably used an incorrect cost function which means that our model is very good at solving the wrong problem. We should find a better way of encoding the problem we want to solve (consider a new cost function, metric or data labels) and start from scratch.
Did you enjoy reading this article?
Would you like to learn more about leveraging AI to drive growth and innovation, software craft in data engineering, and MLOps?
Subscribe to the newsletter or add this blog to your RSS reader (does anyone still use them?) to get a notification when I publish a new essay!
You may also like
- Understanding the softmax activation function
- Save and restore a Tensorflow model using Keras for continuous model training
- Using Hyperband for TensorFlow hyperparameter tuning with keras-tuner
- How to train a model in TensorFlow 2.0
- The optimal learning rate during fine-tuning of an artificial neural network
- MLOps engineer by day
- AI and data engineering consultant by night
- Python and data engineering trainer
- Conference speaker
- Contributed a chapter to the book "97 Things Every Data Engineer Should Know"
- Twitter: @mikulskibartosz
- Mastodon: @firstname.lastname@example.org