Skip to content

Commit

Permalink
[Perf]Polish UniformRandom And Split it into ScheduleBlock (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#1357)

本PR因和paddle联编测试需两边修改,现CINN强行合入,待Paddle对应PR合入后CI可正常。
  • Loading branch information
Aurelius84 authored and jiahy0825 committed May 25, 2023
1 parent 299bb50 commit d5652c9
Show file tree
Hide file tree
Showing 10 changed files with 287 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ endif()

if (WITH_CUDA)
nv_test(test_codegen_cuda_generate SRCS codegen_cuda_generate_test.cc DEPS cinncore)
nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore)
nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore cinn_runtime)

if (WITH_TESTING)
cc_library(generated1_cuda SRCS generated1.cu DEPS cinncore)
Expand Down
2 changes: 1 addition & 1 deletion cinn/backends/codegen_debug_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ TEST(CodeGenDebug, RunCudaSourceCode) {
common::Context::Global().ResetNameId();

std::string source_code = R"ROC(
extern "C" {
#include "cinn_cuda_runtime_source.cuh"
extern "C" {
#ifdef __CUDACC_RTC__
typedef int int32_t;
Expand Down
12 changes: 11 additions & 1 deletion cinn/backends/extern_func_protos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() {
static const std::vector<std::string> extern_funcs_float_bool_unary = {"isnan", "isfinite", "isinf"};
static const std::vector<std::string> extern_funcs_int_binary = {
"left_shift", "right_shift", "bitwise_or", "bitwise_and", "bitwise_xor", "bitwise_not"};
static const std::vector<std::string> extern_funcs_int_int_unary = {"bitwise_not"};
static const std::vector<std::string> extern_funcs_int_int_unary = {"bitwise_not"};
static const std::vector<std::string> extern_funcs_int_float_call = {"cinn_nvgpu_uniform_random_fp32"};
static const std::vector<std::string> extern_funcs_int_double_call = {"cinn_nvgpu_uniform_random_fp64"};
for (int i = 0; i < extern_funcs_fp32_unary.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_fp32_unary[i], {Float(32)}, Float(32));
Register(proto->name, proto);
Expand All @@ -44,6 +46,14 @@ ExternFunctionProtoRegistry::ExternFunctionProtoRegistry() {
auto* proto = new FunctionProto(extern_funcs_int_int_unary[i], {Int(32)}, Int(32));
Register(proto->name, proto);
}
for (int i = 0; i < extern_funcs_int_float_call.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_int_float_call[i], {Int(32)}, Float(32));
Register(proto->name, proto);
}
for (int i = 0; i < extern_funcs_int_double_call.size(); ++i) {
auto* proto = new FunctionProto(extern_funcs_int_double_call[i], {Int(32)}, Float(64));
Register(proto->name, proto);
}

auto* n = detail::CreateTanhVProto();
Register(n->name, n);
Expand Down
3 changes: 3 additions & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,6 @@ cc_test(test_repeat SRCS repeat_test.cc DEPS cinncore)
cc_test(test_one_hot SRCS one_hot_test.cc DEPS cinncore)
cc_test(test_lookup_table SRCS lookup_table_test.cc DEPS cinncore)
cc_test(test_reciprocal SRCS reciprocal_test.cc DEPS cinncore)
if (WITH_CUDA)
cc_test(test_uniform_random_gpu SRCS uniform_random_test.cc DEPS cinncore)
endif()
52 changes: 48 additions & 4 deletions cinn/hlir/op/contrib/uniform_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/hlir/op/contrib/uniform_random.h"

#include <gflags/gflags.h>

Expand Down Expand Up @@ -45,13 +46,43 @@
#include "cinn/poly/stage.h"
#include "glog/logging.h"

DECLARE_bool(cinn_ir_schedule);

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

// Only for min = 0. and max = 1.
ir::Tensor UniformRandom(const std::vector<int> &shape,
int seed,
const std::string &dtype,
const Target &target,
const std::string &tensor_name) {
std::string extern_func = "cinn_nvgpu_uniform_random_";
if (target != common::DefaultNVGPUTarget()) {
LOG(FATAL) << "Not Implemented UniformRandom for target: " << target;
}

if (dtype == "float32") {
extern_func += "fp32";
} else if (dtype == "float64") {
extern_func += "fp64";
} else {
LOG(FATAL) << "Not Implemented UniformRandom for dtype: " << dtype;
}

std::vector<Expr> new_shape;
for (auto item : shape) {
new_shape.push_back(Expr(item));
}

return lang::Compute(
new_shape, [=]() { return lang::CallExtern(extern_func, {Expr(seed)}); }, tensor_name);
}

std::shared_ptr<framework::OpStrategy> StrategyForUniformRandom(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
Expand All @@ -60,9 +91,22 @@ std::shared_ptr<framework::OpStrategy> StrategyForUniformRandom(const framework:
framework::CINNCompute uniform_random_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(attrs.attr_store.count("shape"));
ir::Tensor shape_tensor;
std::string tensor_name = "uniform_random_out";
auto out = pe::Identity(shape_tensor, tensor_name).front();
auto stages = CreateStages({out});
CHECK(output_shapes.size() == 1UL);
CHECK(attrs.attr_store.count("seed"));
int seed = absl::get<int>(attrs.attr_store.at("seed"));
std::string dtype = "float32";
if (attrs.attr_store.find("dtype") != attrs.attr_store.end()) {
dtype = absl::get<std::string>(attrs.attr_store.at("dtype"));
}
CINNValuePack arg_pack = args[0];
std::string tensor_name = UniqName("uniform_random_out");
if (FLAGS_cinn_ir_schedule) {
CHECK_EQ(arg_pack.size(), 1U);
CHECK(arg_pack[0].is_string());
tensor_name = arg_pack[0].operator std::string();
}
auto out = UniformRandom(output_shapes[0], seed, dtype, target, tensor_name);
auto stages = CreateStages({out});
std::vector<CINNValue> res{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{res};
});
Expand Down Expand Up @@ -104,7 +148,7 @@ CINN_REGISTER_HELPER(uniform_random_ops) {
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForUniformRandom)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForUniformRandom))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForUniformRandom))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise)
.set_support_level(4);

return true;
Expand Down
37 changes: 37 additions & 0 deletions cinn/hlir/op/contrib/uniform_random.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>
#include <vector>

#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"

namespace cinn {
namespace hlir {
namespace op {

// Only for min = 0. and max = 1.
ir::Tensor UniformRandom(const std::vector<int>& shape,
int seed,
const std::string& dtype,
const Target& target,
const std::string& tensor_name);

} // namespace op
} // namespace hlir
} // namespace cinn
167 changes: 167 additions & 0 deletions cinn/hlir/op/contrib/uniform_random_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cinn/hlir/op/contrib/uniform_random.h"

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <string>
#include <vector>

#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/backends/codegen_cuda_dev.h"
#include "cinn/backends/codegen_cuda_util.h"
#include "cinn/common/context.h"
#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/optimize.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/lang/lower.h"
#include "cinn/lang/placeholder.h"
#include "cinn/poly/stage.h"
#include "cinn/utils/data_util.h"

namespace cinn {
namespace hlir {
namespace op {

#ifdef CINN_WITH_CUDA
TEST(GenerateCode_CUDA, UniformRandomGPU) {
common::Context::Global().ResetNameId();

common::Target target = common::DefaultNVGPUTarget();

std::vector<int> shape = {128, 12};
int seed = 2023;
std::string dtype = "float32";

ir::Tensor res = UniformRandom(shape, seed, dtype, target, "uniform_random_out");

poly::StageMap stages = poly::CreateStages({res});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestGenerateCodeGPU_UniformRandom", stages, {res}, {}, {}, nullptr, target, true);

VLOG(6) << "Expr before CUDA codegen:";
VLOG(6) << funcs[0]->body;

ir::Module::Builder builder("UniformRandom_Module", target);
for (auto& f : funcs) {
builder.AddFunction(f);
}

auto module = builder.Build();
auto host_module_device_module = backends::SplitCudaAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(host_module_device_module);
auto& device_module = std::get<1>(host_module_device_module);

backends::CodeGenCUDA_Dev codegen(target);
std::string source_code = codegen.Compile(device_module);
LOG(INFO) << "compiled code:\n" << source_code;
}

} // namespace op
} // namespace hlir

namespace frontend {

TEST(Builder, UniformRandomFP32) {
NetBuilder builder("net_builder");

std::vector<int> shape = {128, 12, 128, 128};
int seed = 2023;
std::string dtype = "float32";
auto out = builder.UniformRandom(shape, 0., 1., seed, dtype);
auto program = builder.Build();

for (int i = 0; i < program.size(); ++i) {
LOG(INFO) << "instruction: " << program[i];
}

Target target = common::DefaultNVGPUTarget();
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);

LOG(INFO) << "graph: \n" << graph->Visualize();

auto scope = BuildScope(target, graph);

hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

auto out_ten = scope->GetTensor(std::string(out->id));
runtime_program->Execute();

EXPECT_EQ(out_ten->type(), Float(32));

std::vector<float> data = GetTensorData<float>(out_ten, target);

int cnt = 0;
for (int i = 0; i < 128 * 12 * 128 * 128; ++i) {
if (data[i] > 0.5) cnt++;
}
float ratio = (float)cnt / (128 * 12 * 128 * 128);
LOG(INFO) << "count: " << cnt;
LOG(INFO) << "x > 0.5f ratio: " << ratio;
EXPECT_LE(ratio, 0.501f);
EXPECT_GE(ratio, 0.499f);
}

TEST(Builder, UniformRandomFP64) {
NetBuilder builder("net_builder");

std::vector<int> shape = {128, 12, 128, 128};
int seed = 2023;
std::string dtype = "float64";
auto out = builder.UniformRandom(shape, 0., 1., seed, dtype);
auto program = builder.Build();

for (int i = 0; i < program.size(); ++i) {
LOG(INFO) << "instruction: " << program[i];
}

Target target = common::DefaultNVGPUTarget();
std::unordered_set<std::string> fetch_ids;
auto graph = Optimize(&program, fetch_ids, target);

LOG(INFO) << "graph: \n" << graph->Visualize();

auto scope = BuildScope(target, graph);

hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();

auto out_ten = scope->GetTensor(std::string(out->id));
runtime_program->Execute();

EXPECT_EQ(out_ten->type(), Float(64));

std::vector<double> data = GetTensorData<double>(out_ten, target);

int cnt = 0;
for (int i = 0; i < 128 * 12 * 128 * 128; ++i) {
if (data[i] > 0.5) cnt++;
}

float ratio = (float)cnt / (128 * 12 * 128 * 128);
LOG(INFO) << "count: " << cnt;
LOG(INFO) << "x > 0.5f ratio: " << ratio;
EXPECT_LE(ratio, 0.501f);
EXPECT_GE(ratio, 0.499f);
}
#endif

} // namespace frontend

} // namespace cinn
1 change: 0 additions & 1 deletion cinn/hlir/op/external_api_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ CINN_REGISTER_HELPER(op_external_api) {
CINN_OP_REGISTER_EXTERNAL_API(cublas_gemm, default_nvgpu).set_api_name("cinn_call_cublas");
CINN_OP_REGISTER_EXTERNAL_API(cublas_matmul, default_nvgpu).set_api_name("cinn_call_cublas");
CINN_OP_REGISTER_EXTERNAL_API(gaussian_random, default_nvgpu).set_api_name("cinn_call_gaussian_random");
CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random");
CINN_OP_REGISTER_EXTERNAL_API(randint, default_nvgpu).set_api_name("cinn_call_randint");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host");
Expand Down
18 changes: 18 additions & 0 deletions cinn/runtime/cuda/cinn_cuda_runtime_source.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
/**
* \file This file contains all the intrinsics available to be used in CUDA code generated by CodeGen.
*/

#include <cuda_runtime.h>
#include <curand_kernel.h>

extern "C" {
// *************************************************************** //
// float32 unary and binary operator
Expand Down Expand Up @@ -342,6 +346,20 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left
shfl_res = __shfl_down_sync(mask, tmp_val, offset, 32); \
tmp_val = op((threadIdx.x & 0x1f) + offset < lane ? shfl_res : init, tmp_val);

__device__ inline float cinn_nvgpu_uniform_random_fp32(int seed){
curandStatePhilox4_32_10_t state;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, idx, 1, &state);
return curand_uniform(&state);
}

__device__ inline double cinn_nvgpu_uniform_random_fp64(int seed){
curandStatePhilox4_32_10_t state;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
curand_init(seed, idx, 1, &state);
return curand_uniform_double(&state);
}

#define CINN_WARP_SHUFFLE_INTERNAL_IMPL(REDUCE_TYPE, INITIAL_VALUE, DTYPE) \
__device__ inline DTYPE cinn_warp_shuffle_##REDUCE_TYPE##_internal(const DTYPE value) { \
DTYPE tmp_val = value, shfl_res; \
Expand Down
1 change: 1 addition & 0 deletions cinn/utils/data_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ std::vector<T> GetTensorData(const hlir::framework::Tensor& tensor, const common
}

template std::vector<float> GetTensorData<float>(const hlir::framework::Tensor& tensor, const common::Target& target);
template std::vector<double> GetTensorData<double>(const hlir::framework::Tensor& tensor, const common::Target& target);
template std::vector<int> GetTensorData<int>(const hlir::framework::Tensor& tensor, const common::Target& target);

} // namespace cinn

0 comments on commit d5652c9

Please sign in to comment.