Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Add triangular solve op #1224

Merged
merged 11 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <vector>

#include "cinn/frontend/syntax.h"
#include "cinn/hlir/pe/broadcast.h"

namespace cinn {
namespace frontend {
Expand Down Expand Up @@ -660,6 +661,42 @@ Variable NetBuilder::Cholesky(const Variable& x, bool upper) {
return CustomInstr("cholesky", {x}, {{"upper", upper}}).front();
}

Variable NetBuilder::TriangularSolve(
const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal) {
// broadcast
std::vector<Variable> inputs{input1, input2};
{
auto a_ndim = input1->shape.size();
auto b_ndim = input2->shape.size();
CHECK_GE(a_ndim, 2) << "The input matrix A shape size should >= 2! Please check again.";
CHECK_GE(b_ndim, 2) << "The input matrix B shape size should >= 2! Please check again.";
std::vector<int> input1_shape_cut(input1->shape.begin(), input1->shape.end() - 2);
std::vector<int> input2_shape_cut(input2->shape.begin(), input2->shape.end() - 2);
std::vector<int> common_shape;
hlir::pe::GetBroadcastOutShape(input1_shape_cut, input2_shape_cut, &common_shape);

// broadcast input1
std::vector<int> input1_shape(common_shape.begin(), common_shape.end());
input1_shape.push_back(input1->shape[a_ndim - 2]);
input1_shape.push_back(input1->shape[a_ndim - 1]);
inputs[0] = BroadcastTo(input1, input1_shape);

// broadcast input2
std::vector<int> input2_shape(common_shape.begin(), common_shape.end());
input2_shape.push_back(input2->shape[b_ndim - 2]);
input2_shape.push_back(input2->shape[b_ndim - 1]);
inputs[1] = BroadcastTo(input2, input2_shape);
}

return CustomInstr("triangular_solve",
inputs,
{{"left_side", left_side},
{"upper", upper},
{"transpose_a", transpose_a},
{"unit_diagonal", unit_diagonal}})
.front();
}

Variable NetBuilder::Norm(const Variable& x, int axis, float epsilon) {
Instruction instr("norm", {x});
instr.SetAttr<int32_t>("axis", axis);
Expand Down
15 changes: 15 additions & 0 deletions cinn/frontend/net_builder.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,21 @@ class NetBuilder {
*/
Variable Cholesky(const Variable& x, bool upper = false);

/**
* @brief Solve triangular linear systems with multiple right-hand-sides.
* @param input1 triangular matrix stored in lower or upper mode.
* @param input2 matrix on the right hand side.
* @param left_side When left_side is true, compute A*X = B.
When left_side is false, compute X*A = B.
* @param upper When upper is true, use the upper part of the triangular matrix.
When upper is false, use the lower part of the triangular matrix.
* @param transpose_a When transpose_a is true, use the transpose of matrix A
* @param unit_diagonal When unit_diagonal is true, assume the elements on the main diagonal of matrix A are unity
* @return The solution for the triangular linear systems.
*/
Variable TriangularSolve(
const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal);

/**
* @brief l2-Norm
* @param x The input operand to be normed.
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ gather_srcs(cinnapi_src SRCS
gaussian_random.cc
uniform_random.cc
cholesky.cc
triangular_solve.cc
)

cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
Expand Down
121 changes: 121 additions & 0 deletions cinn/hlir/op/contrib/triangular_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <vector>

#include "cinn/common/common.h"
#include "cinn/common/macros.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/op/op_util.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

std::shared_ptr<framework::OpStrategy> StrategyForTriangularSolve(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute triangular_solve_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of triangular_solve is empty! Please check.";
CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of triangular_solve.";
Expr a_expr = pack_args[0];
Expr b_expr = pack_args[1];
ir::Tensor a = a_expr.as_tensor_ref();
ir::Tensor b = b_expr.as_tensor_ref();
std::string tensor_name = "triangular_solve_out";
auto out = pe::Identity(b, tensor_name).front();
auto stages = CreateStages({out});
std::vector<CINNValue> res{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{res};
});
auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
triangular_solve_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.triangular_solve.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForTriangularSolve(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again.";
framework::shape_t a_shape = inputs_shape[0];
framework::shape_t b_shape = inputs_shape[1];
int a_shape_size = a_shape.size();
int b_shape_size = b_shape.size();
CHECK_GE(a_shape_size, 2U) << "The input matrix A shape size should >= 2! Please check again.";
CHECK_GE(b_shape_size, 2U) << "The input matrix B shape size should >= 2! Please check again.";

int left_side = -1;
for (auto &iter : attrs) {
if (iter.first == "left_side") {
left_side = absl::get<bool>(iter.second);
break;
}
}

CHECK_EQ(a_shape[a_shape_size - 2], a_shape[a_shape_size - 1])
<< "The last two dimensions of the input a must be the same!";
if (left_side) {
CHECK_EQ(a_shape[a_shape_size - 2], b_shape[b_shape_size - 2])
<< "The last-but-one dimension of the two vectors must be consistent.";
} else {
CHECK_EQ(a_shape[a_shape_size - 1], b_shape[b_shape_size - 1])
<< "The last dimension of the two vectors must be consistent.";
}

return {b_shape};
}

std::vector<Type> InferDtypeForTriangularSolve(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
CHECK_EQ(inputs_type.size(), 2U) << "The input's shape size should be 2! Please check again.";
CHECK(inputs_type[0].is_float(32) || inputs_type[0].is_float(64))
<< "The input's dtype should be float32 or float64! Please check again.";
CHECK(inputs_type[1].is_float(32) || inputs_type[1].is_float(64))
<< "The input's dtype should be float32 or float64! Please check again.";
return std::vector<Type>{inputs_type[1]};
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(triangular_solve_ops) {
CINN_REGISTER_OP(triangular_solve)
.describe("TriangularSolve")
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForTriangularSolve)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTriangularSolve))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTriangularSolve))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible)
.set_support_level(4);

return true;
}
40 changes: 40 additions & 0 deletions cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,44 @@ std::vector<ir::Expr> CustomCallArgsForCholesky(const framework::NodeAttr &attrs
return args;
}

std::vector<ir::Expr> CustomCallArgsForTriangularSolve(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<std::vector<int>> &output_shapes) {
CHECK_EQ(inputs.size(), 2UL);
auto attr_store = attrs.attr_store;
CHECK(attr_store.count("left_side"));
CHECK(attr_store.count("upper"));
CHECK(attr_store.count("transpose_a"));
CHECK(attr_store.count("unit_diagonal"));

ir::Tensor a = inputs[0];
ir::Tensor b = inputs[1];
int a_ndim = static_cast<int>(a->shape.size());
int b_ndim = static_cast<int>(b->shape.size());
int batch_size = 1;
for (int i = 0; i < a_ndim - 2; i++) {
batch_size *= a->shape[i].as_int32();
}

auto left_side = absl::get<bool>(attrs.attr_store.at("left_side"));
auto upper = absl::get<bool>(attrs.attr_store.at("upper"));
auto transpose_a = absl::get<bool>(attrs.attr_store.at("transpose_a"));
auto unit_diagonal = absl::get<bool>(attrs.attr_store.at("unit_diagonal"));

int m = a->shape[a_ndim - 1].as_int32();
int k = left_side ? b->shape[b_ndim - 1].as_int32() : b->shape[b_ndim - 2].as_int32();

std::vector<ir::Expr> args = {ir::Expr(batch_size),
ir::Expr(m),
ir::Expr(k),
ir::Expr(left_side),
ir::Expr(upper),
ir::Expr(transpose_a),
ir::Expr(unit_diagonal)};

return args;
}

bool RegisteryCustomCallArgsFunc() {
#ifdef CINN_WITH_CUDA
CustomCallArgsFuncRegistry::Global().Register(
Expand All @@ -736,6 +774,8 @@ bool RegisteryCustomCallArgsFunc() {
"cinn_call_cholesky_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForCholesky);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_batched_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForBatchedCublas);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_triangular_solve_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForTriangularSolve);
#endif

#ifdef CINN_WITH_CUDNN
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/external_api_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ CINN_REGISTER_HELPER(op_external_api) {
CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu");
CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host");
CINN_OP_REGISTER_EXTERNAL_API(triangular_solve, default_nvgpu).set_api_name("cinn_call_triangular_solve_nvgpu");
#ifdef CINN_WITH_CUDNN
CINN_OP_REGISTER_EXTERNAL_API(conv2d, default_nvgpu).set_trans_func([](const ::cinn::hlir::framework::Node* node) {
CHECK(node->attrs.attr_store.count("conv_type"));
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ CINN_USE_REGISTER(reciprocal_ops)
CINN_USE_REGISTER(gaussian_random_ops)
CINN_USE_REGISTER(uniform_random_ops)
CINN_USE_REGISTER(cholesky_ops)
CINN_USE_REGISTER(triangular_solve_ops)
CINN_USE_REGISTER(op_external_api)
10 changes: 9 additions & 1 deletion cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,15 @@ void BindFrontend(pybind11::module *m) {
py::arg("seed") = 0,
py::arg("dtype") = "float32")
.def("norm", &NetBuilder::Norm, py::arg("x"), py::arg("axis") = -1, py::arg("epsilon") = 1e-12f)
.def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false);
.def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false)
.def("triangular_solve",
&NetBuilder::TriangularSolve,
py::arg("input1"),
py::arg("input2"),
py::arg("left_side") = true,
py::arg("upper") = false,
py::arg("transpose_a") = false,
py::arg("unit_diagonal") = false);

auto computation = py::class_<CinnComputation, std::shared_ptr<CinnComputation>>(*m, "Computation");
py::class_<CinnComputation::CompileOptions>(computation, "CompileOptions")
Expand Down
15 changes: 15 additions & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,21 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) {
.AddInputType<void *>() // stream
.End();

using cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_triangular_solve_nvgpu, cinn::common::DefaultNVGPUTarget())
.SetRetType<void>()
.AddInputType<void *>() // v_args
.AddInputType<int>() // num_args
.AddInputType<int>() // batch_size
.AddInputType<int>() // m
.AddInputType<int>() // k
.AddInputType<bool>() // left_side
.AddInputType<bool>() // upper
.AddInputType<bool>() // transpose_a
.AddInputType<bool>() // unit_diagonal
.AddInputType<void *>() // stream
.End();

#ifdef CINN_WITH_CUDNN
using cinn::runtime::cuda::cinn_call_cudnn_conv2d_forward;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, cinn::common::DefaultHostTarget())
Expand Down
Loading