Module keras.saving.saving_utils

Utils related to keras model saving.

Expand source code
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""

# pylint: disable=g-bad-import-order, g-direct-tensorflow-import
import tensorflow.compat.v2 as tf

import copy
import os
from keras import backend as K
from keras import losses
from keras import optimizer_v1
from keras import optimizers
from keras.engine import base_layer_utils
from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
# pylint: enable=g-bad-import-order, g-direct-tensorflow-import


def extract_model_metrics(model):
  """Convert metrics from a Keras model `compile` API to dictionary.

  This is used for converting Keras models to Estimators and SavedModels.

  Args:
    model: A `tf.keras.Model` object.

  Returns:
    Dictionary mapping metric names to metric instances. May return `None` if
    the model does not contain any metrics.
  """
  if getattr(model, '_compile_metrics', None):
    # TODO(psv/kathywu): use this implementation in model to estimator flow.
    # We are not using model.metrics here because we want to exclude the metrics
    # added using `add_metric` API.
    return {m.name: m for m in model._compile_metric_functions}  # pylint: disable=protected-access
  return None


def model_call_inputs(model, keep_original_batch_size=False):
  """Inspect model to get its input signature.

  The model's input signature is a list with a single (possibly-nested) object.
  This is due to the Keras-enforced restriction that tensor inputs must be
  passed in as the first argument.

  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]

  Args:
    model: Keras Model object.
    keep_original_batch_size: A boolean indicating whether we want to keep using
      the original batch size or set it to None. Default is `False`, which means
      that the batch dim of the returned input signature will always be set to
      `None`.

  Returns:
    A tuple containing `(args, kwargs)` TensorSpecs of the model call function
    inputs.
    `kwargs` does not contain the `training` argument.
  """
  input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
  if input_specs is None:
    return None, None
  input_specs = _enforce_names_consistency(input_specs)
  return input_specs


def raise_model_input_error(model):
  raise ValueError(
      'Model {} cannot be saved because the input shapes have not been '
      'set. Usually, input shapes are automatically determined from calling'
      ' `.fit()` or `.predict()`. To manually set the shapes, call '
      '`model.build(input_shape)`.'.format(model))


def trace_model_call(model, input_signature=None):
  """Trace the model call to create a tf.function for exporting a Keras model.

  Args:
    model: A Keras model.
    input_signature: optional, a list of tf.TensorSpec objects specifying the
      inputs to the model.

  Returns:
    A tf.function wrapping the model's call function with input signatures set.

  Raises:
    ValueError: if input signature cannot be inferred from the model.
  """
  if input_signature is None:
    if isinstance(model.call, tf.__internal__.function.Function):
      input_signature = model.call.input_signature

  if input_signature:
    model_args = input_signature
    model_kwargs = {}
  else:
    model_args, model_kwargs = model_call_inputs(model)
    input_signature = model_args  # store

    if model_args is None:
      raise_model_input_error(model)

  @tf.function
  def _wrapped_model(*args, **kwargs):
    """A concrete tf.function that wraps the model's call function."""
    kwargs['training'] = False
    with base_layer_utils.call_context().enter(
        model, inputs=None, build_graph=False, training=False, saving=True):
      outputs = model(*args, **kwargs)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}

  return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)


def model_metadata(model, include_optimizer=True, require_config=True):
  """Returns a dictionary containing the model metadata."""
  from keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
  from keras.optimizer_v2 import optimizer_v2  # pylint: disable=g-import-not-at-top

  model_config = {'class_name': model.__class__.__name__}
  try:
    model_config['config'] = model.get_config()
  except NotImplementedError as e:
    if require_config:
      raise e

  metadata = dict(
      keras_version=str(keras_version),
      backend=K.backend(),
      model_config=model_config)
  if model.optimizer and include_optimizer:
    if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
      logging.warning(
          'TensorFlow optimizers do not '
          'make it possible to access '
          'optimizer attributes or optimizer state '
          'after instantiation. '
          'As a result, we cannot save the optimizer '
          'as part of the model save file. '
          'You will have to compile your model again after loading it. '
          'Prefer using a Keras optimizer instead '
          '(see keras.io/optimizers).')
    elif model._compile_was_called:  # pylint: disable=protected-access
      training_config = model._get_compile_args(user_metrics=False)  # pylint: disable=protected-access
      training_config.pop('optimizer', None)  # Handled separately.
      metadata['training_config'] = _serialize_nested_config(training_config)
      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
        raise NotImplementedError(
            'As of now, Optimizers loaded from SavedModel cannot be saved. '
            'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
            ' please set the `include_optimizer` option to `False`. For '
            '`tf.saved_model.save`, delete the optimizer from the model.')
      else:
        optimizer_config = {
            'class_name':
                generic_utils.get_registered_name(model.optimizer.__class__),
            'config':
                model.optimizer.get_config()
        }
      metadata['training_config']['optimizer_config'] = optimizer_config
  return metadata


def should_overwrite(filepath, overwrite):
  """Returns whether the filepath should be overwritten."""
  # If file exists and should not be overwritten.
  if not overwrite and os.path.isfile(filepath):
    return ask_to_proceed_with_overwrite(filepath)
  return True


def compile_args_from_training_config(training_config, custom_objects=None):
  """Return model.compile arguments from training config."""
  if custom_objects is None:
    custom_objects = {}

  with generic_utils.CustomObjectScope(custom_objects):
    optimizer_config = training_config['optimizer_config']
    optimizer = optimizers.deserialize(optimizer_config)

    # Recover losses.
    loss = None
    loss_config = training_config.get('loss', None)
    if loss_config is not None:
      loss = _deserialize_nested_config(losses.deserialize, loss_config)

    # Recover metrics.
    metrics = None
    metrics_config = training_config.get('metrics', None)
    if metrics_config is not None:
      metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)

    # Recover weighted metrics.
    weighted_metrics = None
    weighted_metrics_config = training_config.get('weighted_metrics', None)
    if weighted_metrics_config is not None:
      weighted_metrics = _deserialize_nested_config(_deserialize_metric,
                                                    weighted_metrics_config)

    sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
        training_config, 'sample_weight_mode') else None
    loss_weights = training_config['loss_weights']

  return dict(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      weighted_metrics=weighted_metrics,
      loss_weights=loss_weights,
      sample_weight_mode=sample_weight_mode)


def _deserialize_nested_config(deserialize_fn, config):
  """Deserializes arbitrary Keras `config` using `deserialize_fn`."""

  def _is_single_object(obj):
    if isinstance(obj, dict) and 'class_name' in obj:
      return True  # Serialized Keras object.
    if isinstance(obj, str):
      return True  # Serialized function or string.
    return False

  if config is None:
    return None
  if _is_single_object(config):
    return deserialize_fn(config)
  elif isinstance(config, dict):
    return {
        k: _deserialize_nested_config(deserialize_fn, v)
        for k, v in config.items()
    }
  elif isinstance(config, (tuple, list)):
    return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]

  raise ValueError('Saved configuration not understood.')


def _serialize_nested_config(config):
  """Serialized a nested structure of Keras objects."""

  def _serialize_fn(obj):
    if callable(obj):
      return generic_utils.serialize_keras_object(obj)
    return obj

  return tf.nest.map_structure(_serialize_fn, config)


def _deserialize_metric(metric_config):
  """Deserialize metrics, leaving special strings untouched."""
  from keras import metrics as metrics_module  # pylint:disable=g-import-not-at-top
  if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
    # Do not deserialize accuracy and cross-entropy strings as we have special
    # case handling for these in compile, based on model output shape.
    return metric_config
  return metrics_module.deserialize(metric_config)


def _enforce_names_consistency(specs):
  """Enforces that either all specs have names or none do."""

  def _has_name(spec):
    return hasattr(spec, 'name') and spec.name is not None

  def _clear_name(spec):
    spec = copy.deepcopy(spec)
    if hasattr(spec, 'name'):
      spec._name = None  # pylint:disable=protected-access
    return spec

  flat_specs = tf.nest.flatten(specs)
  name_inconsistency = (
      any(_has_name(s) for s in flat_specs) and
      not all(_has_name(s) for s in flat_specs))

  if name_inconsistency:
    specs = tf.nest.map_structure(_clear_name, specs)
  return specs


def try_build_compiled_arguments(model):
  if (not version_utils.is_v1_layer_or_model(model) and
      model.outputs is not None):
    try:
      if not model.compiled_loss.built:
        model.compiled_loss.build(model.outputs)
      if not model.compiled_metrics.built:
        model.compiled_metrics.build(model.outputs, model.outputs)
    except:  # pylint: disable=bare-except
      logging.warning(
          'Compiled the loaded model, but the compiled metrics have yet to '
          'be built. `model.compile_metrics` will be empty until you train '
          'or evaluate the model.')


def is_hdf5_filepath(filepath):
  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
          filepath.endswith('.hdf5'))

Functions

def compile_args_from_training_config(training_config, custom_objects=None)

Return model.compile arguments from training config.

Expand source code
def compile_args_from_training_config(training_config, custom_objects=None):
  """Return model.compile arguments from training config."""
  if custom_objects is None:
    custom_objects = {}

  with generic_utils.CustomObjectScope(custom_objects):
    optimizer_config = training_config['optimizer_config']
    optimizer = optimizers.deserialize(optimizer_config)

    # Recover losses.
    loss = None
    loss_config = training_config.get('loss', None)
    if loss_config is not None:
      loss = _deserialize_nested_config(losses.deserialize, loss_config)

    # Recover metrics.
    metrics = None
    metrics_config = training_config.get('metrics', None)
    if metrics_config is not None:
      metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)

    # Recover weighted metrics.
    weighted_metrics = None
    weighted_metrics_config = training_config.get('weighted_metrics', None)
    if weighted_metrics_config is not None:
      weighted_metrics = _deserialize_nested_config(_deserialize_metric,
                                                    weighted_metrics_config)

    sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
        training_config, 'sample_weight_mode') else None
    loss_weights = training_config['loss_weights']

  return dict(
      optimizer=optimizer,
      loss=loss,
      metrics=metrics,
      weighted_metrics=weighted_metrics,
      loss_weights=loss_weights,
      sample_weight_mode=sample_weight_mode)
def extract_model_metrics(model)

Convert metrics from a Keras model compile API to dictionary.

This is used for converting Keras models to Estimators and SavedModels.

Args

model
A tf.keras.Model object.

Returns

Dictionary mapping metric names to metric instances. May return None if the model does not contain any metrics.

Expand source code
def extract_model_metrics(model):
  """Convert metrics from a Keras model `compile` API to dictionary.

  This is used for converting Keras models to Estimators and SavedModels.

  Args:
    model: A `tf.keras.Model` object.

  Returns:
    Dictionary mapping metric names to metric instances. May return `None` if
    the model does not contain any metrics.
  """
  if getattr(model, '_compile_metrics', None):
    # TODO(psv/kathywu): use this implementation in model to estimator flow.
    # We are not using model.metrics here because we want to exclude the metrics
    # added using `add_metric` API.
    return {m.name: m for m in model._compile_metric_functions}  # pylint: disable=protected-access
  return None
def is_hdf5_filepath(filepath)
Expand source code
def is_hdf5_filepath(filepath):
  return (filepath.endswith('.h5') or filepath.endswith('.keras') or
          filepath.endswith('.hdf5'))
def model_call_inputs(model, keep_original_batch_size=False)

Inspect model to get its input signature.

The model's input signature is a list with a single (possibly-nested) object. This is due to the Keras-enforced restriction that tensor inputs must be passed in as the first argument.

For example, a model with input {'feature1': , 'feature2': } will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]

Args

model
Keras Model object.
keep_original_batch_size
A boolean indicating whether we want to keep using the original batch size or set it to None. Default is False, which means that the batch dim of the returned input signature will always be set to None.

Returns

A tuple containing (args, kwargs) TensorSpecs of the model call function inputs. kwargs does not contain the training argument.

Expand source code
def model_call_inputs(model, keep_original_batch_size=False):
  """Inspect model to get its input signature.

  The model's input signature is a list with a single (possibly-nested) object.
  This is due to the Keras-enforced restriction that tensor inputs must be
  passed in as the first argument.

  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]

  Args:
    model: Keras Model object.
    keep_original_batch_size: A boolean indicating whether we want to keep using
      the original batch size or set it to None. Default is `False`, which means
      that the batch dim of the returned input signature will always be set to
      `None`.

  Returns:
    A tuple containing `(args, kwargs)` TensorSpecs of the model call function
    inputs.
    `kwargs` does not contain the `training` argument.
  """
  input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
  if input_specs is None:
    return None, None
  input_specs = _enforce_names_consistency(input_specs)
  return input_specs
def model_metadata(model, include_optimizer=True, require_config=True)

Returns a dictionary containing the model metadata.

Expand source code
def model_metadata(model, include_optimizer=True, require_config=True):
  """Returns a dictionary containing the model metadata."""
  from keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
  from keras.optimizer_v2 import optimizer_v2  # pylint: disable=g-import-not-at-top

  model_config = {'class_name': model.__class__.__name__}
  try:
    model_config['config'] = model.get_config()
  except NotImplementedError as e:
    if require_config:
      raise e

  metadata = dict(
      keras_version=str(keras_version),
      backend=K.backend(),
      model_config=model_config)
  if model.optimizer and include_optimizer:
    if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
      logging.warning(
          'TensorFlow optimizers do not '
          'make it possible to access '
          'optimizer attributes or optimizer state '
          'after instantiation. '
          'As a result, we cannot save the optimizer '
          'as part of the model save file. '
          'You will have to compile your model again after loading it. '
          'Prefer using a Keras optimizer instead '
          '(see keras.io/optimizers).')
    elif model._compile_was_called:  # pylint: disable=protected-access
      training_config = model._get_compile_args(user_metrics=False)  # pylint: disable=protected-access
      training_config.pop('optimizer', None)  # Handled separately.
      metadata['training_config'] = _serialize_nested_config(training_config)
      if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
        raise NotImplementedError(
            'As of now, Optimizers loaded from SavedModel cannot be saved. '
            'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
            ' please set the `include_optimizer` option to `False`. For '
            '`tf.saved_model.save`, delete the optimizer from the model.')
      else:
        optimizer_config = {
            'class_name':
                generic_utils.get_registered_name(model.optimizer.__class__),
            'config':
                model.optimizer.get_config()
        }
      metadata['training_config']['optimizer_config'] = optimizer_config
  return metadata
def raise_model_input_error(model)
Expand source code
def raise_model_input_error(model):
  raise ValueError(
      'Model {} cannot be saved because the input shapes have not been '
      'set. Usually, input shapes are automatically determined from calling'
      ' `.fit()` or `.predict()`. To manually set the shapes, call '
      '`model.build(input_shape)`.'.format(model))
def should_overwrite(filepath, overwrite)

Returns whether the filepath should be overwritten.

Expand source code
def should_overwrite(filepath, overwrite):
  """Returns whether the filepath should be overwritten."""
  # If file exists and should not be overwritten.
  if not overwrite and os.path.isfile(filepath):
    return ask_to_proceed_with_overwrite(filepath)
  return True
def trace_model_call(model, input_signature=None)

Trace the model call to create a tf.function for exporting a Keras model.

Args

model
A Keras model.
input_signature
optional, a list of tf.TensorSpec objects specifying the inputs to the model.

Returns

A tf.function wrapping the model's call function with input signatures set.

Raises

ValueError
if input signature cannot be inferred from the model.
Expand source code
def trace_model_call(model, input_signature=None):
  """Trace the model call to create a tf.function for exporting a Keras model.

  Args:
    model: A Keras model.
    input_signature: optional, a list of tf.TensorSpec objects specifying the
      inputs to the model.

  Returns:
    A tf.function wrapping the model's call function with input signatures set.

  Raises:
    ValueError: if input signature cannot be inferred from the model.
  """
  if input_signature is None:
    if isinstance(model.call, tf.__internal__.function.Function):
      input_signature = model.call.input_signature

  if input_signature:
    model_args = input_signature
    model_kwargs = {}
  else:
    model_args, model_kwargs = model_call_inputs(model)
    input_signature = model_args  # store

    if model_args is None:
      raise_model_input_error(model)

  @tf.function
  def _wrapped_model(*args, **kwargs):
    """A concrete tf.function that wraps the model's call function."""
    kwargs['training'] = False
    with base_layer_utils.call_context().enter(
        model, inputs=None, build_graph=False, training=False, saving=True):
      outputs = model(*args, **kwargs)

    # Outputs always has to be a flat dict.
    output_names = model.output_names  # Functional Model.
    if output_names is None:  # Subclassed Model.
      from keras.engine import compile_utils  # pylint: disable=g-import-not-at-top
      output_names = compile_utils.create_pseudo_output_names(outputs)
    outputs = tf.nest.flatten(outputs)
    return {name: output for name, output in zip(output_names, outputs)}

  return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
def try_build_compiled_arguments(model)
Expand source code
def try_build_compiled_arguments(model):
  if (not version_utils.is_v1_layer_or_model(model) and
      model.outputs is not None):
    try:
      if not model.compiled_loss.built:
        model.compiled_loss.build(model.outputs)
      if not model.compiled_metrics.built:
        model.compiled_metrics.build(model.outputs, model.outputs)
    except:  # pylint: disable=bare-except
      logging.warning(
          'Compiled the loaded model, but the compiled metrics have yet to '
          'be built. `model.compile_metrics` will be empty until you train '
          'or evaluate the model.')