diff --git a/cinn/frontend/net_builder.cc b/cinn/frontend/net_builder.cc index c00aed5eba..c2357f3af6 100644 --- a/cinn/frontend/net_builder.cc +++ b/cinn/frontend/net_builder.cc @@ -19,6 +19,7 @@ #include #include "cinn/frontend/syntax.h" +#include "cinn/hlir/pe/broadcast.h" namespace cinn { namespace frontend { @@ -660,6 +661,42 @@ Variable NetBuilder::Cholesky(const Variable& x, bool upper) { return CustomInstr("cholesky", {x}, {{"upper", upper}}).front(); } +Variable NetBuilder::TriangularSolve( + const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal) { + // broadcast + std::vector inputs{input1, input2}; + { + auto a_ndim = input1->shape.size(); + auto b_ndim = input2->shape.size(); + CHECK_GE(a_ndim, 2) << "The input matrix A shape size should >= 2! Please check again."; + CHECK_GE(b_ndim, 2) << "The input matrix B shape size should >= 2! Please check again."; + std::vector input1_shape_cut(input1->shape.begin(), input1->shape.end() - 2); + std::vector input2_shape_cut(input2->shape.begin(), input2->shape.end() - 2); + std::vector common_shape; + hlir::pe::GetBroadcastOutShape(input1_shape_cut, input2_shape_cut, &common_shape); + + // broadcast input1 + std::vector input1_shape(common_shape.begin(), common_shape.end()); + input1_shape.push_back(input1->shape[a_ndim - 2]); + input1_shape.push_back(input1->shape[a_ndim - 1]); + inputs[0] = BroadcastTo(input1, input1_shape); + + // broadcast input2 + std::vector input2_shape(common_shape.begin(), common_shape.end()); + input2_shape.push_back(input2->shape[b_ndim - 2]); + input2_shape.push_back(input2->shape[b_ndim - 1]); + inputs[1] = BroadcastTo(input2, input2_shape); + } + + return CustomInstr("triangular_solve", + inputs, + {{"left_side", left_side}, + {"upper", upper}, + {"transpose_a", transpose_a}, + {"unit_diagonal", unit_diagonal}}) + .front(); +} + Variable NetBuilder::Norm(const Variable& x, int axis, float epsilon) { Instruction instr("norm", {x}); instr.SetAttr("axis", axis); diff --git a/cinn/frontend/net_builder.h b/cinn/frontend/net_builder.h old mode 100755 new mode 100644 index 498aed60dc..67226336b0 --- a/cinn/frontend/net_builder.h +++ b/cinn/frontend/net_builder.h @@ -999,6 +999,21 @@ class NetBuilder { */ Variable Cholesky(const Variable& x, bool upper = false); + /** + * @brief Solve triangular linear systems with multiple right-hand-sides. + * @param input1 triangular matrix stored in lower or upper mode. + * @param input2 matrix on the right hand side. + * @param left_side When left_side is true, compute A*X = B. + When left_side is false, compute X*A = B. + * @param upper When upper is true, use the upper part of the triangular matrix. + When upper is false, use the lower part of the triangular matrix. + * @param transpose_a When transpose_a is true, use the transpose of matrix A + * @param unit_diagonal When unit_diagonal is true, assume the elements on the main diagonal of matrix A are unity + * @return The solution for the triangular linear systems. + */ + Variable TriangularSolve( + const Variable& input1, const Variable& input2, bool left_side, bool upper, bool transpose_a, bool unit_diagonal); + /** * @brief l2-Norm * @param x The input operand to be normed. diff --git a/cinn/hlir/op/contrib/CMakeLists.txt b/cinn/hlir/op/contrib/CMakeLists.txt index e166342171..62af3e2830 100644 --- a/cinn/hlir/op/contrib/CMakeLists.txt +++ b/cinn/hlir/op/contrib/CMakeLists.txt @@ -14,6 +14,7 @@ gather_srcs(cinnapi_src SRCS gaussian_random.cc uniform_random.cc cholesky.cc + triangular_solve.cc ) cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/contrib/triangular_solve.cc b/cinn/hlir/op/contrib/triangular_solve.cc new file mode 100644 index 0000000000..b7b1cfd794 --- /dev/null +++ b/cinn/hlir/op/contrib/triangular_solve.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "cinn/common/common.h" +#include "cinn/common/macros.h" +#include "cinn/hlir/framework/node.h" +#include "cinn/hlir/framework/op.h" +#include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/op_util.h" +#include "cinn/hlir/pe/elementwise.h" +#include "cinn/hlir/pe/ir_schedule_pe.h" +#include "cinn/ir/ir.h" +#include "cinn/ir/ir_base.h" +#include "cinn/ir/tensor.h" +#include "cinn/lang/builtin.h" +#include "cinn/lang/compute.h" + +namespace cinn { +namespace hlir { +namespace op { + +using common::CINNValue; +using common::CINNValuePack; + +std::shared_ptr StrategyForTriangularSolve(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute triangular_solve_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of triangular_solve is empty! Please check."; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of triangular_solve."; + Expr a_expr = pack_args[0]; + Expr b_expr = pack_args[1]; + ir::Tensor a = a_expr.as_tensor_ref(); + ir::Tensor b = b_expr.as_tensor_ref(); + std::string tensor_name = "triangular_solve_out"; + auto out = pe::Identity(b, tensor_name).front(); + auto stages = CreateStages({out}); + std::vector res{CINNValue(out), CINNValue(stages)}; + *ret = CINNValuePack{res}; + }); + auto strategy = std::make_shared(); + strategy->AddImpl( + triangular_solve_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.triangular_solve.x86", 1); + return strategy; +} + +std::vector InferShapeForTriangularSolve(const std::vector &inputs_shape, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_shape.size(), 2U) << "The input's shape size should be 2! Please check again."; + framework::shape_t a_shape = inputs_shape[0]; + framework::shape_t b_shape = inputs_shape[1]; + int a_shape_size = a_shape.size(); + int b_shape_size = b_shape.size(); + CHECK_GE(a_shape_size, 2U) << "The input matrix A shape size should >= 2! Please check again."; + CHECK_GE(b_shape_size, 2U) << "The input matrix B shape size should >= 2! Please check again."; + + int left_side = -1; + for (auto &iter : attrs) { + if (iter.first == "left_side") { + left_side = absl::get(iter.second); + break; + } + } + + CHECK_EQ(a_shape[a_shape_size - 2], a_shape[a_shape_size - 1]) + << "The last two dimensions of the input a must be the same!"; + if (left_side) { + CHECK_EQ(a_shape[a_shape_size - 2], b_shape[b_shape_size - 2]) + << "The last-but-one dimension of the two vectors must be consistent."; + } else { + CHECK_EQ(a_shape[a_shape_size - 1], b_shape[b_shape_size - 1]) + << "The last dimension of the two vectors must be consistent."; + } + + return {b_shape}; +} + +std::vector InferDtypeForTriangularSolve(const std::vector &inputs_type, + const framework::AttrMapType &attrs) { + CHECK_EQ(inputs_type.size(), 2U) << "The input's shape size should be 2! Please check again."; + CHECK(inputs_type[0].is_float(32) || inputs_type[0].is_float(64)) + << "The input's dtype should be float32 or float64! Please check again."; + CHECK(inputs_type[1].is_float(32) || inputs_type[1].is_float(64)) + << "The input's dtype should be float32 or float64! Please check again."; + return std::vector{inputs_type[1]}; +} + +} // namespace op +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(triangular_solve_ops) { + CINN_REGISTER_OP(triangular_solve) + .describe("TriangularSolve") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForTriangularSolve) + .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTriangularSolve)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForTriangularSolve)) + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_support_level(4); + + return true; +} diff --git a/cinn/hlir/op/custom_call.cc b/cinn/hlir/op/custom_call.cc index ceaa7c4b27..ada899f815 100644 --- a/cinn/hlir/op/custom_call.cc +++ b/cinn/hlir/op/custom_call.cc @@ -724,6 +724,44 @@ std::vector CustomCallArgsForCholesky(const framework::NodeAttr &attrs return args; } +std::vector CustomCallArgsForTriangularSolve(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector> &output_shapes) { + CHECK_EQ(inputs.size(), 2UL); + auto attr_store = attrs.attr_store; + CHECK(attr_store.count("left_side")); + CHECK(attr_store.count("upper")); + CHECK(attr_store.count("transpose_a")); + CHECK(attr_store.count("unit_diagonal")); + + ir::Tensor a = inputs[0]; + ir::Tensor b = inputs[1]; + int a_ndim = static_cast(a->shape.size()); + int b_ndim = static_cast(b->shape.size()); + int batch_size = 1; + for (int i = 0; i < a_ndim - 2; i++) { + batch_size *= a->shape[i].as_int32(); + } + + auto left_side = absl::get(attrs.attr_store.at("left_side")); + auto upper = absl::get(attrs.attr_store.at("upper")); + auto transpose_a = absl::get(attrs.attr_store.at("transpose_a")); + auto unit_diagonal = absl::get(attrs.attr_store.at("unit_diagonal")); + + int m = a->shape[a_ndim - 1].as_int32(); + int k = left_side ? b->shape[b_ndim - 1].as_int32() : b->shape[b_ndim - 2].as_int32(); + + std::vector args = {ir::Expr(batch_size), + ir::Expr(m), + ir::Expr(k), + ir::Expr(left_side), + ir::Expr(upper), + ir::Expr(transpose_a), + ir::Expr(unit_diagonal)}; + + return args; +} + bool RegisteryCustomCallArgsFunc() { #ifdef CINN_WITH_CUDA CustomCallArgsFuncRegistry::Global().Register( @@ -736,6 +774,8 @@ bool RegisteryCustomCallArgsFunc() { "cinn_call_cholesky_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForCholesky); CustomCallArgsFuncRegistry::Global().Register( "cinn_call_batched_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForBatchedCublas); + CustomCallArgsFuncRegistry::Global().Register( + "cinn_call_triangular_solve_nvgpu", common::DefaultNVGPUTarget(), CustomCallArgsForTriangularSolve); #endif #ifdef CINN_WITH_CUDNN diff --git a/cinn/hlir/op/external_api_registry.cc b/cinn/hlir/op/external_api_registry.cc index 0011c8f243..2d2cf1702e 100644 --- a/cinn/hlir/op/external_api_registry.cc +++ b/cinn/hlir/op/external_api_registry.cc @@ -58,6 +58,7 @@ CINN_REGISTER_HELPER(op_external_api) { CINN_OP_REGISTER_EXTERNAL_API(uniform_random, default_nvgpu).set_api_name("cinn_call_uniform_random"); CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_nvgpu).set_api_name("cinn_call_cholesky_nvgpu"); CINN_OP_REGISTER_EXTERNAL_API(cholesky, default_host).set_api_name("cinn_call_cholesky_host"); + CINN_OP_REGISTER_EXTERNAL_API(triangular_solve, default_nvgpu).set_api_name("cinn_call_triangular_solve_nvgpu"); #ifdef CINN_WITH_CUDNN CINN_OP_REGISTER_EXTERNAL_API(conv2d, default_nvgpu).set_trans_func([](const ::cinn::hlir::framework::Node* node) { CHECK(node->attrs.attr_store.count("conv_type")); diff --git a/cinn/hlir/op/use_ops.h b/cinn/hlir/op/use_ops.h index 4de19f8dba..f931c48312 100644 --- a/cinn/hlir/op/use_ops.h +++ b/cinn/hlir/op/use_ops.h @@ -37,4 +37,5 @@ CINN_USE_REGISTER(reciprocal_ops) CINN_USE_REGISTER(gaussian_random_ops) CINN_USE_REGISTER(uniform_random_ops) CINN_USE_REGISTER(cholesky_ops) +CINN_USE_REGISTER(triangular_solve_ops) CINN_USE_REGISTER(op_external_api) diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index b65168a9a2..35d5ad9cd1 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -670,7 +670,15 @@ void BindFrontend(pybind11::module *m) { py::arg("seed") = 0, py::arg("dtype") = "float32") .def("norm", &NetBuilder::Norm, py::arg("x"), py::arg("axis") = -1, py::arg("epsilon") = 1e-12f) - .def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false); + .def("cholesky", &NetBuilder::Cholesky, py::arg("x"), py::arg("upper") = false) + .def("triangular_solve", + &NetBuilder::TriangularSolve, + py::arg("input1"), + py::arg("input2"), + py::arg("left_side") = true, + py::arg("upper") = false, + py::arg("transpose_a") = false, + py::arg("unit_diagonal") = false); auto computation = py::class_>(*m, "Computation"); py::class_(computation, "CompileOptions") diff --git a/cinn/runtime/cuda/cuda_intrinsics.cc b/cinn/runtime/cuda/cuda_intrinsics.cc index 77082412fa..615610ecab 100644 --- a/cinn/runtime/cuda/cuda_intrinsics.cc +++ b/cinn/runtime/cuda/cuda_intrinsics.cc @@ -369,6 +369,21 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) { .AddInputType() // stream .End(); + using cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu; + REGISTER_EXTERN_FUNC_HELPER(cinn_call_triangular_solve_nvgpu, cinn::common::DefaultNVGPUTarget()) + .SetRetType() + .AddInputType() // v_args + .AddInputType() // num_args + .AddInputType() // batch_size + .AddInputType() // m + .AddInputType() // k + .AddInputType() // left_side + .AddInputType() // upper + .AddInputType() // transpose_a + .AddInputType() // unit_diagonal + .AddInputType() // stream + .End(); + #ifdef CINN_WITH_CUDNN using cinn::runtime::cuda::cinn_call_cudnn_conv2d_forward; REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, cinn::common::DefaultHostTarget()) diff --git a/cinn/runtime/cuda/cuda_util.cc b/cinn/runtime/cuda/cuda_util.cc index 95b9f34b7b..35ee0cc83d 100644 --- a/cinn/runtime/cuda/cuda_util.cc +++ b/cinn/runtime/cuda/cuda_util.cc @@ -23,6 +23,7 @@ #include #include +#include #ifdef CINN_WITH_CUDNN #include #endif @@ -1214,6 +1215,101 @@ void cinn_call_cholesky_nvgpu(void *v_args, int num_args, int batch_size, int m, CUSOLVER_CALL(cusolverDnDestroy(handler)); } +void cinn_call_triangular_solve_nvgpu(void *v_args, + int num_args, + int batch_size, + int m, + int k, + bool left_side, + bool upper, + bool transpose_a, + bool unit_diagonal, + void *stream) { + cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle(); + cudaStream_t custream = static_cast(stream); + CUBLAS_CALL(cublasSetStream(handle, custream)); + + int b_rows = left_side ? k : m; + int b_cols = left_side ? m : k; + int lda = m; + int ldb = b_rows; + cublasSideMode_t side = left_side ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t uplo = upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + cublasOperation_t transa = transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasDiagType_t diag = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + cinn_pod_value_t *args = static_cast(v_args); + cinn_buffer_t *input1 = args[0].operator cinn_buffer_t *(); + cinn_buffer_t *input2 = args[1].operator cinn_buffer_t *(); + cinn_buffer_t *output = args[2].operator cinn_buffer_t *(); + + CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float); + CHECK_EQ(input2->type.code, cinn_type_code_t::cinn_type_float); + CHECK_EQ(input1->type.bits, input2->type.bits); + uint8_t bits = input1->type.bits; + uint8_t bytes = bits / 8; + CHECK(bits == 32 || bits == 64) << "unsupported bits = " << bits << " float data type for triangular solve"; + + std::string debug_info = + "triangular solve op: left_side=" + std::to_string(left_side) + ", upper=" + std::to_string(uplo) + + ", transpose_a=" + std::to_string(transa) + ", unit_diagonal=" + std::to_string(unit_diagonal) + + ", batch_size=" + std::to_string(batch_size) + ", m=" + std::to_string(m) + ", k=" + std::to_string(k) + + ", input1_dtype={code: " + std::to_string(input1->type.code) + ", bits: " + std::to_string(input1->type.bits) + + "}" + ", input2_dtype={code: " + std::to_string(input2->type.code) + + ", bits: " + std::to_string(input2->type.bits) + "}"; + VLOG(4) << debug_info; + + void *a_ptr = reinterpret_cast(input1->memory); + void *b_ptr = reinterpret_cast(input2->memory); + void *x_ptr = reinterpret_cast(output->memory); + + // The API cublasStrsmBatched overwrites the right-hand sides, so the right-hand sides should be copied to the output. + // The output can then be used directly for the calculation. + size_t numel = input2->num_elements(); + CUDA_CALL(cudaMemcpyAsync(x_ptr, b_ptr, numel * bytes, cudaMemcpyDeviceToDevice, custream)); + + std::vector a_array(batch_size, nullptr); + std::vector x_array(batch_size, nullptr); + for (int i = 0; i < batch_size; ++i) { + a_array[i] = reinterpret_cast(a_ptr) + i * m * m * bytes; + x_array[i] = reinterpret_cast(x_ptr) + i * m * k * bytes; + } + thrust::device_vector dev_a_array(a_array.begin(), a_array.end()); + thrust::device_vector dev_x_array(x_array.begin(), x_array.end()); + + if (bits == 32) { + std::vector alpha(batch_size, 1.0f); + CUBLAS_CALL(cublasStrsmBatched(handle, + side, + uplo, + transa, + diag, + b_rows, + b_cols, + alpha.data(), + reinterpret_cast(dev_a_array.data().get()), + lda, + reinterpret_cast(dev_x_array.data().get()), + ldb, + batch_size)); + } else if (bits == 64) { + std::vector alpha(batch_size, 1.0); + CUBLAS_CALL(cublasDtrsmBatched(handle, + side, + uplo, + transa, + diag, + b_rows, + b_cols, + alpha.data(), + reinterpret_cast(dev_a_array.data().get()), + lda, + reinterpret_cast(dev_x_array.data().get()), + ldb, + batch_size)); + } +} + void cinn_gpu_cublas_mul(const std::vector &attrs, cinn_buffer_t *input1, cinn_buffer_t *input2, diff --git a/cinn/runtime/cuda/cuda_util.h b/cinn/runtime/cuda/cuda_util.h index e09b536fe9..6f67a814a5 100644 --- a/cinn/runtime/cuda/cuda_util.h +++ b/cinn/runtime/cuda/cuda_util.h @@ -48,6 +48,17 @@ void cinn_call_uniform_random(void* v_args, int num_args, float min, float max, void cinn_call_cholesky_nvgpu(void* v_args, int num_args, int batch_size, int m, bool upper, void* stream = nullptr); +void cinn_call_triangular_solve_nvgpu(void* v_args, + int num_args, + int batch_size, + int m, + int k, + bool left_side, + bool upper, + bool transpose_a, + bool unit_diagonal, + void* stream = nullptr); + #ifdef CINN_WITH_CUDNN void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map& attr, cinn_buffer_t* x, diff --git a/cinn/runtime/custom_function_test.cc b/cinn/runtime/custom_function_test.cc index c11999720a..770d21401b 100644 --- a/cinn/runtime/custom_function_test.cc +++ b/cinn/runtime/custom_function_test.cc @@ -16,6 +16,7 @@ #include #include +#include #ifdef CINN_WITH_CUDA #include @@ -294,5 +295,47 @@ TEST(CustomCallCholesky, test) { } } +#ifdef CINN_WITH_CUDA +TEST(CustomCallTriangularSolve, test) { + Target target = common::DefaultNVGPUTarget(); + Target host_target = common::DefaultHostTarget(); + + int batch_size = 1; + int m = 3; + int k = 1; + bool left_side = true; + bool upper = true; + bool transpose_a = false; + bool unit_diagonal = false; + + double input_a_host[9] = {1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 0.0, 0.0, -1.0}; + double input_b_host[3] = {0.0, -9.0, 5.0}; + CinnBufferAllocHelper a(cinn_x86_device, cinn_float64_t(), {m, m}); + CinnBufferAllocHelper b(cinn_x86_device, cinn_float64_t(), {m, k}); + auto* input_a = a.mutable_data(target); + auto* input_b = b.mutable_data(target); + SetInputValue(input_a, input_a_host, m * m, target); + SetInputValue(input_b, input_b_host, m * k, target); + + // Output matrix out + CinnBufferAllocHelper out(cinn_x86_device, cinn_float64_t(), {m, k}); + auto* output = out.mutable_data(target); + + // Result matrix res + double result[3] = {7.0, -2.0, -5.0}; + + constexpr int num_args = 3; + cinn_pod_value_t v_args[num_args] = { + cinn_pod_value_t(a.get()), cinn_pod_value_t(b.get()), cinn_pod_value_t(out.get())}; + cinn::runtime::cuda::cinn_call_triangular_solve_nvgpu( + v_args, num_args, batch_size, m, k, left_side, upper, transpose_a, unit_diagonal); + std::vector device_output(batch_size * m * k, 0.0f); + cudaMemcpy(device_output.data(), output, batch_size * m * k * sizeof(double), cudaMemcpyDeviceToHost); + for (int i = 0; i < batch_size * m * k; i++) { + ASSERT_NEAR(device_output[i], result[i], 1e-5) << "The output of triangular solve should be the same as result"; + } +} +#endif + } // namespace runtime } // namespace cinn diff --git a/python/tests/ops/test_triangular_solve_op.py b/python/tests/ops/test_triangular_solve_op.py new file mode 100644 index 0000000000..616d1fc9a9 --- /dev/null +++ b/python/tests/ops/test_triangular_solve_op.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, OpTestTool +import paddle +from cinn.frontend import * +from cinn.common import * + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOp(OpTest): + def setUp(self): + self.init_case() + + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + def build_paddle_program(self, target): + def transpose_last_two_dims(x): + shape = x.shape + last_dim_idx = len(shape) - 1 + second_last_dim_idx = len(shape) - 2 + perm = list(range(len(shape))) + perm[last_dim_idx], perm[second_last_dim_idx] = perm[ + second_last_dim_idx], perm[last_dim_idx] + x_transposed = paddle.transpose(x, perm=perm) + return x_transposed + + input1 = paddle.to_tensor(self.inputs["input1"], stop_gradient=True) + input2 = paddle.to_tensor(self.inputs["input2"], stop_gradient=True) + if self.left_side: + out = paddle.linalg.triangular_solve(input1, input2, self.upper, + self.transpose_a, + self.unit_diagonal) + self.paddle_outputs = [out] + else: + input1 = transpose_last_two_dims(input1) + input2 = transpose_last_two_dims(input2) + out = paddle.linalg.triangular_solve( + input1, input2, not self.upper, self.transpose_a, + self.unit_diagonal) + out = transpose_last_two_dims(out) + self.paddle_outputs = [out] + + def build_cinn_program(self, target): + builder = NetBuilder("triangular_solve") + input1 = builder.create_input( + self.nptype2cinntype(self.inputs["input1"].dtype), + self.inputs["input1"].shape, + "input1", + ) + input2 = builder.create_input( + self.nptype2cinntype(self.inputs["input2"].dtype), + self.inputs["input2"].shape, + "input2", + ) + out = builder.triangular_solve( + input1, + input2, + self.left_side, + self.upper, + self.transpose_a, + self.unit_diagonal, + ) + prog = builder.build() + res = self.get_cinn_output( + prog, + target, + [input1, input2], + [self.inputs["input1"], self.inputs["input2"]], + [out], + passes=[], + ) + self.cinn_outputs = [res[0]] + + def test_check_results(self): + self.check_outputs_and_grads() + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpUnitDiagonal(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = True + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpLower(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = False + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpZeroBatchDim1(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((3, 3)).astype(np.float32), + "input2": np.random.random((3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpZeroBatchDim2(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpZeroBatchDim3(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpBroadCast(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((2, 2, 3, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 4)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpBroadCast1(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((3, 3, 3)).astype(np.float32), + "input2": np.random.random((2, 2, 3, 3, 4)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpBroadCast2(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((2, 1, 3, 3, 3)).astype(np.float32), + "input2": np.random.random((2, 2, 3, 3, 4)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpBroadCast3(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((5, 1, 3, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 2, 1, 3, 4)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpTranspose(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = True + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpRightSide(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((2, 3, 3)).astype(np.float32), + "input2": np.random.random((2, 1, 3)).astype(np.float32), + } + self.left_side = False + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpRightSide1(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 2, 3, 3)).astype(np.float32), + "input2": np.random.random((2, 1, 2, 1, 3)).astype(np.float32), + } + self.left_side = False + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpDoubleFloat(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float64), + "input2": np.random.random((1, 3, 1)).astype(np.float64), + } + self.left_side = True + self.upper = True + self.transpose_a = True + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpBatch(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((5, 3, 3)).astype(np.float32), + "input2": np.random.random((5, 3, 1)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpMultipleRightHandSides(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((2, 3, 3)).astype(np.float32), + "input2": np.random.random((2, 3, 10)).astype(np.float32), + } + self.left_side = True + self.upper = True + self.transpose_a = False + self.unit_diagonal = False + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpSingular(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + # set one dim to zeros to make a singular matrix + self.inputs["input1"][0][0] = 0 + self.left_side = True + self.upper = True + self.transpose_a = True + self.unit_diagonal = False + + def test_check_results(self): + self.check_outputs_and_grads(equal_nan=True) + + +@OpTestTool.skip_if(not is_compiled_with_cuda(), + "triangular solve op support GPU only now.") +class TestTriangularSolveOpSingular1(TestTriangularSolveOp): + def init_case(self): + self.inputs = { + "input1": np.random.random((1, 3, 3)).astype(np.float32), + "input2": np.random.random((1, 3, 1)).astype(np.float32), + } + # set one dim to zeros to make a singular matrix + self.inputs["input1"][0][2] = 0 + self.left_side = True + self.upper = True + self.transpose_a = True + self.unit_diagonal = False + + def test_check_results(self): + self.check_outputs_and_grads(equal_nan=True) + + +if __name__ == "__main__": + unittest.main()