Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] RT-DETR post-processing yields incorrect results when use_focal_loss=False #32578

Closed
2 of 4 tasks
dwchoo opened this issue Aug 10, 2024 · 2 comments
Closed
2 of 4 tasks

Comments

@dwchoo
Copy link

dwchoo commented Aug 10, 2024

System Info

  • transformers version: 4.45.0.dev0
  • Platform: Linux-5.19.0-1010-nvidia-lowlatency-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.4
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0a0+07cecf4168.nv24.05 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA GeForce RTX 4090

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The official usage example for RT-DETR in the documentation produces incorrect results when use_focal_loss is set to False in the post_process_object_detection method.

import torch
import requests

from PIL import Image
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 
image = Image.open(requests.get(url, stream=True).raw)

image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# Here is changed
# Default use_focal_loss=True
results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]), threshold=0.3, use_focal_loss=False) 

for result in results:
    for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
        score, label = score.item(), label_id.item()
        box = [round(i, 2) for i in box.tolist()]
        print(f"{model.config.id2label[label]}: {score:.2f} {box}")

Expected behavior

When use_focal_loss=True

sofa: 0.97 [0.14, 0.38, 640.13, 476.21]
cat: 0.96 [343.38, 24.28, 640.14, 371.5]
cat: 0.96 [13.23, 54.18, 318.98, 472.22]
remote: 0.95 [40.11, 73.44, 175.96, 118.48]
remote: 0.92 [333.73, 76.58, 369.97, 186.99]

true

When use_focal_loss=False

person: 1.00 [40.11, 73.44, 175.96, 118.48]
person: 1.00 [343.38, 24.28, 640.14, 371.5]
person: 1.00 [13.23, 54.18, 318.98, 472.22]
person: 1.00 [333.73, 76.58, 369.97, 186.99]
person: 1.00 [-10.18, -66.86, 629.35, 412.62]
person: 1.00 [0.85, -69.13, 640.82, 410.67]
person: 1.00 [4.19, -24.31, 644.17, 378.63]
person: 1.00 [0.04, -62.35, 640.01, 417.49]
person: 1.00 [-8.1, 136.34, 631.0, 385.72]
person: 1.00 [11.32, 54.2, 317.21, 473.39]
person: 1.00 [-5.81, 139.17, 620.73, 380.67]
person: 1.00 [2.41, 18.83, 642.39, 378.26]
person: 1.00 [-9.84, 21.47, 629.93, 322.49]
person: 1.00 [0.36, -1.19, 640.35, 478.7]
person: 1.00 [-44.53, -104.03, 595.36, 373.99]
person: 1.00 [0.91, -39.17, 640.89, 440.75]
person: 1.00 [-2.88, 216.51, 637.06, 434.71]
person: 1.00 [1.13, -0.58, 331.65, 476.0]
person: 1.00 [1.15, 58.44, 428.54, 471.67]
person: 1.00 [5.91, -66.87, 645.79, 378.84]
person: 1.00 [0.9, -44.8, 640.88, 435.03]
person: 1.00 [-0.19, -1.78, 639.8, 476.89]
person: 1.00 [1.72, 39.05, 641.1, 472.89]
person: 1.00 [343.53, 24.69, 639.36, 370.48]
person: 1.00 [1.8, 30.17, 640.73, 470.69]
person: 1.00 [6.8, -69.75, 646.77, 389.73]
person: 1.00 [0.6, 3.42, 640.59, 389.43]
person: 1.00 [-45.64, -107.19, 594.27, 372.2]
person: 1.00 [0.68, 0.79, 640.59, 314.32]
person: 1.00 [1.17, 1.47, 641.16, 384.91]
person: 1.00 [0.14, 0.38, 640.13, 476.21]
person: 1.00 [15.19, 6.0, 655.03, 271.6]
person: 1.00 [-35.92, -98.61, 602.61, 379.19]
person: 1.00 [1.21, -27.17, 641.2, 452.78]
person: 1.00 [-42.0, -93.72, 597.85, 385.56]
person: 1.00 [11.33, 127.73, 481.84, 470.83]
person: 1.00 [0.78, 2.29, 640.77, 375.53]
person: 1.00 [-38.3, -109.69, 599.97, 369.88]
person: 1.00 [-8.05, 198.77, 631.16, 385.42]
person: 1.00 [28.85, 277.41, 638.55, 475.06]
person: 1.00 [1.96, 262.19, 641.8, 475.22]
person: 1.00 [0.15, 234.33, 639.66, 465.15]
person: 1.00 [100.86, 25.73, 590.97, 369.99]
person: 1.00 [16.91, -39.15, 198.13, 427.58]
person: 1.00 [8.48, 56.71, 317.32, 377.75]
person: 1.00 [-24.16, 44.59, 609.98, 324.35]
person: 1.00 [1.67, -21.88, 641.65, 457.15]
person: 1.00 [1.58, 47.78, 263.77, 316.35]
person: 1.00 [66.05, 40.36, 641.24, 475.55]
person: 1.00 [8.08, 34.69, 559.91, 368.08]
person: 1.00 [0.03, 9.21, 185.75, 465.75]
person: 1.00 [0.51, 3.66, 640.51, 381.63]
person: 1.00 [-39.93, -109.05, 600.01, 370.51]
person: 1.00 [69.6, 72.99, 172.19, 112.08]
person: 1.00 [-38.78, -102.11, 600.59, 370.73]
person: 1.00 [59.02, 29.3, 615.04, 463.61]
person: 1.00 [37.15, 53.98, 308.09, 124.21]
person: 1.00 [1.09, 221.75, 638.09, 448.28]
person: 1.00 [-56.47, 27.89, 419.31, 455.42]
person: 1.00 [0.78, -0.79, 217.89, 478.34]
person: 1.00 [8.94, 25.11, 638.34, 472.57]
person: 1.00 [18.73, 35.29, 619.38, 458.91]
person: 1.00 [345.32, 23.68, 645.02, 269.5]
person: 1.00 [28.8, -47.42, 631.41, 350.5]
person: 1.00 [1.26, 42.63, 172.07, 298.25]
person: 1.00 [-12.0, -0.47, 598.94, 346.74]
person: 1.00 [-5.09, 173.18, 617.97, 383.35]
person: 1.00 [1.47, -0.27, 641.42, 415.41]
person: 1.00 [56.78, 13.05, 640.34, 473.68]
person: 1.00 [0.19, 0.81, 640.18, 130.65]
person: 1.00 [73.93, 121.67, 637.92, 474.11]
person: 1.00 [1.52, 72.09, 329.32, 478.16]
person: 1.00 [-8.07, -74.78, 594.32, 387.01]
person: 1.00 [19.6, -17.51, 635.97, 320.83]
person: 1.00 [-40.61, -73.53, 599.33, 371.17]
person: 1.00 [243.94, 26.74, 594.06, 309.69]
person: 1.00 [446.15, 18.67, 581.79, 310.19]
person: 1.00 [4.56, 285.25, 644.46, 469.4]
person: 1.00 [151.32, 62.36, 367.71, 326.84]
person: 1.00 [25.66, 283.26, 322.72, 474.62]
person: 1.00 [3.91, -25.77, 643.89, 399.53]
person: 1.00 [0.46, 191.1, 225.5, 458.27]
person: 1.00 [1.85, -17.32, 641.01, 460.72]
person: 1.00 [162.58, 50.65, 598.12, 381.86]
person: 1.00 [0.65, 236.97, 640.58, 449.68]
person: 1.00 [528.98, -1.92, 638.65, 62.44]
person: 1.00 [363.64, -0.13, 638.31, 82.46]
person: 1.00 [41.2, 73.73, 175.34, 118.12]
person: 1.00 [0.97, -1.93, 367.2, 477.25]
person: 1.00 [487.96, -0.82, 639.21, 82.64]
person: 1.00 [2.84, 37.55, 363.03, 340.35]
person: 1.00 [1.45, 331.03, 641.42, 476.01]
person: 1.00 [3.64, 66.66, 353.08, 363.45]
person: 1.00 [346.94, 25.13, 639.3, 471.17]
person: 1.00 [-0.81, -1.63, 336.41, 478.3]
person: 1.00 [11.35, 60.53, 399.96, 474.83]
person: 1.00 [202.02, 56.34, 609.75, 367.46]
person: 1.00 [64.17, 301.09, 637.44, 475.32]
person: 1.00 [43.04, 40.22, 549.0, 324.92]
person: 1.00 [0.49, 0.75, 639.51, 123.66]
person: 1.00 [2.91, 1.34, 496.53, 247.87]
person: 1.00 [-43.47, -95.43, 596.13, 383.73]
person: 1.00 [-0.32, 0.54, 639.64, 121.65]
person: 1.00 [2.36, 1.03, 480.97, 123.27]
person: 1.00 [1.69, 101.47, 625.26, 384.55]
person: 1.00 [10.54, 169.03, 224.0, 460.98]
person: 1.00 [-2.45, 46.63, 371.29, 441.16]
person: 1.00 [-25.4, -99.08, 593.66, 379.06]
person: 1.00 [490.38, 20.15, 639.81, 471.27]
person: 1.00 [1.95, 2.5, 641.92, 319.84]
person: 1.00 [0.54, 0.91, 640.43, 125.01]
person: 1.00 [1.31, 130.35, 121.19, 469.66]
person: 1.00 [-0.02, 0.68, 639.98, 124.31]
person: 1.00 [344.29, 93.63, 576.88, 370.13]
person: 1.00 [0.48, 37.4, 94.03, 470.18]
person: 1.00 [71.13, -0.23, 639.15, 81.83]
person: 1.00 [367.52, 37.94, 601.49, 230.9]
person: 1.00 [-30.56, -102.78, 600.38, 376.28]
person: 1.00 [5.34, 269.57, 643.09, 465.62]
person: 1.00 [369.59, -0.1, 638.84, 72.72]
person: 1.00 [173.63, 34.99, 579.36, 436.05]
person: 1.00 [11.33, 25.64, 638.05, 424.99]
person: 1.00 [4.19, 23.88, 644.11, 387.04]
person: 1.00 [-0.95, -73.24, 637.02, 368.35]
person: 1.00 [336.04, 77.28, 372.21, 257.29]
person: 1.00 [14.94, -66.54, 647.64, 379.76]
person: 1.00 [6.49, 132.64, 303.64, 466.5]
person: 1.00 [223.85, 45.22, 607.61, 361.32]
person: 1.00 [1.19, 0.52, 640.01, 122.91]
person: 1.00 [12.19, 69.9, 319.91, 312.88]
person: 1.00 [7.95, 7.4, 646.76, 442.46]
person: 1.00 [2.6, -38.88, 642.58, 440.3]
person: 1.00 [2.97, 127.14, 640.23, 475.11]
person: 1.00 [4.5, 94.64, 311.44, 313.71]
person: 1.00 [-0.41, -0.06, 639.58, 313.42]
person: 1.00 [17.37, 57.79, 215.01, 412.27]
person: 1.00 [10.01, 50.71, 269.38, 446.99]
person: 1.00 [-37.33, -48.81, 601.75, 404.09]
person: 1.00 [492.12, -0.21, 639.69, 70.05]
person: 1.00 [-21.18, -73.4, 590.9, 401.15]
person: 1.00 [408.74, 24.02, 643.94, 265.61]
person: 1.00 [17.23, 186.59, 247.22, 472.83]
person: 1.00 [151.44, 190.11, 297.58, 321.04]
person: 1.00 [3.88, 302.06, 643.83, 475.11]
person: 1.00 [0.27, 0.25, 640.26, 123.1]
person: 1.00 [492.85, 16.35, 640.26, 471.57]
person: 1.00 [14.62, 55.53, 316.26, 291.06]
person: 1.00 [14.1, -0.33, 631.16, 74.81]
person: 1.00 [334.63, 78.56, 470.02, 356.54]
person: 1.00 [3.77, 18.34, 479.68, 473.46]
person: 1.00 [344.6, 102.96, 638.63, 470.47]
person: 1.00 [2.75, 8.23, 640.96, 339.47]
person: 1.00 [108.61, 17.33, 635.49, 368.47]
person: 1.00 [487.9, -0.24, 638.91, 461.87]
person: 1.00 [158.74, 117.42, 639.63, 473.97]
person: 1.00 [1.58, 19.88, 305.79, 360.04]
person: 1.00 [498.59, 28.67, 640.16, 472.73]
person: 1.00 [-2.19, 201.11, 637.61, 382.93]
person: 1.00 [0.14, 49.01, 71.28, 435.95]
person: 1.00 [-15.64, 102.72, 604.85, 368.72]
person: 1.00 [494.15, 32.05, 628.72, 340.45]
person: 1.00 [1.41, 0.66, 639.71, 297.89]
person: 1.00 [2.69, 14.26, 640.87, 388.07]
person: 1.00 [340.11, 0.93, 638.18, 194.54]
person: 1.00 [0.94, 0.74, 639.67, 124.13]
person: 1.00 [0.18, 0.32, 640.17, 122.62]
person: 1.00 [500.6, 6.68, 639.57, 471.04]
person: 1.00 [348.77, 204.83, 628.61, 413.85]
person: 1.00 [-37.45, -96.32, 599.52, 380.37]
person: 1.00 [346.7, 93.69, 578.03, 370.37]
person: 1.00 [7.68, 56.58, 318.09, 370.17]
person: 1.00 [-39.59, -92.42, 598.33, 386.99]
person: 1.00 [1.42, 59.61, 315.34, 474.65]
person: 1.00 [7.24, -22.12, 645.44, 405.42]
person: 1.00 [148.65, 51.8, 632.96, 372.74]
person: 1.00 [7.99, 57.89, 427.42, 468.69]
person: 1.00 [256.47, 53.29, 602.2, 369.29]
person: 1.00 [417.25, 45.64, 579.57, 289.73]
person: 1.00 [8.96, 218.43, 623.48, 375.79]
person: 1.00 [345.58, 24.0, 640.83, 212.95]
person: 1.00 [2.6, 196.52, 221.6, 472.77]
person: 1.00 [1.54, -0.22, 309.6, 129.08]
person: 1.00 [252.79, 46.48, 636.74, 375.81]
person: 1.00 [0.68, 0.06, 638.26, 121.15]
person: 1.00 [0.6, -0.06, 346.16, 246.56]
person: 1.00 [538.08, 52.16, 638.37, 173.23]
person: 1.00 [491.73, 12.13, 638.8, 456.34]
person: 1.00 [425.46, 25.33, 640.99, 373.4]
person: 1.00 [0.28, 131.58, 104.88, 471.05]
person: 1.00 [-28.17, -6.42, 603.09, 363.32]
person: 1.00 [334.83, 92.94, 366.77, 247.36]
person: 1.00 [1.08, 59.42, 316.2, 472.36]
person: 1.00 [314.86, 296.47, 636.58, 475.68]
person: 1.00 [-15.84, -83.0, 592.57, 396.41]
person: 1.00 [336.01, 76.63, 396.59, 319.83]
person: 1.00 [382.79, 72.34, 573.73, 280.94]
person: 1.00 [128.7, 100.72, 482.46, 342.77]
person: 1.00 [-0.86, 50.07, 608.39, 301.51]
person: 1.00 [364.86, 73.13, 388.2, 106.67]
person: 1.00 [1.21, 32.98, 179.92, 253.96]
person: 1.00 [180.18, 54.78, 318.81, 239.11]
person: 1.00 [9.61, 38.79, 639.13, 472.02]
person: 1.00 [1.54, 51.77, 316.49, 325.46]
person: 1.00 [340.6, 76.51, 370.03, 98.49]
person: 1.00 [17.32, 215.85, 218.81, 472.48]
person: 1.00 [1.8, 260.97, 74.21, 472.16]
person: 1.00 [162.18, 101.62, 638.44, 474.12]
person: 1.00 [0.21, -0.03, 640.19, 121.75]
person: 1.00 [8.45, 243.09, 194.34, 468.54]
person: 1.00 [0.37, 30.73, 83.28, 471.32]
person: 1.00 [336.06, 77.69, 440.21, 353.98]
person: 1.00 [343.71, 98.27, 585.57, 352.21]
person: 1.00 [569.03, 54.91, 638.27, 87.99]
person: 1.00 [349.85, 199.06, 580.5, 371.36]
person: 1.00 [-5.68, 0.19, 615.74, 478.37]
person: 1.00 [19.14, 66.82, 596.88, 306.63]
person: 1.00 [-0.11, 57.04, 49.94, 226.28]
person: 1.00 [-7.56, -21.22, 319.02, 381.06]
person: 1.00 [-1.29, 53.2, 309.89, 473.24]
person: 1.00 [12.82, 15.03, 248.0, 322.41]
person: 1.00 [516.15, 119.23, 639.08, 474.0]
person: 1.00 [18.63, 199.56, 221.1, 473.27]
person: 1.00 [497.05, 117.41, 639.74, 473.1]
person: 1.00 [9.79, 118.04, 297.19, 454.02]
person: 1.00 [3.34, 358.19, 641.11, 477.52]
person: 1.00 [172.61, 57.34, 320.07, 239.83]
person: 1.00 [1.02, 54.99, 183.33, 473.55]
person: 1.00 [55.08, 396.24, 425.23, 477.13]
person: 1.00 [41.17, 76.11, 106.77, 116.85]
person: 1.00 [6.81, 44.78, 39.15, 108.71]
person: 1.00 [175.53, 55.39, 319.92, 238.32]
person: 1.00 [143.4, 25.21, 636.62, 372.52]
person: 1.00 [233.7, 53.43, 382.07, 192.45]
person: 1.00 [348.71, 189.82, 580.6, 371.47]
person: 1.00 [220.17, 54.97, 315.42, 141.13]
person: 1.00 [0.11, 2.01, 213.73, 472.97]
person: 1.00 [569.18, 105.69, 632.75, 223.87]
person: 1.00 [347.64, 129.3, 581.54, 372.06]
person: 1.00 [36.19, 57.23, 315.95, 240.78]
person: 1.00 [333.96, 77.02, 369.67, 186.11]
person: 1.00 [8.43, 161.46, 631.11, 443.6]
person: 1.00 [349.32, 197.01, 585.32, 310.25]
person: 1.00 [179.19, 9.84, 637.71, 314.29]
person: 1.00 [567.7, 54.52, 638.05, 87.81]
person: 1.00 [343.33, 76.47, 368.09, 116.57]
person: 1.00 [334.01, 25.98, 640.51, 210.25]
person: 1.00 [-0.24, -0.37, 476.16, 58.04]
person: 1.00 [365.76, 72.45, 388.16, 97.67]
person: 1.00 [5.58, 48.1, 556.11, 321.49]
person: 1.00 [479.51, 0.24, 639.99, 138.34]
person: 1.00 [112.97, 54.48, 320.09, 319.25]
person: 1.00 [96.5, 170.52, 591.23, 385.25]
person: 1.00 [18.35, 208.82, 634.48, 435.54]
person: 1.00 [203.47, 54.36, 315.21, 181.25]
person: 1.00 [6.79, 38.45, 95.96, 202.38]
person: 1.00 [35.03, 83.94, 585.53, 336.66]
person: 1.00 [-1.06, 46.68, 322.56, 431.87]
person: 1.00 [333.98, 76.84, 369.62, 186.06]
person: 1.00 [340.02, 76.48, 582.81, 371.45]
person: 1.00 [115.9, 44.38, 574.69, 256.77]
person: 1.00 [347.19, 182.82, 584.57, 326.97]
person: 1.00 [292.2, 99.52, 312.06, 109.58]
person: 1.00 [130.86, 25.02, 634.26, 374.85]
person: 1.00 [20.14, -56.34, 569.14, 422.56]
person: 1.00 [493.6, 23.8, 639.67, 203.12]
person: 1.00 [340.95, 90.61, 366.48, 126.22]
person: 1.00 [338.55, 76.9, 369.83, 129.09]
person: 1.00 [0.4, -1.09, 397.33, 47.79]
person: 1.00 [-5.2, 183.06, 631.77, 380.65]
person: 1.00 [2.09, 53.29, 318.05, 266.63]
person: 1.00 [-17.94, -90.57, 591.21, 388.28]
person: 1.00 [494.47, 0.84, 640.16, 158.92]
person: 1.00 [39.41, 20.58, 622.99, 320.53]
person: 1.00 [349.8, 216.39, 586.93, 371.11]
person: 1.00 [3.41, 72.11, 168.76, 463.4]
person: 1.00 [333.96, 127.41, 354.65, 186.02]
person: 1.00 [21.7, -50.29, 546.78, 413.27]
person: 1.00 [349.67, 201.48, 579.65, 371.36]
person: 1.00 [44.24, 73.55, 177.53, 102.62]
person: 1.00 [260.45, 100.7, 309.76, 141.39]
person: 1.00 [334.24, 127.85, 354.47, 185.53]
person: 1.00 [11.41, 167.25, 587.53, 383.86]
person: 1.00 [338.72, 76.53, 369.79, 121.35]
person: 1.00 [334.4, 154.26, 351.33, 186.84]
person: 1.00 [0.52, 0.42, 638.42, 86.97]
person: 1.00 [336.65, 74.82, 380.12, 162.23]
person: 1.00 [248.96, 27.01, 637.61, 474.17]
person: 1.00 [313.42, 59.9, 500.07, 257.06]
person: 1.00 [5.76, 6.61, 291.94, 327.65]
person: 1.00 [113.73, 21.36, 633.36, 379.54]
person: 1.00 [16.32, 82.46, 68.45, 241.11]
person: 1.00 [516.34, -4.28, 593.78, 48.09]
person: 1.00 [1.01, 0.37, 207.69, 239.27]
person: 1.00 [-0.22, 52.25, 70.18, 130.13]
person: 1.00 [26.83, 323.37, 73.32, 471.25]
person: 1.00 [352.95, 25.03, 640.83, 201.57]
person: 1.00 [-0.55, -1.35, 305.24, 42.25]
person: 1.00 [358.3, 24.66, 640.36, 209.21]
person: 1.00 [4.82, 60.29, 387.41, 474.41]
person: 1.00 [442.45, 218.22, 493.46, 290.22]

false

Solved

Fixes a critical bug in the post_process_object_detection
scores = torch.nn.functional.softmax(out_logits)[:, :, :-1]
-> scores = torch.nn.functional.softmax(out_logits,dim-1)

@qubvel
Copy link
Member

qubvel commented Aug 12, 2024

Hi @dwchoo thank you for reporting the issue and submitting a PR, I will check this out!

cc @SangbumChoi

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants