Skip to content

Commit

Permalink
Fixing the location of DimAnalysis in onnx-to-zhigh pass and some rul…
Browse files Browse the repository at this point in the history
…es in zhigh-to-onnx pass (onnx#2794)

* Fix a typo in device placement pass

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* Fix onnx-to-zhigh

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* fix lit tests

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* Clean up

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

* cleanup

Signed-off-by: Tung D. Le <tung@jp.ibm.com>

---------

Signed-off-by: Tung D. Le <tung@jp.ibm.com>
  • Loading branch information
tungld authored Apr 11, 2024
1 parent 230c79e commit 80a63f2
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_onnx_mlir_library(OMONNXToZHigh
OMNNPACompilerOptions
OMONNXOps
OMONNXToKrnl
OMShapeInferencePass
OMZHighOps

ACCEL_INCLUDE_DIRS PRIVATE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void DevicePlacementPass::runOnOperation() {
// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
RewritePatternSet Patterns2(context);
getONNXToZHighOneOpPatterns(Patterns2);
getONNXToZHighMultipleOpPatterns(Patterns2);
(void)applyAnalysisConversion(module, target, std::move(Patterns2),
ConversionConfig{.legalizableOps = &legalizedOps2});

Expand Down
13 changes: 8 additions & 5 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"

using namespace mlir;

Expand Down Expand Up @@ -328,16 +329,13 @@ void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
patterns.insert<replaceONNXMatMulAddPattern2>(context);
patterns.insert<replaceONNXReluConvPattern>(context);
patterns.insert<replaceONNXLogSoftmaxPattern>(context);
// Shape inference for newly-added operations.
getShapeInferencePatterns(patterns);
}

void ONNXToZHighLoweringPass::runOnOperation() {
ModuleOp module = getOperation();

// Run the unknown dimension analysis to help check equality of unknown
// dimensions at compile time.
onnx_mlir::DimAnalysis dimAnalysis(module);
dimAnalysis.analyze();

// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
Expand All @@ -363,6 +361,11 @@ void ONNXToZHighLoweringPass::runOnOperation() {
// It's ok to fail.
(void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns));

// Run the unknown dimension analysis to help check equality of unknown
// dimensions at compile time.
onnx_mlir::DimAnalysis dimAnalysis(module);
dimAnalysis.analyze();

// Single ONNX to ZHigh operation lowering.
RewritePatternSet patterns(&getContext());
onnx_mlir::getONNXToZHighOneOpPatterns(patterns);
Expand Down
2 changes: 2 additions & 0 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ void ZHighToONNXLoweringPass::runOnOperation() {

RewritePatternSet patterns(&getContext());
populateWithGenerated(patterns);
zhigh::ZHighStickOp::getCanonicalizationPatterns(patterns, &getContext());
zhigh::ZHighUnstickOp::getCanonicalizationPatterns(patterns, &getContext());

(void)applyPatternsAndFoldGreedily(function, std::move(patterns));
}
Expand Down
95 changes: 70 additions & 25 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ZHighToONNX.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,62 +37,107 @@ def CreateONNXMaxOp : NativeCodeCall<"$_builder.create<ONNXMaxOp>($_loc, $0.getT
// ONNXAddOp %X = ZHighUnstickOp (ZHighAddOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighAddPattern : Pat<
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
(ONNXAddOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
def replaceZHighAddPattern1 : Pat<
(ZHighUnstickOp (ZHighAddOp (ZHighStickOp:$s_x $x, $_), $y)),
(ONNXAddOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)]
>;

def replaceZHighAddPattern2 : Pat<
(ZHighUnstickOp (ZHighAddOp $x, (ZHighStickOp:$s_y $y, $_))),
(ONNXAddOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)]
>;

//===----------------------------------------------------------------------===//
// ONNXMulOp %X = ZHighUnstickOp (ZHighMulOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMulPattern : Pat<
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
(ONNXMulOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
def replaceZHighMulPattern1 : Pat<
(ZHighUnstickOp (ZHighMulOp (ZHighStickOp:$s_x $x, $_), $y)),
(ONNXMulOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMulPattern2 : Pat<
(ZHighUnstickOp (ZHighMulOp $x, (ZHighStickOp:$s_y $y, $_))),
(ONNXMulOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [],
(addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// ONNXSubOp %X = ZHighUnstickOp (ZHighSubOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighSubPattern : Pat<
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
(ONNXSubOp $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
def replaceZHighSubPattern1 : Pat<
(ZHighUnstickOp (ZHighSubOp (ZHighStickOp:$s_x $x, $_), $y)),
(ONNXSubOp $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighSubPattern2 : Pat<
(ZHighUnstickOp (ZHighSubOp $x, (ZHighStickOp:$s_y $y, $_))),
(ONNXSubOp (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// ONNXDivOp %X = ZHighUnstickOp (ZHighDivOp (ZHighStickOp
// %X),(ZHighStickOp %Y))
// Note: turn off this pattern since NNPA is faster at this moment.
//===----------------------------------------------------------------------===//
// def replaceZHighDivPattern : Pat<
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
// (ONNXDivOp $x, $y),
// [(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
// >;
//def replaceZHighDivPattern1 : Pat<
// (ZHighUnstickOp (ZHighDivOp (ZHighStickOp:$s_x $x, $_), $y)),
// (ONNXDivOp $x, (ZHighUnstickOp $y)),
// [(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
// (addBenefit 1)
//>;
//
//def replaceZHighDivPattern2 : Pat<
// (ZHighUnstickOp (ZHighDivOp $x, (ZHighStickOp:$s_y $y, $_))),
// (ONNXDivOp (ZHighUnstickOp $x), $y),
// [(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
// (addBenefit 0)
//>;

//===----------------------------------------------------------------------===//
// ONNXMinOp %X = ZHighUnstickOp (ZHighMinOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMinPattern : Pat<
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
(CreateONNXMinOp $u, $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
def replaceZHighMinPattern1 : Pat<
(ZHighUnstickOp:$u (ZHighMinOp (ZHighStickOp:$s_x $x, $_), $y)),
(CreateONNXMinOp $u, $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMinPattern2 : Pat<
(ZHighUnstickOp:$u (ZHighMinOp $x, (ZHighStickOp:$s_y $y, $_))),
(CreateONNXMinOp $u, (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// ONNXMaxOp %X = ZHighUnstickOp (ZHighMaxOp (ZHighStickOp %X),
// (ZHighStickOp %Y))
//===----------------------------------------------------------------------===//
def replaceZHighMaxPattern : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), (ZHighStickOp:$s_y $y, $_))),
(CreateONNXMaxOp $u, $x, $y),
[(NotBlockArgument:$x), (HasOneUse:$s_x), (HasOneUse:$s_y)]
def replaceZHighMaxPattern1 : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp (ZHighStickOp:$s_x $x, $_), $y)),
(CreateONNXMaxOp $u, $x, (ZHighUnstickOp $y)),
[(NotBlockArgument:$x), (HasOneUse:$s_x)], [ ],
(addBenefit 1)
>;

def replaceZHighMaxPattern2 : Pat<
(ZHighUnstickOp:$u (ZHighMaxOp $x, (ZHighStickOp:$s_y $y, $_))),
(CreateONNXMaxOp $u, (ZHighUnstickOp $x), $y),
[(NotBlockArgument:$y), (HasOneUse:$s_y)], [ ],
(addBenefit 0)
>;

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x1xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%arg0, %arg1) {onnx_node_name = "onnx.Add1"} : (tensor<10x10xf32>, tensor<10x1xf32>) -> tensor<*xf32>
%1 = "onnx.Add"(%arg0, %0) {onnx_node_name = "onnx.Add2"} : (tensor<10x10xf32>, tensor<*xf32>) -> tensor<*xf32>
"onnx.Return"(%1) : (tensor<*xf32>) -> ()
%2 = "onnx.Relu"(%1) {onnx_node_name = "onnx.Relu"} : (tensor<*xf32>) -> tensor<*xf32>
"onnx.Return"(%2) : (tensor<*xf32>) -> ()
}

// CHECK-LABEL: func.func @test_instrument_add_onnx_zhigh
Expand All @@ -21,6 +22,9 @@ func.func @test_instrument_add_onnx_zhigh(%arg0 : tensor<10x10xf32>, %arg1 : ten
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 5 : i64} : () -> ()
// CHECK: "zhigh.Add"
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Add", tag = 6 : i64} : () -> ()
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 5 : i64} : () -> ()
// CHECK: "zhigh.Relu"
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Relu", tag = 6 : i64} : () -> ()
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 5 : i64} : () -> ()
// CHECK: "zhigh.Unstick"
// CHECK: "krnl.runtime_instrument"() {opName = "zhigh.Unstick", tag = 6 : i64} : () -> ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ func.func @test_fuse_onnx_relu_conv2d(%arg0: tensor<5x3x32x32xf32>, %arg1 : tens
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[VAR_2_]]) {layout = "HWCK"} : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>
// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<2xf32>) -> tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<*xf16>) -> tensor<5x2x31x31xf32>
// CHECK: [[VAR_5_:%.+]] = "zhigh.Conv2D"([[VAR_1_]], [[VAR_3_]], [[VAR_4_]]) {act_func = "ACT_RELU", kernel_shape = [2, 2], padding_type = "VALID_PADDING", strides = [1, 1]} : (tensor<5x32x32x3xf16, #zhigh.layout<{dataLayout = "NHWC"}>>, tensor<2x2x3x2xf16, #zhigh.layout<{dataLayout = "HWCK"}>>, tensor<2xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>
// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<5x31x31x2xf16, #zhigh.layout<{dataLayout = "NHWC"}>>) -> tensor<5x2x31x31xf32>
// CHECK: return [[VAR_6_]] : tensor<5x2x31x31xf32>
// CHECK: }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias(
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
// CHECK: }
// CHECK-NOT: "onnx.Add"
Expand All @@ -105,8 +105,8 @@ func.func @test_onnx_matmul_add_to_zhigh_1D_bias_normalized(
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<4x8xf32>) -> tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<8x16xf32>) -> tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "1D"} : (tensor<16xf32>) -> tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<*xf16>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf16>) -> tensor<4x16xf32>
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<4x8xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<8x16xf16, #zhigh.layout<{dataLayout = "2D"}>>, tensor<16xf16, #zhigh.layout<{dataLayout = "1D"}>>) -> tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<4x16xf16, #zhigh.layout<{dataLayout = "2D"}>>) -> tensor<4x16xf32>
// CHECK: return [[VAR_4_]] : tensor<4x16xf32>
// CHECK: }
// CHECK-NOT: "onnx.Add"
Expand Down
20 changes: 10 additions & 10 deletions test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/softmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ func.func @test_onnx_logsoftmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {

// CHECK-LABEL: func @test_onnx_logsoftmax
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<*xf32>
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<10x10xf32>
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<10x10xf32>) -> tensor<1x10x10xf32>
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x10x10xf32>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x10x10xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x10x10xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x10x10xf32>) -> tensor<10x10xf32>
// CHECK: return [[VAR_4_]] : tensor<10x10xf32>
// CHECK: }
}
Expand All @@ -57,11 +57,11 @@ func.func @test_onnx_logsoftmax_dyn(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {

// CHECK-LABEL: func @test_onnx_logsoftmax_dyn
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<*xf32>
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<*xf32>) -> tensor<*xf16>
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<*xf16>) -> tensor<*xf16>
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<*xf16>) -> tensor<*xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<*xf32>) -> tensor<?x?xf32>
// CHECK: [[VAR_0_:%.+]] = "onnx.UnsqueezeV11"([[PARAM_0_]]) {axes = [0]} : (tensor<?x?xf32>) -> tensor<1x?x?xf32>
// CHECK: [[VAR_1_:%.+]] = "zhigh.Stick"([[VAR_0_]]) {layout = "3DS"} : (tensor<1x?x?xf32>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
// CHECK: [[VAR_2_:%.+]] = "zhigh.Softmax"([[VAR_1_]]) {act_func = "ACT_LOG"} : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>
// CHECK: [[VAR_3_:%.+]] = "zhigh.Unstick"([[VAR_2_]]) : (tensor<1x?x?xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x?x?xf32>
// CHECK: [[VAR_4_:%.+]] = "onnx.SqueezeV11"([[VAR_3_]]) {axes = [0]} : (tensor<1x?x?xf32>) -> tensor<?x?xf32>
// CHECK: return [[VAR_4_]] : tensor<?x?xf32>
// CHECK: }
}
Expand Down

0 comments on commit 80a63f2

Please sign in to comment.