Skip to content

Commit

Permalink
Force double precision in AtenMeanDimOp
Browse files Browse the repository at this point in the history
  • Loading branch information
qedawkins committed Aug 29, 2022
1 parent 2623185 commit 831ff2e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
17 changes: 15 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,9 +1214,18 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
op, "expected `dim` to be `None` or constructed from list construct");
}

// Upcasting the input tensor to `F64` dtype for higher precision during the
// computation of the result.
BaseTensorType outputTensorType = outputType.cast<BaseTensorType>();
Type newOutputType = outputTensorType.getWithSizesAndDtype(
outputTensorType.getSizes(), rewriter.getF64Type());;
if (inputType.getDtype().getIntOrFloatBitWidth() != 64) {
input = convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type());
}

// Compute sum along dimensions specified in `dimList`.
Value sumAlongDims = rewriter.create<AtenSumDimIntListOp>(
loc, outputType, input, dimList, keepDim, dtype);
loc, newOutputType, input, dimList, keepDim, dtype);

// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize;
Expand All @@ -1232,8 +1241,12 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
}
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
Value result = rewriter.create<AtenDivScalarOp>(loc, newOutputType, sumAlongDims,
productDimSize);
if (outputTensorType.getDtype().getIntOrFloatBitWidth() != 64) {
result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype());
}
rewriter.replaceOp(op, result);
return success();
}
};
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,26 @@ def MeanDimNoneDimModule_basic(module, tu: TestUtils):

# ==============================================================================

class MeanDimLargeInputModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, (2, 3))


@register_test_case(module_factory=lambda: MeanDimLargeInputModule())
def MeanDimLargeInputModule_basic(module, tu: TestUtils):
module.forward(100 + tu.rand(3, 4, 1024, 8192))

# ==============================================================================

class VarUnbiasedModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 831ff2e

Please sign in to comment.