Skip to content

Commit

Permalink
Merge pull request onnx#128 from C-P2PN897/merge-main-5aca454
Browse files Browse the repository at this point in the history
Merge onnx/onnx-mlir 5aca454 into zosdev/onnx-mlir metis
  • Loading branch information
cjvolzka authored and GitHub Enterprise committed Sep 18, 2023
2 parents 9bbec14 + 1b89604 commit ea6d77e
Show file tree
Hide file tree
Showing 28 changed files with 114 additions and 31 deletions.
3 changes: 2 additions & 1 deletion .azure-pipelines/Windows-CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ parameters:

jobs:
- job: Build_onnx_mlir_Windows
timeoutInMinutes: 240
# 4h timeout is sometimes a tiny bit short when llvm-project is rebuilt
timeoutInMinutes: 270
pool:
vmImage: 'windows-2019'

Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 91088978d712cd7b33610c59f69d87d5a39e3113 && cd ..
cd llvm-project && git checkout 4acc3ffbb0af5631bc7916aeff3570f448899647 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 91088978d712cd7b33610c59f69d87d5a39e3113 && cd ..
cd llvm-project && git checkout 4acc3ffbb0af5631bc7916aeff3570f448899647 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
2 changes: 1 addition & 1 deletion docs/Testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ The all_test_names.txt is automatically generated with command "make check-onnx-

### Adding ONNX-supported test cases to the current set of backend tests

When the ONNX-to-Krnl conversion of an operator is added, the corresponding backend tests for this operator should be added to test.py. The available test cases can be found in `third_part/onnx/onnx/backend/test/case/node`. You can identify new tests by looking for the new operator in `test/backend/all_test_names.txt`. Once you have located new tests, you may add the new tests in the `test/backend/inference_backend.py.` Please note to add suffix `_cpu` to the onnx test name. Associated with the test, you can define how to run the tests for the new operator. For example:
When the ONNX-to-Krnl conversion of an operator is added, the corresponding backend tests for this operator should be added to test.py. The available test cases can be found in `third_party/onnx/onnx/backend/test/case/node`. You can identify new tests by looking for the new operator in `test/backend/all_test_names.txt`. Once you have located new tests, you may add the new tests in the `test/backend/inference_backend.py.` Please note to add suffix `_cpu` to the onnx test name. Associated with the test, you can define how to run the tests for the new operator. For example:
```
"test_and2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
```
Expand Down
4 changes: 2 additions & 2 deletions docs/UpdatingLLVMCommit.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Updating the LLVM commit or MLIR-HLO submodule

ONNX-MLIR depends on `llvm-project` (among various other projects such as `mlir-hlo`). The `llvm-project` dependency is captured in [utils/clone-mlir.sh](clone-mlir.sh). `mlir-hlo` is a submodule found in the `third_party` directory.
ONNX-MLIR depends on `llvm-project` (among various other projects such as `mlir-hlo`). The `llvm-project` dependency is captured in [../utils/clone-mlir.sh](clone-mlir.sh). `mlir-hlo` is a submodule found in the `third_party` directory.

We plan to update `llvm-project` a couple of times a month in order to keep up-to-date with the advancements made in `mlir`, but also to decrease the complexity of each update. There is currently no plan to update `mlir-hlo` on any given schedule, though for a specific LLVM update it may be necessary to also update the `mlir-hlo` submodule for the build to continue working correctly. This is because `mlir-hlo` itself also has a dependency on `mlir`.

Expand All @@ -17,7 +17,7 @@ We've started an update rotation that is described [here](https://github.com/onn
## What is the update process?

1. **Lookup green commit hashes**: From the Github issue https://github.com/llvm/torch-mlir/issues/1178, find the LLVM and MLIR-HLO green commits for the week when ONNX-MLIR is being updated.
2. **Update the `llvm-project` commit**: Update the LLVM commit referenced in the source tree to the green commit hash for the LLVM project from Step 1. The current locations that need to be updated are [utils/clone-mlir.sh](clone-mlir.sh), [docs/BuildOnLinuxOSX.md](BuildOnLinuxOSX.md) and [docs/BuildOnWindows.md](BuildOnWindows.md).
2. **Update the `llvm-project` commit**: Update the LLVM commit referenced in the source tree to the green commit hash for the LLVM project from Step 1. The current locations that need to be updated are [utils/clone-mlir.sh](../utils/clone-mlir.sh), [docs/BuildOnLinuxOSX.md](BuildOnLinuxOSX.md) and [docs/BuildOnWindows.md](BuildOnWindows.md).
3. **Update the `mlir-hlo` submodule**: In the `third-party/mlir-hlo` directory, run `git fetch` followed by `git checkout <mlir-hlo-commit-hash>` (where `<mlir-hlo-commit-hash>` is the green commit hash for the MLIR-HLO project from Step 1).
4. **Rebuild and test ONNX-MLIR**: This might involve fixing various API breakages introduced upstream (they are likely unrelated to what you are working on). If these fixes are too complex, please file a work-in-progress PR explaining the issues you are running into asking for help so that someone from the community can help.

Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def ZHigh_Dialect : Dialect {
let summary = "A high-level dialect for the ONNX NNPA accelerator ISA.";
let cppNamespace = "::onnx_mlir::zhigh";
let useDefaultAttributePrinterParser = 1;
let usePropertiesForAttributes = 0;
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def ZLow_Dialect : Dialect {
let name = "zlow";
let summary = "A low-level dialect for the ONNX NNPA accelerator ISA.";
let cppNamespace = "::onnx_mlir::zlow";
let usePropertiesForAttributes = 0;
}

// Base class for ZLow dialect operations. This operation inherits from the
Expand Down
1 change: 0 additions & 1 deletion src/Builder/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ add_onnx_mlir_library(OMBuilder
ModelInputShaper.cpp

LINK_LIBS PUBLIC
OMCompilerOptions
OMHasOnnxSubgraphOpInterface
OMONNXOps
OMResultTypeInferenceOpInterface
Expand Down
3 changes: 1 addition & 2 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "src/Builder/ImportONNXUtils.hpp"
#include "src/Builder/ModelInputShaper.hpp"
#include "src/Builder/SymbolTable.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -165,7 +164,7 @@ class FrontendGenImpl {
opset_map_ = GetOpsetImportsFromProto(model); // Which opsets to use.
in_model_functions_ = GetModelLocalFunctions(model);
importGraph(model.graph());
if (VerboseOutput) {
if (options_.verboseOutput) {
llvm::outs()
<< "The ONNX model has " << num_of_parameters_
<< " elements in its initializers. This value would be close to and "
Expand Down
1 change: 1 addition & 0 deletions src/Builder/FrontendDialectTransformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace onnx_mlir {
* Options to control the translation of an ONNX model to ONNX-MLIR.
*/
struct ImportOptions {
bool verboseOutput = false;
// Use types/shapes in the input-model for translation (for intermediate
// variables)
bool useOnnxModelTypes = false;
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ add_onnx_mlir_library(OMCompilerDialects
OMKrnlOps
OMONNXOps
MLIRIR
MLIROpenMPDialect
)

add_onnx_mlir_library(OMCompilerPasses
Expand Down
2 changes: 1 addition & 1 deletion src/Compiler/CompilerPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) {
// Dynamic iterate in ONNXOpTransformPass
pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold,
onnxOpTransformReport, targetCPU,
enableSimdDataLayout && !disableSimdOption));
enableSimdDataLayout && !disableSimdOption, enableConvOptPass));
} else {
// Statically add extra passes
for (int i = 0; i < repeatOnnxTransform; i++) {
Expand Down
32 changes: 32 additions & 0 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "CompilerUtils.hpp"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
Expand All @@ -35,6 +36,7 @@

#include "src/Accelerators/Accelerator.hpp"
#include "src/Builder/FrontendDialectTransformer.hpp"
#include "src/Builder/ModelInputShaper.hpp"
#include "src/Compiler/CompilerDialects.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
Expand Down Expand Up @@ -183,6 +185,35 @@ static void loadMLIR(std::string inputFilename, mlir::MLIRContext &context,
llvm::errs() << "Error can't load file " << inputFilename << "\n";
exit(1);
}

// Set shape information if required.
// Only set shape if the module has a single function.
uint64_t numOfFuncOp = 0;
func::FuncOp funcOp;
module->walk([&](func::FuncOp f) {
funcOp = f;
numOfFuncOp++;
});
if ((numOfFuncOp == 1) && (!shapeInformation.empty())) {
ModelInputShaper modelInputShaper_;
modelInputShaper_.setShapeInformation(shapeInformation);
auto funcType = dyn_cast<FunctionType>(funcOp.getFunctionType());
ArrayRef<Type> argTypes = funcType.getInputs();
SmallVector<Type, 4> newArgTypes;
for (uint64_t i = 0; i < argTypes.size(); ++i) {
Type argTy = argTypes[i];
// Get user's shape information.
argTy = modelInputShaper_.reshape(i, argTy);
// Update the arguments.
funcOp.getBody().back().getArgument(i).setType(argTy);
newArgTypes.emplace_back(argTy);
}
// Update the function type.
FunctionType newType =
FunctionType::get(&context, newArgTypes, funcType.getResults());
ConversionPatternRewriter rewriter(&context);
rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
}
}

// Tailor LLVMIR to add features that cannot be done with MLIR LLVMIR.
Expand Down Expand Up @@ -589,6 +620,7 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context,

if (inputIsSTDIN || inputIsONNX || inputIsONNXText || inputIsJSON) {
ImportOptions options;
options.verboseOutput = VerboseOutput;
options.useOnnxModelTypes = useOnnxModelTypes;
options.invokeOnnxVersionConverter = invokeOnnxVersionConverter;
options.shapeInformation = shapeInformation;
Expand Down
1 change: 0 additions & 1 deletion src/Conversion/KrnlToAffine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ add_onnx_mlir_library(OMKrnlToAffine

LINK_LIBS PUBLIC
OMSpecializedKernelOpInterface
OMCompilerOptions
OMONNXOps
OMSupport
MLIRTransforms
Expand Down
3 changes: 0 additions & 3 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/Support/Debug.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"
Expand Down Expand Up @@ -712,7 +711,6 @@ void ConvertKrnlToAffinePass::runOnOperation() {

MLIRContext *ctx = &getContext();
OpBuilder builder(ctx);
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");

const auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
LowerToLLVMOptions options(
Expand Down Expand Up @@ -821,7 +819,6 @@ void ConvertKrnlToAffinePass::runOnOperation() {
}

delete currUnrollAndJamList;
VectorMachineSupport::clearGlobalVectorMachineSupport();
}

std::unique_ptr<Pass> createConvertKrnlToAffinePass() {
Expand Down
1 change: 0 additions & 1 deletion src/Conversion/ONNXToKrnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ add_onnx_mlir_library(OMONNXToKrnl

LINK_LIBS PUBLIC
OMAccelerator
OMCompilerOptions
OMONNXOps
OMSupport
MLIRFuncDialect
Expand Down
4 changes: 0 additions & 4 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include "src/Accelerators/Accelerator.hpp"
#include "src/Builder/ModelInputShaper.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Mlir/VectorMachineSupport.hpp"

Expand Down Expand Up @@ -335,8 +334,6 @@ struct FrontendToKrnlLoweringPass

void FrontendToKrnlLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// Define vector machine.
VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, "");
// Perform dim analysis (useful for SIMD but also to avoid broadcast
// expressions in index access patterns).
DimAnalysis *dimAnalysis = new DimAnalysis(module);
Expand Down Expand Up @@ -441,7 +438,6 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
VectorMachineSupport::clearGlobalVectorMachineSupport();
delete dimAnalysis;
}

Expand Down
1 change: 1 addition & 0 deletions src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include "src/Interface/SpecializedKernelOpInterface.td"
def Krnl_Dialect : Dialect {
let name = "krnl";
let cppNamespace = "::mlir";
let usePropertiesForAttributes = 0;
let useDefaultTypePrinterParser = 1;
let dependentDialects = [
"affine::AffineDialect",
Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNX.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def ONNX_Dialect : Dialect {
let cppNamespace = "::mlir";
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 0;
let usePropertiesForAttributes = 0;
let dependentDialects = ["func::FuncDialect"];
let hasConstantMaterializer = 1;
let extraClassDeclaration = [{
Expand Down
6 changes: 5 additions & 1 deletion src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,11 @@ ArrayAttr createArrayAttrFromConstantOp(ONNXConstantOp constOp) {
DenseElementsAttr createDenseElementsAttrFromFloatAttr(
PatternRewriter &rewriter, Type elementType, FloatAttr attr) {
auto tensorType = RankedTensorType::get({1}, elementType);
return DenseElementsAttr::get(tensorType, {attr.getValue()});
auto ftype = cast<FloatType>(elementType);
APFloat f = attr.getValue();
bool ignored;
f.convert(ftype.getFloatSemantics(), APFloat::rmNearestTiesToEven, &ignored);
return DenseElementsAttr::get(tensorType, {f});
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions src/Pass/Passes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ std::unique_ptr<mlir::Pass> createScrubDisposablePass(bool closeAfter = true);

/// Pass for ONNX graph level optimization
std::unique_ptr<mlir::Pass> createONNXOpTransformPass();
std::unique_ptr<mlir::Pass> createONNXOpTransformPass(
int threshold, bool report, bool targetCPU, bool enableSimdDataLayoutOpt);
std::unique_ptr<mlir::Pass> createONNXOpTransformPass(int threshold,
bool report, bool targetCPU, bool enableSimdDataLayoutOpt,
bool enableConvOptPass);

/// Pass for rewriting inside frontend dialect.
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass(
Expand Down
1 change: 0 additions & 1 deletion src/Transform/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ add_onnx_mlir_library(OMShapeInferencePass
OMShapeInferenceOpInterface
MLIRFuncDialect
MLIRPass
OMCompilerOptions
OMShapeInference
)

Expand Down
16 changes: 10 additions & 6 deletions src/Transform/ONNX/ONNXOpTransformPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"

#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"

Expand Down Expand Up @@ -47,17 +46,21 @@ struct ONNXOpTransformPass : public mlir::PassWrapper<ONNXOpTransformPass,
"onnx-op-transform-simd-data-layout",
llvm::cl::desc("Enable SIMD data layout opt in op transform passes."),
llvm::cl::init(false)};
Option<bool> enableConvOptPass{*this, "enable-conv-opt-pass",
llvm::cl::desc("Enable the ConvOptPass. Default is true."),
llvm::cl::init(true)};

ONNXOpTransformPass() = default;
ONNXOpTransformPass(const ONNXOpTransformPass &pass)
: mlir::PassWrapper<ONNXOpTransformPass,
OperationPass<mlir::ModuleOp>>() {}
ONNXOpTransformPass(int threshold, bool report, bool targetCPU,
bool enableSimdDataLayoutOpt) {
bool enableSimdDataLayoutOpt, bool enableConvOptPass) {
this->onnxOpTransformThreshold = threshold;
this->onnxOpTransformReport = report;
this->onnxOpTransformTargetCPU = targetCPU;
this->onnxOpTransformEnableSimdDataLayout = enableSimdDataLayoutOpt;
this->enableConvOptPass = enableConvOptPass;
}

void runOnOperation() final;
Expand All @@ -79,7 +82,7 @@ void ONNXOpTransformPass::runOnOperation() {
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createShapeInferencePass());
// Convolution Optimization currently only for CPU.
if (onnxOpTransformTargetCPU && onnx_mlir::enableConvOptPass) {
if (onnxOpTransformTargetCPU && enableConvOptPass) {
dynamicPM.addNestedPass<func::FuncOp>(
onnx_mlir::createConvOptONNXToONNXPass(
onnxOpTransformEnableSimdDataLayout));
Expand Down Expand Up @@ -116,8 +119,9 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXOpTransformPass() {
return std::make_unique<ONNXOpTransformPass>();
}

std::unique_ptr<mlir::Pass> onnx_mlir::createONNXOpTransformPass(
int threshold, bool report, bool targetCPU, bool enableSimdDataLayoutOpt) {
std::unique_ptr<mlir::Pass> onnx_mlir::createONNXOpTransformPass(int threshold,
bool report, bool targetCPU, bool enableSimdDataLayoutOpt,
bool enableConvOptPass) {
return std::make_unique<ONNXOpTransformPass>(
threshold, report, targetCPU, enableSimdDataLayoutOpt);
threshold, report, targetCPU, enableSimdDataLayoutOpt, enableConvOptPass);
}
11 changes: 11 additions & 0 deletions test/mlir/driver/shape_information.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: onnx-mlir --EmitONNXIR --shapeInformation=0:3x-1 --printIR %s | FileCheck %s

module {
func.func @main_graph(%arg0: tensor<3x2xi64>, %arg1: tensor<3x2xi64>) -> tensor<3x2xi64> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x2xi64>, tensor<3x2xi64>) -> tensor<3x2xi64>
onnx.Return %0 : tensor<3x2xi64>

// CHECK-LABEL main_graph
// CHECK: "onnx.Add"(%arg0, %arg1) : (tensor<3x?xi64>, tensor<3x2xi64>) -> tensor<3x2xi64
}
}
20 changes: 20 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,26 @@ func.func @test_rewrite_batchnormtestmode_1d(%arg0 : tensor<64xf32>, %scale : te

// -----

func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale : tensor<1xf32>, %bias : tensor<1xf32>, %mean : tensor<1xf32>, %var : tensor<1xf32>) -> tensor<64xf16> {
%0 = "onnx.BatchNormalizationInferenceMode"(%arg0, %scale, %bias, %mean, %var) {epsilon = 1.00000007E-5 : f32} : (tensor<64xf16>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<64xf16>
onnx.Return %0 : tensor<64xf16>

// CHECK-LABEL: func.func @test_rewrite_batchnormtestmode_1d_f16
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<64xf16>, [[PARAM_1_:%.+]]: tensor<1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>, [[PARAM_3_:%.+]]: tensor<1xf32>, [[PARAM_4_:%.+]]: tensor<1xf32>) -> tensor<64xf16> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1.001360e-05> : tensor<1xf16>
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_0_]]) : (tensor<1xf32>, tensor<1xf16>) -> tensor<*xf32>
// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf16>, tensor<*xf32>) -> tensor<*xf16>
// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_5_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<*xf16>, tensor<*xf32>) -> tensor<64xf16>
// CHECK: onnx.Return [[VAR_7_]] : tensor<64xf16>
// CHECK: }
}

// -----

func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32>
Expand Down
Loading

0 comments on commit ea6d77e

Please sign in to comment.