Skip to content

Commit

Permalink
[Prim][PIR] Modify the function get_reduce_dims_from_out (#65666)
Browse files Browse the repository at this point in the history
* add some config and modify the function get_reduce_dims_from_out

* modify the code

* fix the bug

* modify the useless part
  • Loading branch information
zeroRains authored Jul 8, 2024
1 parent 036cfda commit 35effb3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 7 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/primitive/base/decomp_trans.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ using Program = pir::Program;

// some outputs like xshape will no longer used after decomp, and those outputs
// will skip checking.
std::unordered_set<std::string> decomp_op_contain_none = {"pd_op.squeeze",
"pd_op.unsqueeze",
"pd_op.flatten",
"pd_op.batch_norm",
"pd_op.batch_norm_",
"pd_op.dropout"};
std::unordered_set<std::string> decomp_op_contain_none = {
"pd_op.squeeze",
"pd_op.unsqueeze",
"pd_op.flatten",
"pd_op.batch_norm",
"pd_op.batch_norm_",
"pd_op.dropout",
"pd_op.instance_norm",
};
//
std::unordered_set<std::string> dynamic_shape_blacklist = {"pd_op.squeeze",
"pd_op.unsqueeze",
Expand Down
30 changes: 29 additions & 1 deletion paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,38 @@ static std::vector<int64_t> process_dims(const Tensor& origin,
}

// These method don't need to be specified
// These method only handle the static shape case
static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims,
const phi::DDim& in_dims) {
std::vector<int64_t> result;
bool has_dynamic_shape = false;
for (int i = 0; i < dout_dims.size(); i++) {
if (dout_dims[i] == -1) {
has_dynamic_shape = true;
break;
}
}
PADDLE_ENFORCE_EQ(
has_dynamic_shape,
false,
platform::errors::InvalidArgument(
"Function get_reduce_dims_from_out() only use in static shape case, "
"but the input [dout_dims] have the dynamic shape."));

for (int i = 0; i < in_dims.size(); i++) {
if (in_dims[i] == -1) {
has_dynamic_shape = true;
break;
}
}
PADDLE_ENFORCE_EQ(
has_dynamic_shape,
false,
platform::errors::InvalidArgument(
"Function get_reduce_dims_from_out() only use in static shape case, "
"but the input [in_dims] have the dynamic shape."));

int bat = dout_dims.size() - in_dims.size();
std::vector<int64_t> result;
for (int i = 0; i < bat; ++i) {
result.push_back(i);
}
Expand Down

0 comments on commit 35effb3

Please sign in to comment.