diff --git a/keras_cv/layers/vit_det_layers.py b/keras_cv/layers/vit_det_layers.py index 78c0b0bfb6..9311a957f5 100644 --- a/keras_cv/layers/vit_det_layers.py +++ b/keras_cv/layers/vit_det_layers.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras from keras_cv.backend import ops @@ -123,16 +121,16 @@ def _get_rel_pos(self, query_size, key_size, rel_pos): return rel_pos_resized else: rel_pos_resized = rel_pos - query_coordinates = np.arange(query_size, dtype="float32")[:, None] * ( - max(key_size / query_size, 1.0) - ) - key_coordinates = np.arange(key_size, dtype="float32")[None, :] * ( - max(query_size / key_size, 1.0) - ) + query_coordinates = ops.cast( + ops.arange(query_size), dtype=self.compute_dtype + )[:, None] * (max(key_size / query_size, 1.0)) + key_coordinates = ops.cast( + ops.arange(key_size), dtype=self.compute_dtype + )[None, :] * (max(query_size / key_size, 1.0)) relative_coordinates = (query_coordinates - key_coordinates) + ( key_size - 1 ) * max(query_size / key_size, 1.0) - relative_coordinates = relative_coordinates.astype("int32") + relative_coordinates = ops.cast(relative_coordinates, dtype="int32") return ops.take(rel_pos_resized, relative_coordinates, 0) def call(self, attention_map, queries, query_size, key_size): diff --git a/keras_cv/models/segmentation/segment_anything/sam_layers.py b/keras_cv/models/segmentation/segment_anything/sam_layers.py index 577031c63c..fffc4faee5 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_layers.py +++ b/keras_cv/models/segmentation/segment_anything/sam_layers.py @@ -99,7 +99,7 @@ def call(self, query, value, key): # Attention C_PH = ops.shape(query)[-1] out = query @ ops.transpose(key, (0, 1, 3, 2)) - out = out / ops.sqrt(ops.cast(C_PH, dtype=self.dtype)) + out = out / ops.sqrt(ops.cast(C_PH, dtype=self.compute_dtype)) out = ops.softmax(out, axis=-1) # Get output @@ -278,7 +278,7 @@ def __init__(self, num_positional_features, scale, **kwargs): self.positional_encoding_gaussian_matrix = self.add_weight( name="positional_encoding_gaussian_matrix", shape=(2, self.num_positional_features), - dtype=self.dtype, + dtype=self.variable_dtype, trainable=False, initializer=keras.initializers.get("normal"), ) @@ -288,7 +288,9 @@ def build(self, input_shape=None): def __positional_encodings(self, coords): coords = coords * 2 - 1 - coords = coords @ self.positional_encoding_gaussian_matrix + coords = coords @ ops.cast( + self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype + ) coords = coords * (2 * math.pi) return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1) @@ -305,11 +307,11 @@ def encode_image(self, size): tensor: Positional encoding of the image. """ H, W = size - grid = ops.ones(shape=(H, W), dtype=self.dtype) + grid = ops.ones(shape=(H, W), dtype=self.compute_dtype) y_embed = ops.cumsum(grid, axis=0) - 0.5 x_embed = ops.cumsum(grid, axis=1) - 0.5 - y_embed = y_embed / ops.cast(H, self.dtype) - x_embed = x_embed / ops.cast(W, self.dtype) + y_embed = y_embed / ops.cast(H, self.compute_dtype) + x_embed = x_embed / ops.cast(W, self.compute_dtype) return self.__positional_encodings( ops.stack([x_embed, y_embed], axis=-1) ) diff --git a/keras_cv/models/segmentation/segment_anything/sam_test.py b/keras_cv/models/segmentation/segment_anything/sam_test.py index 3546cb906f..295355a716 100644 --- a/keras_cv/models/segmentation/segment_anything/sam_test.py +++ b/keras_cv/models/segmentation/segment_anything/sam_test.py @@ -220,48 +220,68 @@ def test_mask_decoder(self): self.assertEqual(num_parameters, 4_058_340) @pytest.mark.large - def test_end_to_end_model_predict(self): - model = SegmentAnythingModel( - backbone=self.image_encoder, - prompt_encoder=self.prompt_encoder, - mask_decoder=self.mask_decoder, - ) - - # We use box-only prompting for this test. - mask_prompts = self.get_prompts(1, "boxes") - inputs = { - "images": np.ones((1, 1024, 1024, 3)), - } - inputs.update(mask_prompts) - - # Check the number of parameters - num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) - self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + @parameterized.named_parameters( + [ + ("float32", "float32"), + ("mixed_float16", "mixed_float16"), + ("bfloat16", "bfloat16"), + ] + ) + def test_end_to_end_model_predict(self, dtype_policy): + import threading + + with threading.Lock(): + # We are changing the global dtype policy here but don't want any + # other tests to use that policy, so compute under a lock until + # we reset the global policy. + old_policy = getattr( + keras.mixed_precision, "dtype_policy", lambda: "float32" + )() + keras.mixed_precision.set_global_policy(dtype_policy) + model = SegmentAnythingModel( + backbone=self.image_encoder, + prompt_encoder=self.prompt_encoder, + mask_decoder=self.mask_decoder, + ) - # Forward pass through the model - outputs = model.predict(inputs) - masks, iou_pred = outputs["masks"], outputs["iou_pred"] + # We use box-only prompting for this test. + mask_prompts = self.get_prompts(1, "boxes") + inputs = { + "images": np.ones((1, 1024, 1024, 3)), + } + inputs.update(mask_prompts) + + # Check the number of parameters + num_parameters = np.sum([np.prod(x.shape) for x in model.weights]) + self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340) + + # Forward pass through the model + outputs = model.predict(inputs) + masks, iou_pred = outputs["masks"], outputs["iou_pred"] + + # Check the output is equal to the one we expect if we + # run each component separately. This is to confirm that + # the graph is getting compiled correctly i.e. the jitted + # execution is equivalent to the eager execution. + features = self.image_encoder(inputs["images"]) + outputs_ex = self.prompt_encoder( + {k: v for k, v in inputs.items() if k != "images"} + ) + outputs_ex = self.mask_decoder( + { + "image_embeddings": features, + "image_pe": outputs_ex["dense_positional_embeddings"], + "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], + "dense_prompt_embeddings": outputs_ex["dense_embeddings"], + }, + ) + masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] - # Check the output is equal to the one we expect if we - # run each component separately. This is to confirm that - # the graph is getting compiled correctly i.e. the jitted - # execution is equivalent to the eager execution. - features = self.image_encoder(inputs["images"]) - outputs_ex = self.prompt_encoder( - {k: v for k, v in inputs.items() if k != "images"} - ) - outputs_ex = self.mask_decoder( - { - "image_embeddings": features, - "image_pe": outputs_ex["dense_positional_embeddings"], - "sparse_prompt_embeddings": outputs_ex["sparse_embeddings"], - "dense_prompt_embeddings": outputs_ex["dense_embeddings"], - }, - ) - masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"] + self.assertAllClose(masks, masks_ex, atol=1e-4) + self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) - self.assertAllClose(masks, masks_ex, atol=1e-4) - self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4) + # Reset the global policy + keras.mixed_precision.set_global_policy(old_policy) @pytest.mark.extra_large def test_end_to_end_model_save(self):