diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b40eaaafd88..8abe90935ee 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4047,6 +4047,32 @@ def Torch_AtenVarDimOp : Torch_Op<"aten.var.dim", [ }]; } +def Torch_AtenVarCorrectionOp : Torch_Op<"aten.var.correction", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + AnyTorchOptionalIntType:$correction, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenVarCorrectionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index cf2d2beee6e..af5268a0656 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -270,6 +270,8 @@ class ConvertReductionOp : public ConversionPattern { "`keepdim` must be a constant bool"); SmallVector dimList; + bool isNoneOrEmptyDimList = + op.dim().getType().template isa(); if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) { // Fix negative dimensions, if any, before adding to the list. for (int64_t dim : dimList) { @@ -278,13 +280,16 @@ class ConvertReductionOp : public ConversionPattern { if (isValidDim(dim, inputType.getRank())) opInfo.dimSet.insert(dim); } - } else if (op.dim().getType().template isa()) { + if (dimList.empty()) + isNoneOrEmptyDimList = true; + } else if (!isNoneOrEmptyDimList) { + return rewriter.notifyMatchFailure( + op, "`dim` argument must be a constant int list or None"); + } + if (isNoneOrEmptyDimList) { // If no dimensions were specified, reduce along all dimensions for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); - } else { - return rewriter.notifyMatchFailure( - op, "`dim` argument must be a constant int list or None"); } return opInfo; diff --git a/lib/Conversion/TorchToStd/TorchToStd.cpp b/lib/Conversion/TorchToStd/TorchToStd.cpp index 00b969f79b8..974c1c0d2a6 100644 --- a/lib/Conversion/TorchToStd/TorchToStd.cpp +++ b/lib/Conversion/TorchToStd/TorchToStd.cpp @@ -323,7 +323,7 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns .add>( typeConverter, context); @@ -333,6 +333,9 @@ class ConvertTorchToStd : public ConvertTorchToStdBase { patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenIntComparisonOp>( + typeConverter, context); target.addIllegalOp(); patterns.add< diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c978e3ef450..f80e92df30f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -229,8 +229,7 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenZeroOp - : public OpRewritePattern { +class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, @@ -705,7 +704,8 @@ class DecomposeAtenTOp : public OpRewritePattern { // unsqueezed_sizes += [1, s] // expanded_sizes += [m, s] // reshaped_sizes += [m * s] -// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) +// return +// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) // namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { @@ -754,7 +754,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { assert(leadingRank >= 0 && "leadingRank should greater than 0"); for (size_t i = 0; i < leadingRank; ++i) { insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{repeats[i]}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{repeats[i]}); reshapedSizes.push_back(repeats[i]); } @@ -772,18 +773,20 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { loc, rewriter.getI64IntegerAttr(selfShape[i])); } - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one, dimSize}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{scale, dimSize}); + insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, + ArrayRef{one, dimSize}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{scale, dimSize}); Value scaledSize = rewriter.create(loc, dimSize, scale); reshapedSizes.push_back(scaledSize); } Type dtype = self.getType().cast().getDtype(); - Type unsqueezedType = - ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype); + Type unsqueezedType = ValueTensorType::get( + context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); + Type expandedType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = @@ -792,8 +795,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); - auto reshaped = - rewriter.create(loc, unsqueezedType, op.self(), unsqueezedDims); + auto reshaped = rewriter.create(loc, unsqueezedType, op.self(), + unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); @@ -1009,6 +1012,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.self(); + unsigned inputRank = getTensorRank(input); Value dimList = op.dim(); Value keepDim = op.keepdim(); Value dtype = op.dtype(); @@ -1033,12 +1037,18 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, input, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); + Value productDimSize; + // Case: Reduce along all dims. + if (dimListConstruct.elements().empty() && inputRank != 0) { + productDimSize = rewriter.create(loc, input); + } else { + productDimSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + for (Value dim : dimListConstruct.elements()) { + Value dimSize = rewriter.create(loc, input, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, productDimSize); @@ -1184,14 +1194,14 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { Value input = op.self(); BaseTensorType inputType = input.getType().cast(); - Value inputTimesBeta = rewriter.create( - loc, inputType, input, op.beta()); + Value inputTimesBeta = + rewriter.create(loc, inputType, input, op.beta()); // out = log1p(exp(input * beta)) / beta Value exp = rewriter.create(loc, inputType, inputTimesBeta); Value log1p = rewriter.create(loc, inputType, exp); - Value out = rewriter.create( - loc, inputType, log1p, op.beta()); + Value out = + rewriter.create(loc, inputType, log1p, op.beta()); // Select where x * beta > threshold auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), @@ -1199,8 +1209,8 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { Value condition = rewriter.create( loc, boolResType, inputTimesBeta, op.threshold()); - rewriter.replaceOpWithNewOp( - op, op.getType(), condition, input, out); + rewriter.replaceOpWithNewOp(op, op.getType(), condition, + input, out); return success(); } }; @@ -2138,6 +2148,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { }; } // namespace +template +static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, + bool unbiased, int64_t correction) { + Location loc = op.getLoc(); + Value self = op.self(); + Value dimList = op.dim(); + Value keepDim = op.keepdim(); + BaseTensorType inputTensorTy = self.getType().cast(); + Type outputType = op.getType(); + BaseTensorType outputTensorType = outputType.cast(); + Type newOutputType = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type()); + if (!inputTensorTy.hasDtype() || + !inputTensorTy.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "support floating-point type input only"); + } + + // Upcasting the input tensor to `F64` dtype for higher precision during the + // computation of the result. + if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { + self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); + inputTensorTy = self.getType().cast(); + } + + unsigned inputRank = getTensorRank(self); + SmallVector dimListElements; + bool isNoneOrEmpty = true; + if (!dimList.getType().template isa()) { + if (!getListConstructElements(dimList, dimListElements)) + return rewriter.notifyMatchFailure( + op, "expect dimList to be constructed from list construct"); + if (!dimListElements.empty() || inputRank == 0) + isNoneOrEmpty = false; + } + if (isNoneOrEmpty) { + for (unsigned i = 0; i < inputRank; i++) + dimListElements.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimListElements); + } + Type meanDimResultType = inputTensorTy; + for (unsigned i = 0; i < dimListElements.size(); i++) + meanDimResultType = computeReductionType( + rewriter, op, meanDimResultType.cast(), + dimListElements[i], + /*keepDim=*/true); + + Value constantNone = rewriter.create(loc); + Value constantTrue = rewriter.create(loc, true); + Value meanAlongDims = rewriter.create( + loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, + /*dtype=*/constantNone); + Value subMean = + createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); + Value square = rewriter.create(loc, inputTensorTy, subMean); + + if (!unbiased) { + Value result = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + result = convertTensorToDtype(rewriter, loc, result, + outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); + } + // Divide the square sum by productDimSize - correction. + Value squareSum = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + + // `productDimSize` is product of sizes of dimensions to be reduced. + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value productDimSize = constantOne; + for (Value dim : dimListElements) { + Value dimSize = rewriter.create(loc, self, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } + Value cstCorrection = rewriter.create( + loc, rewriter.getI64IntegerAttr(correction)); + // The `correction` value should be less than or equal to `productDimSize + + // 1`. + Value productDimSizePlusOne = + rewriter.create(loc, productDimSize, constantOne); + Value cond = + rewriter.create(loc, productDimSizePlusOne, cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + Value productDimSizeSubCorrection = + rewriter.create(loc, productDimSize, cstCorrection); + Value result = rewriter.create(loc, newOutputType, squareSum, + productDimSizeSubCorrection); + result = + convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); +} + // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) @@ -2151,70 +2262,44 @@ class DecomposeAtenVarDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarDimOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.self(); - Value dimList = op.dim(); - Value keepDim = op.keepdim(); - Type outputType = op.getType(); - BaseTensorType inputTensorTy = self.getType().cast(); - if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { - return rewriter.notifyMatchFailure(op, - "support floating type input only"); - } - - auto dimListConstruct = dimList.getDefiningOp(); - if (!dimListConstruct) { - return rewriter.notifyMatchFailure( - op, "expect dimList to be constructed from list construct"); - } - bool unbiased; if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) { return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } + int64_t correction = unbiased ? 1 : 0; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); + return success(); + } +}; +} // namespace - SmallVector dimListElements = dimListConstruct.elements(); - Type meanDimResultType = inputTensorTy; - for (unsigned i = 0; i < dimListElements.size(); i++) - meanDimResultType = computeReductionType( - rewriter, op, meanDimResultType.cast(), - dimListElements[i], - /*keepDim=*/true); - - Value constantNone = rewriter.create(loc); - Value constantTrue = rewriter.create(loc, true); - Value meanAlongDims = rewriter.create( - loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, - /*dtype=*/constantNone); - Value subMean = - createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); - Value square = rewriter.create(loc, inputTensorTy, subMean); - if (unbiased) { - // Bessel’s correction is used. Divide the square sum by - // productDimSize-1. - Value squareSum = rewriter.create( - loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone); - - // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, self, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); - } - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value productDimSizeSubOne = - rewriter.create(loc, productDimSize, constantOne); - rewriter.replaceOpWithNewOp(op, outputType, squareSum, - productDimSizeSubOne); +// Decompose aten.var(x, dims) into: +// sub = aten.sub(x, aten.mean(x, dims)) +// square = aten.square(sub) +// out = aten.sum(square, dims) / (productDimSize - correction) +namespace { +class DecomposeAtenVarCorrectionOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenVarCorrectionOp op, + PatternRewriter &rewriter) const override { + int64_t correction; + if (!op.correction().getType().isa()) { + if (!matchPattern(op.correction(), m_TorchConstantInt(&correction))) + return rewriter.notifyMatchFailure( + op, "Only support constant int correction for aten.var"); } else { - rewriter.replaceOpWithNewOp( - op, outputType, square, dimList, keepDim, /*dtype=*/constantNone); + // The default value in case of `correction` being None is 1. + correction = 1; } + bool unbiased = correction == 0 ? false : true; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; @@ -2426,6 +2511,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1eff5b7cab3..d54eb41b228 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -931,7 +931,7 @@ void TypeAnalysis::visitOperation(Operation *op, Type dtype = operands[0]->getValue().dtype; visitReductionAlongAllDimsOp(max, dtype, operands); return; - } else if (isa(op)) { + } else if (isa(op)) { auto input = operands[0]->getValue(); visitReductionAlongAllDimsOp(op, input.dtype, operands); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ad52744936c..9a2f1142a07 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5566,9 +5566,55 @@ module { } func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list { %none = torch.constant.none - %0 = torch.derefine %none : !torch.none to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + torch.prim.If.yield %arg1 : !torch.list + } + %3 = torch.derefine %none : !torch.none to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list + } + func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool + } + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %3 = torch.derefine %none : !torch.none to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { %0 = torch.prim.ListConstruct : () -> !torch.list @@ -5627,25 +5673,55 @@ module { return %1 : !torch.tuple, list> } func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { - %0 = torch.derefine %arg3 : !torch.optional to !torch.any - %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - return %1 : !torch.list + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + torch.prim.If.yield %arg1 : !torch.list + } + %3 = torch.derefine %arg3 : !torch.optional to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list { + %true = torch.constant.bool true %none = torch.constant.none + %int0 = torch.constant.int 0 %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool - %1 = torch.prim.If %0 -> (!torch.list) { - %2 = torch.prim.ListConstruct : () -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool } else { - %2 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list - %3 = torch.derefine %arg3 : !torch.optional to !torch.any - %4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list - torch.prim.If.yield %4 : !torch.list + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool } - return %1 : !torch.list + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %3 = torch.derefine %arg3 : !torch.optional to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list } func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1aa13ef5e45..c03a0be3f4d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -490,6 +490,13 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: return [] def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: + if len(dim)==0: + dim = list(range(len(self))) + return upstream_shape_functions.mean_dim(self, dim, keepdim, None) + +def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: + if dim is None or len(dim)==0: + dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, None) def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: @@ -528,13 +535,14 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[ return reduced_shape, reduced_shape def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + if len(dim)==0: + dim = list(range(len(self))) return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: - if dim is None: - return upstream_shape_functions.mean_dim(self, [], keepdim, dtype) - else: - return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) + if dim is None or len(dim)==0: + dim = list(range(len(self))) + return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype) def aten〇permute(self: List[int], dims: List[int]) -> List[int]: return upstream_shape_functions.permute(self, dims) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 48ee9722f1a..f983b08aa22 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -383,6 +383,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)") + emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index be1ed507848..3fcf2953422 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -106,6 +106,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumDimIntListEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumDimIntListEmptyDimModule()) +def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class ReduceSumUnsignedIntModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py index 8a626d9625a..52266cb9d66 100644 --- a/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py +++ b/python/torch_mlir_e2e_test/test_suite/scalar_comparison.py @@ -81,6 +81,29 @@ def GtIntModule_basic(module, tu: TestUtils): # ============================================================================== +class GeIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ([], torch.int64, True), + ]) + def forward(self, lhs, rhs): + return torch.ops.aten.ge(int(lhs), int(rhs)) + + +@register_test_case(module_factory=lambda: GeIntModule()) +def GeIntModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-100, 100, ()), torch.randint(-100, 100, ())) + + +# ============================================================================== + + class GeFloatModule(torch.nn.Module): def __init__(self): diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 550b96fe054..4d27e5263ff 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -180,6 +180,26 @@ def forward(self, x): def MeanDimNegativeModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class MeanDimEmptyDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, dim=[]) + + +@register_test_case(module_factory=lambda: MeanDimEmptyDimModule()) +def MeanDimEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + # ============================================================================== class VarUnbiasedModule(torch.nn.Module): @@ -410,7 +430,7 @@ def VarDimNegativeModule_basic(module, tu: TestUtils): # ============================================================================== -class VarDimKeepDimFalseModule(torch.nn.Module): +class VarDimEmptyDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -421,9 +441,166 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, x): - return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=False) + return torch.ops.aten.var(x, dim=[], keepdim=False) -@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule()) -def VarDimKeepDimFalseModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: VarDimEmptyDimModule()) +def VarDimEmptyDimModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class VarCorrectionModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionModule()) +def VarCorrectionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionSingleDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[1], correction=1) + + +@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule()) +def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionAllDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, + dim=[0, 1, 2], + correction=10, + keepdim=False) + + +@register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule()) +def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionKeepDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True) + + +@register_test_case(module_factory=lambda: VarCorrectionKeepDimModule()) +def VarCorrectionKeepDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=None) + + +@register_test_case(module_factory=lambda: VarCorrectionNoneModule()) +def VarCorrectionNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionEmptyDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionEmptyDimModule()) +def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionLargeInputModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[2, 3], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule()) +def VarCorrectionLargeInputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 1024, 8192)) diff --git a/test/Conversion/TorchToStd/basic.mlir b/test/Conversion/TorchToStd/basic.mlir index f45388e4e08..755c34f69d5 100644 --- a/test/Conversion/TorchToStd/basic.mlir +++ b/test/Conversion/TorchToStd/basic.mlir @@ -66,6 +66,19 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.ge.int( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { +// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64 +// CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] +// CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool { + %0 = torch.aten.ge.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.vtensor.literal() -> !torch.vtensor<[],f32> { // CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : tensor // CHECK: %[[VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],f32> diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 123e01d1ba5..6eee2f29244 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -233,28 +233,36 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true %0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -269,26 +277,34 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false %0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -303,28 +319,36 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true @@ -340,26 +364,34 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false @@ -1165,22 +1197,30 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list // CHECK: %[[UNBIASED:.*]] = torch.constant.bool false // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none // CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,1],f32>, !torch.float -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,7],f32> -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> -// CHECK: return %[[VAR]] : !torch.vtensor<[3,4,1],f32> +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { %int2 = torch.constant.int 2 %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list @@ -1208,3 +1248,46 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) - %ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32> return %ret : !torch.tensor<[2,3],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.var.correction( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1_0:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[CST2_0:.*]] = torch.constant.int 2 +// CHECK: %[[NUM_ELEMENTS_PLUS_ONE:.*]] = torch.aten.add.int %[[NUM_ELEMENTS_0]], %[[CST1_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[PRED:.*]] = torch.aten.ge.int %[[NUM_ELEMENTS_PLUS_ONE]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED]], "correction value should be less than or equal to productDimSize + 1" +// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> +func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { + %int2 = torch.constant.int 2 + %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %keepdim = torch.constant.bool true + %0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32> + return %0 : !torch.vtensor<[3,4,1],f32> +}