Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

phi_multiclass_nms3 #44613

Merged
merged 1 commit into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions paddle/fluid/operators/detection/multiclass_nms_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ limitations under the License. */

#include <glog/logging.h>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/nms_util.h"
#include "paddle/phi/infermeta/ternary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -609,12 +611,6 @@ class MultiClassNMS3Op : public MultiClassNMS2Op {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMS2Op::InferShape(ctx);

ctx->SetOutputDim("NmsRoisNum", {-1});
}
};

class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
Expand All @@ -633,6 +629,10 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
MultiClassNMSShapeFunctor,
PD_INFER_META(phi::MultiClassNMSInferMeta));

namespace ops = paddle::operators;
REGISTER_OPERATOR(
multiclass_nms,
Expand All @@ -658,7 +658,5 @@ REGISTER_OPERATOR(
ops::MultiClassNMS3Op,
ops::MultiClassNMS3OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms3,
ops::MultiClassNMSKernel<float>,
ops::MultiClassNMSKernel<double>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MultiClassNMSShapeFunctor);
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,15 @@
func : multi_dot
backward : multi_dot_grad

- api : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
infer_meta :
func : MultiClassNMSInferMeta
kernel :
func : multiclass_nms3
optional : rois_num

# multinomial
- api : multinomial
args : (Tensor x, int num_samples, bool replacement)
Expand Down
93 changes: 93 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,99 @@ void LinspaceInferMeta(const MetaTensor& start,
LinspaceRawInferMeta(start, stop, number, out);
}

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config) {
auto box_dims = bboxes.dims();
auto score_dims = scores.dims();
auto score_size = score_dims.size();

if (config.is_runtime) {
PADDLE_ENFORCE_EQ(
score_size == 2 || score_size == 3,
true,
errors::InvalidArgument("The rank of Input(Scores) must be 2 or 3"
". But received rank = %d",
score_size));
PADDLE_ENFORCE_EQ(
box_dims.size(),
3,
errors::InvalidArgument("The rank of Input(BBoxes) must be 3"
". But received rank = %d",
box_dims.size()));
if (score_size == 3) {
PADDLE_ENFORCE_EQ(box_dims[2] == 4 || box_dims[2] == 8 ||
box_dims[2] == 16 || box_dims[2] == 24 ||
box_dims[2] == 32,
true,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4 or 8, "
"represents the layout of coordinate "
"[xmin, ymin, xmax, ymax] or "
"4 points: [x1, y1, x2, y2, x3, y3, x4, y4] or "
"8 points: [xi, yi] i= 1,2,...,8 or "
"12 points: [xi, yi] i= 1,2,...,12 or "
"16 points: [xi, yi] i= 1,2,...,16"));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[2],
errors::InvalidArgument(
"The 2nd dimension of Input(BBoxes) must be equal to "
"last dimension of Input(Scores), which represents the "
"predicted bboxes."
"But received box_dims[1](%s) != socre_dims[2](%s)",
box_dims[1],
score_dims[2]));
} else {
PADDLE_ENFORCE_EQ(box_dims[2],
4,
errors::InvalidArgument(
"The last dimension of Input"
"(BBoxes) must be 4. But received dimension = %d",
box_dims[2]));
PADDLE_ENFORCE_EQ(
box_dims[1],
score_dims[1],
errors::InvalidArgument(
"The 2nd dimension of Input"
"(BBoxes) must be equal to the 2nd dimension of Input(Scores). "
"But received box dimension = %d, score dimension = %d",
box_dims[1],
score_dims[1]));
}
}
PADDLE_ENFORCE_NE(out,
nullptr,
errors::InvalidArgument(
"The out in MultiClassNMSInferMeta can't be nullptr."));
PADDLE_ENFORCE_NE(
index,
nullptr,
errors::InvalidArgument(
"The index in MultiClassNMSInferMeta can't be nullptr."));
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.

out->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
out->set_dtype(bboxes.dtype());
index->set_dims(phi::make_ddim({-1, box_dims[2] + 2}));
index->set_dtype(DataType::INT32);
nms_rois_num->set_dims(phi::make_ddim({-1}));
nms_rois_num->set_dtype(DataType::INT32);
}

void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ void LinspaceInferMeta(const MetaTensor& start,
DataType dtype,
MetaTensor* out);

void MultiClassNMSInferMeta(const MetaTensor& bboxes,
const MetaTensor& scores,
const MetaTensor& rois_num,
float score_threshold,
int nms_top_k,
int keep_top_k,
float nms_threshold,
bool normalized,
float nms_eta,
int background_label,
MetaTensor* out,
MetaTensor* index,
MetaTensor* nms_rois_num,
MetaConfig config = MetaConfig());

void NllLossRawInferMeta(const MetaTensor& input,
const MetaTensor& label,
const MetaTensor& weight,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ set(COMMON_KERNEL_DEPS
lod_utils
custom_kernel
string_infermeta
utf8proc)
utf8proc
gpc)

copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})

Expand Down
Loading