Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dtype support for SegmentAnythingModel #2207

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions keras_cv/models/segmentation/segment_anything/sam_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def call(self, query, value, key):
# Attention
C_PH = ops.shape(query)[-1]
out = query @ ops.transpose(key, (0, 1, 3, 2))
out = out / ops.sqrt(ops.cast(C_PH, dtype=self.dtype))
out = out / ops.sqrt(ops.cast(C_PH, dtype=self.compute_dtype))
out = ops.softmax(out, axis=-1)

# Get output
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, num_positional_features, scale, **kwargs):
self.positional_encoding_gaussian_matrix = self.add_weight(
name="positional_encoding_gaussian_matrix",
shape=(2, self.num_positional_features),
dtype=self.dtype,
dtype=self.variable_dtype,
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
trainable=False,
initializer=keras.initializers.get("normal"),
)
Expand All @@ -288,7 +288,9 @@ def build(self, input_shape=None):

def __positional_encodings(self, coords):
coords = coords * 2 - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = coords @ ops.cast(
self.positional_encoding_gaussian_matrix, dtype=self.compute_dtype
)
coords = coords * (2 * math.pi)
return ops.concatenate([ops.sin(coords), ops.cos(coords)], axis=-1)

Expand All @@ -305,11 +307,11 @@ def encode_image(self, size):
tensor: Positional encoding of the image.
"""
H, W = size
grid = ops.ones(shape=(H, W), dtype=self.dtype)
grid = ops.ones(shape=(H, W), dtype=self.compute_dtype)
y_embed = ops.cumsum(grid, axis=0) - 0.5
x_embed = ops.cumsum(grid, axis=1) - 0.5
y_embed = y_embed / ops.cast(H, self.dtype)
x_embed = x_embed / ops.cast(W, self.dtype)
y_embed = y_embed / ops.cast(H, self.compute_dtype)
x_embed = x_embed / ops.cast(W, self.compute_dtype)
return self.__positional_encodings(
ops.stack([x_embed, y_embed], axis=-1)
)
Expand Down
98 changes: 59 additions & 39 deletions keras_cv/models/segmentation/segment_anything/sam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,48 +220,68 @@ def test_mask_decoder(self):
self.assertEqual(num_parameters, 4_058_340)

@pytest.mark.large
def test_end_to_end_model_predict(self):
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)
@parameterized.named_parameters(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this test marked as large? (just for my own learning)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model initialized here is a ViT Base model with 130M parameters. Creating and evaluating it takes about 15-20 seconds which is significantly more than small unit tests in KerasCV.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks! No need on this PR, but in general it will be good to separate small checks (like dtype stuff) into fast running tests, and keep the large test only for the things that must inherently by large parameter count and slow (like preset tests).

Did a big rewrite of KerasNLP backbones to this effect a bit ago. e.g. https://github.com/keras-team/keras-nlp/blob/a05f411a27eab437e71a1651c97e9addf26298ef/keras_nlp/models/bert/bert_backbone_test.py#L38-L80

[
("float32", "float32"),
("mixed_float16", "mixed_float16"),
("bfloat16", "bfloat16"),
]
)
def test_end_to_end_model_predict(self, dtype_policy):
import threading

with threading.Lock():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's with this? are we running our cv testing multi-processed ever?

Copy link
Contributor Author

@tirthasheshpatel tirthasheshpatel Dec 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be multi-processed with the -n <num_threads> argumment in pytest. PyTest uses multi-processing and not multi-threading so locking should not be necessary here. I just added this as a safeguard if anyone ever tries to run these tests using Python threads.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term, we could move towards Model(dtype=policy) support, so that these tests can run effectively without mutating global state.

# We are changing the global dtype policy here but don't want any
# other tests to use that policy, so compute under a lock until
# we reset the global policy.
old_policy = getattr(
keras.mixed_precision, "dtype_policy", lambda: "float32"
)()
keras.mixed_precision.set_global_policy(dtype_policy)
model = SegmentAnythingModel(
backbone=self.image_encoder,
prompt_encoder=self.prompt_encoder,
mask_decoder=self.mask_decoder,
)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
# We use box-only prompting for this test.
mask_prompts = self.get_prompts(1, "boxes")
inputs = {
"images": np.ones((1, 1024, 1024, 3)),
}
inputs.update(mask_prompts)

# Check the number of parameters
num_parameters = np.sum([np.prod(x.shape) for x in model.weights])
self.assertEqual(num_parameters, 89_670_912 + 6_476 + 4_058_340)

# Forward pass through the model
outputs = model.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex["sparse_embeddings"],
"dense_prompt_embeddings": outputs_ex["dense_embeddings"],
},
)
masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"]

# Check the output is equal to the one we expect if we
# run each component separately. This is to confirm that
# the graph is getting compiled correctly i.e. the jitted
# execution is equivalent to the eager execution.
features = self.image_encoder(inputs["images"])
outputs_ex = self.prompt_encoder(
{k: v for k, v in inputs.items() if k != "images"}
)
outputs_ex = self.mask_decoder(
{
"image_embeddings": features,
"image_pe": outputs_ex["dense_positional_embeddings"],
"sparse_prompt_embeddings": outputs_ex["sparse_embeddings"],
"dense_prompt_embeddings": outputs_ex["dense_embeddings"],
},
)
masks_ex, iou_pred_ex = outputs_ex["masks"], outputs_ex["iou_pred"]
self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)

self.assertAllClose(masks, masks_ex, atol=1e-4)
self.assertAllClose(iou_pred, iou_pred_ex, atol=1e-4)
# Reset the global policy
keras.mixed_precision.set_dtype_policy(old_policy)

@pytest.mark.extra_large
def test_end_to_end_model_save(self):
Expand Down
Loading