Skip to content

Commit

Permalink
Support torch2onnx for maskformer series (#10782)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunningLeon authored Aug 15, 2023
1 parent 884aad0 commit a98f36e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 15 deletions.
10 changes: 3 additions & 7 deletions mmdet/models/dense_heads/mask2former_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,7 @@ def forward(self, x: List[Tensor],
decoder layer. Each with shape (batch_size, num_queries, \
h, w).
"""
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
batch_size = len(batch_img_metas)
batch_size = x[0].shape[0]
mask_features, multi_scale_memorys = self.pixel_decoder(x)
# multi_scale_memorys (from low resolution to high resolution)
decoder_inputs = []
Expand Down Expand Up @@ -438,9 +435,8 @@ def forward(self, x: List[Tensor],
for i in range(self.num_transformer_decoder_layers):
level_idx = i % self.num_transformer_feat_level
# if a mask is all True(all background), then set it all False.
attn_mask[torch.where(
attn_mask.sum(-1) == attn_mask.shape[-1])] = False

mask_sum = (attn_mask.sum(-1) != attn_mask.shape[-1]).unsqueeze(-1)
attn_mask = attn_mask & mask_sum
# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
query_feat = layer(
Expand Down
2 changes: 1 addition & 1 deletion mmdet/models/dense_heads/maskformer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def forward(self, x: Tuple[Tensor],
batch_img_metas = [
data_sample.metainfo for data_sample in batch_data_samples
]
batch_size = len(batch_img_metas)
batch_size = x[0].shape[0]
input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w),
dtype=torch.float32)
Expand Down
13 changes: 6 additions & 7 deletions mmdet/models/layers/msdeformattn_pixel_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]
feat_hw = torch._shape_as_tensor(feat)[2:].to(feat.device)

# no padding
padding_mask_resized = feat.new_zeros(
Expand All @@ -177,7 +177,8 @@ def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
feat_wh = feat_hw.unsqueeze(0).flip(dims=[0, 1])
factor = feat_wh * self.strides[level_idx]
reference_points = reference_points / factor

# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
Expand All @@ -188,7 +189,7 @@ def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
spatial_shapes.append(feat_hw)
reference_points_list.append(reference_points)
# shape (batch_size, total_num_queries),
# total_num_queries=sum([., h_i * w_i,.])
Expand All @@ -197,11 +198,10 @@ def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
encoder_inputs = torch.cat(encoder_input_list, dim=1)
level_positional_encodings = torch.cat(
level_positional_encoding_list, dim=1)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=device)
num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
spatial_shapes = torch.cat(spatial_shapes).view(-1, 2)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros(
(1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
Expand All @@ -223,7 +223,6 @@ def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
memory = memory.permute(0, 2, 1)

# from low resolution to high resolution
num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_queries_per_level, dim=-1)
outs = [
x.reshape(batch_size, -1, spatial_shapes[i][0],
Expand Down

0 comments on commit a98f36e

Please sign in to comment.