Save and restore a Tensorflow model using Keras for continuous model training

Save and restore a Tensorflow model using Keras for continuous model training

Sometimes, we want to stop fitting the model and get the current model weights or the best weights we get so far. When do we do it? Usually, when fitting runs for too long, and we don’t see any improvement.

Occasionally, we want to restore model training after a script failure.

On other occasions, we use cheap Amazon spot instances or their equivalent provided by other services, and we must prepare our code to be interrupted and resumed at any time.

In all of those situations, we can use Tensorflow checkpoints to store the intermediate state of the model and resume training later.

To use the Tensorflow checkpoints, we need to define the model. I am not going to do it in this example, because the model structure is not relevant. When we call the model.compile function, we are ready to go.

So here is the last line of the model definition. We must write the model saving/restoring code after that line.

1
model.compile(loss='categorical_crossentropy', optimizer='adam')

Saving the Keras model into a file

To save the model, we are going to use Keras checkpoint feature.
In this example, I am going to store only the best version of the model.

To decide which version should be stored, Keras is going to observe the loss function and choose the model version that has minimal loss.

1
2
3
4
5
6
from keras.callbacks import ModelCheckpoint

filepath = "model.h5"

checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min') 
model.fit(X, Y, epochs=5, batch_size=2000, verbose = 1, callbacks = [checkpoint])

If instead of loss we want to track the accuracy, we must change both the monitor and mode parameter.

1
checkpoint = ModelCheckpoint(filepath, monitor = 'acc', verbose = 1, save_best_only = True, mode = 'max')

Do you want to show your product/service to 25000 data science enthusiasts every month? I am looking for companies which would like to become a partner of this blog.

Are you interested? Is your employer interested? Here are the details of the offer.

Restore a Keras model from a file and continue fitting the model

Now, we can restore the model from the file. All we need is the load_model function. After loading the model, we can restore fitting the model.

1
2
3
4
5
6
from keras.models import load_model

new_model = load_model("model.h5")

checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min') 
new_model.fit(X, Y, epochs=5, batch_size=2000, callbacks = [checkpoint], verbose = 1)

Does it really resume fitting?

Yes, let’s look at the log output. Here are the messages logged during the first run of the fit function (before saving the model).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Epoch 1/5
74394/74394 [==============================] - 121s 2ms/step - loss: 3.1464

Epoch 00001: loss improved from inf to 3.14638, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 3.0030

Epoch 00002: loss improved from 3.14638 to 3.00302, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9952

Epoch 00003: loss improved from 3.00302 to 2.99524, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.9812

Epoch 00004: loss improved from 2.99524 to 2.98121, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9357

Epoch 00005: loss improved from 2.98121 to 2.93567, saving model to model.h5

Now, let’s look at the output of the second run (after loading the model from a file and calling the fit function again):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Epoch 1/5
74394/74394 [==============================] - 118s 2ms/step - loss: 2.8460

Epoch 00001: loss improved from inf to 2.84600, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7889

Epoch 00002: loss improved from 2.84600 to 2.78892, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 116s 2ms/step - loss: 2.7534

Epoch 00003: loss improved from 2.78892 to 2.75342, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7183

Epoch 00004: loss improved from 2.75342 to 2.71827, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.6811

Epoch 00005: loss improved from 2.71827 to 2.68112, saving model to model.h5

It did not start from scratch. Keras continued fitting the model.

Doing it all in one script.

When we write a single script, we must somehow distinguish between the first run of the fit function and subsequent runs. It is necessary because during the first run we want to define the model structure, but during other runs, all we need is the load_model function.

Don’t overthink it. It is simple. Just check if the model file exists. If not, define the model and run the fit function for the first time. If the file exists, load the model from it and call the fit function again.

What about optimizer parameters (learning rate, momentum, etc.)?

When we use the code from the example above, the whole model is stored. It means that the file contains the model structure (its architecture), model weights and the optimizer parameter.

To store only the model weights, we should set the save_weights_only parameter of the ModelCheckpoint to true.

1
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min', save_weights_only = True)

Obviously, in that case, we can no longer use the load_model function.

Now, it is necessary to define the model architecture again, set the optimizer parameters, and compile the model.

After all of that, we finally can call the model.load_weights function.

So, we should probably stick to storing the whole model.


Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you watch programming live streams, check out my YouTube channel.
You can also follow me on Twitter: @mikulskibartosz

If you want to hire me, send me a message on LinkedIn or Twitter.


If this article was helpful, consider donating to WWF or any other charity of your choice.
Bartosz Mikulski
Bartosz Mikulski * data scientist / software engineer * conference speaker * organizer of School of A.I. meetups in Poznań * co-founder of Software Craftsmanship Poznan & Poznan Scala User Group