Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert to tensor before smart resize (#2184) #2185

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions keras_cv/backend/tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,70 @@
from tensorflow import reduce_all as all # noqa: F403, F401
from tensorflow import reduce_max as max # noqa: F403, F401
from tensorflow import split # noqa: F403, F401
from tensorflow.keras.preprocessing.image import ( # noqa: F403, F401
smart_resize,
)

import numpy as np
import tensorflow as tf


def smart_resize(x, size, interpolation="bilinear"):
"""Resize images to a target size without aspect ratio distortion.

Copied from `tf_keras` for Keras 3 and for use in `tf.data` pipeline.
"""
if len(size) != 2:
raise ValueError(
f"Expected `size` to be a tuple of 2 integers, but got: {size}."
)
img = tf.convert_to_tensor(x)
if img.shape.rank is not None:
if img.shape.rank < 3 or img.shape.rank > 4:
raise ValueError(
"Expected an image array with shape `(height, width, "
"channels)`, or `(batch_size, height, width, channels)`, but "
f"got input with incorrect rank, of shape {img.shape}."
)
shape = tf.shape(img)
height, width = shape[-3], shape[-2]
target_height, target_width = size
if img.shape.rank is not None:
static_num_channels = img.shape[-1]
else:
static_num_channels = None

crop_height = tf.cast(
tf.cast(width * target_height, "float32") / target_width, "int32"
)
crop_width = tf.cast(
tf.cast(height * target_width, "float32") / target_height, "int32"
)

# Set back to input height / width if crop_height / crop_width is not
# smaller.
crop_height = tf.minimum(height, crop_height)
crop_width = tf.minimum(width, crop_width)

crop_box_hstart = tf.cast(
tf.cast(height - crop_height, "float32") / 2, "int32"
)
crop_box_wstart = tf.cast(
tf.cast(width - crop_width, "float32") / 2, "int32"
)

if img.shape.rank == 4:
crop_box_start = tf.stack([0, crop_box_hstart, crop_box_wstart, 0])
crop_box_size = tf.stack([-1, crop_height, crop_width, -1])
else:
crop_box_start = tf.stack([crop_box_hstart, crop_box_wstart, 0])
crop_box_size = tf.stack([crop_height, crop_width, -1])

img = tf.slice(img, crop_box_start, crop_box_size)
img = tf.image.resize(images=img, size=size, method=interpolation)
# Apparent bug in resize_images_v2 may cause shape to be lost
if img.shape.rank is not None:
if img.shape.rank == 4:
img.set_shape((None, None, None, static_num_channels))
if img.shape.rank == 3:
img.set_shape((None, None, static_num_channels))
if isinstance(x, np.ndarray):
return img.numpy()
return img
4 changes: 2 additions & 2 deletions keras_cv/layers/preprocessing/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from keras_cv import bounding_box
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import ops
from keras_cv.backend import tf_ops
from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)
Expand Down Expand Up @@ -309,7 +309,7 @@ def _resize_with_crop(self, inputs):
def resize_with_crop_to_aspect(x, interpolation_method):
if isinstance(x, tf.RaggedTensor):
x = x.to_tensor()
return ops.smart_resize(
return tf_ops.smart_resize(
x,
size=size,
interpolation=interpolation_method,
Expand Down
Loading