Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a subdirectory named cinn in operators and move releated files into it #37938

Merged
merged 7 commits into from
Dec 8, 2021
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"

namespace paddle {
namespace framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"

Expand Down
16 changes: 6 additions & 10 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ if (WITH_LITE)
add_subdirectory(lite)
endif()

if(WITH_CINN)
add_subdirectory(cinn)
endif()

SET(OP_HEADER_DEPS xxhash executor)

if (WITH_GPU)
Expand Down Expand Up @@ -82,7 +86,7 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} pten pten_api_utils)

register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op cinn_launch_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
op_library(save_combine_op DEPS string_array)
Expand Down Expand Up @@ -167,14 +171,6 @@ if (WITH_ASCEND_CL)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner)
endif()

if (WITH_CINN)
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS transform_desc cinn_compiler cinn ${OP_HEADER_DEPS})
if (WITH_TESTING)
cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op)
set_tests_properties(cinn_launch_op_test PROPERTIES ENVIRONMENT OMP_NUM_THREADS=1)
endif()
endif()

# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
# op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
Expand Down Expand Up @@ -205,7 +201,7 @@ elseif(WITH_ROCM)
else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif()
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context share_buffer_op)
cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor device_context share_buffer_op)

cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON)
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
include(operators)
register_operators(EXCLUDES cinn_launch_op)

cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS cinn)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS scope lod_tensor cinn_launch_context)

op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS cinn cinn_compiler cinn_launch_context)
cc_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op)
set_tests_properties(cinn_launch_op_test PROPERTIES ENVIRONMENT "OMP_NUM_THREADS=1;runtime_include_dir=${PADDLE_BINARY_DIR}/third_party/CINN/src/external_cinn/cinn/runtime/cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add WITH_TESTING for CINN tests, like what I have done

Otherwise, if WITH_CINN=ON and WITH_TESTING=OFF, the compilation will fail.

Copy link
Contributor Author

@CtfGo CtfGo Dec 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it

Original file line number Diff line number Diff line change
Expand Up @@ -12,76 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/cinn_launch_op.h"

#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include <functional>
#include <vector>

#include "paddle/fluid/string/string_helper.h"

DECLARE_bool(cudnn_deterministic);

namespace paddle {
namespace operators {

namespace details {

const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place) {
if (platform::is_cpu_place(place)) {
return ::cinn::common::DefaultHostTarget();
} else if (platform::is_gpu_place(place)) {
return ::cinn::common::DefaultNVGPUTarget();
}

PADDLE_THROW(platform::errors::InvalidArgument(
"CINN is not supported on current place:%s", place));
return ::cinn::common::UnkTarget();
}

void DebugCinnCompiledResult(const CinnCompiledObject& result) {
if (!VLOG_IS_ON(4)) {
return;
}
const auto& cinn_runtime_program = result.runtime_program;
const auto& cinn_scope = *(result.scope);
const auto& paddle2cinn_varmap = result.paddle2cinn_varmap;

VLOG(4) << "Compiled runtime_program instrunction size:["
<< cinn_runtime_program->size() << "]";

std::vector<std::string> infos;
auto cinn_var_names = cinn_scope.var_names();
infos.reserve(cinn_var_names.size());
std::transform(cinn_var_names.begin(), cinn_var_names.end(),
std::back_inserter(infos),
[](const auto& name_view) { return name_view.data(); });
VLOG(4) << "Compiled scope variable names:["
<< string::join_strings(infos, ',') << "]";

infos.clear();
infos.reserve(paddle2cinn_varmap.size());
std::transform(paddle2cinn_varmap.begin(), paddle2cinn_varmap.end(),
std::back_inserter(infos), [](const auto& paddle2cinn) {
return paddle2cinn.first + "->" + paddle2cinn.second;
});
VLOG(4) << "Compiled paddle2cinn_varmap:[" << string::join_strings(infos, ',')
<< "]";
}

void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
const CinnLaunchContext& context, void* stream) {
compiled_obj.runtime_program->Execute(&context.FinalizeArguments(), stream);
}

void SetCinnRuntimeFlags() {
VLOG(4) << "Set FLAGS_cinn_cudnn_deterministic to "
<< FLAGS_cudnn_deterministic;
::cinn::runtime::SetCinnCudnnDeterministic(FLAGS_cudnn_deterministic);
}

CinnLaunchContext::CinnLaunchContext(const CinnCompiledObject& compiled_obj)
: paddle2cinn_varmap_(compiled_obj.paddle2cinn_varmap),
cinn_scope_(compiled_obj.scope) {
CinnLaunchContext::CinnLaunchContext(
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::shared_ptr<CinnScope>& cinn_scope)
: paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
auto var_names = cinn_scope_->var_names();
cinn_variable_names_.reserve(var_names.size());
std::transform(
Expand Down Expand Up @@ -221,90 +163,5 @@ CinnLaunchContext::FinalizeArguments() const {
}

} // namespace details

class CinnLaunchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}

protected:
/* [Why use single type kernel]:
*
* This op is similar to a control flow op, it doses not need
* a op kernel, but in order to make it execute under dynamic
* graph mode, implement it with op kernel.
*
* So whether the kernel data type is int, float or other type,
* which has no effect on its execution logic, so directly
* specified a data type here.
*
* Of course, the data type here is also not important.
*/

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};

class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kX,
"(vector<LoDTensor>)"
"which are the input of graph inside the CinnLaunchOp.")
.AsDuplicable();
AddOutput(kOutputs,
"(vector<LoDTensor>)"
"which are the output of graph inside the CinnLaunchOp.")
.AsDuplicable();
AddAttr<std::string>(
kCompilationKey,
"(string)"
"a hash key used to get the graph object or its computation result.");
AddComment(R"DOC(
CinnLaunch Operator.

This operator is used to launch CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md)
to compile a graph and execute the compiled object.

Both input and output of this operator are a set of variables
which are input and output of the graph respectively that will be
compiled and executed in this operator.
In addition, there is an attribute named 'compilation_key' should be
set necessarily to get corresponding ir::Graph object of the graph
or its computation result.

It accomplishes the computation of graph following several steps:
1. Fetch ir::Graph object from CinnCompiler using kCompilationKey
2. Compile the graph to a compiled object, and insert it to the
global cache so that we can directly query it from this cache next time
when shape of input variables are not changed at all.
3. Create and instantiate all variables used to execute compiled runtime program
if necessary according to the info(type,shape) included in the return scope.
4. Pack each tensor buffer of all above variables as execution arguments.
5. Launch execution of the runtime program with above arguments, then
the result would be output by writing value on underlying buffer address.

)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(
cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* see [Why use single type kernel] */
REGISTER_OP_CPU_KERNEL(
cinn_launch,
ops::CinnLaunchOpKernel<paddle::platform::CPUDeviceContext, float>);
104 changes: 104 additions & 0 deletions paddle/fluid/operators/cinn/cinn_launch_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "cinn/hlir/framework/scope.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/place.h"

namespace paddle {
namespace operators {
namespace details {

using LoDTensor = framework::LoDTensor;
using CinnTensor = ::cinn::hlir::framework::Tensor;
using CinnScope = ::cinn::hlir::framework::Scope;

class CinnLaunchContext {
public:
explicit CinnLaunchContext(
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::shared_ptr<CinnScope>& cinn_scope);

// Return whether a Paddle variable used on compiled kernels
bool IsVariableUsed(const std::string& var_name);

// Assign tensor buffer to input or output variables
void AssignExternalVariable(const std::string& var_name,
const platform::Place& place, LoDTensor* tensor);

// Assign tensor buffer to internal variables
void AssignInternalVariable(const std::string& var_name,
const platform::Place& place, LoDTensor* tensor);

// Extract internal variable names from CinnScope
// by excluding used input and output variables
std::unordered_set<std::string> GetInternalVariableNames();

// Finalize all execution arguments and return them
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const;

std::vector<std::unique_ptr<cinn_buffer_t>> HandoverBuffers() {
return std::move(hold_buffers_);
}

private:
// Get CinnTensor with CINN variable name
CinnTensor GetCinnTensor(const std::string& var_name);

// Check whether tensors from Paddle and CINN of the same variable
// are equivalent in type and dimension
void CheckTensorEquivalent(const std::string& var_name,
const LoDTensor& paddle_tensor,
const CinnTensor& cinn_tensor);

// Share the buffer of a Paddle tensor to CINN by delivering memory address
// to a cinn_buffer_t object
std::unique_ptr<cinn_buffer_t> ShareTensorWithCinnBuffer(
const platform::Place& place, bool free_mem_callback, LoDTensor* tensor);

// Set an argument with (cinn name)->(paddle tensor) pair
void SetArgument(const std::string& cinn_name, const platform::Place& place,
bool free_mem_callback, LoDTensor* paddle_tensor);

private:
// a variable name map from paddle to cinn
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap_;
// the variable scope of cinn
const std::shared_ptr<CinnScope> cinn_scope_;

// all variables used by compiled executable program
std::unordered_set<std::string> cinn_variable_names_;

// because a cinn_pod_value_t does not own the cinn_buffer_t object,
// an extra stroage is necessary to keep the object and it can
// not be released until runtime program finish execution.
std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers_;

// name to execution argument
std::map<std::string, cinn_pod_value_t> name2argument_;
};

} // namespace details
} // namespace operators
} // namespace paddle
Loading