From d51c3a85e16db7e95d5930059f2db3a189053749 Mon Sep 17 00:00:00 2001 From: dwchoo <50813484+dwchoo@users.noreply.github.com> Date: Fri, 23 Aug 2024 07:56:31 +0000 Subject: [PATCH] FIX rtdetr postprocessing at use_focal_loss=False --- rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index 7d70113a..24b42ba9 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -45,7 +45,7 @@ def forward(self, outputs, orig_target_sizes): boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) else: - scores = F.softmax(logits)[:, :, :-1] + scores = F.softmax(logits, dim=-1) scores, labels = scores.max(dim=-1) boxes = bbox_pred if scores.shape[1] > self.num_top_queries: