Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RT-DETR postprocessing bug #32579

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

dwchoo
Copy link

@dwchoo dwchoo commented Aug 10, 2024

What does this PR do?

This PR addresses two important issues:

  1. Fixes a critical bug in the post_process_object_detection method of RTDetrImageProcessor for the RT-DETR model.
  2. Improves the documentation for the labels parameter in RTDetrForObjectDetection.

Fixes # #32578(issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts

Improve RT-DETR documentation: Clarify bounding box format for labels

Current Documentation

  • The output bounding box format is clearly specified as (top_left_x, top_left_y, bottom_right_x, bottom_right_y) in the RTDetrImageProcessor documentation.
  • However, the format for input bounding boxes in the labels parameter is not explicitly stated in the RTDetrForObjectDetection documentation.

Missing Information

The labels parameter requires bounding boxes in the following format:

  • (center_x, center_y, width, height)
  • Values should be normalized to the range [0, 1]
    This information is crucial for correctly calculating the loss but is currently missing from the documentation.

Proposed Solution

Add the following clarification to the documentation for the labels parameter in RTDetrForObjectDetection:

"The bounding box coordinates in the 'boxes' key should be in the format (center_x, center_y, width, height) and have normalized values in the range [0, 1]."

Impact

Adding this information will significantly improve the user experience by:

  1. Reducing confusion about the required input format
  2. Ensuring correct loss calculation
  3. Saving users time in debugging and troubleshooting

Additional Notes

  • This issue was discovered while attempting to use the RT-DETR model for custom training.
  • The lack of this information in the documentation led to difficulties in properly preparing the input data and calculating the loss.

Related Links

geon0430 and others added 3 commits August 10, 2024 18:03
…RTDetrForObjectDetection

The format of labels for RTDetrForObjectDetection was not clearly specified,
leading to confusion. Added detailed comments explaining the label structure
to reduce ambiguity and improve ease of use.
@dwchoo dwchoo changed the title Rt detr postprocessing bug Rt-DETR postprocessing bug Aug 10, 2024
@dwchoo dwchoo changed the title Rt-DETR postprocessing bug RT-DETR postprocessing bug Aug 10, 2024
@qubvel qubvel added the Vision label Aug 11, 2024
@qubvel qubvel self-requested a review August 11, 2024 23:05
@SangbumChoi
Copy link
Contributor

@dwchoo Hi thanks for bringing this issue. Initially this model was aimed for use_focal_loss = True.
https://github.com/lyuwenyu/RT-DETR/blob/bc0cf9f16c1ae98e925a7495e32c81319a624088/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L48

Just visualizing use_focal_loss = False with use_focal_loss = True trained model might be not recommended pipeline. Can you also share the finetuned result?

@dwchoo
Copy link
Author

dwchoo commented Aug 12, 2024

@SangbumChoi Thank you for your response. I'd like to clarify that I haven't performed any fine-tuning on the model. My concerns are based on an analysis of the existing code.

I believe there's an issue with the line scores = F.softmax(logits)[:, :, :-1]. This operation excludes the last class from the softmax results, which in this case is class 79 (toothbrush).

Here's my reasoning:

  1. By using [:, :, :-1], we're effectively dropping the last class from the softmax output.
  2. This causes the maximum class ID to be 78 instead of 79, ignoring the 'toothbrush' class entirely.(all classes : 80, 0~79)
  3. I believe this is a bug unrelated to fine-tuning or the use of focal loss.

When I modify the code to softmax(out_logits, dim=-1), the model correctly classifies toothbrushes.

Given this, I suggest that the issue isn't about visualizing a use_focal_loss = False scenario with a model trained with use_focal_loss = True. Instead, it appears to be an oversight in the post-processing step that affects the model's ability to detect all classes correctly.

Could you please review this specific part of the code? I believe addressing this issue would improve the model's performance regardless of the use_focal_loss setting.

@SangbumChoi
Copy link
Contributor

@dwchoo Thanks for the detail description. I think this is due to the legacy of additional index that handles for non-object class. I think it is good to merge, but usually we add some test script to check these changes.

@qubvel Should we add some additional logic such as test_use_focal_loss_false?

@dwchoo
Copy link
Author

dwchoo commented Aug 12, 2024

@SangbumChoi @qubvel

Here is some code you can test.

import torch
import requests

from PIL import Image
import matplotlib.pyplot as plt
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor

# toothbrush coco image
url = 'http://images.cocodataset.org/train2017/000000464902.jpg' 
#url = 'http://images.cocodataset.org/val2017/000000039769.jpg' 

image = Image.open(requests.get(url, stream=True).raw)

image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd")

inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

results = image_processor.post_process_object_detection(
    outputs, 
    target_sizes=torch.tensor([image.size[::-1]]), 
    threshold=0.3, 
    use_focal_loss=True) ###### Change here True -> False

for result in results:
    for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
        score, label = score.item(), label_id.item()
        box = [round(i, 2) for i in box.tolist()]
        print(f"{model.config.id2label[label]}: {score:.2f} {box}")

Here's a visualization code you can use:

import matplotlib.pyplot as plt


plt.figure(figsize=(10, 10))
plt.imshow(image)
ax = plt.gca()

for result in results:
    for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
        label = model.config.id2label[label_id.item()]
        score = score.item()
        box = [round(i, 2) for i in box.tolist()]
        
        x_min, y_min, x_max, y_max = box
        rect = plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                             linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        
        plt.text(x_min, y_min, f'{label} {score:.2f}', color='white',
                 fontsize=12, bbox=dict(facecolor='red', alpha=0.5))

plt.axis('off')
plt.show()

@qubvel
Copy link
Member

qubvel commented Aug 12, 2024

Hi, @dwchoo and @SangbumChoi thanks for taking the time to dive into a problem!

As far as I understand RT-DETR in the current implementation in transformers and in the original implementation was designed to be trained with focal loss only. I mean, it applies a sigmoid function on logits, and an additional "void" class is not added to class logits as it would be done in the case of softmax activation function + cross-entropy loss.
Thus, the model cannot be tuned for use_focal_loss = False, even though we can compute the loss for this option I suppose the model will fail to converge. @SangbumChoi correct me I'm wrong 🙂

The solution might be one of

  1. Fix the modeling code to add additional "void" class -> then, the postprocessing would be fine with [..., :-1]
  2. Remove the option to use use_focal_loss = False in modeling and postprocessing code

@dwchoo
Copy link
Author

dwchoo commented Aug 12, 2024

@qubvel @SangbumChoi
Thank you for your explanation. I'd like to offer a different perspective:

  1. In practice, using softmax instead of sigmoid for class calculation doesn't significantly affect the model's output.

  2. The postprocessing code is only for result interpretation, not for training. Using softmax here doesn't impact model performance or training convergence.

  3. The current implementation seems to be an oversight. Using dim=-1 in the softmax calculation produces nearly identical results and aligns with common practices.

  4. Keeping the use_focal_loss = False option could be valuable for experimentation and comparison.

  5. I suggest maintaining both sigmoid and softmax options in postprocessing, allowing users to choose based on their specific needs.

I'm open to further discussion on this approach. What are your thoughts?

@SangbumChoi
Copy link
Contributor

Since it transformers based RTDETR is based on COCO format I also generally agree with the @dwchoo's comments. (Which means agree with current PR)

Also I think I have to recorrect about the [..., :-1]. I thought it was additional "void" class but it could be just purely mistake from the original repo.

@qubvel
Copy link
Member

qubvel commented Aug 13, 2024

@dwchoo @SangbumChoi thanks for the discussion!

I've experimented a bit and fixed the ability of RT-DETR to be trained with a "void" class + cross-entropy loss for labels (and other losses too).
Here is a PR

I guess this "use_focal_loss" is not the right name for the argument here, it should be more about how the model was trained and what activation function it is expected to be applied. For the model that will be fine-tuned with cross-entropy loss + "void" class, we should later apply softmax + remove the last "void" class, so I would argue not to remove [..., :-1]. The image processor should be consistent with model settings, for experiments over logits you can always avoid using post post-processing function and apply softmax or any other custom function / filtering / thresholding over logits.

I agree that it should be specified in the docs to avoid some misunderstanding of the parameter.

Please let me know what you think!

@dwchoo
Copy link
Author

dwchoo commented Aug 14, 2024

@qubvel @SangbumChoi
Thank you for your thoughtful proposal. While I appreciate the effort to enhance the model's flexibility, I respectfully disagree with the addition of a "void" class. Allow me to explain my perspective:

RT-DETR's architecture fundamentally differs from traditional YOLO models. Unlike YOLO, which uses a grid system with a fixed number of anchors necessitating a "void" class, RT-DETR employs 'Object Queries' and 'Uncertainty-minimal Query Selection' to identify and locate objects. This paper approach enables RT-DETR to effectively distinguish between background and objects without the need for a "void" class.

Given this structural difference, the [..., :-1] operation to remove a "void" class seems unnecessary and potentially contradicts the original author's design intentions for RT-DETR. Adding a "void" class and then removing it in post-processing might introduce unnecessary complexity without clear benefits.

I believe that maintaining the model's original design without a "void" class aligns better with RT-DETR's architecture and intended functionality. Perhaps we could explore alternative ways to enhance the model's flexibility that are more in line with its core design principles?

I'm open to further discussion and would be interested in hearing your thoughts on these points.

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Aug 14, 2024

I guess this "use_focal_loss" is not the right name for the argument here

I agree with this. This is the reason why we are having little bit confusion is naming issue. Let me make this discuss clear. Also be aware that @qubvel did additional loss function which is multiclass vs. multilabel scenario.

Small argument in following statement.

RT-DETR employs 'Object Queries' and 'Uncertainty-minimal Query Selection' to identify and locate objects. This paper approach enables RT-DETR to effectively distinguish between background and objects without the need for a "void" class.

We cannot assure that RT-DETR is aimed for without "void" class. Uncertainty-minimal Query Selection is for accurate query selection from vanilla model and Void class is to add foreground and background classification which is similar effect but working in independent stage (query selection vs. loss part).

https://github.com/lyuwenyu/RT-DETR/blob/bc0cf9f16c1ae98e925a7495e32c81319a624088/rtdetr_pytorch/src/zoo/rtdetr/matcher.py#L76

In conclusion, it is all about the policy of sigmoid and softmax. Since focal_loss is the Hard Negative Example focused Cross Entropy Loss, traditional CE loss might need "void class". Best way is can @dwchoo also do some simple finetuning test for neglecting vs. including "void" class when we do not use focal_loss?

@dwchoo
Copy link
Author

dwchoo commented Aug 14, 2024

@SangbumChoi
After thoroughly examining the RT-DETR paper, particularly section "4.3 Uncertainty-minimal Query Selection", I respectfully argue that RT-DETR was intentionally designed to function without a void class.

[RT-DETR, 4.3 Uncertainty-minimal Query Selection]
The confidence score represents the likelihood that the feature includes foreground objects. Nevertheless, the detector are required to simultaneously model the category and location of objects, both of which determine the quality of the features. Hence, the performance score of the feature is a latent variable that is jointly correlated with both classification and localization. Based on the analysis, the current query selection lead to a considerable level of uncertainty in the selected features, resulting in sub-optimal initialization for the decoder and hindering the performance of the detector.
To address this problem, we propose the uncertainty minimal query selection scheme, which explicitly constructs and
optimizes the epistemic uncertainty to model the joint latent variable of encoder features, thereby providing high-quality queries for the decoder. Specifically, the feature uncertainty U is defined as the discrepancy between the predicted distributions of localization P and classification C in Eq. (2). To minimize the uncertainty of the queries, we integrate the uncertainty into the loss function for the gradient-based optimization in Eq. (3).

The authors address foreground object detection through the Uncertainty U (Eq.2), which minimizes discrepancies between object location (P) and classification (C). This approach, integrated into the loss function (Eq.3), effectively distinguishes background from foreground without a explicit void class.

We cannot assure that RT-DETR is aimed for without "void" class. Uncertainty-minimal Query Selection is for accurate query selection from vanilla model and Void class is to add foreground and background classification which is similar effect but working in independent stage (query selection vs. loss part).

Unlike YOLO-based models that use fixed anchors to differentiate objects and background, DETR-series models, including RT-DETR, directly locate and classify objects through Transformer encoder-decoder architecture. RT-DETR advances this concept by replacing DETR's 'no object' class with the Uncertainty U mechanism.
Adding a "void" class to RT-DETR would contradict its core design principles and potentially compromise its structural advantages. While performance comparisons are valuable, fine-tuning RT-DETR with an added void class may not align with the model's original objectives.

I believe optimizing RT-DETR's performance within its original framework would be more beneficial. However, I'm open to further discussion on this matter, as diverse perspectives can lead to valuable insights in our field.

@qubvel
Copy link
Member

qubvel commented Aug 19, 2024

Hi @dwchoo, thanks for your answer and for providing details from the original paper.

As you mentioned RT-DETR was designed to be trained without void class, but it was also designed to use a sigmoid function in postprocessing. As far as I see all model configs in original repo specify use_focal_loss=True, thus the original model does not intend to be used with softmax for postprocessing as suggested in current PR.

However, in the original code, there is a cross-entropy loss function and postprocessing that may be used for it (I mean softmax + [:, :, :-1]), despite this combination not being used in any of the current model configurations.

So the one option is to strictly follow the original implementation and remove all unused code, including the not used loss functions and postprocessing steps.

The second option, while keeping the default implementation aligned with the original one, is to fix optional loss functions and allow the user to decide which losses to use and whether the void class should be added or not.

I don't see how it may compromise RT-DETR advantages, taking into consideration that default behavior kept aligned with the original one.

Please let me know what you think!

@dwchoo
Copy link
Author

dwchoo commented Aug 20, 2024

@qubvel
Thank you so much for your detailed response and insights! I greatly appreciate your time and effort in addressing this matter.

Could I ask for a bit more time on this? I've contacted the RT-DETR authors by email to get their direct feedback on this issue. While I'm still waiting for their response, I'm also planning to try contacting them through the original RT-DETR repository.

The question of the void class seems to be quite significant for this model, and I'm eager to ensure that we're proceeding with the most accurate understanding possible. I'm hoping to clarify whether this is simply an oversight or if perhaps I've misunderstood some aspect of the model's design.

I think hearing from the authors themselves will help clear things up and show us what to do next. I hope it's okay if we wait a little bit for their response.

Thank you again for your patience and understanding. I'm looking forward to continuing our discussion once I have more information to share.

@dwchoo
Copy link
Author

dwchoo commented Aug 29, 2024

@SangbumChoi @qubvel
Hi there! Sorry for the radio silence.

Quick update: I raised a PR with the original RT-DETR repo. The author acknowledged the issue but preferred to keep things as is for now, given the extensive changes required.

That said, I think we could still improve things in the transformers package.

How about we:

  1. Remove the use_focal_loss argument from post_process_object_detection
  2. Remove the softmax calculation code

This should clear up confusion and potential bugs. What do you think? If you're on board, I'm happy to update the PR with these changes.

Looking forward to hearing your thoughts!

@SangbumChoi
Copy link
Contributor

SangbumChoi commented Aug 29, 2024

@dwchoo After reviewing the conversation of the author and your PR, I think this current PR and suggestion looks like it is kind of workaround solution.
(Because I have never heard that using sigmoid in training and evaluation in softmax, if there is an example I would love to see it also!
Sigmoid treats each output independently, while Softmax normalizes across all outputs. Applying softmax to logits trained with sigmoid is unusual because the softmax expects outputs tailored for multi-class classification.)

As you mentioned I think you are right that author mentioned about 0, 1, 2 situation they have only shared the code of 0 solution. (2nd solution does not work properly at the moment in original + this huggingface code)
So instead of changing in this postprocessing pipeline, I recommend you to change in the config.num_labels pipeline to add arbitrary background class in order to keep the current postprocessing situation. I can help you if you want to implement in this way. However, let's see what Pavel suggests.

@dwchoo
Copy link
Author

dwchoo commented Sep 5, 2024

@qubvel Thank you for your ongoing engagement. I'd like to propose the following based on my analysis:

  1. Training Process:
    The model uses RTDetrHungarianMatcher for loss calculation during training. This function doesn't consider a void class, even when use_focal_loss=False (where it uses softmax without accounting for a void class).

  2. Implications of Adding Void Class:
    Incorporating a void class option throughout the codebase would require extensive changes to the training process.

  3. Current Issue Scope:
    The problem we're addressing is specifically related to post_process_object_detection, which isn't directly linked to the model's core functionality.

  4. Proposed Solution:
    I suggest removing the use_focal_loss argument from post_process_object_detection. This change wouldn't affect the RT-DETR model's overall structure.

  5. Flexibility for Users:
    Users can still create and attach their own post_process_object_detection function if needed.

  6. Reducing Confusion:
    Eliminating the use_focal_loss argument from post_process_object_detection would likely reduce user confusion.

This approach maintains the model's integrity while addressing the specific issue at hand. I believe it offers a good balance between fixing the problem and minimizing changes to the codebase.

I'm looking forward to your thoughts on this proposal.

@qubvel
Copy link
Member

qubvel commented Sep 5, 2024

Hi @dwchoo, thanks a lot for working on this and for your structured responses, I really appreciate this discussion and hope it will also be useful for the community members who will find it 🤗

Here are my thoughts on these points:

  1. Training Process:
    The model uses RTDetrHungarianMatcher for loss calculation during training. This function doesn't consider a void class, even when use_focal_loss=False (where it uses softmax without accounting for a void class).

The RTDetrHungarianMatcher is an extended version of DetrHungarianMatcher, and it can handle void class the same way

out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
target_ids = torch.cat([v["class_labels"] for v in targets])
target_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
class_cost = -out_prob[:, target_ids]

  1. Implications of Adding Void Class:
    Incorporating a void class option throughout the codebase would require extensive changes to the training process.

As you can see in this PR not too many changes have to be done to add a void class, its more about fixing already existing code. And I was able to conduct some experiments with no issues with it.

  1. Current Issue Scope:
    The problem we're addressing is specifically related to post_process_object_detection, which isn't directly linked to the model's core functionality.

Agreed, that we are addressing the postprocessing problem, however, it also touches the main functionality because each model has its own pre- and post-processing pipelines that can change obtained results significantly.

  1. Proposed Solution:
    I suggest removing the use_focal_loss argument from post_process_object_detection. This change wouldn't affect the RT-DETR model's overall structure.
  2. Flexibility for Users:
    Users can still create and attach their own post_process_object_detection function if needed.
  3. Reducing Confusion:
    Eliminating the use_focal_loss argument from post_process_object_detection would likely reduce user confusion.

We, potentially, can remove the use_focal_loss argument in postprocessing, however, I'm more biased towards renaming it. As I noticed before, I don't see any problem extending RT-DETR functionality, while keeping default behavior aligned with the original implementation and well-documenting additional features.

If you have the bandwidth and agree with my thoughts, you can handle renaming + deprecating of the argument in this PR to keep your contribution to the library. Please let me know what you think 🤗

@dwchoo
Copy link
Author

dwchoo commented Sep 19, 2024

@qubvel , I apologize for the delayed reply.

I believe your perspective is correct. Given the mix of various opinions, it might be beneficial to clarify and organize our approach. I would greatly appreciate your guidance on how to proceed with the PR modifications (in conjunction with #32658).

May I suggest summarizing our approach as follows?

  1. Provide users with the option to train with a void class (reflecting this in both training and pre/post-processing).
  2. Modify the name of the 'use_focal_loss' variable.

Could you please confirm if this summary aligns with your vision, or if there are any adjustments needed?

Thank you for your patience and guidance throughout this process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants