From 831ff2e94656fed901a806b75a09843d9c31effc Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 29 Aug 2022 13:42:41 -0400 Subject: [PATCH] Force double precision in AtenMeanDimOp --- .../Torch/Transforms/DecomposeComplexOps.cpp | 17 ++++++++++++++-- .../torch_mlir_e2e_test/test_suite/stats.py | 20 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0f9e5d149b0b..a63d6f56de56 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1214,9 +1214,18 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { op, "expected `dim` to be `None` or constructed from list construct"); } + // Upcasting the input tensor to `F64` dtype for higher precision during the + // computation of the result. + BaseTensorType outputTensorType = outputType.cast(); + Type newOutputType = outputTensorType.getWithSizesAndDtype( + outputTensorType.getSizes(), rewriter.getF64Type());; + if (inputType.getDtype().getIntOrFloatBitWidth() != 64) { + input = convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); + } + // Compute sum along dimensions specified in `dimList`. Value sumAlongDims = rewriter.create( - loc, outputType, input, dimList, keepDim, dtype); + loc, newOutputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. Value productDimSize; @@ -1232,8 +1241,12 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { rewriter.create(loc, productDimSize, dimSize); } } - rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, + Value result = rewriter.create(loc, newOutputType, sumAlongDims, productDimSize); + if (outputTensorType.getDtype().getIntOrFloatBitWidth() != 64) { + result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); + } + rewriter.replaceOp(op, result); return success(); } }; diff --git a/python/torch_mlir_e2e_test/test_suite/stats.py b/python/torch_mlir_e2e_test/test_suite/stats.py index ebf9573bbdd0..33d6390bde29 100644 --- a/python/torch_mlir_e2e_test/test_suite/stats.py +++ b/python/torch_mlir_e2e_test/test_suite/stats.py @@ -221,6 +221,26 @@ def MeanDimNoneDimModule_basic(module, tu: TestUtils): # ============================================================================== +class MeanDimLargeInputModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.mean(x, (2, 3)) + + +@register_test_case(module_factory=lambda: MeanDimLargeInputModule()) +def MeanDimLargeInputModule_basic(module, tu: TestUtils): + module.forward(100 + tu.rand(3, 4, 1024, 8192)) + +# ============================================================================== + class VarUnbiasedModule(torch.nn.Module): def __init__(self): super().__init__()