diff --git a/src/Accelerators/NNPA/Compiler/CMakeLists.txt b/src/Accelerators/NNPA/Compiler/CMakeLists.txt index 6a8c29a45d..83e4bdd9a2 100644 --- a/src/Accelerators/NNPA/Compiler/CMakeLists.txt +++ b/src/Accelerators/NNPA/Compiler/CMakeLists.txt @@ -1,5 +1,3 @@ -get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS) - add_onnx_mlir_library(OMNNPACompilerOptions NNPACompilerOptions.cpp @@ -12,7 +10,6 @@ add_onnx_mlir_library(OMNNPACompilerOptions ${NNPA_ONNX_MLIR_BIN_ROOT} LINK_LIBS PUBLIC - ${OMLibs} OMCompilerOptions ACCEL_INCLUDE_DIRS PRIVATE @@ -32,7 +29,6 @@ add_onnx_mlir_library(OMNNPACompilerUtils ${NNPA_ONNX_MLIR_BIN_ROOT} LINK_LIBS PUBLIC - ${OMLibs} OMNNPACompilerOptions OMCompilerPasses diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index 5f033fff9e..ccc494e226 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -55,6 +55,13 @@ llvm::cl::opt nnpaEnableCompilerStickUnstick( "stick/unstick code. Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt nnpaEnableScalarBcastBinary( + "nnpa-enable-scalar-bcast-binary", + llvm::cl::desc("Enable the lowering to NNPA the broadcasting binary ops " + "whose one of the operands is scalar. Currently support " + "ONNXDiv only. Default is false."), + llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); + llvm::cl::opt nnpaLoadDevicePlacementFile{ "nnpa-load-device-placement-file", llvm::cl::desc( diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index d7eee71707..67c3151317 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -49,11 +49,13 @@ typedef enum { } NNPAPlacementHeuristic; extern llvm::cl::OptionCategory OnnxMlirOptions; +extern llvm::cl::OptionCategory OnnxMlirCommonOptions; extern llvm::cl::opt nnpaEmissionTarget; extern llvm::cl::opt nnpaClipToDLFloatRange; extern llvm::cl::opt nnpaEnableZHighToOnnx; extern llvm::cl::opt nnpaEnableZHighDecomposeStickUnstick; extern llvm::cl::opt nnpaEnableCompilerStickUnstick; +extern llvm::cl::opt nnpaEnableScalarBcastBinary; extern llvm::cl::opt nnpaPlacementHeuristic; extern llvm::cl::opt profileZHighIR; extern llvm::cl::opt nnpaLoadDevicePlacementFile; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt index de58e1277e..2e50d0797b 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt @@ -11,7 +11,7 @@ add_onnx_mlir_library(OMONNXToZHigh libzdnn LINK_LIBS PUBLIC - OMCompilerOptions + OMNNPACompilerOptions OMONNXOps OMONNXToKrnl OMZHighOps @@ -32,7 +32,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh libzdnn LINK_LIBS PUBLIC - OMCompilerOptions + OMNNPACompilerOptions OMONNXOps OMONNXToKrnl OMZHighOps diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 26cc6d20e7..9f8c44efdd 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -324,12 +324,19 @@ bool isSuitableForZDNN( // Check NNPA level. if (!isCompatibleWithNNPALevel(NNPA_Z16)) return false; - if (!isF32ScalarConstantTensor(A) && !isValidElementTypeAndRank(A)) + // Broadcast with a scalar operand. + if (isEnableScalarBcastBinary()) { + if (isF32ScalarConstantTensor(A) && isValidElementTypeAndRank(B)) + return true; + if (isF32ScalarConstantTensor(B) && isValidElementTypeAndRank(A)) + return true; + } + // Non-broadcast cases. + if (!isValidElementTypeAndRank(A)) return false; - if (!isF32ScalarConstantTensor(B) && !isValidElementTypeAndRank(B)) + if (!isValidElementTypeAndRank(B)) return false; - return isF32ScalarConstantTensor(A) || isF32ScalarConstantTensor(B) || - dimAnalysis->sameShape(A, B); + return dimAnalysis->sameShape(A, B); } /// Check legality for ONNXSum. diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td index ee5a1b3940..7dd088bb3c 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td @@ -29,6 +29,8 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td" /// dag benefitsAdded = (addBenefit 0) /// >; +def IsEnableScalarBcastBinary: Constraint>; + def IsNoneType : Constraint())">>; def IsNotNoneType : Constraint())">>; @@ -227,7 +229,7 @@ def replaceONNXDivBroadcastPattern1 : Pat< (GetScalarF32AttrFromConstant $y), (NoneLayoutAttr)), (returnType $s_x))), - [(IsF32ScalarConstantTensor $y)], [], + [(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $y)], [], (addBenefit 1) >; @@ -241,7 +243,7 @@ def replaceONNXDivBroadcastPattern2 : Pat< (NoneLayoutAttr)), (ZHighStickOp:$s_y $y, (NoneLayoutAttr)), (returnType $s_y))), - [(IsF32ScalarConstantTensor $x)], [], + [(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $x)], [], (addBenefit 1) >; diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp index 150b040f00..f99e9737ac 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp @@ -14,11 +14,14 @@ //===----------------------------------------------------------------------===// #include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp" +#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" using namespace mlir; namespace onnx_mlir { +bool isEnableScalarBcastBinary() { return nnpaEnableScalarBcastBinary; } + /// Get transposed tensor by using a permutation array. Value emitONNXTranspose( Location loc, PatternRewriter &rewriter, Value x, ArrayRef perms) { diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index ec95a7a1be..2eef0b9646 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -27,6 +27,8 @@ const std::string DEVICE_ATTRIBUTE = "device"; const std::string CPU_DEVICE = "cpu"; const std::string NNPA_DEVICE = "nnpa"; +bool isEnableScalarBcastBinary(); + template void addDynamicallyLegalOpFor(mlir::ConversionTarget *target, const onnx_mlir::DimAnalysis *dimAnalysis, diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index b097c5d1eb..9e25e44fa0 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -29,11 +29,9 @@ std::unique_ptr createDevicePlacementPass( /// Add pass for lowering ONNX ops to ZHigh ops. std::unique_ptr createONNXToZHighPass(); -std::unique_ptr createONNXToZHighPass(); /// Add pass for rewriting ONNX ops for ZHigh. std::unique_ptr createRewriteONNXForZHighPass(); -std::unique_ptr createRewriteONNXForZHighPass(); /// Add pass for re-construct ONNX ops from ZHigh ops. std::unique_ptr createZHighToONNXPass(); diff --git a/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp b/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp index 18b7048cca..99dc1037f2 100644 --- a/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp +++ b/src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp @@ -53,7 +53,7 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern { auto memrefTy = llvm::dyn_cast(memref.getType()); int64_t rank = memrefTy.getRank(); - assert(mapResults.value().size() == rank && "Invalid indices"); + assert((int64_t)mapResults.value().size() == rank && "Invalid indices"); // Only lower this op after the memref is normalized. if (!memrefTy.getLayout().isIdentity()) diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir new file mode 100644 index 0000000000..7df7a2cb2a --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div-bcast.mlir @@ -0,0 +1,44 @@ +// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --nnpa-enable-scalar-bcast-binary %s -split-input-file | FileCheck %s + +// COM: Division by a scalar in case of dynamic dimensions. +func.func @test_div_unknown_scalar1(%arg0 : tensor) -> tensor<*xf32> { + %0 = onnx.Constant dense<8.000000e+00> : tensor + %1 = "onnx.Div"(%arg0, %0) : (tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_div_unknown_scalar1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor) -> tensor> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor>, tensor>) -> tensor> +// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_7_]] : tensor +// CHECK: } +} + +// ----- + +// COM: Division by a scalar in case of dynamic dimensions. +func.func @test_div_unknown_scalar2(%arg0 : tensor) -> tensor<*xf32> { + %0 = onnx.Constant dense<8.000000e+00> : tensor + %1 = "onnx.Div"(%0, %arg0) : (tensor, tensor) -> tensor<*xf32> + "func.return"(%1) : (tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_div_unknown_scalar2 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> +// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor> +// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor) -> tensor> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor>, tensor>) -> tensor> +// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor>) -> tensor +// CHECK: return [[VAR_7_]] : tensor +// CHECK: } +} + diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir index 5cbb42a2f1..9cad7a6915 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir @@ -32,50 +32,6 @@ func.func @test_div_3ds(%arg0 : tensor<10x10x10xf32>, %arg1 : tensor<10x10x10xf3 // ----- -// COM: Division by a scalar in case of dynamic dimensions. -func.func @test_div_unknown_scalar1(%arg0 : tensor) -> tensor<*xf32> { - %0 = onnx.Constant dense<8.000000e+00> : tensor - %1 = "onnx.Div"(%arg0, %0) : (tensor, tensor) -> tensor<*xf32> - "func.return"(%1) : (tensor<*xf32>) -> () - -// CHECK-LABEL: func.func @test_div_unknown_scalar1 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor -// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor) -> tensor> -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> -// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> -// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> -// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor> -// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor>, tensor>) -> tensor> -// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor>) -> tensor -// CHECK: return [[VAR_7_]] : tensor -// CHECK: } -} - -// ----- - -// COM: Division by a scalar in case of dynamic dimensions. -func.func @test_div_unknown_scalar2(%arg0 : tensor) -> tensor<*xf32> { - %0 = onnx.Constant dense<8.000000e+00> : tensor - %1 = "onnx.Div"(%0, %arg0) : (tensor, tensor) -> tensor<*xf32> - "func.return"(%1) : (tensor<*xf32>) -> () - -// CHECK-LABEL: func.func @test_div_unknown_scalar2 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor -// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor) -> tensor<1xi64> -// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor) -> tensor<1xi64> -// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64> -// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor> -// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor) -> tensor> -// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor>, tensor>) -> tensor> -// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor>) -> tensor -// CHECK: return [[VAR_7_]] : tensor -// CHECK: } -} - -// ----- - // COM: Do not lower broadcasting onnx.Div to zHigh. func.func @test_div_not_lowered_diff_shape(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10xf32>) -> tensor<*xf32> { %0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10xf32>) -> tensor<*xf32> diff --git a/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir b/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir index 2f222ba37e..cfe6e2b611 100644 --- a/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir +++ b/test/mlir/accelerators/nnpa/driver/matmul-div-in-attention-layer.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR %s | FileCheck %s +// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --nnpa-enable-scalar-bcast-binary --printIR %s | FileCheck %s // Check whether the compiler can remove unstick/stick so that the output of zdnn matmul is passed directly to zdnn div. func.func @matmul_div(%arg0: tensor) -> tensor {