Module keras.saving.saved_model.load
Keras SavedModel deserialization.
Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras SavedModel deserialization."""
import tensorflow.compat.v2 as tf
import os
import re
import types
from google.protobuf import message
from keras import backend
from keras import regularizers
from keras.engine import input_spec
from keras.optimizer_v2 import optimizer_v2
from keras.protobuf import saved_metadata_pb2
from keras.protobuf import versions_pb2
from keras.saving import saving_utils
from keras.saving.saved_model import constants
from keras.saving.saved_model import json_utils
from keras.saving.saved_model import utils
from keras.saving.saved_model.serialized_attributes import CommonEndpoints
from keras.utils import generic_utils
from keras.utils import metrics_utils
from keras.utils.generic_utils import LazyLoader
from tensorflow.python.platform import tf_logging as logging
# To avoid circular dependencies between keras/engine and keras/saving,
# code in keras/saving must delay imports.
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
# once the issue with copybara is fixed.
# pylint:disable=g-inconsistent-quotes
models_lib = LazyLoader("models_lib", globals(),
"keras.models")
base_layer = LazyLoader(
"base_layer", globals(),
"keras.engine.base_layer")
layers_module = LazyLoader(
"layers_module", globals(),
"keras.layers")
input_layer = LazyLoader(
"input_layer", globals(),
"keras.engine.input_layer")
functional_lib = LazyLoader(
"functional_lib", globals(),
"keras.engine.functional")
training_lib = LazyLoader(
"training_lib", globals(),
"keras.engine.training")
training_lib_v1 = LazyLoader(
"training_lib_v1", globals(),
"keras.engine.training_v1")
metrics = LazyLoader("metrics", globals(),
"keras.metrics")
recurrent = LazyLoader(
"recurrent", globals(),
"keras.layers.recurrent")
# pylint:enable=g-inconsistent-quotes
PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
CommonEndpoints.all_checkpointable_objects)
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
"""Loads Keras objects from a SavedModel.
Any Keras layer or model saved to the SavedModel will be loaded back
as Keras objects. Other objects are loaded as regular trackable objects (same
as `tf.saved_model.load`).
Currently, Keras saving/loading only retains the Keras object's weights,
losses, and call function.
The loaded model can be re-compiled, but the original optimizer, compiled loss
functions, and metrics are not retained. This is temporary, and `model.save`
will soon be able to serialize compiled models.
Args:
path: Path to SavedModel.
compile: If true, compile the model after loading it.
options: Optional `tf.saved_model.LoadOptions` object that specifies
options for loading from SavedModel.
Returns:
Object loaded from SavedModel.
"""
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
# TODO(kathywu): Add code to load from objects that contain all endpoints
# Look for metadata file or parse the SavedModel
metadata = saved_metadata_pb2.SavedMetadata()
meta_graph_def = tf.__internal__.saved_model.parse_saved_model(path).meta_graphs[0]
object_graph_def = meta_graph_def.object_graph_def
path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
if tf.compat.v1.gfile.Exists(path_to_metadata_pb):
try:
with tf.io.gfile.GFile(path_to_metadata_pb, 'rb') as f:
file_content = f.read()
metadata.ParseFromString(file_content)
except message.DecodeError as e:
raise IOError('Cannot parse keras metadata {}: {}.'
.format(path_to_metadata_pb, str(e)))
else:
logging.warning('SavedModel saved prior to TF 2.5 detected when loading '
'Keras model. Please ensure that you are saving the model '
'with model.save() or tf.keras.models.save_model(), *NOT* '
'tf.saved_model.save(). To confirm, there should be a file '
'named "keras_metadata.pb" in the SavedModel directory.')
_read_legacy_metadata(object_graph_def, metadata)
if not metadata.nodes:
# When there are no Keras objects, return the results from the core loader
return tf.saved_model.load(path, options=options)
metadata = _update_to_current_version(metadata)
# Recreate layers and metrics using the info stored in the metadata.
keras_loader = KerasObjectLoader(metadata, object_graph_def)
keras_loader.load_layers(compile=compile)
# Generate a dictionary of all loaded nodes.
nodes_to_load = {'root': None}
for node_id, loaded_node in keras_loader.loaded_nodes.items():
nodes_to_load[keras_loader.get_path(node_id)] = loaded_node
loaded = tf.__internal__.saved_model.load_partial(path, nodes_to_load, options=options)
# Finalize the loaded layers and remove the extra tracked dependencies.
keras_loader.finalize_objects()
keras_loader.del_tracking()
model = loaded['root']
# pylint: disable=protected-access
if isinstance(model, training_lib.Model) and compile:
# TODO(kathywu): Use compiled objects from SavedModel, instead of
# creating new objects from the training config.
training_config = model._serialized_attributes['metadata'].get(
'training_config', None)
if training_config is not None:
model.compile(**saving_utils.compile_args_from_training_config(
training_config), from_serialized=True)
saving_utils.try_build_compiled_arguments(model)
if isinstance(model.optimizer, optimizer_v2.OptimizerV2):
if (model.optimizer.get_slot_names()):
logging.warning('Your optimizer uses slots. '
'Slots cannot be restored from saved_model, '
'as a result, your model is starting with '
'a new initialized optimizer.')
else:
logging.warning('No training configuration found in save file, so the '
'model was *not* compiled. Compile it manually.')
# pylint: enable=protected-access
# Force variables and resources to initialize.
if not tf.executing_eagerly():
sess = backend.get_session() # Variables are initialized by this call.
sess.run(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS))
return model
def _update_to_current_version(metadata):
"""Applies version updates to the metadata proto for backwards compat."""
for node in metadata.nodes:
if node.version.producer == 1 and node.identifier in [
constants.MODEL_IDENTIFIER, constants.SEQUENTIAL_IDENTIFIER,
constants.NETWORK_IDENTIFIER]:
node_metadata = json_utils.decode(node.metadata)
save_spec = node_metadata.get('save_spec')
if save_spec is not None:
node_metadata['full_save_spec'] = ([save_spec], {})
node.metadata = json_utils.Encoder().encode(node_metadata)
return metadata
def _read_legacy_metadata(object_graph_def, metadata):
"""Builds a KerasMetadata proto from the SavedModel ObjectGraphDef."""
# Older SavedModels store the metadata directly in the proto instead of the
# separate pb file.
node_paths = _generate_object_paths(object_graph_def)
for node_id, proto in enumerate(object_graph_def.nodes):
if (proto.WhichOneof('kind') == 'user_object' and
proto.user_object.identifier in constants.KERAS_OBJECT_IDENTIFIERS):
if not proto.user_object.metadata:
raise ValueError('Unable to create a Keras model from this SavedModel. '
'This SavedModel was created with '
'`tf.saved_model.save`, and lacks the Keras metadata.'
'Please save your Keras model by calling `model.save`'
'or `tf.keras.models.save_model`.')
metadata.nodes.add(
node_id=node_id,
node_path=node_paths[node_id],
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
identifier=proto.user_object.identifier,
metadata=proto.user_object.metadata)
def _generate_object_paths(object_graph_def):
"""Traverses through an ObjectGraphDef and builds a map of all node paths."""
paths = {0: 'root'}
nodes_to_visit = [0]
while nodes_to_visit:
current_node = nodes_to_visit.pop()
current_path = paths[current_node]
for reference in object_graph_def.nodes[current_node].children:
if reference.node_id in paths:
continue
paths[reference.node_id] = '{}.{}'.format(current_path,
reference.local_name)
nodes_to_visit.append(reference.node_id)
return paths
def _is_graph_network(layer):
"""Determines whether the layer is a graph network."""
# pylint: disable=protected-access
if isinstance(layer, RevivedNetwork):
return False
elif isinstance(layer, functional_lib.Functional):
return (layer._is_graph_network or
isinstance(layer, models_lib.Sequential))
return False
class KerasObjectLoader(object):
"""Loader that recreates Keras objects (e.g. layers, models).
Layers and models are revived from either the config or SavedModel following
these rules:
1. If object is a graph network (i.e. Sequential or Functional) then it will
be initialized using the structure from the config only after the children
layers have been created. Graph networks must be initialized with inputs
and outputs, so all child layers must be created beforehand.
2. If object's config exists and the class can be found, then revive from
config.
3. Object may have already been created if its parent was revived from config.
In this case, do nothing.
4. If nothing of the above applies, compose the various artifacts from the
SavedModel to create a subclassed layer or model. At this time, custom
metrics are not supported.
"""
def __init__(self, metadata, object_graph_def):
self._metadata = {x.node_id: x for x in metadata.nodes}
self._proto = object_graph_def
self._node_paths = {node_data.node_id: node_data.node_path
for node_data in metadata.nodes}
self.loaded_nodes = {} # Maps node path -> loaded node
# Store all node ids that have already been traversed when tracking nodes
# that were recreated from the config.
self._traversed_nodes_from_config = set()
# Maps model id -> (blank model obj, list of child layer or their node ids)
# This tracks all layers in functional and sequential models. These models
# are only reconstructed after all of their child layers have been created.
self.model_layer_dependencies = {}
self._models_to_reconstruct = []
def del_tracking(self):
"""Removes tracked references that are only used when loading the model."""
# Now that the node object has been fully loaded, and the checkpoint has
# been restored, the object no longer needs to track objects added from
# SerializedAttributes. (Note that saving a training checkpoint still
# functions correctly, because layers and variables are tracked separately
# by the Layer object.)
# TODO(kathywu): Instead of outright deleting these nodes (which would
# make restoring from a different checkpoint tricky), mark them as extra
# dependencies that are OK to overwrite.
for node in self.loaded_nodes.values():
node = node[0]
if not isinstance(node, base_layer.Layer):
# Loaded nodes can contain other trackable objects created when
# loading layers from the config, such as variables.
continue
for name in PUBLIC_ATTRIBUTES:
node._delete_tracking(name) # pylint: disable=protected-access
if isinstance(node, functional_lib.Functional):
# Delete the temporary layer dependencies, which were used to restore
# the checkpointed values. When the model is live, the user can delete
# or add layers to the model at any time, so these layer dependencies
# may be obsolete.
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
for name in dependencies:
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
node._delete_tracking(name) # pylint: disable=protected-access
def _add_children_recreated_from_config(self, obj, proto, node_id):
"""Recursively records objects recreated from config."""
# pylint: disable=protected-access
if node_id in self._traversed_nodes_from_config:
return
parent_path = self._node_paths[node_id]
self._traversed_nodes_from_config.add(node_id)
obj._maybe_initialize_trackable()
if isinstance(obj, base_layer.Layer) and not obj.built:
metadata = json_utils.decode(self._metadata[node_id].metadata)
self._try_build_layer(obj, node_id, metadata.get('build_input_shape'))
# Create list of all possible children
children = []
# Look for direct children
for reference in proto.children:
obj_child = obj._lookup_dependency(reference.local_name)
children.append((obj_child, reference.node_id, reference.local_name))
# Add metrics that may have been added to the layer._metrics list.
# This is stored in the SavedModel as layer.keras_api.layer_metrics in
# SavedModels created after Tf 2.2.
metric_list_node_id = self._search_for_child_node(
node_id, [constants.KERAS_ATTR, 'layer_metrics'])
if metric_list_node_id is not None and hasattr(obj, '_metrics'):
obj_metrics = {m.name: m for m in obj._metrics}
for reference in self._proto.nodes[metric_list_node_id].children:
metric = obj_metrics.get(reference.local_name)
if metric is not None:
metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR,
reference.local_name)
children.append((metric, reference.node_id, metric_path))
for (obj_child, child_id, child_name) in children:
child_proto = self._proto.nodes[child_id]
if not isinstance(obj_child, tf.__internal__.tracking.Trackable):
continue
if (child_proto.user_object.identifier in
tf.__internal__.saved_model.load.registered_identifiers()):
setter = tf.__internal__.saved_model.load.get_setter(child_proto.user_object)
elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS:
setter = _revive_setter
else:
setter = setattr
# pylint: enable=protected-access
if child_id in self.loaded_nodes:
if self.loaded_nodes[child_id][0] is not obj_child:
# This means that the same trackable object is referenced by two
# different objects that were recreated from the config.
logging.warning(
'Looks like there is an object (perhaps variable or '
'layer) that is shared between different layers/models. '
'This may cause issues when restoring the variable '
'values. Object: {}'.format(obj_child))
continue
# Overwrite variable names with the ones saved in the SavedModel.
if (child_proto.WhichOneof('kind') == 'variable' and
child_proto.variable.name):
obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access
if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure):
setter = lambda *args: None
child_path = '{}.{}'.format(parent_path, child_name)
self._node_paths[child_id] = child_path
self._add_children_recreated_from_config(
obj_child, child_proto, child_id)
self.loaded_nodes[child_id] = obj_child, setter
def load_layers(self, compile=True): # pylint: disable=redefined-builtin
"""Load all layer nodes from the metadata."""
# Load metrics after models and layers, since it's likely that models
# and layers will create the metric when initialized (this avoids wasting
# time by creating objects multiple times).
metric_list = []
for node_metadata in self._metadata.values():
if node_metadata.identifier == constants.METRIC_IDENTIFIER:
metric_list.append(node_metadata)
continue
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
for node_metadata in metric_list:
try:
self.loaded_nodes[node_metadata.node_id] = self._load_layer(
node_metadata.node_id, node_metadata.identifier,
node_metadata.metadata)
except ValueError:
# Metrics are only needed when the model is compiled later. We ignore
# errors when trying to load custom metrics when `compile=False` until
# custom metrics are serialized properly (b/135550038).
if compile:
raise
logging.warning('Unable to restore custom metric. Please ensure that '
'the layer implements `get_config` and `from_config` '
'when saving. In addition, please use the '
'`custom_objects` arg when calling `load_model()`.')
def _load_layer(self, node_id, identifier, metadata):
"""Load a single layer from a SavedUserObject proto."""
metadata = json_utils.decode(metadata)
# If node was already created
if node_id in self.loaded_nodes:
node, setter = self.loaded_nodes[node_id]
# Revive setter requires the object to have a `_serialized_attributes`
# property. Add it here.
_maybe_add_serialized_attributes(node, metadata)
config = metadata.get('config')
if _is_graph_network(node) and generic_utils.validate_config(config):
child_nodes = self._get_child_layer_node_ids(node_id)
self.model_layer_dependencies[node_id] = (node, child_nodes)
if not child_nodes:
self._models_to_reconstruct.append(node_id)
return node, setter
# Detect whether this object can be revived from the config. If not, then
# revive from the SavedModel instead.
obj, setter = self._revive_from_config(identifier, metadata, node_id)
if obj is None:
obj, setter = revive_custom_object(identifier, metadata)
# Add an attribute that stores the extra functions/objects saved in the
# SavedModel. Most of these functions/objects are ignored, but some are
# used later in the loading process (e.g. the list of regularization
# losses, or the training config of compiled models).
_maybe_add_serialized_attributes(obj, metadata)
return obj, setter
def _revive_from_config(self, identifier, metadata, node_id):
"""Revives a layer/model from config, or returns None."""
if identifier == constants.METRIC_IDENTIFIER:
obj = self._revive_metric_from_config(metadata)
else:
obj = (
self._revive_graph_network(identifier, metadata, node_id) or
self._revive_layer_or_model_from_config(metadata, node_id))
if obj is None:
return None, None
setter = self._config_node_setter(_revive_setter)
self._add_children_recreated_from_config(
obj, self._proto.nodes[node_id], node_id)
return obj, setter
def _revive_graph_network(self, identifier, metadata, node_id):
"""Revives a graph network from config."""
# Determine whether the metadata contains information for reviving a
# functional or Sequential model.
config = metadata.get('config')
if not generic_utils.validate_config(config):
return None
class_name = tf.compat.as_str(metadata['class_name'])
if generic_utils.get_registered_object(class_name) is not None:
return None
model_is_functional_or_sequential = (
metadata.get('is_graph_network', False) or
class_name == 'Sequential' or
class_name == 'Functional')
if not model_is_functional_or_sequential:
return None
# Revive functional and sequential models as blank model objects for now (
# must be initialized to enable setattr tracking and attribute caching).
# Reconstruction of the network is deferred until all of the model's layers
# have been revived.
if class_name == 'Sequential':
model = models_lib.Sequential(name=config['name'])
# The model is a custom Sequential model.
elif identifier == constants.SEQUENTIAL_IDENTIFIER:
# Uses the custom class name, since the config does not have one.
model = models_lib.Sequential(name=class_name)
else:
model = models_lib.Functional(
inputs=[], outputs=[], name=config['name'])
# Record this model and its layers. This will later be used to reconstruct
# the model.
layers = self._get_child_layer_node_ids(node_id)
self.model_layer_dependencies[node_id] = (model, layers)
if not layers:
self._models_to_reconstruct.append(node_id)
return model
def _revive_layer_or_model_from_config(self, metadata, node_id):
"""Revives a layer/custom model from config; returns None if infeasible."""
# Check that the following requirements are met for reviving from config:
# 1. Object can be deserialized from config.
# 2. If the object needs to be built, then the build input shape can be
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
shared_object_id = metadata.get('shared_object_id')
must_restore_from_config = metadata.get('must_restore_from_config')
if not generic_utils.validate_config(config):
return None
try:
obj = layers_module.deserialize(
generic_utils.serialize_keras_class_and_config(
class_name, config, shared_object_id=shared_object_id))
except ValueError:
if must_restore_from_config:
raise RuntimeError(
'Unable to restore a layer of class {cls}. Layers of '
'class {cls} require that the class be provided to '
'the model loading code, either by registering the '
'class using @keras.utils.register_keras_serializable '
'on the class def and including that file in your '
'program, or by passing the class in a '
'keras.utils.CustomObjectScope that wraps this load '
'call.'.format(cls=class_name))
else:
return None
# Use the dtype, name, and trainable status. Often times these are not
# specified in custom configs, so retrieve their values from the metadata.
# pylint: disable=protected-access
obj._name = metadata['name']
if metadata.get('trainable') is not None:
obj.trainable = metadata['trainable']
if metadata.get('dtype') is not None:
obj._set_dtype_policy(metadata['dtype'])
if metadata.get('stateful') is not None:
obj.stateful = metadata['stateful']
# Restore model save spec for subclassed models. (layers do not store a
# SaveSpec)
if isinstance(obj, training_lib.Model):
full_save_spec = metadata.get('full_save_spec')
if full_save_spec is not None:
args_spec, kwargs_spec = full_save_spec
inputs_spec = args_spec.pop(0)
obj._set_save_spec(inputs_spec, args_spec, kwargs_spec)
# pylint: enable=protected-access
build_input_shape = metadata.get('build_input_shape')
built = self._try_build_layer(obj, node_id, build_input_shape)
if not built:
# If the layer cannot be built, revive a custom layer instead.
return None
return obj
def _revive_metric_from_config(self, metadata):
"""Revives a metric object using the config saved in the metadata."""
class_name = tf.compat.as_str(metadata['class_name'])
config = metadata.get('config')
if not generic_utils.validate_config(config):
return None
try:
obj = metrics.deserialize(
generic_utils.serialize_keras_class_and_config(class_name, config))
except ValueError:
return None
build_input_shape = metadata.get('build_input_shape')
if build_input_shape is not None and hasattr(obj, '_build'):
obj._build(build_input_shape) # pylint: disable=protected-access
return obj
def _try_build_layer(self, obj, node_id, build_input_shape):
"""Attempts to build the layer."""
if obj.built or hasattr(obj.build, '_is_default'):
obj.built = True
return True
if build_input_shape is None:
build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True)
if build_input_shape is not None:
obj.build(build_input_shape)
base_layer.Layer.build(obj, build_input_shape)
return True
return False
def _load_edges(self):
"""Add edges for all nodes that are not waiting on initialization."""
for node_id, proto in enumerate(self._proto.nodes):
if node_id not in self.model_layer_dependencies:
self._add_object_graph_edges(proto, node_id)
def get_path(self, node_id):
return self._node_paths[node_id]
def finalize_objects(self):
"""Finish setting up Keras objects.
This function is executed after all objects and functions have been created.
Call functions and losses are attached to each layer, and once all layers
have been fully set up, graph networks are initialized.
Subclassed models that are revived from the SavedModel are treated like
layers, and have their call/loss functions attached here.
"""
# Finish setting up layers and subclassed models. This step attaches call
# functions and losses to each object, and sets model inputs/outputs.
layers_revived_from_config = []
layers_revived_from_saved_model = []
for node_id, (node, _) in self.loaded_nodes.items():
if (not isinstance(node, base_layer.Layer) or
# Don't finalize models until all layers have finished loading.
node_id in self.model_layer_dependencies):
continue
self._unblock_model_reconstruction(node_id, node)
if isinstance(node, input_layer.InputLayer):
continue
elif isinstance(node, metrics.Metric):
continue
if isinstance(node, (RevivedLayer, RevivedInputLayer)):
layers_revived_from_saved_model.append(node)
else:
layers_revived_from_config.append(node)
_finalize_saved_model_layers(layers_revived_from_saved_model)
_finalize_config_layers(layers_revived_from_config)
# Initialize graph networks, now that layer dependencies have been resolved.
self._reconstruct_all_models()
def _unblock_model_reconstruction(self, layer_id, layer):
"""Removes layer from blocking model reconstruction."""
for model_id, v in self.model_layer_dependencies.items():
_, layers = v
if layer_id not in layers:
continue
layers[layers.index(layer_id)] = layer
if all(isinstance(x, base_layer.Layer) for x in layers):
self._models_to_reconstruct.append(model_id)
def _reconstruct_all_models(self):
"""Reconstructs the network structure of all models."""
all_initialized_models = set()
while self._models_to_reconstruct:
model_id = self._models_to_reconstruct.pop(0)
all_initialized_models.add(model_id)
model, layers = self.model_layer_dependencies[model_id]
self._reconstruct_model(model_id, model, layers)
_finalize_config_layers([model])
if all_initialized_models != set(self.model_layer_dependencies.keys()):
# This should not happen.
uninitialized_model_ids = (
set(self.model_layer_dependencies.keys()) - all_initialized_models)
uninitialized_model_names = [
self.model_layer_dependencies[model_id][0].name
for model_id in uninitialized_model_ids]
raise ValueError('Error when loading from SavedModel -- the following '
'models could not be initialized: {}'
.format(uninitialized_model_names))
def _reconstruct_model(self, model_id, model, layers):
"""Reconstructs the network structure."""
config = json_utils.decode(self._metadata[model_id].metadata)['config']
# Set up model inputs
if model.inputs:
# Inputs may already be created if the model is instantiated in another
# object's __init__.
pass
elif isinstance(model, models_lib.Sequential):
if not layers or not isinstance(layers[0], input_layer.InputLayer):
if config['layers'][0]['class_name'] == 'InputLayer':
layers.insert(0, input_layer.InputLayer.from_config(
config['layers'][0]['config']))
elif 'batch_input_shape' in config['layers'][0]['config']:
batch_input_shape = config['layers'][0]['config']['batch_input_shape']
layers.insert(0, input_layer.InputLayer(
input_shape=batch_input_shape[1:],
batch_size=batch_input_shape[0],
dtype=layers[0].dtype,
name=layers[0].name + '_input'))
model.__init__(layers, name=config['name'])
if not model.inputs:
first_layer = self._get_child_layer_node_ids(model_id)[0]
input_specs = self._infer_inputs(first_layer)
input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
model._set_inputs(input_specs) # pylint: disable=protected-access
if not model.built and not isinstance(input_specs, dict):
model.build(input_shapes)
else: # Reconstruct functional model
(inputs, outputs,
created_layers) = functional_lib.reconstruct_from_config(
config, created_layers={layer.name: layer for layer in layers})
model.__init__(inputs, outputs, name=config['name'])
functional_lib.connect_ancillary_layers(model, created_layers)
# Set model dtype.
_set_network_attributes_from_metadata(model)
# Unblock models that are dependent on this model.
self._unblock_model_reconstruction(model_id, model)
def _get_child_layer_node_ids(self, node_id):
"""Returns the node ids of each layer in a Sequential/Functional model."""
# Sequential and Functional track layers with names following the format
# "layer-N". Use this to generate the list of layers.
num_layers = 0
child_layers = {}
pattern = re.compile('layer-(\\d+)')
for child in self._proto.nodes[node_id].children:
m = pattern.match(child.local_name)
if m is None:
continue
layer_n = int(m.group(1))
num_layers = max(layer_n + 1, num_layers)
child_layers[layer_n] = child.node_id
ordered = []
for n in range(num_layers):
child = child_layers.get(n)
if child is None:
break
ordered.append(child)
return ordered
def _search_for_child_node(self, parent_id, path_to_child):
"""Returns node id of child node.
A helper method for traversing the object graph proto.
As an example, say that the object graph proto in the SavedModel contains an
object with the following child and grandchild attributes:
`parent.child_a.child_b`
This method can be used to retrieve the node id of `child_b` using the
parent's node id by calling:
`_search_for_child_node(parent_id, ['child_a', 'child_b'])`.
Args:
parent_id: node id of parent node
path_to_child: list of children names.
Returns:
node_id of child, or None if child isn't found.
"""
if not path_to_child:
return parent_id
for child in self._proto.nodes[parent_id].children:
if child.local_name == path_to_child[0]:
return self._search_for_child_node(child.node_id, path_to_child[1:])
return None
def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
"""Infers input shape of layer from SavedModel functions."""
coder = tf.__internal__.saved_model.StructureCoder()
call_fn_id = self._search_for_child_node(
layer_node_id, ['call_and_return_all_conditional_losses'])
if call_fn_id is None:
return None
concrete_functions = (
self._proto.nodes[call_fn_id].function.concrete_functions)
if not concrete_functions:
return None
call_fn_name = concrete_functions[0]
call_fn_proto = self._proto.concrete_functions[call_fn_name]
structured_input_signature = coder.decode_proto(
call_fn_proto.canonicalized_input_signature)
inputs = structured_input_signature[0][0]
if convert_to_shapes:
return tf.nest.map_structure(lambda spec: spec.shape, inputs)
else:
return inputs
def _config_node_setter(self, setter):
"""Creates edges for nodes that are recreated from config."""
def setattr_wrapper(obj, name, value):
# Avoid overwriting attributes of objects recreated from the config.
if obj._lookup_dependency(name) is None: # pylint: disable=protected-access
setter(obj, name, value)
return setattr_wrapper
def _finalize_saved_model_layers(layers):
"""Runs the final steps of loading Keras Layers from SavedModel."""
# pylint: disable=protected-access
# 1. Set up call functions for all layers initialized from the SavedModel (
# and not the config)
for layer in layers:
layer.built = True
layer_call = getattr(_get_keras_attr(layer),
'call_and_return_conditional_losses', None)
if layer_call and layer_call.concrete_functions:
layer.call = utils.use_wrapped_call(
layer, layer_call, return_method=True)
expects_training_arg = layer._serialized_attributes['metadata'][
'expects_training_arg']
if 'training' in layer_call.function_spec.arg_names:
# This could change the value of `expects_training_arg` if this layer
# doesn't expect a training arg, but has a child layer that does.
expects_training_arg = True
layer._init_call_fn_args(expects_training_arg)
else:
layer.call = types.MethodType(
_unable_to_call_layer_due_to_serialization_issue, layer)
for layer in layers:
# 2. Set model inputs and outputs.
if isinstance(layer, RevivedNetwork):
_set_network_attributes_from_metadata(layer)
if hasattr(_get_keras_attr(layer), 'call_and_return_conditional_losses'):
call_fn = _get_keras_attr(layer).call_and_return_conditional_losses
if not call_fn.concrete_functions:
continue
if call_fn.input_signature is None:
args, kwargs = infer_inputs_from_restored_call_function(call_fn)
args = list(args)
inputs = args.pop(0)
else:
args = call_fn.input_signature
args = list(args)
inputs = args.pop(0)
kwargs = None
layer._set_save_spec(inputs, args, kwargs) # pylint: disable=protected-access
# V1 models require calling _set_inputs to set the `.inputs` attr.
# Skip this step when there are multiple tensor inputs (this behavior
# is not well supported in V1 models).
if not any(isinstance(x, tf.TensorSpec)
for x in tf.nest.flatten([args, kwargs])):
layer._set_inputs(inputs)
# 3. Add losses that aren't generated by the layer.call function.
_restore_layer_unconditional_losses(layer)
_restore_layer_activation_loss(layer)
# 4. Restore metrics list
_restore_layer_metrics(layer)
# pylint: enable=protected-access
def _unable_to_call_layer_due_to_serialization_issue(
layer, *unused_args, **unused_kwargs):
"""Replaces the `layer.call` if the layer was not fully serialized.
Keras Model/Layer serialization is relatively relaxed because SavedModels
are not always loaded back as keras models. Thus, when there is an issue
tracing a non-signature function, a warning is logged instead of raising an
error. This results in a SavedModel where the model's call function is saved,
but the internal layer call functions are not.
When deserialized with `tf.keras.models.load_model`, the internal layers
which do not have serialized call functions should raise an error when called.
Args:
layer: Layer without the serialized call function.
Raises:
ValueError
"""
raise ValueError(
'Cannot call custom layer {} of type {}, because the call function was '
'not serialized to the SavedModel.'
'Please try one of the following methods to fix this issue:'
'\n\n(1) Implement `get_config` and `from_config` in the layer/model '
'class, and pass the object to the `custom_objects` argument when '
'loading the model. For more details, see: '
'https://www.tensorflow.org/guide/keras/save_and_serialize'
'\n\n(2) Ensure that the subclassed model or layer overwrites `call` '
'and not `__call__`. The input shape and dtype will be automatically '
'recorded when the object is called, and used when saving. To manually '
'specify the input shape/dtype, decorate the call function with '
'`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
def _finalize_config_layers(layers):
"""Runs the final steps of loading Keras Layers from config."""
for layer in layers:
# It is assumed that layers define their unconditional losses after being
# recreated from the config and built. The exceptions to this
# are Functional and Sequential models, which only store conditional losses
# (losses dependent on the inputs) in the config. Unconditional losses like
# weight regularization must be revived from the SavedModel.
if _is_graph_network(layer):
_restore_layer_unconditional_losses(layer)
# Some layers, like Dense, record their activation loss function in the
# config. However, not all layers do this, so the activation loss may be
# missing when restored from the config/hdf5.
# TODO(kathywu): Investigate ways to improve the config to ensure consistent
# loading behavior between HDF5 and SavedModel.
_restore_layer_activation_loss(layer)
# Restore metrics list.
_restore_layer_metrics(layer)
# Restore RNN layer states.
if (isinstance(layer, recurrent.RNN) and
layer.stateful and
hasattr(_get_keras_attr(layer), 'states')):
layer.states = getattr(_get_keras_attr(layer), 'states', None)
for variable in tf.nest.flatten(layer.states):
backend.track_variable(variable)
# Perform any layer defined finalization of the layer state.
layer.finalize_state()
def _finalize_metric(metric):
metric.update_state = types.MethodType(metrics_utils.update_state_wrapper(
metric.keras_api.update_state), metric)
metric.result = metric.keras_api.result
def _restore_layer_unconditional_losses(layer):
"""Restore unconditional losses from SavedModel."""
if hasattr(_get_keras_attr(layer), 'layer_regularization_losses'):
losses = getattr(_get_keras_attr(layer), 'layer_regularization_losses', [])
else:
# Some earlier SavedModels may not have layer_regularization_losses
# serialized separately. Fall back to using the regularization_losses
# list if it does not exist.
losses = layer._serialized_attributes.get('regularization_losses', []) # pylint: disable=protected-access
for loss in losses:
layer.add_loss(loss)
def _restore_layer_activation_loss(layer):
"""Restore actiation loss from SavedModel."""
# Use wrapped activity regularizer function if the layer's activity
# regularizer wasn't created during initialization.
activity_regularizer = getattr(_get_keras_attr(layer),
'activity_regularizer_fn', None)
if activity_regularizer and not layer.activity_regularizer:
try:
layer.activity_regularizer = activity_regularizer
except AttributeError:
# This may happen if a layer wrapper is saved with an activity
# regularizer. The wrapper object's activity regularizer is unsettable.
pass
def revive_custom_object(identifier, metadata):
"""Revives object from SavedModel."""
if tf.compat.v1.executing_eagerly_outside_functions():
model_class = training_lib.Model
else:
model_class = training_lib_v1.Model
revived_classes = {
constants.INPUT_LAYER_IDENTIFIER: (
RevivedInputLayer, input_layer.InputLayer),
constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer),
constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class),
constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional),
constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential),
}
parent_classes = revived_classes.get(identifier, None)
if parent_classes is not None:
parent_classes = revived_classes[identifier]
revived_cls = type(
tf.compat.as_str(metadata['class_name']), parent_classes, {})
return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
else:
raise ValueError('Unable to restore custom object of type {} currently. '
'Please make sure that the layer implements `get_config`'
'and `from_config` when saving. In addition, please use '
'the `custom_objects` arg when calling `load_model()`.'
.format(identifier))
def _restore_layer_metrics(layer):
metrics_list = getattr(_get_keras_attr(layer), 'layer_metrics', {})
layer_metrics = {m.name: m for m in layer._metrics} # pylint: disable=protected-access
for name, metric in metrics_list.items():
if name not in layer_metrics:
# Metrics may be added during initialization/building of custom layers.
layer._metrics.append(metric) # pylint: disable=protected-access
# TODO(kathywu): Centrally define keys and functions for both serialization and
# deserialization.
class RevivedLayer(object):
"""Keras layer loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived layer from metadata stored in the SavedModel proto."""
init_args = dict(
name=metadata['name'],
trainable=metadata['trainable'])
if metadata.get('dtype') is not None:
init_args['dtype'] = metadata['dtype']
if metadata.get('batch_input_shape') is not None:
init_args['batch_input_shape'] = metadata['batch_input_shape']
revived_obj = cls(**init_args)
with utils.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
revived_obj._expects_training_arg = metadata['expects_training_arg']
config = metadata.get('config')
if generic_utils.validate_config(config):
revived_obj._config = config
if metadata.get('input_spec') is not None:
revived_obj.input_spec = recursively_deserialize_keras_object(
metadata['input_spec'],
module_objects={'InputSpec': input_spec.InputSpec})
if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer'])
if metadata.get('_is_feature_layer') is not None:
revived_obj._is_feature_layer = metadata['_is_feature_layer']
if metadata.get('stateful') is not None:
revived_obj.stateful = metadata['stateful']
# pylint:enable=protected-access
return revived_obj, _revive_setter
@property
def keras_api(self):
return self._serialized_attributes.get(constants.KERAS_ATTR, None)
def get_config(self):
if hasattr(self, '_config'):
return self._config
else:
raise NotImplementedError
def _revive_setter(layer, name, value):
"""Setter function that saves some attributes to separate dictionary."""
# Many attributes in the SavedModel conflict with properties defined in
# Layer and Model. Save these attributes to a separate dictionary.
if name in PUBLIC_ATTRIBUTES:
# pylint: disable=protected-access
if isinstance(value, tf.__internal__.tracking.Trackable):
layer._track_trackable(value, name=name)
layer._serialized_attributes[name] = value
# pylint: enable=protected-access
elif (isinstance(layer, functional_lib.Functional) and
re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in
# network._track_layers, should not be added as an attribute. They should
# be temporarily added as a dependency so that checkpointed values can be
# restored. These dependencies are manually deleted in
# KerasObjectLoader.del_tracking.
# Set `overwrite=True` in the case that `layer` already tracks a different
# layer-n. This may cause variable values to not be loaded properly in the
# original layer-n, but we already warn the users about this
# (ctrl-f "shared between different layers/models").
layer._track_trackable(value, name, overwrite=True) # pylint: disable=protected-access
elif getattr(layer, name, None) is not None:
# Don't overwrite already defined attributes.
pass
else:
setattr(layer, name, value)
class RevivedInputLayer(object):
"""InputLayer loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Revives the saved InputLayer from the Metadata."""
init_args = dict(
name=metadata['name'],
dtype=metadata['dtype'],
sparse=metadata['sparse'],
ragged=metadata['ragged'],
batch_input_shape=metadata['batch_input_shape'])
revived_obj = cls(**init_args)
with utils.no_automatic_dependency_tracking_scope(revived_obj):
revived_obj._config = metadata['config'] # pylint:disable=protected-access
return revived_obj, setattr
def get_config(self):
return self._config
def recursively_deserialize_keras_object(config, module_objects=None):
"""Deserialize Keras object from a nested structure."""
if isinstance(config, dict):
if 'class_name' in config:
return generic_utils.deserialize_keras_object(
config, module_objects=module_objects)
else:
return {key: recursively_deserialize_keras_object(config[key],
module_objects)
for key in config}
if isinstance(config, (tuple, list)):
return [recursively_deserialize_keras_object(x, module_objects)
for x in config]
else:
raise ValueError('Unable to decode config: {}'.format(config))
def get_common_shape(x, y):
"""Find a `TensorShape` that is compatible with both `x` and `y`."""
if x is None != y is None:
raise RuntimeError(
'Cannot find a common shape when LHS shape is None but RHS shape '
'is not (or vice versa): %s vs. %s' % (x, y))
if x is None:
return None # The associated input was not a Tensor, no shape generated.
if not isinstance(x, tf.TensorShape):
raise TypeError('Expected x to be a TensorShape but saw %s' % (x,))
if not isinstance(y, tf.TensorShape):
raise TypeError('Expected y to be a TensorShape but saw %s' % (y,))
if x.rank != y.rank or x.rank is None:
return tf.TensorShape(None)
dims = []
for dim_x, dim_y in zip(x.dims, y.dims):
if (dim_x != dim_y
or tf.compat.dimension_value(dim_x) is None
or tf.compat.dimension_value(dim_y) is None):
dims.append(None)
else:
dims.append(tf.compat.dimension_value(dim_x))
return tf.TensorShape(dims)
def infer_inputs_from_restored_call_function(fn):
"""Returns TensorSpec of inputs from a restored call function.
Args:
fn: Restored layer call function. It is assumed that `fn` has at least
one concrete function and that the inputs are in the first argument.
Returns:
TensorSpec of call function inputs in the form of (args, kwargs)
"""
def common_spec(x, y):
if not isinstance(x, tf.TensorSpec):
# Doesn't particularly matter what is returned in this case because the
# result will be filtered out in _set_input_shape.
return x
common_shape = get_common_shape(x.shape, y.shape)
if isinstance(x, tf.SparseTensorSpec):
return tf.SparseTensorSpec(common_shape, x.dtype)
elif isinstance(x, tf.RaggedTensorSpec):
return tf.RaggedTensorSpec(common_shape, x.dtype)
return tf.TensorSpec(common_shape, x.dtype, x.name)
spec = fn.concrete_functions[0].structured_input_signature
for concrete in fn.concrete_functions[1:]:
spec2 = concrete.structured_input_signature
spec = tf.nest.map_structure(common_spec, spec, spec2)
return spec
class RevivedNetwork(RevivedLayer):
"""Keras network of layers loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived network from metadata stored in the SavedModel proto."""
revived_obj = cls(name=metadata['name'])
# Store attributes revived from SerializedAttributes in a un-tracked
# dictionary. The attributes are the ones listed in CommonEndpoints or
# "keras_api" for keras-specific attributes.
with utils.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
revived_obj._expects_training_arg = metadata['expects_training_arg']
config = metadata.get('config')
if generic_utils.validate_config(config):
revived_obj._config = config
if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer'])
# pylint:enable=protected-access
return revived_obj, _revive_setter # pylint:disable=protected-access
def _set_network_attributes_from_metadata(revived_obj):
"""Sets attributes recorded in the metadata."""
with utils.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
metadata = revived_obj._serialized_attributes['metadata']
if metadata.get('dtype') is not None:
revived_obj._set_dtype_policy(metadata['dtype'])
revived_obj._trainable = metadata['trainable']
# pylint:enable=protected-access
def _maybe_add_serialized_attributes(layer, metadata):
# Store attributes revived from SerializedAttributes in a un-tracked
# dictionary. The attributes are the ones listed in CommonEndpoints or
# "keras_api" for keras-specific attributes.
if not hasattr(layer, '_serialized_attributes'):
with utils.no_automatic_dependency_tracking_scope(layer):
layer._serialized_attributes = {'metadata': metadata} # pylint: disable=protected-access
def _get_keras_attr(layer):
return getattr(layer, '_serialized_attributes', {}).get(constants.KERAS_ATTR,
None)
Functions
def get_common_shape(x, y)
-
Find a
TensorShape
that is compatible with bothx
andy
.Expand source code
def get_common_shape(x, y): """Find a `TensorShape` that is compatible with both `x` and `y`.""" if x is None != y is None: raise RuntimeError( 'Cannot find a common shape when LHS shape is None but RHS shape ' 'is not (or vice versa): %s vs. %s' % (x, y)) if x is None: return None # The associated input was not a Tensor, no shape generated. if not isinstance(x, tf.TensorShape): raise TypeError('Expected x to be a TensorShape but saw %s' % (x,)) if not isinstance(y, tf.TensorShape): raise TypeError('Expected y to be a TensorShape but saw %s' % (y,)) if x.rank != y.rank or x.rank is None: return tf.TensorShape(None) dims = [] for dim_x, dim_y in zip(x.dims, y.dims): if (dim_x != dim_y or tf.compat.dimension_value(dim_x) is None or tf.compat.dimension_value(dim_y) is None): dims.append(None) else: dims.append(tf.compat.dimension_value(dim_x)) return tf.TensorShape(dims)
def infer_inputs_from_restored_call_function(fn)
-
Returns TensorSpec of inputs from a restored call function.
Args
fn
- Restored layer call function. It is assumed that
fn
has at least one concrete function and that the inputs are in the first argument.
Returns
TensorSpec of call function inputs in the form of (args, kwargs)
Expand source code
def infer_inputs_from_restored_call_function(fn): """Returns TensorSpec of inputs from a restored call function. Args: fn: Restored layer call function. It is assumed that `fn` has at least one concrete function and that the inputs are in the first argument. Returns: TensorSpec of call function inputs in the form of (args, kwargs) """ def common_spec(x, y): if not isinstance(x, tf.TensorSpec): # Doesn't particularly matter what is returned in this case because the # result will be filtered out in _set_input_shape. return x common_shape = get_common_shape(x.shape, y.shape) if isinstance(x, tf.SparseTensorSpec): return tf.SparseTensorSpec(common_shape, x.dtype) elif isinstance(x, tf.RaggedTensorSpec): return tf.RaggedTensorSpec(common_shape, x.dtype) return tf.TensorSpec(common_shape, x.dtype, x.name) spec = fn.concrete_functions[0].structured_input_signature for concrete in fn.concrete_functions[1:]: spec2 = concrete.structured_input_signature spec = tf.nest.map_structure(common_spec, spec, spec2) return spec
def load(path, compile=True, options=None)
-
Loads Keras objects from a SavedModel.
Any Keras layer or model saved to the SavedModel will be loaded back as Keras objects. Other objects are loaded as regular trackable objects (same as
tf.saved_model.load
).Currently, Keras saving/loading only retains the Keras object's weights, losses, and call function.
The loaded model can be re-compiled, but the original optimizer, compiled loss functions, and metrics are not retained. This is temporary, and
model.save
will soon be able to serialize compiled models.Args
path
- Path to SavedModel.
compile
- If true, compile the model after loading it.
options
- Optional
tf.saved_model.LoadOptions
object that specifies options for loading from SavedModel.
Returns
Object loaded from SavedModel.
Expand source code
def load(path, compile=True, options=None): # pylint: disable=redefined-builtin """Loads Keras objects from a SavedModel. Any Keras layer or model saved to the SavedModel will be loaded back as Keras objects. Other objects are loaded as regular trackable objects (same as `tf.saved_model.load`). Currently, Keras saving/loading only retains the Keras object's weights, losses, and call function. The loaded model can be re-compiled, but the original optimizer, compiled loss functions, and metrics are not retained. This is temporary, and `model.save` will soon be able to serialize compiled models. Args: path: Path to SavedModel. compile: If true, compile the model after loading it. options: Optional `tf.saved_model.LoadOptions` object that specifies options for loading from SavedModel. Returns: Object loaded from SavedModel. """ # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints # Look for metadata file or parse the SavedModel metadata = saved_metadata_pb2.SavedMetadata() meta_graph_def = tf.__internal__.saved_model.parse_saved_model(path).meta_graphs[0] object_graph_def = meta_graph_def.object_graph_def path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) if tf.compat.v1.gfile.Exists(path_to_metadata_pb): try: with tf.io.gfile.GFile(path_to_metadata_pb, 'rb') as f: file_content = f.read() metadata.ParseFromString(file_content) except message.DecodeError as e: raise IOError('Cannot parse keras metadata {}: {}.' .format(path_to_metadata_pb, str(e))) else: logging.warning('SavedModel saved prior to TF 2.5 detected when loading ' 'Keras model. Please ensure that you are saving the model ' 'with model.save() or tf.keras.models.save_model(), *NOT* ' 'tf.saved_model.save(). To confirm, there should be a file ' 'named "keras_metadata.pb" in the SavedModel directory.') _read_legacy_metadata(object_graph_def, metadata) if not metadata.nodes: # When there are no Keras objects, return the results from the core loader return tf.saved_model.load(path, options=options) metadata = _update_to_current_version(metadata) # Recreate layers and metrics using the info stored in the metadata. keras_loader = KerasObjectLoader(metadata, object_graph_def) keras_loader.load_layers(compile=compile) # Generate a dictionary of all loaded nodes. nodes_to_load = {'root': None} for node_id, loaded_node in keras_loader.loaded_nodes.items(): nodes_to_load[keras_loader.get_path(node_id)] = loaded_node loaded = tf.__internal__.saved_model.load_partial(path, nodes_to_load, options=options) # Finalize the loaded layers and remove the extra tracked dependencies. keras_loader.finalize_objects() keras_loader.del_tracking() model = loaded['root'] # pylint: disable=protected-access if isinstance(model, training_lib.Model) and compile: # TODO(kathywu): Use compiled objects from SavedModel, instead of # creating new objects from the training config. training_config = model._serialized_attributes['metadata'].get( 'training_config', None) if training_config is not None: model.compile(**saving_utils.compile_args_from_training_config( training_config), from_serialized=True) saving_utils.try_build_compiled_arguments(model) if isinstance(model.optimizer, optimizer_v2.OptimizerV2): if (model.optimizer.get_slot_names()): logging.warning('Your optimizer uses slots. ' 'Slots cannot be restored from saved_model, ' 'as a result, your model is starting with ' 'a new initialized optimizer.') else: logging.warning('No training configuration found in save file, so the ' 'model was *not* compiled. Compile it manually.') # pylint: enable=protected-access # Force variables and resources to initialize. if not tf.executing_eagerly(): sess = backend.get_session() # Variables are initialized by this call. sess.run(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)) return model
def recursively_deserialize_keras_object(config, module_objects=None)
-
Deserialize Keras object from a nested structure.
Expand source code
def recursively_deserialize_keras_object(config, module_objects=None): """Deserialize Keras object from a nested structure.""" if isinstance(config, dict): if 'class_name' in config: return generic_utils.deserialize_keras_object( config, module_objects=module_objects) else: return {key: recursively_deserialize_keras_object(config[key], module_objects) for key in config} if isinstance(config, (tuple, list)): return [recursively_deserialize_keras_object(x, module_objects) for x in config] else: raise ValueError('Unable to decode config: {}'.format(config))
def revive_custom_object(identifier, metadata)
-
Revives object from SavedModel.
Expand source code
def revive_custom_object(identifier, metadata): """Revives object from SavedModel.""" if tf.compat.v1.executing_eagerly_outside_functions(): model_class = training_lib.Model else: model_class = training_lib_v1.Model revived_classes = { constants.INPUT_LAYER_IDENTIFIER: ( RevivedInputLayer, input_layer.InputLayer), constants.LAYER_IDENTIFIER: (RevivedLayer, base_layer.Layer), constants.MODEL_IDENTIFIER: (RevivedNetwork, model_class), constants.NETWORK_IDENTIFIER: (RevivedNetwork, functional_lib.Functional), constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork, models_lib.Sequential), } parent_classes = revived_classes.get(identifier, None) if parent_classes is not None: parent_classes = revived_classes[identifier] revived_cls = type( tf.compat.as_str(metadata['class_name']), parent_classes, {}) return revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access else: raise ValueError('Unable to restore custom object of type {} currently. ' 'Please make sure that the layer implements `get_config`' 'and `from_config` when saving. In addition, please use ' 'the `custom_objects` arg when calling `load_model()`.' .format(identifier))
Classes
class KerasObjectLoader (metadata, object_graph_def)
-
Loader that recreates Keras objects (e.g. layers, models).
Layers and models are revived from either the config or SavedModel following these rules: 1. If object is a graph network (i.e. Sequential or Functional) then it will be initialized using the structure from the config only after the children layers have been created. Graph networks must be initialized with inputs and outputs, so all child layers must be created beforehand. 2. If object's config exists and the class can be found, then revive from config. 3. Object may have already been created if its parent was revived from config. In this case, do nothing. 4. If nothing of the above applies, compose the various artifacts from the SavedModel to create a subclassed layer or model. At this time, custom metrics are not supported.
Expand source code
class KerasObjectLoader(object): """Loader that recreates Keras objects (e.g. layers, models). Layers and models are revived from either the config or SavedModel following these rules: 1. If object is a graph network (i.e. Sequential or Functional) then it will be initialized using the structure from the config only after the children layers have been created. Graph networks must be initialized with inputs and outputs, so all child layers must be created beforehand. 2. If object's config exists and the class can be found, then revive from config. 3. Object may have already been created if its parent was revived from config. In this case, do nothing. 4. If nothing of the above applies, compose the various artifacts from the SavedModel to create a subclassed layer or model. At this time, custom metrics are not supported. """ def __init__(self, metadata, object_graph_def): self._metadata = {x.node_id: x for x in metadata.nodes} self._proto = object_graph_def self._node_paths = {node_data.node_id: node_data.node_path for node_data in metadata.nodes} self.loaded_nodes = {} # Maps node path -> loaded node # Store all node ids that have already been traversed when tracking nodes # that were recreated from the config. self._traversed_nodes_from_config = set() # Maps model id -> (blank model obj, list of child layer or their node ids) # This tracks all layers in functional and sequential models. These models # are only reconstructed after all of their child layers have been created. self.model_layer_dependencies = {} self._models_to_reconstruct = [] def del_tracking(self): """Removes tracked references that are only used when loading the model.""" # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked separately # by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self.loaded_nodes.values(): node = node[0] if not isinstance(node, base_layer.Layer): # Loaded nodes can contain other trackable objects created when # loading layers from the config, such as variables. continue for name in PUBLIC_ATTRIBUTES: node._delete_tracking(name) # pylint: disable=protected-access if isinstance(node, functional_lib.Functional): # Delete the temporary layer dependencies, which were used to restore # the checkpointed values. When the model is live, the user can delete # or add layers to the model at any time, so these layer dependencies # may be obsolete. dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access for name in dependencies: if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: node._delete_tracking(name) # pylint: disable=protected-access def _add_children_recreated_from_config(self, obj, proto, node_id): """Recursively records objects recreated from config.""" # pylint: disable=protected-access if node_id in self._traversed_nodes_from_config: return parent_path = self._node_paths[node_id] self._traversed_nodes_from_config.add(node_id) obj._maybe_initialize_trackable() if isinstance(obj, base_layer.Layer) and not obj.built: metadata = json_utils.decode(self._metadata[node_id].metadata) self._try_build_layer(obj, node_id, metadata.get('build_input_shape')) # Create list of all possible children children = [] # Look for direct children for reference in proto.children: obj_child = obj._lookup_dependency(reference.local_name) children.append((obj_child, reference.node_id, reference.local_name)) # Add metrics that may have been added to the layer._metrics list. # This is stored in the SavedModel as layer.keras_api.layer_metrics in # SavedModels created after Tf 2.2. metric_list_node_id = self._search_for_child_node( node_id, [constants.KERAS_ATTR, 'layer_metrics']) if metric_list_node_id is not None and hasattr(obj, '_metrics'): obj_metrics = {m.name: m for m in obj._metrics} for reference in self._proto.nodes[metric_list_node_id].children: metric = obj_metrics.get(reference.local_name) if metric is not None: metric_path = '{}.layer_metrics.{}'.format(constants.KERAS_ATTR, reference.local_name) children.append((metric, reference.node_id, metric_path)) for (obj_child, child_id, child_name) in children: child_proto = self._proto.nodes[child_id] if not isinstance(obj_child, tf.__internal__.tracking.Trackable): continue if (child_proto.user_object.identifier in tf.__internal__.saved_model.load.registered_identifiers()): setter = tf.__internal__.saved_model.load.get_setter(child_proto.user_object) elif obj_child._object_identifier in constants.KERAS_OBJECT_IDENTIFIERS: setter = _revive_setter else: setter = setattr # pylint: enable=protected-access if child_id in self.loaded_nodes: if self.loaded_nodes[child_id][0] is not obj_child: # This means that the same trackable object is referenced by two # different objects that were recreated from the config. logging.warning( 'Looks like there is an object (perhaps variable or ' 'layer) that is shared between different layers/models. ' 'This may cause issues when restoring the variable ' 'values. Object: {}'.format(obj_child)) continue # Overwrite variable names with the ones saved in the SavedModel. if (child_proto.WhichOneof('kind') == 'variable' and child_proto.variable.name): obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure): setter = lambda *args: None child_path = '{}.{}'.format(parent_path, child_name) self._node_paths[child_id] = child_path self._add_children_recreated_from_config( obj_child, child_proto, child_id) self.loaded_nodes[child_id] = obj_child, setter def load_layers(self, compile=True): # pylint: disable=redefined-builtin """Load all layer nodes from the metadata.""" # Load metrics after models and layers, since it's likely that models # and layers will create the metric when initialized (this avoids wasting # time by creating objects multiple times). metric_list = [] for node_metadata in self._metadata.values(): if node_metadata.identifier == constants.METRIC_IDENTIFIER: metric_list.append(node_metadata) continue self.loaded_nodes[node_metadata.node_id] = self._load_layer( node_metadata.node_id, node_metadata.identifier, node_metadata.metadata) for node_metadata in metric_list: try: self.loaded_nodes[node_metadata.node_id] = self._load_layer( node_metadata.node_id, node_metadata.identifier, node_metadata.metadata) except ValueError: # Metrics are only needed when the model is compiled later. We ignore # errors when trying to load custom metrics when `compile=False` until # custom metrics are serialized properly (b/135550038). if compile: raise logging.warning('Unable to restore custom metric. Please ensure that ' 'the layer implements `get_config` and `from_config` ' 'when saving. In addition, please use the ' '`custom_objects` arg when calling `load_model()`.') def _load_layer(self, node_id, identifier, metadata): """Load a single layer from a SavedUserObject proto.""" metadata = json_utils.decode(metadata) # If node was already created if node_id in self.loaded_nodes: node, setter = self.loaded_nodes[node_id] # Revive setter requires the object to have a `_serialized_attributes` # property. Add it here. _maybe_add_serialized_attributes(node, metadata) config = metadata.get('config') if _is_graph_network(node) and generic_utils.validate_config(config): child_nodes = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (node, child_nodes) if not child_nodes: self._models_to_reconstruct.append(node_id) return node, setter # Detect whether this object can be revived from the config. If not, then # revive from the SavedModel instead. obj, setter = self._revive_from_config(identifier, metadata, node_id) if obj is None: obj, setter = revive_custom_object(identifier, metadata) # Add an attribute that stores the extra functions/objects saved in the # SavedModel. Most of these functions/objects are ignored, but some are # used later in the loading process (e.g. the list of regularization # losses, or the training config of compiled models). _maybe_add_serialized_attributes(obj, metadata) return obj, setter def _revive_from_config(self, identifier, metadata, node_id): """Revives a layer/model from config, or returns None.""" if identifier == constants.METRIC_IDENTIFIER: obj = self._revive_metric_from_config(metadata) else: obj = ( self._revive_graph_network(identifier, metadata, node_id) or self._revive_layer_or_model_from_config(metadata, node_id)) if obj is None: return None, None setter = self._config_node_setter(_revive_setter) self._add_children_recreated_from_config( obj, self._proto.nodes[node_id], node_id) return obj, setter def _revive_graph_network(self, identifier, metadata, node_id): """Revives a graph network from config.""" # Determine whether the metadata contains information for reviving a # functional or Sequential model. config = metadata.get('config') if not generic_utils.validate_config(config): return None class_name = tf.compat.as_str(metadata['class_name']) if generic_utils.get_registered_object(class_name) is not None: return None model_is_functional_or_sequential = ( metadata.get('is_graph_network', False) or class_name == 'Sequential' or class_name == 'Functional') if not model_is_functional_or_sequential: return None # Revive functional and sequential models as blank model objects for now ( # must be initialized to enable setattr tracking and attribute caching). # Reconstruction of the network is deferred until all of the model's layers # have been revived. if class_name == 'Sequential': model = models_lib.Sequential(name=config['name']) # The model is a custom Sequential model. elif identifier == constants.SEQUENTIAL_IDENTIFIER: # Uses the custom class name, since the config does not have one. model = models_lib.Sequential(name=class_name) else: model = models_lib.Functional( inputs=[], outputs=[], name=config['name']) # Record this model and its layers. This will later be used to reconstruct # the model. layers = self._get_child_layer_node_ids(node_id) self.model_layer_dependencies[node_id] = (model, layers) if not layers: self._models_to_reconstruct.append(node_id) return model def _revive_layer_or_model_from_config(self, metadata, node_id): """Revives a layer/custom model from config; returns None if infeasible.""" # Check that the following requirements are met for reviving from config: # 1. Object can be deserialized from config. # 2. If the object needs to be built, then the build input shape can be # found. class_name = metadata.get('class_name') config = metadata.get('config') shared_object_id = metadata.get('shared_object_id') must_restore_from_config = metadata.get('must_restore_from_config') if not generic_utils.validate_config(config): return None try: obj = layers_module.deserialize( generic_utils.serialize_keras_class_and_config( class_name, config, shared_object_id=shared_object_id)) except ValueError: if must_restore_from_config: raise RuntimeError( 'Unable to restore a layer of class {cls}. Layers of ' 'class {cls} require that the class be provided to ' 'the model loading code, either by registering the ' 'class using @keras.utils.register_keras_serializable ' 'on the class def and including that file in your ' 'program, or by passing the class in a ' 'keras.utils.CustomObjectScope that wraps this load ' 'call.'.format(cls=class_name)) else: return None # Use the dtype, name, and trainable status. Often times these are not # specified in custom configs, so retrieve their values from the metadata. # pylint: disable=protected-access obj._name = metadata['name'] if metadata.get('trainable') is not None: obj.trainable = metadata['trainable'] if metadata.get('dtype') is not None: obj._set_dtype_policy(metadata['dtype']) if metadata.get('stateful') is not None: obj.stateful = metadata['stateful'] # Restore model save spec for subclassed models. (layers do not store a # SaveSpec) if isinstance(obj, training_lib.Model): full_save_spec = metadata.get('full_save_spec') if full_save_spec is not None: args_spec, kwargs_spec = full_save_spec inputs_spec = args_spec.pop(0) obj._set_save_spec(inputs_spec, args_spec, kwargs_spec) # pylint: enable=protected-access build_input_shape = metadata.get('build_input_shape') built = self._try_build_layer(obj, node_id, build_input_shape) if not built: # If the layer cannot be built, revive a custom layer instead. return None return obj def _revive_metric_from_config(self, metadata): """Revives a metric object using the config saved in the metadata.""" class_name = tf.compat.as_str(metadata['class_name']) config = metadata.get('config') if not generic_utils.validate_config(config): return None try: obj = metrics.deserialize( generic_utils.serialize_keras_class_and_config(class_name, config)) except ValueError: return None build_input_shape = metadata.get('build_input_shape') if build_input_shape is not None and hasattr(obj, '_build'): obj._build(build_input_shape) # pylint: disable=protected-access return obj def _try_build_layer(self, obj, node_id, build_input_shape): """Attempts to build the layer.""" if obj.built or hasattr(obj.build, '_is_default'): obj.built = True return True if build_input_shape is None: build_input_shape = self._infer_inputs(node_id, convert_to_shapes=True) if build_input_shape is not None: obj.build(build_input_shape) base_layer.Layer.build(obj, build_input_shape) return True return False def _load_edges(self): """Add edges for all nodes that are not waiting on initialization.""" for node_id, proto in enumerate(self._proto.nodes): if node_id not in self.model_layer_dependencies: self._add_object_graph_edges(proto, node_id) def get_path(self, node_id): return self._node_paths[node_id] def finalize_objects(self): """Finish setting up Keras objects. This function is executed after all objects and functions have been created. Call functions and losses are attached to each layer, and once all layers have been fully set up, graph networks are initialized. Subclassed models that are revived from the SavedModel are treated like layers, and have their call/loss functions attached here. """ # Finish setting up layers and subclassed models. This step attaches call # functions and losses to each object, and sets model inputs/outputs. layers_revived_from_config = [] layers_revived_from_saved_model = [] for node_id, (node, _) in self.loaded_nodes.items(): if (not isinstance(node, base_layer.Layer) or # Don't finalize models until all layers have finished loading. node_id in self.model_layer_dependencies): continue self._unblock_model_reconstruction(node_id, node) if isinstance(node, input_layer.InputLayer): continue elif isinstance(node, metrics.Metric): continue if isinstance(node, (RevivedLayer, RevivedInputLayer)): layers_revived_from_saved_model.append(node) else: layers_revived_from_config.append(node) _finalize_saved_model_layers(layers_revived_from_saved_model) _finalize_config_layers(layers_revived_from_config) # Initialize graph networks, now that layer dependencies have been resolved. self._reconstruct_all_models() def _unblock_model_reconstruction(self, layer_id, layer): """Removes layer from blocking model reconstruction.""" for model_id, v in self.model_layer_dependencies.items(): _, layers = v if layer_id not in layers: continue layers[layers.index(layer_id)] = layer if all(isinstance(x, base_layer.Layer) for x in layers): self._models_to_reconstruct.append(model_id) def _reconstruct_all_models(self): """Reconstructs the network structure of all models.""" all_initialized_models = set() while self._models_to_reconstruct: model_id = self._models_to_reconstruct.pop(0) all_initialized_models.add(model_id) model, layers = self.model_layer_dependencies[model_id] self._reconstruct_model(model_id, model, layers) _finalize_config_layers([model]) if all_initialized_models != set(self.model_layer_dependencies.keys()): # This should not happen. uninitialized_model_ids = ( set(self.model_layer_dependencies.keys()) - all_initialized_models) uninitialized_model_names = [ self.model_layer_dependencies[model_id][0].name for model_id in uninitialized_model_ids] raise ValueError('Error when loading from SavedModel -- the following ' 'models could not be initialized: {}' .format(uninitialized_model_names)) def _reconstruct_model(self, model_id, model, layers): """Reconstructs the network structure.""" config = json_utils.decode(self._metadata[model_id].metadata)['config'] # Set up model inputs if model.inputs: # Inputs may already be created if the model is instantiated in another # object's __init__. pass elif isinstance(model, models_lib.Sequential): if not layers or not isinstance(layers[0], input_layer.InputLayer): if config['layers'][0]['class_name'] == 'InputLayer': layers.insert(0, input_layer.InputLayer.from_config( config['layers'][0]['config'])) elif 'batch_input_shape' in config['layers'][0]['config']: batch_input_shape = config['layers'][0]['config']['batch_input_shape'] layers.insert(0, input_layer.InputLayer( input_shape=batch_input_shape[1:], batch_size=batch_input_shape[0], dtype=layers[0].dtype, name=layers[0].name + '_input')) model.__init__(layers, name=config['name']) if not model.inputs: first_layer = self._get_child_layer_node_ids(model_id)[0] input_specs = self._infer_inputs(first_layer) input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) model._set_inputs(input_specs) # pylint: disable=protected-access if not model.built and not isinstance(input_specs, dict): model.build(input_shapes) else: # Reconstruct functional model (inputs, outputs, created_layers) = functional_lib.reconstruct_from_config( config, created_layers={layer.name: layer for layer in layers}) model.__init__(inputs, outputs, name=config['name']) functional_lib.connect_ancillary_layers(model, created_layers) # Set model dtype. _set_network_attributes_from_metadata(model) # Unblock models that are dependent on this model. self._unblock_model_reconstruction(model_id, model) def _get_child_layer_node_ids(self, node_id): """Returns the node ids of each layer in a Sequential/Functional model.""" # Sequential and Functional track layers with names following the format # "layer-N". Use this to generate the list of layers. num_layers = 0 child_layers = {} pattern = re.compile('layer-(\\d+)') for child in self._proto.nodes[node_id].children: m = pattern.match(child.local_name) if m is None: continue layer_n = int(m.group(1)) num_layers = max(layer_n + 1, num_layers) child_layers[layer_n] = child.node_id ordered = [] for n in range(num_layers): child = child_layers.get(n) if child is None: break ordered.append(child) return ordered def _search_for_child_node(self, parent_id, path_to_child): """Returns node id of child node. A helper method for traversing the object graph proto. As an example, say that the object graph proto in the SavedModel contains an object with the following child and grandchild attributes: `parent.child_a.child_b` This method can be used to retrieve the node id of `child_b` using the parent's node id by calling: `_search_for_child_node(parent_id, ['child_a', 'child_b'])`. Args: parent_id: node id of parent node path_to_child: list of children names. Returns: node_id of child, or None if child isn't found. """ if not path_to_child: return parent_id for child in self._proto.nodes[parent_id].children: if child.local_name == path_to_child[0]: return self._search_for_child_node(child.node_id, path_to_child[1:]) return None def _infer_inputs(self, layer_node_id, convert_to_shapes=False): """Infers input shape of layer from SavedModel functions.""" coder = tf.__internal__.saved_model.StructureCoder() call_fn_id = self._search_for_child_node( layer_node_id, ['call_and_return_all_conditional_losses']) if call_fn_id is None: return None concrete_functions = ( self._proto.nodes[call_fn_id].function.concrete_functions) if not concrete_functions: return None call_fn_name = concrete_functions[0] call_fn_proto = self._proto.concrete_functions[call_fn_name] structured_input_signature = coder.decode_proto( call_fn_proto.canonicalized_input_signature) inputs = structured_input_signature[0][0] if convert_to_shapes: return tf.nest.map_structure(lambda spec: spec.shape, inputs) else: return inputs def _config_node_setter(self, setter): """Creates edges for nodes that are recreated from config.""" def setattr_wrapper(obj, name, value): # Avoid overwriting attributes of objects recreated from the config. if obj._lookup_dependency(name) is None: # pylint: disable=protected-access setter(obj, name, value) return setattr_wrapper
Methods
def del_tracking(self)
-
Removes tracked references that are only used when loading the model.
Expand source code
def del_tracking(self): """Removes tracked references that are only used when loading the model.""" # Now that the node object has been fully loaded, and the checkpoint has # been restored, the object no longer needs to track objects added from # SerializedAttributes. (Note that saving a training checkpoint still # functions correctly, because layers and variables are tracked separately # by the Layer object.) # TODO(kathywu): Instead of outright deleting these nodes (which would # make restoring from a different checkpoint tricky), mark them as extra # dependencies that are OK to overwrite. for node in self.loaded_nodes.values(): node = node[0] if not isinstance(node, base_layer.Layer): # Loaded nodes can contain other trackable objects created when # loading layers from the config, such as variables. continue for name in PUBLIC_ATTRIBUTES: node._delete_tracking(name) # pylint: disable=protected-access if isinstance(node, functional_lib.Functional): # Delete the temporary layer dependencies, which were used to restore # the checkpointed values. When the model is live, the user can delete # or add layers to the model at any time, so these layer dependencies # may be obsolete. dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access for name in dependencies: if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None: node._delete_tracking(name) # pylint: disable=protected-access
def finalize_objects(self)
-
Finish setting up Keras objects.
This function is executed after all objects and functions have been created. Call functions and losses are attached to each layer, and once all layers have been fully set up, graph networks are initialized.
Subclassed models that are revived from the SavedModel are treated like layers, and have their call/loss functions attached here.
Expand source code
def finalize_objects(self): """Finish setting up Keras objects. This function is executed after all objects and functions have been created. Call functions and losses are attached to each layer, and once all layers have been fully set up, graph networks are initialized. Subclassed models that are revived from the SavedModel are treated like layers, and have their call/loss functions attached here. """ # Finish setting up layers and subclassed models. This step attaches call # functions and losses to each object, and sets model inputs/outputs. layers_revived_from_config = [] layers_revived_from_saved_model = [] for node_id, (node, _) in self.loaded_nodes.items(): if (not isinstance(node, base_layer.Layer) or # Don't finalize models until all layers have finished loading. node_id in self.model_layer_dependencies): continue self._unblock_model_reconstruction(node_id, node) if isinstance(node, input_layer.InputLayer): continue elif isinstance(node, metrics.Metric): continue if isinstance(node, (RevivedLayer, RevivedInputLayer)): layers_revived_from_saved_model.append(node) else: layers_revived_from_config.append(node) _finalize_saved_model_layers(layers_revived_from_saved_model) _finalize_config_layers(layers_revived_from_config) # Initialize graph networks, now that layer dependencies have been resolved. self._reconstruct_all_models()
def get_path(self, node_id)
-
Expand source code
def get_path(self, node_id): return self._node_paths[node_id]
def load_layers(self, compile=True)
-
Load all layer nodes from the metadata.
Expand source code
def load_layers(self, compile=True): # pylint: disable=redefined-builtin """Load all layer nodes from the metadata.""" # Load metrics after models and layers, since it's likely that models # and layers will create the metric when initialized (this avoids wasting # time by creating objects multiple times). metric_list = [] for node_metadata in self._metadata.values(): if node_metadata.identifier == constants.METRIC_IDENTIFIER: metric_list.append(node_metadata) continue self.loaded_nodes[node_metadata.node_id] = self._load_layer( node_metadata.node_id, node_metadata.identifier, node_metadata.metadata) for node_metadata in metric_list: try: self.loaded_nodes[node_metadata.node_id] = self._load_layer( node_metadata.node_id, node_metadata.identifier, node_metadata.metadata) except ValueError: # Metrics are only needed when the model is compiled later. We ignore # errors when trying to load custom metrics when `compile=False` until # custom metrics are serialized properly (b/135550038). if compile: raise logging.warning('Unable to restore custom metric. Please ensure that ' 'the layer implements `get_config` and `from_config` ' 'when saving. In addition, please use the ' '`custom_objects` arg when calling `load_model()`.')
class RevivedInputLayer
-
InputLayer loaded from a SavedModel.
Expand source code
class RevivedInputLayer(object): """InputLayer loaded from a SavedModel.""" @classmethod def _init_from_metadata(cls, metadata): """Revives the saved InputLayer from the Metadata.""" init_args = dict( name=metadata['name'], dtype=metadata['dtype'], sparse=metadata['sparse'], ragged=metadata['ragged'], batch_input_shape=metadata['batch_input_shape']) revived_obj = cls(**init_args) with utils.no_automatic_dependency_tracking_scope(revived_obj): revived_obj._config = metadata['config'] # pylint:disable=protected-access return revived_obj, setattr def get_config(self): return self._config
Methods
def get_config(self)
-
Expand source code
def get_config(self): return self._config
class RevivedLayer
-
Keras layer loaded from a SavedModel.
Expand source code
class RevivedLayer(object): """Keras layer loaded from a SavedModel.""" @classmethod def _init_from_metadata(cls, metadata): """Create revived layer from metadata stored in the SavedModel proto.""" init_args = dict( name=metadata['name'], trainable=metadata['trainable']) if metadata.get('dtype') is not None: init_args['dtype'] = metadata['dtype'] if metadata.get('batch_input_shape') is not None: init_args['batch_input_shape'] = metadata['batch_input_shape'] revived_obj = cls(**init_args) with utils.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('input_spec') is not None: revived_obj.input_spec = recursively_deserialize_keras_object( metadata['input_spec'], module_objects={'InputSpec': input_spec.InputSpec}) if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) if metadata.get('_is_feature_layer') is not None: revived_obj._is_feature_layer = metadata['_is_feature_layer'] if metadata.get('stateful') is not None: revived_obj.stateful = metadata['stateful'] # pylint:enable=protected-access return revived_obj, _revive_setter @property def keras_api(self): return self._serialized_attributes.get(constants.KERAS_ATTR, None) def get_config(self): if hasattr(self, '_config'): return self._config else: raise NotImplementedError
Subclasses
Instance variables
var keras_api
-
Expand source code
@property def keras_api(self): return self._serialized_attributes.get(constants.KERAS_ATTR, None)
Methods
def get_config(self)
-
Expand source code
def get_config(self): if hasattr(self, '_config'): return self._config else: raise NotImplementedError
class RevivedNetwork
-
Keras network of layers loaded from a SavedModel.
Expand source code
class RevivedNetwork(RevivedLayer): """Keras network of layers loaded from a SavedModel.""" @classmethod def _init_from_metadata(cls, metadata): """Create revived network from metadata stored in the SavedModel proto.""" revived_obj = cls(name=metadata['name']) # Store attributes revived from SerializedAttributes in a un-tracked # dictionary. The attributes are the ones listed in CommonEndpoints or # "keras_api" for keras-specific attributes. with utils.no_automatic_dependency_tracking_scope(revived_obj): # pylint:disable=protected-access revived_obj._expects_training_arg = metadata['expects_training_arg'] config = metadata.get('config') if generic_utils.validate_config(config): revived_obj._config = config if metadata.get('activity_regularizer') is not None: revived_obj.activity_regularizer = regularizers.deserialize( metadata['activity_regularizer']) # pylint:enable=protected-access return revived_obj, _revive_setter # pylint:disable=protected-access
Ancestors