Module keras.engine.base_layer_utils

Contains private utilities used mainly by the base Layer class.

Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains private utilities used mainly by the base Layer class."""

import tensorflow.compat.v2 as tf

import functools
import threading
from keras import backend
from keras.utils import control_flow_util
from keras.utils import tf_inspect
from keras.utils import tf_utils
from tensorflow.python.util.tf_export import keras_export

_call_context = threading.local()


def create_mean_metric(value, name=None):
  # import keras will import base_layer and then this module, and metric relies
  # on base_layer, which result into a cyclic dependency.
  from keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
  metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
  return metric_obj, metric_obj(value)


def make_variable(name,
                  shape=None,
                  dtype=tf.float32,
                  initializer=None,
                  trainable=None,
                  caching_device=None,
                  validate_shape=True,
                  constraint=None,
                  use_resource=None,
                  collections=None,
                  synchronization=tf.VariableSynchronization.AUTO,
                  aggregation=tf.VariableAggregation.NONE,
                  partitioner=None):  # pylint: disable=unused-argument
  """Temporary util to create a variable (relies on `variable_scope.variable`).

  Some reuse-related technicalities prevent us from using
  `variable_scope.get_variable()` directly, so we use a subcomponent
  that has fewer constraints (`variable_scope.variable()`).

  In the longer term, it seems like a similar "default variable creator" method
  should exist in `Trackable` instead. When this happens, we can get
  rid of this temporary solution.

  TODO(fchollet): remove this method when no longer needed.

  Args:
    name: Variable name.
    shape: Variable shape.
    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
    initializer: Initializer instance (callable).
    trainable: Whether the variable should be part of the layer's
      "trainable_variables" (e.g. variables, biases)
      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
      Note, if the current variable scope is marked as non-trainable
      then this parameter is ignored and any added variables are also
      marked as non-trainable. `trainable` defaults to `True` unless
      `synchronization` is set to `ON_READ`.
    caching_device: Passed to `tf.Variable`.
    validate_shape: Passed to `tf.Variable`.
    constraint: Constraint instance (callable).
    use_resource: Whether to use a `ResourceVariable`.
    collections: List of graph collections keys. The new variable is added to
      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    synchronization: Indicates when a distributed a variable will be
      aggregated. Accepted values are constants defined in the class
      `tf.VariableSynchronization`. By default the synchronization is set to
      `AUTO` and the current `DistributionStrategy` chooses
      when to synchronize. If `synchronization` is set to `ON_READ`,
      `trainable` must not be set to `True`.
    aggregation: Indicates how a distributed variable will be aggregated.
      Accepted values are constants defined in the class
      `tf.VariableAggregation`.
    partitioner: Not handled at this time.

  Returns:
    Variable instance.
  """
  initializing_from_value = False
  if initializer is not None and not callable(initializer):
    initializing_from_value = True

  if initializing_from_value:
    init_val = initializer
    variable_dtype = None
  else:
    # Instantiate initializer if provided initializer is a type object.
    if tf_inspect.isclass(initializer):
      initializer = initializer()
    init_val = functools.partial(initializer, shape, dtype=dtype)
    variable_dtype = dtype.base_dtype

  variable_shape = tf.TensorShape(shape)

  if use_resource is None:
    use_resource = True
  # In theory, in `use_resource` is True and `collections` is empty
  # (that is to say, in TF2), we can use tf.Variable.
  # However, this breaks legacy (Estimator) checkpoints
  # because it changes variable names. Remove this when V1 is fully deprecated.
  return tf.compat.v1.Variable(
      initial_value=init_val,
      name=name,
      trainable=trainable,
      caching_device=caching_device,
      dtype=variable_dtype,
      validate_shape=validate_shape,
      constraint=constraint,
      use_resource=use_resource,
      collections=collections,
      synchronization=synchronization,
      aggregation=aggregation,
      shape=variable_shape if variable_shape else None)


def collect_previous_mask(input_tensors):
  """Retrieves the output mask(s) of the previous node.

  Args:
      input_tensors: An arbitrary structure of Tensors.

  Returns:
      A mask tensor or list of mask tensors.
  """

  def _collect_previous_mask(x):
    return getattr(x, '_keras_mask', None)

  return tf.nest.map_structure(_collect_previous_mask, input_tensors)


def have_all_keras_metadata(tensors):
  return all(hasattr(x, '_keras_history') for x in tf.nest.flatten(tensors))


def generate_placeholders_from_shape(shape):
  return tf.compat.v1.placeholder(shape=shape, dtype=backend.floatx())


def create_keras_history(tensors):
  """Wraps TensorFlow Operations for compatibility with the Functional API.

  This method checks to see if a Tensor in `tensors` is missing Keras metadata
  and has its origin in a Keras `Input` Layer. If so, this method will replace
  the raw TensorFlow Operations that created this tensor with
  `TensorFlowOpLayer` instances that create identical operations.

  Any Tensors not originating from a Keras `Input` Layer will be treated as
  constants when constructing `TensorFlowOpLayer` instances.

  Args:
    tensors: A structure of Tensors, some of which come from raw TensorFlow
      operations and need to have Keras metadata assigned to them.

  Returns:
    created_layers: List. The `TensorFlowOpLayer` instances created to wrap
      the raw Tensorflow operations.
  """
  _, created_layers = _create_keras_history_helper(tensors, set(), [])
  return created_layers


# Unsafe Internal attribute.
# If True, Keras will not evaluate the constant-foldable inputs to tf op
# layers in TF1 graphs. This *might* speed up model construction time in
# certain settings, but it means
# the models will not be serializable/deserializable via get_config
# (Only via Savedmodels). It may also change the semantics of whether
# generated random numbers are generated once and re-used, or recomputed
# each time.
# Note: This path triggers for TPUEstimators / xla compiled graphs regardless
# of this setting.
_UNSAFE_GRAPH_OP_LAYER_CREATION = False


def _create_keras_history_helper(tensors, processed_ops, created_layers):
  """Helper method for `create_keras_history`.

  Args:
    tensors: A structure of Tensors for which to create Keras metadata.
    processed_ops: Set. TensorFlow operations that have already been wrapped in
      `TensorFlowOpLayer` instances.
    created_layers: List. The `TensorFlowOpLayer` instances created.

  Returns:
    Tuple. First element is the updated set of TensorFlow Operations that
    have been wrapped in `TensorFlowOpLayer` instances. Second element is
    a list of the `TensorFlowOpLayer` instances created.
  """
  if tf.compat.v1.executing_eagerly_outside_functions():
    raise ValueError(
        '`create_keras_history` should only be called if eager is disabled!')
  # Import of `base_layer` needed in order to create `TensorFlowOpLayer`.
  # Cannot be imported at top because of circular dependencies.
  # TODO(omalleyt): Resolve circular dependency.
  from keras.engine import base_layer  # pylint: disable=g-import-not-at-top
  tensor_list = tf.nest.flatten(tensors)
  sparse_ops = []
  ragged_tensors = []
  for tensor in tensor_list:
    if getattr(tensor, '_keras_history', None) is not None:
      continue
    if isinstance(
        tensor, (tf.SparseTensor, tf.compat.v1.SparseTensorValue)):
      sparse_ops.append(tensor.op)
      continue
    if tf_utils.is_ragged(tensor):
      # Ragged tensors don't have an op property
      ragged_tensors.append(tensor)
      continue
    op = tensor.op  # The Op that created this Tensor.
    if op not in processed_ops:
      # Recursively set `_keras_history`.
      op_inputs = list(op.inputs)
      constants = {}
      layer_inputs = []
      for i, op_input in enumerate(op_inputs):
        if uses_keras_history(op_input):
          layer_inputs.append(op_input)
        else:
          # Treat any value not originating from a `keras.Input` as
          # a constant. Variables cannot be supported.
          ds_with_session = (
              tf.distribute.in_cross_replica_context() and
              not tf.compat.v1.executing_eagerly_outside_functions())
          using_xla = control_flow_util.GraphOrParentsInXlaContext(
              tf.compat.v1.get_default_graph())
          if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION:
            # In Legacy Graph mode, evaluating here makes Session be
            # configured improperly. The downside of this is that saving
            # via `get_config` breaks, but SavedModel still works.
            constants[i] = op_input
          else:
            with tf.init_scope():
              constants[i] = backend.function([], op_input)([])
      layer_inputs = unnest_if_single_tensor(layer_inputs)
      processed_ops, created_layers = _create_keras_history_helper(
          layer_inputs, processed_ops, created_layers)
      name = op.name
      node_def = op.node_def.SerializeToString()
      op_layer = base_layer.TensorFlowOpLayer(
          node_def, constants=constants, name=name)
      created_layers.append(op_layer)
      op_layer._set_connectivity_metadata(  # pylint: disable=protected-access
          args=(layer_inputs,),
          kwargs={},
          outputs=op.outputs)
      processed_ops.update([op])
  if sparse_ops or ragged_tensors:
    lambda_example = """
    weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights)
    output = tf.keras.layers.Lambda(weights_mult)(input)
    """
    raise ValueError(
        'Tensorflow ops that generate ragged or sparse tensor '
        'outputs are currently not supported by Keras automatic '
        'op wrapping. Please wrap these ops in a Lambda layer: '
        '\n\n```\n{example}\n```\n'
        'Sparse ops encountered: {sparse_ops}\n'
        'Ragged tensors encountered: {ragged_tensors}\n'.format(
            example=lambda_example,
            sparse_ops=str(sparse_ops),
            ragged_tensors=str(ragged_tensors)))
  return processed_ops, created_layers


def unnest_if_single_tensor(input_tensors):
  # Preserve compatibility with older configs
  flat_input_tensors = tf.nest.flatten(input_tensors)
  # If this is a single element but not a dict, unwrap. If this is a dict,
  # assume the first layer expects a dict (as is the case with a
  # DenseFeatures layer); pass through.
  if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
    input_tensors = flat_input_tensors[0]
  return input_tensors


def needs_keras_history(tensors, ignore_call_context=False):
  """Check if any Tensors need to be wrapped in TensorFlowOpLayers.

  This will never return True inside a sublayer, because sublayers
  do not need to create Keras History. Otherwise, this returns True
  if one or more of `tensors` originates from a `keras.Input` and
  does not have `_keras_history` set.

  Args:
    tensors: An arbitrary nested structure of Tensors.
    ignore_call_context: Whether to ignore the check of if currently
      outside of a `call` context. This is `True` when creating
      KerasHistory inside `Node`, where we always know that Tensors
      are being used with the Functional API.

  Returns:
    Bool, whether at least one Tensor needs to be wrapped.
  """
  input_tensors = tf.nest.flatten(tensors)
  if call_context().in_call and not ignore_call_context:
    return False
  if all(
      getattr(tensor, '_keras_history', None) is not None
      for tensor in input_tensors):
    # KerasHistory already set.
    return False
  return uses_keras_history(tensors)


def is_in_keras_graph():
  """Returns if currently executing inside of a Keras graph."""
  return call_context().in_keras_graph


def is_in_eager_or_tf_function():
  """Returns if in eager mode or inside of a tf.function."""
  return tf.executing_eagerly() or is_in_tf_function()


def is_in_tf_function():
  """Returns if inside of a tf.function."""
  # Check if running in V1 graph mode.
  if not tf.compat.v1.executing_eagerly_outside_functions():
    return False
  if not tf.inside_function():
    return False
  # Check if inside Keras FuncGraph.
  if is_in_keras_graph():
    return False
  # Check for a v1 `wrap_function` FuncGraph.
  graph = tf.compat.v1.get_default_graph()
  if (getattr(graph, 'name', False) and
      graph.name.startswith('wrapped_function')):
    return False
  return True


def uses_keras_history(tensors):
  """Check if at least one Tensor originates from a `keras.Input`.

  This is `True` if at least one Tensor has its origin in a `keras.Input`.
  Any Tensor that originates from a `keras.Input` will have a dependency
  Tensor with a `_keras_history` attribute attached. Tensors that have
  already been checked to not originate from a `keras.Input`
  are marked as `_keras_history_checked`.

  Args:
    tensors: An arbitrary nested structure of Tensors.

  Returns:
    Bool, whether at least one Tensor originates from a `keras.Input`.
  """
  checked_tensors = set()
  tensors_to_check = tf.nest.flatten(tensors)

  while tensors_to_check:
    new_tensors_to_check = []
    for tensor in tensors_to_check:
      if id(tensor) in checked_tensors:
        continue

      checked_tensors.add(id(tensor))

      if getattr(tensor, '_keras_history_checked', None) is not None:
        continue
      if getattr(tensor, '_keras_history', None) is not None:
        return True

      try:
        new_tensors_to_check.extend(tensor.op.inputs)
      except AttributeError:
        # In case `tensor` is a Variable created in an Eager context.
        pass

    tensors_to_check = new_tensors_to_check

  # Mark that these Tensors have been checked once for `_keras_history`,
  # and should not be checked again for performance reasons.
  mark_checked(tensors)
  return False


def mark_checked(tensors):
  """Marks that these Tensors should not be tracked.

  This prevents Layers from attempting to create TensorFlowOpLayers
  for these Tensors.

  Args:
    tensors: An arbitrary structure of Tensors.
  """

  def _mark_checked(tensor):
    tensor._keras_history_checked = True  # pylint: disable=protected-access

  tf.nest.map_structure(_mark_checked, tensors)


def call_context():
  """Returns currently active `CallContext`."""
  call_ctx = getattr(_call_context, 'call_context', None)
  if call_ctx is None:
    call_ctx = CallContext()
    _call_context.call_context = call_ctx
  return call_ctx


# Inject the call_context function to keras_deps to remove the dependency
# from TFLite to Keras.
tf.__internal__.register_call_context_function(call_context)


class CallContext(object):
  """Keeps track of properties currently inside a Layer/Model's `call`.

  Attributes:
    in_call: Whether currently inside the `call` of a Layer.
    layer: The `Layer` whose `call` is currently active.
    inputs: The inputs to the currently active `Layer`.
    build_graph: Whether currently inside a Graph or FuncGraph.
    training: Whether currently executing in training or inference mode.
    saving: Whether currently saving to SavedModel.
    frozen: Whether currently executing inside a `Layer` with `trainable` set to
      `False`.
    in_keras_graph: Whether executing inside the Keras Graph.
  """

  def __init__(self):
    # Handle `in_call` separately as it is the most-read attr and reading it is
    # on the hot path.
    self.in_call = False
    self._state = {
        'layer': None,
        'inputs': None,
        'build_graph': False,
        'training': None,
        'saving': None
    }
    # TODO(b/150169018): This logic can be replaced after the Functional API
    # refactor.
    self._in_keras_graph = False

  def enter(self, layer, inputs, build_graph, training, saving=None):
    """Push a Layer and its inputs and state onto the current call context.

    Args:
      layer: The `Layer` whose `call` is currently active.
      inputs: The inputs to the currently active `Layer`.
      build_graph: Whether currently inside a Graph or FuncGraph.
      training: Whether currently executing in training or inference mode.
      saving: Whether currently saving to SavedModel.

    Returns:
      Context manager.
    """
    state = {
        'layer': layer,
        'inputs': inputs,
        'build_graph': build_graph,
        'training': training,
        'saving': saving
    }
    return CallContextManager(self, state)

  @property
  def layer(self):
    return self._state['layer']

  @property
  def inputs(self):
    return self._state['inputs']

  @property
  def build_graph(self):
    return self._state['build_graph']

  @property
  def training(self):
    return self._state['training']

  @property
  def saving(self):
    return self._state['saving']

  @property
  def frozen(self):
    layer = self._state['layer']
    if not layer:
      return False
    return not layer.trainable

  @property
  def in_keras_graph(self):
    # Returns True even if in a subgraph of the Keras graph, such as those
    # created by control flow ops.
    if tf.executing_eagerly():
      return False
    return (self._in_keras_graph or
            getattr(backend.get_graph(), 'name', None) == 'keras_graph')


class CallContextManager(object):
  """Context manager for `CallContext`."""

  def __init__(self, call_ctx, state):
    self._call_ctx = call_ctx
    self._state = state
    self._build_graph = state['build_graph']

  def __enter__(self):
    call_ctx = self._call_ctx
    self._prev_in_call = call_ctx.in_call
    self._prev_state = call_ctx._state

    call_ctx.in_call = True
    call_ctx._state = self._state

    # TODO(b/150169018): This logic can be removed after the Functional API
    # refactor.
    if self._build_graph:
      self._prev_in_keras_graph = call_ctx._in_keras_graph
      call_ctx._in_keras_graph = (
          call_ctx._in_keras_graph or
          getattr(backend.get_graph(), 'name', None) == 'keras_graph')

  def __exit__(self, *exc_info):
    call_ctx = self._call_ctx
    call_ctx.in_call = self._prev_in_call
    call_ctx._state = self._prev_state

    if self._build_graph:
      call_ctx._in_keras_graph = self._prev_in_keras_graph


def training_arg_passed_to_call(argspec, args, kwargs):
  """Returns whether a user passed the `training` argument in `__call__`."""
  # `argspec.args` starts with ['self', 'inputs']
  full_args = dict(zip(argspec.args[2:], args))
  full_args.update(kwargs)
  return 'training' in full_args and full_args['training'] is not None


def is_subclassed(layer):
  """Returns True if the object is a subclassed layer or subclassed model."""
  return (layer.__module__.find('keras.engine') == -1 and
          layer.__module__.find('keras.layers') == -1)


def from_saved_model(layer):
  """Returns whether the layer is loaded from a SavedModel."""
  return layer.__module__.find('keras.saving.saved_model') != -1


def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
  """Checks that tensors passed to `add_*` method match the Keras graph.

  When one of the `add_*` method is called inside a V2 conditional branch,
  the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
  We need to raise clear error messages in such cases.

  Args:
    tensor: Tensor to check, or `False` if it is known that an error
      should be raised.
    method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
    force_raise: If an error should be raised regardless of `tensor`.

  Raises:
    RuntimeError: In case of an out-of-graph tensor.
  """
  if (force_raise or
      (tf.compat.v1.executing_eagerly_outside_functions() and
       hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
    if method == 'activity_regularizer':
      bad_example = """
      class TestModel(tf.keras.Model):

        def __init__(self):
          super(TestModel, self).__init__(name='test_model')
          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')

        def call(self, x, training=None):
          if training:
            return self.dense(x)
          else:
            return self.dense(x)
      """
      correct_example = """
      class TestModel(tf.keras.Model):

        def __init__(self):
          super(TestModel, self).__init__(name='test_model')
          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')

        def call(self, x, training=None):
          return self.dense(x)
      """
      raise RuntimeError(
          'You are using a layer with `activity_regularizer` in a control flow '
          'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
          'Please move your call to the layer with `activity_regularizer` out '
          'of the control flow branch, e.g.:\n{correct_example}\n'
          'You can also resolve this by marking your outer model/layer dynamic'
          ' (eager-only) by passing `dynamic=True` to the layer constructor. '
          'Any kind of control flow is supported with dynamic layers. '
          'Note that using `dynamic=True` requires you to implement static '
          'shape inference in the `compute_output_shape(input_shape)` '
          'method.'.format(
              bad_example=bad_example, correct_example=correct_example))

    if method == 'add_metric':
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          metric = compute_metric(inputs)
          self.add_metric(metric, name='my_metric', aggregation='mean')
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          metric = compute_metric(inputs)
        else:
          metric = 0.
        self.add_metric(metric, name='my_metric', aggregation='mean')
        return inputs
      """
    elif method == 'add_loss':
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          loss = compute_loss(inputs)
          self.add_loss(loss)
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          loss = compute_loss(inputs)
        else:
          loss = 0.
        self.add_loss(loss)
        return inputs
      """
    else:
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          self.add_update(self.w.assign_add(1))
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          increment = 1
        else:
          increment = 0
        self.add_update(self.w.assign_add(increment))
        return inputs
      """
    raise RuntimeError(
        'You are using the method `{method}` in a control flow branch '
        'in your layer, e.g.:\n{bad_example}\n'
        'This is not currently supported. '
        'Please move your call to {method} out of the control flow branch, '
        'e.g.:\n{correct_example}\n'
        'You can also resolve this by marking your layer '
        'as dynamic (eager-only) by passing '
        '`dynamic=True` to the layer constructor. '
        'Any kind of control flow is supported with dynamic layers. '
        'Note that using `dynamic=True` requires you '
        'to implement static shape inference '
        'in the `compute_output_shape(input_shape)` method.'.format(
            method=method,
            bad_example=bad_example,
            correct_example=correct_example))


def mark_as_return(outputs, acd):
  """Marks `outputs` as the return values for automatic control deps."""

  def _mark_as_return(tensor):
    """Marks `tensor` as the return value for automatic control deps."""
    if not tf.is_tensor(tensor):
      return tensor

    # pylint: disable=protected-access
    return_tensor = acd.mark_as_return(tensor)
    if getattr(tensor, '_keras_mask', None) is not None:
      return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
    else:
      return_tensor._keras_mask = None

    # Handle TensorFlow Probability attached metadata.
    # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
    if getattr(tensor, '_tfp_distribution', None) is not None:
      return_tensor._tfp_distribution = tensor._tfp_distribution

    return return_tensor
    # pylint: enable=protected-access

  return tf.nest.map_structure(_mark_as_return, outputs)


V2_DTYPE_BEHAVIOR = None


@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
def enable_v2_dtype_behavior():
  """Enable the V2 dtype behavior for Keras layers.

  By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
  is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
  mixed precision requires V2 dtype behavior to be enabled, this function allows
  you to use mixed precision in Keras layers if `disable_v2_behavior` has been
  called.

  When enabled, the dtype of Keras layers defaults to floatx (which is typically
  float32) instead of None. In addition, layers will automatically cast
  floating-point inputs to the layer's dtype.

  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
  >>> print(layer.dtype)  # float32 since V2 dtype behavior is enabled
  float32
  >>> y = layer(x)  # Layer casts inputs since V2 dtype behavior is enabled
  >>> print(y.dtype.name)
  float32

  A layer author can opt-out their layer from the automatic input casting by
  passing `autocast=False` to the base Layer's constructor. This disables the
  autocasting part of the V2 behavior for that layer, but not the defaulting to
  floatx part of the V2 behavior.

  When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
  will default to the global policy instead of floatx. Layers will automatically
  cast inputs to the policy's compute_dtype.
  """
  global V2_DTYPE_BEHAVIOR
  V2_DTYPE_BEHAVIOR = True


@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
def disable_v2_dtype_behavior():
  """Disables the V2 dtype behavior for Keras layers.

  See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
  """
  global V2_DTYPE_BEHAVIOR
  V2_DTYPE_BEHAVIOR = False


def v2_dtype_behavior_enabled():
  """Returns True if the V2 dtype behavior is enabled."""
  if V2_DTYPE_BEHAVIOR is None:
    return tf.__internal__.tf2.enabled()
  return V2_DTYPE_BEHAVIOR


class TrackableWeightHandler(object):
  """Keras wrapper for handling tracking.Trackable object saving and restoring.

  This class handles Trackables in both V1 and V2 modes, ensuring that they can
  be saved and restored with the correct data and without adding additional ops
  on every save.

  Attributes:
    trackable: The trackable to wrap.
    num_tensors: The number of tensors that this trackable requires for saving.
  """

  def __init__(self, trackable):
    if not isinstance(trackable, tf.__internal__.tracking.Trackable):
      raise ValueError('%s is not a Trackable object.' % (trackable,))
    self._trackable = trackable
    self._distribute_strategy = tf.distribute.get_strategy()

    # TODO(b/141682913): Figure out why this is private and fix it.
    saveables = trackable._gather_saveables_for_checkpoint().values()  # pylint: disable=protected-access
    # 'Saveables' won't exist when we're passed a legacy TF1 table like
    # a StaticHashTable.
    if not saveables:
      self._num_tensors = 0
      self._setter = lambda weights: None
      self._getter = lambda: []

    elif len(saveables) == 1:
      saveable = list(saveables)[0]

      if tf.compat.v1.executing_eagerly_outside_functions():
        # If we're in eager mode, we need to defer calling the Trackable's
        # saveable() callable until data export time.
        # However, it is safe to call the saveable as many times as we want, so
        # we will call it now to figure out how many tensors this Trackable will
        # produce.
        self._saveable = saveable
        self._num_tensors = len(self._saveable().specs)
        self._setter = lambda weights: self._saveable().restore(weights, None)
        self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
      else:
        # If we're in Graph mode, we need to evaluate the Saveable only once and
        # cache the resulting restore graph. Failing to do this will result in
        # new assignment ops being added to the graph each time set_weights() is
        # called.
        self._placeholder_tensors = []
        self._saveable = saveable()
        self._num_tensors = len(self._saveable.specs)
        for spec in self._saveable.specs:
          tensor = spec.tensor
          self._placeholder_tensors.append(
              tf.compat.v1.placeholder(tensor.dtype, tensor.shape))
        self._assign_op = self._saveable.restore(self._placeholder_tensors,
                                                 None)
        self._setter = self._set_weights_v1
        self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
    else:
      raise ValueError('Only Trackables with one Saveable are supported. '
                       'The Trackable %s has %d Saveables.' %
                       (trackable, len(saveables)))

  @property
  def num_tensors(self):
    return self._num_tensors

  def set_weights(self, weights):
    if len(weights) != self._num_tensors:
      raise ValueError(
          ('Weight handler for trackable %s received the wrong number of ' +
           'weights: expected %s, got %s.') %
          (self._trackable, self._num_tensors, len(weights)))
    self._setter(weights)

  def get_tensors(self):
    return self._getter()

  def _set_weights_v1(self, weights):
    feed_dict = {}
    for idx, tensor in enumerate(weights):
      feed_dict[self._placeholder_tensors[idx]] = tensor
    backend.get_session().run(self._assign_op, feed_dict)


def no_ragged_support(inputs, layer_name):
  input_list = tf.nest.flatten(inputs)
  if any(isinstance(x, tf.RaggedTensor) for x in input_list):
    raise ValueError('Layer %s does not support RaggedTensors as input. '
                     'Inputs received: %s. You can try converting your '
                     'input to an uniform tensor.' % (layer_name, inputs))


def is_split_variable(v):
  """Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
  return hasattr(v, '_variable_list') or hasattr(v, '_variables')


def has_weights(obj):
  obj_type = type(obj)
  return (hasattr(obj_type, 'trainable_weights') and
          hasattr(obj_type, 'non_trainable_weights') and
          not isinstance(obj, type))


# TODO(kathywu): This is a temporary hack. When a network of layers is revived
# from SavedModel, only the top-level layer will have losses. This causes issues
# in eager mode because the child layers may have graph losses
# (thus model.losses returns a mix of Eager and graph tensors). To fix this,
# whenever eager losses are added to one layer, add eager losses to all
# child layers. This causes `.losses` to only return eager losses.
REVIVED_LOSS_PLACEHOLDER = (
    'This layer\'s losses have been added to the parent layer.')

Functions

def call_context()

Returns currently active CallContext.

Expand source code
def call_context():
  """Returns currently active `CallContext`."""
  call_ctx = getattr(_call_context, 'call_context', None)
  if call_ctx is None:
    call_ctx = CallContext()
    _call_context.call_context = call_ctx
  return call_ctx
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False)

Checks that tensors passed to add_* method match the Keras graph.

When one of the add_* method is called inside a V2 conditional branch, the underlying tensor gets created in a FuncGraph managed by control_flow_v2. We need to raise clear error messages in such cases.

Args

tensor
Tensor to check, or False if it is known that an error should be raised.
method
Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
force_raise
If an error should be raised regardless of tensor.

Raises

RuntimeError
In case of an out-of-graph tensor.
Expand source code
def check_graph_consistency(tensor=None, method='add_loss', force_raise=False):
  """Checks that tensors passed to `add_*` method match the Keras graph.

  When one of the `add_*` method is called inside a V2 conditional branch,
  the underlying tensor gets created in a FuncGraph managed by control_flow_v2.
  We need to raise clear error messages in such cases.

  Args:
    tensor: Tensor to check, or `False` if it is known that an error
      should be raised.
    method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}.
    force_raise: If an error should be raised regardless of `tensor`.

  Raises:
    RuntimeError: In case of an out-of-graph tensor.
  """
  if (force_raise or
      (tf.compat.v1.executing_eagerly_outside_functions() and
       hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)):
    if method == 'activity_regularizer':
      bad_example = """
      class TestModel(tf.keras.Model):

        def __init__(self):
          super(TestModel, self).__init__(name='test_model')
          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')

        def call(self, x, training=None):
          if training:
            return self.dense(x)
          else:
            return self.dense(x)
      """
      correct_example = """
      class TestModel(tf.keras.Model):

        def __init__(self):
          super(TestModel, self).__init__(name='test_model')
          self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2')

        def call(self, x, training=None):
          return self.dense(x)
      """
      raise RuntimeError(
          'You are using a layer with `activity_regularizer` in a control flow '
          'branch, e.g.:\n{bad_example}\nThis is currently not supported. '
          'Please move your call to the layer with `activity_regularizer` out '
          'of the control flow branch, e.g.:\n{correct_example}\n'
          'You can also resolve this by marking your outer model/layer dynamic'
          ' (eager-only) by passing `dynamic=True` to the layer constructor. '
          'Any kind of control flow is supported with dynamic layers. '
          'Note that using `dynamic=True` requires you to implement static '
          'shape inference in the `compute_output_shape(input_shape)` '
          'method.'.format(
              bad_example=bad_example, correct_example=correct_example))

    if method == 'add_metric':
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          metric = compute_metric(inputs)
          self.add_metric(metric, name='my_metric', aggregation='mean')
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          metric = compute_metric(inputs)
        else:
          metric = 0.
        self.add_metric(metric, name='my_metric', aggregation='mean')
        return inputs
      """
    elif method == 'add_loss':
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          loss = compute_loss(inputs)
          self.add_loss(loss)
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          loss = compute_loss(inputs)
        else:
          loss = 0.
        self.add_loss(loss)
        return inputs
      """
    else:
      bad_example = """
      def call(self, inputs, training=None):
        if training:
          self.add_update(self.w.assign_add(1))
        return inputs
      """
      correct_example = """
      def call(self, inputs, training=None):
        if training:
          increment = 1
        else:
          increment = 0
        self.add_update(self.w.assign_add(increment))
        return inputs
      """
    raise RuntimeError(
        'You are using the method `{method}` in a control flow branch '
        'in your layer, e.g.:\n{bad_example}\n'
        'This is not currently supported. '
        'Please move your call to {method} out of the control flow branch, '
        'e.g.:\n{correct_example}\n'
        'You can also resolve this by marking your layer '
        'as dynamic (eager-only) by passing '
        '`dynamic=True` to the layer constructor. '
        'Any kind of control flow is supported with dynamic layers. '
        'Note that using `dynamic=True` requires you '
        'to implement static shape inference '
        'in the `compute_output_shape(input_shape)` method.'.format(
            method=method,
            bad_example=bad_example,
            correct_example=correct_example))
def collect_previous_mask(input_tensors)

Retrieves the output mask(s) of the previous node.

Args

input_tensors
An arbitrary structure of Tensors.

Returns

A mask tensor or list of mask tensors.

Expand source code
def collect_previous_mask(input_tensors):
  """Retrieves the output mask(s) of the previous node.

  Args:
      input_tensors: An arbitrary structure of Tensors.

  Returns:
      A mask tensor or list of mask tensors.
  """

  def _collect_previous_mask(x):
    return getattr(x, '_keras_mask', None)

  return tf.nest.map_structure(_collect_previous_mask, input_tensors)
def create_keras_history(tensors)

Wraps TensorFlow Operations for compatibility with the Functional API.

This method checks to see if a Tensor in tensors is missing Keras metadata and has its origin in a Keras Input Layer. If so, this method will replace the raw TensorFlow Operations that created this tensor with TensorFlowOpLayer instances that create identical operations.

Any Tensors not originating from a Keras Input Layer will be treated as constants when constructing TensorFlowOpLayer instances.

Args

tensors
A structure of Tensors, some of which come from raw TensorFlow operations and need to have Keras metadata assigned to them.

Returns

created_layers
List. The TensorFlowOpLayer instances created to wrap the raw Tensorflow operations.
Expand source code
def create_keras_history(tensors):
  """Wraps TensorFlow Operations for compatibility with the Functional API.

  This method checks to see if a Tensor in `tensors` is missing Keras metadata
  and has its origin in a Keras `Input` Layer. If so, this method will replace
  the raw TensorFlow Operations that created this tensor with
  `TensorFlowOpLayer` instances that create identical operations.

  Any Tensors not originating from a Keras `Input` Layer will be treated as
  constants when constructing `TensorFlowOpLayer` instances.

  Args:
    tensors: A structure of Tensors, some of which come from raw TensorFlow
      operations and need to have Keras metadata assigned to them.

  Returns:
    created_layers: List. The `TensorFlowOpLayer` instances created to wrap
      the raw Tensorflow operations.
  """
  _, created_layers = _create_keras_history_helper(tensors, set(), [])
  return created_layers
def create_mean_metric(value, name=None)
Expand source code
def create_mean_metric(value, name=None):
  # import keras will import base_layer and then this module, and metric relies
  # on base_layer, which result into a cyclic dependency.
  from keras import metrics as metrics_module  # pylint: disable=g-import-not-at-top
  metric_obj = metrics_module.Mean(name=name, dtype=value.dtype)
  return metric_obj, metric_obj(value)
def disable_v2_dtype_behavior()

Disables the V2 dtype behavior for Keras layers.

See tf.compat.v1.keras.layers.enable_v2_dtype_behavior.

Expand source code
@keras_export(v1=['keras.layers.disable_v2_dtype_behavior'])
def disable_v2_dtype_behavior():
  """Disables the V2 dtype behavior for Keras layers.

  See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`.
  """
  global V2_DTYPE_BEHAVIOR
  V2_DTYPE_BEHAVIOR = False
def enable_v2_dtype_behavior()

Enable the V2 dtype behavior for Keras layers.

By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function is only useful if tf.compat.v1.disable_v2_behavior has been called. Since mixed precision requires V2 dtype behavior to be enabled, this function allows you to use mixed precision in Keras layers if disable_v2_behavior has been called.

When enabled, the dtype of Keras layers defaults to floatx (which is typically float32) instead of None. In addition, layers will automatically cast floating-point inputs to the layer's dtype.

>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
>>> print(layer.dtype)  # float32 since V2 dtype behavior is enabled
float32
>>> y = layer(x)  # Layer casts inputs since V2 dtype behavior is enabled
>>> print(y.dtype.name)
float32

A layer author can opt-out their layer from the automatic input casting by passing autocast=False to the base Layer's constructor. This disables the autocasting part of the V2 behavior for that layer, but not the defaulting to floatx part of the V2 behavior.

When a global tf.keras.mixed_precision.Policy is set, a Keras layer's dtype will default to the global policy instead of floatx. Layers will automatically cast inputs to the policy's compute_dtype.

Expand source code
@keras_export(v1=['keras.layers.enable_v2_dtype_behavior'])
def enable_v2_dtype_behavior():
  """Enable the V2 dtype behavior for Keras layers.

  By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function
  is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since
  mixed precision requires V2 dtype behavior to be enabled, this function allows
  you to use mixed precision in Keras layers if `disable_v2_behavior` has been
  called.

  When enabled, the dtype of Keras layers defaults to floatx (which is typically
  float32) instead of None. In addition, layers will automatically cast
  floating-point inputs to the layer's dtype.

  >>> x = tf.ones((4, 4, 4, 4), dtype='float64')
  >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
  >>> print(layer.dtype)  # float32 since V2 dtype behavior is enabled
  float32
  >>> y = layer(x)  # Layer casts inputs since V2 dtype behavior is enabled
  >>> print(y.dtype.name)
  float32

  A layer author can opt-out their layer from the automatic input casting by
  passing `autocast=False` to the base Layer's constructor. This disables the
  autocasting part of the V2 behavior for that layer, but not the defaulting to
  floatx part of the V2 behavior.

  When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
  will default to the global policy instead of floatx. Layers will automatically
  cast inputs to the policy's compute_dtype.
  """
  global V2_DTYPE_BEHAVIOR
  V2_DTYPE_BEHAVIOR = True
def from_saved_model(layer)

Returns whether the layer is loaded from a SavedModel.

Expand source code
def from_saved_model(layer):
  """Returns whether the layer is loaded from a SavedModel."""
  return layer.__module__.find('keras.saving.saved_model') != -1
def generate_placeholders_from_shape(shape)
Expand source code
def generate_placeholders_from_shape(shape):
  return tf.compat.v1.placeholder(shape=shape, dtype=backend.floatx())
def has_weights(obj)
Expand source code
def has_weights(obj):
  obj_type = type(obj)
  return (hasattr(obj_type, 'trainable_weights') and
          hasattr(obj_type, 'non_trainable_weights') and
          not isinstance(obj, type))
def have_all_keras_metadata(tensors)
Expand source code
def have_all_keras_metadata(tensors):
  return all(hasattr(x, '_keras_history') for x in tf.nest.flatten(tensors))
def is_in_eager_or_tf_function()

Returns if in eager mode or inside of a tf.function.

Expand source code
def is_in_eager_or_tf_function():
  """Returns if in eager mode or inside of a tf.function."""
  return tf.executing_eagerly() or is_in_tf_function()
def is_in_keras_graph()

Returns if currently executing inside of a Keras graph.

Expand source code
def is_in_keras_graph():
  """Returns if currently executing inside of a Keras graph."""
  return call_context().in_keras_graph
def is_in_tf_function()

Returns if inside of a tf.function.

Expand source code
def is_in_tf_function():
  """Returns if inside of a tf.function."""
  # Check if running in V1 graph mode.
  if not tf.compat.v1.executing_eagerly_outside_functions():
    return False
  if not tf.inside_function():
    return False
  # Check if inside Keras FuncGraph.
  if is_in_keras_graph():
    return False
  # Check for a v1 `wrap_function` FuncGraph.
  graph = tf.compat.v1.get_default_graph()
  if (getattr(graph, 'name', False) and
      graph.name.startswith('wrapped_function')):
    return False
  return True
def is_split_variable(v)

Returns True if v is either a PartionedVariable or a ShardedVariable.

Expand source code
def is_split_variable(v):
  """Returns True if `v` is either a PartionedVariable or a ShardedVariable."""
  return hasattr(v, '_variable_list') or hasattr(v, '_variables')
def is_subclassed(layer)

Returns True if the object is a subclassed layer or subclassed model.

Expand source code
def is_subclassed(layer):
  """Returns True if the object is a subclassed layer or subclassed model."""
  return (layer.__module__.find('keras.engine') == -1 and
          layer.__module__.find('keras.layers') == -1)
def make_variable(name, shape=None, dtype=tf.float32, initializer=None, trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, collections=None, synchronization=VariableSynchronization.AUTO, aggregation=VariableAggregationV2.NONE, partitioner=None)

Temporary util to create a variable (relies on variable_scope.variable).

Some reuse-related technicalities prevent us from using variable_scope.get_variable() directly, so we use a subcomponent that has fewer constraints (variable_scope.variable()).

In the longer term, it seems like a similar "default variable creator" method should exist in Trackable instead. When this happens, we can get rid of this temporary solution.

TODO(fchollet): remove this method when no longer needed.

Args

name
Variable name.
shape
Variable shape.
dtype
The type of the variable. Defaults to self.dtype or float32.
initializer
Initializer instance (callable).
trainable
Whether the variable should be part of the layer's "trainable_variables" (e.g. variables, biases) or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also marked as non-trainable. trainable defaults to True unless synchronization is set to ON_READ.
caching_device
Passed to tf.Variable.
validate_shape
Passed to tf.Variable.
constraint
Constraint instance (callable).
use_resource
Whether to use a ResourceVariable.
collections
List of graph collections keys. The new variable is added to these collections. Defaults to [GraphKeys.GLOBAL_VARIABLES].
synchronization
Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class tf.VariableSynchronization. By default the synchronization is set to AUTO and the current DistributionStrategy chooses when to synchronize. If synchronization is set to ON_READ, trainable must not be set to True.
aggregation
Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class tf.VariableAggregation.
partitioner
Not handled at this time.

Returns

Variable instance.

Expand source code
def make_variable(name,
                  shape=None,
                  dtype=tf.float32,
                  initializer=None,
                  trainable=None,
                  caching_device=None,
                  validate_shape=True,
                  constraint=None,
                  use_resource=None,
                  collections=None,
                  synchronization=tf.VariableSynchronization.AUTO,
                  aggregation=tf.VariableAggregation.NONE,
                  partitioner=None):  # pylint: disable=unused-argument
  """Temporary util to create a variable (relies on `variable_scope.variable`).

  Some reuse-related technicalities prevent us from using
  `variable_scope.get_variable()` directly, so we use a subcomponent
  that has fewer constraints (`variable_scope.variable()`).

  In the longer term, it seems like a similar "default variable creator" method
  should exist in `Trackable` instead. When this happens, we can get
  rid of this temporary solution.

  TODO(fchollet): remove this method when no longer needed.

  Args:
    name: Variable name.
    shape: Variable shape.
    dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
    initializer: Initializer instance (callable).
    trainable: Whether the variable should be part of the layer's
      "trainable_variables" (e.g. variables, biases)
      or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
      Note, if the current variable scope is marked as non-trainable
      then this parameter is ignored and any added variables are also
      marked as non-trainable. `trainable` defaults to `True` unless
      `synchronization` is set to `ON_READ`.
    caching_device: Passed to `tf.Variable`.
    validate_shape: Passed to `tf.Variable`.
    constraint: Constraint instance (callable).
    use_resource: Whether to use a `ResourceVariable`.
    collections: List of graph collections keys. The new variable is added to
      these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    synchronization: Indicates when a distributed a variable will be
      aggregated. Accepted values are constants defined in the class
      `tf.VariableSynchronization`. By default the synchronization is set to
      `AUTO` and the current `DistributionStrategy` chooses
      when to synchronize. If `synchronization` is set to `ON_READ`,
      `trainable` must not be set to `True`.
    aggregation: Indicates how a distributed variable will be aggregated.
      Accepted values are constants defined in the class
      `tf.VariableAggregation`.
    partitioner: Not handled at this time.

  Returns:
    Variable instance.
  """
  initializing_from_value = False
  if initializer is not None and not callable(initializer):
    initializing_from_value = True

  if initializing_from_value:
    init_val = initializer
    variable_dtype = None
  else:
    # Instantiate initializer if provided initializer is a type object.
    if tf_inspect.isclass(initializer):
      initializer = initializer()
    init_val = functools.partial(initializer, shape, dtype=dtype)
    variable_dtype = dtype.base_dtype

  variable_shape = tf.TensorShape(shape)

  if use_resource is None:
    use_resource = True
  # In theory, in `use_resource` is True and `collections` is empty
  # (that is to say, in TF2), we can use tf.Variable.
  # However, this breaks legacy (Estimator) checkpoints
  # because it changes variable names. Remove this when V1 is fully deprecated.
  return tf.compat.v1.Variable(
      initial_value=init_val,
      name=name,
      trainable=trainable,
      caching_device=caching_device,
      dtype=variable_dtype,
      validate_shape=validate_shape,
      constraint=constraint,
      use_resource=use_resource,
      collections=collections,
      synchronization=synchronization,
      aggregation=aggregation,
      shape=variable_shape if variable_shape else None)
def mark_as_return(outputs, acd)

Marks outputs as the return values for automatic control deps.

Expand source code
def mark_as_return(outputs, acd):
  """Marks `outputs` as the return values for automatic control deps."""

  def _mark_as_return(tensor):
    """Marks `tensor` as the return value for automatic control deps."""
    if not tf.is_tensor(tensor):
      return tensor

    # pylint: disable=protected-access
    return_tensor = acd.mark_as_return(tensor)
    if getattr(tensor, '_keras_mask', None) is not None:
      return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask)
    else:
      return_tensor._keras_mask = None

    # Handle TensorFlow Probability attached metadata.
    # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`.
    if getattr(tensor, '_tfp_distribution', None) is not None:
      return_tensor._tfp_distribution = tensor._tfp_distribution

    return return_tensor
    # pylint: enable=protected-access

  return tf.nest.map_structure(_mark_as_return, outputs)
def mark_checked(tensors)

Marks that these Tensors should not be tracked.

This prevents Layers from attempting to create TensorFlowOpLayers for these Tensors.

Args

tensors
An arbitrary structure of Tensors.
Expand source code
def mark_checked(tensors):
  """Marks that these Tensors should not be tracked.

  This prevents Layers from attempting to create TensorFlowOpLayers
  for these Tensors.

  Args:
    tensors: An arbitrary structure of Tensors.
  """

  def _mark_checked(tensor):
    tensor._keras_history_checked = True  # pylint: disable=protected-access

  tf.nest.map_structure(_mark_checked, tensors)
def needs_keras_history(tensors, ignore_call_context=False)

Check if any Tensors need to be wrapped in TensorFlowOpLayers.

This will never return True inside a sublayer, because sublayers do not need to create Keras History. Otherwise, this returns True if one or more of tensors originates from a keras.Input and does not have _keras_history set.

Args

tensors
An arbitrary nested structure of Tensors.
ignore_call_context
Whether to ignore the check of if currently outside of a call context. This is True when creating KerasHistory inside Node, where we always know that Tensors are being used with the Functional API.

Returns

Bool, whether at least one Tensor needs to be wrapped.

Expand source code
def needs_keras_history(tensors, ignore_call_context=False):
  """Check if any Tensors need to be wrapped in TensorFlowOpLayers.

  This will never return True inside a sublayer, because sublayers
  do not need to create Keras History. Otherwise, this returns True
  if one or more of `tensors` originates from a `keras.Input` and
  does not have `_keras_history` set.

  Args:
    tensors: An arbitrary nested structure of Tensors.
    ignore_call_context: Whether to ignore the check of if currently
      outside of a `call` context. This is `True` when creating
      KerasHistory inside `Node`, where we always know that Tensors
      are being used with the Functional API.

  Returns:
    Bool, whether at least one Tensor needs to be wrapped.
  """
  input_tensors = tf.nest.flatten(tensors)
  if call_context().in_call and not ignore_call_context:
    return False
  if all(
      getattr(tensor, '_keras_history', None) is not None
      for tensor in input_tensors):
    # KerasHistory already set.
    return False
  return uses_keras_history(tensors)
def no_ragged_support(inputs, layer_name)
Expand source code
def no_ragged_support(inputs, layer_name):
  input_list = tf.nest.flatten(inputs)
  if any(isinstance(x, tf.RaggedTensor) for x in input_list):
    raise ValueError('Layer %s does not support RaggedTensors as input. '
                     'Inputs received: %s. You can try converting your '
                     'input to an uniform tensor.' % (layer_name, inputs))
def training_arg_passed_to_call(argspec, args, kwargs)

Returns whether a user passed the training argument in __call__.

Expand source code
def training_arg_passed_to_call(argspec, args, kwargs):
  """Returns whether a user passed the `training` argument in `__call__`."""
  # `argspec.args` starts with ['self', 'inputs']
  full_args = dict(zip(argspec.args[2:], args))
  full_args.update(kwargs)
  return 'training' in full_args and full_args['training'] is not None
def unnest_if_single_tensor(input_tensors)
Expand source code
def unnest_if_single_tensor(input_tensors):
  # Preserve compatibility with older configs
  flat_input_tensors = tf.nest.flatten(input_tensors)
  # If this is a single element but not a dict, unwrap. If this is a dict,
  # assume the first layer expects a dict (as is the case with a
  # DenseFeatures layer); pass through.
  if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1:
    input_tensors = flat_input_tensors[0]
  return input_tensors
def uses_keras_history(tensors)

Check if at least one Tensor originates from a keras.Input.

This is True if at least one Tensor has its origin in a keras.Input. Any Tensor that originates from a keras.Input will have a dependency Tensor with a _keras_history attribute attached. Tensors that have already been checked to not originate from a keras.Input are marked as _keras_history_checked.

Args

tensors
An arbitrary nested structure of Tensors.

Returns

Bool, whether at least one Tensor originates from a keras.Input.

Expand source code
def uses_keras_history(tensors):
  """Check if at least one Tensor originates from a `keras.Input`.

  This is `True` if at least one Tensor has its origin in a `keras.Input`.
  Any Tensor that originates from a `keras.Input` will have a dependency
  Tensor with a `_keras_history` attribute attached. Tensors that have
  already been checked to not originate from a `keras.Input`
  are marked as `_keras_history_checked`.

  Args:
    tensors: An arbitrary nested structure of Tensors.

  Returns:
    Bool, whether at least one Tensor originates from a `keras.Input`.
  """
  checked_tensors = set()
  tensors_to_check = tf.nest.flatten(tensors)

  while tensors_to_check:
    new_tensors_to_check = []
    for tensor in tensors_to_check:
      if id(tensor) in checked_tensors:
        continue

      checked_tensors.add(id(tensor))

      if getattr(tensor, '_keras_history_checked', None) is not None:
        continue
      if getattr(tensor, '_keras_history', None) is not None:
        return True

      try:
        new_tensors_to_check.extend(tensor.op.inputs)
      except AttributeError:
        # In case `tensor` is a Variable created in an Eager context.
        pass

    tensors_to_check = new_tensors_to_check

  # Mark that these Tensors have been checked once for `_keras_history`,
  # and should not be checked again for performance reasons.
  mark_checked(tensors)
  return False
def v2_dtype_behavior_enabled()

Returns True if the V2 dtype behavior is enabled.

Expand source code
def v2_dtype_behavior_enabled():
  """Returns True if the V2 dtype behavior is enabled."""
  if V2_DTYPE_BEHAVIOR is None:
    return tf.__internal__.tf2.enabled()
  return V2_DTYPE_BEHAVIOR

Classes

class CallContext

Keeps track of properties currently inside a Layer/Model's call.

Attributes

in_call
Whether currently inside the call of a Layer.
layer
The Layer whose call is currently active.
inputs
The inputs to the currently active Layer.
build_graph
Whether currently inside a Graph or FuncGraph.
training
Whether currently executing in training or inference mode.
saving
Whether currently saving to SavedModel.
frozen
Whether currently executing inside a Layer with trainable set to False.
in_keras_graph
Whether executing inside the Keras Graph.
Expand source code
class CallContext(object):
  """Keeps track of properties currently inside a Layer/Model's `call`.

  Attributes:
    in_call: Whether currently inside the `call` of a Layer.
    layer: The `Layer` whose `call` is currently active.
    inputs: The inputs to the currently active `Layer`.
    build_graph: Whether currently inside a Graph or FuncGraph.
    training: Whether currently executing in training or inference mode.
    saving: Whether currently saving to SavedModel.
    frozen: Whether currently executing inside a `Layer` with `trainable` set to
      `False`.
    in_keras_graph: Whether executing inside the Keras Graph.
  """

  def __init__(self):
    # Handle `in_call` separately as it is the most-read attr and reading it is
    # on the hot path.
    self.in_call = False
    self._state = {
        'layer': None,
        'inputs': None,
        'build_graph': False,
        'training': None,
        'saving': None
    }
    # TODO(b/150169018): This logic can be replaced after the Functional API
    # refactor.
    self._in_keras_graph = False

  def enter(self, layer, inputs, build_graph, training, saving=None):
    """Push a Layer and its inputs and state onto the current call context.

    Args:
      layer: The `Layer` whose `call` is currently active.
      inputs: The inputs to the currently active `Layer`.
      build_graph: Whether currently inside a Graph or FuncGraph.
      training: Whether currently executing in training or inference mode.
      saving: Whether currently saving to SavedModel.

    Returns:
      Context manager.
    """
    state = {
        'layer': layer,
        'inputs': inputs,
        'build_graph': build_graph,
        'training': training,
        'saving': saving
    }
    return CallContextManager(self, state)

  @property
  def layer(self):
    return self._state['layer']

  @property
  def inputs(self):
    return self._state['inputs']

  @property
  def build_graph(self):
    return self._state['build_graph']

  @property
  def training(self):
    return self._state['training']

  @property
  def saving(self):
    return self._state['saving']

  @property
  def frozen(self):
    layer = self._state['layer']
    if not layer:
      return False
    return not layer.trainable

  @property
  def in_keras_graph(self):
    # Returns True even if in a subgraph of the Keras graph, such as those
    # created by control flow ops.
    if tf.executing_eagerly():
      return False
    return (self._in_keras_graph or
            getattr(backend.get_graph(), 'name', None) == 'keras_graph')

Instance variables

var build_graph
Expand source code
@property
def build_graph(self):
  return self._state['build_graph']
var frozen
Expand source code
@property
def frozen(self):
  layer = self._state['layer']
  if not layer:
    return False
  return not layer.trainable
var in_keras_graph
Expand source code
@property
def in_keras_graph(self):
  # Returns True even if in a subgraph of the Keras graph, such as those
  # created by control flow ops.
  if tf.executing_eagerly():
    return False
  return (self._in_keras_graph or
          getattr(backend.get_graph(), 'name', None) == 'keras_graph')
var inputs
Expand source code
@property
def inputs(self):
  return self._state['inputs']
var layer
Expand source code
@property
def layer(self):
  return self._state['layer']
var saving
Expand source code
@property
def saving(self):
  return self._state['saving']
var training
Expand source code
@property
def training(self):
  return self._state['training']

Methods

def enter(self, layer, inputs, build_graph, training, saving=None)

Push a Layer and its inputs and state onto the current call context.

Args

layer
The Layer whose call is currently active.
inputs
The inputs to the currently active Layer.
build_graph
Whether currently inside a Graph or FuncGraph.
training
Whether currently executing in training or inference mode.
saving
Whether currently saving to SavedModel.

Returns

Context manager.

Expand source code
def enter(self, layer, inputs, build_graph, training, saving=None):
  """Push a Layer and its inputs and state onto the current call context.

  Args:
    layer: The `Layer` whose `call` is currently active.
    inputs: The inputs to the currently active `Layer`.
    build_graph: Whether currently inside a Graph or FuncGraph.
    training: Whether currently executing in training or inference mode.
    saving: Whether currently saving to SavedModel.

  Returns:
    Context manager.
  """
  state = {
      'layer': layer,
      'inputs': inputs,
      'build_graph': build_graph,
      'training': training,
      'saving': saving
  }
  return CallContextManager(self, state)
class CallContextManager (call_ctx, state)

Context manager for CallContext.

Expand source code
class CallContextManager(object):
  """Context manager for `CallContext`."""

  def __init__(self, call_ctx, state):
    self._call_ctx = call_ctx
    self._state = state
    self._build_graph = state['build_graph']

  def __enter__(self):
    call_ctx = self._call_ctx
    self._prev_in_call = call_ctx.in_call
    self._prev_state = call_ctx._state

    call_ctx.in_call = True
    call_ctx._state = self._state

    # TODO(b/150169018): This logic can be removed after the Functional API
    # refactor.
    if self._build_graph:
      self._prev_in_keras_graph = call_ctx._in_keras_graph
      call_ctx._in_keras_graph = (
          call_ctx._in_keras_graph or
          getattr(backend.get_graph(), 'name', None) == 'keras_graph')

  def __exit__(self, *exc_info):
    call_ctx = self._call_ctx
    call_ctx.in_call = self._prev_in_call
    call_ctx._state = self._prev_state

    if self._build_graph:
      call_ctx._in_keras_graph = self._prev_in_keras_graph
class TrackableWeightHandler (trackable)

Keras wrapper for handling tracking.Trackable object saving and restoring.

This class handles Trackables in both V1 and V2 modes, ensuring that they can be saved and restored with the correct data and without adding additional ops on every save.

Attributes

trackable
The trackable to wrap.
num_tensors
The number of tensors that this trackable requires for saving.
Expand source code
class TrackableWeightHandler(object):
  """Keras wrapper for handling tracking.Trackable object saving and restoring.

  This class handles Trackables in both V1 and V2 modes, ensuring that they can
  be saved and restored with the correct data and without adding additional ops
  on every save.

  Attributes:
    trackable: The trackable to wrap.
    num_tensors: The number of tensors that this trackable requires for saving.
  """

  def __init__(self, trackable):
    if not isinstance(trackable, tf.__internal__.tracking.Trackable):
      raise ValueError('%s is not a Trackable object.' % (trackable,))
    self._trackable = trackable
    self._distribute_strategy = tf.distribute.get_strategy()

    # TODO(b/141682913): Figure out why this is private and fix it.
    saveables = trackable._gather_saveables_for_checkpoint().values()  # pylint: disable=protected-access
    # 'Saveables' won't exist when we're passed a legacy TF1 table like
    # a StaticHashTable.
    if not saveables:
      self._num_tensors = 0
      self._setter = lambda weights: None
      self._getter = lambda: []

    elif len(saveables) == 1:
      saveable = list(saveables)[0]

      if tf.compat.v1.executing_eagerly_outside_functions():
        # If we're in eager mode, we need to defer calling the Trackable's
        # saveable() callable until data export time.
        # However, it is safe to call the saveable as many times as we want, so
        # we will call it now to figure out how many tensors this Trackable will
        # produce.
        self._saveable = saveable
        self._num_tensors = len(self._saveable().specs)
        self._setter = lambda weights: self._saveable().restore(weights, None)
        self._getter = lambda: [spec.tensor for spec in self._saveable().specs]
      else:
        # If we're in Graph mode, we need to evaluate the Saveable only once and
        # cache the resulting restore graph. Failing to do this will result in
        # new assignment ops being added to the graph each time set_weights() is
        # called.
        self._placeholder_tensors = []
        self._saveable = saveable()
        self._num_tensors = len(self._saveable.specs)
        for spec in self._saveable.specs:
          tensor = spec.tensor
          self._placeholder_tensors.append(
              tf.compat.v1.placeholder(tensor.dtype, tensor.shape))
        self._assign_op = self._saveable.restore(self._placeholder_tensors,
                                                 None)
        self._setter = self._set_weights_v1
        self._getter = lambda: [spec.tensor for spec in self._saveable.specs]
    else:
      raise ValueError('Only Trackables with one Saveable are supported. '
                       'The Trackable %s has %d Saveables.' %
                       (trackable, len(saveables)))

  @property
  def num_tensors(self):
    return self._num_tensors

  def set_weights(self, weights):
    if len(weights) != self._num_tensors:
      raise ValueError(
          ('Weight handler for trackable %s received the wrong number of ' +
           'weights: expected %s, got %s.') %
          (self._trackable, self._num_tensors, len(weights)))
    self._setter(weights)

  def get_tensors(self):
    return self._getter()

  def _set_weights_v1(self, weights):
    feed_dict = {}
    for idx, tensor in enumerate(weights):
      feed_dict[self._placeholder_tensors[idx]] = tensor
    backend.get_session().run(self._assign_op, feed_dict)

Subclasses

Instance variables

var num_tensors
Expand source code
@property
def num_tensors(self):
  return self._num_tensors

Methods

def get_tensors(self)
Expand source code
def get_tensors(self):
  return self._getter()
def set_weights(self, weights)
Expand source code
def set_weights(self, weights):
  if len(weights) != self._num_tensors:
    raise ValueError(
        ('Weight handler for trackable %s received the wrong number of ' +
         'weights: expected %s, got %s.') %
        (self._trackable, self._num_tensors, len(weights)))
  self._setter(weights)