Module keras.distribute.keras_stateful_lstm_model_correctness_test

Tests for stateful tf.keras LSTM models using DistributionStrategy.

Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for stateful tf.keras LSTM models using DistributionStrategy."""

import tensorflow.compat.v2 as tf

import numpy as np

import keras
from keras.distribute import keras_correctness_test_base
from keras.optimizer_v2 import gradient_descent as gradient_descent_keras


def strategies_for_stateful_embedding_model():
  """Returns TPUStrategy with single core device assignment."""

  return [
      tf.__internal__.distribute.combinations.tpu_strategy_one_core,
  ]


def test_combinations_for_stateful_embedding_model():
  return (tf.__internal__.test.combinations.combine(
      distribution=strategies_for_stateful_embedding_model(),
      mode='graph',
      use_numpy=False,
      use_validation_data=False))


class DistributionStrategyStatefulLstmModelCorrectnessTest(
    keras_correctness_test_base
    .TestDistributionStrategyEmbeddingModelCorrectnessBase):

  def get_model(self,
                max_words=10,
                initial_weights=None,
                distribution=None,
                input_shapes=None):
    del input_shapes
    batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      word_ids = keras.layers.Input(
          shape=(max_words,),
          batch_size=batch_size,
          dtype=np.int32,
          name='words')
      word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
      lstm_embed = keras.layers.LSTM(
          units=4, return_sequences=False, stateful=True)(
              word_embed)

      preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
      model = keras.Model(inputs=[word_ids], outputs=[preds])

      if initial_weights:
        model.set_weights(initial_weights)

      optimizer_fn = gradient_descent_keras.SGD

      model.compile(
          optimizer=optimizer_fn(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'])
    return model

  # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
  # doesn't work and enable for DistributionStrategy more generally.
  @tf.__internal__.distribute.combinations.generate(test_combinations_for_stateful_embedding_model())
  def disabled_test_stateful_lstm_model_correctness(
      self, distribution, use_numpy, use_validation_data):
    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        is_stateful_model=True)

  @tf.__internal__.distribute.combinations.generate(
      tf.__internal__.test.combinations.times(
          keras_correctness_test_base
          .test_combinations_with_tpu_strategies_graph()))
  def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
      self, distribution, use_numpy, use_validation_data):
    with self.assertRaisesRegex(
        ValueError, 'RNNs with stateful=True not yet supported with '
        'tf.distribute.Strategy.'):
      self.run_correctness_test(
          distribution,
          use_numpy,
          use_validation_data,
          is_stateful_model=True)


if __name__ == '__main__':
  tf.test.main()

Functions

def strategies_for_stateful_embedding_model()

Returns TPUStrategy with single core device assignment.

Expand source code
def strategies_for_stateful_embedding_model():
  """Returns TPUStrategy with single core device assignment."""

  return [
      tf.__internal__.distribute.combinations.tpu_strategy_one_core,
  ]
def test_combinations_for_stateful_embedding_model()
Expand source code
def test_combinations_for_stateful_embedding_model():
  return (tf.__internal__.test.combinations.combine(
      distribution=strategies_for_stateful_embedding_model(),
      mode='graph',
      use_numpy=False,
      use_validation_data=False))

Classes

class DistributionStrategyStatefulLstmModelCorrectnessTest (methodName='runTest')

Base class to test correctness of Keras models with embedding layers.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class DistributionStrategyStatefulLstmModelCorrectnessTest(
    keras_correctness_test_base
    .TestDistributionStrategyEmbeddingModelCorrectnessBase):

  def get_model(self,
                max_words=10,
                initial_weights=None,
                distribution=None,
                input_shapes=None):
    del input_shapes
    batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

    with keras_correctness_test_base.MaybeDistributionScope(distribution):
      word_ids = keras.layers.Input(
          shape=(max_words,),
          batch_size=batch_size,
          dtype=np.int32,
          name='words')
      word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
      lstm_embed = keras.layers.LSTM(
          units=4, return_sequences=False, stateful=True)(
              word_embed)

      preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
      model = keras.Model(inputs=[word_ids], outputs=[preds])

      if initial_weights:
        model.set_weights(initial_weights)

      optimizer_fn = gradient_descent_keras.SGD

      model.compile(
          optimizer=optimizer_fn(learning_rate=0.1),
          loss='sparse_categorical_crossentropy',
          metrics=['sparse_categorical_accuracy'])
    return model

  # TODO(jhseu): Disabled to fix b/130808953. Need to investigate why it
  # doesn't work and enable for DistributionStrategy more generally.
  @tf.__internal__.distribute.combinations.generate(test_combinations_for_stateful_embedding_model())
  def disabled_test_stateful_lstm_model_correctness(
      self, distribution, use_numpy, use_validation_data):
    self.run_correctness_test(
        distribution,
        use_numpy,
        use_validation_data,
        is_stateful_model=True)

  @tf.__internal__.distribute.combinations.generate(
      tf.__internal__.test.combinations.times(
          keras_correctness_test_base
          .test_combinations_with_tpu_strategies_graph()))
  def test_incorrectly_use_multiple_cores_for_stateful_lstm_model(
      self, distribution, use_numpy, use_validation_data):
    with self.assertRaisesRegex(
        ValueError, 'RNNs with stateful=True not yet supported with '
        'tf.distribute.Strategy.'):
      self.run_correctness_test(
          distribution,
          use_numpy,
          use_validation_data,
          is_stateful_model=True)

Ancestors

Class variables

var disabled_test_stateful_lstm_model_correctness

Methods

def get_model(self, max_words=10, initial_weights=None, distribution=None, input_shapes=None)
Expand source code
def get_model(self,
              max_words=10,
              initial_weights=None,
              distribution=None,
              input_shapes=None):
  del input_shapes
  batch_size = keras_correctness_test_base._GLOBAL_BATCH_SIZE

  with keras_correctness_test_base.MaybeDistributionScope(distribution):
    word_ids = keras.layers.Input(
        shape=(max_words,),
        batch_size=batch_size,
        dtype=np.int32,
        name='words')
    word_embed = keras.layers.Embedding(input_dim=20, output_dim=10)(word_ids)
    lstm_embed = keras.layers.LSTM(
        units=4, return_sequences=False, stateful=True)(
            word_embed)

    preds = keras.layers.Dense(2, activation='softmax')(lstm_embed)
    model = keras.Model(inputs=[word_ids], outputs=[preds])

    if initial_weights:
      model.set_weights(initial_weights)

    optimizer_fn = gradient_descent_keras.SGD

    model.compile(
        optimizer=optimizer_fn(learning_rate=0.1),
        loss='sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy'])
  return model
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_False_usevalidationdata_False(self, **kwargs)

A wrapped test method that can treat some arguments in a special way.

Expand source code
def decorated(self, **kwargs):
  """A wrapped test method that can treat some arguments in a special way."""
  original_kwargs = kwargs.copy()

  # Skip combinations that are going to be executed in a different testing
  # environment.
  reasons_to_skip = []
  for combination in test_combinations:
    should_execute, reason = combination.should_execute_combination(
        original_kwargs.copy())
    if not should_execute:
      reasons_to_skip.append(" - " + reason)

  if reasons_to_skip:
    self.skipTest("\n".join(reasons_to_skip))

  customized_parameters = []
  for combination in test_combinations:
    customized_parameters.extend(combination.parameter_modifiers())
  customized_parameters = set(customized_parameters)

  # The function for running the test under the total set of
  # `context_managers`:
  def execute_test_method():
    requested_parameters = tf_inspect.getfullargspec(test_method).args
    for customized_parameter in customized_parameters:
      for argument, value in customized_parameter.modified_arguments(
          original_kwargs.copy(), requested_parameters).items():
        if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
          kwargs.pop(argument, None)
        else:
          kwargs[argument] = value

    omitted_arguments = set(requested_parameters).difference(
        set(list(kwargs.keys()) + ["self"]))
    if omitted_arguments:
      raise ValueError("The test requires parameters whose arguments "
                       "were not passed: {} .".format(omitted_arguments))
    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
        set(requested_parameters))
    if missing_arguments:
      raise ValueError("The test does not take parameters that were passed "
                       ": {} .".format(missing_arguments))

    kwargs_to_pass = {}
    for parameter in requested_parameters:
      if parameter == "self":
        kwargs_to_pass[parameter] = self
      else:
        kwargs_to_pass[parameter] = kwargs[parameter]
    test_method(**kwargs_to_pass)

  # Install `context_managers` before running the test:
  context_managers = []
  for combination in test_combinations:
    for manager in combination.context_managers(
        original_kwargs.copy()):
      context_managers.append(manager)

  if hasattr(contextlib, "nested"):  # Python 2
    # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
    with contextlib.nested(*context_managers):
      execute_test_method()
  else:  # Python 3
    with contextlib.ExitStack() as context_stack:
      for manager in context_managers:
        context_stack.enter_context(manager)
      execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_False_usevalidationdata_True(self, **kwargs)

A wrapped test method that can treat some arguments in a special way.

Expand source code
def decorated(self, **kwargs):
  """A wrapped test method that can treat some arguments in a special way."""
  original_kwargs = kwargs.copy()

  # Skip combinations that are going to be executed in a different testing
  # environment.
  reasons_to_skip = []
  for combination in test_combinations:
    should_execute, reason = combination.should_execute_combination(
        original_kwargs.copy())
    if not should_execute:
      reasons_to_skip.append(" - " + reason)

  if reasons_to_skip:
    self.skipTest("\n".join(reasons_to_skip))

  customized_parameters = []
  for combination in test_combinations:
    customized_parameters.extend(combination.parameter_modifiers())
  customized_parameters = set(customized_parameters)

  # The function for running the test under the total set of
  # `context_managers`:
  def execute_test_method():
    requested_parameters = tf_inspect.getfullargspec(test_method).args
    for customized_parameter in customized_parameters:
      for argument, value in customized_parameter.modified_arguments(
          original_kwargs.copy(), requested_parameters).items():
        if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
          kwargs.pop(argument, None)
        else:
          kwargs[argument] = value

    omitted_arguments = set(requested_parameters).difference(
        set(list(kwargs.keys()) + ["self"]))
    if omitted_arguments:
      raise ValueError("The test requires parameters whose arguments "
                       "were not passed: {} .".format(omitted_arguments))
    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
        set(requested_parameters))
    if missing_arguments:
      raise ValueError("The test does not take parameters that were passed "
                       ": {} .".format(missing_arguments))

    kwargs_to_pass = {}
    for parameter in requested_parameters:
      if parameter == "self":
        kwargs_to_pass[parameter] = self
      else:
        kwargs_to_pass[parameter] = kwargs[parameter]
    test_method(**kwargs_to_pass)

  # Install `context_managers` before running the test:
  context_managers = []
  for combination in test_combinations:
    for manager in combination.context_managers(
        original_kwargs.copy()):
      context_managers.append(manager)

  if hasattr(contextlib, "nested"):  # Python 2
    # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
    with contextlib.nested(*context_managers):
      execute_test_method()
  else:  # Python 3
    with contextlib.ExitStack() as context_stack:
      for manager in context_managers:
        context_stack.enter_context(manager)
      execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_True_usevalidationdata_False(self, **kwargs)

A wrapped test method that can treat some arguments in a special way.

Expand source code
def decorated(self, **kwargs):
  """A wrapped test method that can treat some arguments in a special way."""
  original_kwargs = kwargs.copy()

  # Skip combinations that are going to be executed in a different testing
  # environment.
  reasons_to_skip = []
  for combination in test_combinations:
    should_execute, reason = combination.should_execute_combination(
        original_kwargs.copy())
    if not should_execute:
      reasons_to_skip.append(" - " + reason)

  if reasons_to_skip:
    self.skipTest("\n".join(reasons_to_skip))

  customized_parameters = []
  for combination in test_combinations:
    customized_parameters.extend(combination.parameter_modifiers())
  customized_parameters = set(customized_parameters)

  # The function for running the test under the total set of
  # `context_managers`:
  def execute_test_method():
    requested_parameters = tf_inspect.getfullargspec(test_method).args
    for customized_parameter in customized_parameters:
      for argument, value in customized_parameter.modified_arguments(
          original_kwargs.copy(), requested_parameters).items():
        if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
          kwargs.pop(argument, None)
        else:
          kwargs[argument] = value

    omitted_arguments = set(requested_parameters).difference(
        set(list(kwargs.keys()) + ["self"]))
    if omitted_arguments:
      raise ValueError("The test requires parameters whose arguments "
                       "were not passed: {} .".format(omitted_arguments))
    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
        set(requested_parameters))
    if missing_arguments:
      raise ValueError("The test does not take parameters that were passed "
                       ": {} .".format(missing_arguments))

    kwargs_to_pass = {}
    for parameter in requested_parameters:
      if parameter == "self":
        kwargs_to_pass[parameter] = self
      else:
        kwargs_to_pass[parameter] = kwargs[parameter]
    test_method(**kwargs_to_pass)

  # Install `context_managers` before running the test:
  context_managers = []
  for combination in test_combinations:
    for manager in combination.context_managers(
        original_kwargs.copy()):
      context_managers.append(manager)

  if hasattr(contextlib, "nested"):  # Python 2
    # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
    with contextlib.nested(*context_managers):
      execute_test_method()
  else:  # Python 3
    with contextlib.ExitStack() as context_stack:
      for manager in context_managers:
        context_stack.enter_context(manager)
      execute_test_method()
def test_incorrectly_use_multiple_cores_for_stateful_lstm_model_test_distribution_TPU_mode_graph_usenumpy_True_usevalidationdata_True(self, **kwargs)

A wrapped test method that can treat some arguments in a special way.

Expand source code
def decorated(self, **kwargs):
  """A wrapped test method that can treat some arguments in a special way."""
  original_kwargs = kwargs.copy()

  # Skip combinations that are going to be executed in a different testing
  # environment.
  reasons_to_skip = []
  for combination in test_combinations:
    should_execute, reason = combination.should_execute_combination(
        original_kwargs.copy())
    if not should_execute:
      reasons_to_skip.append(" - " + reason)

  if reasons_to_skip:
    self.skipTest("\n".join(reasons_to_skip))

  customized_parameters = []
  for combination in test_combinations:
    customized_parameters.extend(combination.parameter_modifiers())
  customized_parameters = set(customized_parameters)

  # The function for running the test under the total set of
  # `context_managers`:
  def execute_test_method():
    requested_parameters = tf_inspect.getfullargspec(test_method).args
    for customized_parameter in customized_parameters:
      for argument, value in customized_parameter.modified_arguments(
          original_kwargs.copy(), requested_parameters).items():
        if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST:
          kwargs.pop(argument, None)
        else:
          kwargs[argument] = value

    omitted_arguments = set(requested_parameters).difference(
        set(list(kwargs.keys()) + ["self"]))
    if omitted_arguments:
      raise ValueError("The test requires parameters whose arguments "
                       "were not passed: {} .".format(omitted_arguments))
    missing_arguments = set(list(kwargs.keys()) + ["self"]).difference(
        set(requested_parameters))
    if missing_arguments:
      raise ValueError("The test does not take parameters that were passed "
                       ": {} .".format(missing_arguments))

    kwargs_to_pass = {}
    for parameter in requested_parameters:
      if parameter == "self":
        kwargs_to_pass[parameter] = self
      else:
        kwargs_to_pass[parameter] = kwargs[parameter]
    test_method(**kwargs_to_pass)

  # Install `context_managers` before running the test:
  context_managers = []
  for combination in test_combinations:
    for manager in combination.context_managers(
        original_kwargs.copy()):
      context_managers.append(manager)

  if hasattr(contextlib, "nested"):  # Python 2
    # TODO(isaprykin): Switch to ExitStack when contextlib2 is available.
    with contextlib.nested(*context_managers):
      execute_test_method()
  else:  # Python 3
    with contextlib.ExitStack() as context_stack:
      for manager in context_managers:
        context_stack.enter_context(manager)
      execute_test_method()

Inherited members