diff --git a/CMakeLists.txt b/CMakeLists.txt index bf031b772c5d..8f6d4d932d19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -39,14 +39,6 @@ endmacro() option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) if(TORCH_MLIR_ENABLE_MHLO) add_definitions(-DTORCH_MLIR_ENABLE_MHLO) - # The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU. - # One can truncate from i64 to i32 since dimension sizes are unlikely to exceed - # the range of i32(4GiB) - option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - "Enable truncate dimension size from i64 to i32(unsafely)" OFF) - if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32) - add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32) - endif() endif() option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index d5aa20532b6d..28138edcb478 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { Convert Torch ops to mhlo ops. }]; let constructor = "mlir::torch::createConvertTorchToMhloPass()"; + + // Specify any options. + let options = [ + Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false", + "Enable static shape conversion">, + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes + // are unlikely to exceed the range of i32(4GiB) + Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false", + "Enable truncate index from i64 to i32(unsafely)">, + ]; } #endif diff --git a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h index ef1337770b8d..8e2f5fc8630d 100644 --- a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h +++ b/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h @@ -17,6 +17,8 @@ namespace mlir { namespace torch { std::unique_ptr> createConvertTorchToMhloPass(); +std::unique_ptr> +createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index); } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 0fafbf1336cd..5d4d95c26b82 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -29,6 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; bool skipMultiplyAlpha(Value alphaValue) { double doubleValue; @@ -379,63 +380,58 @@ class ConvertAtenTransposeIntOp } // namespace // AtenBroadcastToOp -namespace { -class ConvertAtenBroadcastToOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value self = adaptor.self(); - auto selfTy = self.getType().cast(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - -#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE - if (selfTy.hasStaticShape()) { - Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); - rewriter.replaceOp(op, bcastOp); - return success(); - } -#endif +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenBroadcastToOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + auto selfTy = self.getType().cast(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + + if (options.enableStaticShape && selfTy.hasStaticShape()) { + Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); + rewriter.replaceOp(op, bcastOp); + return success(); + } - SmallVector shape; - if (!(getListConstructElements(adaptor.size(), shape))) { - return op->emitError("desired shape must be a list of scalar"); + SmallVector shape; + if (!(getListConstructElements(adaptor.size(), shape))) { + return op->emitError("desired shape must be a list of scalar"); + } + SmallVector bcastShapeVec; + int64_t totalRank = shape.size(); + int64_t selfRank = selfTy.getRank(); + int64_t leadingRank = totalRank - selfRank; + + for (int64_t i = 0; i < totalRank; ++i) { + Value dValue = shape[i]; + Value newD; + int64_t dInt; + if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { + return op->emitError("element of desired shape must be a scalar"); } - SmallVector bcastShapeVec; - int64_t totalRank = shape.size(); - int64_t selfRank = selfTy.getRank(); - int64_t leadingRank = totalRank - selfRank; - - for (int64_t i = 0; i < totalRank; ++i) { - Value dValue = shape[i]; - Value newD; - int64_t dInt; - if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { - return op->emitError("element of desired shape must be a scalar"); - } - if (i >= leadingRank && dInt == -1) { - newD = rewriter.create(op->getLoc(), self, - i - leadingRank); - } else { - dValue = rewriter.create(op->getLoc(), - dValue); - newD = rewriter.create( - op->getLoc(), rewriter.getIndexType(), dValue); - } - bcastShapeVec.push_back(newD); + if (i >= leadingRank && dInt == -1) { + newD = rewriter.create(op->getLoc(), self, + i - leadingRank); + } else { + dValue = rewriter.create(op->getLoc(), + dValue); + newD = rewriter.create( + op->getLoc(), rewriter.getIndexType(), dValue); } + bcastShapeVec.push_back(newD); + } -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + if (options.dimSizeIndexBits == 32) { for (auto &dsize : bcastShapeVec) { auto dsizeI64 = rewriter.create( op->getLoc(), rewriter.getI64Type(), dsize); dsize = rewriter.create(op->getLoc(), rewriter.getI32Type(), dsizeI64); } -#endif + } Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); @@ -445,66 +441,45 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); return success(); - } -}; -} // namespace +} // AtenPermuteOp -namespace { -class ConvertAtenPermuteOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenPermuteOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value self = adaptor.self(); - // Not a ranked tensor type - auto inType = self.getType().dyn_cast(); - auto outType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); - if (!inType) - return op.emitError("only ranked tensor types with static shapes are " - "currently supported"); - - SmallVector permValues; - if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) - return rewriter.notifyMatchFailure( - op, "only constant dimensions are currently supported"); - - int64_t inRank = inType.getRank(); - for (auto &d : permValues) { - d = toPositiveDim(d, inRank); - if (!isValidDim(d, inRank)) - return op.emitError("not all dims are valid"); - } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPermuteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.self(); + // Not a ranked tensor type + auto inType = self.getType().dyn_cast(); + auto outType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + if (!inType) + return op.emitError("only ranked tensor types with static shapes are " + "currently supported"); + + SmallVector permValues; + if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) + return rewriter.notifyMatchFailure( + op, "only constant dimensions are currently supported"); - DenseIntElementsAttr permutation = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(permValues.size())}, - rewriter.getI64Type()), - permValues); - rewriter.replaceOpWithNewOp(op, outType, self, - permutation); - return success(); + int64_t inRank = inType.getRank(); + for (auto &d : permValues) { + d = toPositiveDim(d, inRank); + if (!isValidDim(d, inRank)) + return op.emitError("not all dims are valid"); } -}; - -} // namespace -namespace { -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace + DenseIntElementsAttr permutation = DenseIntElementsAttr::get( + RankedTensorType::get({static_cast(permValues.size())}, + rewriter.getI64Type()), + permValues); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); + return success(); +} // AtenTanhOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenTanhOp op, OpAdaptor adaptor, @@ -520,10 +495,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "only floating-point datatype legalization currently supported"); } } -} // namespace // ValueTensorLiteralOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, @@ -553,11 +526,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -} // namespace // AtenReciprocalOp // Reciprocal(x) = Div(1, x) -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReciprocalOp op, OpAdaptor adaptor, @@ -575,10 +546,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } -} // namespace // PrimNumToTensorScalarOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimNumToTensorScalarOp op, OpAdaptor adaptor, @@ -592,11 +561,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOp(op, mhloTensor); return success(); } -} // namespace // AtenContiguousOp // Ref: TosaToTosa.cpp for implementation details -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenContiguousOp op, OpAdaptor adaptor, @@ -614,11 +581,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -} // namespace // AtenReluOp // Relu(x) = Max(0, x) -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenReluOp op, OpAdaptor adaptor, @@ -641,11 +606,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -} // namespace // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, @@ -668,10 +631,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } -} // namespace // AtenErfOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, @@ -686,10 +647,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -} // namespace // AtenBatchNormOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenBatchNormOp op, OpAdaptor adaptor, @@ -716,12 +675,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value channelDim = rewriter.create(op->getLoc(), input, 1); -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - auto channelDimI64 = rewriter.create( - op->getLoc(), rewriter.getI64Type(), channelDim); - channelDim = rewriter.create( - op->getLoc(), rewriter.getI32Type(), channelDimI64); -#endif + if (options.dimSizeIndexBits == 32) { + auto channelDimI64 = rewriter.create( + op->getLoc(), rewriter.getI64Type(), channelDim); + channelDim = rewriter.create( + op->getLoc(), rewriter.getI32Type(), channelDimI64); + } Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); @@ -806,10 +765,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } -} // namespace // AtenNativeLayerNormOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNativeLayerNormOp op, OpAdaptor adaptor, @@ -949,10 +906,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -} // namespace // AtenCatOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenCatOp op, OpAdaptor adaptor, @@ -983,10 +938,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, outType, ValueRange(builtinTensors), posDim); return success(); } -} // namespace // AtenNumelOp -namespace { template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenNumelOp op, @@ -996,7 +949,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTy = self.getType().dyn_cast(); size_t rank = selfTy.getRank(); - Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); + Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); Value numel = rewriter.create(loc, rewriter.getIntegerAttr(intType, 1)); @@ -1015,26 +968,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } return success(); } -} // namespace void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - #define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context); + patterns.add>(typeConverter, context) INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); INSERT_UNARY_FPONLY_PATTERN(AtenCloneOp, mhlo::CopyOp); @@ -1045,14 +990,14 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ - context); + context) INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context) INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); @@ -1062,7 +1007,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context) INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); @@ -1072,7 +1017,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context) INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); @@ -1086,7 +1031,11 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) + + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenTanhOp); INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); diff --git a/lib/Conversion/TorchToMhlo/Gather.cpp b/lib/Conversion/TorchToMhlo/Gather.cpp index 05b38b9cca3d..a1185c2c14cf 100644 --- a/lib/Conversion/TorchToMhlo/Gather.cpp +++ b/lib/Conversion/TorchToMhlo/Gather.cpp @@ -23,18 +23,14 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; - -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 -static constexpr size_t kMhloDimSizeBits = 32; -#else -static constexpr size_t kMhloDimSizeBits = 64; -#endif +using namespace mlir::torch::torch_to_mhlo; namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, - Value input, Value indices, int64_t axis) { + Value input, Value indices, int64_t axis, + size_t dimSizeIndexBits) { auto loc = op->getLoc(); - Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); @@ -98,16 +94,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, sliceSizesTensor, dimsAttr) .getResult(); } - -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; +} // namespace // Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) @@ -149,8 +136,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "sparse gradients is currently not supported"); - Value output = - gatherTensorAlongSingleAxis(rewriter, op, weight, adaptor.indices(), 0); + Value output = gatherTensorAlongSingleAxis( + rewriter, op, weight, adaptor.indices(), 0, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); @@ -170,24 +157,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant dim is currently supported"); - Value output = - gatherTensorAlongSingleAxis(rewriter, op, self, adaptor.index(), dim); + Value output = gatherTensorAlongSingleAxis( + rewriter, op, self, adaptor.index(), dim, options.dimSizeIndexBits); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); } -} // namespace void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenEmbeddingOp); INSERT_ATENOP_PATTERN(AtenIndexSelectOp); #undef INSERT_ATENOP_PATTERN diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index da54955c479f..3b428c98cda1 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -25,6 +25,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; namespace { Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, @@ -71,7 +72,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, } void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, - Value &inpRhs, int64_t leadingRank) { + Value &inpRhs, int64_t leadingRank, + size_t dimSizeIndexBits) { Value lhs = inpLhs; Value rhs = inpRhs; auto lhsRankTy = inpLhs.getType().dyn_cast(); @@ -92,9 +94,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = - *mhlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims); - auto lhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, lhs); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, rhs, leadingDims, dimSizeIndexBits); + auto lhsDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), lhsDimSizes.end()); lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, @@ -103,9 +106,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(lhsShape.begin(), lhsShape.begin() + leadingRank); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = - *mhlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims); - auto rhsDimSizes = *mhlo::getDimSizesOfTensor(rewriter, op, rhs); + auto newDimSizes = *mhlo::getDimSizesOfTensor( + rewriter, op, lhs, leadingDims, dimSizeIndexBits); + auto rhsDimSizes = + *mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), rhsDimSizes.end()); rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, @@ -122,9 +126,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, // implement their specialized input processing (e.g transpose), and output // processing, e.g. GEMM or fully connected bias handling. template -class ConvertAtenMatmulBaseOp : public OpConversionPattern { +class ConvertAtenMatmulBaseOp : public ConvertAtenOp { public: - using OpConversionPattern::OpConversionPattern; + using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; // Each variant must implement corresponding parameter parsing options. // Maintain separate input read functions for each variant because it is not @@ -159,20 +163,24 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return success(); } + const auto &options = ConvertAtenOp::getOptions(); int64_t nBatchDims; if (rhsRank <= 2) { auto leadingRank = lhsRank - 2; - getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); nBatchDims = leadingRank; } else if (lhsRank <= 2) { auto leadingRank = rhsRank - 2; - getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); nBatchDims = leadingRank; } else { assert(rhsRank > 2 && lhsRank > 2); auto leadingRank = std::max(lhsRank - rhsRank, rhsRank - lhsRank); nBatchDims = std::max(lhsRank - 2, rhsRank - 2); - getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); } auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); auto lhsContractingDim = nBatchDims + 1; @@ -187,7 +195,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - auto resultTy = OpConversionPattern::getTypeConverter() + auto resultTy = ConvertAtenOp::getTypeConverter() ->convertType(op.getType()) .template cast(); @@ -215,7 +223,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { rewriter.replaceOpWithNewOp( op, - OpConversionPattern::getTypeConverter() + ConvertAtenOp::getTypeConverter() ->convertType(op.getType()) .template cast(), output); @@ -340,7 +348,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto rhsTy = rhs.getType().cast(); auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(), rhsTy.getRank() - lhsTy.getRank()); - getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank); + + const auto &options = ConvertAtenOp::getOptions(); + getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank, + options.dimSizeIndexBits); auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); auto nBatchDims = resultRank - 2; auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); @@ -356,8 +367,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { /*rhsContractingDimensions=*/{rhsContractingDim}); auto resultTy = - OpConversionPattern::getTypeConverter()->convertType( - op.getType()); + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); Value matmulOutput = rewriter.create( op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); @@ -377,12 +387,9 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { } }; -} // namespace - -namespace { -class ConvertAtenConvolutionOp : public OpConversionPattern { +class ConvertAtenConvolutionOp : public ConvertAtenOp { public: - using OpConversionPattern::OpConversionPattern; + using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenConvolutionOp::Adaptor; Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op, @@ -390,8 +397,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto weightTy = weight.getType().cast(); auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); - SmallVector weightShapeVec = - *mhlo::getDimSizesOfTensor(rewriter, op, weight); + const auto &options = getOptions(); + SmallVector weightShapeVec = *mhlo::getDimSizesOfTensor( + rewriter, op, weight, options.dimSizeIndexBits); auto weightShape = weightTy.getShape(); SmallVector weightShapeInt(rank); std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); @@ -601,9 +609,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return mhloConvOp.getResult(); } - LogicalResult - matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Value input = adaptor.input(); Value weight = adaptor.weight(); @@ -714,7 +721,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Reshape and promote bias auto inputUnsqzDims = llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); - bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims); + + const auto &options = getOptions(); + bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, + options.dimSizeIndexBits); bias = mhlo::promoteType(rewriter, bias, outTy); DenseIntElementsAttr bcastDimensions; @@ -727,31 +737,31 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); #undef INSERT_MATMUL_ATEMOP_PATTERN #define INSERT_MM_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_MM_ATENOP_PATTERN(AtenMmOp); INSERT_MM_ATENOP_PATTERN(AtenBmmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); #undef INSERT_LINEAR_ATEMOP_PATTERN #define INSERT_CONVOLUTION_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add(typeConverter, context); + patterns.add(typeConverter, context, options) INSERT_CONVOLUTION_ATENOP_PATTERN(AtenConvolutionOp); #undef INSERT_CONVOLUTION_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index 8d646a73f6be..5cab3e7d17a9 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -259,9 +259,10 @@ SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { return posDims; } -FailureOr> -getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, - ArrayRef inpDims) { +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + ArrayRef inpDims, + size_t dimSizeIndexBits) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { return rewriter.notifyMatchFailure( @@ -276,14 +277,15 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, auto loc = op->getLoc(); for (auto d : dims) { dimSizes.emplace_back(rewriter.create( - loc, rewriter.getIntegerType(kMhloDimSizeBits), + loc, rewriter.getIntegerType(dimSizeIndexBits), rewriter.create(loc, value, d))); } return dimSizes; } -FailureOr> -getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + size_t dimSizeIndexBits) { auto valueTy = value.getType().dyn_cast(); if (!valueTy) { return rewriter.notifyMatchFailure( @@ -294,12 +296,12 @@ getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { // Get int vector [0, 1, ..., rank-1] std::vector dims(rank); std::iota(dims.begin(), dims.end(), 0); - return getDimSizesOfTensor(rewriter, op, value, dims); + return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits); } FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, - ArrayRef inputUnsqzDims) { + Value tensor, ArrayRef inputUnsqzDims, + size_t dimSizeIndexBits) { // Returns a new tensor with dims of size 1 inserted at the specified // position. // @@ -307,7 +309,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, // tensor) are specified with unsqzDims. Indices must be in-order, and in // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. - auto dimSizesInfo = getDimSizesOfTensor(rewriter, op, tensor); + auto dimSizesInfo = + getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -324,7 +327,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = tensor.getType().dyn_cast(); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(kMhloDimSizeBits); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); auto one = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h index 5875d7baea29..466cadb81cf7 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -19,11 +19,6 @@ namespace mlir { namespace mhlo { -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 -static constexpr size_t kMhloDimSizeBits = 32; -#else -static constexpr size_t kMhloDimSizeBits = 64; -#endif using mlir::ConversionPatternRewriter; @@ -60,22 +55,23 @@ SmallVector toPositiveDims(ArrayRef dims, int64_t rank); // Get the dimension sizes of the input tensor, given the dimension axes FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, - ArrayRef inpDims); + ArrayRef inpDims, + size_t dimSizeIndexBits); // Get the dimension sizes of the input tensor -FailureOr> -getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value); +FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + size_t dimSizeIndexBits); // Get a tensor that unsqueezed the specified dimensions of the input tensor FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, - ArrayRef inputUnsqzDims); + Value tensor, ArrayRef inputUnsqzDims, + size_t dimSizeIndexBits); - Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); } // namespace mhlo } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H \ No newline at end of file +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToMhlo/Pooling.cpp index 6bb07579bb09..514f941a434b 100644 --- a/lib/Conversion/TorchToMhlo/Pooling.cpp +++ b/lib/Conversion/TorchToMhlo/Pooling.cpp @@ -28,6 +28,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { @@ -72,22 +73,9 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } -namespace { -template -class ConvertAtenPoolingOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - // AtenMaxPool2dOp -namespace { template <> -LogicalResult ConvertAtenPoolingOp::matchAndRewrite( +LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); @@ -186,12 +174,10 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } -} // namespace // AtenMaxPool2dWithIndicesOp -namespace { template <> -LogicalResult ConvertAtenPoolingOp::matchAndRewrite( +LogicalResult ConvertAtenOp::matchAndRewrite( AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); @@ -269,7 +255,9 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( rewriter.getI64Type()), mhloPadding); - auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -379,12 +367,10 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( rewriter.replaceOp(op, reduceWindowOp.getResults()); return success(); } -} // namespace // AtenAvgPool2dOp -namespace { template <> -LogicalResult ConvertAtenPoolingOp::matchAndRewrite( +LogicalResult ConvertAtenOp::matchAndRewrite( AtenAvgPool2dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); @@ -502,7 +488,9 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( Value windowSizeConst = mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); - auto inputShapeVec = *mhlo::getDimSizesOfTensor(rewriter, op, input); + const auto &options = getOptions(); + auto inputShapeVec = + *mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); @@ -540,17 +528,15 @@ LogicalResult ConvertAtenPoolingOp::matchAndRewrite( return success(); } -} // namespace - void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options); target.addIllegalOp(); - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options); target.addIllegalOp(); - patterns.add>(typeConverter, - context); + patterns.add>(typeConverter, + context, options); } diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index 12db6b635a6b..2e195a87fb77 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -16,25 +16,56 @@ namespace mlir { namespace torch { namespace torch_to_mhlo { +struct TorchToMhloOptions { + bool enableStaticShape = false; + size_t dimSizeIndexBits = 64; +}; + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpAdaptor = typename AtenOpT::Adaptor; + ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, + const TorchToMhloOptions &options) + : OpConversionPattern(typeConverter, context) { + this->options = options; + } + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure(op, "haven't been implemented"); + } + const TorchToMhloOptions &getOptions() const { return options; } + +private: + TorchToMhloOptions options; +}; + void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, + const TorchToMhloOptions &options); void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, + const TorchToMhloOptions &options); void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, + const TorchToMhloOptions &options); void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToMhloOptions &options); void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, + const TorchToMhloOptions &options); void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, + const TorchToMhloOptions &options); } // namespace torch_to_mhlo } // namespace torch diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index 2f99e79f73c2..eb5422d58683 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -25,6 +25,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; +using namespace mlir::torch::torch_to_mhlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { @@ -72,7 +73,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, // Util for converting AtenArgmaxOp and AtenMaxDimOp static llvm::Optional getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, - ArrayRef inputShapeVec, int64_t dim) { + ArrayRef inputShapeVec, int64_t dim, + size_t dimSizeIndexBits) { auto inputTy = input.getType().template cast(); if (!inputTy) { return llvm::None; @@ -86,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) return llvm::None; Value initIndex; - if (mlir::mhlo::kMhloDimSizeBits == 32) { + if (dimSizeIndexBits == 32) { initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); } else { initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); @@ -98,7 +100,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( - op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)), + op->getLoc(), + RankedTensorType::get(inputShape, + rewriter.getIntegerType(dimSizeIndexBits)), inputShapeTensor, static_cast(dim)); auto mhloReduceOp = rewriter.create( @@ -114,7 +118,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, // Add block arguments auto blockValArgumentType = RankedTensorType::get({}, inputTy.getElementType()); - auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)); + auto blockIdxArgumentType = + RankedTensorType::get({}, rewriter.getIntegerType(dimSizeIndexBits)); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); @@ -171,9 +176,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, namespace { template -class ConvertAtenReductionOp : public OpConversionPattern { +class ConvertAtenReductionOp : public ConvertAtenOp { public: - using OpConversionPattern::OpConversionPattern; + using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, @@ -220,21 +225,24 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } - auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto mhloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1)); + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); @@ -297,20 +305,24 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } - auto inputShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); + const auto &options = getOptions(); + auto inputShapeInfo = + mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto mhloReduceResults = - getMaxInDim(rewriter, op, input, inputShapeVec, dim).value(); + auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; outShapeVec[dim] = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); @@ -532,15 +544,18 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } if (keepDim) { - auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input); + const auto &options = getOptions(); + auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } @@ -558,11 +573,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 741854c84102..67ff28c39ba3 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -32,6 +32,12 @@ namespace { class ConvertTorchToMhlo : public ConvertTorchToMhloBase { public: + ConvertTorchToMhlo() = default; + ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) { + this->enableStaticShape = enableStaticShape; + this->enableI32Index = enableI32Index; + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -51,18 +57,20 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { RewritePatternSet patterns(context); + torch_to_mhlo::TorchToMhloOptions options{enableStaticShape, + enableI32Index ? 32u : 64u}; torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, - target); - torch_to_mhlo::populateViewLikeOpPatternsAndLegality(typeConverter, - patterns, target); + target, options); + torch_to_mhlo::populateViewLikeOpPatternsAndLegality( + typeConverter, patterns, target, options); torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, - target); - torch_to_mhlo::populateReductionOpPatternsAndLegality(typeConverter, - patterns, target); + target, options); + torch_to_mhlo::populateReductionOpPatternsAndLegality( + typeConverter, patterns, target, options); torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, - target); + target, options); torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, - target); + target, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -75,5 +83,12 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { std::unique_ptr> mlir::torch::createConvertTorchToMhloPass() { - return std::make_unique(); + return std::make_unique(false, false); +} + +std::unique_ptr> +mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape, + bool enableI32Index) { + return std::make_unique(enableStaticShape, + enableI32Index); } diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp index b6cf840be699..e89f3f3895ce 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -28,6 +28,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; +using namespace mlir::torch::torch_to_mhlo; namespace { // A dimension index from torch.dialect might outside the range [0, dimSize]. @@ -55,10 +56,11 @@ Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, Type outTy, Value input, Value startIndex, Value endIndex, Value step, size_t dimIndex, - ArrayRef dimSizes) { + ArrayRef dimSizes, + size_t dimSizeIndexBits) { auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] - Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); + Type intType = rewriter.getIntegerType(dimSizeIndexBits); Value zero = rewriter.create( loc, rewriter.getIntegerAttr(intType, 0)); Value one = rewriter.create( @@ -109,7 +111,8 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, Type outTy, Value input, llvm::Optional startIndexOpt, llvm::Optional endIndexOpt, - llvm::Optional stepOpt, int64_t dim) { + llvm::Optional stepOpt, int64_t dim, + size_t dimSizeIndexBits) { auto loc = op->getLoc(); auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); @@ -133,77 +136,31 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, : rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - auto i32Type = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); - normStartIndex = - rewriter.create(loc, i32Type, normStartIndex); - normEndIndex = rewriter.create(loc, i32Type, normEndIndex); - step = rewriter.create(loc, i32Type, step); -#endif + if (dimSizeIndexBits == 32) { + Type intType = rewriter.getIntegerType(dimSizeIndexBits); + normStartIndex = + rewriter.create(loc, intType, normStartIndex); + normEndIndex = rewriter.create(loc, intType, normEndIndex); + step = rewriter.create(loc, intType, step); + } FailureOr> dimSizesInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input); + mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto dimSizes = *dimSizesInfo; return getDynamicSliceInternal(rewriter, op, outTy, input, normStartIndex, - normEndIndex, step, dim, dimSizes); -} - -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSliceTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.self(); - auto selfTy = self.getType().template cast(); - if (!selfTy) - return op.emitError("only ranked tensor types are supported"); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); - int64_t dim; - if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure( - op, "only constant dim is currently supported"); - - auto getOptionalVal = [&](Value val) -> llvm::Optional { - if (val.getType().isa()) { - return llvm::None; - } else { - return val; - } - }; - - llvm::Optional start = getOptionalVal(adaptor.start()); - llvm::Optional end = getOptionalVal(adaptor.end()); - llvm::Optional step = getOptionalVal(adaptor.step()); - - FailureOr sliceInfo = - getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim); - if (failed(sliceInfo)) - return op.emitError("can not create a dynmaic slice"); - - auto slice = *sliceInfo; - rewriter.replaceOp(op, slice); - return success(); + normEndIndex, step, dim, dimSizes, + dimSizeIndexBits); } // This defines a template to construct ops whose legalizations are // specialized. template -class ConvertAtenViewOp : public OpConversionPattern { +class ConvertAtenViewOp : public ConvertAtenOp { public: - using OpConversionPattern::OpConversionPattern; + using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult @@ -235,19 +192,19 @@ class ConvertAtenViewOp : public OpConversionPattern { return dSize; }); -#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 - // The i64 calculation is much slower than i32 on some devices, such as - // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are - // unlikely to exceed the range of i32(4GiB) - std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { - // dimSize: cast i64 -> i32 - dSize = - rewriter.create(loc, rewriter.getI32Type(), dSize); - return dSize; - }); -#endif + const auto &options = ConvertAtenOp::getOptions(); + Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); + if (options.dimSizeIndexBits == 32) { + // The i64 calculation is much slower than i32 on some devices, such as + // Nvidia GPU. One can truncate from i64 to i32 since dimension sizes are + // unlikely to exceed the range of i32(4GiB) + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { + // dimSize: cast i64 -> i32 + dSize = rewriter.create(loc, intType, dSize); + return dSize; + }); + } - Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); Value numel = rewriter.create( loc, rewriter.getIntegerAttr(intType, 1)); for (auto d : dimSizes) { @@ -293,6 +250,45 @@ bool ConvertAtenViewOp::getAtenViewOpSizes( SmallVector &dimSizes) const { return getListConstructElements(adaptor.shape(), dimSizes); } +} // namespace + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().template cast(); + if (!selfTy) + return op.emitError("only ranked tensor types are supported"); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + int64_t dim; + if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "only constant dim is currently supported"); + + auto getOptionalVal = [&](Value val) -> llvm::Optional { + if (val.getType().isa()) { + return llvm::None; + } else { + return val; + } + }; + + llvm::Optional start = getOptionalVal(adaptor.start()); + llvm::Optional end = getOptionalVal(adaptor.end()); + llvm::Optional step = getOptionalVal(adaptor.step()); + + FailureOr sliceInfo = + getDynamicSlice(rewriter, op, outTy, self, start, end, step, dim, + options.dimSizeIndexBits); + if (failed(sliceInfo)) + return op.emitError("can not create a dynmaic slice"); + + auto slice = *sliceInfo; + rewriter.replaceOp(op, slice); + return success(); +} template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -324,7 +320,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -372,7 +369,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); + auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -397,8 +395,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.dim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - auto unsqzTensorInfo = - mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), {dim}); + auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.self(), + {dim}, options.dimSizeIndexBits); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); @@ -406,16 +404,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOp(op, *unsqzTensorInfo); return success(); } -} // namespace void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, const TorchToMhloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenSqueezeOp); INSERT_ATENOP_PATTERN(AtenSqueezeDimOp); @@ -424,7 +421,7 @@ void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context, options) INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN