Skip to content

Commit

Permalink
Fix dtype support for SegmentAnythingModel (#2207)
Browse files Browse the repository at this point in the history
* Fix dtype support for SAM

* Update keras_cv/models/segmentation/segment_anything/sam_test.py

* Fix Keras 2 failures

* Fix F401 lint error; remove unused import
  • Loading branch information
tirthasheshpatel authored Dec 4, 2023
1 parent 431e97c commit 37ffac0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 54 deletions.
16 changes: 7 additions & 9 deletions keras_cv/layers/vit_det_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import keras
from keras_cv.backend import ops
Expand Down Expand Up @@ -123,16 +121,16 @@ def _get_rel_pos(self, query_size, key_size, rel_pos):
return rel_pos_resized
else:
rel_pos_resized = rel_pos
query_coordinates = np.arange(query_size, dtype="float32")[:, None] * (
max(key_size / query_size, 1.0)
)
key_coordinates = np.arange(key_size, dtype="float32")[None, :] * (
max(query_size / key_size, 1.0)
)
query_coordinates = ops.cast(
ops.arange(query_size), dtype=self.compute_dtype
)[:, None] * (max(key_size / query_size, 1.0))
key_coordinates = ops.cast(
ops.arange(key_size), dtype=self.compute_dtype
)[None, :] * (max(query_size / key_size, 1.0))
relative_coordinates = (query_coordinates - key_coordinates) + (
key_size - 1
) * max(query_size / key_size, 1.0)
relative_coordinates = relative_coordinates.astype("int32")
relative_coordinates = ops.cast(relative_coordinates, dtype="int32")
return ops.take(rel_pos_resized, relative_coordinates, 0)

def call(self, attention_map, queries, query_size, key_size):
Expand Down
14 changes: 8 additions & 6 deletions keras_cv/models/segmentation/segment_anything/sam_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def call(self, query, value, key):
# Attention
C_PH = ops.shape(query)[-1]
out = query @ ops.transpose(key, (0, 1, 3, 2))
out = out / ops.sqrt(ops.cast(C_PH, dtype=self.dtype))
out = out / ops.sqrt(ops.cast(C_PH, dtype=self.compute_dtype))
out = ops.softmax(out, axis=-1)

# Get output
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, num_positional_features, scale, **kwargs):
self.positional_encoding_gaussian_matrix = self.add_weight(
name="positional_encoding_gaussian_matrix",
shape=(2, self.num_positional_features),
dtype=self.dtype,
dtype=self.variable_dtype,
trainable=False,
initializer=keras.initializers.get("normal"),
)
Expand All @@ -288,7 +288,9 @@ def build(self, input_shape=None):

def __positional_encodings(self, coords):
coords = coords * 2 - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = coords @ ops.cast(
self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
)
coords = coords * (2 * math.pi)
return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)

Expand All @@ -305,11 +307,11 @@ def encode_image(self, size):
tensor: Positional encoding of the image.
"""
H, W = size
grid = ops.ones(shape=(H, W), dtype=self.dtype)
grid = ops.ones(shape=(H, W), dtype=self.compute_dtype)
y_embed = ops.cumsum(grid, axis=0) - 0.5
x_embed = ops.cumsum(grid, axis=1) - 0.5
y_embed = y_embed / ops.cast(H, self.dtype)
x_embed = x_embed / ops.cast(W, self.dtype)
y_embed = y_embed / ops.cast(H, self.compute_dtype)
x_embed = x_embed / ops.cast(W, self.compute_dtype)
return self.__positional_encodings(
ops.stack([x_embed, y_embed], axis=-1)
)
Expand Down
98 changes: 59 additions & 39 deletions keras_cv/models/segmentation/segment_anything/sam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,48 +220,68 @@ def test_mask_decoder(self):
self.assertEqual(num_parameters, 4_058_340)

@pytest.mark.large
def test_end_to_end_model_predict(self):
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)
@parameterized.named_parameters(
[
("float32", "float32"),
("mixed_float16", "mixed_float16"),
("bfloat16", "bfloat16"),
]
)
def test_end_to_end_model_predict(self, dtype_policy):
import threading

with threading.Lock():
# We are changing the global dtype policy here but don't want any
# other tests to use that policy, so compute under a lock until
# we reset the global policy.
old_policy = getattr(
keras.mixed_precision, "dtype_policy", lambda: "float32"
)()
keras.mixed_precision.set_global_policy(dtype_policy)
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex["sparse_embeddings"],
"dense_prompt_embeddings": outputs_ex["dense_embeddings"],
},
)
masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex["sparse_embeddings"],
"dense_prompt_embeddings": outputs_ex["dense_embeddings"],
},
)
masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"]
self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)

self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)
# Reset the global policy
keras.mixed_precision.set_global_policy(old_policy)

@pytest.mark.extra_large
def test_end_to_end_model_save(self):
Expand Down

0 comments on commit 37ffac0

Please sign in to comment.