Module keras.initializers
Keras initializer serialization / deserialization.
Expand source code
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras initializer serialization / deserialization."""
import tensorflow.compat.v2 as tf
import threading
from tensorflow.python import tf2
from keras.initializers import initializers_v1
from keras.initializers import initializers_v2
from keras.utils import generic_utils
from keras.utils import tf_inspect as inspect
from tensorflow.python.ops import init_ops
from tensorflow.python.util.tf_export import keras_export
# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it
# thread-local to avoid concurrent mutations.
LOCAL = threading.local()
def populate_deserializable_objects():
"""Populates dict ALL_OBJECTS with every built-in initializer.
"""
global LOCAL
if not hasattr(LOCAL, 'ALL_OBJECTS'):
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = None
if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled():
# Objects dict is already generated for the proper TF version:
# do nothing.
return
LOCAL.ALL_OBJECTS = {}
LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled()
# Compatibility aliases (need to exist in both V1 and V2).
LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant
LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity
LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform
LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones
LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal
LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal
LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform
LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal
LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling
LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros
# Out of an abundance of caution we also include these aliases that have
# a non-zero probability of having been included in saved configs in the past.
LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal
LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform
LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal
LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform
LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal
LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform
if tf.__internal__.tf2.enabled():
# For V2, entries are generated automatically based on the content of
# initializers_v2.py.
v2_objs = {}
base_cls = initializers_v2.Initializer
generic_utils.populate_dict_with_module_objects(
v2_objs,
[initializers_v2],
obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls))
for key, value in v2_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
else:
# V1 initializers.
v1_objs = {
'Constant': tf.compat.v1.constant_initializer,
'GlorotNormal': tf.compat.v1.glorot_normal_initializer,
'GlorotUniform': tf.compat.v1.glorot_uniform_initializer,
'Identity': tf.compat.v1.initializers.identity,
'Ones': tf.compat.v1.ones_initializer,
'Orthogonal': tf.compat.v1.orthogonal_initializer,
'VarianceScaling': tf.compat.v1.variance_scaling_initializer,
'Zeros': tf.compat.v1.zeros_initializer,
'HeNormal': initializers_v1.HeNormal,
'HeUniform': initializers_v1.HeUniform,
'LecunNormal': initializers_v1.LecunNormal,
'LecunUniform': initializers_v1.LecunUniform,
'RandomNormal': initializers_v1.RandomNormal,
'RandomUniform': initializers_v1.RandomUniform,
'TruncatedNormal': initializers_v1.TruncatedNormal,
}
for key, value in v1_objs.items():
LOCAL.ALL_OBJECTS[key] = value
# Functional aliases.
LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value
# More compatibility aliases.
LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal']
LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform']
LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones']
LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
# For backwards compatibility, we populate this file with the objects
# from ALL_OBJECTS. We make no guarantees as to whether these objects will
# using their correct version.
populate_deserializable_objects()
globals().update(LOCAL.ALL_OBJECTS)
# Utility functions
@keras_export('keras.initializers.serialize')
def serialize(initializer):
return generic_utils.serialize_keras_object(initializer)
@keras_export('keras.initializers.deserialize')
def deserialize(config, custom_objects=None):
"""Return an `Initializer` object from its config."""
populate_deserializable_objects()
return generic_utils.deserialize_keras_object(
config,
module_objects=LOCAL.ALL_OBJECTS,
custom_objects=custom_objects,
printable_module_name='initializer')
@keras_export('keras.initializers.get')
def get(identifier):
"""Retrieve a Keras initializer by the identifier.
The `identifier` may be the string name of a initializers function or class (
case-sensitively).
>>> identifier = 'Ones'
>>> tf.keras.initializers.deserialize(identifier)
<...keras.initializers.initializers_v2.Ones...>
You can also specify `config` of the initializer to this function by passing
dict containing `class_name` and `config` as an identifier. Also note that the
`class_name` must map to a `Initializer` class.
>>> cfg = {'class_name': 'Ones', 'config': {}}
>>> tf.keras.initializers.deserialize(cfg)
<...keras.initializers.initializers_v2.Ones...>
In the case that the `identifier` is a class, this method will return a new
instance of the class by its constructor.
Args:
identifier: String or dict that contains the initializer name or
configurations.
Returns:
Initializer instance base on the input identifier.
Raises:
ValueError: If the input identifier is not a supported type or in a bad
format.
"""
if identifier is None:
return None
if isinstance(identifier, dict):
return deserialize(identifier)
elif isinstance(identifier, str):
identifier = str(identifier)
return deserialize(identifier)
elif callable(identifier):
if inspect.isclass(identifier):
identifier = identifier()
return identifier
else:
raise ValueError('Could not interpret initializer identifier: ' +
str(identifier))
Sub-modules
keras.initializers.initializers_v1
-
Keras initializers for TF 1.
keras.initializers.initializers_v2
-
Keras initializers for TF 2.
Functions
def deserialize(config, custom_objects=None)
-
Return an
Initializer
object from its config.Expand source code
@keras_export('keras.initializers.deserialize') def deserialize(config, custom_objects=None): """Return an `Initializer` object from its config.""" populate_deserializable_objects() return generic_utils.deserialize_keras_object( config, module_objects=LOCAL.ALL_OBJECTS, custom_objects=custom_objects, printable_module_name='initializer')
def get(identifier)
-
Retrieve a Keras initializer by the identifier.
The
identifier
may be the string name of a initializers function or class ( case-sensitively).>>> identifier = 'Ones' >>> tf.keras.initializers.deserialize(identifier) <...keras.initializers.initializers_v2.Ones...>
You can also specify
config
of the initializer to this function by passing dict containingclass_name
andconfig
as an identifier. Also note that theclass_name
must map to aInitializer
class.>>> cfg = {'class_name': 'Ones', 'config': {}} >>> tf.keras.initializers.deserialize(cfg) <...keras.initializers.initializers_v2.Ones...>
In the case that the
identifier
is a class, this method will return a new instance of the class by its constructor.Args
identifier
- String or dict that contains the initializer name or configurations.
Returns
Initializer instance base on the input identifier.
Raises
ValueError
- If the input identifier is not a supported type or in a bad format.
Expand source code
@keras_export('keras.initializers.get') def get(identifier): """Retrieve a Keras initializer by the identifier. The `identifier` may be the string name of a initializers function or class ( case-sensitively). >>> identifier = 'Ones' >>> tf.keras.initializers.deserialize(identifier) <...keras.initializers.initializers_v2.Ones...> You can also specify `config` of the initializer to this function by passing dict containing `class_name` and `config` as an identifier. Also note that the `class_name` must map to a `Initializer` class. >>> cfg = {'class_name': 'Ones', 'config': {}} >>> tf.keras.initializers.deserialize(cfg) <...keras.initializers.initializers_v2.Ones...> In the case that the `identifier` is a class, this method will return a new instance of the class by its constructor. Args: identifier: String or dict that contains the initializer name or configurations. Returns: Initializer instance base on the input identifier. Raises: ValueError: If the input identifier is not a supported type or in a bad format. """ if identifier is None: return None if isinstance(identifier, dict): return deserialize(identifier) elif isinstance(identifier, str): identifier = str(identifier) return deserialize(identifier) elif callable(identifier): if inspect.isclass(identifier): identifier = identifier() return identifier else: raise ValueError('Could not interpret initializer identifier: ' + str(identifier))
def populate_deserializable_objects()
-
Populates dict ALL_OBJECTS with every built-in initializer.
Expand source code
def populate_deserializable_objects(): """Populates dict ALL_OBJECTS with every built-in initializer. """ global LOCAL if not hasattr(LOCAL, 'ALL_OBJECTS'): LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = None if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf.__internal__.tf2.enabled(): # Objects dict is already generated for the proper TF version: # do nothing. return LOCAL.ALL_OBJECTS = {} LOCAL.GENERATED_WITH_V2 = tf.__internal__.tf2.enabled() # Compatibility aliases (need to exist in both V1 and V2). LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros # Out of an abundance of caution we also include these aliases that have # a non-zero probability of having been included in saved configs in the past. LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform if tf.__internal__.tf2.enabled(): # For V2, entries are generated automatically based on the content of # initializers_v2.py. v2_objs = {} base_cls = initializers_v2.Initializer generic_utils.populate_dict_with_module_objects( v2_objs, [initializers_v2], obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) for key, value in v2_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value else: # V1 initializers. v1_objs = { 'Constant': tf.compat.v1.constant_initializer, 'GlorotNormal': tf.compat.v1.glorot_normal_initializer, 'GlorotUniform': tf.compat.v1.glorot_uniform_initializer, 'Identity': tf.compat.v1.initializers.identity, 'Ones': tf.compat.v1.ones_initializer, 'Orthogonal': tf.compat.v1.orthogonal_initializer, 'VarianceScaling': tf.compat.v1.variance_scaling_initializer, 'Zeros': tf.compat.v1.zeros_initializer, 'HeNormal': initializers_v1.HeNormal, 'HeUniform': initializers_v1.HeUniform, 'LecunNormal': initializers_v1.LecunNormal, 'LecunUniform': initializers_v1.LecunUniform, 'RandomNormal': initializers_v1.RandomNormal, 'RandomUniform': initializers_v1.RandomUniform, 'TruncatedNormal': initializers_v1.TruncatedNormal, } for key, value in v1_objs.items(): LOCAL.ALL_OBJECTS[key] = value # Functional aliases. LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value # More compatibility aliases. LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros']
def serialize(initializer)
-
Expand source code
@keras_export('keras.initializers.serialize') def serialize(initializer): return generic_utils.serialize_keras_object(initializer)