Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… support_prim_in_new_ir
  • Loading branch information
Charles-hit committed Aug 17, 2023
2 parents c161003 + dcfe2f1 commit 95642e3
Show file tree
Hide file tree
Showing 22 changed files with 295 additions and 351 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,13 @@ void FakeInitializeOutputsForFunctionKernel(
? DataType::INT64
: in_dtype;
}
} else if (op_type == "searchsorted") {
bool out_int32 = op.Attr<bool>("out_int32");
if (out_int32) {
dtype = DataType::INT32;
} else {
dtype = DataType::INT64;
}
} else {
VLOG(4) << "Get dtype result from InferMeta";
RuntimeInferShapeContext infer_shape_ctx(op, runtime_ctx);
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/new_executor/program_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ ProgramInterpreter::ProgramInterpreter(const platform::Place& place,

static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph &&
!execution_config.used_for_control_flow_op &&
interpreter::BlockCanBeStaticBuilt(block);

exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
Expand Down
7 changes: 1 addition & 6 deletions paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
if (map->has(rid)) {
// Use ProcessGroup
distributed::ProcessGroup* pg = map->get(rid);
std::vector<phi::DenseTensor> in_tensor;
std::vector<phi::DenseTensor> out_tensor;
in_tensor.push_back(*in);
out_tensor.push_back(*out);

distributed::AllreduceOptions opts;
switch (red_type) {
case kRedSum:
Expand All @@ -293,7 +288,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
"Invalid reduce type: %d", red_type));
}

auto task = pg->AllReduce(in_tensor, out_tensor, opts);
auto task = pg->AllReduce(out, *in, opts, false, true);
task->Wait();
return;
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/prim/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().AddSkipCompOps(op_type);
}

void PrimCommonUtils::SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types) {
for (const auto& item : op_types) {
StaticCompositeContext::Instance().AddSkipCompOps(item);
}
}

void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) {
StaticCompositeContext::Instance().RemoveSkipCompOps(op_type);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/prim/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class PrimCommonUtils {
static void SetAllPrimEnabled(bool enabled);
static size_t CheckSkipCompOps(const std::string& op_type);
static void AddSkipCompOps(const std::string& op_type);
static void SetPrimBackwardBlacklist(
const std::unordered_set<std::string>& op_types);
static void RemoveSkipCompOps(const std::string& op_type);
static void SetTargetGradName(const std::map<std::string, std::string>& m);
};
Expand Down
Loading

0 comments on commit 95642e3

Please sign in to comment.