diff --git a/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py b/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py index 804db69c..9958e3ed 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py @@ -307,7 +307,12 @@ def forward(self, feats): feat_low = proj_feats[idx - 1] feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high) inner_outs[0] = feat_high - upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest') + n_low, c_low, h_low, w_low = feat_low.shape + n_high, c_high, h_high, w_high = feat_high.shape + if h_low == 2 * h_high and w_low == 2 * w_high: + upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest') + else: + upsample_feat = F.interpolate(feat_high, size=(h_low, w_low), mode='bilinear') inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1)) inner_outs.insert(0, inner_out) diff --git a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr.py b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr.py index 851d4f74..1bf82c4c 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr.py @@ -28,7 +28,11 @@ def __init__(self, backbone: nn.Module, encoder, decoder, multi_scale=None): def forward(self, x, targets=None): if self.multi_scale and self.training: sz = np.random.choice(self.multi_scale) - x = F.interpolate(x, size=[sz, sz]) + n, c, h, w = x.shape + if w > h: # assuming longer side matches sz + x = F.interpolate(x, size=[int(sz * h / w), int(sz)]) + else: + x = F.interpolate(x, size=[int(sz), int(sz * w / h)]) x = self.backbone(x) x = self.encoder(x)