Skip to content

Commit

Permalink
Move tf.keras to separate function in test
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed Jul 22, 2019
1 parent 04070cc commit 63b7b64
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,6 @@ def test_decode_webp(self):

self.assertAllEqual(webp_v, png)

def test_webp_file_dataset(self):
"""Test case for WebPDataset.
"""
filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")

num_repeats = 2

dataset = image_io.WebPDataset([filename, filename])
# Repeat 2 times (2 * 2 = 4 images)
dataset = dataset.repeat(num_repeats)
# Drop alpha channel
dataset = dataset.map(lambda x: x[:, :, :3])
# Resize to 224 * 224
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
# Batch to 3, still have 4 images (3 + 1)
dataset = dataset.batch(1)
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
y = model.predict(dataset)
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
for i in p:
assert i[0][1] == 'pineapple' # not truly a pineapple, though
assert len(p) == 4

def test_tiff_file_dataset(self):
"""Test case for TIFFDataset.
"""
Expand Down Expand Up @@ -198,5 +174,29 @@ def test_draw_bounding_box(self):
# self.assertAllEqual(bb_image_v, ex_image_v)
_ = bb_image_p.eval()

def test_webp_file_dataset(self):
"""Test case for WebPDataset.
"""
filename = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_image", "sample.webp")

num_repeats = 2

dataset = image_io.WebPDataset([filename, filename])
# Repeat 2 times (2 * 2 = 4 images)
dataset = dataset.repeat(num_repeats)
# Drop alpha channel
dataset = dataset.map(lambda x: x[:, :, :3])
# Resize to 224 * 224
dataset = dataset.map(lambda x: tf.keras.applications.resnet50.preprocess_input(tf.image.resize(x, (224, 224))))
# Batch to 3, still have 4 images (3 + 1)
dataset = dataset.batch(1)
model = tf.keras.applications.resnet50.ResNet50(weights='imagenet')
y = model.predict(dataset)
p = tf.keras.applications.resnet50.decode_predictions(y, top=1)
for i in p:
assert i[0][1] == 'pineapple' # not truly a pineapple, though
assert len(p) == 4

if __name__ == "__main__":
test.main()

0 comments on commit 63b7b64

Please sign in to comment.