From 107f7afa5c97ae665743c44146476ad04bfbb9bd Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 22 Jul 2022 18:12:14 +0530 Subject: [PATCH] [MLIR][TORCH] Add decomposition for aten.var.correction op This commit adds the decomposition for `aten.var.correction` op. Signed-Off By: Vivek Khandelwal { + let summary = "Generated op for `aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + AnyTorchOptionalIntType:$correction, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenVarCorrectionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenVarCorrectionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenNllLossForwardOp : Torch_Op<"aten.nll_loss_forward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index c978e3ef4505..22553901a8af 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -229,8 +229,7 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenZeroOp - : public OpRewritePattern { +class DecomposeAtenZeroOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, @@ -705,7 +704,8 @@ class DecomposeAtenTOp : public OpRewritePattern { // unsqueezed_sizes += [1, s] // expanded_sizes += [m, s] // reshaped_sizes += [m * s] -// return self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) +// return +// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) // namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { @@ -754,7 +754,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { assert(leadingRank >= 0 && "leadingRank should greater than 0"); for (size_t i = 0; i < leadingRank; ++i) { insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{repeats[i]}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{repeats[i]}); reshapedSizes.push_back(repeats[i]); } @@ -772,18 +773,20 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { loc, rewriter.getI64IntegerAttr(selfShape[i])); } - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one, dimSize}); - insertDimSizes(expandedSizes, expandedIntSizes, ArrayRef{scale, dimSize}); + insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, + ArrayRef{one, dimSize}); + insertDimSizes(expandedSizes, expandedIntSizes, + ArrayRef{scale, dimSize}); Value scaledSize = rewriter.create(loc, dimSize, scale); reshapedSizes.push_back(scaledSize); } Type dtype = self.getType().cast().getDtype(); - Type unsqueezedType = - ValueTensorType::get(context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::makeArrayRef(expandedIntSizes), dtype); + Type unsqueezedType = ValueTensorType::get( + context, llvm::makeArrayRef(unsqueezedIntSizes), dtype); + Type expandedType = ValueTensorType::get( + context, llvm::makeArrayRef(expandedIntSizes), dtype); auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value unsqueezedDims = @@ -792,8 +795,8 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { rewriter.create(loc, listType, expandedSizes); Value reshapedDims = rewriter.create(loc, listType, reshapedSizes); - auto reshaped = - rewriter.create(loc, unsqueezedType, op.self(), unsqueezedDims); + auto reshaped = rewriter.create(loc, unsqueezedType, op.self(), + unsqueezedDims); auto expanded = rewriter.create(loc, expandedType, reshaped, expandedDims); @@ -1184,14 +1187,14 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { Value input = op.self(); BaseTensorType inputType = input.getType().cast(); - Value inputTimesBeta = rewriter.create( - loc, inputType, input, op.beta()); + Value inputTimesBeta = + rewriter.create(loc, inputType, input, op.beta()); // out = log1p(exp(input * beta)) / beta Value exp = rewriter.create(loc, inputType, inputTimesBeta); Value log1p = rewriter.create(loc, inputType, exp); - Value out = rewriter.create( - loc, inputType, log1p, op.beta()); + Value out = + rewriter.create(loc, inputType, log1p, op.beta()); // Select where x * beta > threshold auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), @@ -1199,8 +1202,8 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { Value condition = rewriter.create( loc, boolResType, inputTimesBeta, op.threshold()); - rewriter.replaceOpWithNewOp( - op, op.getType(), condition, input, out); + rewriter.replaceOpWithNewOp(op, op.getType(), condition, + input, out); return success(); } }; @@ -2138,6 +2141,107 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { }; } // namespace +template +static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, + bool unbiased, int64_t correction) { + Location loc = op.getLoc(); + Value self = op.self(); + Value dimList = op.dim(); + Value keepDim = op.keepdim(); + BaseTensorType inputTensorTy = self.getType().cast(); + Type outputType = op.getType(); + BaseTensorType outputTensorType = outputType.cast(); + Type newOutputType = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type()); + if (!inputTensorTy.hasDtype() || + !inputTensorTy.getDtype().isa()) { + return rewriter.notifyMatchFailure( + op, "support floating-point type input only"); + } + + // Upcasting the input tensor to `F64` dtype for higher precision during the + // computation of the result. + if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) { + self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); + inputTensorTy = self.getType().cast(); + } + + unsigned inputRank = getTensorRank(self); + SmallVector dimListElements; + bool isNoneOrEmpty = true; + if (!dimList.getType().template isa()) { + if (!getListConstructElements(dimList, dimListElements)) + return rewriter.notifyMatchFailure( + op, "expect dimList to be constructed from list construct"); + if (dimListElements.size() != 0 || inputRank == 0) + isNoneOrEmpty = false; + } + if (isNoneOrEmpty) { + for (unsigned i = 0; i < inputRank; i++) + dimListElements.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimListElements); + } + Type meanDimResultType = inputTensorTy; + for (unsigned i = 0; i < dimListElements.size(); i++) + meanDimResultType = computeReductionType( + rewriter, op, meanDimResultType.cast(), + dimListElements[i], + /*keepDim=*/true); + + Value constantNone = rewriter.create(loc); + Value constantTrue = rewriter.create(loc, true); + Value meanAlongDims = rewriter.create( + loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, + /*dtype=*/constantNone); + Value subMean = + createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); + Value square = rewriter.create(loc, inputTensorTy, subMean); + + if (!unbiased) { + Value result = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + result = convertTensorToDtype(rewriter, loc, result, + outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); + } + // Divide the square sum by productDimSize - correction. + Value squareSum = rewriter.create( + loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + + // `productDimSize` is product of sizes of dimensions to be reduced. + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value productDimSize = constantOne; + for (Value dim : dimListElements) { + Value dimSize = rewriter.create(loc, self, dim); + productDimSize = + rewriter.create(loc, productDimSize, dimSize); + } + Value cstCorrection = rewriter.create( + loc, rewriter.getI64IntegerAttr(correction)); + // The `correction` value should be less than or equal to `productDimSize + + // 1`. + Value productDimSizePlusOne = + rewriter.create(loc, productDimSize, constantOne); + Value cond = + rewriter.create(loc, productDimSizePlusOne, cstCorrection); + rewriter.create( + loc, cond, + "correction value should be less than or equal to productDimSize + 1"); + Value productDimSizeSubCorrection = + rewriter.create(loc, productDimSize, cstCorrection); + Value result = rewriter.create(loc, newOutputType, squareSum, + productDimSizeSubCorrection); + result = + convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); + rewriter.replaceOp(op, result); + return success(); +} + // Decompose aten.var(x, dims) into: // sub = aten.sub(x, aten.mean(x, dims)) // square = aten.square(sub) @@ -2151,70 +2255,44 @@ class DecomposeAtenVarDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenVarDimOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.self(); - Value dimList = op.dim(); - Value keepDim = op.keepdim(); - Type outputType = op.getType(); - BaseTensorType inputTensorTy = self.getType().cast(); - if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { - return rewriter.notifyMatchFailure(op, - "support floating type input only"); - } - - auto dimListConstruct = dimList.getDefiningOp(); - if (!dimListConstruct) { - return rewriter.notifyMatchFailure( - op, "expect dimList to be constructed from list construct"); - } - bool unbiased; if (!matchPattern(op.unbiased(), m_TorchConstantBool(&unbiased))) { return rewriter.notifyMatchFailure( op, "Only support constant unbiased for aten.var"); } + int64_t correction = unbiased ? 1 : 0; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); + return success(); + } +}; +} // namespace - SmallVector dimListElements = dimListConstruct.elements(); - Type meanDimResultType = inputTensorTy; - for (unsigned i = 0; i < dimListElements.size(); i++) - meanDimResultType = computeReductionType( - rewriter, op, meanDimResultType.cast(), - dimListElements[i], - /*keepDim=*/true); - - Value constantNone = rewriter.create(loc); - Value constantTrue = rewriter.create(loc, true); - Value meanAlongDims = rewriter.create( - loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, - /*dtype=*/constantNone); - Value subMean = - createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); - Value square = rewriter.create(loc, inputTensorTy, subMean); - if (unbiased) { - // Bessel’s correction is used. Divide the square sum by - // productDimSize-1. - Value squareSum = rewriter.create( - loc, outputType, square, dimList, keepDim, /*dtype=*/constantNone); - - // `productDimSize` is product of sizes of dimensions to be reduced. - Value productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - for (Value dim : dimListConstruct.elements()) { - Value dimSize = rewriter.create(loc, self, dim); - productDimSize = - rewriter.create(loc, productDimSize, dimSize); - } - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value productDimSizeSubOne = - rewriter.create(loc, productDimSize, constantOne); - rewriter.replaceOpWithNewOp(op, outputType, squareSum, - productDimSizeSubOne); +// Decompose aten.var(x, dims) into: +// sub = aten.sub(x, aten.mean(x, dims)) +// square = aten.square(sub) +// out = aten.sum(square, dims) / (productDimSize - correction) +namespace { +class DecomposeAtenVarCorrectionOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenVarCorrectionOp op, + PatternRewriter &rewriter) const override { + int64_t correction; + if (!op.correction().getType().isa()) { + if (!matchPattern(op.correction(), m_TorchConstantInt(&correction))) + return rewriter.notifyMatchFailure( + op, "Only support constant int correction for aten.var"); } else { - rewriter.replaceOpWithNewOp( - op, outputType, square, dimList, keepDim, /*dtype=*/constantNone); + // The default value in case of `correction` being None is 1. + correction = 1; } + bool unbiased = correction == 0 ? false : true; + if (failed(calculateVariance(op, rewriter, unbiased, + correction))) + return rewriter.notifyMatchFailure(op, "invalid variance parameters"); return success(); } }; @@ -2426,6 +2504,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1eff5b7cab3a..d54eb41b2280 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -931,7 +931,7 @@ void TypeAnalysis::visitOperation(Operation *op, Type dtype = operands[0]->getValue().dtype; visitReductionAlongAllDimsOp(max, dtype, operands); return; - } else if (isa(op)) { + } else if (isa(op)) { auto input = operands[0]->getValue(); visitReductionAlongAllDimsOp(op, input.dtype, operands); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ad52744936c2..7a03168239ed 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -5570,6 +5570,36 @@ module { %1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list return %1 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.list { + %true = torch.constant.bool true + %none = torch.constant.none + %int0 = torch.constant.int 0 + %0 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool + %1 = torch.prim.If %0 -> (!torch.bool) { + torch.prim.If.yield %true : !torch.bool + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + %6 = torch.aten.len.t %5 : !torch.list -> !torch.int + %7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool + torch.prim.If.yield %7 : !torch.bool + } + %2 = torch.prim.If %1 -> (!torch.list) { + %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %6 = torch.prim.ListConstruct : () -> !torch.list + torch.prim.Loop %5, %true, init() { + ^bb0(%arg4: !torch.int): + %7 = torch.aten.append.t %6, %arg4 : !torch.list, !torch.int -> !torch.list + torch.prim.Loop.condition %true, iter() + } : (!torch.int, !torch.bool) -> () + torch.prim.If.yield %6 : !torch.list + } else { + %5 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list + torch.prim.If.yield %5 : !torch.list + } + %3 = torch.derefine %none : !torch.none to !torch.any + %4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list, !torch.list, !torch.bool, !torch.any) -> !torch.list + return %4 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.std"(%arg0: !torch.list, %arg1: !torch.bool) -> !torch.list { %0 = torch.prim.ListConstruct : () -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1aa13ef5e45d..85cadbf12c84 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -492,6 +492,11 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]: def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]: return upstream_shape_functions.mean_dim(self, dim, keepdim, None) +def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]: + if dim is None or len(dim)==0: + dim = list(range(len(self))) + return upstream_shape_functions.mean_dim(self, dim, keepdim, None) + def aten〇std(self: List[int], unbiased: bool = True) -> List[int]: return [] diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 48ee9722f1af..f983b08aa228 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -383,6 +383,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::var : (Tensor, bool) -> (Tensor)") emit("aten::var.dim : (Tensor, int[], bool, bool) -> (Tensor)") + emit("aten::var.correction : (Tensor, int[]?, int?, bool) -> (Tensor)") emit("aten::nll_loss_forward : (Tensor, Tensor, Tensor?, int, int) -> (Tensor, Tensor)") emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index 550b96fe0547..01c07d87ab88 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -427,3 +427,160 @@ def forward(self, x): @register_test_case(module_factory=lambda: VarDimKeepDimFalseModule()) def VarDimKeepDimFalseModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class VarCorrectionModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionModule()) +def VarCorrectionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionSingleDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[1], correction=1) + + +@register_test_case(module_factory=lambda: VarCorrectionSingleDimReduceModule()) +def VarCorrectionSingleDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionAllDimReduceModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, + dim=[0, 1, 2], + correction=10, + keepdim=False) + + +@register_test_case(module_factory=lambda: VarCorrectionAllDimReduceModule()) +def VarCorrectionAllDimReduceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionKeepDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True) + + +@register_test_case(module_factory=lambda: VarCorrectionKeepDimModule()) +def VarCorrectionKeepDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionNoneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=None, correction=None) + + +@register_test_case(module_factory=lambda: VarCorrectionNoneModule()) +def VarCorrectionNoneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionEmptyDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionEmptyDimModule()) +def VarCorrectionEmptyDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 7)) + + +# ============================================================================== + + +class VarCorrectionLargeInputModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.var(x, dim=[2, 3], correction=2) + + +@register_test_case(module_factory=lambda: VarCorrectionLargeInputModule()) +def VarCorrectionLargeInputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 1024, 8192)) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 123e01d1ba54..6eee2f29244d 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -233,28 +233,36 @@ func.func @torch.aten.square(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtens // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true %0 = torch.aten.var %arg0, %true: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -269,26 +277,34 @@ func.func @torch.aten.var$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: return %[[BIASED_VAR]] : !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[],f32> func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false %0 = torch.aten.var %arg0, %false: !torch.vtensor<[?,?],f32>, !torch.bool -> !torch.vtensor<[],f32> @@ -303,28 +319,36 @@ func.func @torch.aten.var$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE_0]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[CST1_2:.*]] = torch.constant.int 1 // CHECK: %[[NUM_ELEMENTS_0_SUB_1:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST1_2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[UNBIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0_SUB_1]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[UNBIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[UNBIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[UNBIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %true = torch.constant.bool true @@ -340,26 +364,34 @@ func.func @torch.aten.std$unbiased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.v // CHECK: %[[CST1:.*]] = torch.constant.int 1 // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool false +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f64> // CHECK: %[[DTYPE:.*]] = torch.constant.none // CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[CST_TRUE]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[MUL]], %[[DIM1]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[1,1],f64>, !torch.int -> !torch.vtensor<[1,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1,1],f32>, !torch.float -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[1,1],f64>, !torch.float -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64> -> !torch.vtensor<[?,?],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[DTYPE]] : !torch.vtensor<[?,?],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f64> // CHECK: %[[CST1_1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM0_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[CST1_1]], %[[DIM0_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM1_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[MUL_0]], %[[DIM1_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> -// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[BIASED_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_VAR:.*]] = torch.aten.to.dtype %[[BIASED_VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[BIASED_STD:.*]] = torch.aten.sqrt %[[DOWNCAST_VAR]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> // CHECK: return %[[BIASED_STD]] : !torch.vtensor<[],f32> func.func @torch.aten.std$biased(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[],f32> { %false = torch.constant.bool false @@ -1165,22 +1197,30 @@ func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !t // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list // CHECK: %[[UNBIASED:.*]] = torch.constant.bool false // CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none // CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true -// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 -// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,1],f32>, !torch.float -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f32>, !torch.vtensor<[3,4,7],f32> -> !torch.vtensor<[3,4,7],f32> -// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> // CHECK: %[[CST1_0:.*]] = torch.constant.int 1 -// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f32>, !torch.int -> !torch.int +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[SUB_MEAN_SQUARE]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int // CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f32>, !torch.int -> !torch.vtensor<[3,4,1],f32> -// CHECK: return %[[VAR]] : !torch.vtensor<[3,4,1],f32> +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_0]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { %int2 = torch.constant.int 2 %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list @@ -1208,3 +1248,46 @@ func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) - %ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32> return %ret : !torch.tensor<[2,3],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.var.correction( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list +// CHECK: %[[KEEPDIM:.*]] = torch.constant.bool true +// CHECK: %[[CST7:.*]] = torch.constant.int 7 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[UPCAST_INPUT:.*]] = torch.aten.to.dtype %[[INPUT]], %[[CST7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4,7],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[NONE_0:.*]] = torch.constant.none +// CHECK: %[[KEEPDIM_0:.*]] = torch.constant.bool true +// CHECK: %[[SUM:.*]] = torch.aten.sum.dim_IntList %[[UPCAST_INPUT]], %[[DIMS]], %[[KEEPDIM_0]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS:.*]] = torch.aten.mul.int %[[CST1]], %[[DIM2]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[MEAN:.*]] = torch.aten.div.Scalar %[[SUM]], %[[NUM_ELEMENTS]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB_MEAN:.*]] = torch.aten.sub.Tensor %[[UPCAST_INPUT]], %[[MEAN]], %[[ALPHA]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,1],f64>, !torch.float -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB_MEAN]], %[[SUB_MEAN]] : !torch.vtensor<[3,4,7],f64>, !torch.vtensor<[3,4,7],f64> -> !torch.vtensor<[3,4,7],f64> +// CHECK: %[[SUB_MEAN_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_MEAN_SQUARE]], %[[DIMS]], %[[KEEPDIM]], %[[NONE_0]] : !torch.vtensor<[3,4,7],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST1_0:.*]] = torch.constant.int 1 +// CHECK: %[[DIM2_0:.*]] = torch.aten.size.int %[[UPCAST_INPUT]], %[[CST2]] : !torch.vtensor<[3,4,7],f64>, !torch.int -> !torch.int +// CHECK: %[[NUM_ELEMENTS_0:.*]] = torch.aten.mul.int %[[CST1_0]], %[[DIM2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[CST2_0:.*]] = torch.constant.int 2 +// CHECK: %[[NUM_ELEMENTS_PLUS_ONE:.*]] = torch.aten.add.int %[[NUM_ELEMENTS_0]], %[[CST1_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[PRED:.*]] = torch.aten.ge.int %[[NUM_ELEMENTS_PLUS_ONE]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[PRED]], "correction value should be less than or equal to productDimSize + 1" +// CHECK: %[[NUM_ELEMENTS_MINUS_CORRECTION:.*]] = torch.aten.sub.int %[[NUM_ELEMENTS_0]], %[[CST2_0]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[VAR:.*]] = torch.aten.div.Scalar %[[SUB_MEAN_SQUARE_SUM]], %[[NUM_ELEMENTS_MINUS_CORRECTION]] : !torch.vtensor<[3,4,1],f64>, !torch.int -> !torch.vtensor<[3,4,1],f64> +// CHECK: %[[CST6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE_0:.*]] = torch.constant.bool false +// CHECK: %[[NONE_1:.*]] = torch.constant.none +// CHECK: %[[DOWNCAST_RESULT:.*]] = torch.aten.to.dtype %[[VAR]], %[[CST6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_1]] : !torch.vtensor<[3,4,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4,1],f32> +// CHECK: return %[[DOWNCAST_RESULT]] : !torch.vtensor<[3,4,1],f32> +func.func @torch.aten.var.correction(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vtensor<[3,4,1],f32> { + %int2 = torch.constant.int 2 + %dims = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %keepdim = torch.constant.bool true + %0 = torch.aten.var.correction %arg0, %dims, %int2, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list, !torch.int, !torch.bool -> !torch.vtensor<[3,4,1],f32> + return %0 : !torch.vtensor<[3,4,1],f32> +}