Module keras.layers.preprocessing.preprocessing_test_utils
Tests for Keras' base preprocessing layer.
Expand source code
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras' base preprocessing layer."""
import tensorflow.compat.v2 as tf
import collections
import numpy as np
class PreprocessingLayerTest(tf.test.TestCase):
"""Base test class for preprocessing layer API validation."""
# TODO(b/137303934): Consider incorporating something like this Close vs All
# behavior into core tf.test.TestCase.
def assertAllCloseOrEqual(self, a, b, msg=None):
"""Asserts that elements are close (if numeric) or equal (if string)."""
if a is None or b is None:
self.assertAllEqual(a, b, msg=msg)
elif isinstance(a, (list, tuple)):
self.assertEqual(len(a), len(b))
for a_value, b_value in zip(a, b):
self.assertAllCloseOrEqual(a_value, b_value, msg=msg)
elif isinstance(a, collections.abc.Mapping):
self.assertEqual(len(a), len(b))
for key, a_value in a.items():
b_value = b[key]
error_message = "{} ({})".format(msg, key) if msg else None
self.assertAllCloseOrEqual(a_value, b_value, error_message)
elif (isinstance(a, float) or
hasattr(a, "dtype") and np.issubdtype(a.dtype, np.number)):
self.assertAllClose(a, b, msg=msg)
else:
self.assertAllEqual(a, b, msg=msg)
def assert_extracted_output_equal(self, combiner, acc1, acc2, msg=None):
data_1 = combiner.extract(acc1)
data_2 = combiner.extract(acc2)
self.assertAllCloseOrEqual(data_1, data_2, msg=msg)
# This is an injection seam so that tests like TextVectorizationTest can
# define their own methods for asserting that accumulators are equal.
compare_accumulators = assertAllCloseOrEqual
def validate_accumulator_computation(self, combiner, data, expected):
"""Validate that various combinations of compute and merge are identical."""
if len(data) < 4:
raise AssertionError("Data must have at least 4 elements.")
data_0 = np.array([data[0]])
data_1 = np.array([data[1]])
data_2 = np.array(data[2:])
single_compute = combiner.compute(data)
all_merge = combiner.merge([
combiner.compute(data_0),
combiner.compute(data_1),
combiner.compute(data_2)
])
self.compare_accumulators(
single_compute,
all_merge,
msg="Sharding data should not change the data output.")
unordered_all_merge = combiner.merge([
combiner.compute(data_1),
combiner.compute(data_2),
combiner.compute(data_0)
])
self.compare_accumulators(
all_merge,
unordered_all_merge,
msg="The order of merge arguments should not change the data "
"output.")
hierarchical_merge = combiner.merge([
combiner.compute(data_1),
combiner.merge([combiner.compute(data_2),
combiner.compute(data_0)])
])
self.compare_accumulators(
all_merge,
hierarchical_merge,
msg="Nesting merge arguments should not change the data output.")
nested_compute = combiner.compute(
data_0, combiner.compute(data_1, combiner.compute(data_2)))
self.compare_accumulators(
all_merge,
nested_compute,
msg="Nesting compute arguments should not change the data output.")
mixed_compute = combiner.merge([
combiner.compute(data_0),
combiner.compute(data_1, combiner.compute(data_2))
])
self.compare_accumulators(
all_merge,
mixed_compute,
msg="Mixing merge and compute calls should not change the data "
"output.")
single_merge = combiner.merge([
combiner.merge([combiner.compute(data_0)]),
combiner.compute(data_1, combiner.compute(data_2))
])
self.compare_accumulators(
all_merge,
single_merge,
msg="Calling merge with a data length of 1 should not change the data "
"output.")
self.compare_accumulators(
expected,
all_merge,
msg="Calculated accumulators "
"did not match expected accumulator.")
def validate_accumulator_extract(self, combiner, data, expected):
"""Validate that the expected results of computing and extracting."""
acc = combiner.compute(data)
extracted_data = combiner.extract(acc)
self.assertAllCloseOrEqual(expected, extracted_data)
def validate_accumulator_extract_and_restore(self, combiner, data, expected):
"""Validate that the extract<->restore loop loses no data."""
acc = combiner.compute(data)
extracted_data = combiner.extract(acc)
restored_acc = combiner.restore(extracted_data)
self.assert_extracted_output_equal(combiner, acc, restored_acc)
self.assertAllCloseOrEqual(expected, combiner.extract(restored_acc))
def validate_accumulator_serialize_and_deserialize(self, combiner, data,
expected):
"""Validate that the serialize<->deserialize loop loses no data."""
acc = combiner.compute(data)
serialized_data = combiner.serialize(acc)
deserialized_data = combiner.deserialize(serialized_data)
self.compare_accumulators(acc, deserialized_data)
self.compare_accumulators(expected, deserialized_data)
def validate_accumulator_uniqueness(self, combiner, data):
"""Validate that every call to compute creates a unique accumulator."""
acc = combiner.compute(data)
acc2 = combiner.compute(data)
self.assertIsNot(acc, acc2)
self.compare_accumulators(acc, acc2)
Classes
class PreprocessingLayerTest (methodName='runTest')
-
Base test class for preprocessing layer API validation.
Create an instance of the class that will use the named test method when executed. Raises a ValueError if the instance does not have a method with the specified name.
Expand source code
class PreprocessingLayerTest(tf.test.TestCase): """Base test class for preprocessing layer API validation.""" # TODO(b/137303934): Consider incorporating something like this Close vs All # behavior into core tf.test.TestCase. def assertAllCloseOrEqual(self, a, b, msg=None): """Asserts that elements are close (if numeric) or equal (if string).""" if a is None or b is None: self.assertAllEqual(a, b, msg=msg) elif isinstance(a, (list, tuple)): self.assertEqual(len(a), len(b)) for a_value, b_value in zip(a, b): self.assertAllCloseOrEqual(a_value, b_value, msg=msg) elif isinstance(a, collections.abc.Mapping): self.assertEqual(len(a), len(b)) for key, a_value in a.items(): b_value = b[key] error_message = "{} ({})".format(msg, key) if msg else None self.assertAllCloseOrEqual(a_value, b_value, error_message) elif (isinstance(a, float) or hasattr(a, "dtype") and np.issubdtype(a.dtype, np.number)): self.assertAllClose(a, b, msg=msg) else: self.assertAllEqual(a, b, msg=msg) def assert_extracted_output_equal(self, combiner, acc1, acc2, msg=None): data_1 = combiner.extract(acc1) data_2 = combiner.extract(acc2) self.assertAllCloseOrEqual(data_1, data_2, msg=msg) # This is an injection seam so that tests like TextVectorizationTest can # define their own methods for asserting that accumulators are equal. compare_accumulators = assertAllCloseOrEqual def validate_accumulator_computation(self, combiner, data, expected): """Validate that various combinations of compute and merge are identical.""" if len(data) < 4: raise AssertionError("Data must have at least 4 elements.") data_0 = np.array([data[0]]) data_1 = np.array([data[1]]) data_2 = np.array(data[2:]) single_compute = combiner.compute(data) all_merge = combiner.merge([ combiner.compute(data_0), combiner.compute(data_1), combiner.compute(data_2) ]) self.compare_accumulators( single_compute, all_merge, msg="Sharding data should not change the data output.") unordered_all_merge = combiner.merge([ combiner.compute(data_1), combiner.compute(data_2), combiner.compute(data_0) ]) self.compare_accumulators( all_merge, unordered_all_merge, msg="The order of merge arguments should not change the data " "output.") hierarchical_merge = combiner.merge([ combiner.compute(data_1), combiner.merge([combiner.compute(data_2), combiner.compute(data_0)]) ]) self.compare_accumulators( all_merge, hierarchical_merge, msg="Nesting merge arguments should not change the data output.") nested_compute = combiner.compute( data_0, combiner.compute(data_1, combiner.compute(data_2))) self.compare_accumulators( all_merge, nested_compute, msg="Nesting compute arguments should not change the data output.") mixed_compute = combiner.merge([ combiner.compute(data_0), combiner.compute(data_1, combiner.compute(data_2)) ]) self.compare_accumulators( all_merge, mixed_compute, msg="Mixing merge and compute calls should not change the data " "output.") single_merge = combiner.merge([ combiner.merge([combiner.compute(data_0)]), combiner.compute(data_1, combiner.compute(data_2)) ]) self.compare_accumulators( all_merge, single_merge, msg="Calling merge with a data length of 1 should not change the data " "output.") self.compare_accumulators( expected, all_merge, msg="Calculated accumulators " "did not match expected accumulator.") def validate_accumulator_extract(self, combiner, data, expected): """Validate that the expected results of computing and extracting.""" acc = combiner.compute(data) extracted_data = combiner.extract(acc) self.assertAllCloseOrEqual(expected, extracted_data) def validate_accumulator_extract_and_restore(self, combiner, data, expected): """Validate that the extract<->restore loop loses no data.""" acc = combiner.compute(data) extracted_data = combiner.extract(acc) restored_acc = combiner.restore(extracted_data) self.assert_extracted_output_equal(combiner, acc, restored_acc) self.assertAllCloseOrEqual(expected, combiner.extract(restored_acc)) def validate_accumulator_serialize_and_deserialize(self, combiner, data, expected): """Validate that the serialize<->deserialize loop loses no data.""" acc = combiner.compute(data) serialized_data = combiner.serialize(acc) deserialized_data = combiner.deserialize(serialized_data) self.compare_accumulators(acc, deserialized_data) self.compare_accumulators(expected, deserialized_data) def validate_accumulator_uniqueness(self, combiner, data): """Validate that every call to compute creates a unique accumulator.""" acc = combiner.compute(data) acc2 = combiner.compute(data) self.assertIsNot(acc, acc2) self.compare_accumulators(acc, acc2)
Ancestors
- tensorflow.python.framework.test_util.TensorFlowTestCase
- absl.testing.absltest.TestCase
- absl.third_party.unittest3_backport.case.TestCase
- unittest.case.TestCase
Methods
def assertAllCloseOrEqual(self, a, b, msg=None)
-
Asserts that elements are close (if numeric) or equal (if string).
Expand source code
def assertAllCloseOrEqual(self, a, b, msg=None): """Asserts that elements are close (if numeric) or equal (if string).""" if a is None or b is None: self.assertAllEqual(a, b, msg=msg) elif isinstance(a, (list, tuple)): self.assertEqual(len(a), len(b)) for a_value, b_value in zip(a, b): self.assertAllCloseOrEqual(a_value, b_value, msg=msg) elif isinstance(a, collections.abc.Mapping): self.assertEqual(len(a), len(b)) for key, a_value in a.items(): b_value = b[key] error_message = "{} ({})".format(msg, key) if msg else None self.assertAllCloseOrEqual(a_value, b_value, error_message) elif (isinstance(a, float) or hasattr(a, "dtype") and np.issubdtype(a.dtype, np.number)): self.assertAllClose(a, b, msg=msg) else: self.assertAllEqual(a, b, msg=msg)
def assert_extracted_output_equal(self, combiner, acc1, acc2, msg=None)
-
Expand source code
def assert_extracted_output_equal(self, combiner, acc1, acc2, msg=None): data_1 = combiner.extract(acc1) data_2 = combiner.extract(acc2) self.assertAllCloseOrEqual(data_1, data_2, msg=msg)
def compare_accumulators(self, a, b, msg=None)
-
Asserts that elements are close (if numeric) or equal (if string).
Expand source code
def assertAllCloseOrEqual(self, a, b, msg=None): """Asserts that elements are close (if numeric) or equal (if string).""" if a is None or b is None: self.assertAllEqual(a, b, msg=msg) elif isinstance(a, (list, tuple)): self.assertEqual(len(a), len(b)) for a_value, b_value in zip(a, b): self.assertAllCloseOrEqual(a_value, b_value, msg=msg) elif isinstance(a, collections.abc.Mapping): self.assertEqual(len(a), len(b)) for key, a_value in a.items(): b_value = b[key] error_message = "{} ({})".format(msg, key) if msg else None self.assertAllCloseOrEqual(a_value, b_value, error_message) elif (isinstance(a, float) or hasattr(a, "dtype") and np.issubdtype(a.dtype, np.number)): self.assertAllClose(a, b, msg=msg) else: self.assertAllEqual(a, b, msg=msg)
def validate_accumulator_computation(self, combiner, data, expected)
-
Validate that various combinations of compute and merge are identical.
Expand source code
def validate_accumulator_computation(self, combiner, data, expected): """Validate that various combinations of compute and merge are identical.""" if len(data) < 4: raise AssertionError("Data must have at least 4 elements.") data_0 = np.array([data[0]]) data_1 = np.array([data[1]]) data_2 = np.array(data[2:]) single_compute = combiner.compute(data) all_merge = combiner.merge([ combiner.compute(data_0), combiner.compute(data_1), combiner.compute(data_2) ]) self.compare_accumulators( single_compute, all_merge, msg="Sharding data should not change the data output.") unordered_all_merge = combiner.merge([ combiner.compute(data_1), combiner.compute(data_2), combiner.compute(data_0) ]) self.compare_accumulators( all_merge, unordered_all_merge, msg="The order of merge arguments should not change the data " "output.") hierarchical_merge = combiner.merge([ combiner.compute(data_1), combiner.merge([combiner.compute(data_2), combiner.compute(data_0)]) ]) self.compare_accumulators( all_merge, hierarchical_merge, msg="Nesting merge arguments should not change the data output.") nested_compute = combiner.compute( data_0, combiner.compute(data_1, combiner.compute(data_2))) self.compare_accumulators( all_merge, nested_compute, msg="Nesting compute arguments should not change the data output.") mixed_compute = combiner.merge([ combiner.compute(data_0), combiner.compute(data_1, combiner.compute(data_2)) ]) self.compare_accumulators( all_merge, mixed_compute, msg="Mixing merge and compute calls should not change the data " "output.") single_merge = combiner.merge([ combiner.merge([combiner.compute(data_0)]), combiner.compute(data_1, combiner.compute(data_2)) ]) self.compare_accumulators( all_merge, single_merge, msg="Calling merge with a data length of 1 should not change the data " "output.") self.compare_accumulators( expected, all_merge, msg="Calculated accumulators " "did not match expected accumulator.")
def validate_accumulator_extract(self, combiner, data, expected)
-
Validate that the expected results of computing and extracting.
Expand source code
def validate_accumulator_extract(self, combiner, data, expected): """Validate that the expected results of computing and extracting.""" acc = combiner.compute(data) extracted_data = combiner.extract(acc) self.assertAllCloseOrEqual(expected, extracted_data)
def validate_accumulator_extract_and_restore(self, combiner, data, expected)
-
Validate that the extract<->restore loop loses no data.
Expand source code
def validate_accumulator_extract_and_restore(self, combiner, data, expected): """Validate that the extract<->restore loop loses no data.""" acc = combiner.compute(data) extracted_data = combiner.extract(acc) restored_acc = combiner.restore(extracted_data) self.assert_extracted_output_equal(combiner, acc, restored_acc) self.assertAllCloseOrEqual(expected, combiner.extract(restored_acc))
def validate_accumulator_serialize_and_deserialize(self, combiner, data, expected)
-
Validate that the serialize<->deserialize loop loses no data.
Expand source code
def validate_accumulator_serialize_and_deserialize(self, combiner, data, expected): """Validate that the serialize<->deserialize loop loses no data.""" acc = combiner.compute(data) serialized_data = combiner.serialize(acc) deserialized_data = combiner.deserialize(serialized_data) self.compare_accumulators(acc, deserialized_data) self.compare_accumulators(expected, deserialized_data)
def validate_accumulator_uniqueness(self, combiner, data)
-
Validate that every call to compute creates a unique accumulator.
Expand source code
def validate_accumulator_uniqueness(self, combiner, data): """Validate that every call to compute creates a unique accumulator.""" acc = combiner.compute(data) acc2 = combiner.compute(data) self.assertIsNot(acc, acc2) self.compare_accumulators(acc, acc2)