Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gzy19990617 committed Feb 27, 2023
1 parent 84e9632 commit a017a1d
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 24 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2560,7 +2560,7 @@ USE_TRT_CONVERTER(preln_groupnorm_act)
USE_TRT_CONVERTER(flash_multihead_matmul)
USE_TRT_CONVERTER(cross_multihead_matmul)
#endif
#if IS_TRT_VERSION_GE(8500)
#if IS_TRT_VERSION_GE(8510)
USE_TRT_CONVERTER(grid_sampler)
#endif
#if IS_TRT_VERSION_GE(8200)
Expand Down
16 changes: 5 additions & 11 deletions paddle/fluid/inference/tensorrt/convert/grid_sampler_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,16 +14,6 @@ limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {
Expand All @@ -36,6 +26,7 @@ class GridSamplerOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(8510)
VLOG(3) << "convert a fluid grid_sampler op to tensorrt GridSample layer";
framework::OpDesc op_desc(op, nullptr);
std::string input_x_name = op_desc.Input("X").front();
Expand Down Expand Up @@ -83,6 +74,9 @@ class GridSamplerOpConverter : public OpConverter {
layer->setAlignCorners(align_corners);

RreplenishLayerAndOutput(layer, "grid_sampler", {output_name}, test_mode);
#else
VLOG(3) << "grid_sampler is not supported when TensorRT < 8.5.1";
#endif
}
};

Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ nvinfer1::DimsExprs PNormInferMeta(
return output;
}

#if !IS_TRT_VERSION_GE(8500)
nvinfer1::DimsExprs GridSamplerInferMeta(
int output_index,
const nvinfer1::DimsExprs* inputs,
Expand All @@ -387,7 +386,6 @@ nvinfer1::DimsExprs GridSamplerInferMeta(
}
return output;
}
#endif

PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta);
Expand All @@ -397,9 +395,7 @@ PD_REGISTER_DYNAMIC_INFER_META_FN(scatter_nd_add, ScatterNdAddInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta);
PD_REGISTER_DYNAMIC_INFER_META_FN(pad3d, Pad3dInferMeta);
#if !IS_TRT_VERSION_GE(8500)
PD_REGISTER_DYNAMIC_INFER_META_FN(grid_sampler, GridSamplerInferMeta);
#endif
} // namespace tensorrt
} // namespace inference
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#pragma once

#include "paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_factory.h"
#include "paddle/fluid/inference/tensorrt/helper.h"

namespace paddle {
namespace inference {
Expand All @@ -28,9 +27,7 @@ USE_TRT_DYNAMIC_INFER_META_FN(unfold);
USE_TRT_DYNAMIC_INFER_META_FN(scatter_nd_add);
USE_TRT_DYNAMIC_INFER_META_FN(pad3d);
USE_TRT_DYNAMIC_INFER_META_FN(inverse);
#if !IS_TRT_VERSION_GE(8500)
USE_TRT_DYNAMIC_INFER_META_FN(grid_sampler);
#endif
} // namespace tensorrt
} // namespace inference
} // namespace paddle
4 changes: 2 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2543,8 +2543,8 @@ struct SimpleOpTypeSetTeller : public Teller {
}

if (op_type == "grid_sampler") {
#if !IS_TRT_VERSION_GE(8500)
VLOG(3) << "grid_sampler is not supported when TensorRT < 8.5";
#if !IS_TRT_VERSION_GE(8510)
VLOG(3) << "grid_sampler is not supported when TensorRT < 8.5.1";
return false;
#else
if (!with_dynamic_shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
class TrtConvertGridSampler(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
self.trt_param.workspace_size = 1073741824
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8500:
return False
return True

def sample_program_configs(self):
Expand Down

0 comments on commit a017a1d

Please sign in to comment.