Skip to content

Commit

Permalink
Fix RetinaNet shape issue (#2119)
Browse files Browse the repository at this point in the history
  • Loading branch information
tirthasheshpatel authored Nov 3, 2023
1 parent 2ac0943 commit 937e163
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _encode_sample(self, box_labels, anchor_boxes, image_shape):
gt_boxes, matched_gt_idx
)
matched_gt_boxes = ops.reshape(
matched_gt_boxes, (-1, matched_gt_boxes.shape[1], 4)
matched_gt_boxes, (-1, ops.shape(matched_gt_boxes)[1], 4)
)

box_target = bounding_box._encode_box_to_deltas(
Expand Down Expand Up @@ -166,7 +166,7 @@ def _encode_sample(self, box_labels, anchor_boxes, image_shape):
box_shape = ops.shape(gt_boxes)
batch_size = box_shape[0]
n_boxes = box_shape[1]
box_ids = ops.arange(gt_boxes.shape[1], dtype=matched_gt_idx.dtype)
box_ids = ops.arange(n_boxes, dtype=matched_gt_idx.dtype)
matched_ids = ops.expand_dims(matched_gt_idx, axis=-1)
matches = box_ids == matched_ids
matches = ops.any(matches, axis=1)
Expand Down Expand Up @@ -197,7 +197,8 @@ def call(self, images, box_labels):
f"Received `type(images)={type(images)}`."
)

image_shape = tuple(images[0].shape)
image_shape = ops.shape(images)
image_shape = (image_shape[1], image_shape[2], image_shape[3])
box_labels = bounding_box.to_dense(box_labels)
if len(box_labels["classes"].shape) == 2:
box_labels["classes"] = ops.expand_dims(
Expand All @@ -215,7 +216,7 @@ def call(self, images, box_labels):
result = self._encode_sample(box_labels, anchor_boxes, image_shape)
encoded_box_targets = result["boxes"]
encoded_box_targets = ops.reshape(
encoded_box_targets, (-1, encoded_box_targets.shape[1], 4)
encoded_box_targets, (-1, ops.shape(encoded_box_targets)[1], 4)
)
class_targets = result["classes"]
return encoded_box_targets, class_targets
Expand Down
44 changes: 44 additions & 0 deletions keras_cv/models/object_detection/retinanet/retinanet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl.testing import parameterized

import keras_cv
from keras_cv import backend
from keras_cv.backend import keras
from keras_cv.backend import ops
from keras_cv.models.backbones.test_backbone_presets import (
Expand Down Expand Up @@ -249,6 +250,49 @@ def test_call_with_custom_label_encoder(self):
)
model(ops.ones(shape=(2, 224, 224, 3)))

def test_tf_dataset_data_generator(self):
if backend.multi_backend() and keras.backend.backend() != "tensorflow":
pytest.skip("TensorFlow required for `tf.data.Dataset` test.")

def data_generator():
image = tf.ones((512, 512, 3), dtype=tf.float32)

bounding_boxes = {
"boxes": tf.ones((3, 4), dtype=tf.float32),
"classes": tf.ones((3,), dtype=tf.float32),
}

yield {"images": image, "bounding_boxes": bounding_boxes}

data = tf.data.Dataset.from_generator(
generator=data_generator,
output_signature={
"images": tf.TensorSpec(shape=(512, 512, 3), dtype=tf.float32),
"bounding_boxes": {
"boxes": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
"classes": tf.TensorSpec(shape=(None,), dtype=tf.float32),
},
},
).batch(1)

model = keras_cv.models.RetinaNet(
num_classes=2,
bounding_box_format="xyxy",
backbone=keras_cv.models.ResNet50Backbone.from_preset(
"resnet50_imagenet",
load_weights=False,
),
)

model.compile(
classification_loss="focal",
box_loss="smoothl1",
optimizer="adam",
jit_compile=False,
)

model.fit(data, epochs=1, batch_size=1, steps_per_epoch=1)


@pytest.mark.large
class RetinaNetSmokeTest(TestCase):
Expand Down

0 comments on commit 937e163

Please sign in to comment.