Module keras.distribute.keras_correctness_test_base

Correctness tests for tf.keras using DistributionStrategy.

Expand source code
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Correctness tests for tf.keras using DistributionStrategy."""

import tensorflow.compat.v2 as tf

import functools
from absl.testing import parameterized
import numpy as np

import keras
from keras.distribute import distributed_training_utils
from keras.distribute.strategy_combinations import all_strategies
from keras.distribute.strategy_combinations import multi_worker_mirrored_strategies
from keras.distribute.strategy_combinations import strategies_minus_tpu
from keras.mixed_precision import policy
from keras.preprocessing import sequence

_RANDOM_SEED = 1337
_EVAL_STEPS = 20
_GLOBAL_BATCH_SIZE = 64

# Note: Please make sure the tests in this file are also covered in
# keras_backward_compat_test for features that are supported with both APIs.


def eager_mode_test_configuration():
  return tf.__internal__.test.combinations.combine(
      mode='eager', use_numpy=[True, False], use_validation_data=[True, False])


def graph_mode_test_configuration():
  return tf.__internal__.test.combinations.combine(
      mode='graph', use_numpy=[True, False], use_validation_data=[True, False])


def all_strategy_and_input_config_combinations():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=all_strategies),
      eager_mode_test_configuration() + graph_mode_test_configuration()))


def all_strategy_and_input_config_combinations_eager():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=all_strategies),
      eager_mode_test_configuration()))


def strategy_minus_tpu_and_input_config_combinations_eager():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=strategies_minus_tpu),
      eager_mode_test_configuration()))


def strategies_for_embedding_models():
  """Returns distribution strategies to test for embedding models.

  Since embedding models take longer to train, we disregard DefaultStrategy
  in order to prevent testing timeouts.
  """

  return [
      s for s in all_strategies if s.required_tpu or s.required_gpus or
      s is tf.__internal__.distribute.combinations.one_device_strategy
  ]


def test_combinations_for_embedding_model():
  # TODO(sourabhbajaj): Enable tests for eager mode
  eager_mode_strategies = [
      s for s in strategies_for_embedding_models() if not s.required_tpu
  ]

  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(
          distribution=strategies_for_embedding_models()),
      (graph_mode_test_configuration())) + tf.__internal__.test.combinations.times(
          tf.__internal__.test.combinations.combine(
              distribution=eager_mode_strategies),
          (eager_mode_test_configuration())))


def test_combinations_with_tpu_strategies_graph():
  tpu_strategies = [
      tf.__internal__.distribute.combinations.tpu_strategy,
  ]

  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=tpu_strategies),
      graph_mode_test_configuration()))


def multi_worker_mirrored_eager():
  return tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=multi_worker_mirrored_strategies),
      eager_mode_test_configuration())


def multi_worker_mirrored_eager_and_graph():
  return tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=multi_worker_mirrored_strategies),
      eager_mode_test_configuration() + graph_mode_test_configuration())


class MaybeDistributionScope(object):
  """Provides a context allowing no distribution strategy."""

  def __init__(self, distribution):
    self._distribution = distribution
    self._scope = None

  def __enter__(self):
    if self._distribution:
      self._scope = self._distribution.scope()
      self._scope.__enter__()

  def __exit__(self, exc_type, value, traceback):
    if self._distribution:
      self._scope.__exit__(exc_type, value, traceback)
      self._scope = None


def batch_wrapper(dataset, batch_size, repeat=None):
  if repeat:
    dataset = dataset.repeat(repeat)
  return dataset.batch(batch_size)


def get_batch_size(global_batch_size, distribution):
  batch_size = global_batch_size
  # TODO(b/118776054): Use global batch size for Keras/DS support.
  use_per_core_batch_size = (
      distribution and
      not distributed_training_utils.global_batch_size_supported(distribution))
  if use_per_core_batch_size:
    batch_size //= distribution.num_replicas_in_sync
  return batch_size


def get_data_size(data):
  """Gets the size of data in list, tuple, dict, or a numpy array."""
  assert isinstance(data, (np.ndarray, list, dict, tuple))

  if isinstance(data, np.ndarray):
    return len(data)

  if isinstance(data, (list, tuple)):
    return len(data[0])

  return len(data.values())


def get_shapes(data):
  shapes = None
  if all(hasattr(x, 'shape') for x in tf.nest.flatten(data)):
    shapes = tf.nest.map_structure(lambda x: x.shape, data)
  return shapes


def get_correctness_test_inputs(use_numpy, use_validation_data,
                                with_distribution, x_train, y_train, x_eval,
                                y_eval, x_predict, training_epochs):
  """Generates the inputs for correctness check when enable Keras with DS."""
  global_batch_size = _GLOBAL_BATCH_SIZE
  batch_size = get_batch_size(global_batch_size, with_distribution)

  if use_numpy:
    training_inputs = {
        'batch_size': batch_size,
        'x': x_train,
        'y': y_train,
        'epochs': training_epochs,
        'shuffle': False,
    }

    if use_validation_data:
      eval_inputs = None
      training_inputs['validation_data'] = (x_eval, y_eval)
    else:
      eval_inputs = {
          'batch_size': batch_size,
          'x': x_eval,
          'y': y_eval,
      }
    predict_inputs = {'x': x_predict}
  else:
    training_data_size = get_data_size(x_train)
    # For dataset inputs, we do not pass batch_size to
    # keras.fit/evaluate/predict. The batch size is part of the dataset.
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    x = batch_wrapper(train_dataset, batch_size, repeat=training_epochs)

    steps_per_epoch = int(np.ceil(1.0 * training_data_size / global_batch_size))
    training_inputs = {
        'batch_size': None,
        'x': x,
        'y': None,
        'epochs': training_epochs,
        'shuffle': False,
        'steps_per_epoch': steps_per_epoch
    }
    if use_validation_data:
      eval_inputs = None  # Remove the eval_inputs
      eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
      x = batch_wrapper(eval_dataset, batch_size)
      training_inputs['validation_data'] = x
      training_inputs['validation_steps'] = 5
    else:
      eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
      x = batch_wrapper(eval_dataset, batch_size)
      eval_steps = int(np.ceil(1.0 * get_data_size(x_eval) / global_batch_size))
      eval_inputs = {
          'batch_size': None,
          'x': x,
          'y': None,
          'steps': eval_steps,
      }

    predict_batch_size = get_batch_size(
        get_data_size(x_predict), with_distribution)
    predict_dataset = tf.data.Dataset.from_tensor_slices(x_predict)
    predict_dataset = batch_wrapper(predict_dataset, predict_batch_size)
    predict_inputs = {
        'steps': 1,
        'x': predict_dataset,
    }

  return training_inputs, eval_inputs, predict_inputs


def fit_eval_and_predict(initial_weights,
                         input_fn,
                         model_fn,
                         distribution=None,
                         is_stateful_model=False):
  """Generates results for fit/predict/evaluate for given model."""
  training_inputs, eval_inputs, predict_inputs = input_fn()
  model = model_fn(
      initial_weights=initial_weights,
      distribution=distribution,
      input_shapes=get_shapes(training_inputs['x']))

  result = {}
  result['training_history_1'] = model.fit(**training_inputs).history

  if eval_inputs is not None:
    result['eval_result_1'] = model.evaluate(**eval_inputs)

  result['weights_1'] = model.get_weights()

  if predict_inputs is not None:
    # Check correctness of the result of predict() invoked
    # multiple times -- as for stateful models, result of
    # predict may differ for each batch.
    predict_length = 1
    if is_stateful_model:
      predict_length = 3
    for i in range(predict_length):
      result_key = 'predict_result_{}'.format(i)
      result[result_key] = model.predict(**predict_inputs)

  # Train and eval again to mimic user's flow.

  result['training_history_2'] = model.fit(**training_inputs).history

  if eval_inputs is not None:
    result['eval_result_2'] = model.evaluate(**eval_inputs)

  result['weights_2'] = model.get_weights()

  return result


def compare_results(results_with_ds,
                    results_without_ds,
                    distribution,
                    testcase,
                    partial_last_batch=None):
  """Compares results of model compiled with/without distribution strategy."""
  if policy.global_policy().compute_dtype in ('float16', 'bfloat16'):
    default_tolerance = 1e-2
    relaxed_tolerance = 1e-2
  elif partial_last_batch == 'train_and_eval':
    # We relax the tolerance a lot in the partial last batch case as
    #   1. the examples in uneven batches may have different weights when
    #      applying the gradients in the distributed case.
    #   2. TF Keras and TF Keras DS have different ways to handle the case when
    #      training with epochs > 1 with numpy inputs. In TF Keras, every epoch
    #      may have a partial batch. While in TF Keras DS, as we convert
    #      numpy inputs into dataset, it will do a repeat() first and calculate
    #      steps_per_epoch, so it will at most have one partial batch. This
    #      makes the 1-CPU result even different.
    default_tolerance = 1e-3
    relaxed_tolerance = 1e-3
  else:
    default_tolerance = 4e-5
    relaxed_tolerance = 1e-4

  def _get_compare_result_tolerance(key):
    """Returns tolerance to compare results."""
    # See b/119257215 for more details. DS test run on GPU could have larger
    # variance then test on CPU.
    if (tf.test.is_gpu_available() and
        key.startswith(('weights_1', 'weights_2', 'predict_result'))):
      return relaxed_tolerance

    return default_tolerance

  for key in sorted(results_with_ds.keys()):
    if (key.startswith('training_history') and
        isinstance(distribution,
                   (tf.distribute.experimental.TPUStrategy, tf.compat.v1.distribute.experimental.TPUStrategy)) and
        distribution.extended.steps_per_run > 1):
      # TODO(b/119894254): Enable this test for all cases once the
      # underlying bug is fixed.
      continue

    tolerance = _get_compare_result_tolerance(key)

    # We don't compare the loss as loss is currently not computed as metric
    # in Keras, the loss value is inaccurate for last partial batch due to
    # more weights for the last batch samples.
    if partial_last_batch is not None:
      if key.startswith('eval_result'):
        results_with_ds[key] = results_with_ds[key][1:]
        results_without_ds[key] = results_without_ds[key][1:]
      if key.startswith('training_history'):
        results_with_ds[key]['val_loss'] = 0
        results_without_ds[key]['val_loss'] = 0

    testcase.assertAllClose(
        results_with_ds[key],
        results_without_ds[key],
        atol=tolerance,
        rtol=tolerance,
        msg='Fail to assert {}.'.format(key))


def should_skip_tpu_with_eager(distribution):
  return (tf.executing_eagerly() and
          isinstance(distribution,
                     (tf.distribute.experimental.TPUStrategy, tf.compat.v1.distribute.experimental.TPUStrategy)))


class LearningRateBatchScheduler(keras.callbacks.Callback):
  """Scheduler that dynamically sets the learning rate of model."""

  def __init__(self, update_freq=None):
    self._update_freq = update_freq

  def on_batch_begin(self, batch, logs=None):
    if self._update_freq and batch % self._update_freq != 0:
      return

    # To avoid divergence, limit the value range.
    lr = 0.001 * (batch % 10)
    keras.backend.set_value(self.model.optimizer.lr, lr)


class TestDistributionStrategyCorrectnessBase(tf.test.TestCase,
                                              parameterized.TestCase):
  """Model agnostic testing infra to test correctness of Keras models."""

  def set_up_test_config(self,
                         use_numpy=False,
                         use_validation_data=False,
                         with_batch_norm=None):
    self.use_numpy = use_numpy
    self.use_validation_data = use_validation_data
    self.with_batch_norm = with_batch_norm

    keras.backend.set_image_data_format('channels_last')
    np.random.seed(_RANDOM_SEED)
    tf.compat.v1.set_random_seed(_RANDOM_SEED)

  def get_data(self):
    num_samples = 10000
    x_train = np.random.randint(0, 2, num_samples)
    x_train = np.reshape(x_train, (num_samples, 1))
    y_train = x_train
    return (x_train.astype('float32'), y_train.astype('float32'), None)

  def get_data_with_partial_last_batch(self):
    raise NotImplementedError

  def get_data_with_partial_last_batch_eval(self):
    raise NotImplementedError

  def get_input_for_correctness_test(self, **kwargs):
    """Generates inputs that are dictionaries.

    We only provide a default implementation of this method here. If you need
    more customized way of providing input to your model, overwrite this method.

    Args:
      **kwargs: key word arguments about how to create the input dictionaries

    Returns:
      Three dictionaries representing the input for fit(), evaluate() and
      predict()
    """

    return get_correctness_test_inputs(**kwargs)

  def get_model(self,
                distribution=None,
                input_shapes=None):
    raise NotImplementedError

  def run_correctness_test(self,
                           distribution,
                           use_numpy,
                           use_validation_data,
                           with_batch_norm=None,
                           is_stateful_model=False,
                           partial_last_batch=None,
                           training_epochs=2):
    with self.cached_session():
      self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)

      if partial_last_batch == 'eval':
        x_train, y_train, x_eval, y_eval, x_predict = (
            self.get_data_with_partial_last_batch_eval())
      elif partial_last_batch == 'train_and_eval':
        x_train, y_train, x_eval, y_eval, x_predict = (
            self.get_data_with_partial_last_batch())
      else:
        x_train, y_train, x_predict = self.get_data()
        x_eval = x_train
        y_eval = y_train

      # The model is built once and the initial weights are saved.
      # This is used to initialize the model for both the distribution and
      # non-distribution run.
      model = self.get_model(
          input_shapes=get_shapes(x_train))
      initial_weights = model.get_weights()

      ds_input_fn = functools.partial(
          self.get_input_for_correctness_test,
          use_numpy=use_numpy,
          use_validation_data=use_validation_data,
          with_distribution=distribution,
          x_train=x_train,
          y_train=y_train,
          x_eval=x_eval,
          y_eval=y_eval,
          x_predict=x_predict,
          training_epochs=training_epochs)

      nods_input_fn = functools.partial(
          self.get_input_for_correctness_test,
          use_numpy=use_numpy,
          use_validation_data=use_validation_data,
          with_distribution=None,
          x_train=x_train,
          y_train=y_train,
          x_eval=x_eval,
          y_eval=y_eval,
          x_predict=x_predict,
          training_epochs=training_epochs)

      results_with_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=ds_input_fn,
          model_fn=self.get_model,
          distribution=distribution,
          is_stateful_model=is_stateful_model)
      results_without_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=nods_input_fn,
          model_fn=self.get_model,
          distribution=None,
          is_stateful_model=is_stateful_model)

      # First, special case, for multi-replica distributed training, batch
      # norm is not aggregated globally. So it is expected to have different
      # weights.
      if (self.with_batch_norm == 'regular' and
          distribution.num_replicas_in_sync > 1):
        with self.assertRaises(AssertionError):
          compare_results(
              results_with_ds,
              results_without_ds,
              distribution,
              testcase=self,
              partial_last_batch=partial_last_batch)
      else:
        compare_results(
            results_with_ds,
            results_without_ds,
            distribution,
            testcase=self,
            partial_last_batch=partial_last_batch)

  def get_input_for_dynamic_lr_test(self, **kwargs):
    """Generates inputs that are dictionaries.

    We only provide a default implementation of this method here. If you need
    more customized way of providing input to your model, overwrite this method.

    Args:
      **kwargs: key word arguments about how to create the input dictionaries

    Returns:
      Three dictionaries representing the input for fit(), evaluate() and
      predict()
    """

    training_input = kwargs
    return training_input, None, None

  def run_dynamic_lr_test(self,
                          distribution):
    with self.cached_session():
      self.set_up_test_config()

      x_train, y_train, _ = self.get_data()
      model = self.get_model(
          input_shapes=get_shapes(x_train))
      initial_weights = model.get_weights()
      update_freq = None

      if (isinstance(distribution, tf.compat.v1.distribute.experimental.TPUStrategy) and
          distribution.extended.steps_per_run > 1):
        # For TPUStrategy with steps_per_run > 1, the callback is not invoked
        # every step. So, to compare the CPU/TPU, we let the CPU to behave the
        # same as TPU.
        update_freq = distribution.extended.steps_per_run

      training_epochs = 2
      global_batch_size = 64

      ds_batch_size = get_batch_size(global_batch_size, distribution)
      nods_batch_size = get_batch_size(global_batch_size, None)

      ds_input_fn = functools.partial(
          self.get_input_for_dynamic_lr_test,
          x=x_train,
          y=y_train,
          batch_size=ds_batch_size,
          shuffle=False,
          epochs=training_epochs,
          callbacks=[LearningRateBatchScheduler(update_freq)],
          validation_data=(x_train, y_train))

      nods_input_fn = functools.partial(
          self.get_input_for_dynamic_lr_test,
          x=x_train,
          y=y_train,
          batch_size=nods_batch_size,
          shuffle=False,
          epochs=training_epochs,
          callbacks=[LearningRateBatchScheduler(update_freq)],
          validation_data=(x_train, y_train))

      results_with_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=ds_input_fn,
          model_fn=self.get_model,
          distribution=distribution)
      results_without_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=nods_input_fn,
          model_fn=self.get_model,
          distribution=None)
      compare_results(
          results_with_ds, results_without_ds, distribution, testcase=self)


class TestDistributionStrategyEmbeddingModelCorrectnessBase(
    TestDistributionStrategyCorrectnessBase):
  """Base class to test correctness of Keras models with embedding layers."""

  def get_data(self,
               count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS),
               min_words=5,
               max_words=10,
               max_word_id=19,
               num_classes=2):
    distribution = []
    for _ in range(num_classes):
      dist = np.abs(np.random.randn(max_word_id))
      dist /= np.sum(dist)
      distribution.append(dist)

    features = []
    labels = []
    for _ in range(count):
      label = np.random.randint(0, num_classes, size=1)[0]
      num_words = np.random.randint(min_words, max_words, size=1)[0]
      word_ids = np.random.choice(
          max_word_id, size=num_words, replace=True, p=distribution[label])
      word_ids = word_ids
      labels.append(label)
      features.append(word_ids)

    features = sequence.pad_sequences(
        features, maxlen=max_words)
    x_train = np.asarray(features, dtype=np.float32)
    y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1))
    x_predict = x_train[:_GLOBAL_BATCH_SIZE]
    return x_train, y_train, x_predict


if __name__ == '__main__':
  tf.test.main()

Functions

def all_strategy_and_input_config_combinations()
Expand source code
def all_strategy_and_input_config_combinations():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=all_strategies),
      eager_mode_test_configuration() + graph_mode_test_configuration()))
def all_strategy_and_input_config_combinations_eager()
Expand source code
def all_strategy_and_input_config_combinations_eager():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=all_strategies),
      eager_mode_test_configuration()))
def batch_wrapper(dataset, batch_size, repeat=None)
Expand source code
def batch_wrapper(dataset, batch_size, repeat=None):
  if repeat:
    dataset = dataset.repeat(repeat)
  return dataset.batch(batch_size)
def compare_results(results_with_ds, results_without_ds, distribution, testcase, partial_last_batch=None)

Compares results of model compiled with/without distribution strategy.

Expand source code
def compare_results(results_with_ds,
                    results_without_ds,
                    distribution,
                    testcase,
                    partial_last_batch=None):
  """Compares results of model compiled with/without distribution strategy."""
  if policy.global_policy().compute_dtype in ('float16', 'bfloat16'):
    default_tolerance = 1e-2
    relaxed_tolerance = 1e-2
  elif partial_last_batch == 'train_and_eval':
    # We relax the tolerance a lot in the partial last batch case as
    #   1. the examples in uneven batches may have different weights when
    #      applying the gradients in the distributed case.
    #   2. TF Keras and TF Keras DS have different ways to handle the case when
    #      training with epochs > 1 with numpy inputs. In TF Keras, every epoch
    #      may have a partial batch. While in TF Keras DS, as we convert
    #      numpy inputs into dataset, it will do a repeat() first and calculate
    #      steps_per_epoch, so it will at most have one partial batch. This
    #      makes the 1-CPU result even different.
    default_tolerance = 1e-3
    relaxed_tolerance = 1e-3
  else:
    default_tolerance = 4e-5
    relaxed_tolerance = 1e-4

  def _get_compare_result_tolerance(key):
    """Returns tolerance to compare results."""
    # See b/119257215 for more details. DS test run on GPU could have larger
    # variance then test on CPU.
    if (tf.test.is_gpu_available() and
        key.startswith(('weights_1', 'weights_2', 'predict_result'))):
      return relaxed_tolerance

    return default_tolerance

  for key in sorted(results_with_ds.keys()):
    if (key.startswith('training_history') and
        isinstance(distribution,
                   (tf.distribute.experimental.TPUStrategy, tf.compat.v1.distribute.experimental.TPUStrategy)) and
        distribution.extended.steps_per_run > 1):
      # TODO(b/119894254): Enable this test for all cases once the
      # underlying bug is fixed.
      continue

    tolerance = _get_compare_result_tolerance(key)

    # We don't compare the loss as loss is currently not computed as metric
    # in Keras, the loss value is inaccurate for last partial batch due to
    # more weights for the last batch samples.
    if partial_last_batch is not None:
      if key.startswith('eval_result'):
        results_with_ds[key] = results_with_ds[key][1:]
        results_without_ds[key] = results_without_ds[key][1:]
      if key.startswith('training_history'):
        results_with_ds[key]['val_loss'] = 0
        results_without_ds[key]['val_loss'] = 0

    testcase.assertAllClose(
        results_with_ds[key],
        results_without_ds[key],
        atol=tolerance,
        rtol=tolerance,
        msg='Fail to assert {}.'.format(key))
def eager_mode_test_configuration()
Expand source code
def eager_mode_test_configuration():
  return tf.__internal__.test.combinations.combine(
      mode='eager', use_numpy=[True, False], use_validation_data=[True, False])
def fit_eval_and_predict(initial_weights, input_fn, model_fn, distribution=None, is_stateful_model=False)

Generates results for fit/predict/evaluate for given model.

Expand source code
def fit_eval_and_predict(initial_weights,
                         input_fn,
                         model_fn,
                         distribution=None,
                         is_stateful_model=False):
  """Generates results for fit/predict/evaluate for given model."""
  training_inputs, eval_inputs, predict_inputs = input_fn()
  model = model_fn(
      initial_weights=initial_weights,
      distribution=distribution,
      input_shapes=get_shapes(training_inputs['x']))

  result = {}
  result['training_history_1'] = model.fit(**training_inputs).history

  if eval_inputs is not None:
    result['eval_result_1'] = model.evaluate(**eval_inputs)

  result['weights_1'] = model.get_weights()

  if predict_inputs is not None:
    # Check correctness of the result of predict() invoked
    # multiple times -- as for stateful models, result of
    # predict may differ for each batch.
    predict_length = 1
    if is_stateful_model:
      predict_length = 3
    for i in range(predict_length):
      result_key = 'predict_result_{}'.format(i)
      result[result_key] = model.predict(**predict_inputs)

  # Train and eval again to mimic user's flow.

  result['training_history_2'] = model.fit(**training_inputs).history

  if eval_inputs is not None:
    result['eval_result_2'] = model.evaluate(**eval_inputs)

  result['weights_2'] = model.get_weights()

  return result
def get_batch_size(global_batch_size, distribution)
Expand source code
def get_batch_size(global_batch_size, distribution):
  batch_size = global_batch_size
  # TODO(b/118776054): Use global batch size for Keras/DS support.
  use_per_core_batch_size = (
      distribution and
      not distributed_training_utils.global_batch_size_supported(distribution))
  if use_per_core_batch_size:
    batch_size //= distribution.num_replicas_in_sync
  return batch_size
def get_correctness_test_inputs(use_numpy, use_validation_data, with_distribution, x_train, y_train, x_eval, y_eval, x_predict, training_epochs)

Generates the inputs for correctness check when enable Keras with DS.

Expand source code
def get_correctness_test_inputs(use_numpy, use_validation_data,
                                with_distribution, x_train, y_train, x_eval,
                                y_eval, x_predict, training_epochs):
  """Generates the inputs for correctness check when enable Keras with DS."""
  global_batch_size = _GLOBAL_BATCH_SIZE
  batch_size = get_batch_size(global_batch_size, with_distribution)

  if use_numpy:
    training_inputs = {
        'batch_size': batch_size,
        'x': x_train,
        'y': y_train,
        'epochs': training_epochs,
        'shuffle': False,
    }

    if use_validation_data:
      eval_inputs = None
      training_inputs['validation_data'] = (x_eval, y_eval)
    else:
      eval_inputs = {
          'batch_size': batch_size,
          'x': x_eval,
          'y': y_eval,
      }
    predict_inputs = {'x': x_predict}
  else:
    training_data_size = get_data_size(x_train)
    # For dataset inputs, we do not pass batch_size to
    # keras.fit/evaluate/predict. The batch size is part of the dataset.
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    x = batch_wrapper(train_dataset, batch_size, repeat=training_epochs)

    steps_per_epoch = int(np.ceil(1.0 * training_data_size / global_batch_size))
    training_inputs = {
        'batch_size': None,
        'x': x,
        'y': None,
        'epochs': training_epochs,
        'shuffle': False,
        'steps_per_epoch': steps_per_epoch
    }
    if use_validation_data:
      eval_inputs = None  # Remove the eval_inputs
      eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
      x = batch_wrapper(eval_dataset, batch_size)
      training_inputs['validation_data'] = x
      training_inputs['validation_steps'] = 5
    else:
      eval_dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
      x = batch_wrapper(eval_dataset, batch_size)
      eval_steps = int(np.ceil(1.0 * get_data_size(x_eval) / global_batch_size))
      eval_inputs = {
          'batch_size': None,
          'x': x,
          'y': None,
          'steps': eval_steps,
      }

    predict_batch_size = get_batch_size(
        get_data_size(x_predict), with_distribution)
    predict_dataset = tf.data.Dataset.from_tensor_slices(x_predict)
    predict_dataset = batch_wrapper(predict_dataset, predict_batch_size)
    predict_inputs = {
        'steps': 1,
        'x': predict_dataset,
    }

  return training_inputs, eval_inputs, predict_inputs
def get_data_size(data)

Gets the size of data in list, tuple, dict, or a numpy array.

Expand source code
def get_data_size(data):
  """Gets the size of data in list, tuple, dict, or a numpy array."""
  assert isinstance(data, (np.ndarray, list, dict, tuple))

  if isinstance(data, np.ndarray):
    return len(data)

  if isinstance(data, (list, tuple)):
    return len(data[0])

  return len(data.values())
def get_shapes(data)
Expand source code
def get_shapes(data):
  shapes = None
  if all(hasattr(x, 'shape') for x in tf.nest.flatten(data)):
    shapes = tf.nest.map_structure(lambda x: x.shape, data)
  return shapes
def graph_mode_test_configuration()
Expand source code
def graph_mode_test_configuration():
  return tf.__internal__.test.combinations.combine(
      mode='graph', use_numpy=[True, False], use_validation_data=[True, False])
def multi_worker_mirrored_eager()
Expand source code
def multi_worker_mirrored_eager():
  return tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=multi_worker_mirrored_strategies),
      eager_mode_test_configuration())
def multi_worker_mirrored_eager_and_graph()
Expand source code
def multi_worker_mirrored_eager_and_graph():
  return tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=multi_worker_mirrored_strategies),
      eager_mode_test_configuration() + graph_mode_test_configuration())
def should_skip_tpu_with_eager(distribution)
Expand source code
def should_skip_tpu_with_eager(distribution):
  return (tf.executing_eagerly() and
          isinstance(distribution,
                     (tf.distribute.experimental.TPUStrategy, tf.compat.v1.distribute.experimental.TPUStrategy)))
def strategies_for_embedding_models()

Returns distribution strategies to test for embedding models.

Since embedding models take longer to train, we disregard DefaultStrategy in order to prevent testing timeouts.

Expand source code
def strategies_for_embedding_models():
  """Returns distribution strategies to test for embedding models.

  Since embedding models take longer to train, we disregard DefaultStrategy
  in order to prevent testing timeouts.
  """

  return [
      s for s in all_strategies if s.required_tpu or s.required_gpus or
      s is tf.__internal__.distribute.combinations.one_device_strategy
  ]
def strategy_minus_tpu_and_input_config_combinations_eager()
Expand source code
def strategy_minus_tpu_and_input_config_combinations_eager():
  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=strategies_minus_tpu),
      eager_mode_test_configuration()))
def test_combinations_for_embedding_model()
Expand source code
def test_combinations_for_embedding_model():
  # TODO(sourabhbajaj): Enable tests for eager mode
  eager_mode_strategies = [
      s for s in strategies_for_embedding_models() if not s.required_tpu
  ]

  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(
          distribution=strategies_for_embedding_models()),
      (graph_mode_test_configuration())) + tf.__internal__.test.combinations.times(
          tf.__internal__.test.combinations.combine(
              distribution=eager_mode_strategies),
          (eager_mode_test_configuration())))
def test_combinations_with_tpu_strategies_graph()
Expand source code
def test_combinations_with_tpu_strategies_graph():
  tpu_strategies = [
      tf.__internal__.distribute.combinations.tpu_strategy,
  ]

  return (tf.__internal__.test.combinations.times(
      tf.__internal__.test.combinations.combine(distribution=tpu_strategies),
      graph_mode_test_configuration()))

Classes

class LearningRateBatchScheduler (update_freq=None)

Scheduler that dynamically sets the learning rate of model.

Expand source code
class LearningRateBatchScheduler(keras.callbacks.Callback):
  """Scheduler that dynamically sets the learning rate of model."""

  def __init__(self, update_freq=None):
    self._update_freq = update_freq

  def on_batch_begin(self, batch, logs=None):
    if self._update_freq and batch % self._update_freq != 0:
      return

    # To avoid divergence, limit the value range.
    lr = 0.001 * (batch % 10)
    keras.backend.set_value(self.model.optimizer.lr, lr)

Ancestors

Inherited members

class MaybeDistributionScope (distribution)

Provides a context allowing no distribution strategy.

Expand source code
class MaybeDistributionScope(object):
  """Provides a context allowing no distribution strategy."""

  def __init__(self, distribution):
    self._distribution = distribution
    self._scope = None

  def __enter__(self):
    if self._distribution:
      self._scope = self._distribution.scope()
      self._scope.__enter__()

  def __exit__(self, exc_type, value, traceback):
    if self._distribution:
      self._scope.__exit__(exc_type, value, traceback)
      self._scope = None
class TestDistributionStrategyCorrectnessBase (methodName='runTest')

Model agnostic testing infra to test correctness of Keras models.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class TestDistributionStrategyCorrectnessBase(tf.test.TestCase,
                                              parameterized.TestCase):
  """Model agnostic testing infra to test correctness of Keras models."""

  def set_up_test_config(self,
                         use_numpy=False,
                         use_validation_data=False,
                         with_batch_norm=None):
    self.use_numpy = use_numpy
    self.use_validation_data = use_validation_data
    self.with_batch_norm = with_batch_norm

    keras.backend.set_image_data_format('channels_last')
    np.random.seed(_RANDOM_SEED)
    tf.compat.v1.set_random_seed(_RANDOM_SEED)

  def get_data(self):
    num_samples = 10000
    x_train = np.random.randint(0, 2, num_samples)
    x_train = np.reshape(x_train, (num_samples, 1))
    y_train = x_train
    return (x_train.astype('float32'), y_train.astype('float32'), None)

  def get_data_with_partial_last_batch(self):
    raise NotImplementedError

  def get_data_with_partial_last_batch_eval(self):
    raise NotImplementedError

  def get_input_for_correctness_test(self, **kwargs):
    """Generates inputs that are dictionaries.

    We only provide a default implementation of this method here. If you need
    more customized way of providing input to your model, overwrite this method.

    Args:
      **kwargs: key word arguments about how to create the input dictionaries

    Returns:
      Three dictionaries representing the input for fit(), evaluate() and
      predict()
    """

    return get_correctness_test_inputs(**kwargs)

  def get_model(self,
                distribution=None,
                input_shapes=None):
    raise NotImplementedError

  def run_correctness_test(self,
                           distribution,
                           use_numpy,
                           use_validation_data,
                           with_batch_norm=None,
                           is_stateful_model=False,
                           partial_last_batch=None,
                           training_epochs=2):
    with self.cached_session():
      self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)

      if partial_last_batch == 'eval':
        x_train, y_train, x_eval, y_eval, x_predict = (
            self.get_data_with_partial_last_batch_eval())
      elif partial_last_batch == 'train_and_eval':
        x_train, y_train, x_eval, y_eval, x_predict = (
            self.get_data_with_partial_last_batch())
      else:
        x_train, y_train, x_predict = self.get_data()
        x_eval = x_train
        y_eval = y_train

      # The model is built once and the initial weights are saved.
      # This is used to initialize the model for both the distribution and
      # non-distribution run.
      model = self.get_model(
          input_shapes=get_shapes(x_train))
      initial_weights = model.get_weights()

      ds_input_fn = functools.partial(
          self.get_input_for_correctness_test,
          use_numpy=use_numpy,
          use_validation_data=use_validation_data,
          with_distribution=distribution,
          x_train=x_train,
          y_train=y_train,
          x_eval=x_eval,
          y_eval=y_eval,
          x_predict=x_predict,
          training_epochs=training_epochs)

      nods_input_fn = functools.partial(
          self.get_input_for_correctness_test,
          use_numpy=use_numpy,
          use_validation_data=use_validation_data,
          with_distribution=None,
          x_train=x_train,
          y_train=y_train,
          x_eval=x_eval,
          y_eval=y_eval,
          x_predict=x_predict,
          training_epochs=training_epochs)

      results_with_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=ds_input_fn,
          model_fn=self.get_model,
          distribution=distribution,
          is_stateful_model=is_stateful_model)
      results_without_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=nods_input_fn,
          model_fn=self.get_model,
          distribution=None,
          is_stateful_model=is_stateful_model)

      # First, special case, for multi-replica distributed training, batch
      # norm is not aggregated globally. So it is expected to have different
      # weights.
      if (self.with_batch_norm == 'regular' and
          distribution.num_replicas_in_sync > 1):
        with self.assertRaises(AssertionError):
          compare_results(
              results_with_ds,
              results_without_ds,
              distribution,
              testcase=self,
              partial_last_batch=partial_last_batch)
      else:
        compare_results(
            results_with_ds,
            results_without_ds,
            distribution,
            testcase=self,
            partial_last_batch=partial_last_batch)

  def get_input_for_dynamic_lr_test(self, **kwargs):
    """Generates inputs that are dictionaries.

    We only provide a default implementation of this method here. If you need
    more customized way of providing input to your model, overwrite this method.

    Args:
      **kwargs: key word arguments about how to create the input dictionaries

    Returns:
      Three dictionaries representing the input for fit(), evaluate() and
      predict()
    """

    training_input = kwargs
    return training_input, None, None

  def run_dynamic_lr_test(self,
                          distribution):
    with self.cached_session():
      self.set_up_test_config()

      x_train, y_train, _ = self.get_data()
      model = self.get_model(
          input_shapes=get_shapes(x_train))
      initial_weights = model.get_weights()
      update_freq = None

      if (isinstance(distribution, tf.compat.v1.distribute.experimental.TPUStrategy) and
          distribution.extended.steps_per_run > 1):
        # For TPUStrategy with steps_per_run > 1, the callback is not invoked
        # every step. So, to compare the CPU/TPU, we let the CPU to behave the
        # same as TPU.
        update_freq = distribution.extended.steps_per_run

      training_epochs = 2
      global_batch_size = 64

      ds_batch_size = get_batch_size(global_batch_size, distribution)
      nods_batch_size = get_batch_size(global_batch_size, None)

      ds_input_fn = functools.partial(
          self.get_input_for_dynamic_lr_test,
          x=x_train,
          y=y_train,
          batch_size=ds_batch_size,
          shuffle=False,
          epochs=training_epochs,
          callbacks=[LearningRateBatchScheduler(update_freq)],
          validation_data=(x_train, y_train))

      nods_input_fn = functools.partial(
          self.get_input_for_dynamic_lr_test,
          x=x_train,
          y=y_train,
          batch_size=nods_batch_size,
          shuffle=False,
          epochs=training_epochs,
          callbacks=[LearningRateBatchScheduler(update_freq)],
          validation_data=(x_train, y_train))

      results_with_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=ds_input_fn,
          model_fn=self.get_model,
          distribution=distribution)
      results_without_ds = fit_eval_and_predict(
          initial_weights,
          input_fn=nods_input_fn,
          model_fn=self.get_model,
          distribution=None)
      compare_results(
          results_with_ds, results_without_ds, distribution, testcase=self)

Ancestors

  • tensorflow.python.framework.test_util.TensorFlowTestCase
  • absl.testing.parameterized.TestCase
  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Subclasses

Methods

def get_data(self)
Expand source code
def get_data(self):
  num_samples = 10000
  x_train = np.random.randint(0, 2, num_samples)
  x_train = np.reshape(x_train, (num_samples, 1))
  y_train = x_train
  return (x_train.astype('float32'), y_train.astype('float32'), None)
def get_data_with_partial_last_batch(self)
Expand source code
def get_data_with_partial_last_batch(self):
  raise NotImplementedError
def get_data_with_partial_last_batch_eval(self)
Expand source code
def get_data_with_partial_last_batch_eval(self):
  raise NotImplementedError
def get_input_for_correctness_test(self, **kwargs)

Generates inputs that are dictionaries.

We only provide a default implementation of this method here. If you need more customized way of providing input to your model, overwrite this method.

Args

**kwargs
key word arguments about how to create the input dictionaries

Returns

Three dictionaries representing the input for fit(), evaluate() and predict()

Expand source code
def get_input_for_correctness_test(self, **kwargs):
  """Generates inputs that are dictionaries.

  We only provide a default implementation of this method here. If you need
  more customized way of providing input to your model, overwrite this method.

  Args:
    **kwargs: key word arguments about how to create the input dictionaries

  Returns:
    Three dictionaries representing the input for fit(), evaluate() and
    predict()
  """

  return get_correctness_test_inputs(**kwargs)
def get_input_for_dynamic_lr_test(self, **kwargs)

Generates inputs that are dictionaries.

We only provide a default implementation of this method here. If you need more customized way of providing input to your model, overwrite this method.

Args

**kwargs
key word arguments about how to create the input dictionaries

Returns

Three dictionaries representing the input for fit(), evaluate() and predict()

Expand source code
def get_input_for_dynamic_lr_test(self, **kwargs):
  """Generates inputs that are dictionaries.

  We only provide a default implementation of this method here. If you need
  more customized way of providing input to your model, overwrite this method.

  Args:
    **kwargs: key word arguments about how to create the input dictionaries

  Returns:
    Three dictionaries representing the input for fit(), evaluate() and
    predict()
  """

  training_input = kwargs
  return training_input, None, None
def get_model(self, distribution=None, input_shapes=None)
Expand source code
def get_model(self,
              distribution=None,
              input_shapes=None):
  raise NotImplementedError
def run_correctness_test(self, distribution, use_numpy, use_validation_data, with_batch_norm=None, is_stateful_model=False, partial_last_batch=None, training_epochs=2)
Expand source code
def run_correctness_test(self,
                         distribution,
                         use_numpy,
                         use_validation_data,
                         with_batch_norm=None,
                         is_stateful_model=False,
                         partial_last_batch=None,
                         training_epochs=2):
  with self.cached_session():
    self.set_up_test_config(use_numpy, use_validation_data, with_batch_norm)

    if partial_last_batch == 'eval':
      x_train, y_train, x_eval, y_eval, x_predict = (
          self.get_data_with_partial_last_batch_eval())
    elif partial_last_batch == 'train_and_eval':
      x_train, y_train, x_eval, y_eval, x_predict = (
          self.get_data_with_partial_last_batch())
    else:
      x_train, y_train, x_predict = self.get_data()
      x_eval = x_train
      y_eval = y_train

    # The model is built once and the initial weights are saved.
    # This is used to initialize the model for both the distribution and
    # non-distribution run.
    model = self.get_model(
        input_shapes=get_shapes(x_train))
    initial_weights = model.get_weights()

    ds_input_fn = functools.partial(
        self.get_input_for_correctness_test,
        use_numpy=use_numpy,
        use_validation_data=use_validation_data,
        with_distribution=distribution,
        x_train=x_train,
        y_train=y_train,
        x_eval=x_eval,
        y_eval=y_eval,
        x_predict=x_predict,
        training_epochs=training_epochs)

    nods_input_fn = functools.partial(
        self.get_input_for_correctness_test,
        use_numpy=use_numpy,
        use_validation_data=use_validation_data,
        with_distribution=None,
        x_train=x_train,
        y_train=y_train,
        x_eval=x_eval,
        y_eval=y_eval,
        x_predict=x_predict,
        training_epochs=training_epochs)

    results_with_ds = fit_eval_and_predict(
        initial_weights,
        input_fn=ds_input_fn,
        model_fn=self.get_model,
        distribution=distribution,
        is_stateful_model=is_stateful_model)
    results_without_ds = fit_eval_and_predict(
        initial_weights,
        input_fn=nods_input_fn,
        model_fn=self.get_model,
        distribution=None,
        is_stateful_model=is_stateful_model)

    # First, special case, for multi-replica distributed training, batch
    # norm is not aggregated globally. So it is expected to have different
    # weights.
    if (self.with_batch_norm == 'regular' and
        distribution.num_replicas_in_sync > 1):
      with self.assertRaises(AssertionError):
        compare_results(
            results_with_ds,
            results_without_ds,
            distribution,
            testcase=self,
            partial_last_batch=partial_last_batch)
    else:
      compare_results(
          results_with_ds,
          results_without_ds,
          distribution,
          testcase=self,
          partial_last_batch=partial_last_batch)
def run_dynamic_lr_test(self, distribution)
Expand source code
def run_dynamic_lr_test(self,
                        distribution):
  with self.cached_session():
    self.set_up_test_config()

    x_train, y_train, _ = self.get_data()
    model = self.get_model(
        input_shapes=get_shapes(x_train))
    initial_weights = model.get_weights()
    update_freq = None

    if (isinstance(distribution, tf.compat.v1.distribute.experimental.TPUStrategy) and
        distribution.extended.steps_per_run > 1):
      # For TPUStrategy with steps_per_run > 1, the callback is not invoked
      # every step. So, to compare the CPU/TPU, we let the CPU to behave the
      # same as TPU.
      update_freq = distribution.extended.steps_per_run

    training_epochs = 2
    global_batch_size = 64

    ds_batch_size = get_batch_size(global_batch_size, distribution)
    nods_batch_size = get_batch_size(global_batch_size, None)

    ds_input_fn = functools.partial(
        self.get_input_for_dynamic_lr_test,
        x=x_train,
        y=y_train,
        batch_size=ds_batch_size,
        shuffle=False,
        epochs=training_epochs,
        callbacks=[LearningRateBatchScheduler(update_freq)],
        validation_data=(x_train, y_train))

    nods_input_fn = functools.partial(
        self.get_input_for_dynamic_lr_test,
        x=x_train,
        y=y_train,
        batch_size=nods_batch_size,
        shuffle=False,
        epochs=training_epochs,
        callbacks=[LearningRateBatchScheduler(update_freq)],
        validation_data=(x_train, y_train))

    results_with_ds = fit_eval_and_predict(
        initial_weights,
        input_fn=ds_input_fn,
        model_fn=self.get_model,
        distribution=distribution)
    results_without_ds = fit_eval_and_predict(
        initial_weights,
        input_fn=nods_input_fn,
        model_fn=self.get_model,
        distribution=None)
    compare_results(
        results_with_ds, results_without_ds, distribution, testcase=self)
def set_up_test_config(self, use_numpy=False, use_validation_data=False, with_batch_norm=None)
Expand source code
def set_up_test_config(self,
                       use_numpy=False,
                       use_validation_data=False,
                       with_batch_norm=None):
  self.use_numpy = use_numpy
  self.use_validation_data = use_validation_data
  self.with_batch_norm = with_batch_norm

  keras.backend.set_image_data_format('channels_last')
  np.random.seed(_RANDOM_SEED)
  tf.compat.v1.set_random_seed(_RANDOM_SEED)
class TestDistributionStrategyEmbeddingModelCorrectnessBase (methodName='runTest')

Base class to test correctness of Keras models with embedding layers.

Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.

Expand source code
class TestDistributionStrategyEmbeddingModelCorrectnessBase(
    TestDistributionStrategyCorrectnessBase):
  """Base class to test correctness of Keras models with embedding layers."""

  def get_data(self,
               count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS),
               min_words=5,
               max_words=10,
               max_word_id=19,
               num_classes=2):
    distribution = []
    for _ in range(num_classes):
      dist = np.abs(np.random.randn(max_word_id))
      dist /= np.sum(dist)
      distribution.append(dist)

    features = []
    labels = []
    for _ in range(count):
      label = np.random.randint(0, num_classes, size=1)[0]
      num_words = np.random.randint(min_words, max_words, size=1)[0]
      word_ids = np.random.choice(
          max_word_id, size=num_words, replace=True, p=distribution[label])
      word_ids = word_ids
      labels.append(label)
      features.append(word_ids)

    features = sequence.pad_sequences(
        features, maxlen=max_words)
    x_train = np.asarray(features, dtype=np.float32)
    y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1))
    x_predict = x_train[:_GLOBAL_BATCH_SIZE]
    return x_train, y_train, x_predict

Ancestors

  • TestDistributionStrategyCorrectnessBase
  • tensorflow.python.framework.test_util.TensorFlowTestCase
  • absl.testing.parameterized.TestCase
  • absl.testing.absltest.TestCase
  • absl.third_party.unittest3_backport.case.TestCase
  • unittest.case.TestCase

Subclasses

Methods

def get_data(self, count=1280, min_words=5, max_words=10, max_word_id=19, num_classes=2)
Expand source code
def get_data(self,
             count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS),
             min_words=5,
             max_words=10,
             max_word_id=19,
             num_classes=2):
  distribution = []
  for _ in range(num_classes):
    dist = np.abs(np.random.randn(max_word_id))
    dist /= np.sum(dist)
    distribution.append(dist)

  features = []
  labels = []
  for _ in range(count):
    label = np.random.randint(0, num_classes, size=1)[0]
    num_words = np.random.randint(min_words, max_words, size=1)[0]
    word_ids = np.random.choice(
        max_word_id, size=num_words, replace=True, p=distribution[label])
    word_ids = word_ids
    labels.append(label)
    features.append(word_ids)

  features = sequence.pad_sequences(
      features, maxlen=max_words)
  x_train = np.asarray(features, dtype=np.float32)
  y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1))
  x_predict = x_train[:_GLOBAL_BATCH_SIZE]
  return x_train, y_train, x_predict

Inherited members