Module keras.engine.partial_batch_padding_handler
Utility object to handler partial batches for TPUStrategy.
Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility object to handler partial batches for TPUStrategy."""
import tensorflow.compat.v2 as tf
# pylint: disable=protected-access
import numpy as np
from keras import backend
class PartialBatchPaddingHandler(object):
"""A container that holds info about partial batches for `predict()`."""
def __init__(self, output_shape):
self.padded_batch_size = 0
self.padding_mask = tf.zeros(0)
self.output_shape = output_shape
def get_real_batch_size(self, dataset_batch):
"""Returns the number of elements in a potentially partial batch."""
if isinstance(dataset_batch, (tuple, list)):
dataset_batch = dataset_batch[0]
assert tf.nest.flatten(dataset_batch)
def _find_any_tensor(batch_features):
tensors = [
x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x)
]
if not tensors:
raise ValueError('Cannot find any Tensor in features dict.')
return tensors[0]
return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0],
dtype='int64')
def update_mask(self, padding_mask, dataset_batch):
"""Calculate and cache the amount of padding required for a batch."""
original_batch_size = self.get_real_batch_size(dataset_batch)
missing_count = self.padded_batch_size - original_batch_size
mask = backend.concatenate([tf.ones(original_batch_size),
tf.zeros(missing_count)], axis=0)
return backend.concatenate([padding_mask, mask], axis=0)
def pad_batch(self, *dataset_batch_elements):
"""Pads out the batch dimension of a tensor to the complete batch size."""
def _pad(batch):
"""Helper function to pad nested data within each batch elements."""
padded_dict_batch = {}
if isinstance(batch, dict):
for key, value in batch.items():
padded_dict_batch[key] = _pad(value)
return padded_dict_batch
rank = len(batch.shape)
assert rank > 0
missing_count = (self.padded_batch_size -
self.get_real_batch_size(batch))
padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
return tf.pad(batch, padding, 'constant')
if len(dataset_batch_elements) == 1:
return _pad(dataset_batch_elements[0])
batch_elements = []
for batch_element in dataset_batch_elements:
batch_elements.append(_pad(batch_element))
return tuple(batch_elements)
def apply_mask(self, prediction_result):
"""Removes prediction output that corresponds to padded input."""
padding_mask = backend.get_value(self.padding_mask)
assert len(padding_mask.shape) == 1
if len(self.output_shape) == 1:
prediction = np.take(prediction_result,
np.nonzero(
padding_mask[:len(prediction_result)]),
axis=0)
if prediction.shape[0] == 1:
prediction = np.squeeze(prediction, axis=0)
return prediction
else:
predictions = []
for i in range(len(self.output_shape)):
prediction = prediction_result[i]
prediction = np.take(prediction, np.nonzero(
padding_mask[:len(prediction)]), axis=0)
predictions.append(np.squeeze(prediction))
return predictions
Classes
class PartialBatchPaddingHandler (output_shape)
-
A container that holds info about partial batches for
predict()
.Expand source code
class PartialBatchPaddingHandler(object): """A container that holds info about partial batches for `predict()`.""" def __init__(self, output_shape): self.padded_batch_size = 0 self.padding_mask = tf.zeros(0) self.output_shape = output_shape def get_real_batch_size(self, dataset_batch): """Returns the number of elements in a potentially partial batch.""" if isinstance(dataset_batch, (tuple, list)): dataset_batch = dataset_batch[0] assert tf.nest.flatten(dataset_batch) def _find_any_tensor(batch_features): tensors = [ x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x) ] if not tensors: raise ValueError('Cannot find any Tensor in features dict.') return tensors[0] return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0], dtype='int64') def update_mask(self, padding_mask, dataset_batch): """Calculate and cache the amount of padding required for a batch.""" original_batch_size = self.get_real_batch_size(dataset_batch) missing_count = self.padded_batch_size - original_batch_size mask = backend.concatenate([tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0) return backend.concatenate([padding_mask, mask], axis=0) def pad_batch(self, *dataset_batch_elements): """Pads out the batch dimension of a tensor to the complete batch size.""" def _pad(batch): """Helper function to pad nested data within each batch elements.""" padded_dict_batch = {} if isinstance(batch, dict): for key, value in batch.items(): padded_dict_batch[key] = _pad(value) return padded_dict_batch rank = len(batch.shape) assert rank > 0 missing_count = (self.padded_batch_size - self.get_real_batch_size(batch)) padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) return tf.pad(batch, padding, 'constant') if len(dataset_batch_elements) == 1: return _pad(dataset_batch_elements[0]) batch_elements = [] for batch_element in dataset_batch_elements: batch_elements.append(_pad(batch_element)) return tuple(batch_elements) def apply_mask(self, prediction_result): """Removes prediction output that corresponds to padded input.""" padding_mask = backend.get_value(self.padding_mask) assert len(padding_mask.shape) == 1 if len(self.output_shape) == 1: prediction = np.take(prediction_result, np.nonzero( padding_mask[:len(prediction_result)]), axis=0) if prediction.shape[0] == 1: prediction = np.squeeze(prediction, axis=0) return prediction else: predictions = [] for i in range(len(self.output_shape)): prediction = prediction_result[i] prediction = np.take(prediction, np.nonzero( padding_mask[:len(prediction)]), axis=0) predictions.append(np.squeeze(prediction)) return predictions
Methods
def apply_mask(self, prediction_result)
-
Removes prediction output that corresponds to padded input.
Expand source code
def apply_mask(self, prediction_result): """Removes prediction output that corresponds to padded input.""" padding_mask = backend.get_value(self.padding_mask) assert len(padding_mask.shape) == 1 if len(self.output_shape) == 1: prediction = np.take(prediction_result, np.nonzero( padding_mask[:len(prediction_result)]), axis=0) if prediction.shape[0] == 1: prediction = np.squeeze(prediction, axis=0) return prediction else: predictions = [] for i in range(len(self.output_shape)): prediction = prediction_result[i] prediction = np.take(prediction, np.nonzero( padding_mask[:len(prediction)]), axis=0) predictions.append(np.squeeze(prediction)) return predictions
def get_real_batch_size(self, dataset_batch)
-
Returns the number of elements in a potentially partial batch.
Expand source code
def get_real_batch_size(self, dataset_batch): """Returns the number of elements in a potentially partial batch.""" if isinstance(dataset_batch, (tuple, list)): dataset_batch = dataset_batch[0] assert tf.nest.flatten(dataset_batch) def _find_any_tensor(batch_features): tensors = [ x for x in tf.nest.flatten(batch_features) if tf.is_tensor(x) ] if not tensors: raise ValueError('Cannot find any Tensor in features dict.') return tensors[0] return backend.cast(backend.shape(_find_any_tensor(dataset_batch))[0], dtype='int64')
def pad_batch(self, *dataset_batch_elements)
-
Pads out the batch dimension of a tensor to the complete batch size.
Expand source code
def pad_batch(self, *dataset_batch_elements): """Pads out the batch dimension of a tensor to the complete batch size.""" def _pad(batch): """Helper function to pad nested data within each batch elements.""" padded_dict_batch = {} if isinstance(batch, dict): for key, value in batch.items(): padded_dict_batch[key] = _pad(value) return padded_dict_batch rank = len(batch.shape) assert rank > 0 missing_count = (self.padded_batch_size - self.get_real_batch_size(batch)) padding = backend.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) return tf.pad(batch, padding, 'constant') if len(dataset_batch_elements) == 1: return _pad(dataset_batch_elements[0]) batch_elements = [] for batch_element in dataset_batch_elements: batch_elements.append(_pad(batch_element)) return tuple(batch_elements)
def update_mask(self, padding_mask, dataset_batch)
-
Calculate and cache the amount of padding required for a batch.
Expand source code
def update_mask(self, padding_mask, dataset_batch): """Calculate and cache the amount of padding required for a batch.""" original_batch_size = self.get_real_batch_size(dataset_batch) missing_count = self.padded_batch_size - original_batch_size mask = backend.concatenate([tf.ones(original_batch_size), tf.zeros(missing_count)], axis=0) return backend.concatenate([padding_mask, mask], axis=0)