Module keras.layers.rnn_cell_wrapper_v2

Module implementing for RNN wrappers for TF v2.

Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing for RNN wrappers for TF v2."""

# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.


from keras.layers import recurrent
from keras.layers.legacy_rnn import rnn_cell_wrapper_impl
from keras.utils import tf_inspect
from tensorflow.python.util.tf_export import tf_export


class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
  """Base class for cells wrappers V2 compatibility.

  This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
  wrappers that are compatible with V1 and V2, and defines helper methods for
  this purpose.
  """

  def __init__(self, cell, *args, **kwargs):
    super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
    self.cell = cell
    cell_call_spec = tf_inspect.getfullargspec(cell.call)
    self._expects_training_arg = ("training" in cell_call_spec.args) or (
        cell_call_spec.varkw is not None
    )

  def call(self, inputs, state, **kwargs):
    """Runs the RNN cell step computation.

    When `call` is being used, we assume that the wrapper object has been built,
    and therefore the wrapped cells has been built via its `build` method and
    its `call` method can be used directly.

    This allows to use the wrapped cell and the non-wrapped cell equivalently
    when using `call` and `build`.

    Args:
      inputs: A tensor with wrapped cell's input.
      state: A tensor or tuple of tensors with wrapped cell's state.
      **kwargs: Additional arguments passed to the wrapped cell's `call`.

    Returns:
      A pair containing:

      - Output: A tensor with cell's output.
      - New state: A tensor or tuple of tensors with new wrapped cell's state.
    """
    return self._call_wrapped_cell(
        inputs, state, cell_call_fn=self.cell.call, **kwargs)

  def build(self, inputs_shape):
    """Builds the wrapped cell."""
    self.cell.build(inputs_shape)
    self.built = True

  def get_config(self):
    config = {
        "cell": {
            "class_name": self.cell.__class__.__name__,
            "config": self.cell.get_config()
        },
    }
    base_config = super(_RNNCellWrapperV2, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @classmethod
  def from_config(cls, config, custom_objects=None):
    config = config.copy()
    from keras.layers.serialization import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
    cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
    return cls(cell, **config)


@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
                     _RNNCellWrapperV2):
  """Operator adding dropout to inputs and outputs of the given cell."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(DropoutWrapper, self).__init__(*args, **kwargs)
    if isinstance(self.cell, recurrent.LSTMCell):
      raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
                       "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
                       "instead.")

  __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__


@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
                      _RNNCellWrapperV2):
  """RNNCell wrapper that ensures cell inputs are added to the outputs."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(ResidualWrapper, self).__init__(*args, **kwargs)

  __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__


@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
                    _RNNCellWrapperV2):
  """Operator that ensures an RNNCell runs on a particular device."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(DeviceWrapper, self).__init__(*args, **kwargs)

  __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__

Classes

class DeviceWrapper (*args, **kwargs)

Operator that ensures an RNNCell runs on a particular device.

Construct a DeviceWrapper for cell with device device.

Ensures the wrapped cell is called with tf.device(device).

Args

cell
An instance of RNNCell.
device
A device string or function, for passing to tf.device.
**kwargs
dict of keyword arguments for base layer.
Expand source code
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
                    _RNNCellWrapperV2):
  """Operator that ensures an RNNCell runs on a particular device."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(DeviceWrapper, self).__init__(*args, **kwargs)

  __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__

Ancestors

Inherited members

class DropoutWrapper (*args, **kwargs)

Operator adding dropout to inputs and outputs of the given cell.

Create a cell with added input, state, and/or output dropout.

If variational_recurrent is set to True (NOT the default behavior), then the same dropout mask is applied at every step, as described in: A Theoretically Grounded Application of Dropout in Recurrent Neural Networks. Y. Gal, Z. Ghahramani.

Otherwise a different dropout mask is applied at every time step.

Note, by default (unless a custom dropout_state_filter is provided), the memory state (c component of any LSTMStateTuple) passing through a DropoutWrapper is never modified. This behavior is described in the above article.

Args

cell
an RNNCell, a projection to output_size is added to it.
input_keep_prob
unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added.
output_keep_prob
unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added.
state_keep_prob
unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. Note the state components to which dropout is applied when state_keep_prob is in (0, 1) are also determined by the argument dropout_state_filter_visitor (e.g. by default dropout is never applied to the c component of an LSTMStateTuple).
variational_recurrent
Python bool. If True, then the same dropout pattern is applied across all time steps per run call. If this parameter is set, input_size must be provided.
input_size
(optional) (possibly nested tuple of) TensorShape objects containing the depth(s) of the input tensors expected to be passed in to the DropoutWrapper. Required and used iff variational_recurrent = True<code> and </code>input_keep_prob < 1.
dtype
(optional) The dtype of the input, state, and output tensors. Required and used iff variational_recurrent = True.
seed
(optional) integer, the randomness seed.
dropout_state_filter_visitor
(optional), default: (see below). Function that takes any hierarchical level of the state and returns a scalar or depth=1 structure of Python booleans describing which terms in the state should be dropped out. In addition, if the function returns True, dropout is applied across this sublevel. If the function returns False, dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (c) state of LSTMCellState objects, and don't try to apply dropout to TensorArray objects: def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True
**kwargs
dict of keyword arguments for base layer.

Raises

TypeError
if cell is not an RNNCell, or keep_state_fn is provided but not callable.
ValueError
if any of the keep_probs are not between 0 and 1.
Expand source code
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
                     _RNNCellWrapperV2):
  """Operator adding dropout to inputs and outputs of the given cell."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(DropoutWrapper, self).__init__(*args, **kwargs)
    if isinstance(self.cell, recurrent.LSTMCell):
      raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
                       "Please use LSTMCell(dropout=x, recurrent_dropout=y) "
                       "instead.")

  __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__

Ancestors

Inherited members

class ResidualWrapper (*args, **kwargs)

RNNCell wrapper that ensures cell inputs are added to the outputs.

Constructs a ResidualWrapper for cell.

Args

cell
An instance of RNNCell.
residual_fn
(Optional) The function to map raw cell inputs and raw cell outputs to the actual cell outputs of the residual network. Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs.
**kwargs
dict of keyword arguments for base layer.
Expand source code
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
                      _RNNCellWrapperV2):
  """RNNCell wrapper that ensures cell inputs are added to the outputs."""

  def __init__(self, *args, **kwargs):  # pylint: disable=useless-super-delegation
    super(ResidualWrapper, self).__init__(*args, **kwargs)

  __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__

Ancestors

Inherited members