From fbf1e64899dc4e174cc7aeb10b7793c6016fbf37 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 9 Aug 2022 10:55:02 +0800 Subject: [PATCH 1/2] fix lowering failed on i32 shape of reduction op --- lib/Conversion/TorchToMhlo/Reduction.cpp | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index 7f8ab24a38f..eb03740116f 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -85,9 +85,14 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter); if (!initValue) return llvm::None; - - Value initIndex = - mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + Value initIndex; + if (mlir::mhlo::kMhloDimSizeBits == 32) { + initIndex = + mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + } else { + initIndex = + mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + } DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( RankedTensorType::get({}, rewriter.getI64Type()), dim); @@ -95,7 +100,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( - op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), + op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)), inputShapeTensor, static_cast(dim)); auto mhloReduceOp = rewriter.create( @@ -111,7 +116,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, // Add block arguments auto blockValArgumentType = RankedTensorType::get({}, inputTy.getElementType()); - auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); + auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)); auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); block.addArgument(blockValArgumentType, op->getLoc()); block.addArgument(blockIdxArgumentType, op->getLoc()); @@ -231,7 +236,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( outShapeVec[dim] = rewriter.create( op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1)); + rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1)); auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); From b9deabf6e4e2a578b390388c0ac668340a7d6df4 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 9 Aug 2022 14:00:39 +0800 Subject: [PATCH 2/2] update --- lib/Conversion/TorchToMhlo/Reduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index eb03740116f..7ae2c58be5e 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -88,7 +88,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value initIndex; if (mlir::mhlo::kMhloDimSizeBits == 32) { initIndex = - mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); + mhlo::getConstTensor(rewriter, op, {0}, {}).getValue(); } else { initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).getValue();