From 1926b8de0ba558852d2bb5e6cce20b63c1298e02 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 18 Dec 2024 18:02:24 -0800 Subject: [PATCH] Fix up torch GPU failing test for mix up We need to make sure to use get any tensors places on cpu before using them in the tensorflow backend during preprocessing. --- keras/src/layers/preprocessing/image_preprocessing/mix_up.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/keras/src/layers/preprocessing/image_preprocessing/mix_up.py b/keras/src/layers/preprocessing/image_preprocessing/mix_up.py index 2c54127af0e..e5c49a4209f 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/mix_up.py +++ b/keras/src/layers/preprocessing/image_preprocessing/mix_up.py @@ -1,3 +1,4 @@ +from keras.src import ops from keras.src.api_export import keras_export from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 BaseImagePreprocessingLayer, @@ -114,6 +115,8 @@ def _mix_up_bounding_boxes(bounding_boxes, transformation): self.backend.set_backend("tensorflow") permutation_order = transformation["permutation_order"] + # Make sure we are on cpu for torch tensors. + permutation_order = ops.convert_to_numpy(permutation_order) boxes, labels = bounding_boxes["boxes"], bounding_boxes["labels"] boxes_for_mix_up = self.backend.numpy.take( @@ -146,6 +149,8 @@ def transform_segmentation_masks( ): def _mix_up_segmentation_masks(segmentation_masks, transformation): mix_weight = transformation["mix_weight"] + # Make sure we are on cpu for torch tensors. + mix_weight = ops.convert_to_numpy(mix_weight) permutation_order = transformation["permutation_order"] mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]) segmentation_masks_for_mix_up = self.backend.numpy.take(