Skip to content

Commit

Permalink
Fix up torch GPU failing test for mix up
Browse files Browse the repository at this point in the history
We need to make sure to use get any tensors places on cpu before using
them in the tensorflow backend during preprocessing.
  • Loading branch information
mattdangerw committed Dec 19, 2024
1 parent 9a3e173 commit 1926b8d
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions keras/src/layers/preprocessing/image_preprocessing/mix_up.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
Expand Down Expand Up @@ -114,6 +115,8 @@ def _mix_up_bounding_boxes(bounding_boxes, transformation):
self.backend.set_backend("tensorflow")

permutation_order = transformation["permutation_order"]
# Make sure we are on cpu for torch tensors.
permutation_order = ops.convert_to_numpy(permutation_order)

boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"]
boxes_for_mix_up = self.backend.numpy.take(
Expand Down Expand Up @@ -146,6 +149,8 @@ def transform_segmentation_masks(
):
def _mix_up_segmentation_masks(segmentation_masks, transformation):
mix_weight = transformation["mix_weight"]
# Make sure we are on cpu for torch tensors.
mix_weight = ops.convert_to_numpy(mix_weight)
permutation_order = transformation["permutation_order"]
mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])
segmentation_masks_for_mix_up = self.backend.numpy.take(
Expand Down

0 comments on commit 1926b8d

Please sign in to comment.