diff --git a/keras_cv/models/object_detection/retinanet/retinanet_label_encoder.py b/keras_cv/models/object_detection/retinanet/retinanet_label_encoder.py index c2410d2ac1..4142f7bac5 100644 --- a/keras_cv/models/object_detection/retinanet/retinanet_label_encoder.py +++ b/keras_cv/models/object_detection/retinanet/retinanet_label_encoder.py @@ -124,7 +124,7 @@ def _encode_sample(self, box_labels, anchor_boxes, image_shape): gt_boxes, matched_gt_idx ) matched_gt_boxes = ops.reshape( - matched_gt_boxes, (-1, matched_gt_boxes.shape[1], 4) + matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4) ) box_target = bounding_box._encode_box_to_deltas( @@ -166,7 +166,7 @@ def _encode_sample(self, box_labels, anchor_boxes, image_shape): box_shape = ops.shape(gt_boxes) batch_size = box_shape[0] n_boxes = box_shape[1] - box_ids = ops.arange(gt_boxes.shape[1], dtype=matched_gt_idx.dtype) + box_ids = ops.arange(n_boxes, dtype=matched_gt_idx.dtype) matched_ids = ops.expand_dims(matched_gt_idx, axis=-1) matches = box_ids == matched_ids matches = ops.any(matches, axis=1) @@ -197,7 +197,8 @@ def call(self, images, box_labels): f"Received `type(images)={type(images)}`." ) - image_shape = tuple(images[0].shape) + image_shape = ops.shape(images) + image_shape = (image_shape[1], image_shape[2], image_shape[3]) box_labels = bounding_box.to_dense(box_labels) if len(box_labels["classes"].shape) == 2: box_labels["classes"] = ops.expand_dims( @@ -215,7 +216,7 @@ def call(self, images, box_labels): result = self._encode_sample(box_labels, anchor_boxes, image_shape) encoded_box_targets = result["boxes"] encoded_box_targets = ops.reshape( - encoded_box_targets, (-1, encoded_box_targets.shape[1], 4) + encoded_box_targets, (-1, ops.shape(encoded_box_targets)[1], 4) ) class_targets = result["classes"] return encoded_box_targets, class_targets diff --git a/keras_cv/models/object_detection/retinanet/retinanet_test.py b/keras_cv/models/object_detection/retinanet/retinanet_test.py index cc0f7c9131..45026262f4 100644 --- a/keras_cv/models/object_detection/retinanet/retinanet_test.py +++ b/keras_cv/models/object_detection/retinanet/retinanet_test.py @@ -20,6 +20,7 @@ from absl.testing import parameterized import keras_cv +from keras_cv import backend from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.backbones.test_backbone_presets import ( @@ -249,6 +250,49 @@ def test_call_with_custom_label_encoder(self): ) model(ops.ones(shape=(2, 224, 224, 3))) + def test_tf_dataset_data_generator(self): + if backend.multi_backend() and keras.backend.backend() != "tensorflow": + pytest.skip("TensorFlow required for `tf.data.Dataset` test.") + + def data_generator(): + image = tf.ones((512, 512, 3), dtype=tf.float32) + + bounding_boxes = { + "boxes": tf.ones((3, 4), dtype=tf.float32), + "classes": tf.ones((3,), dtype=tf.float32), + } + + yield {"images": image, "bounding_boxes": bounding_boxes} + + data = tf.data.Dataset.from_generator( + generator=data_generator, + output_signature={ + "images": tf.TensorSpec(shape=(512, 512, 3), dtype=tf.float32), + "bounding_boxes": { + "boxes": tf.TensorSpec(shape=(None, 4), dtype=tf.float32), + "classes": tf.TensorSpec(shape=(None,), dtype=tf.float32), + }, + }, + ).batch(1) + + model = keras_cv.models.RetinaNet( + num_classes=2, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet", + load_weights=False, + ), + ) + + model.compile( + classification_loss="focal", + box_loss="smoothl1", + optimizer="adam", + jit_compile=False, + ) + + model.fit(data, epochs=1, batch_size=1, steps_per_epoch=1) + @pytest.mark.large class RetinaNetSmokeTest(TestCase):