Module keras.layers.recurrent
Recurrent layers and their base classes.
Expand source code
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-classes-have-attributes
"""Recurrent layers and their base classes."""
import tensorflow.compat.v2 as tf
import collections
import warnings
import numpy as np
from keras import activations
from keras import backend
from keras import constraints
from keras import initializers
from keras import regularizers
from keras.engine.base_layer import Layer
from keras.engine.input_spec import InputSpec
from keras.saving.saved_model import layer_serialization
from keras.utils import control_flow_util
from keras.utils import generic_utils
from keras.utils import tf_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
RECURRENT_DROPOUT_WARNING_MSG = (
'RNN `implementation=2` is not supported when `recurrent_dropout` is set. '
'Using `implementation=1`.')
@keras_export('keras.layers.StackedRNNCells')
class StackedRNNCells(Layer):
"""Wrapper allowing a stack of RNN cells to behave as a single cell.
Used to implement efficient stacked RNNs.
Args:
cells: List of RNN cell instances.
Examples:
```python
batch_size = 3
sentence_max_length = 5
n_features = 2
new_shape = (batch_size, sentence_max_length, n_features)
x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32)
rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)]
stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells)
lstm_layer = tf.keras.layers.RNN(stacked_lstm)
result = lstm_layer(x)
```
"""
def __init__(self, cells, **kwargs):
for cell in cells:
if not 'call' in dir(cell):
raise ValueError('All cells must have a `call` method. '
'received cells:', cells)
if not 'state_size' in dir(cell):
raise ValueError('All cells must have a '
'`state_size` attribute. '
'received cells:', cells)
self.cells = cells
# reverse_state_order determines whether the state size will be in a reverse
# order of the cells' state. User might want to set this to True to keep the
# existing behavior. This is only useful when use RNN(return_state=True)
# since the state will be returned as the same order of state_size.
self.reverse_state_order = kwargs.pop('reverse_state_order', False)
if self.reverse_state_order:
logging.warning('reverse_state_order=True in StackedRNNCells will soon '
'be deprecated. Please update the code to work with the '
'natural order of states if you rely on the RNN states, '
'eg RNN(return_state=True).')
super(StackedRNNCells, self).__init__(**kwargs)
@property
def state_size(self):
return tuple(c.state_size for c in
(self.cells[::-1] if self.reverse_state_order else self.cells))
@property
def output_size(self):
if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
elif _is_multiple_state(self.cells[-1].state_size):
return self.cells[-1].state_size[0]
else:
return self.cells[-1].state_size
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
initial_states = []
for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
get_initial_state_fn = getattr(cell, 'get_initial_state', None)
if get_initial_state_fn:
initial_states.append(get_initial_state_fn(
inputs=inputs, batch_size=batch_size, dtype=dtype))
else:
initial_states.append(_generate_zero_filled_state_for_cell(
cell, inputs, batch_size, dtype))
return tuple(initial_states)
def call(self, inputs, states, constants=None, training=None, **kwargs):
# Recover per-cell states.
state_size = (self.state_size[::-1]
if self.reverse_state_order else self.state_size)
nested_states = tf.nest.pack_sequence_as(state_size, tf.nest.flatten(states))
# Call the cells in order and store the returned states.
new_nested_states = []
for cell, states in zip(self.cells, nested_states):
states = states if tf.nest.is_nested(states) else [states]
# TF cell does not wrap the state into list when there is only one state.
is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
if generic_utils.has_arg(cell.call, 'training'):
kwargs['training'] = training
else:
kwargs.pop('training', None)
# Use the __call__ function for callable objects, eg layers, so that it
# will have the proper name scopes for the ops, etc.
cell_call_fn = cell.__call__ if callable(cell) else cell.call
if generic_utils.has_arg(cell.call, 'constants'):
inputs, states = cell_call_fn(inputs, states,
constants=constants, **kwargs)
else:
inputs, states = cell_call_fn(inputs, states, **kwargs)
new_nested_states.append(states)
return inputs, tf.nest.pack_sequence_as(state_size,
tf.nest.flatten(new_nested_states))
@tf_utils.shape_type_conversion
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
for cell in self.cells:
if isinstance(cell, Layer) and not cell.built:
with backend.name_scope(cell.name):
cell.build(input_shape)
cell.built = True
if getattr(cell, 'output_size', None) is not None:
output_dim = cell.output_size
elif _is_multiple_state(cell.state_size):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
input_shape = tuple([input_shape[0]] +
tf.TensorShape(output_dim).as_list())
self.built = True
def get_config(self):
cells = []
for cell in self.cells:
cells.append(generic_utils.serialize_keras_object(cell))
config = {'cells': cells}
base_config = super(StackedRNNCells, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cells = []
for cell_config in config.pop('cells'):
cells.append(
deserialize_layer(cell_config, custom_objects=custom_objects))
return cls(cells, **config)
@keras_export('keras.layers.RNN')
class RNN(Layer):
"""Base class for recurrent layers.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Args:
cell: A RNN cell instance or a list of RNN cell instances.
A RNN cell is a class that has:
- A `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
section "Note on passing external constants" below.
- A `state_size` attribute. This can be a single integer
(single state) in which case it is the size of the recurrent
state. This can also be a list/tuple of integers (one size per state).
The `state_size` can also be TensorShape or tuple/list of
TensorShape, to represent high dimension state.
- A `output_size` attribute. This can be a single integer or a
TensorShape, which represent the shape of the output. For backward
compatible reason, if this attribute is not available for the
cell, the value will be inferred by the first element of the
`state_size`.
- A `get_initial_state(inputs=None, batch_size=None, dtype=None)`
method that creates a tensor meant to be fed to `call()` as the
initial state, if the user didn't specify any initial state via other
means. The returned initial state should have a shape of
[batch_size, cell.state_size]. The cell might choose to create a
tensor full of zeros, or full of other values based on the cell's
implementation.
`inputs` is the input tensor to the RNN layer, which should
contain the batch size as its shape[0], and also dtype. Note that
the shape[0] might be `None` during the graph construction. Either
the `inputs` or the pair of `batch_size` and `dtype` are provided.
`batch_size` is a scalar tensor that represents the batch size
of the inputs. `dtype` is `tf.DType` that represents the dtype of
the inputs.
For backward compatibility, if this method is not implemented
by the cell, the RNN layer will create a zero filled tensor with the
size of [batch_size, cell.state_size].
In the case that `cell` is a list of RNN cell instances, the cells
will be stacked on top of each other in the RNN, resulting in an
efficient stacked RNN.
return_sequences: Boolean (default `False`). Whether to return the last
output in the output sequence, or the full sequence.
return_state: Boolean (default `False`). Whether to return the last state
in addition to the output.
go_backwards: Boolean (default `False`).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default `False`). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default `False`).
If True, the network will be unrolled, else a symbolic loop will be used.
Unrolling can speed-up a RNN, although it tends to be more
memory-intensive. Unrolling is only suitable for short sequences.
time_major: The shape format of the `inputs` and `outputs` tensors.
If True, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of the
RNN calculation. However, most TensorFlow data is batch-major, so by
default this function accepts input and emits output in batch-major
form.
zero_output_for_mask: Boolean (default `False`).
Whether the output should use zeros for the masked timesteps. Note that
this field is only used when `return_sequences` is True and mask is
provided. It can useful if you want to reuse the raw output sequence of
the RNN without interference from the masked timesteps, eg, merging
bidirectional RNNs.
Call arguments:
inputs: Input tensor.
mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether
a given timestep should be masked. An individual `True` entry indicates
that the corresponding timestep should be utilized, while a `False`
entry indicates that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is for use with cells that use dropout.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
constants: List of constant tensors to be passed to the cell at each
timestep.
Input shape:
N-D tensor with shape `[batch_size, timesteps, ...]` or
`[timesteps, batch_size, ...]` when time_major is True.
Output shape:
- If `return_state`: a list of tensors. The first tensor is
the output. The remaining tensors are the last states,
each with shape `[batch_size, state_size]`, where `state_size` could
be a high dimension tensor shape.
- If `return_sequences`: N-D tensor with shape
`[batch_size, timesteps, output_size]`, where `output_size` could
be a high dimension tensor shape, or
`[timesteps, batch_size, output_size]` when `time_major` is True.
- Else, N-D tensor with shape `[batch_size, output_size]`, where
`output_size` could be a high dimension tensor shape.
Masking:
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter
set to `True`.
Note on using statefulness in RNNs:
You can set RNN layers to be 'stateful', which means that the states
computed for the samples in one batch will be reused as initial states
for the samples in the next batch. This assumes a one-to-one mapping
between samples in different successive batches.
To enable statefulness:
- Specify `stateful=True` in the layer constructor.
- Specify a fixed batch size for your model, by passing
If sequential model:
`batch_input_shape=(...)` to the first layer in your model.
Else for functional model with 1 or more Input layers:
`batch_shape=(...)` to all the first layers in your model.
This is the expected shape of your inputs
*including the batch size*.
It should be a tuple of integers, e.g. `(32, 10, 100)`.
- Specify `shuffle=False` when calling `fit()`.
To reset the states of your model, call `.reset_states()` on either
a specific layer, or on your entire model.
Note on specifying the initial state of RNNs:
You can specify the initial state of RNN layers symbolically by
calling them with the keyword argument `initial_state`. The value of
`initial_state` should be a tensor or list of tensors representing
the initial state of the RNN layer.
You can specify the initial state of RNN layers numerically by
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
Note on passing external constants to RNNs:
You can pass "external" constants to the cell using the `constants`
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
requires that the `cell.call` method accepts the same keyword argument
`constants`. Such constants can be used to condition the cell
transformation on additional static inputs (not changing over time),
a.k.a. an attention mechanism.
Examples:
```python
# First, let's define a RNN Cell, as a layer subclass.
class MinimalRNNCell(keras.layers.Layer):
def __init__(self, units, **kwargs):
self.units = units
self.state_size = units
super(MinimalRNNCell, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, [output]
# Let's use this cell in a RNN layer:
cell = MinimalRNNCell(32)
x = keras.Input((None, 5))
layer = RNN(cell)
y = layer(x)
# Here's how to use the cell to build a stacked RNN:
cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
x = keras.Input((None, 5))
layer = RNN(cells)
y = layer(x)
```
"""
def __init__(self,
cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
if not 'call' in dir(cell):
raise ValueError('`cell` should have a `call` method. '
'The RNN was passed:', cell)
if not 'state_size' in dir(cell):
raise ValueError('The RNN cell should have '
'an attribute `state_size` '
'(tuple of integers, '
'one integer per RNN state).')
# If True, the output for masked timestep will be zeros, whereas in the
# False case, output from previous timestep is returned for masked timestep.
self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)
if 'input_shape' not in kwargs and (
'input_dim' in kwargs or 'input_length' in kwargs):
input_shape = (kwargs.pop('input_length', None),
kwargs.pop('input_dim', None))
kwargs['input_shape'] = input_shape
super(RNN, self).__init__(**kwargs)
self.cell = cell
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
self.time_major = time_major
self.supports_masking = True
# The input shape is unknown yet, it could have nested tensor inputs, and
# the input spec will be the list of specs for nested inputs, the structure
# of the input_spec will be the same as the input.
self.input_spec = None
self.state_spec = None
self._states = None
self.constants_spec = None
self._num_constants = 0
if stateful:
if tf.distribute.has_strategy():
raise ValueError('RNNs with stateful=True not yet supported with '
'tf.distribute.Strategy.')
@property
def _use_input_spec_as_call_signature(self):
if self.unroll:
# When the RNN layer is unrolled, the time step shape cannot be unknown.
# The input spec does not define the time step (because this layer can be
# called with any time step value, as long as it is not None), so it
# cannot be used as the call function signature when saving to SavedModel.
return False
return super(RNN, self)._use_input_spec_as_call_signature
@property
def states(self):
if self._states is None:
state = tf.nest.map_structure(lambda _: None, self.cell.state_size)
return state if tf.nest.is_nested(self.cell.state_size) else [state]
return self._states
@states.setter
# Automatic tracking catches "self._states" which adds an extra weight and
# breaks HDF5 checkpoints.
@tf.__internal__.tracking.no_automatic_dependency_tracking
def states(self, states):
self._states = states
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
# inputs.
try:
input_shape = tf.TensorShape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
input_shape = tf.nest.flatten(input_shape)[0]
batch = input_shape[0]
time_step = input_shape[1]
if self.time_major:
batch, time_step = time_step, batch
if _is_multiple_state(self.cell.state_size):
state_size = self.cell.state_size
else:
state_size = [self.cell.state_size]
def _get_output_shape(flat_output_size):
output_dim = tf.TensorShape(flat_output_size).as_list()
if self.return_sequences:
if self.time_major:
output_shape = tf.TensorShape(
[time_step, batch] + output_dim)
else:
output_shape = tf.TensorShape(
[batch, time_step] + output_dim)
else:
output_shape = tf.TensorShape([batch] + output_dim)
return output_shape
if getattr(self.cell, 'output_size', None) is not None:
# cell.output_size could be nested structure.
output_shape = tf.nest.flatten(tf.nest.map_structure(
_get_output_shape, self.cell.output_size))
output_shape = output_shape[0] if len(output_shape) == 1 else output_shape
else:
# Note that state_size[0] could be a tensor_shape or int.
output_shape = _get_output_shape(state_size[0])
if self.return_state:
def _get_state_shape(flat_state):
state_shape = [batch] + tf.TensorShape(flat_state).as_list()
return tf.TensorShape(state_shape)
state_shape = tf.nest.map_structure(_get_state_shape, state_size)
return generic_utils.to_list(output_shape) + tf.nest.flatten(state_shape)
else:
return output_shape
def compute_mask(self, inputs, mask):
# Time step masks must be the same for each input.
# This is because the mask for an RNN is of size [batch, time_steps, 1],
# and specifies which time steps should be skipped, and a time step
# must be skipped for all inputs.
# TODO(scottzhu): Should we accept multiple different masks?
mask = tf.nest.flatten(mask)[0]
output_mask = mask if self.return_sequences else None
if self.return_state:
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
else:
return output_mask
def build(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
# The input_shape here could be a nest structure.
# do the tensor_shape to shapes here. The input could be single tensor, or a
# nested structure of tensors.
def get_input_spec(shape):
"""Convert input shape to InputSpec."""
if isinstance(shape, tf.TensorShape):
input_spec_shape = shape.as_list()
else:
input_spec_shape = list(shape)
batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
if not self.stateful:
input_spec_shape[batch_index] = None
input_spec_shape[time_step_index] = None
return InputSpec(shape=tuple(input_spec_shape))
def get_step_input_shape(shape):
if isinstance(shape, tf.TensorShape):
shape = tuple(shape.as_list())
# remove the timestep from the input_shape
return shape[1:] if self.time_major else (shape[0],) + shape[2:]
# Check whether the input shape contains any nested shapes. It could be
# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy
# inputs.
try:
input_shape = tf.TensorShape(input_shape)
except (ValueError, TypeError):
# A nested tensor input
pass
if not tf.nest.is_nested(input_shape):
# This indicates the there is only one input.
if self.input_spec is not None:
self.input_spec[0] = get_input_spec(input_shape)
else:
self.input_spec = [get_input_spec(input_shape)]
step_input_shape = get_step_input_shape(input_shape)
else:
if self.input_spec is not None:
self.input_spec[0] = tf.nest.map_structure(get_input_spec, input_shape)
else:
self.input_spec = generic_utils.to_list(
tf.nest.map_structure(get_input_spec, input_shape))
step_input_shape = tf.nest.map_structure(get_step_input_shape, input_shape)
# allow cell (if layer) to build before we set or validate state_spec.
if isinstance(self.cell, Layer) and not self.cell.built:
with backend.name_scope(self.cell.name):
self.cell.build(step_input_shape)
self.cell.built = True
# set or validate state_spec
if _is_multiple_state(self.cell.state_size):
state_size = list(self.cell.state_size)
else:
state_size = [self.cell.state_size]
if self.state_spec is not None:
# initial_state was passed in call, check compatibility
self._validate_state_spec(state_size, self.state_spec)
else:
self.state_spec = [
InputSpec(shape=[None] + tf.TensorShape(dim).as_list())
for dim in state_size
]
if self.stateful:
self.reset_states()
self.built = True
@staticmethod
def _validate_state_spec(cell_state_sizes, init_state_specs):
"""Validate the state spec between the initial_state and the state_size.
Args:
cell_state_sizes: list, the `state_size` attribute from the cell.
init_state_specs: list, the `state_spec` from the initial_state that is
passed in `call()`.
Raises:
ValueError: When initial state spec is not compatible with the state size.
"""
validation_error = ValueError(
'An `initial_state` was passed that is not compatible with '
'`cell.state_size`. Received `state_spec`={}; '
'however `cell.state_size` is '
'{}'.format(init_state_specs, cell_state_sizes))
flat_cell_state_sizes = tf.nest.flatten(cell_state_sizes)
flat_state_specs = tf.nest.flatten(init_state_specs)
if len(flat_cell_state_sizes) != len(flat_state_specs):
raise validation_error
for cell_state_spec, cell_state_size in zip(flat_state_specs,
flat_cell_state_sizes):
if not tf.TensorShape(
# Ignore the first axis for init_state which is for batch
cell_state_spec.shape[1:]).is_compatible_with(
tf.TensorShape(cell_state_size)):
raise validation_error
@doc_controls.do_not_doc_inheritable
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
if tf.nest.is_nested(inputs):
# The input are nested sequences. Use the first element in the seq to get
# batch size and dtype.
inputs = tf.nest.flatten(inputs)[0]
input_shape = tf.shape(inputs)
batch_size = input_shape[1] if self.time_major else input_shape[0]
dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
inputs=None, batch_size=batch_size, dtype=dtype)
else:
init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not tf.nest.is_nested(init_state):
init_state = [init_state]
# Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.
return list(init_state)
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = _standardize_args(inputs,
initial_state,
constants,
self._num_constants)
if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)
# If any of `initial_state` or `constants` are specified and are Keras
# tensors, then add them to the inputs and temporarily modify the
# input_spec to include them.
additional_inputs = []
additional_specs = []
if initial_state is not None:
additional_inputs += initial_state
self.state_spec = tf.nest.map_structure(
lambda s: InputSpec(shape=backend.int_shape(s)), initial_state)
additional_specs += self.state_spec
if constants is not None:
additional_inputs += constants
self.constants_spec = [
InputSpec(shape=backend.int_shape(constant)) for constant in constants
]
self._num_constants = len(constants)
additional_specs += self.constants_spec
# additional_inputs can be empty if initial_state or constants are provided
# but empty (e.g. the cell is stateless).
flat_additional_inputs = tf.nest.flatten(additional_inputs)
is_keras_tensor = backend.is_keras_tensor(
flat_additional_inputs[0]) if flat_additional_inputs else True
for tensor in flat_additional_inputs:
if backend.is_keras_tensor(tensor) != is_keras_tensor:
raise ValueError('The initial state or constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors'
' (a "Keras tensor" is a tensor that was'
' returned by a Keras layer, or by `Input`)')
if is_keras_tensor:
# Compute the full input spec, including state and constants
full_input = [inputs] + additional_inputs
if self.built:
# Keep the input_spec since it has been populated in build() method.
full_input_spec = self.input_spec + additional_specs
else:
# The original input_spec is None since there could be a nested tensor
# input. Update the input_spec to match the inputs.
full_input_spec = generic_utils.to_list(
tf.nest.map_structure(lambda _: None, inputs)) + additional_specs
# Perform the call with temporarily replaced input_spec
self.input_spec = full_input_spec
output = super(RNN, self).__call__(full_input, **kwargs)
# Remove the additional_specs from input spec and keep the rest. It is
# important to keep since the input spec was populated by build(), and
# will be reused in the stateful=True.
self.input_spec = self.input_spec[:-len(additional_specs)]
return output
else:
if initial_state is not None:
kwargs['initial_state'] = initial_state
if constants is not None:
kwargs['constants'] = constants
return super(RNN, self).__call__(inputs, **kwargs)
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
# The input should be dense, padded with zeros. If a ragged input is fed
# into the layer, it is padded and the row lengths are used for masking.
inputs, row_lengths = backend.convert_inputs_if_ragged(inputs)
is_ragged_input = (row_lengths is not None)
self._validate_args_if_ragged(is_ragged_input, mask)
inputs, initial_state, constants = self._process_inputs(
inputs, initial_state, constants)
self._maybe_reset_cell_dropout_mask(self.cell)
if isinstance(self.cell, StackedRNNCells):
for cell in self.cell.cells:
self._maybe_reset_cell_dropout_mask(cell)
if mask is not None:
# Time step masks must be the same for each input.
# TODO(scottzhu): Should we accept multiple different masks?
mask = tf.nest.flatten(mask)[0]
if tf.nest.is_nested(inputs):
# In the case of nested input, use the first element for shape check.
input_shape = backend.int_shape(tf.nest.flatten(inputs)[0])
else:
input_shape = backend.int_shape(inputs)
timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps is None:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined. \n'
'- If using a Sequential model, '
'specify the time dimension by passing '
'an `input_shape` or `batch_input_shape` '
'argument to your first layer. If your '
'first layer is an Embedding, you can '
'also use the `input_length` argument.\n'
'- If using the functional API, specify '
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')
kwargs = {}
if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
# TF RNN cells expect single tensor as state instead of list wrapped tensor.
is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None
# Use the __call__ function for callable objects, eg layers, so that it
# will have the proper name scopes for the ops, etc.
cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call
if constants:
if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
def step(inputs, states):
constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = cell_call_fn(
inputs, states, constants=constants, **kwargs)
if not tf.nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
else:
def step(inputs, states):
states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
output, new_states = cell_call_fn(inputs, states, **kwargs)
if not tf.nest.is_nested(new_states):
new_states = [new_states]
return output, new_states
last_output, outputs, states = backend.rnn(
step,
inputs,
initial_state,
constants=constants,
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
input_length=row_lengths if row_lengths is not None else timesteps,
time_major=self.time_major,
zero_output_for_mask=self.zero_output_for_mask)
if self.stateful:
updates = [
tf.compat.v1.assign(self_state, state) for self_state, state in zip(
tf.nest.flatten(self.states), tf.nest.flatten(states))
]
self.add_update(updates)
if self.return_sequences:
output = backend.maybe_convert_to_ragged(
is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards)
else:
output = last_output
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return generic_utils.to_list(output) + states
else:
return output
def _process_inputs(self, inputs, initial_state, constants):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if (isinstance(inputs, collections.abc.Sequence)
and not isinstance(inputs, tuple)):
# get initial_state from full input spec
# as they could be copied to multiple GPU.
if not self._num_constants:
initial_state = inputs[1:]
else:
initial_state = inputs[1:-self._num_constants]
constants = inputs[-self._num_constants:]
if len(initial_state) == 0:
initial_state = None
inputs = inputs[0]
if self.stateful:
if initial_state is not None:
# When layer is stateful and initial_state is provided, check if the
# recorded state is same as the default value (zeros). Use the recorded
# state if it is not same as the default.
non_zero_count = tf.add_n([tf.math.count_nonzero(s)
for s in tf.nest.flatten(self.states)])
# Set strict = True to keep the original structure of the state.
initial_state = tf.compat.v1.cond(non_zero_count > 0,
true_fn=lambda: self.states,
false_fn=lambda: initial_state,
strict=True)
else:
initial_state = self.states
elif initial_state is None:
initial_state = self.get_initial_state(inputs)
if len(initial_state) != len(self.states):
raise ValueError('Layer has ' + str(len(self.states)) +
' states but was passed ' + str(len(initial_state)) +
' initial states.')
return inputs, initial_state, constants
def _validate_args_if_ragged(self, is_ragged_input, mask):
if not is_ragged_input:
return
if mask is not None:
raise ValueError('The mask that was passed in was ' + str(mask) +
' and cannot be applied to RaggedTensor inputs. Please '
'make sure that there is no mask passed in by upstream '
'layers.')
if self.unroll:
raise ValueError('The input received contains RaggedTensors and does '
'not support unrolling. Disable unrolling by passing '
'`unroll=False` in the RNN Layer constructor.')
def _maybe_reset_cell_dropout_mask(self, cell):
if isinstance(cell, DropoutRNNCellMixin):
cell.reset_dropout_mask()
cell.reset_recurrent_dropout_mask()
def reset_states(self, states=None):
"""Reset the recorded states for the stateful RNN layer.
Can only be used when RNN layer is constructed with `stateful` = `True`.
Args:
states: Numpy arrays that contains the value for the initial state, which
will be feed to cell at the first time step. When the value is None,
zero filled numpy array will be created based on the cell state size.
Raises:
AttributeError: When the RNN layer is not stateful.
ValueError: When the batch size of the RNN layer is unknown.
ValueError: When the input numpy array is not compatible with the RNN
layer state, either size wise or dtype wise.
"""
if not self.stateful:
raise AttributeError('Layer must be stateful.')
spec_shape = None
if self.input_spec is not None:
spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape
if spec_shape is None:
# It is possible to have spec shape to be None, eg when construct a RNN
# with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
# it has 3 dim input, but not its full shape spec before build().
batch_size = None
else:
batch_size = spec_shape[1] if self.time_major else spec_shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
'of your input tensors: \n'
'- If using a Sequential model, '
'specify the batch size by passing '
'a `batch_input_shape` '
'argument to your first layer.\n'
'- If using the functional API, specify '
'the batch size by passing a '
'`batch_shape` argument to your Input layer.')
# initialize state if None
if tf.nest.flatten(self.states)[0] is None:
if getattr(self.cell, 'get_initial_state', None):
flat_init_state_values = tf.nest.flatten(self.cell.get_initial_state(
inputs=None, batch_size=batch_size,
dtype=self.dtype or backend.floatx()))
else:
flat_init_state_values = tf.nest.flatten(_generate_zero_filled_state(
batch_size, self.cell.state_size, self.dtype or backend.floatx()))
flat_states_variables = tf.nest.map_structure(
backend.variable, flat_init_state_values)
self.states = tf.nest.pack_sequence_as(self.cell.state_size,
flat_states_variables)
if not tf.nest.is_nested(self.states):
self.states = [self.states]
elif states is None:
for state, size in zip(tf.nest.flatten(self.states),
tf.nest.flatten(self.cell.state_size)):
backend.set_value(
state,
np.zeros([batch_size] + tf.TensorShape(size).as_list()))
else:
flat_states = tf.nest.flatten(self.states)
flat_input_states = tf.nest.flatten(states)
if len(flat_input_states) != len(flat_states):
raise ValueError('Layer ' + self.name + ' expects ' +
str(len(flat_states)) + ' states, '
'but it received ' + str(len(flat_input_states)) +
' state values. Input received: ' + str(states))
set_value_tuples = []
for i, (value, state) in enumerate(zip(flat_input_states,
flat_states)):
if value.shape != state.shape:
raise ValueError(
'State ' + str(i) + ' is incompatible with layer ' +
self.name + ': expected shape=' + str(
(batch_size, state)) + ', found shape=' + str(value.shape))
set_value_tuples.append((state, value))
backend.batch_set_value(set_value_tuples)
def get_config(self):
config = {
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
'unroll': self.unroll,
'time_major': self.time_major
}
if self._num_constants:
config['num_constants'] = self._num_constants
if self.zero_output_for_mask:
config['zero_output_for_mask'] = self.zero_output_for_mask
config['cell'] = generic_utils.serialize_keras_object(self.cell)
base_config = super(RNN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
num_constants = config.pop('num_constants', 0)
layer = cls(cell, **config)
layer._num_constants = num_constants
return layer
@property
def _trackable_saved_model_saver(self):
return layer_serialization.RNNSavedModelSaver(self)
@keras_export('keras.layers.AbstractRNNCell')
class AbstractRNNCell(Layer):
"""Abstract object representing an RNN cell.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This is the base class for implementing RNN cells with custom behavior.
Every `RNNCell` must have the properties below and implement `call` with
the signature `(output, next_state) = call(input, state)`.
Examples:
```python
class MinimalRNNCell(AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(MinimalRNNCell, self).__init__(**kwargs)
@property
def state_size(self):
return self.units
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
def call(self, inputs, states):
prev_output = states[0]
h = backend.dot(inputs, self.kernel)
output = h + backend.dot(prev_output, self.recurrent_kernel)
return output, output
```
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns. If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.
"""
def call(self, inputs, states):
"""The function that contains the logic for one RNN step calculation.
Args:
inputs: the input tensor, which is a slide from the overall RNN input by
the time dimension (usually the second dimension).
states: the state tensor from previous step, which has the same shape
as `(batch, state_size)`. In the case of timestep 0, it will be the
initial state user specified, or zero filled tensor otherwise.
Returns:
A tuple of two tensors:
1. output tensor for the current timestep, with size `output_size`.
2. state tensor for next step, which has the shape of `state_size`.
"""
raise NotImplementedError('Abstract method')
@property
def state_size(self):
"""size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers
or TensorShapes.
"""
raise NotImplementedError('Abstract method')
@property
def output_size(self):
"""Integer or TensorShape: size of outputs produced by this cell."""
raise NotImplementedError('Abstract method')
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
@doc_controls.do_not_generate_docs
class DropoutRNNCellMixin(object):
"""Object that hold dropout related fields for RNN Cell.
This class is not a standalone RNN cell. It suppose to be used with a RNN cell
by multiple inheritance. Any cell that mix with class should have following
fields:
dropout: a float number within range [0, 1). The ratio that the input
tensor need to dropout.
recurrent_dropout: a float number within range [0, 1). The ratio that the
recurrent state weights need to dropout.
This object will create and cache created dropout masks, and reuse them for
the incoming data, so that the same mask is used for every batch input.
"""
def __init__(self, *args, **kwargs):
self._create_non_trackable_mask_cache()
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _create_non_trackable_mask_cache(self):
"""Create the cache for dropout and recurrent dropout mask.
Note that the following two masks will be used in "graph function" mode,
e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask`
tensors will be generated differently than in the "graph function" case,
and they will be cached.
Also note that in graph mode, we still cache those masks only because the
RNN could be created with `unroll=True`. In that case, the `cell.call()`
function will be invoked multiple times, and we want to ensure same mask
is used every time.
Also the caches are created without tracking. Since they are not picklable
by python when deepcopy, we don't want `layer._obj_reference_counts_dict`
to track it by default.
"""
self._dropout_mask_cache = backend.ContextValueCache(
self._create_dropout_mask)
self._recurrent_dropout_mask_cache = backend.ContextValueCache(
self._create_recurrent_dropout_mask)
def reset_dropout_mask(self):
"""Reset the cached dropout masks if any.
This is important for the RNN layer to invoke this in it `call()` method so
that the cached mask is cleared before calling the `cell.call()`. The mask
should be cached across the timestep within the same batch, but shouldn't
be cached between batches. Otherwise it will introduce unreasonable bias
against certain index of data within the batch.
"""
self._dropout_mask_cache.clear()
def reset_recurrent_dropout_mask(self):
"""Reset the cached recurrent dropout masks if any.
This is important for the RNN layer to invoke this in it call() method so
that the cached mask is cleared before calling the cell.call(). The mask
should be cached across the timestep within the same batch, but shouldn't
be cached between batches. Otherwise it will introduce unreasonable bias
against certain index of data within the batch.
"""
self._recurrent_dropout_mask_cache.clear()
def _create_dropout_mask(self, inputs, training, count=1):
return _generate_dropout_mask(
tf.ones_like(inputs),
self.dropout,
training=training,
count=count)
def _create_recurrent_dropout_mask(self, inputs, training, count=1):
return _generate_dropout_mask(
tf.ones_like(inputs),
self.recurrent_dropout,
training=training,
count=count)
def get_dropout_mask_for_cell(self, inputs, training, count=1):
"""Get the dropout mask for RNN cell's input.
It will create mask based on context if there isn't any existing cached
mask. If a new mask is generated, it will update the cache in the cell.
Args:
inputs: The input tensor whose shape will be used to generate dropout
mask.
training: Boolean tensor, whether its in training mode, dropout will be
ignored in non-training mode.
count: Int, how many dropout mask will be generated. It is useful for cell
that has internal weights fused together.
Returns:
List of mask tensor, generated or cached mask based on context.
"""
if self.dropout == 0:
return None
init_kwargs = dict(inputs=inputs, training=training, count=count)
return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
"""Get the recurrent dropout mask for RNN cell.
It will create mask based on context if there isn't any existing cached
mask. If a new mask is generated, it will update the cache in the cell.
Args:
inputs: The input tensor whose shape will be used to generate dropout
mask.
training: Boolean tensor, whether its in training mode, dropout will be
ignored in non-training mode.
count: Int, how many dropout mask will be generated. It is useful for cell
that has internal weights fused together.
Returns:
List of mask tensor, generated or cached mask based on context.
"""
if self.recurrent_dropout == 0:
return None
init_kwargs = dict(inputs=inputs, training=training, count=count)
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
def __getstate__(self):
# Used for deepcopy. The caching can't be pickled by python, since it will
# contain tensor and graph.
state = super(DropoutRNNCellMixin, self).__getstate__()
state.pop('_dropout_mask_cache', None)
state.pop('_recurrent_dropout_mask_cache', None)
return state
def __setstate__(self, state):
state['_dropout_mask_cache'] = backend.ContextValueCache(
self._create_dropout_mask)
state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache(
self._create_recurrent_dropout_mask)
super(DropoutRNNCellMixin, self).__setstate__(state)
@keras_export('keras.layers.SimpleRNNCell')
class SimpleRNNCell(DropoutRNNCellMixin, Layer):
"""Cell class for SimpleRNN.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
`tf.keras.layer.SimpleRNN` processes the whole sequence.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector. Default: `zeros`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector. Default:
`None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector. Default:
`None`.
dropout: Float between 0 and 1. Fraction of the units to drop for the linear
transformation of the inputs. Default: 0.
recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for
the linear transformation of the recurrent state. Default: 0.
Call arguments:
inputs: A 2D tensor, with shape of `[batch, feature]`.
states: A 2D tensor with shape of `[batch, units]`, which is the state from
the previous time step. For timestep 0, the initial state provided by user
will be feed to cell.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4))
output = rnn(inputs) # The output has shape `[32, 4]`.
rnn = tf.keras.layers.RNN(
tf.keras.layers.SimpleRNNCell(4),
return_sequences=True,
return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
if units < 0:
raise ValueError(f'Received an invalid value for units, expected '
f'a positive integer, got {units}.')
# By default use cached variable under v2 mode, see b/143699808.
if tf.compat.v1.executing_eagerly_outside_functions():
self._enable_caching_device = kwargs.pop('enable_caching_device', True)
else:
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
super(SimpleRNNCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.state_size = self.units
self.output_size = self.units
@tf_utils.shape_type_conversion
def build(self, input_shape):
default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
self.bias = self.add_weight(
shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
prev_output = states[0] if tf.nest.is_nested(states) else states
dp_mask = self.get_dropout_mask_for_cell(inputs, training)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
prev_output, training)
if dp_mask is not None:
h = backend.dot(inputs * dp_mask, self.kernel)
else:
h = backend.dot(inputs, self.kernel)
if self.bias is not None:
h = backend.bias_add(h, self.bias)
if rec_dp_mask is not None:
prev_output = prev_output * rec_dp_mask
output = h + backend.dot(prev_output, self.recurrent_kernel)
if self.activation is not None:
output = self.activation(output)
new_state = [output] if tf.nest.is_nested(states) else output
return output, new_state
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout
}
config.update(_config_for_enable_caching_device(self))
base_config = super(SimpleRNNCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@keras_export('keras.layers.SimpleRNN')
class SimpleRNN(RNN):
"""Fully-connected RNN where the output is to be fed back to input.
See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
for details about the usage of RNN API.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, (default `True`), whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs. Default:
`glorot_uniform`.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
Default: `orthogonal`.
bias_initializer: Initializer for the bias vector. Default: `zeros`.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_regularizer: Regularizer function applied to the
`recurrent_kernel` weights matrix. Default: `None`.
bias_regularizer: Regularizer function applied to the bias vector. Default:
`None`.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation"). Default: `None`.
kernel_constraint: Constraint function applied to the `kernel` weights
matrix. Default: `None`.
recurrent_constraint: Constraint function applied to the `recurrent_kernel`
weights matrix. Default: `None`.
bias_constraint: Constraint function applied to the bias vector. Default:
`None`.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
Default: 0.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the
recurrent state. Default: 0.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default: `False`.
return_state: Boolean. Whether to return the last state
in addition to the output. Default: `False`
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
Call arguments:
inputs: A 3D tensor, with shape `[batch, timesteps, feature]`.
mask: Binary tensor of shape `[batch, timesteps]` indicating whether
a given timestep should be masked. An individual `True` entry indicates
that the corresponding timestep should be utilized, while a `False` entry
indicates that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
Examples:
```python
inputs = np.random.random([32, 10, 8]).astype(np.float32)
simple_rnn = tf.keras.layers.SimpleRNN(4)
output = simple_rnn(inputs) # The output has shape `[32, 4]`.
simple_rnn = tf.keras.layers.SimpleRNN(
4, return_sequences=True, return_state=True)
# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)
```
"""
def __init__(self,
units,
activation='tanh',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
if 'implementation' in kwargs:
kwargs.pop('implementation')
logging.warning('The `implementation` argument '
'in `SimpleRNN` has been deprecated. '
'Please remove it from your layer call.')
if 'enable_caching_device' in kwargs:
cell_kwargs = {'enable_caching_device':
kwargs.pop('enable_caching_device')}
else:
cell_kwargs = {}
cell = SimpleRNNCell(
units,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True),
**cell_kwargs)
super(SimpleRNN, self).__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
def call(self, inputs, mask=None, training=None, initial_state=None):
return super(SimpleRNN, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout
}
base_config = super(SimpleRNN, self).get_config()
config.update(_config_for_enable_caching_device(self.cell))
del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
if 'implementation' in config:
config.pop('implementation')
return cls(**config)
@keras_export(v1=['keras.layers.GRUCell'])
class GRUCell(DropoutRNNCellMixin, Layer):
"""Cell class for the GRU layer.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: hard sigmoid (`hard_sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). False = "before" (default),
True = "after" (CuDNN compatible).
Call arguments:
inputs: A 2D tensor.
states: List of state tensors corresponding to the previous timestep.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
"""
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
reset_after=False,
**kwargs):
if units < 0:
raise ValueError(f'Received an invalid value for units, expected '
f'a positive integer, got {units}.')
# By default use cached variable under v2 mode, see b/143699808.
if tf.compat.v1.executing_eagerly_outside_functions():
self._enable_caching_device = kwargs.pop('enable_caching_device', True)
else:
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
super(GRUCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
implementation = kwargs.pop('implementation', 1)
if self.recurrent_dropout != 0 and implementation != 1:
logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
self.implementation = 1
else:
self.implementation = implementation
self.reset_after = reset_after
self.state_size = self.units
self.output_size = self.units
@tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
default_caching_device = _caching_device(self)
self.kernel = self.add_weight(
shape=(input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 3),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
if not self.reset_after:
bias_shape = (3 * self.units,)
else:
# separate biases for input and recurrent kernels
# Note: the shape is intentionally different from CuDNNGRU biases
# `(2 * 3 * self.units,)`, so that we can distinguish the classes
# when loading and converting saved weights.
bias_shape = (2, 3 * self.units)
self.bias = self.add_weight(shape=bias_shape,
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
def call(self, inputs, states, training=None):
h_tm1 = states[0] if tf.nest.is_nested(states) else states # previous memory
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
h_tm1, training, count=3)
if self.use_bias:
if not self.reset_after:
input_bias, recurrent_bias = self.bias, None
else:
input_bias, recurrent_bias = tf.unstack(self.bias)
if self.implementation == 1:
if 0. < self.dropout < 1.:
inputs_z = inputs * dp_mask[0]
inputs_r = inputs * dp_mask[1]
inputs_h = inputs * dp_mask[2]
else:
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
x_z = backend.dot(inputs_z, self.kernel[:, :self.units])
x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2])
x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:])
if self.use_bias:
x_z = backend.bias_add(x_z, input_bias[:self.units])
x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2])
x_h = backend.bias_add(x_h, input_bias[self.units * 2:])
if 0. < self.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0]
h_tm1_r = h_tm1 * rec_dp_mask[1]
h_tm1_h = h_tm1 * rec_dp_mask[2]
else:
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])
recurrent_r = backend.dot(
h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2])
if self.reset_after and self.use_bias:
recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units])
recurrent_r = backend.bias_add(
recurrent_r, recurrent_bias[self.units:self.units * 2])
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
# reset gate applied after/before matrix multiplication
if self.reset_after:
recurrent_h = backend.dot(
h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
if self.use_bias:
recurrent_h = backend.bias_add(
recurrent_h, recurrent_bias[self.units * 2:])
recurrent_h = r * recurrent_h
else:
recurrent_h = backend.dot(
r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:])
hh = self.activation(x_h + recurrent_h)
else:
if 0. < self.dropout < 1.:
inputs = inputs * dp_mask[0]
# inputs projected by all gate matrices at once
matrix_x = backend.dot(inputs, self.kernel)
if self.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x = backend.bias_add(matrix_x, input_bias)
x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1)
if self.reset_after:
# hidden state projected by all gate matrices at once
matrix_inner = backend.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
matrix_inner = backend.bias_add(matrix_inner, recurrent_bias)
else:
# hidden state projected separately for update/reset and new
matrix_inner = backend.dot(
h_tm1, self.recurrent_kernel[:, :2 * self.units])
recurrent_z, recurrent_r, recurrent_h = tf.split(
matrix_inner, [self.units, self.units, -1], axis=-1)
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
if self.reset_after:
recurrent_h = r * recurrent_h
else:
recurrent_h = backend.dot(
r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
hh = self.activation(x_h + recurrent_h)
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
new_state = [h] if tf.nest.is_nested(states) else h
return h, new_state
def get_config(self):
config = {
'units': self.units,
'activation': activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout,
'implementation': self.implementation,
'reset_after': self.reset_after
}
config.update(_config_for_enable_caching_device(self))
base_config = super(GRUCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
@keras_export(v1=['keras.layers.GRU'])
class GRU(RNN):
"""Gated Recurrent Unit - Cho et al. 2014.
There are two variants. The default one is based on 1406.1078v3 and
has reset gate applied to hidden state before matrix multiplication. The
other one is based on original 1406.1078v1 and has the order reversed.
The second variant is compatible with CuDNNGRU (GPU-only) and allows
inference on CPU. Thus it has separate biases for `kernel` and
`recurrent_kernel`. Use `'reset_after'=True` and
`recurrent_activation='sigmoid'`.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: hard sigmoid (`hard_sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix, used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
time_major: The shape format of the `inputs` and `outputs` tensors.
If True, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of the
RNN calculation. However, most TensorFlow data is batch-major, so by
default this function accepts input and emits output in batch-major
form.
reset_after: GRU convention (whether to apply reset gate after or
before matrix multiplication). False = "before" (default),
True = "after" (CuDNN compatible).
Call arguments:
inputs: A 3D tensor.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
a given timestep should be masked. An individual `True` entry indicates
that the corresponding timestep should be utilized, while a `False`
entry indicates that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
"""
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
reset_after=False,
**kwargs):
implementation = kwargs.pop('implementation', 1)
if implementation == 0:
logging.warning('`implementation=0` has been deprecated, '
'and now defaults to `implementation=1`.'
'Please update your layer call.')
if 'enable_caching_device' in kwargs:
cell_kwargs = {'enable_caching_device':
kwargs.pop('enable_caching_device')}
else:
cell_kwargs = {}
cell = GRUCell(
units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation,
reset_after=reset_after,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True),
**cell_kwargs)
super(GRU, self).__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
def call(self, inputs, mask=None, training=None, initial_state=None):
return super(GRU, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def recurrent_activation(self):
return self.cell.recurrent_activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
@property
def implementation(self):
return self.cell.implementation
@property
def reset_after(self):
return self.cell.reset_after
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout,
'implementation':
self.implementation,
'reset_after':
self.reset_after
}
config.update(_config_for_enable_caching_device(self.cell))
base_config = super(GRU, self).get_config()
del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
if 'implementation' in config and config['implementation'] == 0:
config['implementation'] = 1
return cls(**config)
@keras_export(v1=['keras.layers.LSTMCell'])
class LSTMCell(DropoutRNNCellMixin, Layer):
"""Cell class for the LSTM layer.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: hard sigmoid (`hard_sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs.
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al., 2015](
http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
Call arguments:
inputs: A 2D tensor.
states: List of state tensors corresponding to the previous timestep.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. Only relevant when `dropout` or
`recurrent_dropout` is used.
"""
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
if units < 0:
raise ValueError(f'Received an invalid value for units, expected '
f'a positive integer, got {units}.')
# By default use cached variable under v2 mode, see b/143699808.
if tf.compat.v1.executing_eagerly_outside_functions():
self._enable_caching_device = kwargs.pop('enable_caching_device', True)
else:
self._enable_caching_device = kwargs.pop('enable_caching_device', False)
super(LSTMCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.unit_forget_bias = unit_forget_bias
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
implementation = kwargs.pop('implementation', 1)
if self.recurrent_dropout != 0 and implementation != 1:
logging.debug(RECURRENT_DROPOUT_WARNING_MSG)
self.implementation = 1
else:
self.implementation = implementation
self.state_size = [self.units, self.units]
self.output_size = self.units
@tf_utils.shape_type_conversion
def build(self, input_shape):
default_caching_device = _caching_device(self)
input_dim = input_shape[-1]
self.kernel = self.add_weight(
shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
caching_device=default_caching_device)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint,
caching_device=default_caching_device)
if self.use_bias:
if self.unit_forget_bias:
def bias_initializer(_, *args, **kwargs):
return backend.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.get('ones')((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 2,), *args, **kwargs),
])
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(
shape=(self.units * 4,),
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
caching_device=default_caching_device)
else:
self.bias = None
self.built = True
def _compute_carry_and_output(self, x, h_tm1, c_tm1):
"""Computes carry and output using split kernels."""
x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
f = self.recurrent_activation(x_f + backend.dot(
h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
c = f * c_tm1 + i * self.activation(x_c + backend.dot(
h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
o = self.recurrent_activation(
x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
return c, o
def _compute_carry_and_output_fused(self, z, c_tm1):
"""Computes carry and output using fused kernels."""
z0, z1, z2, z3 = z
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
return c, o
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
h_tm1, training, count=4)
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
k_i, k_f, k_c, k_o = tf.split(
self.kernel, num_or_size_splits=4, axis=1)
x_i = backend.dot(inputs_i, k_i)
x_f = backend.dot(inputs_f, k_f)
x_c = backend.dot(inputs_c, k_c)
x_o = backend.dot(inputs_o, k_o)
if self.use_bias:
b_i, b_f, b_c, b_o = tf.split(
self.bias, num_or_size_splits=4, axis=0)
x_i = backend.bias_add(x_i, b_i)
x_f = backend.bias_add(x_f, b_f)
x_c = backend.bias_add(x_c, b_c)
x_o = backend.bias_add(x_o, b_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
x = (x_i, x_f, x_c, x_o)
h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
else:
if 0. < self.dropout < 1.:
inputs = inputs * dp_mask[0]
z = backend.dot(inputs, self.kernel)
z += backend.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = backend.bias_add(z, self.bias)
z = tf.split(z, num_or_size_splits=4, axis=1)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
h = o * self.activation(c)
return h, [h, c]
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'unit_forget_bias':
self.unit_forget_bias,
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout,
'implementation':
self.implementation
}
config.update(_config_for_enable_caching_device(self))
base_config = super(LSTMCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return list(_generate_zero_filled_state_for_cell(
self, inputs, batch_size, dtype))
@keras_export('keras.experimental.PeepholeLSTMCell')
class PeepholeLSTMCell(LSTMCell):
"""Equivalent to LSTMCell class but adds peephole connections.
Peephole connections allow the gates to utilize the previous internal state as
well as the previous hidden state (which is what LSTMCell is limited to).
This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
From [Gers et al., 2002](
http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):
"We find that LSTM augmented by 'peephole connections' from its internal
cells to its multiplicative gates can learn the fine distinction between
sequences of spikes spaced either 50 or 49 time steps apart without the help
of any short training exemplars."
The peephole implementation is based on:
[Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)
Example:
```python
# Create 2 PeepholeLSTMCells
peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
# Create a layer composed sequentially of the peephole LSTM cells.
layer = RNN(peephole_lstm_cells)
input = keras.Input((timesteps, input_dim))
output = layer(input)
```
"""
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated '
'and will be removed in a future version. '
'Please use tensorflow_addons.rnn.PeepholeLSTMCell '
'instead.')
super(PeepholeLSTMCell, self).__init__(
units=units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
unit_forget_bias=unit_forget_bias,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=kwargs.pop('implementation', 1),
**kwargs)
def build(self, input_shape):
super(PeepholeLSTMCell, self).build(input_shape)
# The following are the weight matrices for the peephole connections. These
# are multiplied with the previous internal state during the computation of
# carry and output.
self.input_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name='input_gate_peephole_weights',
initializer=self.kernel_initializer)
self.forget_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name='forget_gate_peephole_weights',
initializer=self.kernel_initializer)
self.output_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name='output_gate_peephole_weights',
initializer=self.kernel_initializer)
def _compute_carry_and_output(self, x, h_tm1, c_tm1):
x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(x_f + backend.dot(
h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(x_c + backend.dot(
h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
o = self.recurrent_activation(
x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
self.output_gate_peephole_weights * c)
return c, o
def _compute_carry_and_output_fused(self, z, c_tm1):
z0, z1, z2, z3 = z
i = self.recurrent_activation(z0 +
self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(z1 +
self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
return c, o
@keras_export(v1=['keras.layers.LSTM'])
class LSTM(RNN):
"""Long Short-Term Memory layer - Hochreiter 1997.
Note that this cell is not optimized for performance on GPU. Please use
`tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU.
Args:
units: Positive integer, dimensionality of the output space.
activation: Activation function to use.
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step.
Default: hard sigmoid (`hard_sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state.
bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al., 2015](
http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
kernel_constraint: Constraint function applied to
the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix.
bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
return_sequences: Boolean. Whether to return the last output.
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
go_backwards: Boolean (default False).
If True, process the input sequence backwards and return the
reversed sequence.
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
unroll: Boolean (default False).
If True, the network will be unrolled,
else a symbolic loop will be used.
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
time_major: The shape format of the `inputs` and `outputs` tensors.
If True, the inputs and outputs will be in shape
`(timesteps, batch, ...)`, whereas in the False case, it will be
`(batch, timesteps, ...)`. Using `time_major = True` is a bit more
efficient because it avoids transposes at the beginning and end of the
RNN calculation. However, most TensorFlow data is batch-major, so by
default this function accepts input and emits output in batch-major
form.
Call arguments:
inputs: A 3D tensor.
mask: Binary tensor of shape `(samples, timesteps)` indicating whether
a given timestep should be masked. An individual `True` entry indicates
that the corresponding timestep should be utilized, while a `False`
entry indicates that the corresponding timestep should be ignored.
training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. This argument is passed to the cell
when calling it. This is only relevant if `dropout` or
`recurrent_dropout` is used.
initial_state: List of initial state tensors to be passed to the first
call of the cell.
"""
def __init__(self,
units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
**kwargs):
implementation = kwargs.pop('implementation', 1)
if implementation == 0:
logging.warning('`implementation=0` has been deprecated, '
'and now defaults to `implementation=1`.'
'Please update your layer call.')
if 'enable_caching_device' in kwargs:
cell_kwargs = {'enable_caching_device':
kwargs.pop('enable_caching_device')}
else:
cell_kwargs = {}
cell = LSTMCell(
units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
unit_forget_bias=unit_forget_bias,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation,
dtype=kwargs.get('dtype'),
trainable=kwargs.get('trainable', True),
**cell_kwargs)
super(LSTM, self).__init__(
cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.input_spec = [InputSpec(ndim=3)]
def call(self, inputs, mask=None, training=None, initial_state=None):
return super(LSTM, self).call(
inputs, mask=mask, training=training, initial_state=initial_state)
@property
def units(self):
return self.cell.units
@property
def activation(self):
return self.cell.activation
@property
def recurrent_activation(self):
return self.cell.recurrent_activation
@property
def use_bias(self):
return self.cell.use_bias
@property
def kernel_initializer(self):
return self.cell.kernel_initializer
@property
def recurrent_initializer(self):
return self.cell.recurrent_initializer
@property
def bias_initializer(self):
return self.cell.bias_initializer
@property
def unit_forget_bias(self):
return self.cell.unit_forget_bias
@property
def kernel_regularizer(self):
return self.cell.kernel_regularizer
@property
def recurrent_regularizer(self):
return self.cell.recurrent_regularizer
@property
def bias_regularizer(self):
return self.cell.bias_regularizer
@property
def kernel_constraint(self):
return self.cell.kernel_constraint
@property
def recurrent_constraint(self):
return self.cell.recurrent_constraint
@property
def bias_constraint(self):
return self.cell.bias_constraint
@property
def dropout(self):
return self.cell.dropout
@property
def recurrent_dropout(self):
return self.cell.recurrent_dropout
@property
def implementation(self):
return self.cell.implementation
def get_config(self):
config = {
'units':
self.units,
'activation':
activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
'use_bias':
self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer':
initializers.serialize(self.bias_initializer),
'unit_forget_bias':
self.unit_forget_bias,
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer':
regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'kernel_constraint':
constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint':
constraints.serialize(self.bias_constraint),
'dropout':
self.dropout,
'recurrent_dropout':
self.recurrent_dropout,
'implementation':
self.implementation
}
config.update(_config_for_enable_caching_device(self.cell))
base_config = super(LSTM, self).get_config()
del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
if 'implementation' in config and config['implementation'] == 0:
config['implementation'] = 1
return cls(**config)
def _generate_dropout_mask(ones, rate, training=None, count=1):
def dropped_inputs():
return backend.dropout(ones, rate)
if count > 1:
return [
backend.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(count)
]
return backend.in_train_phase(dropped_inputs, ones, training=training)
def _standardize_args(inputs, initial_state, constants, num_constants):
"""Standardizes `__call__` to a single list of tensor inputs.
When running a model loaded from a file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__()` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
Args:
inputs: Tensor or list/tuple of tensors. which may include constants
and initial states. In that case `num_constant` must be specified.
initial_state: Tensor or list of tensors or None, initial states.
constants: Tensor or list of tensors or None, constant tensors.
num_constants: Expected number of constants (if constants are passed as
part of the `inputs` list.
Returns:
inputs: Single tensor or tuple of tensors.
initial_state: List of tensors or None.
constants: List of tensors or None.
"""
if isinstance(inputs, list):
# There are several situations here:
# In the graph mode, __call__ will be only called once. The initial_state
# and constants could be in inputs (from file loading).
# In the eager mode, __call__ will be called twice, once during
# rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
# model.fit/train_on_batch/predict with real np data. In the second case,
# the inputs will contain initial_state and constants as eager tensor.
#
# For either case, the real input is the first item in the list, which
# could be a nested structure itself. Then followed by initial_states, which
# could be a list of items, or list of list if the initial_state is complex
# structure, and finally followed by constants which is a flat list.
assert initial_state is None and constants is None
if num_constants:
constants = inputs[-num_constants:]
inputs = inputs[:-num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[:1]
if len(inputs) > 1:
inputs = tuple(inputs)
else:
inputs = inputs[0]
def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]
initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)
return inputs, initial_state, constants
def _is_multiple_state(state_size):
"""Check whether the state_size contains multiple states."""
return (hasattr(state_size, '__len__') and
not isinstance(state_size, tf.TensorShape))
def _generate_zero_filled_state_for_cell(cell, inputs, batch_size, dtype):
if inputs is not None:
batch_size = tf.shape(inputs)[0]
dtype = inputs.dtype
return _generate_zero_filled_state(batch_size, cell.state_size, dtype)
def _generate_zero_filled_state(batch_size_tensor, state_size, dtype):
"""Generate a zero filled tensor with shape [batch_size, state_size]."""
if batch_size_tensor is None or dtype is None:
raise ValueError(
'batch_size and dtype cannot be None while constructing initial state: '
'batch_size={}, dtype={}'.format(batch_size_tensor, dtype))
def create_zeros(unnested_state_size):
flat_dims = tf.TensorShape(unnested_state_size).as_list()
init_state_size = [batch_size_tensor] + flat_dims
return tf.zeros(init_state_size, dtype=dtype)
if tf.nest.is_nested(state_size):
return tf.nest.map_structure(create_zeros, state_size)
else:
return create_zeros(state_size)
def _caching_device(rnn_cell):
"""Returns the caching device for the RNN variable.
This is useful for distributed training, when variable is not located as same
device as the training worker. By enabling the device cache, this allows
worker to read the variable once and cache locally, rather than read it every
time step from remote when it is needed.
Note that this is assuming the variable that cell needs for each time step is
having the same value in the forward path, and only gets updated in the
backprop. It is true for all the default cells (SimpleRNN, GRU, LSTM). If the
cell body relies on any variable that gets updated every time step, then
caching device will cause it to read the stall value.
Args:
rnn_cell: the rnn cell instance.
"""
if tf.executing_eagerly():
# caching_device is not supported in eager mode.
return None
if not getattr(rnn_cell, '_enable_caching_device', False):
return None
# Don't set a caching device when running in a loop, since it is possible that
# train steps could be wrapped in a tf.while_loop. In that scenario caching
# prevents forward computations in loop iterations from re-reading the
# updated weights.
if control_flow_util.IsInWhileLoop(tf.compat.v1.get_default_graph()):
logging.warning(
'Variable read device caching has been disabled because the '
'RNN is in tf.while_loop loop context, which will cause '
'reading stalled value in forward path. This could slow down '
'the training due to duplicated variable reads. Please '
'consider updating your code to remove tf.while_loop if possible.')
return None
if (rnn_cell._dtype_policy.compute_dtype !=
rnn_cell._dtype_policy.variable_dtype):
logging.warning(
'Variable read device caching has been disabled since it '
'doesn\'t work with the mixed precision API. This is '
'likely to cause a slowdown for RNN training due to '
'duplicated read of variable for each timestep, which '
'will be significant in a multi remote worker setting. '
'Please consider disabling mixed precision API if '
'the performance has been affected.')
return None
# Cache the value on the device that access the variable.
return lambda op: op.device
def _config_for_enable_caching_device(rnn_cell):
"""Return the dict config for RNN cell wrt to enable_caching_device field.
Since enable_caching_device is a internal implementation detail for speed up
the RNN variable read when running on the multi remote worker setting, we
don't want this config to be serialized constantly in the JSON. We will only
serialize this field when a none default value is used to create the cell.
Args:
rnn_cell: the RNN cell for serialize.
Returns:
A dict which contains the JSON config for enable_caching_device value or
empty dict if the enable_caching_device value is same as the default value.
"""
default_enable_caching_device = tf.compat.v1.executing_eagerly_outside_functions()
if rnn_cell._enable_caching_device != default_enable_caching_device:
return {'enable_caching_device': rnn_cell._enable_caching_device}
return {}
Classes
class AbstractRNNCell (trainable=True, name=None, dtype=None, dynamic=False, **kwargs)
-
Abstract object representing an RNN cell.
See the Keras RNN API guide for details about the usage of RNN API.
This is the base class for implementing RNN cells with custom behavior.
Every
RNNCell
must have the properties below and implementcall
with the signature(output, next_state) = call(input, state)
.Examples:
class MinimalRNNCell(AbstractRNNCell): def __init__(self, units, **kwargs): self.units = units super(MinimalRNNCell, self).__init__(**kwargs) @property def state_size(self): return self.units def build(self, input_shape): self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel') self.recurrent_kernel = self.add_weight( shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') self.built = True def call(self, inputs, states): prev_output = states[0] h = backend.dot(inputs, self.kernel) output = h + backend.dot(prev_output, self.recurrent_kernel) return output, output
This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. This definition refers to a horizontal array of such units.
An RNN cell, in the most abstract setting, is anything that has a state and performs some operation that takes a matrix of inputs. This operation results in an output matrix with
self.output_size
columns. Ifself.state_size
is an integer, this operation also results in a new state matrix withself.state_size
columns. Ifself.state_size
is a (possibly nested tuple of) TensorShape object(s), then it should return a matching structure of Tensors having shape[batch_size].concatenate(s)
for eachs
inself.batch_size
.Expand source code
class AbstractRNNCell(Layer): """Abstract object representing an RNN cell. See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) for details about the usage of RNN API. This is the base class for implementing RNN cells with custom behavior. Every `RNNCell` must have the properties below and implement `call` with the signature `(output, next_state) = call(input, state)`. Examples: ```python class MinimalRNNCell(AbstractRNNCell): def __init__(self, units, **kwargs): self.units = units super(MinimalRNNCell, self).__init__(**kwargs) @property def state_size(self): return self.units def build(self, input_shape): self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel') self.recurrent_kernel = self.add_weight( shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') self.built = True def call(self, inputs, states): prev_output = states[0] h = backend.dot(inputs, self.kernel) output = h + backend.dot(prev_output, self.recurrent_kernel) return output, output ``` This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. This definition refers to a horizontal array of such units. An RNN cell, in the most abstract setting, is anything that has a state and performs some operation that takes a matrix of inputs. This operation results in an output matrix with `self.output_size` columns. If `self.state_size` is an integer, this operation also results in a new state matrix with `self.state_size` columns. If `self.state_size` is a (possibly nested tuple of) TensorShape object(s), then it should return a matching structure of Tensors having shape `[batch_size].concatenate(s)` for each `s` in `self.batch_size`. """ def call(self, inputs, states): """The function that contains the logic for one RNN step calculation. Args: inputs: the input tensor, which is a slide from the overall RNN input by the time dimension (usually the second dimension). states: the state tensor from previous step, which has the same shape as `(batch, state_size)`. In the case of timestep 0, it will be the initial state user specified, or zero filled tensor otherwise. Returns: A tuple of two tensors: 1. output tensor for the current timestep, with size `output_size`. 2. state tensor for next step, which has the shape of `state_size`. """ raise NotImplementedError('Abstract method') @property def state_size(self): """size(s) of state(s) used by this cell. It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes. """ raise NotImplementedError('Abstract method') @property def output_size(self): """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError('Abstract method') def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
Ancestors
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
- keras.layers.rnn_cell_wrapper_v2._RNNCellWrapperV2
Instance variables
var output_size
-
Integer or TensorShape: size of outputs produced by this cell.
Expand source code
@property def output_size(self): """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError('Abstract method')
var state_size
-
size(s) of state(s) used by this cell.
It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes.
Expand source code
@property def state_size(self): """size(s) of state(s) used by this cell. It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes. """ raise NotImplementedError('Abstract method')
Methods
def call(self, inputs, states)
-
The function that contains the logic for one RNN step calculation.
Args
inputs
- the input tensor, which is a slide from the overall RNN input by the time dimension (usually the second dimension).
states
- the state tensor from previous step, which has the same shape
as
(batch, state_size)
. In the case of timestep 0, it will be the initial state user specified, or zero filled tensor otherwise.
Returns
A tuple of two tensors: 1. output tensor for the current timestep, with size
output_size
. 2. state tensor for next step, which has the shape ofstate_size
.Expand source code
def call(self, inputs, states): """The function that contains the logic for one RNN step calculation. Args: inputs: the input tensor, which is a slide from the overall RNN input by the time dimension (usually the second dimension). states: the state tensor from previous step, which has the same shape as `(batch, state_size)`. In the case of timestep 0, it will be the initial state user specified, or zero filled tensor otherwise. Returns: A tuple of two tensors: 1. output tensor for the current timestep, with size `output_size`. 2. state tensor for next step, which has the shape of `state_size`. """ raise NotImplementedError('Abstract method')
def get_initial_state(self, inputs=None, batch_size=None, dtype=None)
-
Expand source code
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
Inherited members
Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class DropoutRNNCellMixin (*args, **kwargs)
-
Object that hold dropout related fields for RNN Cell.
This class is not a standalone RNN cell. It suppose to be used with a RNN cell by multiple inheritance. Any cell that mix with class should have following fields: dropout: a float number within range [0, 1). The ratio that the input tensor need to dropout. recurrent_dropout: a float number within range [0, 1). The ratio that the recurrent state weights need to dropout. This object will create and cache created dropout masks, and reuse them for the incoming data, so that the same mask is used for every batch input.
Expand source code
class DropoutRNNCellMixin(object): """Object that hold dropout related fields for RNN Cell. This class is not a standalone RNN cell. It suppose to be used with a RNN cell by multiple inheritance. Any cell that mix with class should have following fields: dropout: a float number within range [0, 1). The ratio that the input tensor need to dropout. recurrent_dropout: a float number within range [0, 1). The ratio that the recurrent state weights need to dropout. This object will create and cache created dropout masks, and reuse them for the incoming data, so that the same mask is used for every batch input. """ def __init__(self, *args, **kwargs): self._create_non_trackable_mask_cache() super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) @tf.__internal__.tracking.no_automatic_dependency_tracking def _create_non_trackable_mask_cache(self): """Create the cache for dropout and recurrent dropout mask. Note that the following two masks will be used in "graph function" mode, e.g. these masks are symbolic tensors. In eager mode, the `eager_*_mask` tensors will be generated differently than in the "graph function" case, and they will be cached. Also note that in graph mode, we still cache those masks only because the RNN could be created with `unroll=True`. In that case, the `cell.call()` function will be invoked multiple times, and we want to ensure same mask is used every time. Also the caches are created without tracking. Since they are not picklable by python when deepcopy, we don't want `layer._obj_reference_counts_dict` to track it by default. """ self._dropout_mask_cache = backend.ContextValueCache( self._create_dropout_mask) self._recurrent_dropout_mask_cache = backend.ContextValueCache( self._create_recurrent_dropout_mask) def reset_dropout_mask(self): """Reset the cached dropout masks if any. This is important for the RNN layer to invoke this in it `call()` method so that the cached mask is cleared before calling the `cell.call()`. The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ self._dropout_mask_cache.clear() def reset_recurrent_dropout_mask(self): """Reset the cached recurrent dropout masks if any. This is important for the RNN layer to invoke this in it call() method so that the cached mask is cleared before calling the cell.call(). The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ self._recurrent_dropout_mask_cache.clear() def _create_dropout_mask(self, inputs, training, count=1): return _generate_dropout_mask( tf.ones_like(inputs), self.dropout, training=training, count=count) def _create_recurrent_dropout_mask(self, inputs, training, count=1): return _generate_dropout_mask( tf.ones_like(inputs), self.recurrent_dropout, training=training, count=count) def get_dropout_mask_for_cell(self, inputs, training, count=1): """Get the dropout mask for RNN cell's input. It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell. Args: inputs: The input tensor whose shape will be used to generate dropout mask. training: Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode. count: Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together. Returns: List of mask tensor, generated or cached mask based on context. """ if self.dropout == 0: return None init_kwargs = dict(inputs=inputs, training=training, count=count) return self._dropout_mask_cache.setdefault(kwargs=init_kwargs) def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): """Get the recurrent dropout mask for RNN cell. It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell. Args: inputs: The input tensor whose shape will be used to generate dropout mask. training: Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode. count: Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together. Returns: List of mask tensor, generated or cached mask based on context. """ if self.recurrent_dropout == 0: return None init_kwargs = dict(inputs=inputs, training=training, count=count) return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs) def __getstate__(self): # Used for deepcopy. The caching can't be pickled by python, since it will # contain tensor and graph. state = super(DropoutRNNCellMixin, self).__getstate__() state.pop('_dropout_mask_cache', None) state.pop('_recurrent_dropout_mask_cache', None) return state def __setstate__(self, state): state['_dropout_mask_cache'] = backend.ContextValueCache( self._create_dropout_mask) state['_recurrent_dropout_mask_cache'] = backend.ContextValueCache( self._create_recurrent_dropout_mask) super(DropoutRNNCellMixin, self).__setstate__(state)
Subclasses
Methods
def get_dropout_mask_for_cell(self, inputs, training, count=1)
-
Get the dropout mask for RNN cell's input.
It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell.
Args
inputs
- The input tensor whose shape will be used to generate dropout mask.
training
- Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode.
count
- Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together.
Returns
List of mask tensor, generated or cached mask based on context.
Expand source code
def get_dropout_mask_for_cell(self, inputs, training, count=1): """Get the dropout mask for RNN cell's input. It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell. Args: inputs: The input tensor whose shape will be used to generate dropout mask. training: Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode. count: Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together. Returns: List of mask tensor, generated or cached mask based on context. """ if self.dropout == 0: return None init_kwargs = dict(inputs=inputs, training=training, count=count) return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1)
-
Get the recurrent dropout mask for RNN cell.
It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell.
Args
inputs
- The input tensor whose shape will be used to generate dropout mask.
training
- Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode.
count
- Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together.
Returns
List of mask tensor, generated or cached mask based on context.
Expand source code
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): """Get the recurrent dropout mask for RNN cell. It will create mask based on context if there isn't any existing cached mask. If a new mask is generated, it will update the cache in the cell. Args: inputs: The input tensor whose shape will be used to generate dropout mask. training: Boolean tensor, whether its in training mode, dropout will be ignored in non-training mode. count: Int, how many dropout mask will be generated. It is useful for cell that has internal weights fused together. Returns: List of mask tensor, generated or cached mask based on context. """ if self.recurrent_dropout == 0: return None init_kwargs = dict(inputs=inputs, training=training, count=count) return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
def reset_dropout_mask(self)
-
Reset the cached dropout masks if any.
This is important for the RNN layer to invoke this in it
call()
method so that the cached mask is cleared before calling thecell.call()
. The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch.Expand source code
def reset_dropout_mask(self): """Reset the cached dropout masks if any. This is important for the RNN layer to invoke this in it `call()` method so that the cached mask is cleared before calling the `cell.call()`. The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ self._dropout_mask_cache.clear()
def reset_recurrent_dropout_mask(self)
-
Reset the cached recurrent dropout masks if any.
This is important for the RNN layer to invoke this in it call() method so that the cached mask is cleared before calling the cell.call(). The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch.
Expand source code
def reset_recurrent_dropout_mask(self): """Reset the cached recurrent dropout masks if any. This is important for the RNN layer to invoke this in it call() method so that the cached mask is cleared before calling the cell.call(). The mask should be cached across the timestep within the same batch, but shouldn't be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ self._recurrent_dropout_mask_cache.clear()
class GRU (units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, reset_after=False, **kwargs)
-
Gated Recurrent Unit - Cho et al. 2014.
There are two variants. The default one is based on 1406.1078v3 and has reset gate applied to hidden state before matrix multiplication. The other one is based on original 1406.1078v1 and has the order reversed.
The second variant is compatible with CuDNNGRU (GPU-only) and allows inference on CPU. Thus it has separate biases for
kernel
andrecurrent_kernel
. Use'reset_after'=True
andrecurrent_activation='sigmoid'
.Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). recurrent_activation
- Activation function to use
for the recurrent step.
Default: hard sigmoid (
hard_sigmoid
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, whether the layer uses a bias vector.
kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. bias_initializer
- Initializer for the bias vector.
kernel_regularizer
- Regularizer function applied to
the
kernel
weights matrix. recurrent_regularizer
- Regularizer function applied to
the
recurrent_kernel
weights matrix. bias_regularizer
- Regularizer function applied to the bias vector.
activity_regularizer
- Regularizer function applied to the output of the layer (its "activation")..
kernel_constraint
- Constraint function applied to
the
kernel
weights matrix. recurrent_constraint
- Constraint function applied to
the
recurrent_kernel
weights matrix. bias_constraint
- Constraint function applied to the bias vector.
dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state.
return_sequences
- Boolean. Whether to return the last output in the output sequence, or the full sequence.
return_state
- Boolean. Whether to return the last state in addition to the output.
go_backwards
- Boolean (default False). If True, process the input sequence backwards and return the reversed sequence.
stateful
- Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
unroll
- Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.
time_major
- The shape format of the
inputs
andoutputs
tensors. If True, the inputs and outputs will be in shape(timesteps, batch, …)
, whereas in the False case, it will be(batch, timesteps, …)
. Usingtime_major = True
is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. reset_after
- GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before" (default), True = "after" (CuDNN compatible).
Call arguments: inputs: A 3D tensor. mask: Binary tensor of shape
(samples, timesteps)
indicating whether a given timestep should be masked. An individualTrue
entry indicates that the corresponding timestep should be utilized, while aFalse
entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant ifdropout
orrecurrent_dropout
is used. initial_state: List of initial state tensors to be passed to the first call of the cell.Expand source code
class GRU(RNN): """Gated Recurrent Unit - Cho et al. 2014. There are two variants. The default one is based on 1406.1078v3 and has reset gate applied to hidden state before matrix multiplication. The other one is based on original 1406.1078v1 and has the order reversed. The second variant is compatible with CuDNNGRU (GPU-only) and allows inference on CPU. Thus it has separate biases for `kernel` and `recurrent_kernel`. Use `'reset_after'=True` and `recurrent_activation='sigmoid'`. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to the output of the layer (its "activation").. kernel_constraint: Constraint function applied to the `kernel` weights matrix. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. go_backwards: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence. stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll: Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major: The shape format of the `inputs` and `outputs` tensors. If True, the inputs and outputs will be in shape `(timesteps, batch, ...)`, whereas in the False case, it will be `(batch, timesteps, ...)`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. reset_after: GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before" (default), True = "after" (CuDNN compatible). Call arguments: inputs: A 3D tensor. mask: Binary tensor of shape `(samples, timesteps)` indicating whether a given timestep should be masked. An individual `True` entry indicates that the corresponding timestep should be utilized, while a `False` entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant if `dropout` or `recurrent_dropout` is used. initial_state: List of initial state tensors to be passed to the first call of the cell. """ def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, reset_after=False, **kwargs): implementation = kwargs.pop('implementation', 1) if implementation == 0: logging.warning('`implementation=0` has been deprecated, ' 'and now defaults to `implementation=1`.' 'Please update your layer call.') if 'enable_caching_device' in kwargs: cell_kwargs = {'enable_caching_device': kwargs.pop('enable_caching_device')} else: cell_kwargs = {} cell = GRUCell( units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, reset_after=reset_after, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True), **cell_kwargs) super(GRU, self).__init__( cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): return super(GRU, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @property def units(self): return self.cell.units @property def activation(self): return self.cell.activation @property def recurrent_activation(self): return self.cell.recurrent_activation @property def use_bias(self): return self.cell.use_bias @property def kernel_initializer(self): return self.cell.kernel_initializer @property def recurrent_initializer(self): return self.cell.recurrent_initializer @property def bias_initializer(self): return self.cell.bias_initializer @property def kernel_regularizer(self): return self.cell.kernel_regularizer @property def recurrent_regularizer(self): return self.cell.recurrent_regularizer @property def bias_regularizer(self): return self.cell.bias_regularizer @property def kernel_constraint(self): return self.cell.kernel_constraint @property def recurrent_constraint(self): return self.cell.recurrent_constraint @property def bias_constraint(self): return self.cell.bias_constraint @property def dropout(self): return self.cell.dropout @property def recurrent_dropout(self): return self.cell.recurrent_dropout @property def implementation(self): return self.cell.implementation @property def reset_after(self): return self.cell.reset_after def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'implementation': self.implementation, 'reset_after': self.reset_after } config.update(_config_for_enable_caching_device(self.cell)) base_config = super(GRU, self).get_config() del base_config['cell'] return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): if 'implementation' in config and config['implementation'] == 0: config['implementation'] = 1 return cls(**config)
Ancestors
- RNN
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
Instance variables
var activation
-
Expand source code
@property def activation(self): return self.cell.activation
var bias_constraint
-
Expand source code
@property def bias_constraint(self): return self.cell.bias_constraint
var bias_initializer
-
Expand source code
@property def bias_initializer(self): return self.cell.bias_initializer
var bias_regularizer
-
Expand source code
@property def bias_regularizer(self): return self.cell.bias_regularizer
var dropout
-
Expand source code
@property def dropout(self): return self.cell.dropout
var implementation
-
Expand source code
@property def implementation(self): return self.cell.implementation
var kernel_constraint
-
Expand source code
@property def kernel_constraint(self): return self.cell.kernel_constraint
var kernel_initializer
-
Expand source code
@property def kernel_initializer(self): return self.cell.kernel_initializer
var kernel_regularizer
-
Expand source code
@property def kernel_regularizer(self): return self.cell.kernel_regularizer
var recurrent_activation
-
Expand source code
@property def recurrent_activation(self): return self.cell.recurrent_activation
var recurrent_constraint
-
Expand source code
@property def recurrent_constraint(self): return self.cell.recurrent_constraint
var recurrent_dropout
-
Expand source code
@property def recurrent_dropout(self): return self.cell.recurrent_dropout
var recurrent_initializer
-
Expand source code
@property def recurrent_initializer(self): return self.cell.recurrent_initializer
var recurrent_regularizer
-
Expand source code
@property def recurrent_regularizer(self): return self.cell.recurrent_regularizer
var reset_after
-
Expand source code
@property def reset_after(self): return self.cell.reset_after
var units
-
Expand source code
@property def units(self): return self.cell.units
var use_bias
-
Expand source code
@property def use_bias(self): return self.cell.use_bias
Inherited members
RNN
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
reset_states
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class GRUCell (units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, reset_after=False, **kwargs)
-
Cell class for the GRU layer.
Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you pass None, no activation is applied (ie. "linear" activation:a(x) = x
). recurrent_activation
- Activation function to use
for the recurrent step.
Default: hard sigmoid (
hard_sigmoid
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, whether the layer uses a bias vector.
kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. bias_initializer
- Initializer for the bias vector.
kernel_regularizer
- Regularizer function applied to
the
kernel
weights matrix. recurrent_regularizer
- Regularizer function applied to
the
recurrent_kernel
weights matrix. bias_regularizer
- Regularizer function applied to the bias vector.
kernel_constraint
- Constraint function applied to
the
kernel
weights matrix. recurrent_constraint
- Constraint function applied to
the
recurrent_kernel
weights matrix. bias_constraint
- Constraint function applied to the bias vector.
dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state.
reset_after
- GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before" (default), True = "after" (CuDNN compatible).
Call arguments: inputs: A 2D tensor. states: List of state tensors corresponding to the previous timestep. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant when
dropout
orrecurrent_dropout
is used.Expand source code
class GRUCell(DropoutRNNCellMixin, Layer): """Cell class for the GRU layer. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. bias_initializer: Initializer for the bias vector. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. kernel_constraint: Constraint function applied to the `kernel` weights matrix. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. reset_after: GRU convention (whether to apply reset gate after or before matrix multiplication). False = "before" (default), True = "after" (CuDNN compatible). Call arguments: inputs: A 2D tensor. states: List of state tensors corresponding to the previous timestep. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant when `dropout` or `recurrent_dropout` is used. """ def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., reset_after=False, **kwargs): if units < 0: raise ValueError(f'Received an invalid value for units, expected ' f'a positive integer, got {units}.') # By default use cached variable under v2 mode, see b/143699808. if tf.compat.v1.executing_eagerly_outside_functions(): self._enable_caching_device = kwargs.pop('enable_caching_device', True) else: self._enable_caching_device = kwargs.pop('enable_caching_device', False) super(GRUCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.recurrent_initializer = initializers.get(recurrent_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(recurrent_constraint) self.bias_constraint = constraints.get(bias_constraint) self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) implementation = kwargs.pop('implementation', 1) if self.recurrent_dropout != 0 and implementation != 1: logging.debug(RECURRENT_DROPOUT_WARNING_MSG) self.implementation = 1 else: self.implementation = implementation self.reset_after = reset_after self.state_size = self.units self.output_size = self.units @tf_utils.shape_type_conversion def build(self, input_shape): input_dim = input_shape[-1] default_caching_device = _caching_device(self) self.kernel = self.add_weight( shape=(input_dim, self.units * 3), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, caching_device=default_caching_device) self.recurrent_kernel = self.add_weight( shape=(self.units, self.units * 3), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, caching_device=default_caching_device) if self.use_bias: if not self.reset_after: bias_shape = (3 * self.units,) else: # separate biases for input and recurrent kernels # Note: the shape is intentionally different from CuDNNGRU biases # `(2 * 3 * self.units,)`, so that we can distinguish the classes # when loading and converting saved weights. bias_shape = (2, 3 * self.units) self.bias = self.add_weight(shape=bias_shape, name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, caching_device=default_caching_device) else: self.bias = None self.built = True def call(self, inputs, states, training=None): h_tm1 = states[0] if tf.nest.is_nested(states) else states # previous memory dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=3) rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( h_tm1, training, count=3) if self.use_bias: if not self.reset_after: input_bias, recurrent_bias = self.bias, None else: input_bias, recurrent_bias = tf.unstack(self.bias) if self.implementation == 1: if 0. < self.dropout < 1.: inputs_z = inputs * dp_mask[0] inputs_r = inputs * dp_mask[1] inputs_h = inputs * dp_mask[2] else: inputs_z = inputs inputs_r = inputs inputs_h = inputs x_z = backend.dot(inputs_z, self.kernel[:, :self.units]) x_r = backend.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) x_h = backend.dot(inputs_h, self.kernel[:, self.units * 2:]) if self.use_bias: x_z = backend.bias_add(x_z, input_bias[:self.units]) x_r = backend.bias_add(x_r, input_bias[self.units: self.units * 2]) x_h = backend.bias_add(x_h, input_bias[self.units * 2:]) if 0. < self.recurrent_dropout < 1.: h_tm1_z = h_tm1 * rec_dp_mask[0] h_tm1_r = h_tm1 * rec_dp_mask[1] h_tm1_h = h_tm1 * rec_dp_mask[2] else: h_tm1_z = h_tm1 h_tm1_r = h_tm1 h_tm1_h = h_tm1 recurrent_z = backend.dot(h_tm1_z, self.recurrent_kernel[:, :self.units]) recurrent_r = backend.dot( h_tm1_r, self.recurrent_kernel[:, self.units:self.units * 2]) if self.reset_after and self.use_bias: recurrent_z = backend.bias_add(recurrent_z, recurrent_bias[:self.units]) recurrent_r = backend.bias_add( recurrent_r, recurrent_bias[self.units:self.units * 2]) z = self.recurrent_activation(x_z + recurrent_z) r = self.recurrent_activation(x_r + recurrent_r) # reset gate applied after/before matrix multiplication if self.reset_after: recurrent_h = backend.dot( h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) if self.use_bias: recurrent_h = backend.bias_add( recurrent_h, recurrent_bias[self.units * 2:]) recurrent_h = r * recurrent_h else: recurrent_h = backend.dot( r * h_tm1_h, self.recurrent_kernel[:, self.units * 2:]) hh = self.activation(x_h + recurrent_h) else: if 0. < self.dropout < 1.: inputs = inputs * dp_mask[0] # inputs projected by all gate matrices at once matrix_x = backend.dot(inputs, self.kernel) if self.use_bias: # biases: bias_z_i, bias_r_i, bias_h_i matrix_x = backend.bias_add(matrix_x, input_bias) x_z, x_r, x_h = tf.split(matrix_x, 3, axis=-1) if self.reset_after: # hidden state projected by all gate matrices at once matrix_inner = backend.dot(h_tm1, self.recurrent_kernel) if self.use_bias: matrix_inner = backend.bias_add(matrix_inner, recurrent_bias) else: # hidden state projected separately for update/reset and new matrix_inner = backend.dot( h_tm1, self.recurrent_kernel[:, :2 * self.units]) recurrent_z, recurrent_r, recurrent_h = tf.split( matrix_inner, [self.units, self.units, -1], axis=-1) z = self.recurrent_activation(x_z + recurrent_z) r = self.recurrent_activation(x_r + recurrent_r) if self.reset_after: recurrent_h = r * recurrent_h else: recurrent_h = backend.dot( r * h_tm1, self.recurrent_kernel[:, 2 * self.units:]) hh = self.activation(x_h + recurrent_h) # previous and candidate state mixed by update gate h = z * h_tm1 + (1 - z) * hh new_state = [h] if tf.nest.is_nested(states) else h return h, new_state def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'implementation': self.implementation, 'reset_after': self.reset_after } config.update(_config_for_enable_caching_device(self)) base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
Ancestors
- DropoutRNNCellMixin
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
Methods
def get_initial_state(self, inputs=None, batch_size=None, dtype=None)
-
Expand source code
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
Inherited members
DropoutRNNCellMixin
:Layer
:Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class LSTM (units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs)
-
Long Short-Term Memory layer - Hochreiter 1997.
Note that this cell is not optimized for performance on GPU. Please use
tf.compat.v1.keras.layers.CuDNNLSTM
for better performance on GPU.Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). recurrent_activation
- Activation function to use
for the recurrent step.
Default: hard sigmoid (
hard_sigmoid
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, whether the layer uses a bias vector.
kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs.. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. bias_initializer
- Initializer for the bias vector.
unit_forget_bias
- Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force
bias_initializer="zeros"
. This is recommended in Jozefowicz et al., 2015. kernel_regularizer
- Regularizer function applied to
the
kernel
weights matrix. recurrent_regularizer
- Regularizer function applied to
the
recurrent_kernel
weights matrix. bias_regularizer
- Regularizer function applied to the bias vector.
activity_regularizer
- Regularizer function applied to the output of the layer (its "activation").
kernel_constraint
- Constraint function applied to
the
kernel
weights matrix. recurrent_constraint
- Constraint function applied to
the
recurrent_kernel
weights matrix. bias_constraint
- Constraint function applied to the bias vector.
dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state.
return_sequences
- Boolean. Whether to return the last output. in the output sequence, or the full sequence.
return_state
- Boolean. Whether to return the last state in addition to the output.
go_backwards
- Boolean (default False). If True, process the input sequence backwards and return the reversed sequence.
stateful
- Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
unroll
- Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.
time_major
- The shape format of the
inputs
andoutputs
tensors. If True, the inputs and outputs will be in shape(timesteps, batch, …)
, whereas in the False case, it will be(batch, timesteps, …)
. Usingtime_major = True
is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form.
Call arguments: inputs: A 3D tensor. mask: Binary tensor of shape
(samples, timesteps)
indicating whether a given timestep should be masked. An individualTrue
entry indicates that the corresponding timestep should be utilized, while aFalse
entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant ifdropout
orrecurrent_dropout
is used. initial_state: List of initial state tensors to be passed to the first call of the cell.Expand source code
class LSTM(RNN): """Long Short-Term Memory layer - Hochreiter 1997. Note that this cell is not optimized for performance on GPU. Please use `tf.compat.v1.keras.layers.CuDNNLSTM` for better performance on GPU. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs.. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. bias_initializer: Initializer for the bias vector. unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al., 2015]( http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf). kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. activity_regularizer: Regularizer function applied to the output of the layer (its "activation"). kernel_constraint: Constraint function applied to the `kernel` weights matrix. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. return_sequences: Boolean. Whether to return the last output. in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state in addition to the output. go_backwards: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence. stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll: Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major: The shape format of the `inputs` and `outputs` tensors. If True, the inputs and outputs will be in shape `(timesteps, batch, ...)`, whereas in the False case, it will be `(batch, timesteps, ...)`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. Call arguments: inputs: A 3D tensor. mask: Binary tensor of shape `(samples, timesteps)` indicating whether a given timestep should be masked. An individual `True` entry indicates that the corresponding timestep should be utilized, while a `False` entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant if `dropout` or `recurrent_dropout` is used. initial_state: List of initial state tensors to be passed to the first call of the cell. """ def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs): implementation = kwargs.pop('implementation', 1) if implementation == 0: logging.warning('`implementation=0` has been deprecated, ' 'and now defaults to `implementation=1`.' 'Please update your layer call.') if 'enable_caching_device' in kwargs: cell_kwargs = {'enable_caching_device': kwargs.pop('enable_caching_device')} else: cell_kwargs = {} cell = LSTMCell( units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, unit_forget_bias=unit_forget_bias, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=implementation, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True), **cell_kwargs) super(LSTM, self).__init__( cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): return super(LSTM, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @property def units(self): return self.cell.units @property def activation(self): return self.cell.activation @property def recurrent_activation(self): return self.cell.recurrent_activation @property def use_bias(self): return self.cell.use_bias @property def kernel_initializer(self): return self.cell.kernel_initializer @property def recurrent_initializer(self): return self.cell.recurrent_initializer @property def bias_initializer(self): return self.cell.bias_initializer @property def unit_forget_bias(self): return self.cell.unit_forget_bias @property def kernel_regularizer(self): return self.cell.kernel_regularizer @property def recurrent_regularizer(self): return self.cell.recurrent_regularizer @property def bias_regularizer(self): return self.cell.bias_regularizer @property def kernel_constraint(self): return self.cell.kernel_constraint @property def recurrent_constraint(self): return self.cell.recurrent_constraint @property def bias_constraint(self): return self.cell.bias_constraint @property def dropout(self): return self.cell.dropout @property def recurrent_dropout(self): return self.cell.recurrent_dropout @property def implementation(self): return self.cell.implementation def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'unit_forget_bias': self.unit_forget_bias, 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'implementation': self.implementation } config.update(_config_for_enable_caching_device(self.cell)) base_config = super(LSTM, self).get_config() del base_config['cell'] return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): if 'implementation' in config and config['implementation'] == 0: config['implementation'] = 1 return cls(**config)
Ancestors
- RNN
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
Instance variables
var activation
-
Expand source code
@property def activation(self): return self.cell.activation
var bias_constraint
-
Expand source code
@property def bias_constraint(self): return self.cell.bias_constraint
var bias_initializer
-
Expand source code
@property def bias_initializer(self): return self.cell.bias_initializer
var bias_regularizer
-
Expand source code
@property def bias_regularizer(self): return self.cell.bias_regularizer
var dropout
-
Expand source code
@property def dropout(self): return self.cell.dropout
var implementation
-
Expand source code
@property def implementation(self): return self.cell.implementation
var kernel_constraint
-
Expand source code
@property def kernel_constraint(self): return self.cell.kernel_constraint
var kernel_initializer
-
Expand source code
@property def kernel_initializer(self): return self.cell.kernel_initializer
var kernel_regularizer
-
Expand source code
@property def kernel_regularizer(self): return self.cell.kernel_regularizer
var recurrent_activation
-
Expand source code
@property def recurrent_activation(self): return self.cell.recurrent_activation
var recurrent_constraint
-
Expand source code
@property def recurrent_constraint(self): return self.cell.recurrent_constraint
var recurrent_dropout
-
Expand source code
@property def recurrent_dropout(self): return self.cell.recurrent_dropout
var recurrent_initializer
-
Expand source code
@property def recurrent_initializer(self): return self.cell.recurrent_initializer
var recurrent_regularizer
-
Expand source code
@property def recurrent_regularizer(self): return self.cell.recurrent_regularizer
var unit_forget_bias
-
Expand source code
@property def unit_forget_bias(self): return self.cell.unit_forget_bias
var units
-
Expand source code
@property def units(self): return self.cell.units
var use_bias
-
Expand source code
@property def use_bias(self): return self.cell.use_bias
Inherited members
RNN
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
reset_states
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class LSTMCell (units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, **kwargs)
-
Cell class for the LSTM layer.
Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). recurrent_activation
- Activation function to use
for the recurrent step.
Default: hard sigmoid (
hard_sigmoid
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, whether the layer uses a bias vector.
kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. bias_initializer
- Initializer for the bias vector.
unit_forget_bias
- Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force
bias_initializer="zeros"
. This is recommended in Jozefowicz et al., 2015 kernel_regularizer
- Regularizer function applied to
the
kernel
weights matrix. recurrent_regularizer
- Regularizer function applied to
the
recurrent_kernel
weights matrix. bias_regularizer
- Regularizer function applied to the bias vector.
kernel_constraint
- Constraint function applied to
the
kernel
weights matrix. recurrent_constraint
- Constraint function applied to
the
recurrent_kernel
weights matrix. bias_constraint
- Constraint function applied to the bias vector.
dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state.
Call arguments: inputs: A 2D tensor. states: List of state tensors corresponding to the previous timestep. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant when
dropout
orrecurrent_dropout
is used.Expand source code
class LSTMCell(DropoutRNNCellMixin, Layer): """Cell class for the LSTM layer. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. Default: hard sigmoid (`hard_sigmoid`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. bias_initializer: Initializer for the bias vector. unit_forget_bias: Boolean. If True, add 1 to the bias of the forget gate at initialization. Setting it to true will also force `bias_initializer="zeros"`. This is recommended in [Jozefowicz et al., 2015]( http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. bias_regularizer: Regularizer function applied to the bias vector. kernel_constraint: Constraint function applied to the `kernel` weights matrix. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. bias_constraint: Constraint function applied to the bias vector. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Call arguments: inputs: A 2D tensor. states: List of state tensors corresponding to the previous timestep. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant when `dropout` or `recurrent_dropout` is used. """ def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., **kwargs): if units < 0: raise ValueError(f'Received an invalid value for units, expected ' f'a positive integer, got {units}.') # By default use cached variable under v2 mode, see b/143699808. if tf.compat.v1.executing_eagerly_outside_functions(): self._enable_caching_device = kwargs.pop('enable_caching_device', True) else: self._enable_caching_device = kwargs.pop('enable_caching_device', False) super(LSTMCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.recurrent_activation = activations.get(recurrent_activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.recurrent_initializer = initializers.get(recurrent_initializer) self.bias_initializer = initializers.get(bias_initializer) self.unit_forget_bias = unit_forget_bias self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(recurrent_constraint) self.bias_constraint = constraints.get(bias_constraint) self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) implementation = kwargs.pop('implementation', 1) if self.recurrent_dropout != 0 and implementation != 1: logging.debug(RECURRENT_DROPOUT_WARNING_MSG) self.implementation = 1 else: self.implementation = implementation self.state_size = [self.units, self.units] self.output_size = self.units @tf_utils.shape_type_conversion def build(self, input_shape): default_caching_device = _caching_device(self) input_dim = input_shape[-1] self.kernel = self.add_weight( shape=(input_dim, self.units * 4), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, caching_device=default_caching_device) self.recurrent_kernel = self.add_weight( shape=(self.units, self.units * 4), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, caching_device=default_caching_device) if self.use_bias: if self.unit_forget_bias: def bias_initializer(_, *args, **kwargs): return backend.concatenate([ self.bias_initializer((self.units,), *args, **kwargs), initializers.get('ones')((self.units,), *args, **kwargs), self.bias_initializer((self.units * 2,), *args, **kwargs), ]) else: bias_initializer = self.bias_initializer self.bias = self.add_weight( shape=(self.units * 4,), name='bias', initializer=bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, caching_device=default_caching_device) else: self.bias = None self.built = True def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 i = self.recurrent_activation( x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) f = self.recurrent_activation(x_f + backend.dot( h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) c = f * c_tm1 + i * self.activation(x_c + backend.dot( h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) o = self.recurrent_activation( x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) return c, o def _compute_carry_and_output_fused(self, z, c_tm1): """Computes carry and output using fused kernels.""" z0, z1, z2, z3 = z i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) return c, o def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( h_tm1, training, count=4) if self.implementation == 1: if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs k_i, k_f, k_c, k_o = tf.split( self.kernel, num_or_size_splits=4, axis=1) x_i = backend.dot(inputs_i, k_i) x_f = backend.dot(inputs_f, k_f) x_c = backend.dot(inputs_c, k_c) x_o = backend.dot(inputs_o, k_o) if self.use_bias: b_i, b_f, b_c, b_o = tf.split( self.bias, num_or_size_splits=4, axis=0) x_i = backend.bias_add(x_i, b_i) x_f = backend.bias_add(x_f, b_f) x_c = backend.bias_add(x_c, b_c) x_o = backend.bias_add(x_o, b_o) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 x = (x_i, x_f, x_c, x_o) h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) else: if 0. < self.dropout < 1.: inputs = inputs * dp_mask[0] z = backend.dot(inputs, self.kernel) z += backend.dot(h_tm1, self.recurrent_kernel) if self.use_bias: z = backend.bias_add(z, self.bias) z = tf.split(z, num_or_size_splits=4, axis=1) c, o = self._compute_carry_and_output_fused(z, c_tm1) h = o * self.activation(c) return h, [h, c] def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'unit_forget_bias': self.unit_forget_bias, 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout, 'implementation': self.implementation } config.update(_config_for_enable_caching_device(self)) base_config = super(LSTMCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return list(_generate_zero_filled_state_for_cell( self, inputs, batch_size, dtype))
Ancestors
- DropoutRNNCellMixin
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
Methods
def get_initial_state(self, inputs=None, batch_size=None, dtype=None)
-
Expand source code
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return list(_generate_zero_filled_state_for_cell( self, inputs, batch_size, dtype))
Inherited members
DropoutRNNCellMixin
:Layer
:Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class PeepholeLSTMCell (units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, **kwargs)
-
Equivalent to LSTMCell class but adds peephole connections.
Peephole connections allow the gates to utilize the previous internal state as well as the previous hidden state (which is what LSTMCell is limited to). This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.
From Gers et al., 2002:
"We find that LSTM augmented by 'peephole connections' from its internal cells to its multiplicative gates can learn the fine distinction between sequences of spikes spaced either 50 or 49 time steps apart without the help of any short training exemplars."
The peephole implementation is based on:
Example:
# Create 2 PeepholeLSTMCells peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] # Create a layer composed sequentially of the peephole LSTM cells. layer = RNN(peephole_lstm_cells) input = keras.Input((timesteps, input_dim)) output = layer(input)
Expand source code
class PeepholeLSTMCell(LSTMCell): """Equivalent to LSTMCell class but adds peephole connections. Peephole connections allow the gates to utilize the previous internal state as well as the previous hidden state (which is what LSTMCell is limited to). This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. From [Gers et al., 2002]( http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): "We find that LSTM augmented by 'peephole connections' from its internal cells to its multiplicative gates can learn the fine distinction between sequences of spikes spaced either 50 or 49 time steps apart without the help of any short training exemplars." The peephole implementation is based on: [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) Example: ```python # Create 2 PeepholeLSTMCells peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] # Create a layer composed sequentially of the peephole LSTM cells. layer = RNN(peephole_lstm_cells) input = keras.Input((timesteps, input_dim)) output = layer(input) ``` """ def __init__(self, units, activation='tanh', recurrent_activation='hard_sigmoid', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', unit_forget_bias=True, kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., **kwargs): warnings.warn('`tf.keras.experimental.PeepholeLSTMCell` is deprecated ' 'and will be removed in a future version. ' 'Please use tensorflow_addons.rnn.PeepholeLSTMCell ' 'instead.') super(PeepholeLSTMCell, self).__init__( units=units, activation=activation, recurrent_activation=recurrent_activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, unit_forget_bias=unit_forget_bias, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, implementation=kwargs.pop('implementation', 1), **kwargs) def build(self, input_shape): super(PeepholeLSTMCell, self).build(input_shape) # The following are the weight matrices for the peephole connections. These # are multiplied with the previous internal state during the computation of # carry and output. self.input_gate_peephole_weights = self.add_weight( shape=(self.units,), name='input_gate_peephole_weights', initializer=self.kernel_initializer) self.forget_gate_peephole_weights = self.add_weight( shape=(self.units,), name='forget_gate_peephole_weights', initializer=self.kernel_initializer) self.output_gate_peephole_weights = self.add_weight( shape=(self.units,), name='output_gate_peephole_weights', initializer=self.kernel_initializer) def _compute_carry_and_output(self, x, h_tm1, c_tm1): x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 i = self.recurrent_activation( x_i + backend.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) + self.input_gate_peephole_weights * c_tm1) f = self.recurrent_activation(x_f + backend.dot( h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + self.forget_gate_peephole_weights * c_tm1) c = f * c_tm1 + i * self.activation(x_c + backend.dot( h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) o = self.recurrent_activation( x_o + backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + self.output_gate_peephole_weights * c) return c, o def _compute_carry_and_output_fused(self, z, c_tm1): z0, z1, z2, z3 = z i = self.recurrent_activation(z0 + self.input_gate_peephole_weights * c_tm1) f = self.recurrent_activation(z1 + self.forget_gate_peephole_weights * c_tm1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) return c, o
Ancestors
- LSTMCell
- DropoutRNNCellMixin
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Inherited members
LSTMCell
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_dropout_mask_for_cell
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_recurrent_dropout_mask_for_cell
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
reset_dropout_mask
reset_recurrent_dropout_mask
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
Layer
:
class RNN (cell, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, time_major=False, **kwargs)
-
Base class for recurrent layers.
See the Keras RNN API guide for details about the usage of RNN API.
Args
cell
- A RNN cell instance or a list of RNN cell instances.
A RNN cell is a class that has:
- A
call(input_at_t, states_at_t)
method, returning(output_at_t, states_at_t_plus_1)
. The call method of the cell can also take the optional argumentconstants
, see section "Note on passing external constants" below. - Astate_size
attribute. This can be a single integer (single state) in which case it is the size of the recurrent state. This can also be a list/tuple of integers (one size per state). Thestate_size
can also be TensorShape or tuple/list of TensorShape, to represent high dimension state. - Aoutput_size
attribute. This can be a single integer or a TensorShape, which represent the shape of the output. For backward compatible reason, if this attribute is not available for the cell, the value will be inferred by the first element of thestate_size
. - Aget_initial_state(inputs=None, batch_size=None, dtype=None)
method that creates a tensor meant to be fed tocall()
as the initial state, if the user didn't specify any initial state via other means. The returned initial state should have a shape of [batch_size, cell.state_size]. The cell might choose to create a tensor full of zeros, or full of other values based on the cell's implementation.inputs
is the input tensor to the RNN layer, which should contain the batch size as its shape[0], and also dtype. Note that the shape[0] might beNone
during the graph construction. Either theinputs
or the pair ofbatch_size
anddtype
are provided.batch_size
is a scalar tensor that represents the batch size of the inputs.dtype
istf.DType
that represents the dtype of the inputs. For backward compatibility, if this method is not implemented by the cell, the RNN layer will create a zero filled tensor with the size of [batch_size, cell.state_size]. In the case thatcell
is a list of RNN cell instances, the cells will be stacked on top of each other in the RNN, resulting in an efficient stacked RNN. return_sequences
- Boolean (default
False
). Whether to return the last output in the output sequence, or the full sequence. return_state
- Boolean (default
False
). Whether to return the last state in addition to the output. go_backwards
- Boolean (default
False
). If True, process the input sequence backwards and return the reversed sequence. stateful
- Boolean (default
False
). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll
- Boolean (default
False
). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major
- The shape format of the
inputs
andoutputs
tensors. If True, the inputs and outputs will be in shape(timesteps, batch, …)
, whereas in the False case, it will be(batch, timesteps, …)
. Usingtime_major = True
is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. zero_output_for_mask
- Boolean (default
False
). Whether the output should use zeros for the masked timesteps. Note that this field is only used whenreturn_sequences
is True and mask is provided. It can useful if you want to reuse the raw output sequence of the RNN without interference from the masked timesteps, eg, merging bidirectional RNNs.
Call arguments: inputs: Input tensor. mask: Binary tensor of shape
[batch_size, timesteps]
indicating whether a given timestep should be masked. An individualTrue
entry indicates that the corresponding timestep should be utilized, while aFalse
entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is for use with cells that use dropout. initial_state: List of initial state tensors to be passed to the first call of the cell. constants: List of constant tensors to be passed to the cell at each timestep.Input shape: N-D tensor with shape
[batch_size, timesteps, …]
or[timesteps, batch_size, …]
when time_major is True.Output shape: - If
return_state
: a list of tensors. The first tensor is the output. The remaining tensors are the last states, each with shape[batch_size, state_size]
, wherestate_size
could be a high dimension tensor shape. - Ifreturn_sequences
: N-D tensor with shape[batch_size, timesteps, output_size]
, whereoutput_size
could be a high dimension tensor shape, or[timesteps, batch_size, output_size]
whentime_major
is True. - Else, N-D tensor with shape[batch_size, output_size]
, whereoutput_size
could be a high dimension tensor shape.Masking
This layer supports masking for input data with a variable number of timesteps. To introduce masks to your data, use an [tf.keras.layers.Embedding] layer with the
mask_zero
parameter set toTrue
.Note on using statefulness in RNNs: You can set RNN layers to be 'stateful', which means that the states computed for the samples in one batch will be reused as initial states for the samples in the next batch. This assumes a one-to-one mapping between samples in different successive batches.
To enable statefulness: - Specify
stateful=True
in the layer constructor. - Specify a fixed batch size for your model, by passing If sequential model:batch_input_shape=(...)
to the first layer in your model. Else for functional model with 1 or more Input layers:batch_shape=(...)
to all the first layers in your model. This is the expected shape of your inputs including the batch size. It should be a tuple of integers, e.g.(32, 10, 100)
. - Specifyshuffle=False
when callingfit()
.To reset the states of your model, call
.reset_states()
on either a specific layer, or on your entire model.Note on specifying the initial state of RNNs: You can specify the initial state of RNN layers symbolically by calling them with the keyword argument
initial_state
. The value ofinitial_state
should be a tensor or list of tensors representing the initial state of the RNN layer.You can specify the initial state of RNN layers numerically by calling
reset_states
with the keyword argumentstates
. The value ofstates
should be a numpy array or list of numpy arrays representing the initial state of the RNN layer.Note on passing external constants to RNNs: You can pass "external" constants to the cell using the
constants
keyword argument ofRNN.__call__
(as well asRNN.call()
) method. This requires that thecell.call
method accepts the same keyword argumentconstants
. Such constants can be used to condition the cell transformation on additional static inputs (not changing over time), a.k.a. an attention mechanism.Examples:
# First, let's define a RNN Cell, as a layer subclass. class MinimalRNNCell(keras.layers.Layer): def __init__(self, units, **kwargs): self.units = units self.state_size = units super(MinimalRNNCell, self).__init__(**kwargs) def build(self, input_shape): self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel') self.recurrent_kernel = self.add_weight( shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') self.built = True def call(self, inputs, states): prev_output = states[0] h = backend.dot(inputs, self.kernel) output = h + backend.dot(prev_output, self.recurrent_kernel) return output, [output] # Let's use this cell in a RNN layer: cell = MinimalRNNCell(32) x = keras.Input((None, 5)) layer = RNN(cell) y = layer(x) # Here's how to use the cell to build a stacked RNN: cells = [MinimalRNNCell(32), MinimalRNNCell(64)] x = keras.Input((None, 5)) layer = RNN(cells) y = layer(x)
Expand source code
class RNN(Layer): """Base class for recurrent layers. See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) for details about the usage of RNN API. Args: cell: A RNN cell instance or a list of RNN cell instances. A RNN cell is a class that has: - A `call(input_at_t, states_at_t)` method, returning `(output_at_t, states_at_t_plus_1)`. The call method of the cell can also take the optional argument `constants`, see section "Note on passing external constants" below. - A `state_size` attribute. This can be a single integer (single state) in which case it is the size of the recurrent state. This can also be a list/tuple of integers (one size per state). The `state_size` can also be TensorShape or tuple/list of TensorShape, to represent high dimension state. - A `output_size` attribute. This can be a single integer or a TensorShape, which represent the shape of the output. For backward compatible reason, if this attribute is not available for the cell, the value will be inferred by the first element of the `state_size`. - A `get_initial_state(inputs=None, batch_size=None, dtype=None)` method that creates a tensor meant to be fed to `call()` as the initial state, if the user didn't specify any initial state via other means. The returned initial state should have a shape of [batch_size, cell.state_size]. The cell might choose to create a tensor full of zeros, or full of other values based on the cell's implementation. `inputs` is the input tensor to the RNN layer, which should contain the batch size as its shape[0], and also dtype. Note that the shape[0] might be `None` during the graph construction. Either the `inputs` or the pair of `batch_size` and `dtype` are provided. `batch_size` is a scalar tensor that represents the batch size of the inputs. `dtype` is `tf.DType` that represents the dtype of the inputs. For backward compatibility, if this method is not implemented by the cell, the RNN layer will create a zero filled tensor with the size of [batch_size, cell.state_size]. In the case that `cell` is a list of RNN cell instances, the cells will be stacked on top of each other in the RNN, resulting in an efficient stacked RNN. return_sequences: Boolean (default `False`). Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean (default `False`). Whether to return the last state in addition to the output. go_backwards: Boolean (default `False`). If True, process the input sequence backwards and return the reversed sequence. stateful: Boolean (default `False`). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll: Boolean (default `False`). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. time_major: The shape format of the `inputs` and `outputs` tensors. If True, the inputs and outputs will be in shape `(timesteps, batch, ...)`, whereas in the False case, it will be `(batch, timesteps, ...)`. Using `time_major = True` is a bit more efficient because it avoids transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. zero_output_for_mask: Boolean (default `False`). Whether the output should use zeros for the masked timesteps. Note that this field is only used when `return_sequences` is True and mask is provided. It can useful if you want to reuse the raw output sequence of the RNN without interference from the masked timesteps, eg, merging bidirectional RNNs. Call arguments: inputs: Input tensor. mask: Binary tensor of shape `[batch_size, timesteps]` indicating whether a given timestep should be masked. An individual `True` entry indicates that the corresponding timestep should be utilized, while a `False` entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is for use with cells that use dropout. initial_state: List of initial state tensors to be passed to the first call of the cell. constants: List of constant tensors to be passed to the cell at each timestep. Input shape: N-D tensor with shape `[batch_size, timesteps, ...]` or `[timesteps, batch_size, ...]` when time_major is True. Output shape: - If `return_state`: a list of tensors. The first tensor is the output. The remaining tensors are the last states, each with shape `[batch_size, state_size]`, where `state_size` could be a high dimension tensor shape. - If `return_sequences`: N-D tensor with shape `[batch_size, timesteps, output_size]`, where `output_size` could be a high dimension tensor shape, or `[timesteps, batch_size, output_size]` when `time_major` is True. - Else, N-D tensor with shape `[batch_size, output_size]`, where `output_size` could be a high dimension tensor shape. Masking: This layer supports masking for input data with a variable number of timesteps. To introduce masks to your data, use an [tf.keras.layers.Embedding] layer with the `mask_zero` parameter set to `True`. Note on using statefulness in RNNs: You can set RNN layers to be 'stateful', which means that the states computed for the samples in one batch will be reused as initial states for the samples in the next batch. This assumes a one-to-one mapping between samples in different successive batches. To enable statefulness: - Specify `stateful=True` in the layer constructor. - Specify a fixed batch size for your model, by passing If sequential model: `batch_input_shape=(...)` to the first layer in your model. Else for functional model with 1 or more Input layers: `batch_shape=(...)` to all the first layers in your model. This is the expected shape of your inputs *including the batch size*. It should be a tuple of integers, e.g. `(32, 10, 100)`. - Specify `shuffle=False` when calling `fit()`. To reset the states of your model, call `.reset_states()` on either a specific layer, or on your entire model. Note on specifying the initial state of RNNs: You can specify the initial state of RNN layers symbolically by calling them with the keyword argument `initial_state`. The value of `initial_state` should be a tensor or list of tensors representing the initial state of the RNN layer. You can specify the initial state of RNN layers numerically by calling `reset_states` with the keyword argument `states`. The value of `states` should be a numpy array or list of numpy arrays representing the initial state of the RNN layer. Note on passing external constants to RNNs: You can pass "external" constants to the cell using the `constants` keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This requires that the `cell.call` method accepts the same keyword argument `constants`. Such constants can be used to condition the cell transformation on additional static inputs (not changing over time), a.k.a. an attention mechanism. Examples: ```python # First, let's define a RNN Cell, as a layer subclass. class MinimalRNNCell(keras.layers.Layer): def __init__(self, units, **kwargs): self.units = units self.state_size = units super(MinimalRNNCell, self).__init__(**kwargs) def build(self, input_shape): self.kernel = self.add_weight(shape=(input_shape[-1], self.units), initializer='uniform', name='kernel') self.recurrent_kernel = self.add_weight( shape=(self.units, self.units), initializer='uniform', name='recurrent_kernel') self.built = True def call(self, inputs, states): prev_output = states[0] h = backend.dot(inputs, self.kernel) output = h + backend.dot(prev_output, self.recurrent_kernel) return output, [output] # Let's use this cell in a RNN layer: cell = MinimalRNNCell(32) x = keras.Input((None, 5)) layer = RNN(cell) y = layer(x) # Here's how to use the cell to build a stacked RNN: cells = [MinimalRNNCell(32), MinimalRNNCell(64)] x = keras.Input((None, 5)) layer = RNN(cells) y = layer(x) ``` """ def __init__(self, cell, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, time_major=False, **kwargs): if isinstance(cell, (list, tuple)): cell = StackedRNNCells(cell) if not 'call' in dir(cell): raise ValueError('`cell` should have a `call` method. ' 'The RNN was passed:', cell) if not 'state_size' in dir(cell): raise ValueError('The RNN cell should have ' 'an attribute `state_size` ' '(tuple of integers, ' 'one integer per RNN state).') # If True, the output for masked timestep will be zeros, whereas in the # False case, output from previous timestep is returned for masked timestep. self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False) if 'input_shape' not in kwargs and ( 'input_dim' in kwargs or 'input_length' in kwargs): input_shape = (kwargs.pop('input_length', None), kwargs.pop('input_dim', None)) kwargs['input_shape'] = input_shape super(RNN, self).__init__(**kwargs) self.cell = cell self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards self.stateful = stateful self.unroll = unroll self.time_major = time_major self.supports_masking = True # The input shape is unknown yet, it could have nested tensor inputs, and # the input spec will be the list of specs for nested inputs, the structure # of the input_spec will be the same as the input. self.input_spec = None self.state_spec = None self._states = None self.constants_spec = None self._num_constants = 0 if stateful: if tf.distribute.has_strategy(): raise ValueError('RNNs with stateful=True not yet supported with ' 'tf.distribute.Strategy.') @property def _use_input_spec_as_call_signature(self): if self.unroll: # When the RNN layer is unrolled, the time step shape cannot be unknown. # The input spec does not define the time step (because this layer can be # called with any time step value, as long as it is not None), so it # cannot be used as the call function signature when saving to SavedModel. return False return super(RNN, self)._use_input_spec_as_call_signature @property def states(self): if self._states is None: state = tf.nest.map_structure(lambda _: None, self.cell.state_size) return state if tf.nest.is_nested(self.cell.state_size) else [state] return self._states @states.setter # Automatic tracking catches "self._states" which adds an extra weight and # breaks HDF5 checkpoints. @tf.__internal__.tracking.no_automatic_dependency_tracking def states(self, states): self._states = states def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] # Check whether the input shape contains any nested shapes. It could be # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy # inputs. try: input_shape = tf.TensorShape(input_shape) except (ValueError, TypeError): # A nested tensor input input_shape = tf.nest.flatten(input_shape)[0] batch = input_shape[0] time_step = input_shape[1] if self.time_major: batch, time_step = time_step, batch if _is_multiple_state(self.cell.state_size): state_size = self.cell.state_size else: state_size = [self.cell.state_size] def _get_output_shape(flat_output_size): output_dim = tf.TensorShape(flat_output_size).as_list() if self.return_sequences: if self.time_major: output_shape = tf.TensorShape( [time_step, batch] + output_dim) else: output_shape = tf.TensorShape( [batch, time_step] + output_dim) else: output_shape = tf.TensorShape([batch] + output_dim) return output_shape if getattr(self.cell, 'output_size', None) is not None: # cell.output_size could be nested structure. output_shape = tf.nest.flatten(tf.nest.map_structure( _get_output_shape, self.cell.output_size)) output_shape = output_shape[0] if len(output_shape) == 1 else output_shape else: # Note that state_size[0] could be a tensor_shape or int. output_shape = _get_output_shape(state_size[0]) if self.return_state: def _get_state_shape(flat_state): state_shape = [batch] + tf.TensorShape(flat_state).as_list() return tf.TensorShape(state_shape) state_shape = tf.nest.map_structure(_get_state_shape, state_size) return generic_utils.to_list(output_shape) + tf.nest.flatten(state_shape) else: return output_shape def compute_mask(self, inputs, mask): # Time step masks must be the same for each input. # This is because the mask for an RNN is of size [batch, time_steps, 1], # and specifies which time steps should be skipped, and a time step # must be skipped for all inputs. # TODO(scottzhu): Should we accept multiple different masks? mask = tf.nest.flatten(mask)[0] output_mask = mask if self.return_sequences else None if self.return_state: state_mask = [None for _ in self.states] return [output_mask] + state_mask else: return output_mask def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] # The input_shape here could be a nest structure. # do the tensor_shape to shapes here. The input could be single tensor, or a # nested structure of tensors. def get_input_spec(shape): """Convert input shape to InputSpec.""" if isinstance(shape, tf.TensorShape): input_spec_shape = shape.as_list() else: input_spec_shape = list(shape) batch_index, time_step_index = (1, 0) if self.time_major else (0, 1) if not self.stateful: input_spec_shape[batch_index] = None input_spec_shape[time_step_index] = None return InputSpec(shape=tuple(input_spec_shape)) def get_step_input_shape(shape): if isinstance(shape, tf.TensorShape): shape = tuple(shape.as_list()) # remove the timestep from the input_shape return shape[1:] if self.time_major else (shape[0],) + shape[2:] # Check whether the input shape contains any nested shapes. It could be # (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy # inputs. try: input_shape = tf.TensorShape(input_shape) except (ValueError, TypeError): # A nested tensor input pass if not tf.nest.is_nested(input_shape): # This indicates the there is only one input. if self.input_spec is not None: self.input_spec[0] = get_input_spec(input_shape) else: self.input_spec = [get_input_spec(input_shape)] step_input_shape = get_step_input_shape(input_shape) else: if self.input_spec is not None: self.input_spec[0] = tf.nest.map_structure(get_input_spec, input_shape) else: self.input_spec = generic_utils.to_list( tf.nest.map_structure(get_input_spec, input_shape)) step_input_shape = tf.nest.map_structure(get_step_input_shape, input_shape) # allow cell (if layer) to build before we set or validate state_spec. if isinstance(self.cell, Layer) and not self.cell.built: with backend.name_scope(self.cell.name): self.cell.build(step_input_shape) self.cell.built = True # set or validate state_spec if _is_multiple_state(self.cell.state_size): state_size = list(self.cell.state_size) else: state_size = [self.cell.state_size] if self.state_spec is not None: # initial_state was passed in call, check compatibility self._validate_state_spec(state_size, self.state_spec) else: self.state_spec = [ InputSpec(shape=[None] + tf.TensorShape(dim).as_list()) for dim in state_size ] if self.stateful: self.reset_states() self.built = True @staticmethod def _validate_state_spec(cell_state_sizes, init_state_specs): """Validate the state spec between the initial_state and the state_size. Args: cell_state_sizes: list, the `state_size` attribute from the cell. init_state_specs: list, the `state_spec` from the initial_state that is passed in `call()`. Raises: ValueError: When initial state spec is not compatible with the state size. """ validation_error = ValueError( 'An `initial_state` was passed that is not compatible with ' '`cell.state_size`. Received `state_spec`={}; ' 'however `cell.state_size` is ' '{}'.format(init_state_specs, cell_state_sizes)) flat_cell_state_sizes = tf.nest.flatten(cell_state_sizes) flat_state_specs = tf.nest.flatten(init_state_specs) if len(flat_cell_state_sizes) != len(flat_state_specs): raise validation_error for cell_state_spec, cell_state_size in zip(flat_state_specs, flat_cell_state_sizes): if not tf.TensorShape( # Ignore the first axis for init_state which is for batch cell_state_spec.shape[1:]).is_compatible_with( tf.TensorShape(cell_state_size)): raise validation_error @doc_controls.do_not_doc_inheritable def get_initial_state(self, inputs): get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) if tf.nest.is_nested(inputs): # The input are nested sequences. Use the first element in the seq to get # batch size and dtype. inputs = tf.nest.flatten(inputs)[0] input_shape = tf.shape(inputs) batch_size = input_shape[1] if self.time_major else input_shape[0] dtype = inputs.dtype if get_initial_state_fn: init_state = get_initial_state_fn( inputs=None, batch_size=batch_size, dtype=dtype) else: init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, dtype) # Keras RNN expect the states in a list, even if it's a single state tensor. if not tf.nest.is_nested(init_state): init_state = [init_state] # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. return list(init_state) def __call__(self, inputs, initial_state=None, constants=None, **kwargs): inputs, initial_state, constants = _standardize_args(inputs, initial_state, constants, self._num_constants) if initial_state is None and constants is None: return super(RNN, self).__call__(inputs, **kwargs) # If any of `initial_state` or `constants` are specified and are Keras # tensors, then add them to the inputs and temporarily modify the # input_spec to include them. additional_inputs = [] additional_specs = [] if initial_state is not None: additional_inputs += initial_state self.state_spec = tf.nest.map_structure( lambda s: InputSpec(shape=backend.int_shape(s)), initial_state) additional_specs += self.state_spec if constants is not None: additional_inputs += constants self.constants_spec = [ InputSpec(shape=backend.int_shape(constant)) for constant in constants ] self._num_constants = len(constants) additional_specs += self.constants_spec # additional_inputs can be empty if initial_state or constants are provided # but empty (e.g. the cell is stateless). flat_additional_inputs = tf.nest.flatten(additional_inputs) is_keras_tensor = backend.is_keras_tensor( flat_additional_inputs[0]) if flat_additional_inputs else True for tensor in flat_additional_inputs: if backend.is_keras_tensor(tensor) != is_keras_tensor: raise ValueError('The initial state or constants of an RNN' ' layer cannot be specified with a mix of' ' Keras tensors and non-Keras tensors' ' (a "Keras tensor" is a tensor that was' ' returned by a Keras layer, or by `Input`)') if is_keras_tensor: # Compute the full input spec, including state and constants full_input = [inputs] + additional_inputs if self.built: # Keep the input_spec since it has been populated in build() method. full_input_spec = self.input_spec + additional_specs else: # The original input_spec is None since there could be a nested tensor # input. Update the input_spec to match the inputs. full_input_spec = generic_utils.to_list( tf.nest.map_structure(lambda _: None, inputs)) + additional_specs # Perform the call with temporarily replaced input_spec self.input_spec = full_input_spec output = super(RNN, self).__call__(full_input, **kwargs) # Remove the additional_specs from input spec and keep the rest. It is # important to keep since the input spec was populated by build(), and # will be reused in the stateful=True. self.input_spec = self.input_spec[:-len(additional_specs)] return output else: if initial_state is not None: kwargs['initial_state'] = initial_state if constants is not None: kwargs['constants'] = constants return super(RNN, self).__call__(inputs, **kwargs) def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): # The input should be dense, padded with zeros. If a ragged input is fed # into the layer, it is padded and the row lengths are used for masking. inputs, row_lengths = backend.convert_inputs_if_ragged(inputs) is_ragged_input = (row_lengths is not None) self._validate_args_if_ragged(is_ragged_input, mask) inputs, initial_state, constants = self._process_inputs( inputs, initial_state, constants) self._maybe_reset_cell_dropout_mask(self.cell) if isinstance(self.cell, StackedRNNCells): for cell in self.cell.cells: self._maybe_reset_cell_dropout_mask(cell) if mask is not None: # Time step masks must be the same for each input. # TODO(scottzhu): Should we accept multiple different masks? mask = tf.nest.flatten(mask)[0] if tf.nest.is_nested(inputs): # In the case of nested input, use the first element for shape check. input_shape = backend.int_shape(tf.nest.flatten(inputs)[0]) else: input_shape = backend.int_shape(inputs) timesteps = input_shape[0] if self.time_major else input_shape[1] if self.unroll and timesteps is None: raise ValueError('Cannot unroll a RNN if the ' 'time dimension is undefined. \n' '- If using a Sequential model, ' 'specify the time dimension by passing ' 'an `input_shape` or `batch_input_shape` ' 'argument to your first layer. If your ' 'first layer is an Embedding, you can ' 'also use the `input_length` argument.\n' '- If using the functional API, specify ' 'the time dimension by passing a `shape` ' 'or `batch_shape` argument to your Input layer.') kwargs = {} if generic_utils.has_arg(self.cell.call, 'training'): kwargs['training'] = training # TF RNN cells expect single tensor as state instead of list wrapped tensor. is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None # Use the __call__ function for callable objects, eg layers, so that it # will have the proper name scopes for the ops, etc. cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call if constants: if not generic_utils.has_arg(self.cell.call, 'constants'): raise ValueError('RNN cell does not support constants') def step(inputs, states): constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type states = states[0] if len(states) == 1 and is_tf_rnn_cell else states output, new_states = cell_call_fn( inputs, states, constants=constants, **kwargs) if not tf.nest.is_nested(new_states): new_states = [new_states] return output, new_states else: def step(inputs, states): states = states[0] if len(states) == 1 and is_tf_rnn_cell else states output, new_states = cell_call_fn(inputs, states, **kwargs) if not tf.nest.is_nested(new_states): new_states = [new_states] return output, new_states last_output, outputs, states = backend.rnn( step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, input_length=row_lengths if row_lengths is not None else timesteps, time_major=self.time_major, zero_output_for_mask=self.zero_output_for_mask) if self.stateful: updates = [ tf.compat.v1.assign(self_state, state) for self_state, state in zip( tf.nest.flatten(self.states), tf.nest.flatten(states)) ] self.add_update(updates) if self.return_sequences: output = backend.maybe_convert_to_ragged( is_ragged_input, outputs, row_lengths, go_backwards=self.go_backwards) else: output = last_output if self.return_state: if not isinstance(states, (list, tuple)): states = [states] else: states = list(states) return generic_utils.to_list(output) + states else: return output def _process_inputs(self, inputs, initial_state, constants): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if (isinstance(inputs, collections.abc.Sequence) and not isinstance(inputs, tuple)): # get initial_state from full input spec # as they could be copied to multiple GPU. if not self._num_constants: initial_state = inputs[1:] else: initial_state = inputs[1:-self._num_constants] constants = inputs[-self._num_constants:] if len(initial_state) == 0: initial_state = None inputs = inputs[0] if self.stateful: if initial_state is not None: # When layer is stateful and initial_state is provided, check if the # recorded state is same as the default value (zeros). Use the recorded # state if it is not same as the default. non_zero_count = tf.add_n([tf.math.count_nonzero(s) for s in tf.nest.flatten(self.states)]) # Set strict = True to keep the original structure of the state. initial_state = tf.compat.v1.cond(non_zero_count > 0, true_fn=lambda: self.states, false_fn=lambda: initial_state, strict=True) else: initial_state = self.states elif initial_state is None: initial_state = self.get_initial_state(inputs) if len(initial_state) != len(self.states): raise ValueError('Layer has ' + str(len(self.states)) + ' states but was passed ' + str(len(initial_state)) + ' initial states.') return inputs, initial_state, constants def _validate_args_if_ragged(self, is_ragged_input, mask): if not is_ragged_input: return if mask is not None: raise ValueError('The mask that was passed in was ' + str(mask) + ' and cannot be applied to RaggedTensor inputs. Please ' 'make sure that there is no mask passed in by upstream ' 'layers.') if self.unroll: raise ValueError('The input received contains RaggedTensors and does ' 'not support unrolling. Disable unrolling by passing ' '`unroll=False` in the RNN Layer constructor.') def _maybe_reset_cell_dropout_mask(self, cell): if isinstance(cell, DropoutRNNCellMixin): cell.reset_dropout_mask() cell.reset_recurrent_dropout_mask() def reset_states(self, states=None): """Reset the recorded states for the stateful RNN layer. Can only be used when RNN layer is constructed with `stateful` = `True`. Args: states: Numpy arrays that contains the value for the initial state, which will be feed to cell at the first time step. When the value is None, zero filled numpy array will be created based on the cell state size. Raises: AttributeError: When the RNN layer is not stateful. ValueError: When the batch size of the RNN layer is unknown. ValueError: When the input numpy array is not compatible with the RNN layer state, either size wise or dtype wise. """ if not self.stateful: raise AttributeError('Layer must be stateful.') spec_shape = None if self.input_spec is not None: spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape if spec_shape is None: # It is possible to have spec shape to be None, eg when construct a RNN # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know # it has 3 dim input, but not its full shape spec before build(). batch_size = None else: batch_size = spec_shape[1] if self.time_major else spec_shape[0] if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' 'of your input tensors: \n' '- If using a Sequential model, ' 'specify the batch size by passing ' 'a `batch_input_shape` ' 'argument to your first layer.\n' '- If using the functional API, specify ' 'the batch size by passing a ' '`batch_shape` argument to your Input layer.') # initialize state if None if tf.nest.flatten(self.states)[0] is None: if getattr(self.cell, 'get_initial_state', None): flat_init_state_values = tf.nest.flatten(self.cell.get_initial_state( inputs=None, batch_size=batch_size, dtype=self.dtype or backend.floatx())) else: flat_init_state_values = tf.nest.flatten(_generate_zero_filled_state( batch_size, self.cell.state_size, self.dtype or backend.floatx())) flat_states_variables = tf.nest.map_structure( backend.variable, flat_init_state_values) self.states = tf.nest.pack_sequence_as(self.cell.state_size, flat_states_variables) if not tf.nest.is_nested(self.states): self.states = [self.states] elif states is None: for state, size in zip(tf.nest.flatten(self.states), tf.nest.flatten(self.cell.state_size)): backend.set_value( state, np.zeros([batch_size] + tf.TensorShape(size).as_list())) else: flat_states = tf.nest.flatten(self.states) flat_input_states = tf.nest.flatten(states) if len(flat_input_states) != len(flat_states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(flat_states)) + ' states, ' 'but it received ' + str(len(flat_input_states)) + ' state values. Input received: ' + str(states)) set_value_tuples = [] for i, (value, state) in enumerate(zip(flat_input_states, flat_states)): if value.shape != state.shape: raise ValueError( 'State ' + str(i) + ' is incompatible with layer ' + self.name + ': expected shape=' + str( (batch_size, state)) + ', found shape=' + str(value.shape)) set_value_tuples.append((state, value)) backend.batch_set_value(set_value_tuples) def get_config(self): config = { 'return_sequences': self.return_sequences, 'return_state': self.return_state, 'go_backwards': self.go_backwards, 'stateful': self.stateful, 'unroll': self.unroll, 'time_major': self.time_major } if self._num_constants: config['num_constants'] = self._num_constants if self.zero_output_for_mask: config['zero_output_for_mask'] = self.zero_output_for_mask config['cell'] = generic_utils.serialize_keras_object(self.cell) base_config = super(RNN, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): from keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) num_constants = config.pop('num_constants', 0) layer = cls(cell, **config) layer._num_constants = num_constants return layer @property def _trackable_saved_model_saver(self): return layer_serialization.RNNSavedModelSaver(self)
Ancestors
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Subclasses
Instance variables
var states
-
Expand source code
@property def states(self): if self._states is None: state = tf.nest.map_structure(lambda _: None, self.cell.state_size) return state if tf.nest.is_nested(self.cell.state_size) else [state] return self._states
Methods
def get_initial_state(self, inputs)
-
Expand source code
@doc_controls.do_not_doc_inheritable def get_initial_state(self, inputs): get_initial_state_fn = getattr(self.cell, 'get_initial_state', None) if tf.nest.is_nested(inputs): # The input are nested sequences. Use the first element in the seq to get # batch size and dtype. inputs = tf.nest.flatten(inputs)[0] input_shape = tf.shape(inputs) batch_size = input_shape[1] if self.time_major else input_shape[0] dtype = inputs.dtype if get_initial_state_fn: init_state = get_initial_state_fn( inputs=None, batch_size=batch_size, dtype=dtype) else: init_state = _generate_zero_filled_state(batch_size, self.cell.state_size, dtype) # Keras RNN expect the states in a list, even if it's a single state tensor. if not tf.nest.is_nested(init_state): init_state = [init_state] # Force the state to be a list in case it is a namedtuple eg LSTMStateTuple. return list(init_state)
def reset_states(self, states=None)
-
Reset the recorded states for the stateful RNN layer.
Can only be used when RNN layer is constructed with
stateful
=True
.Args
states
- Numpy arrays that contains the value for the initial state, which will be feed to cell at the first time step. When the value is None, zero filled numpy array will be created based on the cell state size.
Raises
AttributeError
- When the RNN layer is not stateful.
ValueError
- When the batch size of the RNN layer is unknown.
ValueError
- When the input numpy array is not compatible with the RNN layer state, either size wise or dtype wise.
Expand source code
def reset_states(self, states=None): """Reset the recorded states for the stateful RNN layer. Can only be used when RNN layer is constructed with `stateful` = `True`. Args: states: Numpy arrays that contains the value for the initial state, which will be feed to cell at the first time step. When the value is None, zero filled numpy array will be created based on the cell state size. Raises: AttributeError: When the RNN layer is not stateful. ValueError: When the batch size of the RNN layer is unknown. ValueError: When the input numpy array is not compatible with the RNN layer state, either size wise or dtype wise. """ if not self.stateful: raise AttributeError('Layer must be stateful.') spec_shape = None if self.input_spec is not None: spec_shape = tf.nest.flatten(self.input_spec[0])[0].shape if spec_shape is None: # It is possible to have spec shape to be None, eg when construct a RNN # with a custom cell, or standard RNN layers (LSTM/GRU) which we only know # it has 3 dim input, but not its full shape spec before build(). batch_size = None else: batch_size = spec_shape[1] if self.time_major else spec_shape[0] if not batch_size: raise ValueError('If a RNN is stateful, it needs to know ' 'its batch size. Specify the batch size ' 'of your input tensors: \n' '- If using a Sequential model, ' 'specify the batch size by passing ' 'a `batch_input_shape` ' 'argument to your first layer.\n' '- If using the functional API, specify ' 'the batch size by passing a ' '`batch_shape` argument to your Input layer.') # initialize state if None if tf.nest.flatten(self.states)[0] is None: if getattr(self.cell, 'get_initial_state', None): flat_init_state_values = tf.nest.flatten(self.cell.get_initial_state( inputs=None, batch_size=batch_size, dtype=self.dtype or backend.floatx())) else: flat_init_state_values = tf.nest.flatten(_generate_zero_filled_state( batch_size, self.cell.state_size, self.dtype or backend.floatx())) flat_states_variables = tf.nest.map_structure( backend.variable, flat_init_state_values) self.states = tf.nest.pack_sequence_as(self.cell.state_size, flat_states_variables) if not tf.nest.is_nested(self.states): self.states = [self.states] elif states is None: for state, size in zip(tf.nest.flatten(self.states), tf.nest.flatten(self.cell.state_size)): backend.set_value( state, np.zeros([batch_size] + tf.TensorShape(size).as_list())) else: flat_states = tf.nest.flatten(self.states) flat_input_states = tf.nest.flatten(states) if len(flat_input_states) != len(flat_states): raise ValueError('Layer ' + self.name + ' expects ' + str(len(flat_states)) + ' states, ' 'but it received ' + str(len(flat_input_states)) + ' state values. Input received: ' + str(states)) set_value_tuples = [] for i, (value, state) in enumerate(zip(flat_input_states, flat_states)): if value.shape != state.shape: raise ValueError( 'State ' + str(i) + ' is incompatible with layer ' + self.name + ': expected shape=' + str( (batch_size, state)) + ', found shape=' + str(value.shape)) set_value_tuples.append((state, value)) backend.batch_set_value(set_value_tuples)
Inherited members
Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class SimpleRNN (units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs)
-
Fully-connected RNN where the output is to be fed back to input.
See the Keras RNN API guide for details about the usage of RNN API.
Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you pass None, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, (default
True
), whether the layer uses a bias vector. kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs. Default:glorot_uniform
. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. Default:orthogonal
. bias_initializer
- Initializer for the bias vector. Default:
zeros
. kernel_regularizer
- Regularizer function applied to the
kernel
weights matrix. Default:None
. recurrent_regularizer
- Regularizer function applied to the
recurrent_kernel
weights matrix. Default:None
. bias_regularizer
- Regularizer function applied to the bias vector. Default:
None
. activity_regularizer
- Regularizer function applied to the output of the
layer (its "activation"). Default:
None
. kernel_constraint
- Constraint function applied to the
kernel
weights matrix. Default:None
. recurrent_constraint
- Constraint function applied to the
recurrent_kernel
weights matrix. Default:None
. bias_constraint
- Constraint function applied to the bias vector. Default:
None
. dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.
return_sequences
- Boolean. Whether to return the last output
in the output sequence, or the full sequence. Default:
False
. return_state
- Boolean. Whether to return the last state
in addition to the output. Default:
False
go_backwards
- Boolean (default False). If True, process the input sequence backwards and return the reversed sequence.
stateful
- Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch.
unroll
- Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.
Call arguments: inputs: A 3D tensor, with shape
[batch, timesteps, feature]
. mask: Binary tensor of shape[batch, timesteps]
indicating whether a given timestep should be masked. An individualTrue
entry indicates that the corresponding timestep should be utilized, while aFalse
entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant ifdropout
orrecurrent_dropout
is used. initial_state: List of initial state tensors to be passed to the first call of the cell.Examples:
inputs = np.random.random([32, 10, 8]).astype(np.float32) simple_rnn = tf.keras.layers.SimpleRNN(4) output = simple_rnn(inputs) # The output has shape `[32, 4]`. simple_rnn = tf.keras.layers.SimpleRNN( 4, return_sequences=True, return_state=True) # whole_sequence_output has shape `[32, 10, 4]`. # final_state has shape `[32, 4]`. whole_sequence_output, final_state = simple_rnn(inputs)
Expand source code
class SimpleRNN(RNN): """Fully-connected RNN where the output is to be fed back to input. See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) for details about the usage of RNN API. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, (default `True`), whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. bias_initializer: Initializer for the bias vector. Default: `zeros`. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. Default: `None`. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_regularizer: Regularizer function applied to the bias vector. Default: `None`. activity_regularizer: Regularizer function applied to the output of the layer (its "activation"). Default: `None`. kernel_constraint: Constraint function applied to the `kernel` weights matrix. Default: `None`. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_constraint: Constraint function applied to the bias vector. Default: `None`. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. Default: `False`. return_state: Boolean. Whether to return the last state in addition to the output. Default: `False` go_backwards: Boolean (default False). If True, process the input sequence backwards and return the reversed sequence. stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. unroll: Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. Call arguments: inputs: A 3D tensor, with shape `[batch, timesteps, feature]`. mask: Binary tensor of shape `[batch, timesteps]` indicating whether a given timestep should be masked. An individual `True` entry indicates that the corresponding timestep should be utilized, while a `False` entry indicates that the corresponding timestep should be ignored. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. This argument is passed to the cell when calling it. This is only relevant if `dropout` or `recurrent_dropout` is used. initial_state: List of initial state tensors to be passed to the first call of the cell. Examples: ```python inputs = np.random.random([32, 10, 8]).astype(np.float32) simple_rnn = tf.keras.layers.SimpleRNN(4) output = simple_rnn(inputs) # The output has shape `[32, 4]`. simple_rnn = tf.keras.layers.SimpleRNN( 4, return_sequences=True, return_state=True) # whole_sequence_output has shape `[32, 10, 4]`. # final_state has shape `[32, 4]`. whole_sequence_output, final_state = simple_rnn(inputs) ``` """ def __init__(self, units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., return_sequences=False, return_state=False, go_backwards=False, stateful=False, unroll=False, **kwargs): if 'implementation' in kwargs: kwargs.pop('implementation') logging.warning('The `implementation` argument ' 'in `SimpleRNN` has been deprecated. ' 'Please remove it from your layer call.') if 'enable_caching_device' in kwargs: cell_kwargs = {'enable_caching_device': kwargs.pop('enable_caching_device')} else: cell_kwargs = {} cell = SimpleRNNCell( units, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, recurrent_initializer=recurrent_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, recurrent_regularizer=recurrent_regularizer, bias_regularizer=bias_regularizer, kernel_constraint=kernel_constraint, recurrent_constraint=recurrent_constraint, bias_constraint=bias_constraint, dropout=dropout, recurrent_dropout=recurrent_dropout, dtype=kwargs.get('dtype'), trainable=kwargs.get('trainable', True), **cell_kwargs) super(SimpleRNN, self).__init__( cell, return_sequences=return_sequences, return_state=return_state, go_backwards=go_backwards, stateful=stateful, unroll=unroll, **kwargs) self.activity_regularizer = regularizers.get(activity_regularizer) self.input_spec = [InputSpec(ndim=3)] def call(self, inputs, mask=None, training=None, initial_state=None): return super(SimpleRNN, self).call( inputs, mask=mask, training=training, initial_state=initial_state) @property def units(self): return self.cell.units @property def activation(self): return self.cell.activation @property def use_bias(self): return self.cell.use_bias @property def kernel_initializer(self): return self.cell.kernel_initializer @property def recurrent_initializer(self): return self.cell.recurrent_initializer @property def bias_initializer(self): return self.cell.bias_initializer @property def kernel_regularizer(self): return self.cell.kernel_regularizer @property def recurrent_regularizer(self): return self.cell.recurrent_regularizer @property def bias_regularizer(self): return self.cell.bias_regularizer @property def kernel_constraint(self): return self.cell.kernel_constraint @property def recurrent_constraint(self): return self.cell.recurrent_constraint @property def bias_constraint(self): return self.cell.bias_constraint @property def dropout(self): return self.cell.dropout @property def recurrent_dropout(self): return self.cell.recurrent_dropout def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout } base_config = super(SimpleRNN, self).get_config() config.update(_config_for_enable_caching_device(self.cell)) del base_config['cell'] return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config): if 'implementation' in config: config.pop('implementation') return cls(**config)
Ancestors
- RNN
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Instance variables
var activation
-
Expand source code
@property def activation(self): return self.cell.activation
var bias_constraint
-
Expand source code
@property def bias_constraint(self): return self.cell.bias_constraint
var bias_initializer
-
Expand source code
@property def bias_initializer(self): return self.cell.bias_initializer
var bias_regularizer
-
Expand source code
@property def bias_regularizer(self): return self.cell.bias_regularizer
var dropout
-
Expand source code
@property def dropout(self): return self.cell.dropout
var kernel_constraint
-
Expand source code
@property def kernel_constraint(self): return self.cell.kernel_constraint
var kernel_initializer
-
Expand source code
@property def kernel_initializer(self): return self.cell.kernel_initializer
var kernel_regularizer
-
Expand source code
@property def kernel_regularizer(self): return self.cell.kernel_regularizer
var recurrent_constraint
-
Expand source code
@property def recurrent_constraint(self): return self.cell.recurrent_constraint
var recurrent_dropout
-
Expand source code
@property def recurrent_dropout(self): return self.cell.recurrent_dropout
var recurrent_initializer
-
Expand source code
@property def recurrent_initializer(self): return self.cell.recurrent_initializer
var recurrent_regularizer
-
Expand source code
@property def recurrent_regularizer(self): return self.cell.recurrent_regularizer
var units
-
Expand source code
@property def units(self): return self.cell.units
var use_bias
-
Expand source code
@property def use_bias(self): return self.cell.use_bias
Inherited members
RNN
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
reset_states
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class SimpleRNNCell (units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0.0, recurrent_dropout=0.0, **kwargs)
-
Cell class for SimpleRNN.
See the Keras RNN API guide for details about the usage of RNN API.
This class processes one step within the whole time sequence input, whereas
tf.keras.layer.SimpleRNN
processes the whole sequence.Args
units
- Positive integer, dimensionality of the output space.
activation
- Activation function to use.
Default: hyperbolic tangent (
tanh
). If you passNone
, no activation is applied (ie. "linear" activation:a(x) = x
). use_bias
- Boolean, (default
True
), whether the layer uses a bias vector. kernel_initializer
- Initializer for the
kernel
weights matrix, used for the linear transformation of the inputs. Default:glorot_uniform
. recurrent_initializer
- Initializer for the
recurrent_kernel
weights matrix, used for the linear transformation of the recurrent state. Default:orthogonal
. bias_initializer
- Initializer for the bias vector. Default:
zeros
. kernel_regularizer
- Regularizer function applied to the
kernel
weights matrix. Default:None
. recurrent_regularizer
- Regularizer function applied to the
recurrent_kernel
weights matrix. Default:None
. bias_regularizer
- Regularizer function applied to the bias vector. Default:
None
. kernel_constraint
- Constraint function applied to the
kernel
weights matrix. Default:None
. recurrent_constraint
- Constraint function applied to the
recurrent_kernel
weights matrix. Default:None
. bias_constraint
- Constraint function applied to the bias vector. Default:
None
. dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0.
recurrent_dropout
- Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0.
Call arguments: inputs: A 2D tensor, with shape of
[batch, feature]
. states: A 2D tensor with shape of[batch, units]
, which is the state from the previous time step. For timestep 0, the initial state provided by user will be feed to cell. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant whendropout
orrecurrent_dropout
is used.Examples:
inputs = np.random.random([32, 10, 8]).astype(np.float32) rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) output = rnn(inputs) # The output has shape `[32, 4]`. rnn = tf.keras.layers.RNN( tf.keras.layers.SimpleRNNCell(4), return_sequences=True, return_state=True) # whole_sequence_output has shape `[32, 10, 4]`. # final_state has shape `[32, 4]`. whole_sequence_output, final_state = rnn(inputs)
Expand source code
class SimpleRNNCell(DropoutRNNCellMixin, Layer): """Cell class for SimpleRNN. See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) for details about the usage of RNN API. This class processes one step within the whole time sequence input, whereas `tf.keras.layer.SimpleRNN` processes the whole sequence. Args: units: Positive integer, dimensionality of the output space. activation: Activation function to use. Default: hyperbolic tangent (`tanh`). If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, (default `True`), whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. Default: `glorot_uniform`. recurrent_initializer: Initializer for the `recurrent_kernel` weights matrix, used for the linear transformation of the recurrent state. Default: `orthogonal`. bias_initializer: Initializer for the bias vector. Default: `zeros`. kernel_regularizer: Regularizer function applied to the `kernel` weights matrix. Default: `None`. recurrent_regularizer: Regularizer function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_regularizer: Regularizer function applied to the bias vector. Default: `None`. kernel_constraint: Constraint function applied to the `kernel` weights matrix. Default: `None`. recurrent_constraint: Constraint function applied to the `recurrent_kernel` weights matrix. Default: `None`. bias_constraint: Constraint function applied to the bias vector. Default: `None`. dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Default: 0. recurrent_dropout: Float between 0 and 1. Fraction of the units to drop for the linear transformation of the recurrent state. Default: 0. Call arguments: inputs: A 2D tensor, with shape of `[batch, feature]`. states: A 2D tensor with shape of `[batch, units]`, which is the state from the previous time step. For timestep 0, the initial state provided by user will be feed to cell. training: Python boolean indicating whether the layer should behave in training mode or in inference mode. Only relevant when `dropout` or `recurrent_dropout` is used. Examples: ```python inputs = np.random.random([32, 10, 8]).astype(np.float32) rnn = tf.keras.layers.RNN(tf.keras.layers.SimpleRNNCell(4)) output = rnn(inputs) # The output has shape `[32, 4]`. rnn = tf.keras.layers.RNN( tf.keras.layers.SimpleRNNCell(4), return_sequences=True, return_state=True) # whole_sequence_output has shape `[32, 10, 4]`. # final_state has shape `[32, 4]`. whole_sequence_output, final_state = rnn(inputs) ``` """ def __init__(self, units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, kernel_constraint=None, recurrent_constraint=None, bias_constraint=None, dropout=0., recurrent_dropout=0., **kwargs): if units < 0: raise ValueError(f'Received an invalid value for units, expected ' f'a positive integer, got {units}.') # By default use cached variable under v2 mode, see b/143699808. if tf.compat.v1.executing_eagerly_outside_functions(): self._enable_caching_device = kwargs.pop('enable_caching_device', True) else: self._enable_caching_device = kwargs.pop('enable_caching_device', False) super(SimpleRNNCell, self).__init__(**kwargs) self.units = units self.activation = activations.get(activation) self.use_bias = use_bias self.kernel_initializer = initializers.get(kernel_initializer) self.recurrent_initializer = initializers.get(recurrent_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(recurrent_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(recurrent_constraint) self.bias_constraint = constraints.get(bias_constraint) self.dropout = min(1., max(0., dropout)) self.recurrent_dropout = min(1., max(0., recurrent_dropout)) self.state_size = self.units self.output_size = self.units @tf_utils.shape_type_conversion def build(self, input_shape): default_caching_device = _caching_device(self) self.kernel = self.add_weight( shape=(input_shape[-1], self.units), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, caching_device=default_caching_device) self.recurrent_kernel = self.add_weight( shape=(self.units, self.units), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, caching_device=default_caching_device) if self.use_bias: self.bias = self.add_weight( shape=(self.units,), name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, caching_device=default_caching_device) else: self.bias = None self.built = True def call(self, inputs, states, training=None): prev_output = states[0] if tf.nest.is_nested(states) else states dp_mask = self.get_dropout_mask_for_cell(inputs, training) rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( prev_output, training) if dp_mask is not None: h = backend.dot(inputs * dp_mask, self.kernel) else: h = backend.dot(inputs, self.kernel) if self.bias is not None: h = backend.bias_add(h, self.bias) if rec_dp_mask is not None: prev_output = prev_output * rec_dp_mask output = h + backend.dot(prev_output, self.recurrent_kernel) if self.activation is not None: output = self.activation(output) new_state = [output] if tf.nest.is_nested(states) else output return output, new_state def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype) def get_config(self): config = { 'units': self.units, 'activation': activations.serialize(self.activation), 'use_bias': self.use_bias, 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 'bias_initializer': initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint), 'dropout': self.dropout, 'recurrent_dropout': self.recurrent_dropout } config.update(_config_for_enable_caching_device(self)) base_config = super(SimpleRNNCell, self).get_config() return dict(list(base_config.items()) + list(config.items()))
Ancestors
- DropoutRNNCellMixin
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Methods
def get_initial_state(self, inputs=None, batch_size=None, dtype=None)
-
Expand source code
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)
Inherited members
DropoutRNNCellMixin
:Layer
:Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights
class StackedRNNCells (cells, **kwargs)
-
Wrapper allowing a stack of RNN cells to behave as a single cell.
Used to implement efficient stacked RNNs.
Args
cells
- List of RNN cell instances.
Examples:
batch_size = 3 sentence_max_length = 5 n_features = 2 new_shape = (batch_size, sentence_max_length, n_features) x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) lstm_layer = tf.keras.layers.RNN(stacked_lstm) result = lstm_layer(x)
Expand source code
class StackedRNNCells(Layer): """Wrapper allowing a stack of RNN cells to behave as a single cell. Used to implement efficient stacked RNNs. Args: cells: List of RNN cell instances. Examples: ```python batch_size = 3 sentence_max_length = 5 n_features = 2 new_shape = (batch_size, sentence_max_length, n_features) x = tf.constant(np.reshape(np.arange(30), new_shape), dtype = tf.float32) rnn_cells = [tf.keras.layers.LSTMCell(128) for _ in range(2)] stacked_lstm = tf.keras.layers.StackedRNNCells(rnn_cells) lstm_layer = tf.keras.layers.RNN(stacked_lstm) result = lstm_layer(x) ``` """ def __init__(self, cells, **kwargs): for cell in cells: if not 'call' in dir(cell): raise ValueError('All cells must have a `call` method. ' 'received cells:', cells) if not 'state_size' in dir(cell): raise ValueError('All cells must have a ' '`state_size` attribute. ' 'received cells:', cells) self.cells = cells # reverse_state_order determines whether the state size will be in a reverse # order of the cells' state. User might want to set this to True to keep the # existing behavior. This is only useful when use RNN(return_state=True) # since the state will be returned as the same order of state_size. self.reverse_state_order = kwargs.pop('reverse_state_order', False) if self.reverse_state_order: logging.warning('reverse_state_order=True in StackedRNNCells will soon ' 'be deprecated. Please update the code to work with the ' 'natural order of states if you rely on the RNN states, ' 'eg RNN(return_state=True).') super(StackedRNNCells, self).__init__(**kwargs) @property def state_size(self): return tuple(c.state_size for c in (self.cells[::-1] if self.reverse_state_order else self.cells)) @property def output_size(self): if getattr(self.cells[-1], 'output_size', None) is not None: return self.cells[-1].output_size elif _is_multiple_state(self.cells[-1].state_size): return self.cells[-1].state_size[0] else: return self.cells[-1].state_size def get_initial_state(self, inputs=None, batch_size=None, dtype=None): initial_states = [] for cell in self.cells[::-1] if self.reverse_state_order else self.cells: get_initial_state_fn = getattr(cell, 'get_initial_state', None) if get_initial_state_fn: initial_states.append(get_initial_state_fn( inputs=inputs, batch_size=batch_size, dtype=dtype)) else: initial_states.append(_generate_zero_filled_state_for_cell( cell, inputs, batch_size, dtype)) return tuple(initial_states) def call(self, inputs, states, constants=None, training=None, **kwargs): # Recover per-cell states. state_size = (self.state_size[::-1] if self.reverse_state_order else self.state_size) nested_states = tf.nest.pack_sequence_as(state_size, tf.nest.flatten(states)) # Call the cells in order and store the returned states. new_nested_states = [] for cell, states in zip(self.cells, nested_states): states = states if tf.nest.is_nested(states) else [states] # TF cell does not wrap the state into list when there is only one state. is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None states = states[0] if len(states) == 1 and is_tf_rnn_cell else states if generic_utils.has_arg(cell.call, 'training'): kwargs['training'] = training else: kwargs.pop('training', None) # Use the __call__ function for callable objects, eg layers, so that it # will have the proper name scopes for the ops, etc. cell_call_fn = cell.__call__ if callable(cell) else cell.call if generic_utils.has_arg(cell.call, 'constants'): inputs, states = cell_call_fn(inputs, states, constants=constants, **kwargs) else: inputs, states = cell_call_fn(inputs, states, **kwargs) new_nested_states.append(states) return inputs, tf.nest.pack_sequence_as(state_size, tf.nest.flatten(new_nested_states)) @tf_utils.shape_type_conversion def build(self, input_shape): if isinstance(input_shape, list): input_shape = input_shape[0] for cell in self.cells: if isinstance(cell, Layer) and not cell.built: with backend.name_scope(cell.name): cell.build(input_shape) cell.built = True if getattr(cell, 'output_size', None) is not None: output_dim = cell.output_size elif _is_multiple_state(cell.state_size): output_dim = cell.state_size[0] else: output_dim = cell.state_size input_shape = tuple([input_shape[0]] + tf.TensorShape(output_dim).as_list()) self.built = True def get_config(self): cells = [] for cell in self.cells: cells.append(generic_utils.serialize_keras_object(cell)) config = {'cells': cells} base_config = super(StackedRNNCells, self).get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): from keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top cells = [] for cell_config in config.pop('cells'): cells.append( deserialize_layer(cell_config, custom_objects=custom_objects)) return cls(cells, **config)
Ancestors
- Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- LayerVersionSelector
Instance variables
var output_size
-
Expand source code
@property def output_size(self): if getattr(self.cells[-1], 'output_size', None) is not None: return self.cells[-1].output_size elif _is_multiple_state(self.cells[-1].state_size): return self.cells[-1].state_size[0] else: return self.cells[-1].state_size
var state_size
-
Expand source code
@property def state_size(self): return tuple(c.state_size for c in (self.cells[::-1] if self.reverse_state_order else self.cells))
Methods
def get_initial_state(self, inputs=None, batch_size=None, dtype=None)
-
Expand source code
def get_initial_state(self, inputs=None, batch_size=None, dtype=None): initial_states = [] for cell in self.cells[::-1] if self.reverse_state_order else self.cells: get_initial_state_fn = getattr(cell, 'get_initial_state', None) if get_initial_state_fn: initial_states.append(get_initial_state_fn( inputs=inputs, batch_size=batch_size, dtype=dtype)) else: initial_states.append(_generate_zero_filled_state_for_cell( cell, inputs, batch_size, dtype)) return tuple(initial_states)
Inherited members
Layer
:activity_regularizer
add_loss
add_metric
add_update
add_variable
add_weight
apply
build
call
compute_dtype
compute_mask
compute_output_shape
compute_output_signature
count_params
dtype
dtype_policy
dynamic
finalize_state
from_config
get_config
get_input_at
get_input_mask_at
get_input_shape_at
get_losses_for
get_output_at
get_output_mask_at
get_output_shape_at
get_updates_for
get_weights
inbound_nodes
input
input_mask
input_shape
input_spec
losses
metrics
name
non_trainable_variables
non_trainable_weights
outbound_nodes
output
output_mask
output_shape
set_weights
supports_masking
trainable_variables
trainable_weights
variable_dtype
variables
weights