Module keras.distribute.keras_stateful_lstm_model_correctness_test
Tests for stateful tf.keras LSTM models using DistributionStrategy.
Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stateful tf.keras LSTM models using DistributionStrategy."""
import tensorflow.compat.v2 as tf
import numpy as np
import keras
from keras.distribute import keras_correctness_test_base
from keras.optimizer_v2 import gradient_descent as gradient_descent_keras
def strategies_for_stateful_embedding_model():
"""Returns TPUStrategy with single core device assignment."""
return [
tf.__internal__.distribute.combinations.tpu_strategy_one_core,
]
def test_combinations_for_stateful_embedding_model():
return (tf.__internal__.test.combinations.combine(
distribution=strategies_for_stateful_embedding_model(),
mode='graph',
use_numpy=False,
use_validation_data=False))
class DistributionStrategyStatefulLstmModelCorrectnessTest(
keras_correctness_test_base
.TestDistributionStrategyEmbeddingModelCorrectnessBase):
def get_model(self,
max_words=10,
initial_weights=None,
distribution=None,
input_shapes=None):
del input_shapes
batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE
with keras_correctness_test_base.MaybeDistributionScope(distribution):
word_ids = keras.layers.Input(
shape=(max_words,),
batch_size=batch_size,
dtype=np.int32,
name='words')
word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
lstm_embed = keras.layers.LSTM(
units=4, return_sequences=False, stateful=True)(
word_embed)
preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
model = keras.Model(inputs=[word_ids], outputs=[preds])
if initial_weights:
model.set_weights(initial_weights)
optimizer_fn = gradient_descent_keras.SGD
model.compile(
optimizer=optimizer_fn(learning_rate=0.1),
loss='sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy'])
return model
# TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
# doesn't work and enable for DistributionStrategy more generally.
@tf.__internal__.distribute.combinations.generate(test_combinations_for_stateful_embedding_model())
def disabled_test_stateful_lstm_model_correctness(
self, distribution, use_numpy, use_validation_data):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
is_stateful_model=True)
@tf.__internal__.distribute.combinations.generate(
tf.__internal__.test.combinations.times(
keras_correctness_test_base
.test_combinations_with_tpu_strategies_graph()))
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
self, distribution, use_numpy, use_validation_data):
with self.assertRaisesRegex(
ValueError, 'RNNs with stateful=True not yet supported with '
'tf.distribute.Strategy.'):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
is_stateful_model=True)
if __name__ == '__main__':
tf.test.main()
Functions
def strategies_for_stateful_embedding_model()
-
Returns TPUStrategy with single core device assignment.
Expand source code
def strategies_for_stateful_embedding_model(): """Returns TPUStrategy with single core device assignment.""" return [ tf.__internal__.distribute.combinations.tpu_strategy_one_core, ]
def test_combinations_for_stateful_embedding_model()
-
Expand source code
def test_combinations_for_stateful_embedding_model(): return (tf.__internal__.test.combinations.combine( distribution=strategies_for_stateful_embedding_model(), mode='graph', use_numpy=False, use_validation_data=False))
Classes
class DistributionStrategyStatefulLstmModelCorrectnessTest (methodName='runTest')
-
Base class to test correctness of Keras models with embedding layers.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class DistributionStrategyStatefulLstmModelCorrectnessTest( keras_correctness_test_base .TestDistributionStrategyEmbeddingModelCorrectnessBase): def get_model(self, max_words=10, initial_weights=None, distribution=None, input_shapes=None): del input_shapes batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids = keras.layers.Input( shape=(max_words,), batch_size=batch_size, dtype=np.int32, name='words') word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) lstm_embed = keras.layers.LSTM( units=4, return_sequences=False, stateful=True)( word_embed) preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) model = keras.Model(inputs=[word_ids], outputs=[preds]) if initial_weights: model.set_weights(initial_weights) optimizer_fn = gradient_descent_keras.SGD model.compile( optimizer=optimizer_fn(learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) return model # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it # doesn't work and enable for DistributionStrategy more generally. @tf.__internal__.distribute.combinations.generate(test_combinations_for_stateful_embedding_model()) def disabled_test_stateful_lstm_model_correctness( self, distribution, use_numpy, use_validation_data): self.run_correctness_test( distribution, use_numpy, use_validation_data, is_stateful_model=True) @tf.__internal__.distribute.combinations.generate( tf.__internal__.test.combinations.times( keras_correctness_test_base .test_combinations_with_tpu_strategies_graph())) def test_incorrectly_use_multiple_cores_for_stateful_lstm_model( self, distribution, use_numpy, use_validation_data): with self.assertRaisesRegex( ValueError, 'RNNs with stateful=True not yet supported with ' 'tf.distribute.Strategy.'): self.run_correctness_test( distribution, use_numpy, use_validation_data, is_stateful_model=True)
Ancestors
- TestDistributionStrategyEmbeddingModelCorrectnessBase
- TestDistributionStrategyCorrectnessBase
- tensorflow.python.framework.test_util.TensorFlowTestCase
- absl.testing.parameterized.TestCase
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Class variables
var disabled_test_stateful_lstm_model_correctness
Methods
def get_model(self, max_words=10, initial_weights=None, distribution=None, input_shapes=None)
-
Expand source code
def get_model(self, max_words=10, initial_weights=None, distribution=None, input_shapes=None): del input_shapes batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE with keras_correctness_test_base.MaybeDistributionScope(distribution): word_ids = keras.layers.Input( shape=(max_words,), batch_size=batch_size, dtype=np.int32, name='words') word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids) lstm_embed = keras.layers.LSTM( units=4, return_sequences=False, stateful=True)( word_embed) preds = keras.layers.Dense(2, activation='softmax')(lstm_embed) model = keras.Model(inputs=[word_ids], outputs=[preds]) if initial_weights: model.set_weights(initial_weights) optimizer_fn = gradient_descent_keras.SGD model.compile( optimizer=optimizer_fn(learning_rate=0.1), loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) return model
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_False_usevalidationdata_False(self, **kwargs)
-
A wrapped test method that can treat some arguments in a special way.
Expand source code
def decorated(self, **kwargs): """A wrapped test method that can treat some arguments in a special way.""" original_kwargs = kwargs.copy() # Skip combinations that are going to be executed in a different testing # environment. reasons_to_skip = [] for combination in test_combinations: should_execute, reason = combination.should_execute_combination( original_kwargs.copy()) if not should_execute: reasons_to_skip.append(" - " + reason) if reasons_to_skip: self.skipTest("\n".join(reasons_to_skip)) customized_parameters = [] for combination in test_combinations: customized_parameters.extend(combination.parameter_modifiers()) customized_parameters = set(customized_parameters) # The function for running the test under the total set of # `context_managers`: def execute_test_method(): requested_parameters = tf_inspect.getfullargspec(test_method).args for customized_parameter in customized_parameters: for argument, value in customized_parameter.modified_arguments( original_kwargs.copy(), requested_parameters).items(): if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: kwargs.pop(argument, None) else: kwargs[argument] = value omitted_arguments = set(requested_parameters).difference( set(list(kwargs.keys()) + ["self"])) if omitted_arguments: raise ValueError("The test requires parameters whose arguments " "were not passed: {} .".format(omitted_arguments)) missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_parameters)) if missing_arguments: raise ValueError("The test does not take parameters that were passed " ": {} .".format(missing_arguments)) kwargs_to_pass = {} for parameter in requested_parameters: if parameter == "self": kwargs_to_pass[parameter] = self else: kwargs_to_pass[parameter] = kwargs[parameter] test_method(**kwargs_to_pass) # Install `context_managers` before running the test: context_managers = [] for combination in test_combinations: for manager in combination.context_managers( original_kwargs.copy()): context_managers.append(manager) if hasattr(contextlib, "nested"): # Python 2 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. with contextlib.nested(*context_managers): execute_test_method() else: # Python 3 with contextlib.ExitStack() as context_stack: for manager in context_managers: context_stack.enter_context(manager) execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_False_usevalidationdata_True(self, **kwargs)
-
A wrapped test method that can treat some arguments in a special way.
Expand source code
def decorated(self, **kwargs): """A wrapped test method that can treat some arguments in a special way.""" original_kwargs = kwargs.copy() # Skip combinations that are going to be executed in a different testing # environment. reasons_to_skip = [] for combination in test_combinations: should_execute, reason = combination.should_execute_combination( original_kwargs.copy()) if not should_execute: reasons_to_skip.append(" - " + reason) if reasons_to_skip: self.skipTest("\n".join(reasons_to_skip)) customized_parameters = [] for combination in test_combinations: customized_parameters.extend(combination.parameter_modifiers()) customized_parameters = set(customized_parameters) # The function for running the test under the total set of # `context_managers`: def execute_test_method(): requested_parameters = tf_inspect.getfullargspec(test_method).args for customized_parameter in customized_parameters: for argument, value in customized_parameter.modified_arguments( original_kwargs.copy(), requested_parameters).items(): if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: kwargs.pop(argument, None) else: kwargs[argument] = value omitted_arguments = set(requested_parameters).difference( set(list(kwargs.keys()) + ["self"])) if omitted_arguments: raise ValueError("The test requires parameters whose arguments " "were not passed: {} .".format(omitted_arguments)) missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_parameters)) if missing_arguments: raise ValueError("The test does not take parameters that were passed " ": {} .".format(missing_arguments)) kwargs_to_pass = {} for parameter in requested_parameters: if parameter == "self": kwargs_to_pass[parameter] = self else: kwargs_to_pass[parameter] = kwargs[parameter] test_method(**kwargs_to_pass) # Install `context_managers` before running the test: context_managers = [] for combination in test_combinations: for manager in combination.context_managers( original_kwargs.copy()): context_managers.append(manager) if hasattr(contextlib, "nested"): # Python 2 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. with contextlib.nested(*context_managers): execute_test_method() else: # Python 3 with contextlib.ExitStack() as context_stack: for manager in context_managers: context_stack.enter_context(manager) execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_True_usevalidationdata_False(self, **kwargs)
-
A wrapped test method that can treat some arguments in a special way.
Expand source code
def decorated(self, **kwargs): """A wrapped test method that can treat some arguments in a special way.""" original_kwargs = kwargs.copy() # Skip combinations that are going to be executed in a different testing # environment. reasons_to_skip = [] for combination in test_combinations: should_execute, reason = combination.should_execute_combination( original_kwargs.copy()) if not should_execute: reasons_to_skip.append(" - " + reason) if reasons_to_skip: self.skipTest("\n".join(reasons_to_skip)) customized_parameters = [] for combination in test_combinations: customized_parameters.extend(combination.parameter_modifiers()) customized_parameters = set(customized_parameters) # The function for running the test under the total set of # `context_managers`: def execute_test_method(): requested_parameters = tf_inspect.getfullargspec(test_method).args for customized_parameter in customized_parameters: for argument, value in customized_parameter.modified_arguments( original_kwargs.copy(), requested_parameters).items(): if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: kwargs.pop(argument, None) else: kwargs[argument] = value omitted_arguments = set(requested_parameters).difference( set(list(kwargs.keys()) + ["self"])) if omitted_arguments: raise ValueError("The test requires parameters whose arguments " "were not passed: {} .".format(omitted_arguments)) missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_parameters)) if missing_arguments: raise ValueError("The test does not take parameters that were passed " ": {} .".format(missing_arguments)) kwargs_to_pass = {} for parameter in requested_parameters: if parameter == "self": kwargs_to_pass[parameter] = self else: kwargs_to_pass[parameter] = kwargs[parameter] test_method(**kwargs_to_pass) # Install `context_managers` before running the test: context_managers = [] for combination in test_combinations: for manager in combination.context_managers( original_kwargs.copy()): context_managers.append(manager) if hasattr(contextlib, "nested"): # Python 2 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. with contextlib.nested(*context_managers): execute_test_method() else: # Python 3 with contextlib.ExitStack() as context_stack: for manager in context_managers: context_stack.enter_context(manager) execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_True_usevalidationdata_True(self, **kwargs)
-
A wrapped test method that can treat some arguments in a special way.
Expand source code
def decorated(self, **kwargs): """A wrapped test method that can treat some arguments in a special way.""" original_kwargs = kwargs.copy() # Skip combinations that are going to be executed in a different testing # environment. reasons_to_skip = [] for combination in test_combinations: should_execute, reason = combination.should_execute_combination( original_kwargs.copy()) if not should_execute: reasons_to_skip.append(" - " + reason) if reasons_to_skip: self.skipTest("\n".join(reasons_to_skip)) customized_parameters = [] for combination in test_combinations: customized_parameters.extend(combination.parameter_modifiers()) customized_parameters = set(customized_parameters) # The function for running the test under the total set of # `context_managers`: def execute_test_method(): requested_parameters = tf_inspect.getfullargspec(test_method).args for customized_parameter in customized_parameters: for argument, value in customized_parameter.modified_arguments( original_kwargs.copy(), requested_parameters).items(): if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: kwargs.pop(argument, None) else: kwargs[argument] = value omitted_arguments = set(requested_parameters).difference( set(list(kwargs.keys()) + ["self"])) if omitted_arguments: raise ValueError("The test requires parameters whose arguments " "were not passed: {} .".format(omitted_arguments)) missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_parameters)) if missing_arguments: raise ValueError("The test does not take parameters that were passed " ": {} .".format(missing_arguments)) kwargs_to_pass = {} for parameter in requested_parameters: if parameter == "self": kwargs_to_pass[parameter] = self else: kwargs_to_pass[parameter] = kwargs[parameter] test_method(**kwargs_to_pass) # Install `context_managers` before running the test: context_managers = [] for combination in test_combinations: for manager in combination.context_managers( original_kwargs.copy()): context_managers.append(manager) if hasattr(contextlib, "nested"): # Python 2 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. with contextlib.nested(*context_managers): execute_test_method() else: # Python 3 with contextlib.ExitStack() as context_stack: for manager in context_managers: context_stack.enter_context(manager) execute_test_method()
Inherited members