How to automatically select the hyperparameters of a ResNet neural network
In this article, I am going to show how to automatically tune the hyperparameters of a ResNet network used for multiclass image classification.
I use the keras-tuner project, which currently is in the pre-alpha version.
Before we start, I have to load the dataset. I’m going to use the Fashion-MNIST dataset which is built-in in Keras, so loading it is straightforward.
1 2 3 4 5 6 import tensorflow as tf from tensorflow import keras fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
The dataset consists of grayscale 28x28 pixel images, and it requires a simple preprocessing. First, we have to scale the values to the range between 0 and 1, so I divide the color value by 255. As the second step, I have to reshape the array to fit the expected shape of a convolutional layer. (28, 28, 1) means 28 x 28 pixels, one color channel.
1 2 3 4 5 train_images = train_images / 255.0 test_images = test_images / 255.0 train_images = train_images.reshape(len(train_images), 28, 28, 1) test_images = test_images.reshape(len(test_images), 28, 28, 1)
I must also preprocess the target labels. The ResNet works with one-hot encoded labels, so I have to call the to_categorical function from Keras utils.
1 2 from keras.utils import to_categorical train_labels_binary = to_categorical(train_labels)
When all of that is done, I can finally import the Hyperband tuner and the ResNet implementation. In this example, I let the tuner change every parameter of the neural network, so I don’t specify any of them manually.
1 2 3 4 5 6 7 8 9 10 11 12 13 from kerastuner.applications import HyperResNet from kerastuner.tuners import Hyperband hypermodel = HyperResNet(input_shape=(28, 28, 1), classes=10) tuner = Hyperband( hypermodel, objective='val_accuracy', max_trials=20, directory='FashionMnistResNet', project_name='FashionMNIST') tuner.search(train_images, train_labels_binary, validation_split=0.1)
To get a list of tunable parameters, we have to look at the source code of the HyperResNet class. In the code, we see that we can choose the version of the ResNet, the depth of the convolutional layer blocks, the learning rate, and the optimization algorithm.
Are you interested in data engineering?
Check out my other blog https://easydata.engineering
If I want to override the options specified in the source code, I can pass my HyperParameters object to the tuner. In the following code snippet, I define the acceptable learning rate values (the tuner will choose one of the given values) and decide that Adam optimizer should always be used.
1 2 3 4 from kerastuner import HyperParameters hp = HyperParameters() hp.Choice('learning_rate', values=[1e-3, 1e-4]) hp.Fixed('optimizer', value='adam')
Additionally, I have to set the tune_new_entries to False. This parameter means that the tuner will use the tuning configuration from the HyperResNet source code when it needs a parameter that is not defined in my HyperParameters object.
1 2 3 4 5 6 7 8 tuner = Hyperband( build_model, objective='val_accuracy', hyperparameters=hp, tune_new_entries=False, max_trials=20, directory='FashionMnistResNet', project_name='FashionMNIST')
I may be wrong about the tune_new_entries because the documentation says it should be set to True. When I run it with tune_new_entries=True, the tuner does not even list the parameters that exist in the code but are not specified in my HyperParameters object.
On the other hand, when I set that argument to False, the tuner lists those values in the output, and the values change every time it runs a new trial. I reported that as a bug because it is either an error in the code or documentation that needs to be clarified.
Hopefully, the bug that causes Hyperband to crash has been already fixed when you read this text, but if you get, a TyperError saying that ‘<’ not supported between instances of ‘NoneType’ and ‘float,’ take a look at my previous post and copy the fix.
Remember to share on social media! If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.