diff --git a/.azure-pipelines/Windows-CI.yml b/.azure-pipelines/Windows-CI.yml index aabb98e8d4..8303fee5aa 100644 --- a/.azure-pipelines/Windows-CI.yml +++ b/.azure-pipelines/Windows-CI.yml @@ -11,7 +11,8 @@ parameters: jobs: - job: Build_onnx_mlir_Windows - timeoutInMinutes: 240 + # 4h timeout is sometimes a tiny bit short when llvm-project is rebuilt + timeoutInMinutes: 270 pool: vmImage: 'windows-2019' diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index e28928189f..28bd140fb5 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 91088978d712cd7b33610c59f69d87d5a39e3113 && cd .. +cd llvm-project && git checkout 4acc3ffbb0af5631bc7916aeff3570f448899647 && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index e67317de87..658e455de4 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 91088978d712cd7b33610c59f69d87d5a39e3113 && cd .. +cd llvm-project && git checkout 4acc3ffbb0af5631bc7916aeff3570f448899647 && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/docs/Testing.md b/docs/Testing.md index 43e1bfa5c2..a3f7472678 100644 --- a/docs/Testing.md +++ b/docs/Testing.md @@ -36,7 +36,7 @@ The all_test_names.txt is automatically generated with command "make check-onnx- ### Adding ONNX-supported test cases to the current set of backend tests -When the ONNX-to-Krnl conversion of an operator is added, the corresponding backend tests for this operator should be added to test.py. The available test cases can be found in `third_part/onnx/onnx/backend/test/case/node`. You can identify new tests by looking for the new operator in `test/backend/all_test_names.txt`. Once you have located new tests, you may add the new tests in the `test/backend/inference_backend.py.` Please note to add suffix `_cpu` to the onnx test name. Associated with the test, you can define how to run the tests for the new operator. For example: +When the ONNX-to-Krnl conversion of an operator is added, the corresponding backend tests for this operator should be added to test.py. The available test cases can be found in `third_party/onnx/onnx/backend/test/case/node`. You can identify new tests by looking for the new operator in `test/backend/all_test_names.txt`. Once you have located new tests, you may add the new tests in the `test/backend/inference_backend.py.` Please note to add suffix `_cpu` to the onnx test name. Associated with the test, you can define how to run the tests for the new operator. For example: ``` "test_and2d_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}}, ``` diff --git a/docs/UpdatingLLVMCommit.md b/docs/UpdatingLLVMCommit.md index bf970ef656..3497552626 100644 --- a/docs/UpdatingLLVMCommit.md +++ b/docs/UpdatingLLVMCommit.md @@ -2,7 +2,7 @@ # Updating the LLVM commit or MLIR-HLO submodule -ONNX-MLIR depends on `llvm-project` (among various other projects such as `mlir-hlo`). The `llvm-project` dependency is captured in [utils/clone-mlir.sh](clone-mlir.sh). `mlir-hlo` is a submodule found in the `third_party` directory. +ONNX-MLIR depends on `llvm-project` (among various other projects such as `mlir-hlo`). The `llvm-project` dependency is captured in [../utils/clone-mlir.sh](clone-mlir.sh). `mlir-hlo` is a submodule found in the `third_party` directory. We plan to update `llvm-project` a couple of times a month in order to keep up-to-date with the advancements made in `mlir`, but also to decrease the complexity of each update. There is currently no plan to update `mlir-hlo` on any given schedule, though for a specific LLVM update it may be necessary to also update the `mlir-hlo` submodule for the build to continue working correctly. This is because `mlir-hlo` itself also has a dependency on `mlir`. @@ -17,7 +17,7 @@ We've started an update rotation that is described [here](https://github.com/onn ## What is the update process? 1. **Lookup green commit hashes**: From the Github issue https://github.com/llvm/torch-mlir/issues/1178, find the LLVM and MLIR-HLO green commits for the week when ONNX-MLIR is being updated. -2. **Update the `llvm-project` commit**: Update the LLVM commit referenced in the source tree to the green commit hash for the LLVM project from Step 1. The current locations that need to be updated are [utils/clone-mlir.sh](clone-mlir.sh), [docs/BuildOnLinuxOSX.md](BuildOnLinuxOSX.md) and [docs/BuildOnWindows.md](BuildOnWindows.md). +2. **Update the `llvm-project` commit**: Update the LLVM commit referenced in the source tree to the green commit hash for the LLVM project from Step 1. The current locations that need to be updated are [utils/clone-mlir.sh](../utils/clone-mlir.sh), [docs/BuildOnLinuxOSX.md](BuildOnLinuxOSX.md) and [docs/BuildOnWindows.md](BuildOnWindows.md). 3. **Update the `mlir-hlo` submodule**: In the `third-party/mlir-hlo` directory, run `git fetch` followed by `git checkout ` (where `` is the green commit hash for the MLIR-HLO project from Step 1). 4. **Rebuild and test ONNX-MLIR**: This might involve fixing various API breakages introduced upstream (they are likely unrelated to what you are working on). If these fixes are too complex, please file a work-in-progress PR explaining the issues you are running into asking for help so that someone from the community can help. diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td index 1ceb7ea7de..1f55e59c32 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td @@ -29,6 +29,7 @@ def ZHigh_Dialect : Dialect { let summary = "A high-level dialect for the ONNX NNPA accelerator ISA."; let cppNamespace = "::onnx_mlir::zhigh"; let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 0; } //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td index 5081194bc5..f3e2148211 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td @@ -20,6 +20,7 @@ def ZLow_Dialect : Dialect { let name = "zlow"; let summary = "A low-level dialect for the ONNX NNPA accelerator ISA."; let cppNamespace = "::onnx_mlir::zlow"; + let usePropertiesForAttributes = 0; } // Base class for ZLow dialect operations. This operation inherits from the diff --git a/src/Builder/CMakeLists.txt b/src/Builder/CMakeLists.txt index 589849e431..cfdd25751d 100644 --- a/src/Builder/CMakeLists.txt +++ b/src/Builder/CMakeLists.txt @@ -11,7 +11,6 @@ add_onnx_mlir_library(OMBuilder ModelInputShaper.cpp LINK_LIBS PUBLIC - OMCompilerOptions OMHasOnnxSubgraphOpInterface OMONNXOps OMResultTypeInferenceOpInterface diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index ab51de133e..cae6f5b3a3 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -31,7 +31,6 @@ #include "src/Builder/ImportONNXUtils.hpp" #include "src/Builder/ModelInputShaper.hpp" #include "src/Builder/SymbolTable.hpp" -#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" @@ -165,7 +164,7 @@ class FrontendGenImpl { opset_map_ = GetOpsetImportsFromProto(model); // Which opsets to use. in_model_functions_ = GetModelLocalFunctions(model); importGraph(model.graph()); - if (VerboseOutput) { + if (options_.verboseOutput) { llvm::outs() << "The ONNX model has " << num_of_parameters_ << " elements in its initializers. This value would be close to and " diff --git a/src/Builder/FrontendDialectTransformer.hpp b/src/Builder/FrontendDialectTransformer.hpp index 7d82993440..c45e286211 100644 --- a/src/Builder/FrontendDialectTransformer.hpp +++ b/src/Builder/FrontendDialectTransformer.hpp @@ -39,6 +39,7 @@ namespace onnx_mlir { * Options to control the translation of an ONNX model to ONNX-MLIR. */ struct ImportOptions { + bool verboseOutput = false; // Use types/shapes in the input-model for translation (for intermediate // variables) bool useOnnxModelTypes = false; diff --git a/src/Compiler/CMakeLists.txt b/src/Compiler/CMakeLists.txt index 1b086237be..e75a55dd94 100644 --- a/src/Compiler/CMakeLists.txt +++ b/src/Compiler/CMakeLists.txt @@ -68,6 +68,7 @@ add_onnx_mlir_library(OMCompilerDialects OMKrnlOps OMONNXOps MLIRIR + MLIROpenMPDialect ) add_onnx_mlir_library(OMCompilerPasses diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index ac300ca9ab..f873302cc5 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -90,7 +90,7 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU) { // Dynamic iterate in ONNXOpTransformPass pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold, onnxOpTransformReport, targetCPU, - enableSimdDataLayout && !disableSimdOption)); + enableSimdDataLayout && !disableSimdOption, enableConvOptPass)); } else { // Statically add extra passes for (int i = 0; i < repeatOnnxTransform; i++) { diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 6e7a21a58a..37dddc0788 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -14,6 +14,7 @@ #include "CompilerUtils.hpp" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" @@ -35,6 +36,7 @@ #include "src/Accelerators/Accelerator.hpp" #include "src/Builder/FrontendDialectTransformer.hpp" +#include "src/Builder/ModelInputShaper.hpp" #include "src/Compiler/CompilerDialects.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerPasses.hpp" @@ -183,6 +185,35 @@ static void loadMLIR(std::string inputFilename, mlir::MLIRContext &context, llvm::errs() << "Error can't load file " << inputFilename << "\n"; exit(1); } + + // Set shape information if required. + // Only set shape if the module has a single function. + uint64_t numOfFuncOp = 0; + func::FuncOp funcOp; + module->walk([&](func::FuncOp f) { + funcOp = f; + numOfFuncOp++; + }); + if ((numOfFuncOp == 1) && (!shapeInformation.empty())) { + ModelInputShaper modelInputShaper_; + modelInputShaper_.setShapeInformation(shapeInformation); + auto funcType = dyn_cast(funcOp.getFunctionType()); + ArrayRef argTypes = funcType.getInputs(); + SmallVector newArgTypes; + for (uint64_t i = 0; i < argTypes.size(); ++i) { + Type argTy = argTypes[i]; + // Get user's shape information. + argTy = modelInputShaper_.reshape(i, argTy); + // Update the arguments. + funcOp.getBody().back().getArgument(i).setType(argTy); + newArgTypes.emplace_back(argTy); + } + // Update the function type. + FunctionType newType = + FunctionType::get(&context, newArgTypes, funcType.getResults()); + ConversionPatternRewriter rewriter(&context); + rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + } } // Tailor LLVMIR to add features that cannot be done with MLIR LLVMIR. @@ -589,6 +620,7 @@ int processInputFile(StringRef inputFilename, mlir::MLIRContext &context, if (inputIsSTDIN || inputIsONNX || inputIsONNXText || inputIsJSON) { ImportOptions options; + options.verboseOutput = VerboseOutput; options.useOnnxModelTypes = useOnnxModelTypes; options.invokeOnnxVersionConverter = invokeOnnxVersionConverter; options.shapeInformation = shapeInformation; diff --git a/src/Conversion/KrnlToAffine/CMakeLists.txt b/src/Conversion/KrnlToAffine/CMakeLists.txt index 59d5017f64..33aa9c0d5c 100644 --- a/src/Conversion/KrnlToAffine/CMakeLists.txt +++ b/src/Conversion/KrnlToAffine/CMakeLists.txt @@ -13,7 +13,6 @@ add_onnx_mlir_library(OMKrnlToAffine LINK_LIBS PUBLIC OMSpecializedKernelOpInterface - OMCompilerOptions OMONNXOps OMSupport MLIRTransforms diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 86338cbbb8..cf82a9b541 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -25,7 +25,6 @@ #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/Support/Debug.h" -#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/KrnlToAffine/ConvertKrnlToAffine.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Mlir/VectorMachineSupport.hpp" @@ -712,7 +711,6 @@ void ConvertKrnlToAffinePass::runOnOperation() { MLIRContext *ctx = &getContext(); OpBuilder builder(ctx); - VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); const auto &dataLayoutAnalysis = getAnalysis(); LowerToLLVMOptions options( @@ -821,7 +819,6 @@ void ConvertKrnlToAffinePass::runOnOperation() { } delete currUnrollAndJamList; - VectorMachineSupport::clearGlobalVectorMachineSupport(); } std::unique_ptr createConvertKrnlToAffinePass() { diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index ea68c9e04a..7818a32e96 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -79,7 +79,6 @@ add_onnx_mlir_library(OMONNXToKrnl LINK_LIBS PUBLIC OMAccelerator - OMCompilerOptions OMONNXOps OMSupport MLIRFuncDialect diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 9160027281..1bd5256a51 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -19,7 +19,6 @@ #include "src/Accelerators/Accelerator.hpp" #include "src/Builder/ModelInputShaper.hpp" -#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Mlir/VectorMachineSupport.hpp" @@ -335,8 +334,6 @@ struct FrontendToKrnlLoweringPass void FrontendToKrnlLoweringPass::runOnOperation() { ModuleOp module = getOperation(); - // Define vector machine. - VectorMachineSupport::setGlobalVectorMachineSupport(march, mcpu, ""); // Perform dim analysis (useful for SIMD but also to avoid broadcast // expressions in index access patterns). DimAnalysis *dimAnalysis = new DimAnalysis(module); @@ -441,7 +438,6 @@ void FrontendToKrnlLoweringPass::runOnOperation() { if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); } - VectorMachineSupport::clearGlobalVectorMachineSupport(); delete dimAnalysis; } diff --git a/src/Dialect/Krnl/Krnl.td b/src/Dialect/Krnl/Krnl.td index deb78a675a..24eab4f3ac 100644 --- a/src/Dialect/Krnl/Krnl.td +++ b/src/Dialect/Krnl/Krnl.td @@ -29,6 +29,7 @@ include "src/Interface/SpecializedKernelOpInterface.td" def Krnl_Dialect : Dialect { let name = "krnl"; let cppNamespace = "::mlir"; + let usePropertiesForAttributes = 0; let useDefaultTypePrinterParser = 1; let dependentDialects = [ "affine::AffineDialect", diff --git a/src/Dialect/ONNX/ONNX.td b/src/Dialect/ONNX/ONNX.td index ce1b196d29..d9f7259ee1 100644 --- a/src/Dialect/ONNX/ONNX.td +++ b/src/Dialect/ONNX/ONNX.td @@ -39,6 +39,7 @@ def ONNX_Dialect : Dialect { let cppNamespace = "::mlir"; let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 0; + let usePropertiesForAttributes = 0; let dependentDialects = ["func::FuncDialect"]; let hasConstantMaterializer = 1; let extraClassDeclaration = [{ diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 809a238b48..45ee7f8449 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -411,7 +411,11 @@ ArrayAttr createArrayAttrFromConstantOp(ONNXConstantOp constOp) { DenseElementsAttr createDenseElementsAttrFromFloatAttr( PatternRewriter &rewriter, Type elementType, FloatAttr attr) { auto tensorType = RankedTensorType::get({1}, elementType); - return DenseElementsAttr::get(tensorType, {attr.getValue()}); + auto ftype = cast(elementType); + APFloat f = attr.getValue(); + bool ignored; + f.convert(ftype.getFloatSemantics(), APFloat::rmNearestTiesToEven, &ignored); + return DenseElementsAttr::get(tensorType, {f}); } //===----------------------------------------------------------------------===// diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index f8bab2115e..2597843ab4 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -31,8 +31,9 @@ std::unique_ptr createScrubDisposablePass(bool closeAfter = true); /// Pass for ONNX graph level optimization std::unique_ptr createONNXOpTransformPass(); -std::unique_ptr createONNXOpTransformPass( - int threshold, bool report, bool targetCPU, bool enableSimdDataLayoutOpt); +std::unique_ptr createONNXOpTransformPass(int threshold, + bool report, bool targetCPU, bool enableSimdDataLayoutOpt, + bool enableConvOptPass); /// Pass for rewriting inside frontend dialect. std::unique_ptr createDecomposeONNXToONNXPass( diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Transform/ONNX/CMakeLists.txt index 9178eaf5d4..7b8557c3f5 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Transform/ONNX/CMakeLists.txt @@ -38,7 +38,6 @@ add_onnx_mlir_library(OMShapeInferencePass OMShapeInferenceOpInterface MLIRFuncDialect MLIRPass - OMCompilerOptions OMShapeInference ) diff --git a/src/Transform/ONNX/ONNXOpTransformPass.cpp b/src/Transform/ONNX/ONNXOpTransformPass.cpp index cb7c5a2d45..f371e65e25 100644 --- a/src/Transform/ONNX/ONNXOpTransformPass.cpp +++ b/src/Transform/ONNX/ONNXOpTransformPass.cpp @@ -17,7 +17,6 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Pass/Passes.hpp" @@ -47,17 +46,21 @@ struct ONNXOpTransformPass : public mlir::PassWrapper enableConvOptPass{*this, "enable-conv-opt-pass", + llvm::cl::desc("Enable the ConvOptPass. Default is true."), + llvm::cl::init(true)}; ONNXOpTransformPass() = default; ONNXOpTransformPass(const ONNXOpTransformPass &pass) : mlir::PassWrapper>() {} ONNXOpTransformPass(int threshold, bool report, bool targetCPU, - bool enableSimdDataLayoutOpt) { + bool enableSimdDataLayoutOpt, bool enableConvOptPass) { this->onnxOpTransformThreshold = threshold; this->onnxOpTransformReport = report; this->onnxOpTransformTargetCPU = targetCPU; this->onnxOpTransformEnableSimdDataLayout = enableSimdDataLayoutOpt; + this->enableConvOptPass = enableConvOptPass; } void runOnOperation() final; @@ -79,7 +82,7 @@ void ONNXOpTransformPass::runOnOperation() { dynamicPM.addNestedPass( onnx_mlir::createShapeInferencePass()); // Convolution Optimization currently only for CPU. - if (onnxOpTransformTargetCPU && onnx_mlir::enableConvOptPass) { + if (onnxOpTransformTargetCPU && enableConvOptPass) { dynamicPM.addNestedPass( onnx_mlir::createConvOptONNXToONNXPass( onnxOpTransformEnableSimdDataLayout)); @@ -116,8 +119,9 @@ std::unique_ptr onnx_mlir::createONNXOpTransformPass() { return std::make_unique(); } -std::unique_ptr onnx_mlir::createONNXOpTransformPass( - int threshold, bool report, bool targetCPU, bool enableSimdDataLayoutOpt) { +std::unique_ptr onnx_mlir::createONNXOpTransformPass(int threshold, + bool report, bool targetCPU, bool enableSimdDataLayoutOpt, + bool enableConvOptPass) { return std::make_unique( - threshold, report, targetCPU, enableSimdDataLayoutOpt); + threshold, report, targetCPU, enableSimdDataLayoutOpt, enableConvOptPass); } diff --git a/test/mlir/driver/shape_information.mlir b/test/mlir/driver/shape_information.mlir new file mode 100644 index 0000000000..9904bf9444 --- /dev/null +++ b/test/mlir/driver/shape_information.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir --EmitONNXIR --shapeInformation=0:3x-1 --printIR %s | FileCheck %s + +module { +func.func @main_graph(%arg0: tensor<3x2xi64>, %arg1: tensor<3x2xi64>) -> tensor<3x2xi64> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x2xi64>, tensor<3x2xi64>) -> tensor<3x2xi64> + onnx.Return %0 : tensor<3x2xi64> + +// CHECK-LABEL main_graph +// CHECK: "onnx.Add"(%arg0, %arg1) : (tensor<3x?xi64>, tensor<3x2xi64>) -> tensor<3x2xi64 +} +} diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 87c8f2a853..f58fe49142 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -697,6 +697,26 @@ func.func @test_rewrite_batchnormtestmode_1d(%arg0 : tensor<64xf32>, %scale : te // ----- +func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale : tensor<1xf32>, %bias : tensor<1xf32>, %mean : tensor<1xf32>, %var : tensor<1xf32>) -> tensor<64xf16> { + %0 = "onnx.BatchNormalizationInferenceMode"(%arg0, %scale, %bias, %mean, %var) {epsilon = 1.00000007E-5 : f32} : (tensor<64xf16>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<64xf16> + onnx.Return %0 : tensor<64xf16> + +// CHECK-LABEL: func.func @test_rewrite_batchnormtestmode_1d_f16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<64xf16>, [[PARAM_1_:%.+]]: tensor<1xf32>, [[PARAM_2_:%.+]]: tensor<1xf32>, [[PARAM_3_:%.+]]: tensor<1xf32>, [[PARAM_4_:%.+]]: tensor<1xf32>) -> tensor<64xf16> { +// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<1.001360e-05> : tensor<1xf16> +// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_4_]], [[VAR_0_]]) : (tensor<1xf32>, tensor<1xf16>) -> tensor<*xf32> +// CHECK: [[VAR_2_:%.+]] = "onnx.Sqrt"([[VAR_1_]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_3_:%.+]] = "onnx.Div"([[PARAM_1_]], [[VAR_2_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Mul"([[PARAM_0_]], [[VAR_3_]]) : (tensor<64xf16>, tensor<*xf32>) -> tensor<*xf16> +// CHECK-DAG: [[VAR_5_:%.+]] = "onnx.Mul"([[PARAM_3_]], [[VAR_3_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_6_:%.+]] = "onnx.Sub"([[PARAM_2_]], [[VAR_5_]]) : (tensor<1xf32>, tensor<*xf32>) -> tensor<*xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Add"([[VAR_4_]], [[VAR_6_]]) : (tensor<*xf16>, tensor<*xf32>) -> tensor<64xf16> +// CHECK: onnx.Return [[VAR_7_]] : tensor<64xf16> +// CHECK: } +} + +// ----- + func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32> diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir index e96f8eb238..7c68518db8 100644 --- a/test/mlir/onnx/onnx_constprop.mlir +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -1618,6 +1618,22 @@ func.func @test_expand_broadcast() -> tensor<*xf32> { // ----- +// Expand's shape can be shorter than the data input shape. +func.func @test_expand_2_broadcast() -> tensor<*xf32> { + %0 = onnx.Constant dense<[[[1.0], [3.0], [5.0]]]> : tensor<1x3x1xf32> + %1 = onnx.Constant dense<[1, 2]> : tensor<2xi64> + %2 = "onnx.Expand"(%0, %1) : (tensor<1x3x1xf32>, tensor<2xi64>) -> tensor<*xf32> + "onnx.Return"(%2) : (tensor<*xf32>) -> () + + // CHECK-LABEL: func.func @test_expand_2_broadcast + // CHECK-SAME: () -> tensor<1x3x2xf32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.}}{{.}}[1.000000e+00, 1.000000e+00], [3.000000e+00, 3.000000e+00], [5.000000e+00, 5.000000e+00]{{.}}{{.}}> : tensor<1x3x2xf32> + // CHECK: onnx.Return [[VAR_0_]] : tensor<1x3x2xf32> + // CHECK: } +} + +// ----- + func.func @test_gather_axis_0() -> tensor<*xf32>{ %0 = onnx.Constant dense<[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]]> : tensor<3x2xf32> %1 = onnx.Constant dense<[[0, 1], [1, 2]]> : tensor<2x2xi64> diff --git a/third_party/mlir-hlo b/third_party/mlir-hlo index 01402e75f8..5ae7673e4d 160000 --- a/third_party/mlir-hlo +++ b/third_party/mlir-hlo @@ -1 +1 @@ -Subproject commit 01402e75f8e8a95aa2cd9c267ce929095a2f2e54 +Subproject commit 5ae7673e4d90377481206fb8f8d0ca56d9b28b7e diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index 031e819d9d..8c145d1071 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout 91088978d712cd7b33610c59f69d87d5a39e3113 && cd .. +cd llvm-project && git checkout 4acc3ffbb0af5631bc7916aeff3570f448899647 && cd ..