# Save and restore a Tensorflow model using Keras for continuous model training

Sometimes, we want to stop fitting the model and get the current model weights or the best weights we get so far. When do we do it? Usually, when fitting runs for too long, and we don’t see any improvement.

Occasionally, we want to restore model training after a script failure.

On other occasions, we use cheap Amazon spot instances or their equivalent provided by other services, and we must prepare our code to be interrupted and resumed at any time.

In all of those situations, we can use Tensorflow checkpoints to store the intermediate state of the model and resume training later.

To use the Tensorflow checkpoints, we need to define the model. I am not going to do it in this example, because the model structure is not relevant. When we call the model.compile function, we are ready to go.

So here is the last line of the model definition. We must write the model saving/restoring code after that line.

1

model.compile(loss='categorical_crossentropy', optimizer='adam')

# Saving the Keras model into a file

To save the model, we are going to use Keras checkpoint feature.

In this example, I am going to store only the best version of the model.

To decide which version should be stored, Keras is going to observe the loss function and choose the model version that has minimal loss.

1
2
3
4
5
6

from keras.callbacks import ModelCheckpoint
filepath = "model.h5"
checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min')
model.fit(X, Y, epochs=5, batch_size=2000, verbose = 1, callbacks = [checkpoint])

If instead of loss we want to track the accuracy, we must change both the monitor and mode parameter.

1

checkpoint = ModelCheckpoint(filepath, monitor = 'acc', verbose = 1, save_best_only = True, mode = 'max')

# Restore a Keras model from a file and continue fitting the model

Now, we can restore the model from the file. All we need is the load_model function. After loading the model, we can restore fitting the model.

1
2
3
4
5
6

from keras.models import load_model
new_model = load_model("model.h5")
checkpoint = ModelCheckpoint(filepath, monitor = 'loss', verbose = 1, save_best_only = True, mode = 'min')
new_model.fit(X, Y, epochs=5, batch_size=2000, callbacks = [checkpoint], verbose = 1)

## Does it really resume fitting?

Yes, let’s look at the log output. Here are the messages logged during the first run of the fit function (before saving the model).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

Epoch 1/5
74394/74394 [==============================] - 121s 2ms/step - loss: 3.1464
Epoch 00001: loss improved from inf to 3.14638, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 3.0030
Epoch 00002: loss improved from 3.14638 to 3.00302, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9952
Epoch 00003: loss improved from 3.00302 to 2.99524, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.9812
Epoch 00004: loss improved from 2.99524 to 2.98121, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 114s 2ms/step - loss: 2.9357
Epoch 00005: loss improved from 2.98121 to 2.93567, saving model to model.h5

Now, let’s look at the output of the second run (after loading the model from a file and calling the fit function again):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

Epoch 1/5
74394/74394 [==============================] - 118s 2ms/step - loss: 2.8460
Epoch 00001: loss improved from inf to 2.84600, saving model to model.h5
Epoch 2/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7889
Epoch 00002: loss improved from 2.84600 to 2.78892, saving model to model.h5
Epoch 3/5
74394/74394 [==============================] - 116s 2ms/step - loss: 2.7534
Epoch 00003: loss improved from 2.78892 to 2.75342, saving model to model.h5
Epoch 4/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.7183
Epoch 00004: loss improved from 2.75342 to 2.71827, saving model to model.h5
Epoch 5/5
74394/74394 [==============================] - 115s 2ms/step - loss: 2.6811
Epoch 00005: loss improved from 2.71827 to 2.68112, saving model to model.h5

**It did not start from scratch. Keras continued fitting the model.**

# Doing it all in one script.

When we write a single script, we must somehow distinguish between the first run of the fit function and subsequent runs. It is necessary because during the first run we want to define the model structure, but during other runs, all we need is the load_model function.

Don’t overthink it. It is simple. Just check if the model file exists. If not, define the model and run the fit function for the first time. If the file exists, load the model from it and call the fit function again.

# What about optimizer parameters (learning rate, momentum, etc.)?

When we use the code from the example above, the whole model is stored. It means that the file contains the model structure (its architecture), model weights and the optimizer parameter.

To store only the model weights, we should set the save_weights_only parameter of the ModelCheckpoint to true.

1

checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min', save_weights_only = True)

Obviously, in that case, we can no longer use the `load_model`

function.

Now, it is necessary to define the model architecture again, set the optimizer parameters, and compile the model.

After all of that, we finally can call the `model.load_weights`

function.

So, we should probably stick to storing the whole model.

You may also like

**Remember to share on social media!**

If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

**If you want to contact me, send me a message on LinkedIn or Twitter.**

**Would you like to have a call and talk? Please schedule a meeting using this link.**

## Do you want to work with me at riskmethods?

REMOTE position (available in Poland or Germany)