Skip to content

Commit

Permalink
Fix test errors that were introduced by upgrading to Keras Core 0.1.5 (
Browse files Browse the repository at this point in the history
…keras-team#2041)

* Fix build process for spatial pyramid pooling

* Fix label encoder for YOLOV8 for 0.1.5
  • Loading branch information
ianstenbit authored Aug 26, 2023
1 parent 6d87ffd commit 3e0e7a5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 9 additions & 1 deletion keras_cv/layers/spatial_pyramid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def build(self, input_shape):
keras.layers.Activation(self.activation),
]
)
conv_sequential.build(input_shape)
self.aspp_parallel_channels.append(conv_sequential)

# Channel 2 and afterwards are based on self.dilation_rates, and each of
Expand All @@ -109,6 +110,7 @@ def build(self, input_shape):
keras.layers.Activation(self.activation),
]
)
conv_sequential.build(input_shape)
self.aspp_parallel_channels.append(conv_sequential)

# Last channel is the global average pooling with conv2D 1x1 kernel.
Expand All @@ -125,10 +127,11 @@ def build(self, input_shape):
keras.layers.Activation(self.activation),
]
)
pool_sequential.build(input_shape)
self.aspp_parallel_channels.append(pool_sequential)

# Final projection layers
self.projection = keras.Sequential(
projection = keras.Sequential(
[
keras.layers.Conv2D(
filters=self.num_channels,
Expand All @@ -140,6 +143,11 @@ def build(self, input_shape):
keras.layers.Dropout(rate=self.dropout),
],
)
projection_input_channels = (
2 + len(self.dilation_rates)
) * self.num_channels
projection.build(tuple(input_shape[:-1]) + (projection_input_channels,))
self.projection = projection

def call(self, inputs, training=None):
"""Calls the Atrous Spatial Pyramid Pooling layer on an input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def encode_to_targets(

# return zeros if no gt boxes are present
return ops.cond(
max_num_boxes > 0,
ops.array(max_num_boxes > 0),
lambda: encode_to_targets(
pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt
),
Expand Down

0 comments on commit 3e0e7a5

Please sign in to comment.