Module keras.saving.saved_model.layer_serialization
Classes and functions implementing Layer SavedModel serialization.
Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes and functions implementing Layer SavedModel serialization."""
import tensorflow.compat.v2 as tf
from keras.mixed_precision import policy
from keras.saving.saved_model import base_serialization
from keras.saving.saved_model import constants
from keras.saving.saved_model import save_impl
from keras.saving.saved_model import serialized_attributes
from keras.utils import generic_utils
class LayerSavedModelSaver(base_serialization.SavedModelSaver):
"""Implements Layer SavedModel serialization."""
@property
def object_identifier(self):
return constants.LAYER_IDENTIFIER
@property
def python_properties(self):
# TODO(kathywu): Add python property validator
return self._python_properties_internal()
def _python_properties_internal(self):
"""Returns dictionary of all python properties."""
# TODO(kathywu): Add support for metrics serialization.
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
metadata = dict(
name=self.obj.name,
trainable=self.obj.trainable,
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access
batch_input_shape=getattr(self.obj, '_batch_input_shape', None),
stateful=self.obj.stateful,
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
)
metadata.update(get_serialized(self.obj))
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = tf.nest.map_structure(
lambda x: generic_utils.serialize_keras_object(x) if x else None,
self.obj.input_spec)
if (self.obj.activity_regularizer is not None and
hasattr(self.obj.activity_regularizer, 'get_config')):
metadata['activity_regularizer'] = generic_utils.serialize_keras_object(
self.obj.activity_regularizer)
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata
def objects_to_serialize(self, serialization_cache):
return (self._get_serialized_attributes(
serialization_cache).objects_to_serialize)
def functions_to_serialize(self, serialization_cache):
return (self._get_serialized_attributes(
serialization_cache).functions_to_serialize)
def _get_serialized_attributes(self, serialization_cache):
"""Generates or retrieves serialized attributes from cache."""
keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {})
if self.obj in keras_cache:
return keras_cache[self.obj]
serialized_attr = keras_cache[self.obj] = (
serialized_attributes.SerializedAttributes.new(self.obj))
if (save_impl.should_skip_serialization(self.obj) or
self.obj._must_restore_from_config): # pylint: disable=protected-access
return serialized_attr
object_dict, function_dict = self._get_serialized_attributes_internal(
serialization_cache)
serialized_attr.set_and_validate_objects(object_dict)
serialized_attr.set_and_validate_functions(function_dict)
return serialized_attr
def _get_serialized_attributes_internal(self, serialization_cache):
"""Returns dictionary of serialized attributes."""
objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
# Attribute validator requires that the default save signature is added to
# function dict, even if the value is None.
functions['_default_save_signature'] = None
return objects, functions
# TODO(kathywu): Move serialization utils (and related utils from
# generic_utils.py) to a separate file.
def get_serialized(obj):
with generic_utils.skip_failed_serialization():
# Store the config dictionary, which may be used when reviving the object.
# When loading, the program will attempt to revive the object from config,
# and if that fails, the object will be revived from the SavedModel.
return generic_utils.serialize_keras_object(obj)
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
"""InputLayer serialization."""
@property
def object_identifier(self):
return constants.INPUT_LAYER_IDENTIFIER
@property
def python_properties(self):
return dict(
class_name=type(self.obj).__name__,
name=self.obj.name,
dtype=self.obj.dtype,
sparse=self.obj.sparse,
ragged=self.obj.ragged,
batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access
config=self.obj.get_config())
def objects_to_serialize(self, serialization_cache):
return {}
def functions_to_serialize(self, serialization_cache):
return {}
class RNNSavedModelSaver(LayerSavedModelSaver):
"""RNN layer serialization."""
@property
def object_identifier(self):
return constants.RNN_LAYER_IDENTIFIER
def _get_serialized_attributes_internal(self, serialization_cache):
objects, functions = (
super(RNNSavedModelSaver, self)._get_serialized_attributes_internal(
serialization_cache))
states = tf.__internal__.tracking.wrap(self.obj.states)
# SaveModel require all the objects to be Trackable when saving.
# If the states is still a tuple after wrap_or_unwrap, it means it doesn't
# contain any trackable item within it, eg empty tuple or (None, None) for
# stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can
# make it a Trackable again for saving. When loaded, ConvLSTM2D is
# able to handle the tuple/list conversion.
if isinstance(states, tuple):
states = tf.__internal__.tracking.wrap(list(states))
objects['states'] = states
return objects, functions
class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver):
"""Index lookup layer serialization."""
@property
def python_properties(self):
# TODO(kathywu): Add python property validator
metadata = self._python_properties_internal()
# Clear the vocabulary from the config during saving. The vocab will be
# saved as part of the lookup table directly, which correctly handle saving
# vocabulary files as a SavedModel asset.
metadata['config']['vocabulary'] = None
# Keep a separate config property to track that a vocabulary was passed in
# and not adapted.
metadata['config']['has_input_vocabulary'] = self.obj._has_input_vocabulary # pylint: disable=protected-access
return metadata
Functions
def get_serialized(obj)
-
Expand source code
def get_serialized(obj): with generic_utils.skip_failed_serialization(): # Store the config dictionary, which may be used when reviving the object. # When loading, the program will attempt to revive the object from config, # and if that fails, the object will be revived from the SavedModel. return generic_utils.serialize_keras_object(obj)
Classes
class IndexLookupLayerSavedModelSaver (obj)
-
Index lookup layer serialization.
Expand source code
class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver): """Index lookup layer serialization.""" @property def python_properties(self): # TODO(kathywu): Add python property validator metadata = self._python_properties_internal() # Clear the vocabulary from the config during saving. The vocab will be # saved as part of the lookup table directly, which correctly handle saving # vocabulary files as a SavedModel asset. metadata['config']['vocabulary'] = None # Keep a separate config property to track that a vocabulary was passed in # and not adapted. metadata['config']['has_input_vocabulary'] = self.obj._has_input_vocabulary # pylint: disable=protected-access return metadata
Ancestors
Inherited members
class InputLayerSavedModelSaver (obj)
-
InputLayer serialization.
Expand source code
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): """InputLayer serialization.""" @property def object_identifier(self): return constants.INPUT_LAYER_IDENTIFIER @property def python_properties(self): return dict( class_name=type(self.obj).__name__, name=self.obj.name, dtype=self.obj.dtype, sparse=self.obj.sparse, ragged=self.obj.ragged, batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access config=self.obj.get_config()) def objects_to_serialize(self, serialization_cache): return {} def functions_to_serialize(self, serialization_cache): return {}
Ancestors
Inherited members
class LayerSavedModelSaver (obj)
-
Implements Layer SavedModel serialization.
Expand source code
class LayerSavedModelSaver(base_serialization.SavedModelSaver): """Implements Layer SavedModel serialization.""" @property def object_identifier(self): return constants.LAYER_IDENTIFIER @property def python_properties(self): # TODO(kathywu): Add python property validator return self._python_properties_internal() def _python_properties_internal(self): """Returns dictionary of all python properties.""" # TODO(kathywu): Add support for metrics serialization. # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access batch_input_shape=getattr(self.obj, '_batch_input_shape', None), stateful=self.obj.stateful, must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access ) metadata.update(get_serialized(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = tf.nest.map_structure( lambda x: generic_utils.serialize_keras_object(x) if x else None, self.obj.input_spec) if (self.obj.activity_regularizer is not None and hasattr(self.obj.activity_regularizer, 'get_config')): metadata['activity_regularizer'] = generic_utils.serialize_keras_object( self.obj.activity_regularizer) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata def objects_to_serialize(self, serialization_cache): return (self._get_serialized_attributes( serialization_cache).objects_to_serialize) def functions_to_serialize(self, serialization_cache): return (self._get_serialized_attributes( serialization_cache).functions_to_serialize) def _get_serialized_attributes(self, serialization_cache): """Generates or retrieves serialized attributes from cache.""" keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {}) if self.obj in keras_cache: return keras_cache[self.obj] serialized_attr = keras_cache[self.obj] = ( serialized_attributes.SerializedAttributes.new(self.obj)) if (save_impl.should_skip_serialization(self.obj) or self.obj._must_restore_from_config): # pylint: disable=protected-access return serialized_attr object_dict, function_dict = self._get_serialized_attributes_internal( serialization_cache) serialized_attr.set_and_validate_objects(object_dict) serialized_attr.set_and_validate_functions(function_dict) return serialized_attr def _get_serialized_attributes_internal(self, serialization_cache): """Returns dictionary of serialized attributes.""" objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) functions = save_impl.wrap_layer_functions(self.obj, serialization_cache) # Attribute validator requires that the default save signature is added to # function dict, even if the value is None. functions['_default_save_signature'] = None return objects, functions
Ancestors
Subclasses
Inherited members
class RNNSavedModelSaver (obj)
-
RNN layer serialization.
Expand source code
class RNNSavedModelSaver(LayerSavedModelSaver): """RNN layer serialization.""" @property def object_identifier(self): return constants.RNN_LAYER_IDENTIFIER def _get_serialized_attributes_internal(self, serialization_cache): objects, functions = ( super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( serialization_cache)) states = tf.__internal__.tracking.wrap(self.obj.states) # SaveModel require all the objects to be Trackable when saving. # If the states is still a tuple after wrap_or_unwrap, it means it doesn't # contain any trackable item within it, eg empty tuple or (None, None) for # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can # make it a Trackable again for saving. When loaded, ConvLSTM2D is # able to handle the tuple/list conversion. if isinstance(states, tuple): states = tf.__internal__.tracking.wrap(list(states)) objects['states'] = states return objects, functions
Ancestors
Inherited members