Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a flag to turn on/off the lowering of scalar broadcasting binary ops to NNPA #2778

Merged
merged 8 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/Accelerators/NNPA/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
get_property(OMLibs GLOBAL PROPERTY ONNX_MLIR_LIBS)

add_onnx_mlir_library(OMNNPACompilerOptions
NNPACompilerOptions.cpp

Expand All @@ -12,7 +10,6 @@ add_onnx_mlir_library(OMNNPACompilerOptions
${NNPA_ONNX_MLIR_BIN_ROOT}

LINK_LIBS PUBLIC
${OMLibs}
OMCompilerOptions

ACCEL_INCLUDE_DIRS PRIVATE
Expand All @@ -32,7 +29,6 @@ add_onnx_mlir_library(OMNNPACompilerUtils
${NNPA_ONNX_MLIR_BIN_ROOT}

LINK_LIBS PUBLIC
${OMLibs}
OMNNPACompilerOptions
OMCompilerPasses

Expand Down
7 changes: 7 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick(
"stick/unstick code. Default is false."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions));

llvm::cl::opt<bool> nnpaEnableScalarBcastBinary(
"nnpa-enable-scalar-bcast-binary",
llvm::cl::desc("Enable the lowering to NNPA the broadcasting binary ops "
"whose one of the operands is scalar. Currently support "
"ONNXDiv only. Default is false."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));

llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile{
"nnpa-load-device-placement-file",
llvm::cl::desc(
Expand Down
2 changes: 2 additions & 0 deletions src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ typedef enum {
} NNPAPlacementHeuristic;

extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::OptionCategory OnnxMlirCommonOptions;
extern llvm::cl::opt<onnx_mlir::NNPAEmissionTargetType> nnpaEmissionTarget;
extern llvm::cl::opt<bool> nnpaClipToDLFloatRange;
extern llvm::cl::opt<bool> nnpaEnableZHighToOnnx;
extern llvm::cl::opt<bool> nnpaEnableZHighDecomposeStickUnstick;
extern llvm::cl::opt<bool> nnpaEnableCompilerStickUnstick;
extern llvm::cl::opt<bool> nnpaEnableScalarBcastBinary;
extern llvm::cl::opt<NNPAPlacementHeuristic> nnpaPlacementHeuristic;
extern llvm::cl::opt<bool> profileZHighIR;
extern llvm::cl::opt<std::string> nnpaLoadDevicePlacementFile;
Expand Down
4 changes: 2 additions & 2 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ add_onnx_mlir_library(OMONNXToZHigh
libzdnn

LINK_LIBS PUBLIC
OMCompilerOptions
OMNNPACompilerOptions
OMONNXOps
OMONNXToKrnl
OMZHighOps
Expand All @@ -32,7 +32,7 @@ add_onnx_mlir_library(OMRewriteONNXForZHigh
libzdnn

LINK_LIBS PUBLIC
OMCompilerOptions
OMNNPACompilerOptions
OMONNXOps
OMONNXToKrnl
OMZHighOps
Expand Down
15 changes: 11 additions & 4 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,19 @@ bool isSuitableForZDNN<ONNXDivOp>(
// Check NNPA level.
if (!isCompatibleWithNNPALevel(NNPA_Z16))
return false;
if (!isF32ScalarConstantTensor(A) && !isValidElementTypeAndRank(A))
// Broadcast with a scalar operand.
if (isEnableScalarBcastBinary()) {
if (isF32ScalarConstantTensor(A) && isValidElementTypeAndRank(B))
return true;
if (isF32ScalarConstantTensor(B) && isValidElementTypeAndRank(A))
return true;
}
// Non-broadcast cases.
if (!isValidElementTypeAndRank(A))
return false;
if (!isF32ScalarConstantTensor(B) && !isValidElementTypeAndRank(B))
if (!isValidElementTypeAndRank(B))
return false;
return isF32ScalarConstantTensor(A) || isF32ScalarConstantTensor(B) ||
dimAnalysis->sameShape(A, B);
return dimAnalysis->sameShape(A, B);
}

/// Check legality for ONNXSum.
Expand Down
6 changes: 4 additions & 2 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.td"
/// dag benefitsAdded = (addBenefit 0)
/// >;

def IsEnableScalarBcastBinary: Constraint<CPred<"isEnableScalarBcastBinary()">>;

def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;

def IsNotNoneType : Constraint<CPred<"(!($_self).getType().isa<NoneType>())">>;
Expand Down Expand Up @@ -227,7 +229,7 @@ def replaceONNXDivBroadcastPattern1 : Pat<
(GetScalarF32AttrFromConstant $y),
(NoneLayoutAttr)),
(returnType $s_x))),
[(IsF32ScalarConstantTensor $y)], [],
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $y)], [],
(addBenefit 1)
>;

Expand All @@ -241,7 +243,7 @@ def replaceONNXDivBroadcastPattern2 : Pat<
(NoneLayoutAttr)),
(ZHighStickOp:$s_y $y, (NoneLayoutAttr)),
(returnType $s_y))),
[(IsF32ScalarConstantTensor $x)], [],
[(IsEnableScalarBcastBinary), (IsF32ScalarConstantTensor $x)], [],
(addBenefit 1)
>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
//===----------------------------------------------------------------------===//

#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"

using namespace mlir;
namespace onnx_mlir {

bool isEnableScalarBcastBinary() { return nnpaEnableScalarBcastBinary; }

/// Get transposed tensor by using a permutation array.
Value emitONNXTranspose(
Location loc, PatternRewriter &rewriter, Value x, ArrayRef<int64_t> perms) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const std::string DEVICE_ATTRIBUTE = "device";
const std::string CPU_DEVICE = "cpu";
const std::string NNPA_DEVICE = "nnpa";

bool isEnableScalarBcastBinary();

template <typename OP_TYPE>
void addDynamicallyLegalOpFor(mlir::ConversionTarget *target,
const onnx_mlir::DimAnalysis *dimAnalysis,
Expand Down
2 changes: 0 additions & 2 deletions src/Accelerators/NNPA/Pass/NNPAPasses.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ std::unique_ptr<mlir::Pass> createDevicePlacementPass(

/// Add pass for lowering ONNX ops to ZHigh ops.
std::unique_ptr<mlir::Pass> createONNXToZHighPass();
std::unique_ptr<mlir::Pass> createONNXToZHighPass();

/// Add pass for rewriting ONNX ops for ZHigh.
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();
std::unique_ptr<mlir::Pass> createRewriteONNXForZHighPass();

/// Add pass for re-construct ONNX ops from ZHigh ops.
std::unique_ptr<mlir::Pass> createZHighToONNXPass();
Expand Down
2 changes: 1 addition & 1 deletion src/Conversion/KrnlToAffine/KrnlGetLinearOffsetIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class KrnlGetLinearOffsetIndexLowering : public ConversionPattern {

auto memrefTy = llvm::dyn_cast<MemRefType>(memref.getType());
int64_t rank = memrefTy.getRank();
assert(mapResults.value().size() == rank && "Invalid indices");
assert((int64_t)mapResults.value().size() == rank && "Invalid indices");

// Only lower this op after the memref is normalized.
if (!memrefTy.getLayout().isIdentity())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: onnx-mlir-opt --mcpu=z16 --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --nnpa-enable-scalar-bcast-binary %s -split-input-file | FileCheck %s

// COM: Division by a scalar in case of dynamic dimensions.
func.func @test_div_unknown_scalar1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
%1 = "onnx.Div"(%arg0, %0) : (tensor<?x10xf32>, tensor<f32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func.func @test_div_unknown_scalar1
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
// CHECK: }
}

// -----

// COM: Division by a scalar in case of dynamic dimensions.
func.func @test_div_unknown_scalar2(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
%1 = "onnx.Div"(%0, %arg0) : (tensor<f32>, tensor<?x10xf32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func.func @test_div_unknown_scalar2
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
// CHECK: }
}

44 changes: 0 additions & 44 deletions test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/div.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,6 @@ func.func @test_div_3ds(%arg0 : tensor<10x10x10xf32>, %arg1 : tensor<10x10x10xf3

// -----

// COM: Division by a scalar in case of dynamic dimensions.
func.func @test_div_unknown_scalar1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
%1 = "onnx.Div"(%arg0, %0) : (tensor<?x10xf32>, tensor<f32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func.func @test_div_unknown_scalar1
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK: [[VAR_4_:%.+]] = "onnx.Concat"([[VAR_2_]], [[VAR_3_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK: [[VAR_5_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_4_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_1_]], [[VAR_5_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
// CHECK: }
}

// -----

// COM: Division by a scalar in case of dynamic dimensions.
func.func @test_div_unknown_scalar2(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = onnx.Constant dense<8.000000e+00> : tensor<f32>
%1 = "onnx.Div"(%0, %arg0) : (tensor<f32>, tensor<?x10xf32>) -> tensor<*xf32>
"func.return"(%1) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func.func @test_div_unknown_scalar2
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<8.000000e+00> : tensor<f32>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 0 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Dim"([[PARAM_0_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<1xi64>
// CHECK: [[VAR_3_:%.+]] = "onnx.Concat"([[VAR_1_]], [[VAR_2_]]) {axis = 0 : si64} : (tensor<1xi64>, tensor<1xi64>) -> tensor<2xi64>
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.StickifiedConstantOfShape"([[VAR_3_]]) {layout = "2D", value = 8.000000e+00 : f32} : (tensor<2xi64>) -> tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<?x10xf32>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Div"([[VAR_4_]], [[VAR_5_]]) : (tensor<?x?xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<?x10xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<?x10xf32>
// CHECK: return [[VAR_7_]] : tensor<?x10xf32>
// CHECK: }
}

// -----

// COM: Do not lower broadcasting onnx.Div to zHigh.
func.func @test_div_not_lowered_diff_shape(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10xf32>) -> tensor<*xf32> {
%0 = "onnx.Div"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10xf32>) -> tensor<*xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --printIR %s | FileCheck %s
// RUN: onnx-mlir --mcpu=z16 --maccel=NNPA --EmitMLIR --nnpa-enable-scalar-bcast-binary --printIR %s | FileCheck %s

// Check whether the compiler can remove unstick/stick so that the output of zdnn matmul is passed directly to zdnn div.
func.func @matmul_div(%arg0: tensor<?x12x?x64xf32>) -> tensor<?x?x?x?xf32> {
Expand Down
Loading