Skip to content

Commit

Permalink
[MHLO]fix lowering failed on reduction op with i32 shape (#1185)
Browse files Browse the repository at this point in the history
fixed lowering failed on torch::max.dim while shape type is i32
  • Loading branch information
Yancey1989 authored Aug 9, 2022
1 parent e55fc4d commit f83a905
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,22 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,

Value initValue = createInitialValueForReduceOp(op, inputElemTy, rewriter);
if (!initValue) return llvm::None;

Value initIndex =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).value();
Value initIndex;
if (mlir::mhlo::kMhloDimSizeBits == 32) {
initIndex =
mhlo::getConstTensor<int32_t>(rewriter, op, {0}, {}).getValue();
} else {
initIndex =
mhlo::getConstTensor<int64_t>(rewriter, op, {0}, {}).getValue();
}

DenseIntElementsAttr dimensions = DenseIntElementsAttr::get(
RankedTensorType::get({}, rewriter.getI64Type()), dim);

auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
auto indexTensor = rewriter.create<mhlo::DynamicIotaOp>(
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()),
op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits)),

This comment has been minimized.

Copy link
@pashu123

pashu123 Aug 9, 2022

Member

Please clang-format!

inputShapeTensor, static_cast<uint64_t>(dim));

auto mhloReduceOp = rewriter.create<mhlo::ReduceOp>(
Expand All @@ -111,7 +116,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
// Add block arguments
auto blockValArgumentType =
RankedTensorType::get({}, inputTy.getElementType());
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type());
auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits));
auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type());
block.addArgument(blockValArgumentType, op->getLoc());
block.addArgument(blockIdxArgumentType, op->getLoc());
Expand Down Expand Up @@ -231,7 +236,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(

outShapeVec[dim] = rewriter.create<mlir::arith::ConstantOp>(
op->getLoc(), rewriter.getIntegerAttr(
rewriter.getIntegerType(mhlo::kMhloDimSizeBits), 1));
rewriter.getIntegerType(mlir::mhlo::kMhloDimSizeBits), 1));

auto outShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), outShapeVec);
Expand Down

0 comments on commit f83a905

Please sign in to comment.