Keras Load a Model and Continue Training

System Information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
    Yes

  • OS Platform and Distribution:
    CentOS Linux 7.6.1810

  • TensorFlow installed from :
    binary (pip)

  • TensorFlow version :
    2.2.0-rc3

  • Python version:
    3.6.4

  • CUDA/cuDNN version:
    CUDA 10.1 / cuDNN 7.6.0

  • GPU model and memory:
    2 x TitanX - 12GB

Describe the current behavior
I am using Tensorflow 2.2.0 on multi-gpu system. Having the need to train large networks for several days, I save the model weights with optimizer state using model.save(). When I reload the model using tf.keras.models.load_model(), the loss spikes sharply on TensorBoard and the accuracy also shows a sudden drop. Though the loss recovers within the epoch, it does not comply with the intended behavior of saving training state using model.save().

Describe the expected behavior
The API should be able to save and resume training from the very same point after loading a model from '.h5' file.

Standalone code to reproduce the issue
This code is a minimal reproducible example. It was tested on multi-gpu systems with 8 gpus. The re-run of the script is achieved by deleting the current model and distribute strategy and re-initializing them to simulate stop and restart of training process.

            import os import glob import numpy as np import tensorflow as tf tf.__version__  gpus = tf.config.experimental.list_logical_devices('GPU') print(gpus)  RESULT_DIR = os.path.join(os.getcwd(), 'Test', 'Results') CHECKPOINT_FREQUENCY = 16 LOG_EVERY = 1  BATCH_SIZE_PER_GPU = 16 NUM_GPUS = len(gpus) GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_GPU * NUM_GPUS  def get_model():          model = tf.keras.Sequential([         tf.keras.layers.Conv2D(filters=32, strides=1, kernel_size=(4,4), input_shape=(28,28,1)),         tf.keras.layers.Activation('relu'),         tf.keras.layers.BatchNormalization(),         tf.keras.layers.Flatten(),         tf.keras.layers.Dense(10)     ])          return model  class SparseCategoricalLoss(tf.keras.losses.Loss):          def __init__(self, num_classes, name='SparseCategoricalLoss', from_logits=False, loss_weight=1.0, *args, **kwargs):                  super().__init__(*args, **kwargs)         self.num_classes = num_classes         self.name = name         self.from_logits=from_logits         self.loss_weight = loss_weight              def loss_fn(self, y_true, y_pred):         label = y_true[:,0:self.num_classes]         logit = y_pred[:,0:self.num_classes]         loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=self.from_logits,                                                              name=self.name,                                                              reduction=tf.keras.losses.Reduction.NONE)(label, logit)         loss *= self.loss_weight         return loss               def call(self, y_true, y_pred):         total_loss = self.loss_fn(y_true, y_pred)         return total_loss      def get_config(self):                   config = super().get_config().copy()         config.update({             'num_classes' : self.num_classes,             'name' : self.name,             'loss_weight' : self.loss_weight         })         return config  loss = SparseCategoricalLoss(num_classes=10,                              from_logits=True,                              name='categorical_loss')  strategy = tf.distribute.MirroredStrategy()  with strategy.scope():          model = get_model()          optimizer = tf.keras.optimizers.RMSprop(                                             learning_rate=0.001,                                             epsilon=1.0,                                             momentum=0.9,                                             rho=0.9                                            )          model.compile(optimizer=optimizer, loss=loss, metrics=['acc'])  (X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data() X_train = np.expand_dims(X_train, 3) X_test = np.expand_dims(X_test, 3)  class LoggingCallback(tf.keras.callbacks.Callback):      def __init__(self, result_dir, log_every, initial_step=0, checkpoint_frequency=None, **kwargs):                  super().__init__(**kwargs)                  # Create result directory         self.result_dir = result_dir         if not os.path.exists(result_dir):             os.makedirs(result_dir)                  # create checkpoint directory         checkpoint_dir = os.path.join(self.result_dir, 'checkpoint')         if not os.path.exists(checkpoint_dir):             os.makedirs(checkpoint_dir)                  # create tensorboard directory         tensorboard_dir = os.path.join(self.result_dir, 'tensorboard')         if not os.path.join(tensorboard_dir):             os.makedirs(tensorboard_dir)                  self.log_every = log_every         self.checkpoint_frequency = checkpoint_frequency         self.train_writer = tf.summary.create_file_writer( os.path.join(tensorboard_dir, 'train') )         self.step = initial_step                       # Write metrics to TensorBoard         def write_metrics_tensorboard(self, logs):         with self.train_writer.as_default():             for name, value in logs.items():                 if name in ['batch', 'size']:                     continue                 tf.summary.scalar(name, value, step=self.step)                                       def on_batch_end(self, batch, logs=None):                  self.step += 1                  # Write metrics to tensorboard         if self.step % self.log_every == 0:             self.write_metrics_tensorboard(logs)                      # Save model checkpoint (weights + optimizer state)         if self.checkpoint_frequency and self.step % self.checkpoint_frequency == 0:             name = 'model_step_%d.h5' % self.step             path = os.path.join(self.result_dir, 'checkpoint', name)             self.model.save( path )  callbacks = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, checkpoint_frequency=CHECKPOINT_FREQUENCY)  model.fit(           x = X_train,            y = Y_train,            batch_size=GLOBAL_BATCH_SIZE,           epochs=7,           validation_data = (X_test, Y_test),           callbacks=callbacks,           verbose=1           )  del model del strategy  previous_checkpoints = glob.glob(os.path.join(RESULT_DIR, 'checkpoint', '*')) previous_checkpoints.sort(key=lambda x : int(os.path.basename(x).split('_')[2].replace('.h5', '')) ) latest_checkpoint = previous_checkpoints[-1] print('Found Latest Checkpoint : %s' % latest_checkpoint)      initial_step = int(os.path.basename(latest_checkpoint).split('_')[2].replace('.h5', '')) print('Resuming training from step %d' % initial_step)      new_callback = LoggingCallback(result_dir=RESULT_DIR, log_every=LOG_EVERY, initial_step=initial_step, checkpoint_frequency=CHECKPOINT_FREQUENCY)  strategy = tf.distribute.MirroredStrategy() with strategy.scope():     model = tf.keras.models.load_model( latest_checkpoint, custom_objects={'SparseCategoricalLoss':SparseCategoricalLoss} )  model.fit(           x = X_train,            y = Y_train,            batch_size=GLOBAL_BATCH_SIZE,           epochs=10,           validation_data = (X_test, Y_test),           callbacks=new_callback,           verbose=1           )                      

Here is a link to colab showing the output : https://colab.research.google.com/gist/suraj-maniyar/1a305d7249baee4393147cb479ea2933/restart_training.ipynb

Other info / logs

The TensorBoard entry looks like this :
tensorboard

This was a toy example using mnist. After about 26k steps, when the training was restarted, the loss spiked up indicating that the last saved checkpoint did not save the training configuration correctly.
I am training an InceptionResNet network for several days and the spike in the loss is very concerning when I restart the training (shown below).
tensorboard_inception

hernandezmuseltook1981.blogspot.com

Source: https://github.com/tensorflow/tensorflow/issues/40342

0 Response to "Keras Load a Model and Continue Training"

Post a Comment

Iklan Atas Artikel

Iklan Tengah Artikel 1

Iklan Tengah Artikel 2

Iklan Bawah Artikel