Skip to content

Commit

Permalink
[MHLO] Support I32 as shape tensor dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
wujiawei.jw committed Aug 2, 2022
1 parent 060bc45 commit 3751125
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions lib/Conversion/TorchToMhlo/BasicOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,15 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
bcastShapeVec.push_back(newD);
}

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
for (auto &dsize : bcastShapeVec) {
auto dsizeI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), dsize);
dsize = rewriter.create<arith::TruncIOp>(op->getLoc(),
rewriter.getI32Type(), dsizeI64);
}
#endif

Value bcastShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{bcastShapeVec});
auto dimensionNumbers =
Expand Down Expand Up @@ -664,6 +673,14 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();

Value channelDim = rewriter.create<tensor::DimOp>(op->getLoc(), input, 1);

#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
auto channelDimI64 = rewriter.create<mlir::arith::IndexCastOp>(
op->getLoc(), rewriter.getI64Type(), channelDim);
channelDim = rewriter.create<arith::TruncIOp>(
op->getLoc(), rewriter.getI32Type(), channelDimI64);
#endif

Value channelShape = rewriter.create<tensor::FromElementsOp>(
op->getLoc(), ValueRange{channelDim});
if (failed(checkNotNone(rewriter, op, weight))) {
Expand Down

0 comments on commit 3751125

Please sign in to comment.