From 63b7b64db27d92e6a5cbc2e6a8eec4998f3e39fb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 22 Jul 2019 19:38:37 +0000 Subject: [PATCH] Move tf.keras to separate function in test Signed-off-by: Yong Tang --- tests/test_image.py | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/test_image.py b/tests/test_image.py index 57f7b4a09f..cf8655e040 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -58,30 +58,6 @@ def test_decode_webp(self): self.assertAllEqual(webp_v, png) - def test_webp_file_dataset(self): - """Test case for WebPDataset. - """ - filename = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") - - num_repeats = 2 - - dataset = image_io.WebPDataset([filename, filename]) - # Repeat 2 times (2 * 2 = 4 images) - dataset = dataset.repeat(num_repeats) - # Drop alpha channel - dataset = dataset.map(lambda x: x[:, :, :3]) - # Resize to 224 * 224 - dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) - # Batch to 3, still have 4 images (3 + 1) - dataset = dataset.batch(1) - model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') - y = model.predict(dataset) - p = tf.keras.applications.resnet50.decode_predictions(y, top=1) - for i in p: - assert i[0][1] == 'pineapple' # not truly a pineapple, though - assert len(p) == 4 - def test_tiff_file_dataset(self): """Test case for TIFFDataset. """ @@ -198,5 +174,29 @@ def test_draw_bounding_box(self): # self.assertAllEqual(bb_image_v, ex_image_v) _ = bb_image_p.eval() +def test_webp_file_dataset(self): + """Test case for WebPDataset. + """ + filename = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp") + + num_repeats = 2 + + dataset = image_io.WebPDataset([filename, filename]) + # Repeat 2 times (2 * 2 = 4 images) + dataset = dataset.repeat(num_repeats) + # Drop alpha channel + dataset = dataset.map(lambda x: x[:, :, :3]) + # Resize to 224 * 224 + dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224)))) + # Batch to 3, still have 4 images (3 + 1) + dataset = dataset.batch(1) + model = tf.keras.applications.resnet50.ResNet50(weights='imagenet') + y = model.predict(dataset) + p = tf.keras.applications.resnet50.decode_predictions(y, top=1) + for i in p: + assert i[0][1] == 'pineapple' # not truly a pineapple, though + assert len(p) == 4 + if __name__ == "__main__": test.main()