From 7f909cd498ff3502654f605050f364c1ce31c5a8 Mon Sep 17 00:00:00 2001 From: Ettore Tiotto Date: Fri, 25 Feb 2022 17:07:47 -0500 Subject: [PATCH] Split krnlToLLVM.cpp file (#1199) Signed-off-by: Ettore Tiotto --- src/Compiler/CompilerUtils.cpp | 36 +- src/Conversion/KrnlToAffine/KrnlToAffine.cpp | 2 +- src/Conversion/KrnlToLLVM/CMakeLists.txt | 17 +- .../KrnlToLLVM/ConvertKrnlToLLVM.cpp | 526 ++++++++++++++++++ .../KrnlToLLVM/ConvertKrnlToLLVM.hpp | 85 +++ src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp | 350 ++++++++++++ src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp | 134 +++++ src/Conversion/KrnlToLLVM/KrnlGetRef.cpp | 163 ++++++ src/Conversion/KrnlToLLVM/KrnlGlobal.cpp | 274 +++++++++ src/Conversion/KrnlToLLVM/KrnlInstrument.cpp | 95 ++++ src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp | 123 ++++ src/Conversion/KrnlToLLVM/KrnlPrint.cpp | 119 ++-- src/Conversion/KrnlToLLVM/KrnlPrint.hpp | 44 -- src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp | 101 ++-- src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp | 40 -- .../KrnlToLLVM/KrnlRandomNormal.cpp | 111 ++++ src/Conversion/KrnlToLLVM/KrnlStrlen.cpp | 102 ++++ src/Conversion/KrnlToLLVM/KrnlStrncmp.cpp | 78 +++ src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp | 251 +-------- src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp | 46 -- .../KrnlToLLVM/KrnlToLLVMHelper.cpp | 55 +- .../KrnlToLLVM/KrnlToLLVMHelper.hpp | 10 + src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp | 199 +++++++ .../KrnlToLLVM/KrnlVectorTypeCast.cpp | 171 ++++++ .../ONNXToKrnl/ConvertONNXToKrnl.cpp | 6 +- src/InitOMPasses.hpp | 35 +- src/Pass/Passes.hpp | 29 +- src/Transform/BundleMemoryPools.cpp | 2 +- src/Transform/DisconnectKrnlDimFromAlloc.cpp | 2 +- src/Transform/ElideKrnlGlobalConstants.cpp | 2 +- src/Transform/EnableMemoryPool.cpp | 2 +- src/Transform/LowerKrnlShape.cpp | 2 +- src/Transform/ONNX/ConstProp.cpp | 2 +- src/Transform/ONNX/Decompose.cpp | 2 +- src/Transform/ONNX/ElideConstants.cpp | 2 +- src/Transform/ONNX/InstrumentONNXPass.cpp | 2 +- src/Transform/ONNX/ONNXOpTransformPass.cpp | 11 +- src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp | 2 +- src/Transform/ONNX/ShapeInferencePass.cpp | 2 +- src/Transform/OptimizeMemoryPools.cpp | 2 +- test/onnx2mlir/CustomFnTest.cpp | 2 +- 41 files changed, 2698 insertions(+), 541 deletions(-) create mode 100644 src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp create mode 100644 src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlGetRef.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlGlobal.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlInstrument.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp delete mode 100644 src/Conversion/KrnlToLLVM/KrnlPrint.hpp delete mode 100644 src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlStrlen.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlStrncmp.cpp delete mode 100644 src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp create mode 100644 src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 385a5c36eb3a..306c6d17c68b 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -30,6 +30,7 @@ #include "ExternalUtil.hpp" #include "src/Compiler/CompilerUtils.hpp" +#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" #include "src/Support/OMOptions.hpp" #define DEBUG_TYPE "compiler_utils" @@ -713,23 +714,23 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { // In future, only the dynamic pass, ONNXOpTransformPass, will be used for // this function. - pm.addNestedPass(mlir::createDecomposeONNXToONNXPass()); - pm.addPass(mlir::createShapeInferencePass()); + pm.addNestedPass(onnx_mlir::createDecomposeONNXToONNXPass()); + pm.addPass(onnx_mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(onnx_mlir::createShapeInferencePass()); // There are more opportunities for const propagation once all tensors have // inferred shapes. - pm.addNestedPass(mlir::createConstPropONNXToONNXPass()); + pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); if (onnxOpTransformThreshold > 0) { // Dynamic iterate in ONNXOpTransformPass - pm.addPass(mlir::createONNXOpTransformPass(onnxOpTransformThreshold)); + pm.addPass(onnx_mlir::createONNXOpTransformPass(onnxOpTransformThreshold)); } else { // Statically add extra passes for (int i = 0; i < repeatOnnxTransform; i++) { pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createShapeInferencePass()); - pm.addNestedPass(mlir::createConstPropONNXToONNXPass()); + pm.addPass(onnx_mlir::createShapeInferencePass()); + pm.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); } } @@ -738,10 +739,10 @@ void addONNXToMLIRPasses(mlir::PassManager &pm) { } void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel) { - pm.addNestedPass(mlir::createONNXPreKrnlVerifyPass()); + pm.addNestedPass(onnx_mlir::createONNXPreKrnlVerifyPass()); // Add instrumentation for Onnx Ops - pm.addNestedPass(mlir::createInstrumentONNXPass()); - pm.addPass(mlir::createLowerToKrnlPass(optLevel)); + pm.addNestedPass(onnx_mlir::createInstrumentONNXPass()); + pm.addPass(onnx_mlir::createLowerToKrnlPass(optLevel)); // An additional pass of canonicalization is helpful because lowering // from ONNX dialect to Standard dialect exposes additional canonicalization // opportunities. @@ -751,7 +752,7 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel) { } void addKrnlToAffinePasses(mlir::PassManager &pm) { - pm.addNestedPass(mlir::createConvertKrnlToAffinePass()); + pm.addNestedPass(onnx_mlir::createConvertKrnlToAffinePass()); // Fuse loops in Affine dialect. // pm.addPass(mlir::createLoopFusionPass()); } @@ -766,15 +767,15 @@ void addKrnlToLLVMPasses(mlir::OpPassManager &pm) { // https://mlir.llvm.org/docs/BufferDeallocationInternals. pm.addNestedPass(mlir::bufferization::createBufferDeallocationPass()); if (enableMemoryBundling) { - pm.addNestedPass(mlir::createKrnlEnableMemoryPoolPass()); - pm.addNestedPass(mlir::createKrnlBundleMemoryPoolsPass()); + pm.addNestedPass(krnl::createKrnlEnableMemoryPoolPass()); + pm.addNestedPass(krnl::createKrnlBundleMemoryPoolsPass()); pm.addPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::createKrnlOptimizeMemoryPoolsPass()); + pm.addNestedPass(krnl::createKrnlOptimizeMemoryPoolsPass()); } pm.addNestedPass(mlir::createConvertSCFToCFPass()); - pm.addPass(mlir::createConvertKrnlToLLVMPass()); + pm.addPass(krnl::createConvertKrnlToLLVMPass()); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); pm.addPass(mlir::createCanonicalizerPass()); } @@ -910,10 +911,11 @@ void emitOutputFiles(string outputBaseName, EmissionTargetType emissionTarget, mlir::PassManager cleanSourcePM( &context, mlir::OpPassManager::Nesting::Implicit); if (emissionTarget == EmitONNXIR || emissionTarget == EmitONNXBasic) - cleanSourcePM.addNestedPass(mlir::createElideConstantValuePass()); + cleanSourcePM.addNestedPass( + onnx_mlir::createElideConstantValuePass()); if (emissionTarget == EmitMLIR) cleanSourcePM.addNestedPass( - mlir::createElideConstGlobalValuePass()); + onnx_mlir::createElideConstGlobalValuePass()); if (emissionTarget == EmitONNXBasic || emissionTarget == EmitONNXIR || emissionTarget == EmitMLIR) { diff --git a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp index 1f48f02bd851..3d712ea52ed9 100644 --- a/src/Conversion/KrnlToAffine/KrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/KrnlToAffine.cpp @@ -1533,6 +1533,6 @@ void ConvertKrnlToAffinePass::runOnOperation() { } // namespace -std::unique_ptr mlir::createConvertKrnlToAffinePass() { +std::unique_ptr onnx_mlir::createConvertKrnlToAffinePass() { return std::make_unique(); } diff --git a/src/Conversion/KrnlToLLVM/CMakeLists.txt b/src/Conversion/KrnlToLLVM/CMakeLists.txt index db09f8b8c48a..41e97aa15093 100644 --- a/src/Conversion/KrnlToLLVM/CMakeLists.txt +++ b/src/Conversion/KrnlToLLVM/CMakeLists.txt @@ -1,10 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 add_onnx_mlir_library(OMKrnlToLLVM - KrnlToLLVM.cpp - KrnlToLLVMHelper.cpp + ConvertKrnlToLLVM.cpp + KrnlFindIndex.cpp + KrnlEntryPoint.cpp + KrnlGetRef.cpp + KrnlGlobal.cpp + KrnlInstrument.cpp + KrnlMemcpy.cpp KrnlPrintTensor.cpp - KrnlPrint.cpp + KrnlPrint.cpp + KrnlRandomNormal.cpp + KrnlStrlen.cpp + KrnlStrncmp.cpp + KrnlToLLVMHelper.cpp + KrnlUnaryMath.cpp + KrnlVectorTypeCast.cpp RuntimeAPI.cpp LINK_LIBS PUBLIC diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp new file mode 100644 index 000000000000..8c147ac6bd12 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp @@ -0,0 +1,526 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ConvertKrnlToLLVM.cpp - Krnl Dialect Lowering ---------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements the lowering of Krnl operations to a combination of +// other dialects (affine, std, LLVM). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Endian.h" + +#include "onnx/onnx_pb.h" + +#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Support/Common.hpp" + +using namespace mlir; + +#define DEBUG_TYPE "krnl_to_llvm" + +namespace onnx_mlir { +namespace krnl { + +static void checkConstantOutputs( + ModuleOp &module, SmallVectorImpl &constantOutputs) { + Operation *entryPointOp; + auto walkResult = module->walk([&](mlir::Operation *op) -> WalkResult { + if (llvm::dyn_cast(op)) { + entryPointOp = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + // Do nothing if there is no EntryPoint. + if (!walkResult.wasInterrupted()) + return; + + // Get entry function name. + StringRef entryPointFuncName = + entryPointOp + ->getAttrOfType( + KrnlEntryPointOp::getEntryPointFuncAttrName()) + .getLeafReference() + .getValue(); + + // Get entry function op. + Operation *entryFunc; + module->walk([&](FuncOp op) -> WalkResult { + if (SymbolRefAttr::get(op).getValue() == entryPointFuncName) { + entryFunc = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + assert(entryFunc && "Entry function not found"); + + // Get ReturnOp of the entry function op. + Operation *returnOp; + entryFunc->walk([&](Operation *op) -> WalkResult { + if (llvm::dyn_cast(op)) { + returnOp = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + // Check, for each output, if it was transitively produced by a constant or + // not. + for (Value v : returnOp->getOperands()) { + bool isConstant = false; + Operation *definingOp = v.getDefiningOp(); + if (!definingOp) + // Block argument, not a constant. + isConstant = false; + else { + // If output is just a view, trace back to find which op was producing the + // source memref. + while (auto viewOp = llvm::dyn_cast(definingOp)) { + Value source = viewOp.getViewSource(); + definingOp = source.getDefiningOp(); + // Block argument, stop. + if (!definingOp) + break; + } + if (!definingOp) + // Block argument, not a constant. + isConstant = false; + else if (llvm::dyn_cast(definingOp)) + // A constant defined by KrnlGlobalOp. + isConstant = true; + } + constantOutputs.emplace_back(isConstant); + LLVM_DEBUG(llvm::dbgs() + << "Is entry function output constant? " << isConstant << "\n"); + } +} + +static void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, + LLVMTypeConverter &typeConverter, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint) { + // TODO: look at what is done in + // mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp in function + // LowerVectorToLLVMPass::runOnOperation() and see what we should do about it. + // They run it in two steps, and add additional lowerings. + + vector::populateVectorToVectorCanonicalizationPatterns(patterns); + // Removed in upgrade of LLVM: + // vector::populateVectorSlicesLoweringPatterns(patterns); + vector::populateVectorBroadcastLoweringPatterns(patterns); + vector::populateVectorContractLoweringPatterns(patterns); + vector::populateVectorTransposeLoweringPatterns(patterns); + + populateAffineToStdConversionPatterns(patterns); + populateSCFToControlFlowConversionPatterns(patterns); + + populateShapeToStandardConversionPatterns(patterns); + populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); + populateVectorToLLVMConversionPatterns(typeConverter, patterns); + populateVectorToLLVMMatrixConversionPatterns(typeConverter, patterns); + memref::populateExpandOpsPatterns(patterns); + // Use polynomial approximation for math.{tanh, sin, cos and exp} for better + // performance. + populateMathPolynomialApproximationPatterns(patterns); + arith::populateArithmeticExpandOpsPatterns(patterns); + populateMathToLLVMConversionPatterns(typeConverter, patterns); + populateStdToLLVMConversionPatterns(typeConverter, patterns); + populateMemRefToLLVMConversionPatterns(typeConverter, patterns); + arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); + cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); + + populateReconcileUnrealizedCastsPatterns(patterns); + krnl::populateKrnlToLLVMConversion( + typeConverter, patterns, ctx, constantOutputs, singleEntryPoint); +} + +void recordEntryPointSignatures(ModuleOp &module, + SmallVectorImpl &entryPointNames, + SmallVectorImpl &inSignatures, + SmallVectorImpl &outSignatures) { + module->walk([&](KrnlEntryPointOp entryOp) -> WalkResult { + Operation *op = entryOp.getOperation(); + // Entry point name. + llvm::StringRef entryPointName = + op->getAttrOfType( + KrnlEntryPointOp::getEntryPointFuncAttrName()) + .getLeafReference() + .getValue(); + std::string terminatedEntryPointName = "run_" + entryPointName.str(); + terminatedEntryPointName.push_back('\0'); // null terminate the string. + entryPointNames.emplace_back(terminatedEntryPointName); + + // Input/output signatures. + StringAttr sigAttr = + op->getAttrOfType(KrnlEntryPointOp::getSignatureAttrName()); + llvm::StringRef signature = sigAttr.getValue(); + auto splitSig = signature.split('@'); + llvm::StringRef inSig = splitSig.first; + llvm::StringRef outSig = splitSig.second; + inSignatures.emplace_back(inSig.str()); + outSignatures.emplace_back(outSig.str()); + + return WalkResult::advance(); + }); + + // When there is only a single entry point function in a model, use + // DEFAULT_DYN_ENTRY_POINT. + if (entryPointNames.size() == 1) { + entryPointNames[0] = DEFAULT_DYN_ENTRY_POINT; + entryPointNames[0].push_back('\0'); // null terminate the string. + } +} + +/// This function emits three functions: omQueryEntryPoints, omInputSignature +/// and omOutputSignature. +/// - omQueryEntryPoints has type of `**i8 ()` to query an array of entry point +/// names. +/// - omInputSignature and omOutputSignature have type of type `*i8 (*i8)` to +/// return input and output signatures of the given entry point. +void genSignatureFunction(ModuleOp module, + const ArrayRef entryPointNames, + const ArrayRef inSignatures, + const ArrayRef outSignatures) { + MLIRContext *context = module.getContext(); + Location loc = module.getLoc(); + OpBuilder b(context); + + // Common information. + Type i8Type = IntegerType::get(context, 8); + Type i32Type = IntegerType::get(context, 32); + Type i64Type = IntegerType::get(context, 64); + Type i8PtrTy = LLVM::LLVMPointerType::get(i8Type); + Type i8PtrPtrTy = LLVM::LLVMPointerType::get(i8PtrTy); + IntegerAttr zeroI32Attr = b.getI32IntegerAttr(0); + IntegerAttr zeroI64Attr = b.getI64IntegerAttr(0); + IntegerAttr oneI64Attr = b.getI64IntegerAttr(1); + + uint64_t numOfEntryPoints = entryPointNames.size(); + + // A helper function to emit a global constant operation storing a string. + auto emitGlobalOp = [&context, &b, &loc, &i8Type]( + std::string name, std::string value) { + mlir::StringAttr valueAttr = mlir::StringAttr::get(context, value); + Type valueArrayType = LLVM::LLVMArrayType::get(i8Type, value.size()); + LLVM::GlobalOp globalOp = b.create(loc, valueArrayType, + /*isConstant=*/true, LLVM::Linkage::External, name, valueAttr); + return globalOp; + }; + + // A helper function to get a pointer to the first element in an array. + auto getGlobalOpGEP = [&loc, &b, &i8PtrTy, &i64Type, &zeroI64Attr]( + LLVM::GlobalOp op) { + Value zeroI64 = b.create(loc, i64Type, zeroI64Attr); + Value address = b.create(loc, op); + LLVM::GEPOp gepOp = b.create( + loc, i8PtrTy, address, ArrayRef({zeroI64, zeroI64})); + return gepOp; + }; + + // For each entry point name, emit three global constants to store the entry + // point name and input/output signatures. For the i-th entry point, these + // constants are named as follows: + // - Entry point name: `_entry_point_i`. + // - Input signature: `_entry_point_i_in_sig`. + // - Output signature: `_entry_point_i_out_sig`. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToStart(module.getBody()); + SmallVector entryOps, inSigOps, outSigOps; + for (uint64_t i = 0; i < numOfEntryPoints; ++i) { + // Global constants for entry point names. + std::string entryVarName = "_entry_point_" + std::to_string(i); + LLVM::GlobalOp entryOp = emitGlobalOp(entryVarName, entryPointNames[i]); + entryOps.emplace_back(entryOp); + + // Global constants for input signatures. + std::string inSigVarName = entryVarName + "_in_sig"; + LLVM::GlobalOp inSigOp = emitGlobalOp(inSigVarName, inSignatures[i]); + inSigOps.emplace_back(inSigOp); + + // Global constants for output signatures. + std::string outSigVarName = entryVarName + "_out_sig"; + LLVM::GlobalOp outSigOp = emitGlobalOp(outSigVarName, outSignatures[i]); + outSigOps.emplace_back(outSigOp); + } + + // Emit a global constant to store an array of pointers pointing to each entry + // point constants. The array ends with NULL. + auto arrayType = LLVM::LLVMArrayType::get(i8PtrTy, entryOps.size() + 1); + auto entryArrayOp = b.create(loc, arrayType, + /*isConstant=*/true, LLVM::Linkage::Internal, "_entry_point_arrays", + Attribute()); + { // Fill the initializer with pointers to entry point constants. + Region ®ion = entryArrayOp.getInitializerRegion(); + Block *block = b.createBlock(®ion); + + // Initialize an array with the addresses of the global strings. + b.setInsertionPointToStart(block); + Value array = b.create(loc, arrayType); + + uint32_t index = 0; + Value lastValue = array; + for (const LLVM::GlobalOp &globalOp : entryOps) { + LLVM::GEPOp strAddr = getGlobalOpGEP(globalOp); + lastValue = b.create(loc, arrayType, lastValue, + strAddr, b.getArrayAttr({b.getIndexAttr(index++)})); + } + + // The last element of the array is NULL. + Value nullPtr = b.create(loc, i8PtrTy); + lastValue = b.create(loc, arrayType, lastValue, + nullPtr, b.getArrayAttr({b.getIndexAttr(index++)})); + b.create(loc, ArrayRef({lastValue})); + } + + // Emit a function, omQueryEntryPoints, of type `**8 ()` to query an array of + // entry point names. + { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(module.getBody()); + // Emit the function type. + Type llvmFnType = LLVM::LLVMFunctionType::get(i8PtrPtrTy, {}, false); + LLVM::LLVMFuncOp funcOp = + b.create(loc, "omQueryEntryPoints", llvmFnType); + // Emit the body of the function. + Block *entryBlock = funcOp.addEntryBlock(); + OpBuilder::InsertionGuard bodyGuard(b); + b.setInsertionPointToStart(entryBlock); + Value entryAddr = b.create(loc, entryArrayOp); + Value entryI8Ptr = b.create(loc, i8PtrPtrTy, entryAddr); + b.create(loc, ArrayRef({entryI8Ptr})); + } + + // Emit two signature functions, omInputSignature and omOutputSignature, of + // type `*i8 (*i8)` at the end of the module. + SmallVector funcNames = { + "omInputSignature", "omOutputSignature"}; + SmallVector, 2> sigOps = {inSigOps, outSigOps}; + for (uint64_t i = 0; i < funcNames.size(); ++i) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(module.getBody()); + // 1. Emit the function type. + Type llvmFnType = LLVM::LLVMFunctionType::get(i8PtrTy, {i8PtrTy}, false); + LLVM::LLVMFuncOp funcOp = + b.create(loc, funcNames[i], llvmFnType); + + // 2. Emit the body of the function. + Block *entryBlock = funcOp.addEntryBlock(); + OpBuilder::InsertionGuard bodyGuard(b); + b.setInsertionPointToStart(entryBlock); + + Value zeroI32 = b.create(loc, i32Type, zeroI32Attr); + Value oneI64 = b.create(loc, i64Type, oneI64Attr); + + // 2.1 A buffer to keep a pointer pointing to the return signature string. + Value ptrToReturnSig = b.create(loc, i8PtrPtrTy, oneI64, + /*alignment=*/0); + + // 2.2 The name of the entry point that we want to return its signature. + Value input = entryBlock->getArgument(0); + + // 2.3 Emit code to find the signature of the given entry point. + // Iterate over the list of the entry points and check string equality. + + // Split the current block into condition, true, false, and end blocks. + // - If the user's entry point name is found, go to the true block, then the + // end block. + // - Otherwise, recursively split the false block. + Block *condBlock, *trueBlock, *falseBlock, *endBlock; + condBlock = b.getInsertionBlock(); + trueBlock = condBlock->splitBlock(b.getInsertionPoint()); + falseBlock = b.createBlock( + trueBlock->getParent(), std::next(Region::iterator(trueBlock))); + endBlock = b.createBlock( + falseBlock->getParent(), std::next(Region::iterator(falseBlock))); + + // Emit code for the end block. + b.setInsertionPointToStart(endBlock); + Value res = b.create(loc, i8PtrTy, ptrToReturnSig); + b.create(loc, ArrayRef({res})); + + // Emit code for the condition, true and false blocks. + for (uint64_t j = 0; j < numOfEntryPoints; ++j) { + LLVM::GlobalOp globalEntryPoint = entryOps[j]; + LLVM::GlobalOp globalSignature = sigOps[i][j]; + std::string entryPointName = entryPointNames[j]; + // Emit code for the condition block. + b.setInsertionPointToEnd(condBlock); + // Read an entry point name. + Value entryI8Ptr = getGlobalOpGEP(globalEntryPoint).getResult(); + // Compare it with the user's entry point name. + FlatSymbolRefAttr StrncmpRef = krnl::getOrInsertStrncmp(b, module); + Value length = b.create( + loc, i64Type, b.getI64IntegerAttr(entryPointName.size())); + Value strncmpResult = b.create(loc, i32Type, StrncmpRef, + ArrayRef({input, entryI8Ptr, length})) + .getResult(0); + // Equal if strncmp returns `0`. + Value found = b.create( + loc, LLVM::ICmpPredicate::eq, strncmpResult, zeroI32); + llvm::SmallVector results = {entryI8Ptr}; + // Branch the block into the true and false blocks. + b.create( + loc, found, trueBlock, ValueRange(), falseBlock, ValueRange()); + + // Emit code for the true block. + b.setInsertionPointToStart(trueBlock); + Value sigAddr = b.create(loc, globalSignature); + Value sigI8Ptr = b.create(loc, i8PtrTy, sigAddr); + b.create(loc, sigI8Ptr, ptrToReturnSig); + b.create(loc, ValueRange(), endBlock); + + // Emit code for the false block. + b.setInsertionPointToStart(falseBlock); + if (j == numOfEntryPoints - 1) + b.create(loc, ValueRange(), endBlock); + else { + // Recursively do with the other entry point names. + condBlock = b.getInsertionBlock(); + trueBlock = condBlock->splitBlock(b.getInsertionPoint()); + falseBlock = b.createBlock( + trueBlock->getParent(), std::next(Region::iterator(trueBlock))); + } + } + } +} + +//===----------------------------------------------------------------------===// +// Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +struct ConvertKrnlToLLVMPass + : public PassWrapper> { + + StringRef getArgument() const override { return "convert-krnl-to-llvm"; } + + StringRef getDescription() const override { + return "Lower the Krnl Affine and Std dialects to LLVM."; + } + + void runOnOperation() final; +}; + +void ConvertKrnlToLLVMPass::runOnOperation() { + ModuleOp module = getOperation(); + const auto &dataLayoutAnalysis = getAnalysis(); + LowerToLLVMOptions options( + &getContext(), dataLayoutAnalysis.getAtOrAbove(module)); + options.emitCWrappers = true; + + // Determine, for each output, whether it is a constant or not. + SmallVector constantOutputs; + checkConstantOutputs(module, constantOutputs); + + // Record entry point names and their input/output signatures. + // This info is used to generate global signature functions. + SmallVector entryPointNames, inSignatures, outSignatures; + recordEntryPointSignatures( + module, entryPointNames, inSignatures, outSignatures); + + // Define the target for this lowering i.e. the LLVM dialect. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + // Convert types to legal types for the LLVM dialect. + LLVMTypeConverter typeConverter(&getContext(), options); + + typeConverter.addConversion([&](MemRefType type) -> llvm::Optional { + Type elementType = type.getElementType(); + if (!elementType.isa()) + return llvm::None; + + elementType = elementType.cast().getLLVMType(type.getContext()); + return typeConverter.convertType( + MemRefType::get(type.getShape(), elementType)); + }); + + typeConverter.addConversion([&](StringType type) -> Type { + return typeConverter.convertType(type.getLLVMType(type.getContext())); + }); + + // We have a combination of `krnl`, `affine`, `vector`, and `std` operations. + // We lower in stages until all the code is in the LLVM dialect. + RewritePatternSet patterns(&getContext()); + + populateAffineAndKrnlToLLVMConversion(patterns, typeConverter, &getContext(), + constantOutputs, /*singleEntryPoint=*/entryPointNames.size() == 1); + + // We want to completely lower to LLVM, so we use a `FullConversion`. This + // ensures that only legal operations will remain after the conversion. + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + } + + // Generate signature functions. + if (entryPointNames.size() >= 1) + genSignatureFunction(module, entryPointNames, inSignatures, outSignatures); +} + +/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM. +std::unique_ptr createConvertKrnlToLLVMPass() { + return std::make_unique(); +} + +void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint) { + krnl::populateLoweringKrnlEntryPointOpPattern( + typeConverter, patterns, ctx, constantOutputs, singleEntryPoint); + krnl::populateLoweringKrnlFindIndexOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlGlobalOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlGetRefOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlInstrumentOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlMemcpyOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlVectorTypeCastOpPattern( + typeConverter, patterns, ctx); + krnl::populateLoweringKrnlRandomNormalOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlStrlenOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlUnaryMathOpPattern(typeConverter, patterns, ctx); + krnl::populateLoweringKrnlStrncmpOpPattern(typeConverter, patterns, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp new file mode 100644 index 000000000000..f6ee369cb86a --- /dev/null +++ b/src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp @@ -0,0 +1,85 @@ + +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//====------ ConvertKrnlToLLVM.hpp - Krnl Dialect Lowering ---------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// Lowering of Krnl operations to a combination of other dialects. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Pass/Passes.hpp" +#include "src/Support/Common.hpp" + +const std::string DEFAULT_DYN_ENTRY_POINT = "run_main_graph"; + +using namespace mlir; + +namespace onnx_mlir { +namespace krnl { + +void populateKrnlToLLVMConversion(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint); + +void populateLoweringKrnlEntryPointOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint); + +void populateLoweringKrnlFindIndexOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlGetRefOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlGlobalOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlInstrumentOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlMemcpyOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlPrintOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlPrintTensorOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlRandomNormalOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlStrlenOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlStrncmpOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlUnaryMathOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx); + +void populateLoweringKrnlVectorTypeCastOpPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + MLIRContext *ctx); + +void recordEntryPointSignatures(ModuleOp &module, + SmallVectorImpl &entryPointNames, + SmallVectorImpl &inSignatures, + SmallVectorImpl &outSignatures); + +void genSignatureFunction(ModuleOp module, + const ArrayRef entryPointNames, + const ArrayRef inSignatures, + const ArrayRef outSignatures); +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp new file mode 100644 index 000000000000..e68db512214f --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp @@ -0,0 +1,350 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlEntryPoint.cpp - Lower KrnlEntryPointOp -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlEntryPointOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.hpp" +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlEntryPointOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ArrayRef constantOutputs; + bool singleEntryPoint; + + KrnlEntryPointOpLowering(TypeConverter typeConverter, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint) + : OpRewritePattern(ctx), + constantOutputs(constantOutputs), singleEntryPoint(singleEntryPoint) {} + + LogicalResult matchAndRewrite( + KrnlEntryPointOp op, PatternRewriter &rewriter) const override { + + auto module = op->getParentOfType(); + auto *context = module.getContext(); + const RuntimeAPIRegistry &apiRegistry = + RuntimeAPIRegistry::build(module, rewriter); + auto loc = op.getLoc(); + auto numOutputs = op->getAttrOfType( + KrnlEntryPointOp::getNumOutputsAttrName()) + .getInt(); + + auto opaquePtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); + auto int64Ty = IntegerType::get(context, 64); + + // Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic + // signature. The signature is dynamic because it remains the same no matter + // what the model input/output schema look like. Such dynamic signature + // takes a opaque ptr as input, representing a ptr to a data structure + // containing a set of dynamic memrefs wrapped in a vector; similarly the + // output is also a opaque ptr to a data structure with output memrefs + // wrapped within it. + auto staticEntryPointFuncName = + op->getAttrOfType( + KrnlEntryPointOp::getEntryPointFuncAttrName()) + .getLeafReference() + .getValue(); + + // When there is only a single entry point function in a model, use + // DEFAULT_DYN_ENTRY_POINT. + std::string dynEntryPointName = "run_" + staticEntryPointFuncName.str(); + if (singleEntryPoint) + dynEntryPointName = DEFAULT_DYN_ENTRY_POINT; + rewriter.eraseOp(op); + auto dynEntryPointFuncTy = + LLVM::LLVMFunctionType::get(opaquePtrTy, {opaquePtrTy}, false); + auto dynamicEntryPointFunc = rewriter.create( + loc, dynEntryPointName, dynEntryPointFuncTy); + auto &entryPointEntryBlock = + createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc, loc); + rewriter.setInsertionPointToStart(&entryPointEntryBlock); + + // Based on the static entry point type signature, unpack dynamic memory + // refs to corresponding static memory refs. + auto wrappedStaticEntryPointFuncName = + "_mlir_ciface_" + staticEntryPointFuncName.lower(); + auto *staticEntryPointFunc = + module.lookupSymbol(wrappedStaticEntryPointFuncName); + assert(staticEntryPointFunc && + isa(staticEntryPointFunc) && + "entry point func must exist and be an llvm func op"); + auto staticEntryPointTy = dyn_cast(staticEntryPointFunc) + .getType() + .dyn_cast(); + + // Retrieve dynamic mem refs from wrapped input, and convert every one of + // them to static mem refs. + SmallVector staticInputs; + auto wrappedInput = entryPointEntryBlock.getArgument(0); + + Value omTensorPtrArr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::GET_OMT_ARRAY, {wrappedInput}); + auto one = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(1)); + + // Create a memref type for the return argument of the iface call + Type memRefOutPtrTy = staticEntryPointTy.getParamType(0); + Value ptrToOutMemRef = + rewriter.create(loc, memRefOutPtrTy, one, + /*alignment=*/0); + staticInputs.emplace_back(ptrToOutMemRef); + + // Start with param 1 because 0 is the return value + for (size_t i = 1; i < staticEntryPointTy.getNumParams(); i++) { + // Call API function to retrieve the i-th dynamic memref. + auto idxVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(i - 1)); + + auto omTensorPtrAddrTy = LLVM::LLVMPointerType::get(opaquePtrTy); + auto omTensorPtrAddr = rewriter + .create(loc, omTensorPtrAddrTy, + omTensorPtrArr, ArrayRef({idxVal})) + .getResult(); + auto omTensorPtr = + rewriter.create(loc, opaquePtrTy, omTensorPtrAddr) + .getResult(); + + // Create a (static) memref type corresponding to the i-th memref input to + // the inference function on stack, and load it to memRef. + auto memRefPtrTy = staticEntryPointTy.getParamType(i); + + Value ptrToMemRef = rewriter.create(loc, memRefPtrTy, one, + /*alignment=*/0); + + // Fill in the memref underlying ptrToMemRef with information extracted + // from omTensorPtr. + fillPtrToMemRefWithOMTensor( + omTensorPtr, ptrToMemRef, rewriter, loc, apiRegistry, module); + + // ptrToMemRef will be an input to main computation graph function. + staticInputs.emplace_back(ptrToMemRef); + } + + // Call static entry point with the memref ptrs created, and get output. + rewriter.create( + loc, ArrayRef({}), wrappedStaticEntryPointFuncName, staticInputs); + auto outMemRefs = rewriter.create(loc, ptrToOutMemRef); + auto outMemRefsType = outMemRefs.getType().dyn_cast(); + + std::vector outMemRefList; + if (numOutputs == 1) { + // If only one output tensor exists, the tensor's corresponding memref + // descriptor will be returned as is. + outMemRefList.emplace_back(outMemRefs); + } else { + // Otherwise, if multiple tensors are to be returned, the returned value + // is a struct. Multiple tensors' memref descriptors are packed into the + // same struct. So we unpack them iteratively to outMemRefList. + for (int i = 0; i < numOutputs; i++) { + auto position = rewriter.getArrayAttr({rewriter.getI64IntegerAttr(i)}); + auto type = outMemRefsType.getBody()[i]; + auto extractOp = rewriter.create(loc, + /*res=*/type, + /*type=*/outMemRefs, + /*position=*/position); + outMemRefList.emplace_back(extractOp.getResult()); + } + } + + auto numOutput = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(outMemRefList.size())); + + auto mallocSym = getOrInsertMalloc(rewriter, module); + // TODO(tjingrant): get pointer size from data layout. + size_t kPtrSize = 8; + auto outputOmtPtrsArraySizeInByte = rewriter.create(loc, + int64Ty, rewriter.getI64IntegerAttr(outMemRefList.size() * kPtrSize)); + auto outOmtPtrsArr = + rewriter + .create(loc, + LLVM::LLVMPointerType::get( + IntegerType::get(module.getContext(), 8)), + mallocSym, ArrayRef(outputOmtPtrsArraySizeInByte)) + .getResult(0); + outOmtPtrsArr = rewriter + .create(loc, + LLVM::LLVMPointerType::get( + LLVM::LLVMPointerType::get( + IntegerType::get(module.getContext(), 8)), + 0), + outOmtPtrsArr) + .getResult(); + + for (unsigned int i = 0; i < outMemRefList.size(); i++) { + // Get the i-th memref returned, convert to a dynamic memref and store it + // in the wrappedOutput. + + auto memRef = outMemRefList.at(i); + auto outMemRefTy = memRef.getType().dyn_cast(); + auto outMemRefRank = krnl::getRankFromMemRefType(outMemRefTy); + auto outMemRefRankVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(outMemRefRank)); + Value outOMTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::CREATE_OMTENSOR, {outMemRefRankVal}); + // If output is a constant tensor, OMTensor does not own it. + bool outOwning = constantOutputs[i] ? false : true; + LLVM_DEBUG(llvm::dbgs() << "Output OMTensor " << i + << " with owning = " << outOwning << "\n"); + krnl::fillOMTensorWithMemRef( + memRef, outOMTensor, outOwning, rewriter, loc, apiRegistry, module); + + auto idxVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(i)); + + auto omTensorPtrAddrTy = LLVM::LLVMPointerType::get(opaquePtrTy); + auto omTensorPtrAddr = rewriter + .create(loc, omTensorPtrAddrTy, + outOmtPtrsArr, ArrayRef{idxVal}) + .getResult(); + + rewriter.create(loc, outOMTensor, omTensorPtrAddr); + } + + // Create wrapped output. + Value wrappedOutput = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::CREATE_OMTENSOR_LIST, {outOmtPtrsArr, numOutput, one}); + + // Return wrapped output. + rewriter.create( + loc, SmallVector(1, wrappedOutput)); + return success(); + } + +private: + // Helper function to insert an entry block to LLVM function. + // (TODO): upstream this to MLIR. + Block &createEntryBlock(Type &dynEntryPoint, + LLVM::LLVMFuncOp &dynamicEntryPointFunc, Location &loc) const { + // Add entry block: + auto *entryPointEntryBlock = new Block(); + auto dynEntryPointFuncType = dynEntryPoint.cast(); + dynamicEntryPointFunc.push_back(entryPointEntryBlock); + llvm::SmallVector argTypes; + for (size_t i = 0; i < dynEntryPointFuncType.getNumParams(); i++) + argTypes.emplace_back(dynEntryPointFuncType.getParamType(i)); + auto argLocs = llvm::SmallVector( + dynEntryPointFuncType.getNumParams(), loc); + entryPointEntryBlock->addArguments(argTypes, argLocs); + return *entryPointEntryBlock; + } + + void fillPtrToMemRefWithOMTensor(Value &rtMemRef, Value &ptrToMemRef, + PatternRewriter &rewriter, const Location &loc, + const RuntimeAPIRegistry &apiRegistry, ModuleOp &module) const { + auto *context = module.getContext(); + auto memRefPtrTy = ptrToMemRef.getType().dyn_cast(); + auto memRefTy = memRefPtrTy.getElementType(); + auto int64Ty = IntegerType::get(context, 64); + + Value memRef = rewriter.create(loc, memRefTy); + + // Set dataPtr and alignedDataPtr; + Value dataPtr = RuntimeAPI::callApi( + rewriter, loc, apiRegistry, RuntimeAPI::API::GET_DATA, {rtMemRef}); + dataPtr = rewriter.create( + loc, memRefTy.cast().getBody()[0], dataPtr); + memRef = rewriter.create(loc, memRefTy, memRef, + dataPtr, rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); + memRef = rewriter.create(loc, memRefTy, memRef, + dataPtr, rewriter.getArrayAttr({rewriter.getI64IntegerAttr(1)})); + + // Use zero offset now. + auto zero = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(0)); + memRef = rewriter.create(loc, memRefTy, memRef, zero, + rewriter.getArrayAttr({rewriter.getI64IntegerAttr(2)})); + + // Get rank, sizes array ptr and strides array ptr. + auto rank = + krnl::getRankFromMemRefType(memRefTy.cast()); + Value sizesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::GET_DATA_SHAPE, {rtMemRef}); + Value stridesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::GET_DATA_STRIDES, {rtMemRef}); + + for (decltype(rank) i = 0; i < rank; i++) { + auto dimIdx = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(i)); + + // Insert size of the dimension. + auto dimSizePtr = + rewriter.create(loc, LLVM::LLVMPointerType::get(int64Ty), + sizesArrayPtr, ArrayRef({dimIdx})); + auto dimSize = rewriter.create(loc, int64Ty, dimSizePtr); + memRef = rewriter.create(loc, memRefTy, memRef, + dimSize, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); + + // Insert stride of the dimension. + auto dimStridePtr = + rewriter.create(loc, LLVM::LLVMPointerType::get(int64Ty), + stridesArrayPtr, ArrayRef({dimIdx})); + auto dimStride = + rewriter.create(loc, int64Ty, dimStridePtr); + memRef = rewriter.create(loc, memRefTy, memRef, + dimStride, + rewriter.getArrayAttr( + {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); + } + + rewriter.create(loc, memRef, ptrToMemRef); + } + +private: + FlatSymbolRefAttr getOrInsertMalloc( + PatternRewriter &rewriter, ModuleOp module) const { + // Insert the malloc/aligned_alloc declaration if it is not already present. + auto allocFunc = module.lookupSymbol("malloc"); + auto ctx = rewriter.getContext(); + LLVMTypeConverter converter(ctx); + if (!allocFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + SmallVector callArgTypes = {converter.getIndexType()}; + // aligned_alloc(size_t alignment, size_t size) + auto voidPtrType = LLVM::LLVMPointerType::get( + IntegerType::get(&converter.getContext(), 8)); + allocFunc = + rewriter.create(rewriter.getUnknownLoc(), "malloc", + LLVM::LLVMFunctionType::get(voidPtrType, callArgTypes, + /*isVarArg=*/false)); + } + return SymbolRefAttr::get(ctx, "malloc"); + } +}; + +void populateLoweringKrnlEntryPointOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx, + ArrayRef constantOutputs, bool singleEntryPoint) { + patterns.insert( + typeConverter, ctx, constantOutputs, singleEntryPoint); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp new file mode 100644 index 000000000000..b0e64a13fb60 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlFindIndex.cpp @@ -0,0 +1,134 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------ KrnlFindIndex.cpp - Lowering KrnlFindIndexOp ------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlFindIndexOp operator. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlFindIndexOpLowering : public ConversionPattern { +public: + explicit KrnlFindIndexOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlFindIndexOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto findIndexOp = cast(op); + MLIRContext *context = findIndexOp.getContext(); + Location loc = findIndexOp.getLoc(); + KrnlFindIndexOpAdaptor operandAdaptor(operands); + + // Get a symbol reference to the runtime function to use, creating one if + // necessary. + ModuleOp module = findIndexOp->getParentOfType(); + FlatSymbolRefAttr findIndexRef = + getOrInsertFindIndex(rewriter, module, findIndexOp.input().getType()); + + // Select the value to pass to as the first argument based on the operator + // input type. + Value firstOperand; + TypeSwitch(findIndexOp.input().getType()) + .Case([&](IntegerType type) { + assert(type.getWidth() == 64 && "expecting an i64 type"); + firstOperand = operandAdaptor.input(); + }) + .Case([&](StringType type) { + Type ptrType = operandAdaptor.input() + .getType() + .cast() + .getBody()[1]; + firstOperand = rewriter.create(loc, ptrType, + operandAdaptor.input(), rewriter.getI64ArrayAttr(1)); + }) + .Default([](Type) { llvm_unreachable("unexpected inputType"); }); + + Type GType = + operandAdaptor.G().getType().cast().getBody()[1]; + Type VType = + operandAdaptor.V().getType().cast().getBody()[1]; + + // Remaining operands. + Value extractedGPtr = rewriter.create( + loc, GType, operandAdaptor.G(), rewriter.getI64ArrayAttr(1)); + Value extractedVPtr = rewriter.create( + loc, VType, operandAdaptor.V(), rewriter.getI64ArrayAttr(1)); + Value length = operandAdaptor.len(); + + // Generate the call to the runtime function. + Type retType = IntegerType::get(context, 64); + auto funcCall = rewriter.create(loc, findIndexRef, retType, + ArrayRef({firstOperand, extractedGPtr, extractedVPtr, length})); + + rewriter.replaceOp(op, funcCall.getResults()[0]); + return success(); + } + +private: + /// Return a symbol reference to the appropriate 'find_index_*' runtime + /// function, inserting it into the module if necessary. + static FlatSymbolRefAttr getOrInsertFindIndex( + PatternRewriter &rewriter, ModuleOp module, Type inputType) { + MLIRContext *ctx = module.getContext(); + Type i8Type = IntegerType::get(ctx, 8); + Type i32Type = IntegerType::get(ctx, 32); + Type i64Type = IntegerType::get(ctx, 64); + Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); + Type i32PtrType = LLVM::LLVMPointerType::get(i32Type); + + // Select the runtime function to use based on the input type. + std::string funcName = "find_index_"; + Type firstArgType; + TypeSwitch(inputType) + .Case([&](IntegerType type) { + assert(type.getWidth() == 64 && "expecting an i64 type"); + funcName += "i64"; + firstArgType = i64Type; + }) + .Case([&](StringType type) { + funcName += "str"; + firstArgType = i8PtrType; + }) + .Default([](Type) { llvm_unreachable("unexpected type"); }); + + Optional optFuncDecl = + krnl::getFunctionDeclaration(module, funcName); + if (optFuncDecl.hasValue()) + return optFuncDecl.getValue(); + + // Create 'find_index_*' signature: `i64 ([i8*|i64], i32*, i32*, i32)` + Type fnType = LLVM::LLVMFunctionType::get(i64Type, + ArrayRef({firstArgType, i32PtrType, i32PtrType, i32Type}), false); + + // Insert the function declaration the module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, fnType); + + return SymbolRefAttr::get(ctx, funcName); + } +}; + +void populateLoweringKrnlFindIndexOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp b/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp new file mode 100644 index 000000000000..9933a15e29c9 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlGetRef.cpp @@ -0,0 +1,163 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlGetRefOp.cpp - Lower KrnlGetRefOp -------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlGetRefOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Support/KrnlSupport.hpp" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlGetRefOpLowering : public ConvertToLLVMPattern { +public: + using ConvertToLLVMPattern::createIndexConstant; + + explicit KrnlGetRefOpLowering( + LLVMTypeConverter &typeConverter, MLIRContext *context) + : ConvertToLLVMPattern( + KrnlGetRefOp::getOperationName(), context, typeConverter) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + KrnlGetRefOpAdaptor operandAdaptor(operands); + + // This is the type of the krnl.getref output. This type is used + // for the type of the internal MemRef. + auto type = op->getResult(0).getType(); + auto memRefTy = type.cast(); + + // auto llvmMemRefType = typeConverter->convertType(type).cast(); + auto outputElementType = + typeConverter->convertType(memRefTy.getElementType()); + + // This is the start of the memory pool containing the output MemRef. + Type memPoolType = operandAdaptor.mempool() + .getType() + .cast() + .getBody()[1]; + Value alignedMemPoolBase = rewriter.create(loc, + memPoolType, operandAdaptor.mempool(), rewriter.getI64ArrayAttr(1)); + + // Get pointer using the offset. + auto offset = operandAdaptor.offset(); + auto llvmMemPoolType = typeConverter->convertType(memPoolType).cast(); + auto outputMemPoolTypePtrAlloc = rewriter.create( + loc, llvmMemPoolType, alignedMemPoolBase, ArrayRef({offset})); + + // Bitcast to output MemRef type i.e. from i8* to the element type + // of the output MemRef. + auto llvmOutputElementType = outputElementType.cast(); + Value outputTypedPtrAlloc = rewriter.create(loc, + LLVM::LLVMPointerType::get(llvmOutputElementType), + outputMemPoolTypePtrAlloc); + + // Handle the static case. + if (hasAllConstantDimensions(memRefTy)) { + // Create llvm MemRef from original MemRef and fill the data pointers. + auto llvmMemRef = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), memRefTy, outputTypedPtrAlloc); + + rewriter.replaceOp(op, {llvmMemRef}); + return success(); + } + + // Handle the dynamic case. + + // Compute strides and offset based on MemRef type. + int64_t alignmentOffset; + SmallVector strides; + auto successStrides = + getStridesAndOffset(memRefTy, strides, alignmentOffset); + (void)successStrides; + assert(succeeded(successStrides) && "unexpected non-strided memref"); + + // Create the memRef descriptor. + auto structType = typeConverter->convertType(memRefTy); + auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + + // Allocated pointer, used for malloc/free. + memRefDescriptor.setAllocatedPtr(rewriter, loc, outputTypedPtrAlloc); + + // Actual aligned pointer to payload. + // TODO: support aligned MemRefs. + memRefDescriptor.setAlignedPtr(rewriter, loc, outputTypedPtrAlloc); + + // Offset in aligned pointer. + // TODO: support non-zero here in the aligned case. + memRefDescriptor.setOffset( + rewriter, loc, createIndexConstant(rewriter, loc, 0)); + + if (memRefTy.getRank() != 0) { + // Prepare sizes. + SmallVector sizes; + sizes.reserve(memRefTy.getRank()); + unsigned i = 0; + for (int64_t s : memRefTy.getShape()) + sizes.push_back(s == ShapedType::kDynamicSize + ? operands[2 + i++] + : createIndexConstant(rewriter, loc, s)); + + // Store all sizes in the descriptor. Only dynamic sizes are passed in as + // operands to AllocOp. + Value runningStride = nullptr; + auto nStrides = strides.size(); + SmallVector strideValues(nStrides, nullptr); + for (unsigned i = 0; i < nStrides; ++i) { + int64_t index = nStrides - 1 - i; + if (strides[index] == MemRefType::getDynamicStrideOrOffset()) + // Identity layout map is enforced in the match function, so we + // compute: + // `runningStride *= sizes[index + 1]` + runningStride = runningStride ? rewriter.create(loc, + runningStride, sizes[index + 1]) + : createIndexConstant(rewriter, loc, 1); + else + runningStride = createIndexConstant(rewriter, loc, strides[index]); + strideValues[index] = runningStride; + } + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(sizes)) { + int64_t index = indexedSize.index(); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + } + } + + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } +}; + +void populateLoweringKrnlGetRefOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp new file mode 100644 index 000000000000..478120e897e0 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlGlobal.cpp @@ -0,0 +1,274 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlGlobal.cpp - Lower KrnlGlobalOp ---------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlGlobalOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "onnx/onnx_pb.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Support/KrnlSupport.hpp" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlGlobalOpLowering : public ConvertToLLVMPattern { +public: + explicit KrnlGlobalOpLowering( + LLVMTypeConverter &typeConverter, MLIRContext *context) + : ConvertToLLVMPattern( + KrnlGlobalOp::getOperationName(), context, typeConverter) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto krnlGlobalOp = llvm::dyn_cast(op); + + // The element type of the array. + const auto type = op->getResult(0).getType(); + const auto memRefTy = type.cast(); + const auto constantElementType = + typeConverter->convertType(memRefTy.getElementType()); + auto globalType = constantElementType; + + // The llvm type of the global (example: [2 x [8 x float]]). + const auto shape = (krnlGlobalOp.shape()).dyn_cast(); + if (shape.empty()) + globalType = LLVM::LLVMArrayType::get(globalType.cast(), 1); + else { + for (int i = shape.size() - 1; i >= 0; i--) + globalType = LLVM::LLVMArrayType::get( + globalType.cast(), ArrayAttrIntVal(shape, i)); + } + + // Create the global at the entry of the module. + assert(krnlGlobalOp.value().hasValue() && + "Krnl Global must always have a value"); + auto value = krnlGlobalOp.value().getValue(); + LLVM::GlobalOp global; + TypeSwitch(value) + .Case([&](OpaqueElementsAttr attr) { + global = lowerOpaqueConstant(krnlGlobalOp, globalType, rewriter); + }) + .Case([&](DenseElementsAttr attr) { + global = lowerDenseConstant(krnlGlobalOp, globalType, rewriter); + }) + .Default([&](Attribute attr) { + llvm_unreachable("Unsupported attribute type"); + }); + + // Set the global alignment based on the alignment attribute if it exists, + // otherwise use the module datalayout info. + krnl::setAlignment(global, krnlGlobalOp.alignmentAttr(), + krnlGlobalOp->getParentOfType(), rewriter, + *getTypeConverter()); + + // Prepare data to be inserted into a MemRefDescriptor (a struct). + Value globalOpAddr = + rewriter.create(krnlGlobalOp.getLoc(), global); + MemRefDescriptor memRefDescr = createMemRefDescriptor( + globalOpAddr, memRefTy, krnlGlobalOp.getLoc(), rewriter); + + rewriter.replaceOp(op, {memRefDescr}); + + return success(); + } + +private: + static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { + return (a.getValue()[i]).cast().getInt(); + } + + // LLVM::GlobalOp does not support OpaqueElementsAttr. + // Both StringAttr and OpaqueElementsAttr use StringRef for internal data + // array. Thus, it looks safe to use StringAtrr instead of + // OpaqueElementsAttr. + LLVM::GlobalOp lowerOpaqueConstant(KrnlGlobalOp &krnlGlobalOp, + Type globalType, ConversionPatternRewriter &rewriter) const { + assert(krnlGlobalOp.value().hasValue() && + "Expecting KrnlGlobalOp with a valid value"); + assert(krnlGlobalOp.value().getValue().isa() && + "Expecting a global with an opaque elements attribute"); + + MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = krnlGlobalOp.getLoc(); + ModuleOp module = krnlGlobalOp->getParentOfType(); + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + StringRef data = + krnlGlobalOp.value().getValue().cast().getValue(); + // Check data size. + int64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + assert(((int64_t)data.size() == sizeInBytes) && "Data size mismatch."); + + StringAttr llvmStringAttr = StringAttr::get(context, data); + auto llvmArrayI8Ty = + LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); + LLVM::GlobalOp global = rewriter.create(loc, llvmArrayI8Ty, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + llvmStringAttr); + + LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); + return global; + } + + LLVM::GlobalOp lowerDenseConstant(KrnlGlobalOp &krnlGlobalOp, Type globalType, + ConversionPatternRewriter &rewriter) const { + assert(krnlGlobalOp.value().hasValue() && + "Expecting KrnlGlobalOp with a valid value"); + assert(krnlGlobalOp.value().getValue().isa() && + "Expecting a global with an dense elements attribute"); + + MLIRContext *context = krnlGlobalOp.getContext(); + Location loc = krnlGlobalOp.getLoc(); + ModuleOp module = krnlGlobalOp->getParentOfType(); + + OpBuilder::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + + DenseElementsAttr denseAttr = + krnlGlobalOp.value().getValue().cast(); + + int64_t sizeInBytes = computeSizeInBytes(krnlGlobalOp); + LLVM::GlobalOp global; + if ((!denseAttr.isSplat()) && (sizeInBytes > 1024)) { + ArrayRef rawData = denseAttr.getRawData(); + assert(((int64_t)rawData.size() == sizeInBytes) && "Data size mismatch."); + + StringRef data(rawData.data(), rawData.size()); + StringAttr llvmStringAttr = StringAttr::get(context, data); + auto llvmArrayI8Ty = + LLVM::LLVMArrayType::get(IntegerType::get(context, 8), sizeInBytes); + global = rewriter.create(loc, llvmArrayI8Ty, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + llvmStringAttr); + } else { + if (denseAttr.getElementType().isa()) + global = lowerStringLiteral(krnlGlobalOp, globalType, rewriter); + else + global = rewriter.create(loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + krnlGlobalOp.value().getValue()); + } + + // LLVM_DEBUG(llvm::dbgs() << "global: " << global << "\n";); + return global; + } + + int64_t computeSizeInBytes(KrnlGlobalOp &krnlGlobalOp) const { + // Compute total number of elements. + const auto shape = (krnlGlobalOp.shape()).dyn_cast(); + int64_t numElements = 1; + for (unsigned int i = 0; i < shape.size(); ++i) + numElements *= ArrayAttrIntVal(shape, i); + + const auto type = krnlGlobalOp.getResult().getType(); + const auto memRefTy = type.cast(); + + return numElements * getMemRefEltSizeInBytes(memRefTy); + } + + // Store the given address into a MemRefDescriptor (a struct). + MemRefDescriptor createMemRefDescriptor(Value address, MemRefType memRefType, + Location loc, OpBuilder &builder) const { + Type elementType = memRefType.getElementType(); + LLVMTypeConverter &typeConverter = *getTypeConverter(); + Type llvmElemType = typeConverter.convertType(elementType); + + // Prepare data to be inserted into a MemRefDescriptor (a struct). + auto ptrType = LLVM::LLVMPointerType::get(llvmElemType); + // Bitcast the address to the MemRefType's element type. + Value bitCastOp = builder.create(loc, ptrType, address); + // Create llvm MemRef from original MemRef and fill the data pointers. + return MemRefDescriptor::fromStaticShape( + builder, loc, typeConverter, memRefType, bitCastOp); + } + + // Generate a global string for each krnlGlobalOp string value, and store + // the address of the global strings into an array. Return the array address. + LLVM::GlobalOp lowerStringLiteral( + KrnlGlobalOp &krnlGlobalOp, Type globalType, OpBuilder &builder) const { + assert(krnlGlobalOp.value().getValue().isa() && + "Expecting a dense value"); + + Location loc = krnlGlobalOp.getLoc(); + ModuleOp module = krnlGlobalOp->getParentOfType(); + DenseElementsAttr denseAttr = + krnlGlobalOp.value().getValue().cast(); + + Type i8Type = IntegerType::get(builder.getContext(), 8); + Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); + + int64_t numStrings = denseAttr.getValues().size(); + if (numStrings == 1) { + StringRef str = *denseAttr.getValues().begin(); + return krnl::getOrCreateGlobalString( + str, loc, builder, module, getTypeConverter()); + } + + // Generate LLVM GlobalOps for each string in the KrnlGlobalOp dense + // attribute. + SmallVector globalOps; + for (StringRef str : denseAttr.getValues()) { + LLVM::GlobalOp globalOp = krnl::getOrCreateGlobalString( + str, loc, builder, module, getTypeConverter()); + globalOps.push_back(globalOp); + } + + // Generate an LLVM GlobalOps with an initializer region containing one + // block. + auto arrayType = LLVM::LLVMArrayType::get(i8PtrType, globalOps.size()); + auto global = builder.create(loc, arrayType, + /*isConstant=*/true, LLVM::Linkage::Internal, krnlGlobalOp.name(), + Attribute()); + Region ®ion = global.getInitializerRegion(); + Block *block = builder.createBlock(®ion); + + // Initialize an array with the addresses of the global strings. + builder.setInsertionPoint(block, block->begin()); + Value array = builder.create(loc, arrayType); + + int32_t index = 0; + Value lastValue = array; + for (const LLVM::GlobalOp &globalOp : globalOps) { + Value strAddr = krnl::getPtrToGlobalString(globalOp, loc, builder); + lastValue = builder.create(loc, arrayType, lastValue, + strAddr, builder.getArrayAttr({builder.getIndexAttr(index++)})); + } + + builder.create(loc, ArrayRef({lastValue})); + return global; + } +}; + +void populateLoweringKrnlGlobalOpPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp new file mode 100644 index 000000000000..51a28d71a3f6 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlInstrument.cpp @@ -0,0 +1,95 @@ + +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlInstrument.cpp - Lower KrnlInstrumentOp -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlInstrumentOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlInstrumentOpLowering : public ConversionPattern { +public: + explicit KrnlInstrumentOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlInstrumentOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); + KrnlInstrumentOpAdaptor operandAdaptor(operands); + auto loc = op->getLoc(); + KrnlInstrumentOp instrumentOp = llvm::dyn_cast(op); + + // Get a symbol reference to the memcpy function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto instrumentRef = getOrInsertInstrument(rewriter, parentModule); + + Value nodeName = + rewriter.create(loc, IntegerType::get(context, 64), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), instrumentOp.opID())); + Value tag = + rewriter.create(loc, IntegerType::get(context, 64), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), instrumentOp.tag())); + + rewriter.create(loc, instrumentRef, ArrayRef({}), + ArrayRef({nodeName, tag})); + + rewriter.eraseOp(op); + return success(); + } + +private: + // Create a function declaration for OMInstrumentPoint, the signature is: + // `void (i64, i64)` + FlatSymbolRefAttr getOrInsertInstrument( + PatternRewriter &rewriter, ModuleOp module) const { + auto *context = module.getContext(); + std::string funcName("OMInstrumentPoint"); + if (module.lookupSymbol(funcName)) + return SymbolRefAttr::get(context, funcName); + auto llvmVoidTy = LLVM::LLVMVoidType::get(context); + auto llvmI64Ty = IntegerType::get(context, 64); + auto llvmFnType = LLVM::LLVMFunctionType::get( + llvmVoidTy, ArrayRef({llvmI64Ty, llvmI64Ty}), false); + + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, llvmFnType); + return SymbolRefAttr::get(context, funcName); + } +}; + +void populateLoweringKrnlInstrumentOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp new file mode 100644 index 000000000000..c57ad8994890 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlMemcpy.cpp @@ -0,0 +1,123 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlMemcpy.cpp - Lower KrnlMemcpyOp ---------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlMemcpyOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlMemcpyOpLowering : public ConversionPattern { +public: + explicit KrnlMemcpyOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlMemcpyOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto *context = op->getContext(); + KrnlMemcpyOpAdaptor operandAdaptor(operands); + auto loc = op->getLoc(); + + // Get a symbol reference to the memcpy function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule); + + // First operand. + Type dstType = operandAdaptor.dest() + .getType() + .cast() + .getBody()[1]; + Value alignedDstMemory = rewriter.create( + loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1)); + Value alignedInt8PtrDstMemory = rewriter.create(loc, + LLVM::LLVMPointerType::get(IntegerType::get(context, 8)), + alignedDstMemory); + + // Second operand. + Type srcType = operandAdaptor.src() + .getType() + .cast() + .getBody()[1]; + Value alignedSrcMemory = rewriter.create( + loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1)); + Value alignedInt8PtrSrcMemory = rewriter.create(loc, + LLVM::LLVMPointerType::get(IntegerType::get(context, 8)), + alignedSrcMemory); + + // Size. + Value int64Size = rewriter.create( + loc, IntegerType::get(context, 64), operandAdaptor.size()); + + // Is volatile (set to false). + Value isVolatile = + rewriter.create(loc, IntegerType::get(context, 1), + rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); + + // Memcpy call + rewriter.create(loc, memcpyRef, ArrayRef({}), + ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, + int64Size, isVolatile})); + + rewriter.eraseOp(op); + return success(); + } + +private: + /// Return a symbol reference to the memcpy function, inserting it into the + /// module if necessary. + FlatSymbolRefAttr getOrInsertMemcpy( + PatternRewriter &rewriter, ModuleOp module) const { + auto *context = module.getContext(); + if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) + return SymbolRefAttr::get(context, "llvm.memcpy.p0i8.p0i8.i64"); + // Create a function declaration for memcpy, the signature is: + // * `void (i8*, i8* , i64, i1)` + auto llvmVoidTy = LLVM::LLVMVoidType::get(context); + auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); + auto llvmI64Ty = IntegerType::get(context, 64); + auto llvmI1Ty = IntegerType::get(context, 1); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, + ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), + false); + + // Insert the memcpy function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create( + module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); + return SymbolRefAttr::get(context, "llvm.memcpy.p0i8.p0i8.i64"); + } +}; + +void populateLoweringKrnlMemcpyOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp index 587d94b91749..d9f2cd15be76 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrint.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrint.cpp @@ -12,73 +12,82 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" - -#include "src/Conversion/KrnlToLLVM/KrnlPrint.hpp" -#include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" - #include "llvm/Support/Debug.h" #define DEBUG_TYPE "krnl_to_llvm" using namespace mlir; +using namespace onnx_mlir; namespace onnx_mlir { +namespace krnl { + +class KrnlPrintOpLowering : public ConversionPattern { +public: + explicit KrnlPrintOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlPrintOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto printOp = cast(op); + Location loc = printOp.getLoc(); + KrnlPrintOpAdaptor operandAdaptor(operands); + + Value input = operandAdaptor.input(); + StringRef format = printOp.format(); + ModuleOp module = printOp->getParentOfType(); + + // Get a symbol reference to the runtime function to use, creating one if + // necessary. + auto printfFuncRef = getOrInsertPrintf(rewriter, module); + + // Printf call. + LLVM::GlobalOp formatSpec = getOrCreateGlobalString(format, loc, rewriter, + module, static_cast(getTypeConverter())); + Value formatSpecPtr = getPtrToGlobalString(formatSpec, loc, rewriter); + + if (input) + rewriter.create(loc, printfFuncRef, ArrayRef({}), + ArrayRef({formatSpecPtr, input})); + else + rewriter.create(loc, printfFuncRef, ArrayRef({}), + ArrayRef({formatSpecPtr})); + + rewriter.eraseOp(op); + return success(); + } -LogicalResult KrnlPrintOpLowering::matchAndRewrite(Operation *op, - ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto printOp = cast(op); - Location loc = printOp.getLoc(); - KrnlPrintOpAdaptor operandAdaptor(operands); - - Value input = operandAdaptor.input(); - StringRef format = printOp.format(); - ModuleOp module = printOp->getParentOfType(); - - // Get a symbol reference to the runtime function to use, creating one if - // necessary. - auto printfFuncRef = getOrInsertPrintf(rewriter, module); - - // Printf call. - LLVM::GlobalOp formatSpec = getOrCreateGlobalString(format, loc, rewriter, - module, static_cast(getTypeConverter())); - Value formatSpecPtr = getPtrToGlobalString(formatSpec, loc, rewriter); - - if (input) - rewriter.create(loc, printfFuncRef, ArrayRef({}), - ArrayRef({formatSpecPtr, input})); - else - rewriter.create(loc, printfFuncRef, ArrayRef({}), - ArrayRef({formatSpecPtr})); - - rewriter.eraseOp(op); - return success(); -} - -FlatSymbolRefAttr KrnlPrintOpLowering::getOrInsertPrintf( - PatternRewriter &rewriter, ModuleOp module) { - // Insert the printf declaration if it is not already present. - auto printfFunc = module.lookupSymbol("printf"); - MLIRContext *ctx = rewriter.getContext(); - - if (!printfFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - auto voidType = LLVM::LLVMVoidType::get(ctx); - Type i8Type = IntegerType::get(ctx, 8); - Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); - printfFunc = - rewriter.create(rewriter.getUnknownLoc(), "printf", - LLVM::LLVMFunctionType::get(voidType, i8PtrType, - /*isVarArg=*/true)); +private: + static FlatSymbolRefAttr getOrInsertPrintf( + PatternRewriter &rewriter, ModuleOp module) { + // Insert the printf declaration if it is not already present. + auto printfFunc = module.lookupSymbol("printf"); + MLIRContext *ctx = rewriter.getContext(); + + if (!printfFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto voidType = LLVM::LLVMVoidType::get(ctx); + Type i8Type = IntegerType::get(ctx, 8); + Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); + printfFunc = + rewriter.create(rewriter.getUnknownLoc(), "printf", + LLVM::LLVMFunctionType::get(voidType, i8PtrType, + /*isVarArg=*/true)); + } + return SymbolRefAttr::get(ctx, "printf"); } - return SymbolRefAttr::get(ctx, "printf"); +}; + +void populateLoweringKrnlPrintOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); } +} // namespace krnl } // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlPrint.hpp b/src/Conversion/KrnlToLLVM/KrnlPrint.hpp deleted file mode 100644 index 5f5b14b0da44..000000000000 --- a/src/Conversion/KrnlToLLVM/KrnlPrint.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===------ KrnlPrint.hpp - Lower KrnlPrintOp -----------------------------===// -// -// Copyright 2022 The IBM Research Authors. -// -// ============================================================================= -// -// This file declares the lowering class for the KrnlPrintOp operator. -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" - -#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Support/Common.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -class KrnlPrintOpLowering : public ConversionPattern { -public: - explicit KrnlPrintOpLowering( - MLIRContext *context, TypeConverter &typeConverter) - : ConversionPattern( - typeConverter, KrnlPrintOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; - -private: - static FlatSymbolRefAttr getOrInsertPrintf( - PatternRewriter &rewriter, ModuleOp module); -}; - -} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp index 4c4e175210f8..8f40dd3855bc 100644 --- a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.cpp @@ -12,64 +12,69 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" - -#include "onnx/onnx_pb.h" - -#include "src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp" -#include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" -#include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" - #include "llvm/Support/Debug.h" #define DEBUG_TYPE "krnl_to_llvm" using namespace mlir; +using namespace onnx_mlir; namespace onnx_mlir { - -LogicalResult KrnlPrintTensorOpLowering::matchAndRewrite(Operation *op, - ArrayRef operands, ConversionPatternRewriter &rewriter) const { - auto printTensorOp = cast(op); - MLIRContext *context = printTensorOp.getContext(); - Location loc = printTensorOp.getLoc(); - KrnlPrintTensorOpAdaptor operandAdaptor(operands); - - StringRef msg = printTensorOp.msg(); - Value input = operandAdaptor.input(); - assert(input.getType().isa() && - "expecting LLVMStructType"); - - ModuleOp module = printTensorOp->getParentOfType(); - const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter); - - // Get a symbol reference to the runtime function to use, creating one if - // necessary. - auto int64Ty = IntegerType::get(context, 64); - auto memRefTy = input.getType().dyn_cast(); - auto memRefRank = onnx_mlir::getRankFromMemRefType(memRefTy); - auto memRefRankVal = rewriter.create( - loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank)); - Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, - RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); - - onnx_mlir::fillOMTensorWithMemRef( - input, omTensor, false /*outOwning*/, rewriter, loc, apiRegistry, module); - LLVM::GlobalOp globalStr = getOrCreateGlobalString(msg, loc, rewriter, module, - static_cast(getTypeConverter())); - Value strPtr = getPtrToGlobalString(globalStr, loc, rewriter); - - RuntimeAPI::callApi(rewriter, loc, apiRegistry, - RuntimeAPI::API::PRINT_OMTENSOR, {strPtr, omTensor}); - - rewriter.eraseOp(op); - return success(); +namespace krnl { + +class KrnlPrintTensorOpLowering : public ConversionPattern { +public: + explicit KrnlPrintTensorOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlPrintTensorOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto printTensorOp = cast(op); + MLIRContext *context = printTensorOp.getContext(); + Location loc = printTensorOp.getLoc(); + KrnlPrintTensorOpAdaptor operandAdaptor(operands); + + StringRef msg = printTensorOp.msg(); + Value input = operandAdaptor.input(); + assert(input.getType().isa() && + "expecting LLVMStructType"); + + ModuleOp module = printTensorOp->getParentOfType(); + const auto &apiRegistry = RuntimeAPIRegistry::build(module, rewriter); + + // Get a symbol reference to the runtime function to use, creating one if + // necessary. + auto int64Ty = IntegerType::get(context, 64); + auto memRefTy = input.getType().dyn_cast(); + auto memRefRank = krnl::getRankFromMemRefType(memRefTy); + auto memRefRankVal = rewriter.create( + loc, int64Ty, rewriter.getI64IntegerAttr(memRefRank)); + Value omTensor = RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::CREATE_OMTENSOR, {memRefRankVal}); + + krnl::fillOMTensorWithMemRef(input, omTensor, false /*outOwning*/, rewriter, + loc, apiRegistry, module); + LLVM::GlobalOp globalStr = krnl::getOrCreateGlobalString(msg, loc, rewriter, + module, static_cast(getTypeConverter())); + Value strPtr = krnl::getPtrToGlobalString(globalStr, loc, rewriter); + + RuntimeAPI::callApi(rewriter, loc, apiRegistry, + RuntimeAPI::API::PRINT_OMTENSOR, {strPtr, omTensor}); + + rewriter.eraseOp(op); + return success(); + } +}; + +void populateLoweringKrnlPrintTensorOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); } +} // namespace krnl } // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp b/src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp deleted file mode 100644 index f6f6413a976d..000000000000 --- a/src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===------ KrnlPrintTensor.hpp - Lower KrnlPrintTensorOp -----------------===// -// -// Copyright 2022 The IBM Research Authors. -// -// ============================================================================= -// -// This file declares the lowering class for the KrnlPrintTensorOp operator. -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" - -#include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" -#include "src/Dialect/Krnl/KrnlOps.hpp" -#include "src/Support/Common.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -class KrnlPrintTensorOpLowering : public ConversionPattern { -public: - explicit KrnlPrintTensorOpLowering( - MLIRContext *context, TypeConverter &typeConverter) - : ConversionPattern( - typeConverter, KrnlPrintTensorOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp new file mode 100644 index 000000000000..df30650f8983 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlRandomNormal.cpp @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlRandomNormal.cpp - Lower KrnlRandomNormalOp ---------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlRandomNormalOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlRandomNormalOpLowering : public ConversionPattern { +public: + explicit KrnlRandomNormalOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern(typeConverter, KrnlRandomNormalOp::getOperationName(), + 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + KrnlRandomNormalOpAdaptor operandAdaptor(operands); + auto loc = op->getLoc(); + mlir::Type inType = op->getOperand(2).getType(); + + // Get a symbol reference to the memcpy function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto randomNormalFuncRef = + getOrInsertRandomNormal(rewriter, parentModule, inType); + + // First operand. + Type outputType = operandAdaptor.output() + .getType() + .cast() + .getBody()[1]; + Value alignedOutput = rewriter.create( + loc, outputType, operandAdaptor.output(), rewriter.getI64ArrayAttr(1)); + + // Memcpy call + rewriter.create(loc, randomNormalFuncRef, ArrayRef({}), + ArrayRef({alignedOutput, operandAdaptor.numberOfValues(), + operandAdaptor.mean(), operandAdaptor.scale(), + operandAdaptor.seed()})); + + rewriter.eraseOp(op); + return success(); + } + +private: + FlatSymbolRefAttr getOrInsertRandomNormal( + PatternRewriter &rewriter, ModuleOp module, Type inType) const { + MLIRContext *context = module.getContext(); + StringRef functionName = inType.isF64() ? "get_random_normal_value_f64" + : "get_random_normal_value_f32"; + if (module.lookupSymbol(functionName.str())) + return SymbolRefAttr::get(context, functionName.str()); + + // Signature of the input is: + // "krnl.random_normal"(%0, %c60, %cst, %cst_0, %cst_1) + // with types: + // (memref<3x4x5xf32>, index, f32, f32, f32) + // or + // (memref<3x4x5xf64>, index, f64, f64, f64) + auto llvmVoidTy = LLVM::LLVMVoidType::get(context); + auto llvmOptionsTy = FloatType::getF32(context); + auto llvmOutputTy = LLVM::LLVMPointerType::get(llvmOptionsTy); + if (inType.isF64()) { + llvmOptionsTy = FloatType::getF64(context); + llvmOutputTy = LLVM::LLVMPointerType::get(llvmOptionsTy); + } + auto llvmI64Ty = IntegerType::get(context, 64); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, + ArrayRef({llvmOutputTy, llvmI64Ty, llvmOptionsTy, + llvmOptionsTy, llvmOptionsTy}), + false); + + // Insert the random normal function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create( + module.getLoc(), functionName.str(), llvmFnType); + return SymbolRefAttr::get(context, functionName.str()); + } +}; + +void populateLoweringKrnlRandomNormalOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/KrnlToLLVM/KrnlStrlen.cpp b/src/Conversion/KrnlToLLVM/KrnlStrlen.cpp new file mode 100644 index 000000000000..655f58753082 --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlStrlen.cpp @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlStrlend.cpp - Lower KrnlStrlenOp --------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlStrlenOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlStrlenOpLowering : public ConversionPattern { +public: + explicit KrnlStrlenOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlStrlenOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + KrnlStrlenOpAdaptor operandAdaptor(operands); + Location loc = op->getLoc(); + + // Get a symbol reference to the strlen function, inserting it if necessary. + ModuleOp parentModule = op->getParentOfType(); + auto strlenRef = getOrInsertStrlen(rewriter, parentModule); + + // Operand. + Type strType = operandAdaptor.str() + .getType() + .cast() + .getBody()[1]; + Value extractedStrPtr = rewriter.create( + loc, strType, operandAdaptor.str(), rewriter.getI64ArrayAttr(1)); + + // Strlen call. + // TODO: should return a size_t + Type retType = IntegerType::get(context, 64); + auto funcCall = rewriter.create( + loc, strlenRef, retType, ArrayRef({extractedStrPtr})); + + rewriter.replaceOp(op, funcCall.getResults()[0]); + return success(); + } + +private: + /// Return a symbol reference to the strlen function, inserting it into the + /// module if necessary. + static FlatSymbolRefAttr getOrInsertStrlen( + PatternRewriter &rewriter, ModuleOp module) { + constexpr const char *funcName = "strlen"; + Optional optFuncDecl = + krnl::getFunctionDeclaration(module, funcName); + if (optFuncDecl.hasValue()) + return optFuncDecl.getValue(); + + // Create 'strlen' function signature: `size_t (i8*)` + // TODO: need to create size_t not i64. + MLIRContext *ctx = module.getContext(); + Type i8Type = IntegerType::get(ctx, 8); + Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); + Type fnType = LLVM::LLVMFunctionType::get( + rewriter.getI64Type(), ArrayRef({i8PtrType}), false); + + // Insert the function declaration the module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create(module.getLoc(), funcName, fnType); + + return SymbolRefAttr::get(ctx, funcName); + } +}; + +void populateLoweringKrnlStrlenOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlStrncmp.cpp b/src/Conversion/KrnlToLLVM/KrnlStrncmp.cpp new file mode 100644 index 000000000000..64054f680cce --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlStrncmp.cpp @@ -0,0 +1,78 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlStrncmp.cpp - Lower KrnlStrncmpOp -------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlStrncmpOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlStrncmpOpLowering : public ConversionPattern { +public: + explicit KrnlStrncmpOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlStrncmpOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + KrnlStrncmpOpAdaptor operandAdaptor(operands); + Location loc = op->getLoc(); + + // Get a symbol reference to the strncmp function, inserting it if + // necessary. + ModuleOp parentModule = op->getParentOfType(); + auto StrncmpRef = getOrInsertStrncmp(rewriter, parentModule); + + // Operands. + Type strType = operandAdaptor.str1() + .getType() + .cast() + .getBody()[1]; + Value extractedStrPtr1 = rewriter.create( + loc, strType, operandAdaptor.str1(), rewriter.getI64ArrayAttr(1)); + Value extractedStrPtr2 = rewriter.create( + loc, strType, operandAdaptor.str2(), rewriter.getI64ArrayAttr(1)); + Value length = operandAdaptor.len(); + + // Strncmp call. + MLIRContext *ctx = op->getContext(); + Type i32Type = IntegerType::get(ctx, 32); + auto funcCall = rewriter.create(loc, StrncmpRef, i32Type, + ArrayRef({extractedStrPtr1, extractedStrPtr2, length})); + + rewriter.replaceOp(op, funcCall.getResults()[0]); + return success(); + } +}; + +void populateLoweringKrnlStrncmpOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp index 98b6c2ae90f9..73a1a9870476 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp @@ -44,8 +44,6 @@ #include "onnx/onnx_pb.h" -#include "src/Conversion/KrnlToLLVM/KrnlPrint.hpp" -#include "src/Conversion/KrnlToLLVM/KrnlPrintTensor.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp" #include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" #include "src/Conversion/KrnlToLLVM/RuntimeAPI.hpp" @@ -59,118 +57,10 @@ const std::string DEFAULT_DYN_ENTRY_POINT = "run_main_graph"; using namespace mlir; +using namespace onnx_mlir; namespace { -// Create a function declaration for OMInstrumentPoint, the signature is: -// `void (i64, i64)` -static FlatSymbolRefAttr getOrInsertInstrument( - PatternRewriter &rewriter, ModuleOp module) { - auto *context = module.getContext(); - std::string funcName("OMInstrumentPoint"); - if (module.lookupSymbol(funcName)) - return SymbolRefAttr::get(context, funcName); - auto llvmVoidTy = LLVM::LLVMVoidType::get(context); - auto llvmI64Ty = IntegerType::get(context, 64); - auto llvmFnType = LLVM::LLVMFunctionType::get( - llvmVoidTy, ArrayRef({llvmI64Ty, llvmI64Ty}), false); - - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), funcName, llvmFnType); - return SymbolRefAttr::get(context, funcName); -} - -/// Return a symbol reference to the memcpy function, inserting it into the -/// module if necessary. -static FlatSymbolRefAttr getOrInsertMemcpy( - PatternRewriter &rewriter, ModuleOp module) { - auto *context = module.getContext(); - if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) - return SymbolRefAttr::get(context, "llvm.memcpy.p0i8.p0i8.i64"); - // Create a function declaration for memcpy, the signature is: - // * `void (i8*, i8* , i64, i1)` - auto llvmVoidTy = LLVM::LLVMVoidType::get(context); - auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmI64Ty = IntegerType::get(context, 64); - auto llvmI1Ty = IntegerType::get(context, 1); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, - ArrayRef({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), - false); - - // Insert the memcpy function into the body of the parent module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create( - module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); - return SymbolRefAttr::get(context, "llvm.memcpy.p0i8.p0i8.i64"); -} - -static Optional getFunctionDeclaration( - ModuleOp module, const char *funcName) { - assert(funcName && "Missing function name"); - if (module.lookupSymbol(funcName)) - return SymbolRefAttr::get(module.getContext(), funcName); - - return None; -} - -static FlatSymbolRefAttr getOrInsertRandomNormal( - PatternRewriter &rewriter, ModuleOp module, Type inType) { - MLIRContext *context = module.getContext(); - StringRef functionName = inType.isF64() ? "get_random_normal_value_f64" - : "get_random_normal_value_f32"; - if (module.lookupSymbol(functionName.str())) - return SymbolRefAttr::get(context, functionName.str()); - - // Signature of the input is: - // "krnl.random_normal"(%0, %c60, %cst, %cst_0, %cst_1) - // with types: - // (memref<3x4x5xf32>, index, f32, f32, f32) - // or - // (memref<3x4x5xf64>, index, f64, f64, f64) - auto llvmVoidTy = LLVM::LLVMVoidType::get(context); - auto llvmOptionsTy = FloatType::getF32(context); - auto llvmOutputTy = LLVM::LLVMPointerType::get(llvmOptionsTy); - if (inType.isF64()) { - llvmOptionsTy = FloatType::getF64(context); - llvmOutputTy = LLVM::LLVMPointerType::get(llvmOptionsTy); - } - auto llvmI64Ty = IntegerType::get(context, 64); - auto llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, - ArrayRef({llvmOutputTy, llvmI64Ty, llvmOptionsTy, - llvmOptionsTy, llvmOptionsTy}), - false); - - // Insert the random normal function into the body of the parent module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create( - module.getLoc(), functionName.str(), llvmFnType); - return SymbolRefAttr::get(context, functionName.str()); -} - -static FlatSymbolRefAttr getOrInsertMalloc( - PatternRewriter &rewriter, ModuleOp module) { - // Insert the malloc/aligned_alloc declaration if it is not already present. - auto allocFunc = module.lookupSymbol("malloc"); - auto ctx = rewriter.getContext(); - LLVMTypeConverter converter(ctx); - if (!allocFunc) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - SmallVector callArgTypes = {converter.getIndexType()}; - // aligned_alloc(size_t alignment, size_t size) - auto voidPtrType = LLVM::LLVMPointerType::get( - IntegerType::get(&converter.getContext(), 8)); - allocFunc = - rewriter.create(rewriter.getUnknownLoc(), "malloc", - LLVM::LLVMFunctionType::get(voidPtrType, callArgTypes, - /*isVarArg=*/false)); - } - return SymbolRefAttr::get(ctx, "malloc"); -} - ATTRIBUTE(unused) static FlatSymbolRefAttr getOrInsertDealloc( PatternRewriter &rewriter, ModuleOp module) { @@ -193,6 +83,7 @@ static FlatSymbolRefAttr getOrInsertDealloc( return SymbolRefAttr::get(ctx, "free"); } +#if 0 // This function emits a declaration of the form: // // declare float (float) @@ -631,135 +522,21 @@ class KrnlInstrumentOpLowering : public ConversionPattern { } }; +======= +>>>>>>> 1289f43 (Split krnlToLLVM.cpp file) //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlMemcpyOpLowering //===----------------------------------------------------------------------===// -class KrnlMemcpyOpLowering : public ConversionPattern { -public: - explicit KrnlMemcpyOpLowering(MLIRContext *context) - : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto *context = op->getContext(); - KrnlMemcpyOpAdaptor operandAdaptor(operands); - auto loc = op->getLoc(); - - // Get a symbol reference to the memcpy function, inserting it if necessary. - ModuleOp parentModule = op->getParentOfType(); - auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule); - - // First operand. - Type dstType = operandAdaptor.dest() - .getType() - .cast() - .getBody()[1]; - Value alignedDstMemory = rewriter.create( - loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1)); - Value alignedInt8PtrDstMemory = rewriter.create(loc, - LLVM::LLVMPointerType::get(IntegerType::get(context, 8)), - alignedDstMemory); - - // Second operand. - Type srcType = operandAdaptor.src() - .getType() - .cast() - .getBody()[1]; - Value alignedSrcMemory = rewriter.create( - loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1)); - Value alignedInt8PtrSrcMemory = rewriter.create(loc, - LLVM::LLVMPointerType::get(IntegerType::get(context, 8)), - alignedSrcMemory); - - // Size. - Value int64Size = rewriter.create( - loc, IntegerType::get(context, 64), operandAdaptor.size()); - - // Is volatile (set to false). - Value isVolatile = - rewriter.create(loc, IntegerType::get(context, 1), - rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); - - // Memcpy call - rewriter.create(loc, memcpyRef, ArrayRef({}), - ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, - int64Size, isVolatile})); - - rewriter.eraseOp(op); - return success(); - } -}; - //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlStrlenOpLowering //===----------------------------------------------------------------------===// -class KrnlStrlenOpLowering : public ConversionPattern { -public: - explicit KrnlStrlenOpLowering(MLIRContext *context) - : ConversionPattern(KrnlStrlenOp::getOperationName(), 1, context) {} - - LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); - KrnlStrlenOpAdaptor operandAdaptor(operands); - Location loc = op->getLoc(); - - // Get a symbol reference to the strlen function, inserting it if necessary. - ModuleOp parentModule = op->getParentOfType(); - auto strlenRef = getOrInsertStrlen(rewriter, parentModule); - - // Operand. - Type strType = operandAdaptor.str() - .getType() - .cast() - .getBody()[1]; - Value extractedStrPtr = rewriter.create( - loc, strType, operandAdaptor.str(), rewriter.getI64ArrayAttr(1)); - - // Strlen call. - // TODO: should return a size_t - Type retType = IntegerType::get(context, 64); - auto funcCall = rewriter.create( - loc, strlenRef, retType, ArrayRef({extractedStrPtr})); - - rewriter.replaceOp(op, funcCall.getResults()[0]); - return success(); - } - -private: - /// Return a symbol reference to the strlen function, inserting it into the - /// module if necessary. - static FlatSymbolRefAttr getOrInsertStrlen( - PatternRewriter &rewriter, ModuleOp module) { - constexpr const char *funcName = "strlen"; - Optional optFuncDecl = - getFunctionDeclaration(module, funcName); - if (optFuncDecl.hasValue()) - return optFuncDecl.getValue(); - - // Create 'strlen' function signature: `size_t (i8*)` - // TODO: need to create size_t not i64. - MLIRContext *ctx = module.getContext(); - Type i8Type = IntegerType::get(ctx, 8); - Type i8PtrType = LLVM::LLVMPointerType::get(i8Type); - Type fnType = LLVM::LLVMFunctionType::get( - rewriter.getI64Type(), ArrayRef({i8PtrType}), false); - - // Insert the function declaration the module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), funcName, fnType); - - return SymbolRefAttr::get(ctx, funcName); - } -}; - //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlStrncmpOpLowering //===----------------------------------------------------------------------===// +<<<<<<< HEAD class KrnlStrncmpOpLowering : public ConversionPattern { public: explicit KrnlStrncmpOpLowering(MLIRContext *context) @@ -1392,10 +1169,13 @@ class KrnlRandomNormalOpLowering : public ConversionPattern { } }; +#endif + //===----------------------------------------------------------------------===// // KRNL to LLVM: KrnlFindIndexOpLowering //===----------------------------------------------------------------------===// +#if 0 class KrnlFindIndexOpLowering : public ConversionPattern { public: explicit KrnlFindIndexOpLowering(MLIRContext *context) @@ -1481,7 +1261,7 @@ class KrnlFindIndexOpLowering : public ConversionPattern { .Default([](Type) { llvm_unreachable("unexpected type"); }); Optional optFuncDecl = - getFunctionDeclaration(module, funcName.c_str()); + krnl::getFunctionDeclaration(module, funcName.c_str()); if (optFuncDecl.hasValue()) return optFuncDecl.getValue(); @@ -1498,8 +1278,12 @@ class KrnlFindIndexOpLowering : public ConversionPattern { } }; +#endif + } // end namespace +#if 0 + void mlir::populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, MLIRContext *ctx, LLVMTypeConverter &typeConverter, ArrayRef constantOutputs, bool singleEntryPoint) { @@ -1563,7 +1347,9 @@ void mlir::populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, patterns.insert(ctx); patterns.insert(ctx); } +#endif +#if 0 void mlir::checkConstantOutputs( ModuleOp &module, SmallVectorImpl &constantOutputs) { Operation *entryPointOp; @@ -1638,7 +1424,9 @@ void mlir::checkConstantOutputs( << "Is entry function output constant? " << isConstant << "\n"); } } +#endif +<<<<<<< HEAD void mlir::recordEntryPointSignatures(ModuleOp &module, SmallVectorImpl &entryPointNames, SmallVectorImpl &inSignatures, @@ -1888,6 +1676,9 @@ void mlir::genSignatureFunction(ModuleOp module, } } +======= +#if 0 +>>>>>>> 1289f43 (Split krnlToLLVM.cpp file) //===----------------------------------------------------------------------===// // KRNL + Standard + Vector + Affine dialects lowering to LLVM. //===----------------------------------------------------------------------===// @@ -1970,3 +1761,5 @@ void ConvertKrnlToLLVMPass::runOnOperation() { std::unique_ptr mlir::createConvertKrnlToLLVMPass() { return std::make_unique(); } + +#endif \ No newline at end of file diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp b/src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp deleted file mode 100644 index d16cd6842ef8..000000000000 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVM.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - */ - -//===------ KrnlToLLVM.hpp - Lowering from KRNL+Affine+Std to LLVM -------===// -// -// Copyright 2019-2020 The IBM Research Authors. -// -// ============================================================================= -// -// -// -//===----------------------------------------------------------------------===// - -#ifndef KRNL_TO_LLVM_H -#define KRNL_TO_LLVM_H - -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" - -namespace mlir { - -class MLIRContext; -class LLVMTypeConverter; - -class RewritePatternSet; - -void checkConstantOutputs( - ModuleOp &module, SmallVectorImpl &constantOutputs); - -void recordEntryPointSignatures(ModuleOp &module, - SmallVectorImpl &entryPointNames, - SmallVectorImpl &inSignatures, - SmallVectorImpl &outSignatures); - -void genSignatureFunction(ModuleOp module, - const ArrayRef entryPointNames, - const ArrayRef inSignatures, - const ArrayRef outSignatures); - -void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns, - MLIRContext *ctx, LLVMTypeConverter &typeConverter, - ArrayRef constantOutputs, bool singleEntryPoint); - -} // namespace mlir - -#endif // KRNL_TO_LLVM_H diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp index abc802369813..b84ba212ce6e 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp @@ -18,8 +18,10 @@ #include "llvm/ADT/TypeSwitch.h" using namespace mlir; +using namespace onnx_mlir; namespace onnx_mlir { +namespace krnl { static const int32_t MinGlobalAlign = 16; @@ -30,14 +32,14 @@ int64_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy) { // elements will have 0-length, which in turn causes the MemRef struct to // degenerate into a 3-element struct. For more information, refer to // https://github.com/llvm/llvm-project/blob/main/mlir/docs/ConversionToLLVMDialect.md#memref-types. - auto numElems = memRefTy.getBody().size(); + size_t numElems = memRefTy.getBody().size(); assert((numElems == 3 || numElems == 5) && "Expect MemRef type to contain either 3 or 5 elements."); - if (numElems == 3) - return 0; // MemRef refers to a scalar. - else - return memRefTy.getBody()[3].cast().getNumElements(); + return (numElems == 3) ? 0 // MemRef refers to a scalar. + : memRefTy.getBody()[3] + .cast() + .getNumElements(); } // Convert an MLIR type to the correspoding ONNX type. @@ -47,7 +49,7 @@ onnx::TensorProto::DataType mlirTypeToOnnxType(Type elemType) { TypeSwitch(elemType) .Case( [&](BFloat16Type) { onnxType = onnx::TensorProto::BFLOAT16; }) - .Case([&](ComplexType type) { + .Case([&](ComplexType type) { if (type.getElementType().isa()) onnxType = onnx::TensorProto::COMPLEX64; else if (type.getElementType().isa()) @@ -99,7 +101,7 @@ onnx::TensorProto::DataType mlirTypeToOnnxType(Type elemType) { void fillOMTensorWithMemRef(Value &outMemRef, Value &outOMTensor, int64_t outOwning, PatternRewriter &rewriter, const Location &loc, const RuntimeAPIRegistry &apiRegistry, ModuleOp &module) { - auto *context = module.getContext(); + MLIRContext *context = module.getContext(); auto outMemRefTy = outMemRef.getType().dyn_cast(); auto int64Ty = IntegerType::get(context, 64); @@ -130,13 +132,13 @@ void fillOMTensorWithMemRef(Value &outMemRef, Value &outOMTensor, Type elemTy = outMemRefTy.getBody()[0].cast().getElementType(); - onnx::TensorProto::DataType onnxTy = onnx_mlir::mlirTypeToOnnxType(elemTy); + onnx::TensorProto::DataType onnxTy = krnl::mlirTypeToOnnxType(elemTy); auto onnxTyVal = rewriter.create( loc, int64Ty, rewriter.getI64IntegerAttr(onnxTy)); RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::SET_DATA_TYPE, {outOMTensor, onnxTyVal}); - int64_t rank = onnx_mlir::getRankFromMemRefType(outMemRefTy); + int64_t rank = krnl::getRankFromMemRefType(outMemRefTy); Value sizesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, RuntimeAPI::API::GET_DATA_SHAPE, {outOMTensor}); Value stridesArrayPtr = RuntimeAPI::callApi(rewriter, loc, apiRegistry, @@ -182,7 +184,7 @@ LLVM::GlobalOp getOrCreateGlobalString(StringRef str, Location loc, global = builder.create(loc, type, /*isConstant=*/true, LLVM::Linkage::Internal, str, builder.getStringAttr(str)); - setAlignment(global, nullptr, module, builder, *typeConverter); + krnl::setAlignment(global, nullptr, module, builder, *typeConverter); } return global; @@ -218,4 +220,37 @@ void setAlignment(LLVM::GlobalOp &global, IntegerAttr alignmentAttr, global.setAlignmentAttr(builder.getI64IntegerAttr(MinGlobalAlign)); } +Optional getFunctionDeclaration( + ModuleOp module, StringRef funcName) { + if (module.lookupSymbol(funcName)) + return SymbolRefAttr::get(module.getContext(), funcName); + else + return None; +} + +/// Return a symbol reference to the strncmp function, inserting it into the +/// module if necessary. +FlatSymbolRefAttr getOrInsertStrncmp(OpBuilder &builder, ModuleOp module) { + constexpr const char *funcName = "strncmp"; + Optional optFuncDecl = + krnl::getFunctionDeclaration(module, funcName); + if (optFuncDecl.hasValue()) + return optFuncDecl.getValue(); + + // Create 'strncmp' function signature: `i32 (i8*, i8*, i64)` + MLIRContext *ctx = module.getContext(); + Type i8Type = IntegerType::get(ctx, 8); + Type i8PtrTy = LLVM::LLVMPointerType::get(i8Type); + Type fnType = LLVM::LLVMFunctionType::get(builder.getI32Type(), + ArrayRef({i8PtrTy, i8PtrTy, builder.getI64Type()}), false); + + // Insert the function declaration the module. + PatternRewriter::InsertionGuard insertGuard(builder); + builder.setInsertionPointToStart(module.getBody()); + builder.create(module.getLoc(), funcName, fnType); + + return SymbolRefAttr::get(ctx, funcName); +} + +} // namespace krnl } // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp index 307a6d63e7bf..8d9322de2d1e 100644 --- a/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp +++ b/src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp @@ -23,6 +23,7 @@ using namespace mlir; namespace onnx_mlir { +namespace krnl { /// Get the rank of the given tensor (represented as a memref). int64_t getRankFromMemRefType(LLVM::LLVMStructType memRefTy); @@ -48,4 +49,13 @@ Value getPtrToGlobalString( void setAlignment(LLVM::GlobalOp &global, IntegerAttr alignmentAttr, ModuleOp module, OpBuilder &builder, LLVMTypeConverter &typeConverter); +/// Retrieve the declaration of a function in the given module. +Optional getFunctionDeclaration( + ModuleOp module, StringRef funcName); + +/// Return a symbol reference to the strncmp function, inserting it into the +/// module if necessary. +FlatSymbolRefAttr getOrInsertStrncmp(OpBuilder &builder, ModuleOp module); + +} // namespace krnl } // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp new file mode 100644 index 000000000000..a8334d6f260e --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlUnaryMath.cpp @@ -0,0 +1,199 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlUnaryMath.cpp - Lower KrnlUnaryMath Ops -------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlUnaryMath operators. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +template +struct MathFunctionName { + static std::string functionName() { return "none"; }; +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "erff"; + if (type.isF64()) + return "erf"; + llvm_unreachable("Currently unsupported type for erf"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "acosf"; + if (type.isF64()) + return "acos"; + llvm_unreachable("Unsupported type for acos"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "acoshf"; + if (type.isF64()) + return "acosh"; + llvm_unreachable("Unsupported type for acosh"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "asinf"; + if (type.isF64()) + return "asin"; + llvm_unreachable("Unsupported type for asin"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "asinhf"; + if (type.isF64()) + return "asinh"; + llvm_unreachable("Unsupported type for asinh"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "atanf"; + if (type.isF64()) + return "atan"; + llvm_unreachable("Unsupported type for atan"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "tanf"; + if (type.isF64()) + return "tan"; + llvm_unreachable("Unsupported type for tan"); + } +}; + +template <> +struct MathFunctionName { + static std::string functionName(mlir::Type type) { + if (type.isF32()) + return "atanhf"; + if (type.isF64()) + return "atanh"; + llvm_unreachable("Unsupported type for atanh"); + } +}; + +template +class KrnlUnaryMathOpLowering : public ConversionPattern { +public: + explicit KrnlUnaryMathOpLowering( + TypeConverter &typeConverter, MLIRContext *context) + : ConversionPattern( + typeConverter, KrnlScalarMathOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op->getContext(); + Location loc = op->getLoc(); + + // get the LLVM type for the function args and result + mlir::Type inType = op->getOperand(0).getType(); + mlir::Type llvmType; + if (inType.isF32()) + llvmType = FloatType::getF32(context); + else if (inType.isF64()) + llvmType = FloatType::getF64(context); + + // Insert and/or get reference to elementary math function declaration. + assert( + inType.isIntOrFloat() && "Type for math function must be int or float"); + ModuleOp parentModule = op->getParentOfType(); + auto mathFunctionRef = getOrInsertUnaryMathFunction(rewriter, parentModule, + MathFunctionName().functionName(inType), llvmType); + + // Emit function call. + auto funcCall = rewriter.create( + loc, mathFunctionRef, llvmType, ArrayRef({operands[0]})); + rewriter.replaceOp(op, funcCall.getResults()[0]); + return success(); + } + +private: + // This function emits a declaration of the form: + // + // declare float (float) + // + FlatSymbolRefAttr getOrInsertUnaryMathFunction(PatternRewriter &rewriter, + ModuleOp module, std::string mathFuncName, mlir::Type llvmType) const { + auto *context = module.getContext(); + if (module.lookupSymbol(mathFuncName)) + return SymbolRefAttr::get(context, mathFuncName); + + // Create function declaration. + // auto llvmF32Ty = FloatType::get(context); + auto llvmFnType = + LLVM::LLVMFunctionType::get(llvmType, ArrayRef({llvmType})); + + // Insert the unary math function into the body of the parent module. + PatternRewriter::InsertionGuard insertGuard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + rewriter.create( + module.getLoc(), mathFuncName, llvmFnType); + return SymbolRefAttr::get(context, mathFuncName); + } +}; + +void populateLoweringKrnlUnaryMathOpPattern(TypeConverter &typeConverter, + RewritePatternSet &patterns, MLIRContext *ctx) { + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); + patterns.insert>(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir diff --git a/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp new file mode 100644 index 000000000000..7fc09b5ab2be --- /dev/null +++ b/src/Conversion/KrnlToLLVM/KrnlVectorTypeCast.cpp @@ -0,0 +1,171 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------ KrnlGetRefOp.cpp - Lower KrnlGetRefOp -------------------------===// +// +// Copyright 2019-2022 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the KrnlGetRefOp operator. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp" +#include "src/Dialect/Krnl/KrnlHelper.hpp" +#include "src/Dialect/Krnl/KrnlOps.hpp" +#include "src/Support/KrnlSupport.hpp" + +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "krnl_to_llvm" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace krnl { + +class KrnlVectorTypeCastOpLowering : public ConvertToLLVMPattern { +public: + explicit KrnlVectorTypeCastOpLowering( + LLVMTypeConverter &typeConverter, MLIRContext *context) + : ConvertToLLVMPattern( + KrnlVectorTypeCastOp::getOperationName(), context, typeConverter) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto krnlVectorTypeCastOp = cast(op); + MemRefType sourceType = + krnlVectorTypeCastOp.getOperand().getType().cast(); + MemRefType targetType = krnlVectorTypeCastOp.getType(); + if (!isSupportedMemRefType(targetType) || + !isSupportedMemRefType(sourceType)) + return failure(); + + KrnlVectorTypeCastOp::Adaptor transformed(operands); + MemRefDescriptor srcMemRefDesc(transformed.source()); + + Type targetStructType = + typeConverter->convertType(krnlVectorTypeCastOp.getType()); + if (!targetStructType) + return failure(); + Location loc = op->getLoc(); + // Get memRefDescriptor, the new memref descriptor. + MemRefDescriptor memRefDescriptor = + MemRefDescriptor::undef(rewriter, loc, targetStructType); + auto targetElementPtrType = memRefDescriptor.getElementPtrType(); + + // Set the new memref to the same buffer as the source memref. + Value srcBuffer = srcMemRefDesc.allocatedPtr(rewriter, loc); + Value targetBuffer = rewriter.create( + loc, targetElementPtrType, ArrayRef(srcBuffer)); + memRefDescriptor.setAllocatedPtr(rewriter, loc, targetBuffer); + + // Set the new memref alignment to the same value as source memref. + Value srcBufferAligned = srcMemRefDesc.alignedPtr(rewriter, loc); + Value targetBufAligned = rewriter.create( + loc, targetElementPtrType, ArrayRef(srcBufferAligned)); + memRefDescriptor.setAlignedPtr(rewriter, loc, targetBufAligned); + + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(targetType, strides, offset))) + return failure(); + + // Unhandled dynamic offset. + if (offset == MemRefType::getDynamicStrideOrOffset()) + return failure(); + + memRefDescriptor.setOffset( + rewriter, loc, createIndexConstant(rewriter, loc, offset)); + + // Get the sizes of the memref: all but the last one are copied from the + // source memref. If the dimension size was static, the target memref would + // have the same size. + SmallVector sizes; + sizes.reserve(targetType.getRank()); + for (unsigned pos = 0, e = targetType.getRank() - 1; pos < e; ++pos) { + int64_t dimSize = targetType.getDimSize(pos); + if (ShapedType::isDynamic(dimSize)) + sizes.push_back(srcMemRefDesc.size(rewriter, loc, pos)); + else + sizes.push_back(createIndexConstant(rewriter, loc, dimSize)); + } + + if (!ShapedType::isDynamic(targetType.getShape().back())) { + // The op is already verified to have the right size for the last + // dimension. + sizes.push_back( + createIndexConstant(rewriter, loc, targetType.getShape().back())); + } else { + // We need to divide the dynamic size on the source by the vector width. + // There is the implicit expectation that the last dimension of the + // original memory is a multiple of the vector length. + Value vecWidth = createIndexConstant(rewriter, loc, + targetType.getElementType().cast().getNumElements()); + sizes.push_back(rewriter.create(loc, + srcMemRefDesc.size(rewriter, loc, sourceType.getRank() - 1), + vecWidth)); + } + + assert(!sizes.empty() && "target memref rank can't be zero"); + + // Compute the total number of memref elements. + Value cumulativeSize = sizes.front(); + for (unsigned i = 1, e = sizes.size(); i < e; ++i) + cumulativeSize = rewriter.create( + loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); + + // Calculate the strides. + Value runningStride = nullptr; + // Iterate strides in reverse order, compute runningStride and strideValues. + unsigned nStrides = strides.size(); + SmallVector strideValues(nStrides, nullptr); + for (auto indexedStride : llvm::enumerate(llvm::reverse(strides))) { + int64_t index = nStrides - 1 - indexedStride.index(); + if (strides[index] == MemRefType::getDynamicStrideOrOffset()) + // Identity layout map is enforced in the match function, so we compute: + // `runningStride *= sizes[index + 1]`. + runningStride = runningStride ? rewriter.create(loc, + runningStride, sizes[index + 1]) + : createIndexConstant(rewriter, loc, 1); + else + runningStride = createIndexConstant(rewriter, loc, strides[index]); + strideValues[index] = runningStride; + } + + // Fill size and stride descriptors in memref. + for (auto indexedSize : llvm::enumerate(sizes)) { + int64_t index = indexedSize.index(); + memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); + memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); + } + + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } + + // Check if the MemRefType `type` is supported by the lowering. We currently + // only support memrefs with identity maps. + bool isSupportedMemRefType(MemRefType type) const { + if (!typeConverter->convertType(type.getElementType())) + return false; + return type.getLayout().isIdentity(); + } +}; + +void populateLoweringKrnlVectorTypeCastOpPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace krnl +} // namespace onnx_mlir \ No newline at end of file diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 4c9ba3d71525..0ed539a39bd3 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -262,15 +262,15 @@ void FrontendToKrnlLoweringPass::runOnOperation() { } } -std::unique_ptr mlir::createLowerToKrnlPass() { +std::unique_ptr onnx_mlir::createLowerToKrnlPass() { return std::make_unique(); } -std::unique_ptr mlir::createLowerToKrnlPass(int optLevel) { +std::unique_ptr onnx_mlir::createLowerToKrnlPass(int optLevel) { return std::make_unique(optLevel); } -std::unique_ptr mlir::createLowerToKrnlPass( +std::unique_ptr onnx_mlir::createLowerToKrnlPass( bool emitDealloc, bool enableTiling) { return std::make_unique( emitDealloc, enableTiling); diff --git a/src/InitOMPasses.hpp b/src/InitOMPasses.hpp index 7339c85ce0a5..9398fd42c930 100644 --- a/src/InitOMPasses.hpp +++ b/src/InitOMPasses.hpp @@ -3,76 +3,77 @@ */ #include "mlir/Pass/Pass.h" - #include "src/Pass/Passes.hpp" +using namespace onnx_mlir; + namespace onnx_mlir { void initOMPasses(int optLevel) { // All passes implemented within onnx-mlir should register within this // function to make themselves available as a command-line option. mlir::registerPass([]() -> std::unique_ptr { - return mlir::createONNXOpTransformPass(); + return createONNXOpTransformPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createDecomposeONNXToONNXPass(); + return createDecomposeONNXToONNXPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createShapeInferencePass(); + return createShapeInferencePass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createConstPropONNXToONNXPass(); + return createConstPropONNXToONNXPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createElideConstantValuePass(); + return createElideConstantValuePass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createInstrumentONNXPass(); + return createInstrumentONNXPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createONNXPreKrnlVerifyPass(); + return createONNXPreKrnlVerifyPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createKrnlEnableMemoryPoolPass(); + return krnl::createKrnlEnableMemoryPoolPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createKrnlBundleMemoryPoolsPass(); + return krnl::createKrnlBundleMemoryPoolsPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createKrnlOptimizeMemoryPoolsPass(); + return krnl::createKrnlOptimizeMemoryPoolsPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createConvertKrnlToAffinePass(); + return createConvertKrnlToAffinePass(); }); mlir::registerPass([optLevel]() -> std::unique_ptr { - return mlir::createLowerToKrnlPass(optLevel); + return createLowerToKrnlPass(optLevel); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createElideConstGlobalValuePass(); + return createElideConstGlobalValuePass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createConvertKrnlToLLVMPass(); + return krnl::createConvertKrnlToLLVMPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createDisconnectKrnlDimFromAllocPass(); + return createDisconnectKrnlDimFromAllocPass(); }); mlir::registerPass([]() -> std::unique_ptr { - return mlir::createLowerKrnlShapePass(); + return createLowerKrnlShapePass(); }); } } // namespace onnx_mlir diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index a34c3a26f2c9..1cfb865a6dec 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -16,8 +16,13 @@ #include +using namespace mlir; + namespace mlir { class Pass; +} + +namespace onnx_mlir { /// Pass for ONNX graph level optimization std::unique_ptr createONNXOpTransformPass(); @@ -34,15 +39,6 @@ std::unique_ptr createConstPropONNXToONNXPass(); /// Pass for eliding the values of constant operations. std::unique_ptr createElideConstantValuePass(); -/// Pass for enabling a memory pool for MemRefs. -std::unique_ptr createKrnlEnableMemoryPoolPass(); - -/// Pass for enabling a memory pool for MemRefs. -std::unique_ptr createKrnlBundleMemoryPoolsPass(); - -/// Pass for optimizing memory pools. -std::unique_ptr createKrnlOptimizeMemoryPoolsPass(); - /// Pass for instrument the Onnx ops std::unique_ptr createInstrumentONNXPass(); @@ -67,7 +63,20 @@ std::unique_ptr createLowerKrnlShapePass(); /// Pass for eliding the values of global Krnl operations. std::unique_ptr createElideConstGlobalValuePass(); +namespace krnl { + +/// Pass for enabling a memory pool for MemRefs. +std::unique_ptr createKrnlEnableMemoryPoolPass(); + +/// Pass for enabling a memory pool for MemRefs. +std::unique_ptr createKrnlBundleMemoryPoolsPass(); + +/// Pass for optimizing memory pools. +std::unique_ptr createKrnlOptimizeMemoryPoolsPass(); + /// Pass for lowering Krnl dialect to LLVM dialect. std::unique_ptr createConvertKrnlToLLVMPass(); -} // end namespace mlir +} // namespace krnl + +} // namespace onnx_mlir diff --git a/src/Transform/BundleMemoryPools.cpp b/src/Transform/BundleMemoryPools.cpp index be7be4d72026..2b1dd3322478 100644 --- a/src/Transform/BundleMemoryPools.cpp +++ b/src/Transform/BundleMemoryPools.cpp @@ -543,6 +543,6 @@ class KrnlBundleMemoryPoolsPass }; } // namespace -std::unique_ptr mlir::createKrnlBundleMemoryPoolsPass() { +std::unique_ptr onnx_mlir::krnl::createKrnlBundleMemoryPoolsPass() { return std::make_unique(); } diff --git a/src/Transform/DisconnectKrnlDimFromAlloc.cpp b/src/Transform/DisconnectKrnlDimFromAlloc.cpp index f48d2f0aa851..253d3b3c8495 100644 --- a/src/Transform/DisconnectKrnlDimFromAlloc.cpp +++ b/src/Transform/DisconnectKrnlDimFromAlloc.cpp @@ -151,6 +151,6 @@ class DisconnectKrnlDimFromAllocPass }; } // namespace -std::unique_ptr mlir::createDisconnectKrnlDimFromAllocPass() { +std::unique_ptr onnx_mlir::createDisconnectKrnlDimFromAllocPass() { return std::make_unique(); } diff --git a/src/Transform/ElideKrnlGlobalConstants.cpp b/src/Transform/ElideKrnlGlobalConstants.cpp index 9d06ce6cdc81..340b3b3f5ded 100644 --- a/src/Transform/ElideKrnlGlobalConstants.cpp +++ b/src/Transform/ElideKrnlGlobalConstants.cpp @@ -104,6 +104,6 @@ class ElideConstGlobalValuePass } // namespace -std::unique_ptr mlir::createElideConstGlobalValuePass() { +std::unique_ptr onnx_mlir::createElideConstGlobalValuePass() { return std::make_unique(); } diff --git a/src/Transform/EnableMemoryPool.cpp b/src/Transform/EnableMemoryPool.cpp index 6ee57721bcff..7382a6a4b80c 100644 --- a/src/Transform/EnableMemoryPool.cpp +++ b/src/Transform/EnableMemoryPool.cpp @@ -209,6 +209,6 @@ class KrnlEnableMemoryPoolPass }; } // namespace -std::unique_ptr mlir::createKrnlEnableMemoryPoolPass() { +std::unique_ptr onnx_mlir::krnl::createKrnlEnableMemoryPoolPass() { return std::make_unique(); } diff --git a/src/Transform/LowerKrnlShape.cpp b/src/Transform/LowerKrnlShape.cpp index 781198acf853..b336981c8ece 100644 --- a/src/Transform/LowerKrnlShape.cpp +++ b/src/Transform/LowerKrnlShape.cpp @@ -100,6 +100,6 @@ class LowerKrnlShapePass } // namespace // TODO: integrate with other passes if needed. -std::unique_ptr mlir::createLowerKrnlShapePass() { +std::unique_ptr onnx_mlir::createLowerKrnlShapePass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/ConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp index fb9f897dc6d8..32a383003b45 100644 --- a/src/Transform/ONNX/ConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -798,6 +798,6 @@ void ConstPropONNXToONNXPass::runOnOperation() { /*! * Create a ConstPropONNX pass. */ -std::unique_ptr mlir::createConstPropONNXToONNXPass() { +std::unique_ptr onnx_mlir::createConstPropONNXToONNXPass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/Decompose.cpp b/src/Transform/ONNX/Decompose.cpp index d795ea685e1f..4684bc07bd6d 100644 --- a/src/Transform/ONNX/Decompose.cpp +++ b/src/Transform/ONNX/Decompose.cpp @@ -176,6 +176,6 @@ void DecomposeONNXToONNXPass::runOnOperation() { /*! * Create a DecomposeONNX pass. */ -std::unique_ptr mlir::createDecomposeONNXToONNXPass() { +std::unique_ptr onnx_mlir::createDecomposeONNXToONNXPass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/ElideConstants.cpp b/src/Transform/ONNX/ElideConstants.cpp index dc1a458311c4..9f3f193633f1 100644 --- a/src/Transform/ONNX/ElideConstants.cpp +++ b/src/Transform/ONNX/ElideConstants.cpp @@ -88,6 +88,6 @@ class ElideConstantValuePass /*! * Create a Constant Value Elision pass. */ -std::unique_ptr mlir::createElideConstantValuePass() { +std::unique_ptr onnx_mlir::createElideConstantValuePass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/InstrumentONNXPass.cpp b/src/Transform/ONNX/InstrumentONNXPass.cpp index 3675131ec7e1..8712414f7572 100644 --- a/src/Transform/ONNX/InstrumentONNXPass.cpp +++ b/src/Transform/ONNX/InstrumentONNXPass.cpp @@ -122,6 +122,6 @@ class InstrumentONNXPass /*! * Create an instrumentation pass. */ -std::unique_ptr mlir::createInstrumentONNXPass() { +std::unique_ptr onnx_mlir::createInstrumentONNXPass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/ONNXOpTransformPass.cpp b/src/Transform/ONNX/ONNXOpTransformPass.cpp index a344b09de86a..7e5c9e8dea37 100644 --- a/src/Transform/ONNX/ONNXOpTransformPass.cpp +++ b/src/Transform/ONNX/ONNXOpTransformPass.cpp @@ -128,10 +128,10 @@ void ONNXOpTransformPass::runOnOperation() { do { previousTag = currentTag; OpPassManager dynamicPM("builtin.module"); - dynamicPM.addNestedPass(mlir::createDecomposeONNXToONNXPass()); - dynamicPM.addPass(mlir::createShapeInferencePass()); + dynamicPM.addNestedPass(onnx_mlir::createDecomposeONNXToONNXPass()); + dynamicPM.addPass(onnx_mlir::createShapeInferencePass()); dynamicPM.addPass(mlir::createCanonicalizerPass()); - dynamicPM.addNestedPass(mlir::createConstPropONNXToONNXPass()); + dynamicPM.addNestedPass(onnx_mlir::createConstPropONNXToONNXPass()); if (failed(runPipeline(dynamicPM, module))) return signalPassFailure(); currentTag = createTagForIR(module); @@ -154,10 +154,11 @@ void ONNXOpTransformPass::runOnOperation() { /*! * Create an instrumentation pass. */ -std::unique_ptr mlir::createONNXOpTransformPass() { +std::unique_ptr onnx_mlir::createONNXOpTransformPass() { return std::make_unique(); } -std::unique_ptr mlir::createONNXOpTransformPass(int threshold) { +std::unique_ptr onnx_mlir::createONNXOpTransformPass( + int threshold) { return std::make_unique(threshold); } diff --git a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp index 4b97e82a67a5..4bfe7fd54aa7 100644 --- a/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp +++ b/src/Transform/ONNX/ONNXPreKrnlVerifyPass.cpp @@ -74,6 +74,6 @@ class ONNXPreKrnlVerifyPass /*! * Create an instrumentation pass. */ -std::unique_ptr mlir::createONNXPreKrnlVerifyPass() { +std::unique_ptr onnx_mlir::createONNXPreKrnlVerifyPass() { return std::make_unique(); } diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 4f3485e23891..2542b6979b75 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -174,7 +174,7 @@ class ShapeInferencePass : public mlir::PassWrapper mlir::createShapeInferencePass( +std::unique_ptr onnx_mlir::createShapeInferencePass( bool analyzeAllFunctions) { return std::make_unique(analyzeAllFunctions); } diff --git a/src/Transform/OptimizeMemoryPools.cpp b/src/Transform/OptimizeMemoryPools.cpp index 557e9db8ecef..6c432c80252b 100644 --- a/src/Transform/OptimizeMemoryPools.cpp +++ b/src/Transform/OptimizeMemoryPools.cpp @@ -897,6 +897,6 @@ class KrnlOptimizeMemoryPoolsPass }; } // namespace -std::unique_ptr mlir::createKrnlOptimizeMemoryPoolsPass() { +std::unique_ptr onnx_mlir::krnl::createKrnlOptimizeMemoryPoolsPass() { return std::make_unique(); } diff --git a/test/onnx2mlir/CustomFnTest.cpp b/test/onnx2mlir/CustomFnTest.cpp index 7c5e16b8dc3e..59a48d06292a 100644 --- a/test/onnx2mlir/CustomFnTest.cpp +++ b/test/onnx2mlir/CustomFnTest.cpp @@ -86,7 +86,7 @@ int check(ModelProto &model) { onnx_mlir::ImportFrontendModel(model, context, module, options); mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit); - pm.addPass(mlir::createShapeInferencePass(true)); + pm.addPass(onnx_mlir::createShapeInferencePass(true)); mlir::applyPassManagerCLOptions(pm); if (mlir::failed(pm.run(*module))) { module->dump();