Module keras.distribute.worker_training_state

Training state management.

Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training state management."""

import tensorflow.compat.v2 as tf

import os
from keras import backend
from keras.distribute import distributed_file_utils
from keras.utils import mode_keys

# Constant for `tf.keras.Model` attribute to store the epoch at which the most
# recently saved checkpoint was saved.
CKPT_SAVED_EPOCH = '_ckpt_saved_epoch'

CKPT_SAVED_EPOCH_UNUSED_VALUE = -1


class WorkerTrainingState(object):
  """Training state management class.

  This class provides apis for backing up and restoring the training state.
  This allows model and epoch information to be saved periodically and restore
  for fault-tolerance, also known as preemption-recovery purpose.
  """

  def __init__(self, model, checkpoint_dir):
    self._model = model

    # The epoch at which the checkpoint is saved. Used for fault-tolerance.
    # GPU device only has int64 dtype registered VarHandleOp.
    self._ckpt_saved_epoch = tf.Variable(
        initial_value=tf.constant(
            CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64),
        name='ckpt_saved_epoch')

    # Variable initialization.
    backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

    # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
    # when backing up.
    checkpoint = tf.train.Checkpoint(
        model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)

    # If this is single-worker training, checkpoint_dir are the same for
    # write_checkpoint_manager and read_checkpoint_manager.
    #
    # If this is multi-worker training, and this worker should not
    # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
    # with a temp filepath, so it writes to a file that will be removed at the
    # end of back_up() call. This is necessary because the SyncOnReadVariable
    # needs to be synced across all the workers in order to be read, and all
    # workers need to perform `save()`.
    # But all workers should restore from the same checkpoint_dir as passed in
    # read_checkpoint_manager.
    self.read_checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=os.path.join(checkpoint_dir, 'chief'),
        max_to_keep=1)
    write_checkpoint_dir = distributed_file_utils.write_dirpath(
        checkpoint_dir, self._model.distribute_strategy)
    if self._model.distribute_strategy.extended.should_checkpoint:
      self.write_checkpoint_manager = self.read_checkpoint_manager
    else:
      self.write_checkpoint_manager = tf.train.CheckpointManager(
          checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

  def back_up(self, epoch):
    """Back up the current state of training into a checkpoint file.

    Args:
      epoch: The current epoch information to be saved.
    """
    backend.set_value(self._ckpt_saved_epoch, epoch)
    # Save the model plus CKPT_SAVED_EPOCH variable.
    if self.write_checkpoint_manager.save():
      distributed_file_utils.remove_temp_dirpath(
          self.write_checkpoint_manager.directory,
          self._model.distribute_strategy)

  def restore(self):
    """Restore the training state from the backed up checkpoint file.

    Returns:
      True if the training state is successfully restored. False if the training
      state doesn't need to be restored, or error occurred so it can't.
    """
    self.read_checkpoint_manager.restore_or_initialize()

  def delete_backup(self):
    """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
    if self.write_checkpoint_manager is self.read_checkpoint_manager:
      try:
        tf.io.gfile.rmtree(self.write_checkpoint_manager.directory)
      except tf.errors.NotFoundError:
        pass

  def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
    """Maybe load initial epoch from ckpt considering possible worker recovery.

    When `_ckpt_saved_epoch` attribute exists and is not
    `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
    and indicates the worker is recovering from previous failure. In this case,
    infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
    unfinished training from certain epoch.

    Args:
      initial_epoch: The original initial_epoch user passes in in `fit()`.
      mode: The mode for running `model.fit()`.

    Returns:
      If the training is recovering from previous failure under multi-worker
      training setting, return the epoch the training is supposed to continue
      at. Otherwise, return the `initial_epoch` the user passes in.
    """

    epoch = backend.eval(self._ckpt_saved_epoch)
    if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
      # The most recently saved epoch is one epoch prior to the epoch it
      # failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
      return epoch + 1
    return initial_epoch

Classes

class WorkerTrainingState (model, checkpoint_dir)

Training state management class.

This class provides apis for backing up and restoring the training state. This allows model and epoch information to be saved periodically and restore for fault-tolerance, also known as preemption-recovery purpose.

Expand source code
class WorkerTrainingState(object):
  """Training state management class.

  This class provides apis for backing up and restoring the training state.
  This allows model and epoch information to be saved periodically and restore
  for fault-tolerance, also known as preemption-recovery purpose.
  """

  def __init__(self, model, checkpoint_dir):
    self._model = model

    # The epoch at which the checkpoint is saved. Used for fault-tolerance.
    # GPU device only has int64 dtype registered VarHandleOp.
    self._ckpt_saved_epoch = tf.Variable(
        initial_value=tf.constant(
            CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=tf.int64),
        name='ckpt_saved_epoch')

    # Variable initialization.
    backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

    # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
    # when backing up.
    checkpoint = tf.train.Checkpoint(
        model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)

    # If this is single-worker training, checkpoint_dir are the same for
    # write_checkpoint_manager and read_checkpoint_manager.
    #
    # If this is multi-worker training, and this worker should not
    # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
    # with a temp filepath, so it writes to a file that will be removed at the
    # end of back_up() call. This is necessary because the SyncOnReadVariable
    # needs to be synced across all the workers in order to be read, and all
    # workers need to perform `save()`.
    # But all workers should restore from the same checkpoint_dir as passed in
    # read_checkpoint_manager.
    self.read_checkpoint_manager = tf.train.CheckpointManager(
        checkpoint,
        directory=os.path.join(checkpoint_dir, 'chief'),
        max_to_keep=1)
    write_checkpoint_dir = distributed_file_utils.write_dirpath(
        checkpoint_dir, self._model.distribute_strategy)
    if self._model.distribute_strategy.extended.should_checkpoint:
      self.write_checkpoint_manager = self.read_checkpoint_manager
    else:
      self.write_checkpoint_manager = tf.train.CheckpointManager(
          checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

  def back_up(self, epoch):
    """Back up the current state of training into a checkpoint file.

    Args:
      epoch: The current epoch information to be saved.
    """
    backend.set_value(self._ckpt_saved_epoch, epoch)
    # Save the model plus CKPT_SAVED_EPOCH variable.
    if self.write_checkpoint_manager.save():
      distributed_file_utils.remove_temp_dirpath(
          self.write_checkpoint_manager.directory,
          self._model.distribute_strategy)

  def restore(self):
    """Restore the training state from the backed up checkpoint file.

    Returns:
      True if the training state is successfully restored. False if the training
      state doesn't need to be restored, or error occurred so it can't.
    """
    self.read_checkpoint_manager.restore_or_initialize()

  def delete_backup(self):
    """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
    if self.write_checkpoint_manager is self.read_checkpoint_manager:
      try:
        tf.io.gfile.rmtree(self.write_checkpoint_manager.directory)
      except tf.errors.NotFoundError:
        pass

  def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
    """Maybe load initial epoch from ckpt considering possible worker recovery.

    When `_ckpt_saved_epoch` attribute exists and is not
    `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
    and indicates the worker is recovering from previous failure. In this case,
    infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
    unfinished training from certain epoch.

    Args:
      initial_epoch: The original initial_epoch user passes in in `fit()`.
      mode: The mode for running `model.fit()`.

    Returns:
      If the training is recovering from previous failure under multi-worker
      training setting, return the epoch the training is supposed to continue
      at. Otherwise, return the `initial_epoch` the user passes in.
    """

    epoch = backend.eval(self._ckpt_saved_epoch)
    if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
      # The most recently saved epoch is one epoch prior to the epoch it
      # failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
      return epoch + 1
    return initial_epoch

Methods

def back_up(self, epoch)

Back up the current state of training into a checkpoint file.

Args

epoch
The current epoch information to be saved.
Expand source code
def back_up(self, epoch):
  """Back up the current state of training into a checkpoint file.

  Args:
    epoch: The current epoch information to be saved.
  """
  backend.set_value(self._ckpt_saved_epoch, epoch)
  # Save the model plus CKPT_SAVED_EPOCH variable.
  if self.write_checkpoint_manager.save():
    distributed_file_utils.remove_temp_dirpath(
        self.write_checkpoint_manager.directory,
        self._model.distribute_strategy)
def delete_backup(self)

Delete the backup directories.

Delete the backup directories which should not exist after fit() successfully finishes.

Expand source code
def delete_backup(self):
  """Delete the backup directories.

  Delete the backup directories which should not exist after `fit()`
  successfully finishes.
  """
  if self.write_checkpoint_manager is self.read_checkpoint_manager:
    try:
      tf.io.gfile.rmtree(self.write_checkpoint_manager.directory)
    except tf.errors.NotFoundError:
      pass
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode)

Maybe load initial epoch from ckpt considering possible worker recovery.

When _ckpt_saved_epoch attribute exists and is not CKPT_SAVED_EPOCH_UNUSED_VALUE, this is under multi-worker training setting and indicates the worker is recovering from previous failure. In this case, infer initial_epoch from self._ckpt_saved_epoch to continue previous unfinished training from certain epoch.

Args

initial_epoch
The original initial_epoch user passes in in fit().
mode
The mode for running model.fit().

Returns

If the training is recovering from previous failure under multi-worker training setting, return the epoch the training is supposed to continue at. Otherwise, return the initial_epoch the user passes in.

Expand source code
def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
  """Maybe load initial epoch from ckpt considering possible worker recovery.

  When `_ckpt_saved_epoch` attribute exists and is not
  `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
  and indicates the worker is recovering from previous failure. In this case,
  infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
  unfinished training from certain epoch.

  Args:
    initial_epoch: The original initial_epoch user passes in in `fit()`.
    mode: The mode for running `model.fit()`.

  Returns:
    If the training is recovering from previous failure under multi-worker
    training setting, return the epoch the training is supposed to continue
    at. Otherwise, return the `initial_epoch` the user passes in.
  """

  epoch = backend.eval(self._ckpt_saved_epoch)
  if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
    # The most recently saved epoch is one epoch prior to the epoch it
    # failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
    return epoch + 1
  return initial_epoch
def restore(self)

Restore the training state from the backed up checkpoint file.

Returns

True if the training state is successfully restored. False if the training state doesn't need to be restored, or error occurred so it can't.

Expand source code
def restore(self):
  """Restore the training state from the backed up checkpoint file.

  Returns:
    True if the training state is successfully restored. False if the training
    state doesn't need to be restored, or error occurred so it can't.
  """
  self.read_checkpoint_manager.restore_or_initialize()