Skip to content

Commit

Permalink
[PaddleTRT] Fixed bug in generating quantization calibration table fo…
Browse files Browse the repository at this point in the history
…r yolo_box (PaddlePaddle#61596)

* 修复yolov3量化表yolo_box的bug
  • Loading branch information
lizexu123 committed Feb 5, 2024
1 parent 60325a1 commit 0f23c62
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
7 changes: 5 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,13 @@ class YoloBoxOpConverter : public OpConverter {
? PADDLE_GET_CONST(float, op_desc.GetAttr("iou_aware_factor"))
: 0.5;

int type_id = static_cast<int>(engine_->WithFp16());
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == phi::DataType::INT8) {
with_fp16 = true;
}
auto input_dim = X_tensor->getDimensions();
auto* yolo_box_plugin = new plugin::YoloBoxPlugin(
type_id ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT,
anchors,
class_num,
conf_thresh,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"
#include <algorithm>
#include <cassert>

#include "paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
Expand Down Expand Up @@ -102,7 +101,9 @@ nvinfer1::Dims YoloBoxPlugin::getOutputDimensions(

bool YoloBoxPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::TensorFormat format) const TRT_NOEXCEPT {
return ((type == data_type_ || type == nvinfer1::DataType::kINT32) &&
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF ||
type == nvinfer1::DataType::kINT32) &&
format == nvinfer1::TensorFormat::kLINEAR);
}

Expand Down

0 comments on commit 0f23c62

Please sign in to comment.