Skip to content

Commit

Permalink
Update LLVM level (llvm#1095)
Browse files Browse the repository at this point in the history
* Update LLVM level to 700997aef8c1f2f08c9ac5fca61650b57a01e8b1

Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>
  • Loading branch information
Ettore Tiotto authored Jan 19, 2022
1 parent 184d1a6 commit 5211d6c
Show file tree
Hide file tree
Showing 30 changed files with 349 additions and 278 deletions.
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 0bf230d4220660af8b2667506f8905df2f716bdf && cd ..
cd llvm-project && git checkout 700997aef8c1f2f08c9ac5fca61650b57a01e8b1 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 0bf230d4220660af8b2667506f8905df2f716bdf && cd ..
cd llvm-project && git checkout 700997aef8c1f2f08c9ac5fca61650b57a01e8b1 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
6 changes: 3 additions & 3 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ class FrontendGenImpl {
const auto elementType = builder_.getIntegerType(64);
const auto attributes = ImportNodeAttributes(node);
for (auto attr : attributes) {
if (auto arrayAttr = attr.second.dyn_cast<ArrayAttr>()) {
if (auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>()) {
const auto tensorType =
RankedTensorType::get({(int64_t)arrayAttr.size()}, elementType);
auto constantDenseAttribute =
Expand All @@ -904,7 +904,7 @@ class FrontendGenImpl {

// Map from ONNX attributes to indices, which are
// matched with ONNXSliceOp::build ordering.
auto inputIdx = llvm::StringSwitch<int>(attr.first)
auto inputIdx = llvm::StringSwitch<int>(attr.getName())
.Case("starts", 1)
.Case("ends", 2)
.Case("axes", 3)
Expand Down Expand Up @@ -956,7 +956,7 @@ class FrontendGenImpl {
auto attributes = ImportNodeAttributes(node);
bool hasAxisAttribute = false;
for (auto &attr : attributes)
if (attr.first.strref().equals_insensitive("axis")) {
if (attr.getName().strref().equals_insensitive("axis")) {
hasAxisAttribute = true;
break;
}
Expand Down
3 changes: 2 additions & 1 deletion src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
Expand Down Expand Up @@ -619,7 +620,7 @@ void addKrnlToLLVMPasses(mlir::OpPassManager &pm) {
// Currently this has to be done *after* lowering the affine dialect because
// operations in that dialect do not conform to the requirements explained in
// https://mlir.llvm.org/docs/BufferDeallocationInternals.
pm.addNestedPass<FuncOp>(mlir::createBufferDeallocationPass());
pm.addNestedPass<FuncOp>(mlir::bufferization::createBufferDeallocationPass());
if (enableMemoryBundling) {
pm.addNestedPass<FuncOp>(mlir::createKrnlEnableMemoryPoolPass());
pm.addNestedPass<FuncOp>(mlir::createKrnlBundleMemoryPoolsPass());
Expand Down
8 changes: 4 additions & 4 deletions src/Conversion/KrnlToLLVM/KrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
Expand Down Expand Up @@ -505,17 +505,17 @@ class KrnlGlobalOpLowering : public ConvertToLLVMPattern {
// attempt to set the alignment based on the module datalayout (if it
// exists).
if (alignmentAttr && alignmentAttr.getValue().getSExtValue() != 0)
global.alignmentAttr(alignmentAttr);
global.setAlignmentAttr(alignmentAttr);
else if (module->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) {
// TODO: use MLIR data layout when it becomes available.
llvm::LLVMContext llvmContext;
int32_t align = LLVM::TypeToLLVMIRTranslator(llvmContext)
.getPreferredAlignment(global.getType(),
getTypeConverter()->getDataLayout());
align = std::max(align, MinGlobalAlign);
global.alignmentAttr(rewriter.getI64IntegerAttr(align));
global.setAlignmentAttr(rewriter.getI64IntegerAttr(align));
} else
global.alignmentAttr(rewriter.getI64IntegerAttr(MinGlobalAlign));
global.setAlignmentAttr(rewriter.getI64IntegerAttr(MinGlobalAlign));

// Prepare data to be inserted into a MemRefDescriptor (a struct).
Value globalOpAddr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Expand Down
19 changes: 10 additions & 9 deletions src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {

auto condReg = rewriter.create<KrnlLoadOp>(loc, cond).getResult();
auto ifOp = rewriter.create<scf::IfOp>(loc, condReg, false);
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());

// Create a scalar tensor out of loop iteration variable, as the first
// argument passed to the body graph function.
Expand Down Expand Up @@ -98,7 +98,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
mapper.map(regionArg, params[i]);
}

auto &thenRegion = ifOp.thenRegion();
auto &thenRegion = ifOp.getThenRegion();
auto &thenBlock = thenRegion.front();

// Split the insertion block into two, where the second block
Expand Down Expand Up @@ -131,23 +131,24 @@ struct ONNXLoopOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(postInsertBlock);

// Cast loop body outputs from tensor type to memref type in case it has
// not already been lowered via dummy_cast. Eventually, dummy cast becomes
// a cast from memref type to a memref type when everything is lowered and
// thus becomes redundant.
// not already been lowered. Eventually, 'UnrealizedConversionCastOp'
// becomes a cast from memref type to a memref type when everything is
// lowered and thus becomes redundant.
SmallVector<Value, 4> bodyOutputs(
resultsRange.begin(), resultsRange.end());
for (unsigned int i = 0; i < bodyOutputs.size(); i++) {
for (unsigned i = 0; i < bodyOutputs.size(); i++) {
auto output = bodyOutputs[i];
assert((output.getType().isa<TensorType>() ||
output.getType().isa<MemRefType>()) &&
"Expecting loop body function output to consist of "
"tensors/memrefs.");
auto outputTy = output.getType().cast<ShapedType>();
bodyOutputs[i] = rewriter
.create<KrnlDummyCastOp>(loc, output,
.create<UnrealizedConversionCastOp>(loc,
MemRefType::get(outputTy.getShape(),
outputTy.getElementType()))
.getResult();
outputTy.getElementType()),
output)
.getResult(0);
}

// Copy the newly computed loop condition to pre-allocated buffer.
Expand Down
28 changes: 16 additions & 12 deletions src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ struct ONNXScanOpLowering : public ConversionPattern {
// Copy content of vInit to vFinal, which is used to host intermediate
// values produced by scan body function invocation in a scope accessible by
// all scan iterations.
auto v_initials = llvm::make_range(
operands.begin(), operands.end() - scanOp.num_scan_inputs());
int64_t numInputs = scanOp.num_scan_inputs();
auto v_initials =
llvm::make_range(operands.begin(), operands.end() - numInputs);
for (const auto &vInitAndFinal : llvm::zip(v_initials, outputs))
emitCopy(rewriter, loc, std::get<0>(vInitAndFinal),
std::get<1>(vInitAndFinal));

auto inputOperands = llvm::make_range(
operands.begin() + (operands.size() - numInputs), operands.end());
MemRefBuilder createMemRef(rewriter, loc);
Value maxTripCount = createMemRef.dim(*inputOperands.begin(), 0);

// Create the scan iteration.
BuildKrnlLoop loop(rewriter, loc, 1);
loop.createDefineOp();
Value maxTripCount =
rewriter.create<memref::DimOp>(loc, scanOp.scan_inputs().front(), 0);

loop.pushBounds(0, maxTripCount);
loop.createIterateOp();
rewriter.setInsertionPointToStart(loop.getIterateBlock());
Expand Down Expand Up @@ -126,23 +129,24 @@ struct ONNXScanOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(postInsertBlock);

// Cast scan body outputs from tensor type to memref type in case it has
// not already been lowered via dummy_cast. Eventually, dummy cast becomes
// a cast from memref type to a memref type when everything is lowered and
// thus becomes redundant.
// not already been lowered. Eventually, 'UnrealizedConversionCastOp'
// becomes a cast from memref type to a memref type when everything is
// lowered and thus becomes redundant.
SmallVector<Value, 4> bodyOutputs(
resultsRange.begin(), resultsRange.end());
for (unsigned int i = 0; i < bodyOutputs.size(); i++) {
for (unsigned i = 0; i < bodyOutputs.size(); i++) {
auto output = bodyOutputs[i];
assert((output.getType().isa<TensorType>() ||
output.getType().isa<MemRefType>()) &&
"Expecting scan body function output to consist of"
"tensors/memrefs.");
auto outputTy = output.getType().cast<ShapedType>();
bodyOutputs[i] = rewriter
.create<KrnlDummyCastOp>(loc, output,
.create<UnrealizedConversionCastOp>(loc,
MemRefType::get(outputTy.getShape(),
outputTy.getElementType()))
.getResult();
outputTy.getElementType()),
output)
.getResult(0);
}

// Copy intermediate values of scan carried dependencies to MemRef outside
Expand Down
12 changes: 9 additions & 3 deletions src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering -------===//
//
// Copyright 2019 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -59,7 +59,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
// Math
populateLoweringONNXClipOpPattern(patterns, ctx);
populateLoweringONNXCumSumOpPattern(patterns, ctx);
populateLoweringONNXElementwiseOpPattern(patterns, ctx);
populateLoweringONNXElementwiseOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXGemmOpPattern(patterns, ctx);
populateLoweringONNXHardmaxOpPattern(patterns, ctx);
populateLoweringONNXReductionOpPattern(patterns, ctx);
Expand All @@ -80,7 +80,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns,
populateLoweringONNXUnsqueezeV11OpPattern(patterns, ctx);
populateLoweringONNXTransposeOpPattern(patterns, ctx);
populateLoweringONNXGatherOpPattern(patterns, ctx);
populateLoweringONNXIdentityOpPattern(patterns, ctx);
populateLoweringONNXIdentityOpPattern(patterns, typeConverter, ctx);
populateLoweringONNXConstantOfShapeOpPattern(patterns, ctx);
populateLoweringONNXConstantOpPattern(patterns, ctx);
populateLoweringONNXConcatOpPattern(patterns, ctx);
Expand Down Expand Up @@ -232,6 +232,12 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
return krnlTypeConverter.isLegal(op);
});

// Operations that are legal only if types are not tensors.
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](Operation *op) {
return llvm::none_of(op->getOperandTypes(),
[](Type type) { return type.isa<TensorType>(); });
});

// Define patterns.
populateONNXToKrnlConversionPattern(
patterns, &getContext(), krnlTypeConverter);
Expand Down
8 changes: 4 additions & 4 deletions src/Conversion/ONNXToKrnl/ML/CategoryMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,23 +268,23 @@ struct ONNXCategoryMapperOpLowering : public ConversionPattern {
TypeSwitch<Type>(elementType)
.Case<IntegerType>([&](IntegerType type) {
// index is valid: retrieve the value from 'cat_strings'.
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value loadData = createKrnl.load(constantForCatsStrings, {index});
createKrnl.store(loadData, alloc, loopInd);

// index is not valid: store the default value.
rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
Value loadDefault = createKrnl.load(defaultString);
createKrnl.store(loadDefault, alloc, loopInd);
})
.Case<StringType>([&](StringType type) {
// index is valid: retrieve the value from 'cat_int64s'.
rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
Value loadData = createKrnl.load(constantForCatsInt64s, {index});
createKrnl.store(loadData, alloc, loopInd);

// index is not valid: store the default value.
rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
createKrnl.store(defaultInt64, alloc, loopInd);
})
.Default([&](Type type) { llvm_unreachable("Illegal KeyTy"); });
Expand Down
31 changes: 18 additions & 13 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===---------------- Elementwise.cpp - Elementwise Ops -------------------===//
//
// Copyright 2019 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand Down Expand Up @@ -856,8 +856,9 @@ Value emitScalarOpFor<ONNXRoundOp>(ConversionPatternRewriter &rewriter,
//===----------------------------------------------------------------------===//
template <typename ElementwiseUnaryOp>
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
ONNXElementwiseUnaryOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = ONNXLoc<ElementwiseUnaryOp>(op);
Expand Down Expand Up @@ -920,9 +921,10 @@ template <typename ElementwiseBinaryOp>
struct ONNXElementwiseBinaryOpLowering : public ConversionPattern {
bool isUniBroadcasting = false;

ONNXElementwiseBinaryOpLowering(
ONNXElementwiseBinaryOpLowering(TypeConverter &typeConverter,
MLIRContext *ctx, bool isUniBroadcasting = false)
: ConversionPattern(ElementwiseBinaryOp::getOperationName(), 1, ctx) {
: ConversionPattern(
typeConverter, ElementwiseBinaryOp::getOperationName(), 1, ctx) {
this->isUniBroadcasting = isUniBroadcasting;
}

Expand Down Expand Up @@ -998,8 +1000,10 @@ struct ONNXElementwiseBinaryOpLowering : public ConversionPattern {
//===----------------------------------------------------------------------===//
template <typename ElementwiseVariadicOp>
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
ONNXElementwiseVariadicOpLowering(
TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc =
Expand Down Expand Up @@ -1081,8 +1085,9 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// where op lowering to Krnl dialect.
//===----------------------------------------------------------------------===//
struct ONNXWhereOpLowering : public ConversionPattern {
ONNXWhereOpLowering(MLIRContext *ctx)
: ConversionPattern(ONNXWhereOp::getOperationName(), 1, ctx) {}
ONNXWhereOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(
typeConverter, ONNXWhereOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand Down Expand Up @@ -1155,8 +1160,8 @@ struct ONNXWhereOpLowering : public ConversionPattern {
}
};

void populateLoweringONNXElementwiseOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateLoweringONNXElementwiseOpPattern(RewritePatternSet &patterns,
TypeConverter &typeConverter, MLIRContext *ctx) {
patterns.insert<ONNXElementwiseUnaryOpLowering<mlir::ONNXAbsOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
Expand Down Expand Up @@ -1207,7 +1212,7 @@ void populateLoweringONNXElementwiseOpPattern(
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, ONNXWhereOpLowering,
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(typeConverter, ctx);
patterns.insert<ONNXElementwiseBinaryOpLowering<mlir::ONNXPReluOp>>(
ctx, /*isUniBroadcasting=*/true);
typeConverter, ctx, /*isUniBroadcasting=*/true);
}
1 change: 1 addition & 0 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include "llvm/Support/Debug.h"

// Used to trace which op are used, good for profiling apps.
#define DEBUG_TYPE "gemm"
Expand Down
9 changes: 4 additions & 5 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//====------ ONNXToKrnlCommon.hpp - ONNX dialects to Krnl lowering --------===//
//
// Copyright 2019-2021 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -18,8 +18,7 @@
#include <map>

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
Expand Down Expand Up @@ -269,7 +268,7 @@ void populateLoweringONNXCumSumOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);

void populateLoweringONNXElementwiseOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);
RewritePatternSet &, TypeConverter &, MLIRContext *);

void populateLoweringONNXGemmOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);
Expand Down Expand Up @@ -351,7 +350,7 @@ void populateLoweringONNXReshapeOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);

void populateLoweringONNXIdentityOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);
RewritePatternSet &, TypeConverter &, MLIRContext *);

void populateLoweringONNXConstantOfShapeOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx);
Expand Down
Loading

0 comments on commit 5211d6c

Please sign in to comment.