From 060bc45ca1e288c92cf44c141a82656553159220 Mon Sep 17 00:00:00 2001 From: "wujiawei.jw" Date: Sun, 31 Jul 2022 15:39:09 +0800 Subject: [PATCH 1/4] [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect Co-authored-by: Bairen Yi Co-authored-by: Jiawei Wu Co-authored-by: Tianyou Guo Co-authored-by: Xu Yan Co-authored-by: Ziheng Jiang --- lib/Conversion/TorchToMhlo/BasicOp.cpp | 469 ++++----- lib/Conversion/TorchToMhlo/CMakeLists.txt | 2 + .../TorchToMhlo/MhloLegalizeUtils.cpp | 84 +- .../TorchToMhlo/MhloLegalizeUtils.h | 8 +- lib/Conversion/TorchToMhlo/TorchToMhlo.cpp | 6 +- .../TorchConversion/Transforms/CMakeLists.txt | 30 +- .../TorchConversion/Transforms/Passes.cpp | 13 +- test/Conversion/TorchToMhlo/basic.mlir | 983 +++++++++--------- 8 files changed, 751 insertions(+), 844 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index 4abc401e7aa9..8860192bf75a 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -12,7 +12,10 @@ #include "../PassDetail.h" #include "./MhloLegalizeUtils.h" #include "./PopulatePatterns.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -51,7 +54,7 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("Only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in MHLO"); if (selfTy.getElementType().isa()) { rewriter.replaceOpWithNewOp( @@ -62,7 +65,7 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { return success(); } else { return op.emitError( - "Only floating-point datatype legalization supported"); + "only floating-point datatype legalization supported"); } } }; @@ -85,29 +88,29 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { .template dyn_cast(); if (!outType) - return op.emitError("Only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in MHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) return op.emitError( - "Only floating-point or integer datatype legalization supported"); + "only floating-point or integer datatype legalization supported"); // FIXME: Handle layout, device and pin_memory. Assume dtype has been // processed to set output type correctly? if (!op.layout().getType().template isa()) - return op.emitError("Only default layout is supported"); + return op.emitError("only default layout is supported"); bool pinMemory; if (!op.pin_memory().getType().template isa() && (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return op.emitError( - "Unsupported pin_memory, should be either None or false"); + "unsupported pin_memory, should be either None or false"); } SmallVector shape; if (!matchPattern(op.size(), m_TorchConstantIntList(shape))) { - return op.emitError("Shape must be a list of Scalar constants"); + return op.emitError("shape must be a list of Scalar constants"); } int64_t size = 1; @@ -128,7 +131,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { // These binary op legalizations are specific to add/sub which have an // alpha multiplier. namespace { -template +template class ConvertAtenAddSubOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -137,12 +140,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.self(); - TensorType lhsType = lhs.getType().dyn_cast(); + RankedTensorType lhsType = lhs.getType().dyn_cast(); Value rhs = adaptor.other(); - TensorType rhsType = rhs.getType().dyn_cast(); + RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in MHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -151,53 +154,44 @@ class ConvertAtenAddSubOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( - "Only floating-point or integer datatype legalization supported"); + "only floating-point or integer datatype legalization supported"); } - Value rhsAsTensor; if (!rhsType) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), - rhsAsTensor, outElemTy, - outType.getShape()))) - return op.emitError("Currently only scalar constants are supported for " + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, + outElemTy, {}))) + return op.emitError("currently only scalar constants are supported for " "conversion in MHLO operation"); } - Value lhsTensor = lhs; - Value rhsTensor = rhsType ? rhs : rhsAsTensor; - - // Handle broadcasting. Since we have the output type already, here we - // just broodcast operands' shape to output shape. - lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); - rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); - - // Handle alpha. - Value multTensor; - if (skipMultiplyAlpha(op.alpha())) { - multTensor = rhsTensor; - } else { - Value alphaTensor; + + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); + + if (!skipMultiplyAlpha(op.alpha())) { + Value alpha; if (failed(mhlo::torchAlphaToMhloTensor(rewriter, op.getOperation(), - op.alpha(), alphaTensor, - outElemTy, outType.getShape(), + op.alpha(), alpha, outElemTy, {}, /*checkForUnity=*/false))) { - return op.emitError("Currently only scalar constants are supported for " + return op.emitError("currently only scalar constants are supported for " "alpha in conversion to MHLO operation"); } - - multTensor = rewriter.create(op.getLoc(), outType, rhsTensor, - alphaTensor); + DenseIntElementsAttr bcastDimensions; + rhs = rewriter.create(op->getLoc(), rhs, alpha, + bcastDimensions); } - rewriter.replaceOpWithNewOp(op, outType, lhsTensor, multTensor); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, + bcastDimensions); return success(); } }; } // namespace -// Binary op legalizations for Mul variants. +// Binary op legalizations for Mul/Div variants. namespace { -template -class ConvertAtenMulOp : public OpConversionPattern { +template +class ConvertAtenMulDivOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -210,7 +204,7 @@ class ConvertAtenMulOp : public OpConversionPattern { TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("Only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in MHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -219,86 +213,23 @@ class ConvertAtenMulOp : public OpConversionPattern { Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { return op.emitError( - "Only floating-point or integer datatype legalization supported"); - } - - Value lhsTensor = lhs; - Value rhsTensor; - if (std::is_same()) { - rhsTensor = lhs; - } else { - if (!rhsType) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), - rhsTensor, outElemTy, - outType.getShape()))) - return op.emitError( - "Currently only scalar constants are supported for " - "conversion in MHLO operation"); - } else { - rhsTensor = rhs; - } - } - - // Handle broadcasting. Since we have the output type already, here we - // just broodcast operands' shape to output shape. - lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); - rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); - - rewriter.replaceOpWithNewOp(op, outType, lhsTensor, rhsTensor); - return success(); - } -}; -} // namespace - -// Binary op legalizations for Div variants. -namespace { -template -class ConvertAtenDivOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value lhs = adaptor.self(); - auto lhsTy = lhs.getType().dyn_cast(); - Value rhs = adaptor.other(); - auto rhsTy = rhs.getType().dyn_cast(); - - if (!lhsTy) - return op.emitError("Only Tensor types supported."); - - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); + "only floating-point or integer datatype legalization supported"); } - - Value lhsTensor = lhs; - Value rhsTensor; - if (!rhsTy) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), - rhsTensor, outElemTy, - outType.getShape()))) - return op.emitError("Currently only scalar constants are supported for " + if (!rhsType) { + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, + outElemTy, {}))) + return op.emitError("currently only scalar constants are supported for " "conversion in MHLO operation"); - } else { - rhsTensor = rhs; } - // Handle broadcasting. Since we have the output type already, here we - // just broodcast operands' shape to output shape. - lhsTensor = mhlo::promoteAndBroadcast(rewriter, lhsTensor, outType); - rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, outType); - - rewriter.replaceOpWithNewOp(op, outType, lhsTensor, rhsTensor); + DenseIntElementsAttr bcastDimensions; + lhs = mhlo::promoteType(rewriter, lhs, outType); + rhs = mhlo::promoteType(rewriter, rhs, outType); + rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, + bcastDimensions); return success(); } }; - } // namespace // Binary op legalizations for comparator ops. @@ -318,7 +249,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("Only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in MHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -327,21 +258,19 @@ class ConvertAtenCompareOp : public OpConversionPattern { Type lhsElemTy = lhsTy.getElementType(); if (!lhsElemTy.isIntOrFloat()) { return op.emitError( - "Only floating-point or integer datatype legalization supported"); + "only floating-point or integer datatype legalization supported"); } - Value rhsAsTensor; if (!rhsTy) { - if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), - rhsAsTensor, lhsElemTy, {}))) { - return op.emitError("Currently only scalar constants are supported for " + if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.other(), rhs, + lhsElemTy, {}))) { + return op.emitError("currently only scalar constants are supported for " "conversion in MHLO operation"); } } - Value lhsTensor = lhs; - Value rhsTensor = rhsTy ? rhs : rhsAsTensor; - rhsTensor = mhlo::promoteAndBroadcast(rewriter, rhsTensor, lhsTy); + // TODO: what is the PyTorch default type promotion? + rhs = mhlo::promoteType(rewriter, rhs, lhsTy); mhlo::ComparisonTypeAttr compareTypeAttr; mhlo::ComparisonDirectionAttr compareDirectionAttr; @@ -371,9 +300,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::NE); } - - rewriter.replaceOpWithNewOp( - op, outType, lhsTensor, rhsTensor, compareDirectionAttr, + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); return success(); } @@ -438,12 +367,54 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { matchAndRewrite(AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.self(); + auto selfTy = self.getType().cast(); auto outType = getTypeConverter() ->convertType(op->getResult(0).getType()) .cast(); - Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); - rewriter.replaceOp(op, bcastOp); +#ifdef TORCH_MLIR_ENABLE_MHLO_STATIC_SHAPE + if (selfTy.hasStaticShape()) { + Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); + rewriter.replaceOp(op, bcastOp); + return success(); + } +#endif + + SmallVector shape; + if (!(getListConstructElements(adaptor.size(), shape))) { + return op->emitError("desired shape must be a list of scalar"); + } + SmallVector bcastShapeVec; + int64_t totalRank = shape.size(); + int64_t selfRank = selfTy.getRank(); + int64_t leadingRank = totalRank - selfRank; + + for (int64_t i = 0; i < totalRank; ++i) { + Value dValue = shape[i]; + Value newD; + int64_t dInt; + if (!(matchPattern(dValue, m_TorchConstantInt(&dInt)))) { + return op->emitError("element of desired shape must be a scalar"); + } + if (i >= leadingRank && dInt == -1) { + newD = rewriter.create(op->getLoc(), self, + i - leadingRank); + } else { + dValue = rewriter.create(op->getLoc(), + dValue); + newD = rewriter.create( + op->getLoc(), rewriter.getIndexType(), dValue); + } + bcastShapeVec.push_back(newD); + } + + Value bcastShapeTensor = rewriter.create( + op->getLoc(), ValueRange{bcastShapeVec}); + auto dimensionNumbers = + llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); + rewriter.replaceOpWithNewOp( + op, outType, self, bcastShapeTensor, + rewriter.getI64TensorAttr(dimensionNumbers)); return success(); } }; @@ -464,19 +435,19 @@ class ConvertAtenPermuteOp : public OpConversionPattern { ->convertType(op->getResult(0).getType()) .cast(); if (!inType) - return op.emitError("Only ranked tensor types with static shapes are " + return op.emitError("only ranked tensor types with static shapes are " "currently supported"); SmallVector permValues; if (!matchPattern(adaptor.dims(), m_TorchConstantIntList(permValues))) return rewriter.notifyMatchFailure( - op, "Only constant dimensions are currently supported"); + op, "only constant dimensions are currently supported"); int64_t inRank = inType.getRank(); for (auto &d : permValues) { d = toPositiveDim(d, inRank); if (!isValidDim(d, inRank)) - return op.emitError("Not all dims are valid"); + return op.emitError("not all dims are valid"); } DenseIntElementsAttr permutation = DenseIntElementsAttr::get( @@ -517,7 +488,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } else { return op.emitError( - "Only floating-point datatype legalization currently supported"); + "only floating-point datatype legalization currently supported"); } } } // namespace @@ -535,6 +506,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse // existing elements attribute. + // TODO: what about unsigned integer? if (auto elements = op.valueAttr().dyn_cast()) { Type builtinTensorElemTy = resultType.getElementType(); unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth(); @@ -566,13 +538,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outTy = getTypeConverter()->convertType(op.getType()).cast(); if (!inputTy.getElementType().isa()) { - return op.emitError("Only floating-point datatype legalization supported " + return op.emitError("only floating-point datatype legalization supported " "for AtenReciprocalOp"); } - Value oneTensor = - mhlo::getConstTensor(rewriter, op, {static_cast(1.0)}, {}) - .getValue(); - oneTensor = mhlo::promoteAndBroadcast(rewriter, oneTensor, inputTy); + + Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } @@ -593,7 +563,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (failed(mhlo::torchScalarToMhloTensor(rewriter, op, op.a(), mhloTensor, outputElemType, outputShape, false))) { - return op->emitError("Failed lowering PrimNumToTensorScalarOp to MHLO"); + return op->emitError("failed lowering PrimNumToTensorScalarOp to MHLO"); } rewriter.replaceOp(op, mhloTensor); return success(); @@ -611,7 +581,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); if (!selfType) - return op.emitError("Only tensor types are currently supported"); + return op.emitError("only tensor types are currently supported"); // FIXME: memory_format is not handled. @@ -633,27 +603,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto lhsTy = lhs.getType().cast(); auto lhsElemTy = lhsTy.getElementType(); - int64_t lhsSize = 1; - for (auto &en : llvm::enumerate(lhsTy.getShape())) { - lhsSize *= en.value(); - } - auto constTy = RankedTensorType::get(lhsTy.getShape(), lhsElemTy); - DenseElementsAttr constAttr; - if (lhsElemTy.isa()) { - std::vector constVec( - lhsSize, - APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), - /*negative=*/false)); - constAttr = DenseElementsAttr::get(constTy, constVec); - } else if (lhsElemTy.isa()) { - std::vector constVec( - lhsSize, APInt::getZero(lhsElemTy.getIntOrFloatBitWidth())); - constAttr = DenseElementsAttr::get(constTy, constVec); + if (!lhsElemTy.isa()) { + return op->emitError("only float tensor in relu op is supported"); } - Value rhs = - rewriter.create(op.getLoc(), constTy, constAttr); - rewriter.replaceOpWithNewOp(op, lhs, rhs); + Value zeroTensor; + zeroTensor = chlo::getConstantLike( + rewriter, op->getLoc(), + APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), + false), + lhs); + rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } @@ -666,72 +626,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenErfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.self(); - auto inputType = input.getType().cast(); + auto inputType = input.getType().cast(); if (!inputType.getElementType().isa()) { - return rewriter.notifyMatchFailure(op, "Only support float data type"); + return rewriter.notifyMatchFailure(op, "only float tensor is supported"); } - auto outType = - getTypeConverter()->convertType(op.getType()).cast(); - - // Using: - // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with - // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = - // 0.000972, a4 = 0.078108. - // Erf = 1 - 1 / (1 + a1X + a2X^2 + a3X^3 + a4X^4)^4 - - auto loc = op->getLoc(); - auto zeroConst = - mhlo::getConstTensor(rewriter, op, {0.0}, {}).getValue(); - auto zero = mhlo::promoteAndBroadcast(rewriter, zeroConst, outType); - auto oneConst = - mhlo::getConstTensor(rewriter, op, {1.0}, {}).getValue(); - auto one = mhlo::promoteAndBroadcast(rewriter, oneConst, outType); - auto a1Const = - mhlo::getConstTensor(rewriter, op, {0.278393}, {}).getValue(); - auto a1 = mhlo::promoteAndBroadcast(rewriter, a1Const, outType); - auto a2Const = - mhlo::getConstTensor(rewriter, op, {0.230389}, {}).getValue(); - auto a2 = mhlo::promoteAndBroadcast(rewriter, a2Const, outType); - auto a3Const = - mhlo::getConstTensor(rewriter, op, {0.000972}, {}).getValue(); - auto a3 = mhlo::promoteAndBroadcast(rewriter, a3Const, outType); - auto a4Const = - mhlo::getConstTensor(rewriter, op, {0.078108}, {}).getValue(); - auto a4 = mhlo::promoteAndBroadcast(rewriter, a4Const, outType); - - auto absX = rewriter.create(loc, outType, input); - auto a1X = rewriter.create(loc, outType, a1, absX); - auto sum = rewriter.create(loc, outType, a1X, one); - - auto x2 = rewriter.create(loc, outType, absX, absX); - auto a2X = rewriter.create(loc, outType, a2, x2); - sum = rewriter.create(loc, outType, sum, a2X); - - auto x3 = rewriter.create(loc, outType, x2, absX); - auto a3X = rewriter.create(loc, outType, a3, x3); - sum = rewriter.create(loc, outType, sum, a3X); - - auto x4 = rewriter.create(loc, outType, x3, absX); - auto a4X = rewriter.create(loc, outType, a4, x4); - sum = rewriter.create(loc, outType, sum, a4X); - - auto rcprl = rewriter.create(loc, outType, one, sum); - auto rcprl2 = rewriter.create(loc, outType, rcprl, rcprl); - auto rcprl4 = rewriter.create(loc, outType, rcprl2, rcprl2); - auto erf = rewriter.create(loc, outType, one, rcprl4); - - // Deal with negative x. - mhlo::ComparisonDirectionAttr compareDirectionAttr = - mhlo::ComparisonDirectionAttr::get(op->getContext(), - mhlo::ComparisonDirection::GE); - mhlo::ComparisonTypeAttr compareTypeAttr = mhlo::ComparisonTypeAttr::get( - op->getContext(), mhlo::ComparisonType::FLOAT); - auto geZero = rewriter.create( - loc, RankedTensorType::get(outType.getShape(), rewriter.getI1Type()), - input, zero, compareDirectionAttr, compareTypeAttr); - auto negaErf = rewriter.create(loc, erf); - rewriter.replaceOpWithNewOp(op, outType, geZero, erf, - negaErf); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), input); return success(); } @@ -754,59 +654,63 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value momentum = adaptor.momentum(); (void)momentum; - // init weight, bias, runningVar, runningMean if they are none - auto initNoneValue = [&](Value &input, bool zero) { - SmallVector constVec(inputTy.getShape()[1], - APFloat::getZero(inputTy.getElementType() - .cast() - .getFloatSemantics())); - if (!zero) { - for (auto &item : constVec) { - item = APFloat(inputTy.getElementType() - .cast() - .getFloatSemantics(), - 1); - } - } - auto constType = RankedTensorType::get({inputTy.getShape()[1]}, - inputTy.getElementType()); - auto constAttr = DenseElementsAttr::get(constType, constVec); - input = - rewriter.create(op.getLoc(), constType, constAttr); - }; + if (inputTy.getRank() <= 2) { + return rewriter.notifyMatchFailure(op, + "input should have rank larger than 2"); + } + if (!inputTy.getElementType().template isa()) { + return op.emitError("only input tensor of float type is supported"); + } + auto inputElemTy = inputTy.getElementType().cast(); + + Value channelDim = rewriter.create(op->getLoc(), input, 1); + Value channelShape = rewriter.create( + op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { - initNoneValue(weight, false); + weight = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { - initNoneValue(bias, true); + bias = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { - initNoneValue(runningVar, false); + runningVar = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { - initNoneValue(runningMean, true); + runningMean = mhlo::getConstantOfShape( + rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), + channelShape, + RankedTensorType::get({inputTy.getShape()[1]}, + inputTy.getElementType())); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); auto runningMeanTy = runningMean.getType().cast(); auto runningVarTy = runningVar.getType().cast(); - if (inputTy.getRank() <= 2) { - return rewriter.notifyMatchFailure(op, - "input should have rank larger than 2"); - } + if (weightTy.getRank() != 1 || biasTy.getRank() != 1 || runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) { return rewriter.notifyMatchFailure( op, "expect weight, bias, running_mean and running_var to be rank 1"); } - if (!inputTy.getElementType().template isa() || - !weightTy.getElementType().template isa() || + if (!weightTy.getElementType().template isa() || !biasTy.getElementType().template isa() || !runningMeanTy.getElementType().template isa() || !runningVarTy.getElementType().template isa()) { - return op.emitError( - "Only float element type is supported in MHLO BatchNormOp"); + return op.emitError("only float weight/bias/runningMean/runningVar tensor " + "of float type is supported"); } double eps = 0.0; @@ -858,6 +762,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value weight = adaptor.weight(); Value bias = adaptor.bias(); + if (!inputTy.hasStaticShape()) { + return op->emitError("dynamic shaped input is not supported"); + } + SmallVector normalizedShape; if (!matchPattern(op.normalized_shape(), m_TorchConstantIntList(normalizedShape))) { @@ -866,11 +774,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } double eps = 0; if (!matchPattern(op.eps(), m_TorchConstantFloat(&eps))) { - return rewriter.notifyMatchFailure(op, "non const float eps unsupported"); + return rewriter.notifyMatchFailure(op, + "non const float eps is unsupported"); } if (failed(checkNotNone(rewriter, op, weight)) || failed(checkNotNone(rewriter, op, bias))) { - return op->emitError("Unsupported None for weight or bias"); + return op->emitError("none weight or bias is unsupported"); } auto weightTy = weight.getType().cast(); auto biasTy = bias.getType().cast(); @@ -878,13 +787,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!inputTy.getElementType().isa() || !biasTy.getElementType().isa() || !weightTy.getElementType().isa()) { - return op->emitError("For now, only float data type are supported"); + return op->emitError("currently only float data type are supported"); } int64_t normalizedShapeRank = normalizedShape.size(); if (weightTy.getRank() != normalizedShapeRank || biasTy.getRank() != normalizedShapeRank || inputRank < normalizedShapeRank || normalizedShapeRank < 1) { - return rewriter.notifyMatchFailure(op, "Input or weight or bias shape or" + return rewriter.notifyMatchFailure(op, "input or weight or bias shape or" "normalized shape not compatible"); } for (int64_t i = 1; i <= normalizedShapeRank; i++) { @@ -897,7 +806,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } - // flatten dims to fit batch_norm operation. + // Flatten dims to fit batch_norm operation. int64_t numFeatureDimSize = 1; int64_t numEmbeddingDimSize = 1; for (int64_t i = 0; i < inputRank - normalizedShapeRank; i++) { @@ -915,14 +824,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto mhloBathNormOutMeanOrVarTy = RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); - // reshape input + // Reshape input auto mhloInput = rewriter.create( op->getLoc(), mhloBatchNormOutTy, input, mhlo::getConstTensor(rewriter, op, llvm::makeArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .getValue()); - // generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. + // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() @@ -946,7 +855,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); - // reshape back + // Reshape back auto outputTy = getTypeConverter()->convertType(op.getType(0)).cast(); auto outputMeanOrVarTy = @@ -1016,32 +925,28 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); #undef INSERT_CONSTANT_FILL_PATTERN -#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, MhloOp) \ +#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, mhlo::AddOp); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, mhlo::AddOp); - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, mhlo::SubtractOp); - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, mhlo::SubtractOp); + patterns.add>(typeConverter, context); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, chlo::BroadcastAddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN -#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); -#undef INSERT_BINARY_MUL_PATTERN - -#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ +#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); -#undef INSERT_BINARY_DIV_PATTERN + patterns.add>(typeConverter, context); + INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp); + INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp); + INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp); +#undef INSERT_BINARY_MULDIV_PATTERN #define INSERT_BINARY_COMPARE_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); @@ -1067,4 +972,4 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); #undef INSERT_ATENOP_PATTERN -} +} \ No newline at end of file diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 7798b8d18d05..93a50a61409f 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo DEPENDS MhloDialect + ChloDialect TorchMLIRConversionPassIncGen LINK_COMPONENTS @@ -19,6 +20,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MLIRIR MLIRPass MhloDialect + ChloDialect TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index 72d31802b7bb..8fff582e77e3 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -117,9 +117,9 @@ llvm::Optional getConstTensor(PatternRewriter &rewriter, } template <> -llvm::Optional getConstTensor(PatternRewriter &rewriter, - Operation *op, ArrayRef vec, - ArrayRef shape) { +llvm::Optional +getConstTensor(PatternRewriter &rewriter, Operation *op, + ArrayRef vec, ArrayRef shape) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -149,7 +149,6 @@ template llvm::Optional getConstTensor(PatternRewriter &, ArrayRef vec, ArrayRef shape); - template static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, const int64_t &intValue) { @@ -166,20 +165,15 @@ static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, } template -Value getSplatConstTensor(ConversionPatternRewriter &rewriter, - Operation *op, - T val, - Type dtype, - llvm::ArrayRef dshape) { - auto const_type = RankedTensorType::get( - dshape, dtype); +Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + T val, Type dtype, llvm::ArrayRef dshape) { + auto const_type = RankedTensorType::get(dshape, dtype); auto const_attr = SplatElementsAttr::get(const_type, val); auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); return const_op.getResult(); } - LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value torchScalarValue, Value &mhloTensor, Type dtype, @@ -198,9 +192,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, if (dtype.isa()) { if (doBroadcast) { - mhloTensor = getSplatConstTensor(rewriter, op, - (isFloat ? doubleValue : intValue), - dtype, dshape); + mhloTensor = getSplatConstTensor( + rewriter, op, (isFloat ? doubleValue : intValue), dtype, dshape); } else { mhloTensor = mhlo::getConstTensor( rewriter, op, (isFloat ? doubleValue : intValue), dshape) @@ -219,7 +212,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); if (doBroadcast) { - mhloTensor = getSplatConstTensor(rewriter, op, d, dtype, dshape); + mhloTensor = + getSplatConstTensor(rewriter, op, d, dtype, dshape); } else { mhloTensor = mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); @@ -231,7 +225,8 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); if (doBroadcast) { - mhloTensor = getSplatConstTensor(rewriter, op, d, dtype, dshape); + mhloTensor = + getSplatConstTensor(rewriter, op, d, dtype, dshape); } else { mhloTensor = mhlo::getConstTensor(rewriter, op, {d}, dshape).getValue(); @@ -243,7 +238,6 @@ LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, return success(); } - LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value alphaScalar, Value &alphaTensor, Type dtype, @@ -268,20 +262,33 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, return success(); } +Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { + Operation *op = input.getDefiningOp(); + TensorType in_type = input.getType().dyn_cast(); -Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, - Value input, TensorType outType) { + if (in_type.getElementType() != outType.getElementType()) { + TensorType promotedType = + in_type.cloneWith(in_type.getShape(), outType.getElementType()); + return rewriter.create(op->getLoc(), promotedType, input); + } + return input; +} + +Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, + TensorType outType) { // Two tensors are “broadcastable” if the following rules hold: // - Each tensor has at least one dimension. - // - When iterating over the dimension sizes, starting at the trailing dimension, - // the dimension sizes must either be equal, one of them is 1, or one of them - // does not exist. - Operation* op = input.getDefiningOp(); + // - When iterating over the dimension sizes, starting at the trailing + // dimension, the dimension sizes must either be equal, one of them is 1, or + // one of them does not exist. + Operation *op = input.getDefiningOp(); TensorType in_type = input.getType().dyn_cast(); if (in_type.getElementType() != outType.getElementType()) { - TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - input = rewriter.create(op->getLoc(), promoted_type, input); + TensorType promoted_type = + in_type.cloneWith(in_type.getShape(), outType.getElementType()); + input = + rewriter.create(op->getLoc(), promoted_type, input); } ArrayRef inShape = in_type.getShape(); @@ -301,7 +308,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, bcastDims.push_back(outPos); do_bcast = true; } else { - op->emitError("The size of tensor a (") << inDim << ")" + op->emitError("The size of tensor a (") + << inDim << ")" << "must match the size of tensor b (" << outDim << ")" << "at non-singleton dimension " << inPos; } @@ -311,13 +319,15 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, return input; } DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(bcastDims.size())}, rewriter.getI64Type()), + RankedTensorType::get({static_cast(bcastDims.size())}, + rewriter.getI64Type()), bcastDims); - auto bcast_op = - rewriter.create(op->getLoc(), outType, input, bcast_attr); + auto bcast_op = rewriter.create(op->getLoc(), outType, + input, bcast_attr); return bcast_op.getResult(); } +<<<<<<< HEAD SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { SmallVector posDims; posDims.reserve(rank); @@ -418,5 +428,17 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, .getResult(); } +======= +Value getConstantOfShape(PatternRewriter &rewriter, Location loc, + const APFloat &constant, Value shape, + TensorType outType) { + auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); + auto constTensor = rewriter.create(loc, constAttr); + return rewriter + .create(loc, outType, constTensor, shape, + rewriter.getI64TensorAttr({})) + .getResult(); +} +>>>>>>> [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect } // namespace mhlo -} // namespace mlir +} // namespace mlir \ No newline at end of file diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h index acc1a244cd2c..850c7b75a08e 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h @@ -59,6 +59,8 @@ LogicalResult torchAlphaToMhloTensor(ConversionPatternRewriter &rewriter, llvm::ArrayRef dshape, bool checkForUnity); +Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); + Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, TensorType outType); @@ -78,7 +80,11 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef inputUnsqzDims); + +Value getConstantOfShape(PatternRewriter &rewriter, Location loc, + const APFloat &constant, Value shape, + TensorType outType); } // namespace mhlo } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H \ No newline at end of file diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 890b426909c2..1907a10c54b2 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" @@ -32,6 +33,7 @@ namespace { class ConvertTorchToMhlo : public ConvertTorchToMhloBase { public: void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -40,7 +42,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - target.addLegalDialect(); TypeConverter typeConverter; @@ -68,4 +70,4 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { std::unique_ptr> mlir::torch::createConvertTorchToMhloPass() { return std::make_unique(); -} +} \ No newline at end of file diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 77f58e53cd4f..e90caff5905e 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,3 +1,19 @@ +set(LinkedLibs MLIRIR + MLIRPass + MLIRFuncTransforms + TorchMLIRTorchConversionDialect + TorchMLIRTorchDialect + TorchMLIRTorchPasses + TorchMLIRTorchToLinalg + TorchMLIRTorchToTMTensor + TorchMLIRTorchToStd + TorchMLIRTorchToSCF + MLIRMemRefTransforms) + +if(TORCH_MLIR_ENABLE_MHLO) + list(APPEND LinkedLibs ChloPasses) +endif() + add_mlir_library(TorchMLIRTorchConversionPasses BackendTypeConversion.cpp BackendTypeConversionPasses.cpp @@ -17,15 +33,5 @@ add_mlir_library(TorchMLIRTorchConversionPasses Core LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRFuncTransforms - TorchMLIRTorchConversionDialect - TorchMLIRTorchDialect - TorchMLIRTorchPasses - TorchMLIRTorchToLinalg - TorchMLIRTorchToTMTensor - TorchMLIRTorchToStd - TorchMLIRTorchToSCF - MLIRMemRefTransforms -) + ${LinkedLibs} +) \ No newline at end of file diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 2b439a2a5672..69bfdb69d48a 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -21,6 +21,7 @@ #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #ifdef TORCH_MLIR_ENABLE_MHLO +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -145,10 +146,20 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( // The resolution of `dim` ops tends to create identical ops. CSE them. pm.addNestedPass(createCSEPass()); } + + // Convert CHLO ops to MHLO ops + pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); + if (options.optimize) { + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); + } + // Finish the type conversion from `torch` types to the types of the // MHLO backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); } -#endif +#endif \ No newline at end of file diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index 045e694ab236..e11e40fceba8 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -13,573 +13,610 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // ----- -// CHECK-LABEL: func.func @torch.aten.addscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = mhlo.add %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int9 = torch.constant.int 9 - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.log$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.exp$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64> -// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64> -func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64> - return %0 : !torch.vtensor<[4,64],si64> +// CHECK-LABEL: func.func @torch.aten.clone$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %none = torch.constant.none + %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$bcast( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = mhlo.add %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.addtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.neg$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.negate %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32> -// CHECK: %[[VAL_7:.*]] = mhlo.add %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int2 = torch.constant.int 2 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> +func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { + %0 = torch.vtensor.literal(dense<0.0> : tensor) : !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int9 = torch.constant.int 9 - %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> +func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { + %0 = torch.vtensor.literal(dense<1> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> } + // ----- -// CHECK-LABEL: func.func @torch.aten.multensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.addscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.multensor$bcast( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32> -// CHECK: %[[VAL_5:.*]] = mhlo.multiply %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32> -func.func @torch.aten.multensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { - %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32> - return %0 : !torch.vtensor<[8,4,64],f32> +// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.subscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = mhlo.subtract %[[VAL_1]], %[[VAL_4]] : tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int9 = torch.constant.int 9 +// CHECK-LABEL: func.func @torch.aten.addtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[4,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> } // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],si64> -> tensor<4x64xi64> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor<4x64xi32>) -> tensor<4x64xi64> -// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xi64> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi64> -> !torch.vtensor<[4,64],si64> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],si64> -func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[4,64],si64>) -> !torch.vtensor<[4,64],si64> { +// CHECK-LABEL: func.func @torch.aten.subscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[4,64],si64>, !torch.int -> !torch.vtensor<[4,64],si64> - return %0 : !torch.vtensor<[4,64],si64> + %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.subtensor$bcast( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_5]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.subtensor$bcast(%arg0: !torch.vtensor<[64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK-LABEL: func.func @torch.aten.subtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- // CHECK-LABEL: func.func @torch.aten.subtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<2.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_6:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_5]] : tensor<4x64xf32> -// CHECK: %[[VAL_7:.*]] = mhlo.subtract %[[VAL_2]], %[[VAL_6]] : tensor<4x64xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %4 : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %int2 = torch.constant.int 2 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.divscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 9 -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<9.000000e+00> : tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_1]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int9 = torch.constant.int 9 - %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.subtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> } // ----- -// CHECK-LABEL: func.func @torch.aten.divtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_3]] : tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[4,64],f32> -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_multiply %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.divtensor$bcast( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[8,4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[8,4,64],f32> -> tensor<8x4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[8,1,64],f32> -> tensor<8x1x64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<8x1x64xf32>) -> tensor<8x4x64xf32> -// CHECK: %[[VAL_5:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_4]] : tensor<8x4x64xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[8,4,64],f32> -func.func @torch.aten.divtensor$bcast(%arg0: !torch.vtensor<[8,4,64],f32>, %arg1: !torch.vtensor<[8,1,64],f32>) -> !torch.vtensor<[8,4,64],f32> { - %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[8,4,64],f32>, !torch.vtensor<[8,1,64],f32> -> !torch.vtensor<[8,4,64],f32> - return %0 : !torch.vtensor<[8,4,64],f32> +// CHECK-LABEL: func.func @torch.aten.multensor$basic( +// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_4:.*]] = chlo.broadcast_multiply %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.log$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-LABEL: func.func @torch.aten.divscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_divide %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.exp$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK-LABEL: func.func @torch.aten.divtensor$basic( +// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_4:.*]] = chlo.broadcast_divide %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.clone$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.none -// CHECK: %[[VAL_3:.*]] = "mhlo.copy"(%[[VAL_1]]) : (tensor<4x64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %none = torch.constant.none - %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[4,64],f32>, !torch.none -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> +// CHECK-LABEL: func.func @torch.aten.gt.scalar( +// CHECK-SAME: %arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %1 = mhlo.constant dense<3.000000e+00> : tensor +// CHECK: %2 = chlo.broadcast_compare %0, %1 {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %3 = torch_c.from_builtin_tensor %2 : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %3 : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int3 = torch.constant.int 3 + %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> } // ----- -// CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> -func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { - %0 = torch.vtensor.literal(dense<0.0> : tensor) : !torch.vtensor<[],f32> - return %0 : !torch.vtensor<[],f32> +// CHECK-LABEL: func.func @torch.aten.gt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> } // ----- -// CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> -// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> -func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { - %0 = torch.vtensor.literal(dense<1> : tensor<2xsi64>) : !torch.vtensor<[2],si64> - return %0 : !torch.vtensor<[2],si64> +// CHECK-LABEL: func.func @torch.aten.lt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> } // ----- -// CHECK-LABEL: func.func @torch.aten.gt.scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<3.000000e+00> : tensor -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_1]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> -func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],i1> { - %int3 = torch.constant.int 3 - %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.eq.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> } // ----- -// CHECK-LABEL: func.func @torch.aten.gt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> -func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { - %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.ne.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> } // ----- -// CHECK-LABEL: func.func @torch.aten.gt.tensor$convert( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],si32> -> tensor<4x64xi32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_3]]) : (tensor<64xf32>) -> tensor<64xi32> -// CHECK: %[[VAL_5:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_4]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xi32>) -> tensor<4x64xi32> -// CHECK: %[[VAL_6:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_5]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xi32>, tensor<4x64xi32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> { +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor +// CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> + return %0 : !torch.vtensor<[], si64> +} + +// ----- -func.func @torch.aten.gt.tensor$convert(%arg0: !torch.vtensor<[4,64],si32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { - %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],si32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.contiguous( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32> +func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> + return %0 : !torch.vtensor<[4,64],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.lt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> -func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { - %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.reciprocal( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> +func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.eq.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> -func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { - %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.permute$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> +func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> + return %1 : !torch.vtensor<[64,4],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.ne.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<64xf32>) -> tensor<4x64xf32> -// CHECK: %[[VAL_5:.*]] = "mhlo.compare"(%[[VAL_2]], %[[VAL_4]]) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<4x64xf32>, tensor<4x64xf32>) -> tensor<4x64xi1> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<4x64xi1> -> !torch.vtensor<[4,64],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[4,64],i1> -func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[4,64],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[4,64],i1> { - %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[4,64],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[4,64],i1> - return %0 : !torch.vtensor<[4,64],i1> +// CHECK-LABEL: func.func @torch.aten.transpose$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> +func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$dynamic_implicit( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int-1 = torch.constant.int -1 +// CHECK: %int4 = torch.constant.int 4 +// CHECK: %int8 = torch.constant.int 8 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int8, %int4, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %int8 +// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_3:.*]] : i64 to index +// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %int4 +// CHECK: %[[VAL_6:.*]] = arith.index_cast %[[VAL_5]] : i64 to index +// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor +// CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex> +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32> +func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { + %int-1 = torch.constant.int -1 + %int4 = torch.constant.int 4 + %int8 = torch.constant.int 8 + %0 = torch.prim.ListConstruct %int8, %int4, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.broadcast_to %arg0, %0 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[8,4,?],f32> + return %1 : !torch.vtensor<[8,4,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.batch_norm( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { -// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> -// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> -// CEHCK: %true = torch.constant.bool true -// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor -// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 -// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 -// CEHCK: %int1 = torch.constant.int 1 -// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor -// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor -// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> -// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> +// CHECK-LABEL: func.func @torch.aten.relu( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} -func.func @torch.aten.batch_norm(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$training( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %true = torch.constant.bool true - %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %float1.000000e-01 = torch.constant.float 1.000000e-01 - %float1.000000e-05 = torch.constant.float 1.000000e-05 - %int1 = torch.constant.int 1 - %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - %4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> - return %4 : !torch.vtensor<[2,3,5,5],f32> -} - -// CHECK-LABEL: func.func @torch.aten.batch_norm$none_bias_weight( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { -// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> -// CEHCK: %none = torch.constant.none -// CEHCK: %1 = mhlo.constant dense<1.000000e+00> : tensor<3xf32> -// CEHCK: %2 = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CEHCK: %true = torch.constant.bool true -// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor -// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 -// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 -// CEHCK: %int1 = torch.constant.int 1 -// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1> : tensor -// CEHCK: %[[VAL_4:.*]] = mhlo.add %[[VAL_2]], %[[VAL_3]] : tensor -// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> -// CEHCK: %[[VAL_6:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_5]], %[[VAL_6]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> -// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> -func.func @torch.aten.batch_norm$none_bias_weight(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { - %none = torch.constant.none - %0 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> - %1 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> - %true = torch.constant.bool true - %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float1.000000e-01 = torch.constant.float 1.000000e-01 %float1.000000e-05 = torch.constant.float 1.000000e-05 - %int1 = torch.constant.int 1 - %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - %4 = torch.aten.batch_norm %arg0, %none, %none, %1, %0, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.none, !torch.none, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> - return %4 : !torch.vtensor<[2,3,5,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.batch_norm$inference( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { -// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,5,5],f32> -> tensor<2x3x5x5xf32> -// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CEHCK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> -// CEHCK: %true = torch.constant.bool true -// CHECK: %false = torch.constant.bool false -// CEHCK: %[[VAL_4:.*]] = mhlo.constant dense<0> : tensor -// CEHCK: %float1.000000e-01 = torch.constant.float 1.000000e-01 -// CEHCK: %float1.000000e-05 = torch.constant.float 1.000000e-05 -// CEHCK: %int1 = torch.constant.int 1 -// CEHCK: %[[VAL_5:.*]] = mhlo.constant dense<1> : tensor -// CEHCK: %[[VAL_6:.*]] = mhlo.add %[[VAL_4]], %[[VAL_5]] : tensor -// CEHCK: %[[VAL_7:.*]], %batch_mean, %batch_var = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -> (tensor<2x3x5x5xf32>, tensor<3xf32>, tensor<3xf32>) -// CEHCK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<2x3x5x5xf32> -> !torch.vtensor<[2,3,5,5],f32> -// CEHCK: return %[[VAL_8]] : !torch.vtensor<[2,3,5,5],f32> -func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[2,3,5,5],f32>) -> !torch.vtensor<[2,3,5,5],f32> { + %2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.batch_norm$training( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %false = torch.constant.bool false +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]] = "mhlo.batch_norm_inference"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> %true = torch.constant.bool true %false = torch.constant.bool false - %2 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float1.000000e-01 = torch.constant.float 1.000000e-01 %float1.000000e-05 = torch.constant.float 1.000000e-05 - %int1 = torch.constant.int 1 - %3 = torch.aten.add.Scalar %2, %int1, %int1 : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - %4 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[2,3,5,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,5,5],f32> - return %4 : !torch.vtensor<[2,3,5,5],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.relu( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],f32> -> tensor<2x5xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<2x5xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xf32> -> !torch.vtensor<[2,5],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],f32> -func.func @torch.aten.relu(%arg0: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,5],f32> { - %0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],f32> -> !torch.vtensor<[2,5],f32> - return %0 : !torch.vtensor<[2,5],f32> -} -// ----- -// CHECK-LABEL: func.func @torch.aten.relu$int8( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,5],si8> -> tensor<2x5xi8> -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0> : tensor<2x5xi8> -// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor<2x5xi8> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<2x5xi8> -> !torch.vtensor<[2,5],si8> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[2,5],si8> -func.func @torch.aten.relu$int8(%arg0: !torch.vtensor<[2,5],si8>) -> !torch.vtensor<[2,5],si8> { - %0 = torch.aten.relu %arg0 : !torch.vtensor<[2,5],si8> -> !torch.vtensor<[2,5],si8> - return %0 : !torch.vtensor<[2,5],si8> + %2 = torch.aten.batch_norm %arg0, %1, %0, %0, %1, %false, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> } // ----- -// CHECK-LABEL: func.func @torch.aten.reciprocal( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> { -// CEHCK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0:.*]] : !torch.vtensor<[5,5,5],f32> -> tensor<5x5x5xf32> -// CEHCK: %[[VAL_2:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CEHCK: %[[VAL_3:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<5x5x5xf32> -// CEHCK: %[[VAL_4:.*]] = mhlo.divide %[[VAL_3]], %[[VAL_1]] : tensor<5x5x5xf32> -// CEHCK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<5x5x5xf32> -> !torch.vtensor<[5,5,5],f32> -// CEHCK: return %[[VAL_5]] : !torch.vtensor<[5,5,5],f32> -func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[5,5,5],f32>) -> !torch.vtensor<[5,5,5],f32> { - %0 = torch.aten.reciprocal %arg0 : !torch.vtensor<[5,5,5],f32> -> !torch.vtensor<[5,5,5],f32> - return %0 : !torch.vtensor<[5,5,5],f32> +// CHECK-LABEL: func.func @torch.aten.batch_norm$no_bias_weight( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor +// CHECK: %none = torch.constant.none +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %true = torch.constant.bool true +// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 +// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> +// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> +func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32> + %true = torch.constant.bool true + %float1.000000e-01 = torch.constant.float 1.000000e-01 + %float1.000000e-05 = torch.constant.float 1.000000e-05 + %2 = torch.aten.batch_norm %arg0, %none, %none, %0, %1, %true, %float1.000000e-01, %float1.000000e-05, %true : !torch.vtensor<[?,3,?,?],f32>, !torch.none, !torch.none, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[?,3,?,?],f32> + return %2 : !torch.vtensor<[?,3,?,?],f32> } + // CHECK-LABEL: func @torch.aten.native_layer_norm( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> @@ -617,88 +654,4 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> %2 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %2, %1, %0, %float1.000000e-05 : !torch.vtensor<[3,7,4,5],f32>, !torch.list, !torch.vtensor<[4,5],f32>, !torch.vtensor<[4,5],f32>, !torch.float -> !torch.vtensor<[3,7,4,5],f32>, !torch.vtensor<[3,7,1,1],f32>, !torch.vtensor<[3,7,1,1],f32> return %result0 : !torch.vtensor<[3,7,4,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.contiguous( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %int0 = torch.constant.int 0 -// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32> -func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { - %int0 = torch.constant.int 0 - %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> - return %0 : !torch.vtensor<[4,64],f32> -} - -// ----- - -// CEHCK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> { -// CEHCK: %int1 = torch.constant.int 1 -// CEHCK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor -// CEHCK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],si64> -// CEHCK: return %[[VAL_1]] : !torch.vtensor<[],si64> -func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { - %int1 = torch.constant.int 1 - %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> - return %0 : !torch.vtensor<[], si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.broadcast_to$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 64 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 8 -// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_1]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<8x4x64xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<8x4x64xf32> -> !torch.vtensor<[8,4,64],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[8,4,64],f32> -func.func @torch.aten.broadcast_to$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[8,4,64],f32> { - %int64 = torch.constant.int 64 - %int4 = torch.constant.int 4 - %int8 = torch.constant.int 8 - %0 = torch.prim.ListConstruct %int8, %int4, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list - %1 = torch.aten.broadcast_to %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[8,4,64],f32> - return %1 : !torch.vtensor<[8,4,64],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.permute$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> -func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> - return %1 : !torch.vtensor<[64,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.transpose$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> -func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> } \ No newline at end of file From 3751125a2291a9bb055e971dfda3c74a99c0bb47 Mon Sep 17 00:00:00 2001 From: "wujiawei.jw" Date: Sun, 31 Jul 2022 15:50:16 +0800 Subject: [PATCH 2/4] [MHLO] Support I32 as shape tensor dtype --- lib/Conversion/TorchToMhlo/BasicOp.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index 8860192bf75a..1c33234b2449 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -408,6 +408,15 @@ class ConvertAtenBroadcastToOp : public OpConversionPattern { bcastShapeVec.push_back(newD); } +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + for (auto &dsize : bcastShapeVec) { + auto dsizeI64 = rewriter.create( + op->getLoc(), rewriter.getI64Type(), dsize); + dsize = rewriter.create(op->getLoc(), + rewriter.getI32Type(), dsizeI64); + } +#endif + Value bcastShapeTensor = rewriter.create( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = @@ -664,6 +673,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputElemTy = inputTy.getElementType().cast(); Value channelDim = rewriter.create(op->getLoc(), input, 1); + +#ifdef TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32 + auto channelDimI64 = rewriter.create( + op->getLoc(), rewriter.getI64Type(), channelDim); + channelDim = rewriter.create( + op->getLoc(), rewriter.getI32Type(), channelDimI64); +#endif + Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { From cb8c417a4995d3fef36c6ec1fe6320baec2c4b02 Mon Sep 17 00:00:00 2001 From: "wujiawei.jw" Date: Tue, 2 Aug 2022 00:29:00 +0800 Subject: [PATCH 3/4] [NFC] Add a 'TODO' annotation --- lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index 8fff582e77e3..bbac564fea9b 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -174,6 +174,7 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, return const_op.getResult(); } +// TODO: Support for variable scalar. LogicalResult torchScalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value torchScalarValue, Value &mhloTensor, Type dtype, From 24264e917c8e1a7ce5595350d0eef710e961a8a2 Mon Sep 17 00:00:00 2001 From: "wujiawei.jw" Date: Tue, 2 Aug 2022 09:57:23 +0800 Subject: [PATCH 4/4] [NFC] merge update from main branch --- lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp index bbac564fea9b..4a6c333ffc49 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp @@ -328,7 +328,6 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, return bcast_op.getResult(); } -<<<<<<< HEAD SmallVector toPositiveDims(ArrayRef dims, int64_t rank) { SmallVector posDims; posDims.reserve(rank); @@ -429,7 +428,6 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, .getResult(); } -======= Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { @@ -440,6 +438,5 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, rewriter.getI64TensorAttr({})) .getResult(); } ->>>>>>> [MHLO] Support for dynamic shape in basic op conversion by introducing CHLO dialect } // namespace mhlo } // namespace mlir \ No newline at end of file