Module keras.layers.preprocessing.preprocessing_stage

Preprocessing stage.

Expand source code
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing stage."""

import tensorflow.compat.v2 as tf
# pylint: disable=g-classes-have-attributes

import numpy as np
from keras.engine import base_preprocessing_layer
from keras.engine import functional
from keras.engine import sequential
from keras.utils import tf_utils


# Sequential methods should take precedence.
class PreprocessingStage(sequential.Sequential,
                         base_preprocessing_layer.PreprocessingLayer):
  """A sequential preprocessing stage.

  This preprocessing stage wraps a list of preprocessing layers into a
  Sequential-like object that enables you to `adapt()` the whole list via
  a single `adapt()` call on the preprocessing stage.

  Args:
    layers: List of layers. Can include layers that aren't preprocessing layers.
    name: String. Optional name for the preprocessing stage object.
  """

  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, or a NumPy array, or an EagerTensor.
        Data to be iterated over to adapt the state of the layers in this
        preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of
        the layers in this preprocessing stage.
    """
    if not isinstance(data,
                      (tf.data.Dataset, np.ndarray, tf.__internal__.EagerTensor)):
      raise ValueError(
          '`adapt()` requires a batched Dataset, an EagerTensor, '
          'or a Numpy array as input, '
          'got {}'.format(type(data)))
    if isinstance(data, tf.data.Dataset):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')

    for current_layer_index in range(0, len(self.layers)):
      if not hasattr(self.layers[current_layer_index], 'adapt'):
        # Skip any layer that does not need adapting.
        continue

      def map_fn(x):
        """Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.

        Args:
          x: Batch of inputs seen in entry of the `PreprocessingStage` instance.

        Returns:
          Batch of inputs to be processed by layer
            `self.layers[current_layer_index]`
        """
        if current_layer_index == 0:  # pylint: disable=cell-var-from-loop
          return x
        for i in range(current_layer_index):  # pylint: disable=cell-var-from-loop
          x = self.layers[i](x)
        return x

      if isinstance(data, tf.data.Dataset):
        current_layer_data = data.map(map_fn)
      else:
        current_layer_data = map_fn(data)
      self.layers[current_layer_index].adapt(current_layer_data,
                                             reset_state=reset_state)


# Functional methods shoud take precedence.
class FunctionalPreprocessingStage(functional.Functional,
                                   base_preprocessing_layer.PreprocessingLayer):
  """A functional preprocessing stage.

  This preprocessing stage wraps a graph of preprocessing layers into a
  Functional-like object that enables you to `adapt()` the whole graph via
  a single `adapt()` call on the preprocessing stage.

  Preprocessing stage is not a complete model, so it cannot be called with
  `fit()`. However, it is possible to add regular layers that may be trainable
  to a preprocessing stage.

  A functional preprocessing stage is created in the same way as `Functional`
  models. A stage can be instantiated by passing two arguments to
  `__init__`. The first argument is the `keras.Input` Tensors that represent
  the inputs to the stage. The second argument specifies the output
  tensors that represent the outputs of this stage. Both arguments can be a
  nested structure of tensors.

  Example:

  >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
  ...           'x1': tf.keras.Input(shape=(1,))}
  >>> norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
  >>> y = norm_layer(inputs['x2'])
  >>> y, z = tf.keras.layers.Lambda(lambda x: (x, x))(inputs['x1'])
  >>> outputs = [inputs['x1'], [y, z]]
  >>> stage = FunctionalPreprocessingStage(inputs, outputs)

  Args:
    inputs: An input tensor (must be created via `tf.keras.Input()`), or a list,
      a dict, or a nested strcture of input tensors.
    outputs: An output tensor, or a list, a dict or a nested structure of output
      tensors.
    name: String, optional. Name of the preprocessing stage.
  """

  def fit(self, *args, **kwargs):
    raise ValueError(
        'Preprocessing stage is not a complete model, and hence should not be '
        '`fit`. Instead, you may feed data to `adapt` the stage to set '
        'appropriate states of the layers in the stage.')

  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, a NumPy array, an EagerTensor, or a list,
        dict or nested structure of Numpy Arrays or EagerTensors. The elements
        of Dataset object need to conform with inputs of the stage. The first
        dimension of NumPy arrays or EagerTensors are understood to be batch
        dimension. Data to be iterated over to adapt the state of the layers in
        this preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of the
        layers in this preprocessing stage.

    Examples:

    >>> # For a stage with dict input
    >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
    ...           'x1': tf.keras.Input(shape=(1,))}
    >>> outputs = [inputs['x1'], inputs['x2']]
    >>> stage = FunctionalPreprocessingStage(inputs, outputs)
    >>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
    ...                                          'x2': tf.ones((4,1))})
    >>> sorted(ds.element_spec.items()) # Check element_spec
    [('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
     ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
    >>> stage.adapt(ds)
    >>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
    >>> stage.adapt(data_np)

    """
    if not isinstance(data, tf.data.Dataset):
      data = self._flatten_to_reference_inputs(data)
      if any(not isinstance(datum, (np.ndarray, tf.__internal__.EagerTensor))
             for datum in data):
        raise ValueError(
            '`adapt()` requires a batched Dataset, a list of EagerTensors '
            'or Numpy arrays as input, got {}'.format(type(data)))
      ds_input = [
          tf.data.Dataset.from_tensor_slices(x).batch(1) for x in data
      ]

    if isinstance(data, tf.data.Dataset):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')
      # Unzip dataset object to a list of single input dataset.
      ds_input = _unzip_dataset(data)

    # Dictionary mapping reference tensors to datasets
    ds_dict = {}
    tensor_usage_count = self._tensor_usage_count
    for x, y in zip(self.inputs, ds_input):
      x_id = str(id(x))
      ds_dict[x_id] = [y] * tensor_usage_count[x_id]

    nodes_by_depth = self._nodes_by_depth
    depth_keys = sorted(nodes_by_depth.keys(), reverse=True)

    def build_map_fn(node, args, kwargs):
      if not isinstance(args.element_spec, tuple):

        def map_fn(*x):
          return tf.nest.flatten(node.layer(*x, **kwargs))
      else:

        def map_fn(*x):
          return tf.nest.flatten(node.layer(x, **kwargs))

      return map_fn

    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        # Input node
        if node.is_input:
          continue

        # Node with input not computed yet
        if any(t_id not in ds_dict for t_id in node.flat_input_ids):
          continue

        args, kwargs = node.map_arguments(ds_dict)
        args = tf.data.Dataset.zip(tf.__internal__.nest.list_to_tuple(*args))

        if node.layer.stateful and hasattr(node.layer, 'adapt'):
          node.layer.adapt(args, reset_state=reset_state)

        map_fn = build_map_fn(node, args, kwargs)
        outputs = args.map(map_fn)
        outputs = _unzip_dataset(outputs)

        # Update ds_dict.
        for x_id, y in zip(node.flat_output_ids, outputs):
          ds_dict[x_id] = [y] * tensor_usage_count[x_id]


def _unzip_dataset(ds):
  """Unzip dataset into a list of single element datasets.

  Args:
    ds: A Dataset object.

  Returns:
    A list of Dataset object, each correspond to one of the `element_spec` of
    the input Dataset object.

  Example:

  >>> ds1 = tf.data.Dataset.from_tensor_slices([1, 2, 3])
  >>> ds2 = tf.data.Dataset.from_tensor_slices([4, 5, 6])
  >>> ds_zipped_tuple = tf.data.Dataset.zip((ds1, ds2))
  >>> ds_unzipped_tuple = _unzip_dataset(ds_zipped_tuple)
  >>> ds_zipped_dict = tf.data.Dataset.zip({'ds1': ds1, 'ds2': ds2})
  >>> ds_unzipped_dict = _unzip_dataset(ds_zipped_dict)

  Then the two elements of `ds_unzipped_tuple` and `ds_unzipped_dict` are both
  the same as `ds1` and `ds2`.
  """
  element_count = len(tf.nest.flatten(ds.element_spec))
  ds_unzipped = []
  for i in range(element_count):

    def map_fn(*x, j=i):
      return tf.nest.flatten(x)[j]

    ds_unzipped.append(ds.map(map_fn))
  return ds_unzipped

Classes

class FunctionalPreprocessingStage (inputs, outputs, name=None, trainable=True, **kwargs)

A functional preprocessing stage.

This preprocessing stage wraps a graph of preprocessing layers into a Functional-like object that enables you to adapt() the whole graph via a single adapt() call on the preprocessing stage.

Preprocessing stage is not a complete model, so it cannot be called with fit(). However, it is possible to add regular layers that may be trainable to a preprocessing stage.

A functional preprocessing stage is created in the same way as Functional models. A stage can be instantiated by passing two arguments to __init__. The first argument is the keras.Input Tensors that represent the inputs to the stage. The second argument specifies the output tensors that represent the outputs of this stage. Both arguments can be a nested structure of tensors.

Example:

>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
...           'x1': tf.keras.Input(shape=(1,))}
>>> norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
>>> y = norm_layer(inputs['x2'])
>>> y, z = tf.keras.layers.Lambda(lambda x: (x, x))(inputs['x1'])
>>> outputs = [inputs['x1'], [y, z]]
>>> stage = FunctionalPreprocessingStage(inputs, outputs)

Args

inputs
An input tensor (must be created via tf.keras.Input()), or a list, a dict, or a nested strcture of input tensors.
outputs
An output tensor, or a list, a dict or a nested structure of output tensors.
name
String, optional. Name of the preprocessing stage.
Expand source code
class FunctionalPreprocessingStage(functional.Functional,
                                   base_preprocessing_layer.PreprocessingLayer):
  """A functional preprocessing stage.

  This preprocessing stage wraps a graph of preprocessing layers into a
  Functional-like object that enables you to `adapt()` the whole graph via
  a single `adapt()` call on the preprocessing stage.

  Preprocessing stage is not a complete model, so it cannot be called with
  `fit()`. However, it is possible to add regular layers that may be trainable
  to a preprocessing stage.

  A functional preprocessing stage is created in the same way as `Functional`
  models. A stage can be instantiated by passing two arguments to
  `__init__`. The first argument is the `keras.Input` Tensors that represent
  the inputs to the stage. The second argument specifies the output
  tensors that represent the outputs of this stage. Both arguments can be a
  nested structure of tensors.

  Example:

  >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
  ...           'x1': tf.keras.Input(shape=(1,))}
  >>> norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
  >>> y = norm_layer(inputs['x2'])
  >>> y, z = tf.keras.layers.Lambda(lambda x: (x, x))(inputs['x1'])
  >>> outputs = [inputs['x1'], [y, z]]
  >>> stage = FunctionalPreprocessingStage(inputs, outputs)

  Args:
    inputs: An input tensor (must be created via `tf.keras.Input()`), or a list,
      a dict, or a nested strcture of input tensors.
    outputs: An output tensor, or a list, a dict or a nested structure of output
      tensors.
    name: String, optional. Name of the preprocessing stage.
  """

  def fit(self, *args, **kwargs):
    raise ValueError(
        'Preprocessing stage is not a complete model, and hence should not be '
        '`fit`. Instead, you may feed data to `adapt` the stage to set '
        'appropriate states of the layers in the stage.')

  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, a NumPy array, an EagerTensor, or a list,
        dict or nested structure of Numpy Arrays or EagerTensors. The elements
        of Dataset object need to conform with inputs of the stage. The first
        dimension of NumPy arrays or EagerTensors are understood to be batch
        dimension. Data to be iterated over to adapt the state of the layers in
        this preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of the
        layers in this preprocessing stage.

    Examples:

    >>> # For a stage with dict input
    >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
    ...           'x1': tf.keras.Input(shape=(1,))}
    >>> outputs = [inputs['x1'], inputs['x2']]
    >>> stage = FunctionalPreprocessingStage(inputs, outputs)
    >>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
    ...                                          'x2': tf.ones((4,1))})
    >>> sorted(ds.element_spec.items()) # Check element_spec
    [('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
     ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
    >>> stage.adapt(ds)
    >>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
    >>> stage.adapt(data_np)

    """
    if not isinstance(data, tf.data.Dataset):
      data = self._flatten_to_reference_inputs(data)
      if any(not isinstance(datum, (np.ndarray, tf.__internal__.EagerTensor))
             for datum in data):
        raise ValueError(
            '`adapt()` requires a batched Dataset, a list of EagerTensors '
            'or Numpy arrays as input, got {}'.format(type(data)))
      ds_input = [
          tf.data.Dataset.from_tensor_slices(x).batch(1) for x in data
      ]

    if isinstance(data, tf.data.Dataset):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')
      # Unzip dataset object to a list of single input dataset.
      ds_input = _unzip_dataset(data)

    # Dictionary mapping reference tensors to datasets
    ds_dict = {}
    tensor_usage_count = self._tensor_usage_count
    for x, y in zip(self.inputs, ds_input):
      x_id = str(id(x))
      ds_dict[x_id] = [y] * tensor_usage_count[x_id]

    nodes_by_depth = self._nodes_by_depth
    depth_keys = sorted(nodes_by_depth.keys(), reverse=True)

    def build_map_fn(node, args, kwargs):
      if not isinstance(args.element_spec, tuple):

        def map_fn(*x):
          return tf.nest.flatten(node.layer(*x, **kwargs))
      else:

        def map_fn(*x):
          return tf.nest.flatten(node.layer(x, **kwargs))

      return map_fn

    for depth in depth_keys:
      for node in nodes_by_depth[depth]:
        # Input node
        if node.is_input:
          continue

        # Node with input not computed yet
        if any(t_id not in ds_dict for t_id in node.flat_input_ids):
          continue

        args, kwargs = node.map_arguments(ds_dict)
        args = tf.data.Dataset.zip(tf.__internal__.nest.list_to_tuple(*args))

        if node.layer.stateful and hasattr(node.layer, 'adapt'):
          node.layer.adapt(args, reset_state=reset_state)

        map_fn = build_map_fn(node, args, kwargs)
        outputs = args.map(map_fn)
        outputs = _unzip_dataset(outputs)

        # Update ds_dict.
        for x_id, y in zip(node.flat_output_ids, outputs):
          ds_dict[x_id] = [y] * tensor_usage_count[x_id]

Ancestors

Methods

def adapt(self, data, reset_state=True)

Adapt the state of the layers of the preprocessing stage to the data.

Args

data
A batched Dataset object, a NumPy array, an EagerTensor, or a list, dict or nested structure of Numpy Arrays or EagerTensors. The elements of Dataset object need to conform with inputs of the stage. The first dimension of NumPy arrays or EagerTensors are understood to be batch dimension. Data to be iterated over to adapt the state of the layers in this preprocessing stage.
reset_state
Whether this call to adapt should reset the state of the layers in this preprocessing stage.

Examples:

>>> # For a stage with dict input
>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
...           'x1': tf.keras.Input(shape=(1,))}
>>> outputs = [inputs['x1'], inputs['x2']]
>>> stage = FunctionalPreprocessingStage(inputs, outputs)
>>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
...                                          'x2': tf.ones((4,1))})
>>> sorted(ds.element_spec.items()) # Check element_spec
[('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
 ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
>>> stage.adapt(ds)
>>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
>>> stage.adapt(data_np)
Expand source code
def adapt(self, data, reset_state=True):
  """Adapt the state of the layers of the preprocessing stage to the data.

  Args:
    data: A batched Dataset object, a NumPy array, an EagerTensor, or a list,
      dict or nested structure of Numpy Arrays or EagerTensors. The elements
      of Dataset object need to conform with inputs of the stage. The first
      dimension of NumPy arrays or EagerTensors are understood to be batch
      dimension. Data to be iterated over to adapt the state of the layers in
      this preprocessing stage.
    reset_state: Whether this call to `adapt` should reset the state of the
      layers in this preprocessing stage.

  Examples:

  >>> # For a stage with dict input
  >>> inputs = {'x2': tf.keras.Input(shape=(5,)),
  ...           'x1': tf.keras.Input(shape=(1,))}
  >>> outputs = [inputs['x1'], inputs['x2']]
  >>> stage = FunctionalPreprocessingStage(inputs, outputs)
  >>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
  ...                                          'x2': tf.ones((4,1))})
  >>> sorted(ds.element_spec.items()) # Check element_spec
  [('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
   ('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
  >>> stage.adapt(ds)
  >>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
  >>> stage.adapt(data_np)

  """
  if not isinstance(data, tf.data.Dataset):
    data = self._flatten_to_reference_inputs(data)
    if any(not isinstance(datum, (np.ndarray, tf.__internal__.EagerTensor))
           for datum in data):
      raise ValueError(
          '`adapt()` requires a batched Dataset, a list of EagerTensors '
          'or Numpy arrays as input, got {}'.format(type(data)))
    ds_input = [
        tf.data.Dataset.from_tensor_slices(x).batch(1) for x in data
    ]

  if isinstance(data, tf.data.Dataset):
    # Validate the datasets to try and ensure we haven't been passed one with
    # infinite size. That would cause an infinite loop here.
    if tf_utils.dataset_is_infinite(data):
      raise ValueError(
          'The dataset passed to `adapt()` has an infinite number of '
          'elements. Please use dataset.take(...) to make the number '
          'of elements finite.')
    # Unzip dataset object to a list of single input dataset.
    ds_input = _unzip_dataset(data)

  # Dictionary mapping reference tensors to datasets
  ds_dict = {}
  tensor_usage_count = self._tensor_usage_count
  for x, y in zip(self.inputs, ds_input):
    x_id = str(id(x))
    ds_dict[x_id] = [y] * tensor_usage_count[x_id]

  nodes_by_depth = self._nodes_by_depth
  depth_keys = sorted(nodes_by_depth.keys(), reverse=True)

  def build_map_fn(node, args, kwargs):
    if not isinstance(args.element_spec, tuple):

      def map_fn(*x):
        return tf.nest.flatten(node.layer(*x, **kwargs))
    else:

      def map_fn(*x):
        return tf.nest.flatten(node.layer(x, **kwargs))

    return map_fn

  for depth in depth_keys:
    for node in nodes_by_depth[depth]:
      # Input node
      if node.is_input:
        continue

      # Node with input not computed yet
      if any(t_id not in ds_dict for t_id in node.flat_input_ids):
        continue

      args, kwargs = node.map_arguments(ds_dict)
      args = tf.data.Dataset.zip(tf.__internal__.nest.list_to_tuple(*args))

      if node.layer.stateful and hasattr(node.layer, 'adapt'):
        node.layer.adapt(args, reset_state=reset_state)

      map_fn = build_map_fn(node, args, kwargs)
      outputs = args.map(map_fn)
      outputs = _unzip_dataset(outputs)

      # Update ds_dict.
      for x_id, y in zip(node.flat_output_ids, outputs):
        ds_dict[x_id] = [y] * tensor_usage_count[x_id]

Inherited members

class PreprocessingStage (layers=None, name=None)

A sequential preprocessing stage.

This preprocessing stage wraps a list of preprocessing layers into a Sequential-like object that enables you to adapt() the whole list via a single adapt() call on the preprocessing stage.

Args

layers
List of layers. Can include layers that aren't preprocessing layers.
name
String. Optional name for the preprocessing stage object.

Creates a Sequential model instance.

Args

layers
Optional list of layers to add to the model.
name
Optional name for the model.
Expand source code
class PreprocessingStage(sequential.Sequential,
                         base_preprocessing_layer.PreprocessingLayer):
  """A sequential preprocessing stage.

  This preprocessing stage wraps a list of preprocessing layers into a
  Sequential-like object that enables you to `adapt()` the whole list via
  a single `adapt()` call on the preprocessing stage.

  Args:
    layers: List of layers. Can include layers that aren't preprocessing layers.
    name: String. Optional name for the preprocessing stage object.
  """

  def adapt(self, data, reset_state=True):
    """Adapt the state of the layers of the preprocessing stage to the data.

    Args:
      data: A batched Dataset object, or a NumPy array, or an EagerTensor.
        Data to be iterated over to adapt the state of the layers in this
        preprocessing stage.
      reset_state: Whether this call to `adapt` should reset the state of
        the layers in this preprocessing stage.
    """
    if not isinstance(data,
                      (tf.data.Dataset, np.ndarray, tf.__internal__.EagerTensor)):
      raise ValueError(
          '`adapt()` requires a batched Dataset, an EagerTensor, '
          'or a Numpy array as input, '
          'got {}'.format(type(data)))
    if isinstance(data, tf.data.Dataset):
      # Validate the datasets to try and ensure we haven't been passed one with
      # infinite size. That would cause an infinite loop here.
      if tf_utils.dataset_is_infinite(data):
        raise ValueError(
            'The dataset passed to `adapt()` has an infinite number of '
            'elements. Please use dataset.take(...) to make the number '
            'of elements finite.')

    for current_layer_index in range(0, len(self.layers)):
      if not hasattr(self.layers[current_layer_index], 'adapt'):
        # Skip any layer that does not need adapting.
        continue

      def map_fn(x):
        """Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.

        Args:
          x: Batch of inputs seen in entry of the `PreprocessingStage` instance.

        Returns:
          Batch of inputs to be processed by layer
            `self.layers[current_layer_index]`
        """
        if current_layer_index == 0:  # pylint: disable=cell-var-from-loop
          return x
        for i in range(current_layer_index):  # pylint: disable=cell-var-from-loop
          x = self.layers[i](x)
        return x

      if isinstance(data, tf.data.Dataset):
        current_layer_data = data.map(map_fn)
      else:
        current_layer_data = map_fn(data)
      self.layers[current_layer_index].adapt(current_layer_data,
                                             reset_state=reset_state)

Ancestors

Methods

def adapt(self, data, reset_state=True)

Adapt the state of the layers of the preprocessing stage to the data.

Args

data
A batched Dataset object, or a NumPy array, or an EagerTensor. Data to be iterated over to adapt the state of the layers in this preprocessing stage.
reset_state
Whether this call to adapt should reset the state of the layers in this preprocessing stage.
Expand source code
def adapt(self, data, reset_state=True):
  """Adapt the state of the layers of the preprocessing stage to the data.

  Args:
    data: A batched Dataset object, or a NumPy array, or an EagerTensor.
      Data to be iterated over to adapt the state of the layers in this
      preprocessing stage.
    reset_state: Whether this call to `adapt` should reset the state of
      the layers in this preprocessing stage.
  """
  if not isinstance(data,
                    (tf.data.Dataset, np.ndarray, tf.__internal__.EagerTensor)):
    raise ValueError(
        '`adapt()` requires a batched Dataset, an EagerTensor, '
        'or a Numpy array as input, '
        'got {}'.format(type(data)))
  if isinstance(data, tf.data.Dataset):
    # Validate the datasets to try and ensure we haven't been passed one with
    # infinite size. That would cause an infinite loop here.
    if tf_utils.dataset_is_infinite(data):
      raise ValueError(
          'The dataset passed to `adapt()` has an infinite number of '
          'elements. Please use dataset.take(...) to make the number '
          'of elements finite.')

  for current_layer_index in range(0, len(self.layers)):
    if not hasattr(self.layers[current_layer_index], 'adapt'):
      # Skip any layer that does not need adapting.
      continue

    def map_fn(x):
      """Maps `PreprocessingStage` inputs to inputs at `current_layer_index`.

      Args:
        x: Batch of inputs seen in entry of the `PreprocessingStage` instance.

      Returns:
        Batch of inputs to be processed by layer
          `self.layers[current_layer_index]`
      """
      if current_layer_index == 0:  # pylint: disable=cell-var-from-loop
        return x
      for i in range(current_layer_index):  # pylint: disable=cell-var-from-loop
        x = self.layers[i](x)
      return x

    if isinstance(data, tf.data.Dataset):
      current_layer_data = data.map(map_fn)
    else:
      current_layer_data = map_fn(data)
    self.layers[current_layer_index].adapt(current_layer_data,
                                           reset_state=reset_state)

Inherited members