Module keras.layers.rnn_cell_wrapper_v2
Module implementing for RNN wrappers for TF v2.
Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module implementing for RNN wrappers for TF v2."""
# Note that all the APIs under this module are exported as tf.nn.*. This is due
# to the fact that those APIs were from tf.nn.rnn_cell_impl. They are ported
# here to avoid the cyclic dependency issue for serialization. These APIs will
# probably be deprecated and removed in future since similar API is available in
# existing Keras RNN API.
from keras.layers import recurrent
from keras.layers.legacy_rnn import rnn_cell_wrapper_impl
from keras.utils import tf_inspect
from tensorflow.python.util.tf_export import tf_export
class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
"""Base class for cells wrappers V2 compatibility.
This class along with `rnn_cell_impl._RNNCellWrapperV1` allows to define
wrappers that are compatible with V1 and V2, and defines helper methods for
this purpose.
"""
def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
self._expects_training_arg = ("training" in cell_call_spec.args) or (
cell_call_spec.varkw is not None
)
def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.
When `call` is being used, we assume that the wrapper object has been built,
and therefore the wrapped cells has been built via its `build` method and
its `call` method can be used directly.
This allows to use the wrapped cell and the non-wrapped cell equivalently
when using `call` and `build`.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
**kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
A pair containing:
- Output: A tensor with cell's output.
- New state: A tensor or tuple of tensors with new wrapped cell's state.
"""
return self._call_wrapped_cell(
inputs, state, cell_call_fn=self.cell.call, **kwargs)
def build(self, inputs_shape):
"""Builds the wrapped cell."""
self.cell.build(inputs_shape)
self.built = True
def get_config(self):
config = {
"cell": {
"class_name": self.cell.__class__.__name__,
"config": self.cell.get_config()
},
}
base_config = super(_RNNCellWrapperV2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
from keras.layers.serialization import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop("cell"), custom_objects=custom_objects)
return cls(cell, **config)
@tf_export("nn.RNNCellDropoutWrapper", v1=[])
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase,
_RNNCellWrapperV2):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DropoutWrapper, self).__init__(*args, **kwargs)
if isinstance(self.cell, recurrent.LSTMCell):
raise ValueError("keras LSTM cell does not work with DropoutWrapper. "
"Please use LSTMCell(dropout=x, recurrent_dropout=y) "
"instead.")
__init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellResidualWrapper", v1=[])
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase,
_RNNCellWrapperV2):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(ResidualWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
@tf_export("nn.RNNCellDeviceWrapper", v1=[])
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase,
_RNNCellWrapperV2):
"""Operator that ensures an RNNCell runs on a particular device."""
def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation
super(DeviceWrapper, self).__init__(*args, **kwargs)
__init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__
Classes
class DeviceWrapper (*args, **kwargs)
-
Operator that ensures an RNNCell runs on a particular device.
Construct a
DeviceWrapper
forcell
with devicedevice
.Ensures the wrapped
cell
is called withtf.device(device)
.Args
cell
- An instance of
RNNCell
. device
- A device string or function, for passing to
tf.device
. **kwargs
- dict of keyword arguments for base layer.
Expand source code
class DeviceWrapper(rnn_cell_wrapper_impl.DeviceWrapperBase, _RNNCellWrapperV2): """Operator that ensures an RNNCell runs on a particular device.""" def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation super(DeviceWrapper, self).__init__(*args, **kwargs) __init__.__doc__ = rnn_cell_wrapper_impl.DeviceWrapperBase.__init__.__doc__
Ancestors
- DeviceWrapperBase
- keras.layers.rnn_cell_wrapper_v2._RNNCellWrapperV2
- AbstractRNNCell
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Inherited members
AbstractRNNCell
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
output_size
set_weights
state_size
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class DropoutWrapper (*args, **kwargs)
-
Operator adding dropout to inputs and outputs of the given cell.
Create a cell with added input, state, and/or output dropout.
If
variational_recurrent
is set toTrue
(NOT the default behavior), then the same dropout mask is applied at every step, as described in: A Theoretically Grounded Application of Dropout in Recurrent Neural Networks. Y. Gal, Z. Ghahramani.Otherwise a different dropout mask is applied at every time step.
Note, by default (unless a custom
dropout_state_filter
is provided), the memory state (c
component of anyLSTMStateTuple
) passing through aDropoutWrapper
is never modified. This behavior is described in the above article.Args
cell
- an RNNCell, a projection to output_size is added to it.
input_keep_prob
- unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added.
output_keep_prob
- unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added.
state_keep_prob
- unit Tensor or float between 0 and 1, output keep
probability; if it is constant and 1, no output dropout will be added.
State dropout is performed on the outgoing states of the cell. Note
the state components to which dropout is applied when
state_keep_prob
is in(0, 1)
are also determined by the argumentdropout_state_filter_visitor
(e.g. by default dropout is never applied to thec
component of anLSTMStateTuple
). variational_recurrent
- Python bool.
If
True
, then the same dropout pattern is applied across all time steps per run call. If this parameter is set,input_size
must be provided. input_size
- (optional) (possibly nested tuple of)
TensorShape
objects containing the depth(s) of the input tensors expected to be passed in to theDropoutWrapper
. Required and used iffvariational_recurrent = True<code> and </code>input_keep_prob < 1
. dtype
- (optional) The
dtype
of the input, state, and output tensors. Required and used iffvariational_recurrent = True
. seed
- (optional) integer, the randomness seed.
dropout_state_filter_visitor
- (optional), default: (see below).
Function
that takes any hierarchical level of the state and returns a scalar or
depth=1 structure of Python booleans describing which terms in the state
should be dropped out.
In addition, if the function returns
True
, dropout is applied across this sublevel. If the function returnsFalse
, dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (c
) state ofLSTMCellState
objects, and don't try to apply dropout toTensorArray
objects:def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True
**kwargs
- dict of keyword arguments for base layer.
Raises
TypeError
- if
cell
is not anRNNCell
, orkeep_state_fn
is provided but notcallable
. ValueError
- if any of the keep_probs are not between 0 and 1.
Expand source code
class DropoutWrapper(rnn_cell_wrapper_impl.DropoutWrapperBase, _RNNCellWrapperV2): """Operator adding dropout to inputs and outputs of the given cell.""" def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation super(DropoutWrapper, self).__init__(*args, **kwargs) if isinstance(self.cell, recurrent.LSTMCell): raise ValueError("keras LSTM cell does not work with DropoutWrapper. " "Please use LSTMCell(dropout=x, recurrent_dropout=y) " "instead.") __init__.__doc__ = rnn_cell_wrapper_impl.DropoutWrapperBase.__init__.__doc__
Ancestors
- DropoutWrapperBase
- keras.layers.rnn_cell_wrapper_v2._RNNCellWrapperV2
- AbstractRNNCell
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Inherited members
DropoutWrapperBase
:AbstractRNNCell
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
output_size
set_weights
state_size
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class ResidualWrapper (*args, **kwargs)
-
RNNCell wrapper that ensures cell inputs are added to the outputs.
Constructs a
ResidualWrapper
forcell
.Args
cell
- An instance of
RNNCell
. residual_fn
- (Optional) The function to map raw cell inputs and raw cell outputs to the actual cell outputs of the residual network. Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs and outputs.
**kwargs
- dict of keyword arguments for base layer.
Expand source code
class ResidualWrapper(rnn_cell_wrapper_impl.ResidualWrapperBase, _RNNCellWrapperV2): """RNNCell wrapper that ensures cell inputs are added to the outputs.""" def __init__(self, *args, **kwargs): # pylint: disable=useless-super-delegation super(ResidualWrapper, self).__init__(*args, **kwargs) __init__.__doc__ = rnn_cell_wrapper_impl.ResidualWrapperBase.__init__.__doc__
Ancestors
- ResidualWrapperBase
- keras.layers.rnn_cell_wrapper_v2._RNNCellWrapperV2
- AbstractRNNCell
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Inherited members
ResidualWrapperBase
:AbstractRNNCell
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
output_size
set_weights
state_size
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights