Module keras.saving.saving_utils
Utils related to keras model saving.
Expand source code
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""
# pylint: disable=g-bad-import-order, g-direct-tensorflow-import
import tensorflow.compat.v2 as tf
import copy
import os
from keras import backend as K
from keras import losses
from keras import optimizer_v1
from keras import optimizers
from keras.engine import base_layer_utils
from keras.utils import generic_utils
from keras.utils import version_utils
from keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
# pylint: enable=g-bad-import-order, g-direct-tensorflow-import
def extract_model_metrics(model):
"""Convert metrics from a Keras model `compile` API to dictionary.
This is used for converting Keras models to Estimators and SavedModels.
Args:
model: A `tf.keras.Model` object.
Returns:
Dictionary mapping metric names to metric instances. May return `None` if
the model does not contain any metrics.
"""
if getattr(model, '_compile_metrics', None):
# TODO(psv/kathywu): use this implementation in model to estimator flow.
# We are not using model.metrics here because we want to exclude the metrics
# added using `add_metric` API.
return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
return None
def model_call_inputs(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
This is due to the Keras-enforced restriction that tensor inputs must be
passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object.
keep_original_batch_size: A boolean indicating whether we want to keep using
the original batch size or set it to None. Default is `False`, which means
that the batch dim of the returned input signature will always be set to
`None`.
Returns:
A tuple containing `(args, kwargs)` TensorSpecs of the model call function
inputs.
`kwargs` does not contain the `training` argument.
"""
input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
if input_specs is None:
return None, None
input_specs = _enforce_names_consistency(input_specs)
return input_specs
def raise_model_input_error(model):
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' `.fit()` or `.predict()`. To manually set the shapes, call '
'`model.build(input_shape)`.'.format(model))
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
Args:
model: A Keras model.
input_signature: optional, a list of tf.TensorSpec objects specifying the
inputs to the model.
Returns:
A tf.function wrapping the model's call function with input signatures set.
Raises:
ValueError: if input signature cannot be inferred from the model.
"""
if input_signature is None:
if isinstance(model.call, tf.__internal__.function.Function):
input_signature = model.call.input_signature
if input_signature:
model_args = input_signature
model_kwargs = {}
else:
model_args, model_kwargs = model_call_inputs(model)
input_signature = model_args # store
if model_args is None:
raise_model_input_error(model)
@tf.function
def _wrapped_model(*args, **kwargs):
"""A concrete tf.function that wraps the model's call function."""
kwargs['training'] = False
with base_layer_utils.call_context().enter(
model, inputs=None, build_graph=False, training=False, saving=True):
outputs = model(*args, **kwargs)
# Outputs always has to be a flat dict.
output_names = model.output_names # Functional Model.
if output_names is None: # Subclassed Model.
from keras.engine import compile_utils # pylint: disable=g-import-not-at-top
output_names = compile_utils.create_pseudo_output_names(outputs)
outputs = tf.nest.flatten(outputs)
return {name: output for name, output in zip(output_names, outputs)}
return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
def model_metadata(model, include_optimizer=True, require_config=True):
"""Returns a dictionary containing the model metadata."""
from keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
from keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top
model_config = {'class_name': model.__class__.__name__}
try:
model_config['config'] = model.get_config()
except NotImplementedError as e:
if require_config:
raise e
metadata = dict(
keras_version=str(keras_version),
backend=K.backend(),
model_config=model_config)
if model.optimizer and include_optimizer:
if isinstance(model.optimizer, optimizer_v1.TFOptimizer):
logging.warning(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot save the optimizer '
'as part of the model save file. '
'You will have to compile your model again after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
elif model._compile_was_called: # pylint: disable=protected-access
training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access
training_config.pop('optimizer', None) # Handled separately.
metadata['training_config'] = _serialize_nested_config(training_config)
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
raise NotImplementedError(
'As of now, Optimizers loaded from SavedModel cannot be saved. '
'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
' please set the `include_optimizer` option to `False`. For '
'`tf.saved_model.save`, delete the optimizer from the model.')
else:
optimizer_config = {
'class_name':
generic_utils.get_registered_name(model.optimizer.__class__),
'config':
model.optimizer.get_config()
}
metadata['training_config']['optimizer_config'] = optimizer_config
return metadata
def should_overwrite(filepath, overwrite):
"""Returns whether the filepath should be overwritten."""
# If file exists and should not be overwritten.
if not overwrite and os.path.isfile(filepath):
return ask_to_proceed_with_overwrite(filepath)
return True
def compile_args_from_training_config(training_config, custom_objects=None):
"""Return model.compile arguments from training config."""
if custom_objects is None:
custom_objects = {}
with generic_utils.CustomObjectScope(custom_objects):
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(optimizer_config)
# Recover losses.
loss = None
loss_config = training_config.get('loss', None)
if loss_config is not None:
loss = _deserialize_nested_config(losses.deserialize, loss_config)
# Recover metrics.
metrics = None
metrics_config = training_config.get('metrics', None)
if metrics_config is not None:
metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)
# Recover weighted metrics.
weighted_metrics = None
weighted_metrics_config = training_config.get('weighted_metrics', None)
if weighted_metrics_config is not None:
weighted_metrics = _deserialize_nested_config(_deserialize_metric,
weighted_metrics_config)
sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
training_config, 'sample_weight_mode') else None
loss_weights = training_config['loss_weights']
return dict(
optimizer=optimizer,
loss=loss,
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)
def _deserialize_nested_config(deserialize_fn, config):
"""Deserializes arbitrary Keras `config` using `deserialize_fn`."""
def _is_single_object(obj):
if isinstance(obj, dict) and 'class_name' in obj:
return True # Serialized Keras object.
if isinstance(obj, str):
return True # Serialized function or string.
return False
if config is None:
return None
if _is_single_object(config):
return deserialize_fn(config)
elif isinstance(config, dict):
return {
k: _deserialize_nested_config(deserialize_fn, v)
for k, v in config.items()
}
elif isinstance(config, (tuple, list)):
return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]
raise ValueError('Saved configuration not understood.')
def _serialize_nested_config(config):
"""Serialized a nested structure of Keras objects."""
def _serialize_fn(obj):
if callable(obj):
return generic_utils.serialize_keras_object(obj)
return obj
return tf.nest.map_structure(_serialize_fn, config)
def _deserialize_metric(metric_config):
"""Deserialize metrics, leaving special strings untouched."""
from keras import metrics as metrics_module # pylint:disable=g-import-not-at-top
if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
# Do not deserialize accuracy and cross-entropy strings as we have special
# case handling for these in compile, based on model output shape.
return metric_config
return metrics_module.deserialize(metric_config)
def _enforce_names_consistency(specs):
"""Enforces that either all specs have names or none do."""
def _has_name(spec):
return hasattr(spec, 'name') and spec.name is not None
def _clear_name(spec):
spec = copy.deepcopy(spec)
if hasattr(spec, 'name'):
spec._name = None # pylint:disable=protected-access
return spec
flat_specs = tf.nest.flatten(specs)
name_inconsistency = (
any(_has_name(s) for s in flat_specs) and
not all(_has_name(s) for s in flat_specs))
if name_inconsistency:
specs = tf.nest.map_structure(_clear_name, specs)
return specs
def try_build_compiled_arguments(model):
if (not version_utils.is_v1_layer_or_model(model) and
model.outputs is not None):
try:
if not model.compiled_loss.built:
model.compiled_loss.build(model.outputs)
if not model.compiled_metrics.built:
model.compiled_metrics.build(model.outputs, model.outputs)
except: # pylint: disable=bare-except
logging.warning(
'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.')
def is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
filepath.endswith('.hdf5'))
Functions
def compile_args_from_training_config(training_config, custom_objects=None)
-
Return model.compile arguments from training config.
Expand source code
def compile_args_from_training_config(training_config, custom_objects=None): """Return model.compile arguments from training config.""" if custom_objects is None: custom_objects = {} with generic_utils.CustomObjectScope(custom_objects): optimizer_config = training_config['optimizer_config'] optimizer = optimizers.deserialize(optimizer_config) # Recover losses. loss = None loss_config = training_config.get('loss', None) if loss_config is not None: loss = _deserialize_nested_config(losses.deserialize, loss_config) # Recover metrics. metrics = None metrics_config = training_config.get('metrics', None) if metrics_config is not None: metrics = _deserialize_nested_config(_deserialize_metric, metrics_config) # Recover weighted metrics. weighted_metrics = None weighted_metrics_config = training_config.get('weighted_metrics', None) if weighted_metrics_config is not None: weighted_metrics = _deserialize_nested_config(_deserialize_metric, weighted_metrics_config) sample_weight_mode = training_config['sample_weight_mode'] if hasattr( training_config, 'sample_weight_mode') else None loss_weights = training_config['loss_weights'] return dict( optimizer=optimizer, loss=loss, metrics=metrics, weighted_metrics=weighted_metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode)
def extract_model_metrics(model)
-
Convert metrics from a Keras model
compile
API to dictionary.This is used for converting Keras models to Estimators and SavedModels.
Args
model
- A
tf.keras.Model
object.
Returns
Dictionary mapping metric names to metric instances. May return
None
if the model does not contain any metrics.Expand source code
def extract_model_metrics(model): """Convert metrics from a Keras model `compile` API to dictionary. This is used for converting Keras models to Estimators and SavedModels. Args: model: A `tf.keras.Model` object. Returns: Dictionary mapping metric names to metric instances. May return `None` if the model does not contain any metrics. """ if getattr(model, '_compile_metrics', None): # TODO(psv/kathywu): use this implementation in model to estimator flow. # We are not using model.metrics here because we want to exclude the metrics # added using `add_metric` API. return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access return None
def is_hdf5_filepath(filepath)
-
Expand source code
def is_hdf5_filepath(filepath): return (filepath.endswith('.h5') or filepath.endswith('.keras') or filepath.endswith('.hdf5'))
def model_call_inputs(model, keep_original_batch_size=False)
-
Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object. This is due to the Keras-enforced restriction that tensor inputs must be passed in as the first argument.
For example, a model with input {'feature1':
, 'feature2': } will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] Args
model
- Keras Model object.
keep_original_batch_size
- A boolean indicating whether we want to keep using
the original batch size or set it to None. Default is
False
, which means that the batch dim of the returned input signature will always be set toNone
.
Returns
A tuple containing
(args, kwargs)
TensorSpecs of the model call function inputs.kwargs
does not contain thetraining
argument.Expand source code
def model_call_inputs(model, keep_original_batch_size=False): """Inspect model to get its input signature. The model's input signature is a list with a single (possibly-nested) object. This is due to the Keras-enforced restriction that tensor inputs must be passed in as the first argument. For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] Args: model: Keras Model object. keep_original_batch_size: A boolean indicating whether we want to keep using the original batch size or set it to None. Default is `False`, which means that the batch dim of the returned input signature will always be set to `None`. Returns: A tuple containing `(args, kwargs)` TensorSpecs of the model call function inputs. `kwargs` does not contain the `training` argument. """ input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size) if input_specs is None: return None, None input_specs = _enforce_names_consistency(input_specs) return input_specs
def model_metadata(model, include_optimizer=True, require_config=True)
-
Returns a dictionary containing the model metadata.
Expand source code
def model_metadata(model, include_optimizer=True, require_config=True): """Returns a dictionary containing the model metadata.""" from keras import __version__ as keras_version # pylint: disable=g-import-not-at-top from keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top model_config = {'class_name': model.__class__.__name__} try: model_config['config'] = model.get_config() except NotImplementedError as e: if require_config: raise e metadata = dict( keras_version=str(keras_version), backend=K.backend(), model_config=model_config) if model.optimizer and include_optimizer: if isinstance(model.optimizer, optimizer_v1.TFOptimizer): logging.warning( 'TensorFlow optimizers do not ' 'make it possible to access ' 'optimizer attributes or optimizer state ' 'after instantiation. ' 'As a result, we cannot save the optimizer ' 'as part of the model save file. ' 'You will have to compile your model again after loading it. ' 'Prefer using a Keras optimizer instead ' '(see keras.io/optimizers).') elif model._compile_was_called: # pylint: disable=protected-access training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access training_config.pop('optimizer', None) # Handled separately. metadata['training_config'] = _serialize_nested_config(training_config) if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): raise NotImplementedError( 'As of now, Optimizers loaded from SavedModel cannot be saved. ' 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' ' please set the `include_optimizer` option to `False`. For ' '`tf.saved_model.save`, delete the optimizer from the model.') else: optimizer_config = { 'class_name': generic_utils.get_registered_name(model.optimizer.__class__), 'config': model.optimizer.get_config() } metadata['training_config']['optimizer_config'] = optimizer_config return metadata
def raise_model_input_error(model)
-
Expand source code
def raise_model_input_error(model): raise ValueError( 'Model {} cannot be saved because the input shapes have not been ' 'set. Usually, input shapes are automatically determined from calling' ' `.fit()` or `.predict()`. To manually set the shapes, call ' '`model.build(input_shape)`.'.format(model))
def should_overwrite(filepath, overwrite)
-
Returns whether the filepath should be overwritten.
Expand source code
def should_overwrite(filepath, overwrite): """Returns whether the filepath should be overwritten.""" # If file exists and should not be overwritten. if not overwrite and os.path.isfile(filepath): return ask_to_proceed_with_overwrite(filepath) return True
def trace_model_call(model, input_signature=None)
-
Trace the model call to create a tf.function for exporting a Keras model.
Args
model
- A Keras model.
input_signature
- optional, a list of tf.TensorSpec objects specifying the inputs to the model.
Returns
A tf.function wrapping the model's call function with input signatures set.
Raises
ValueError
- if input signature cannot be inferred from the model.
Expand source code
def trace_model_call(model, input_signature=None): """Trace the model call to create a tf.function for exporting a Keras model. Args: model: A Keras model. input_signature: optional, a list of tf.TensorSpec objects specifying the inputs to the model. Returns: A tf.function wrapping the model's call function with input signatures set. Raises: ValueError: if input signature cannot be inferred from the model. """ if input_signature is None: if isinstance(model.call, tf.__internal__.function.Function): input_signature = model.call.input_signature if input_signature: model_args = input_signature model_kwargs = {} else: model_args, model_kwargs = model_call_inputs(model) input_signature = model_args # store if model_args is None: raise_model_input_error(model) @tf.function def _wrapped_model(*args, **kwargs): """A concrete tf.function that wraps the model's call function.""" kwargs['training'] = False with base_layer_utils.call_context().enter( model, inputs=None, build_graph=False, training=False, saving=True): outputs = model(*args, **kwargs) # Outputs always has to be a flat dict. output_names = model.output_names # Functional Model. if output_names is None: # Subclassed Model. from keras.engine import compile_utils # pylint: disable=g-import-not-at-top output_names = compile_utils.create_pseudo_output_names(outputs) outputs = tf.nest.flatten(outputs) return {name: output for name, output in zip(output_names, outputs)} return _wrapped_model.get_concrete_function(*model_args, **model_kwargs)
def try_build_compiled_arguments(model)
-
Expand source code
def try_build_compiled_arguments(model): if (not version_utils.is_v1_layer_or_model(model) and model.outputs is not None): try: if not model.compiled_loss.built: model.compiled_loss.build(model.outputs) if not model.compiled_metrics.built: model.compiled_metrics.build(model.outputs, model.outputs) except: # pylint: disable=bare-except logging.warning( 'Compiled the loaded model, but the compiled metrics have yet to ' 'be built. `model.compile_metrics` will be empty until you train ' 'or evaluate the model.')