Skip to content

Commit

Permalink
[MLIR][TORCH} Fix empty dim cases for the .dim ops
Browse files Browse the repository at this point in the history
This commit fixes the shape calculation for:
1.) aten.mean.dim
2.) aten.var.dim
3.) aten.sum.dim_IntList op

Also, it fixes the lowering of `aten.mean.dim` and
`aten.sum.dim_IntList` for handling the cases of empty dim list.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com
  • Loading branch information
vivekkhandelwal1 committed Jul 29, 2022
1 parent d386b8f commit c681c34
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 34 deletions.
13 changes: 9 additions & 4 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ class ConvertReductionOp : public ConversionPattern {
"`keepdim` must be a constant bool");

SmallVector<int64_t> dimList;
bool isNoneOrEmptyDimList =
op.dim().getType().template isa<Torch::NoneType>();
if (matchPattern(op.dim(), m_TorchConstantIntList(dimList))) {
// Fix negative dimensions, if any, before adding to the list.
for (int64_t dim : dimList) {
Expand All @@ -278,13 +280,16 @@ class ConvertReductionOp : public ConversionPattern {
if (isValidDim(dim, inputType.getRank()))
opInfo.dimSet.insert(dim);
}
} else if (op.dim().getType().template isa<Torch::NoneType>()) {
if (dimList.empty())
isNoneOrEmptyDimList = true;
} else if (!isNoneOrEmptyDimList) {
return rewriter.notifyMatchFailure(
op, "`dim` argument must be a constant int list or None");
}
if (isNoneOrEmptyDimList) {
// If no dimensions were specified, reduce along all dimensions
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);
} else {
return rewriter.notifyMatchFailure(
op, "`dim` argument must be a constant int list or None");
}

return opInfo;
Expand Down
19 changes: 13 additions & 6 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
unsigned inputRank = getTensorRank(input);
Value dimList = op.dim();
Value keepDim = op.keepdim();
Value dtype = op.dtype();
Expand All @@ -1036,12 +1037,18 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern<AtenMeanDimOp> {
loc, outputType, input, dimList, keepDim, dtype);

// `productDimSize` is product of sizes of dimensions to be reduced.
Value productDimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (Value dim : dimListConstruct.elements()) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
Value productDimSize;
// Case: Reduce along all dims.
if (dimListConstruct.elements().empty() && inputRank != 0) {
productDimSize = rewriter.create<AtenNumelOp>(loc, input);
} else {
productDimSize = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
for (Value dim : dimListConstruct.elements()) {
Value dimSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
productDimSize =
rewriter.create<AtenMulIntOp>(loc, productDimSize, dimSize);
}
}
rewriter.replaceOpWithNewOp<AtenDivScalarOp>(op, outputType, sumAlongDims,
productDimSize);
Expand Down
78 changes: 62 additions & 16 deletions lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5566,9 +5566,25 @@ module {
}
func.func @"__torch_mlir_shape_fn.aten.var.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.bool) -> !torch.list<int> {
%none = torch.constant.none
%0 = torch.derefine %none : !torch.none to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<int>
}
%3 = torch.derefine %none : !torch.none to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg3, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.var.correction"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>, %arg3: !torch.bool) -> !torch.list<int> {
%true = torch.constant.bool true
Expand Down Expand Up @@ -5657,25 +5673,55 @@ module {
return %1 : !torch.tuple<list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.mean.dim"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%1 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %1 : !torch.list<int>
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
torch.prim.If.yield %arg1 : !torch.list<int>
}
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.sum.dim_IntList"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {
%true = torch.constant.bool true
%none = torch.constant.none
%int0 = torch.constant.int 0
%0 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool
%1 = torch.prim.If %0 -> (!torch.list<int>) {
%2 = torch.prim.ListConstruct : () -> !torch.list<int>
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
%1 = torch.prim.If %0 -> (!torch.bool) {
torch.prim.If.yield %true : !torch.bool
} else {
%2 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = func.call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
torch.prim.If.yield %4 : !torch.list<int>
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
%6 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int
%7 = torch.aten.eq.int %6, %int0 : !torch.int, !torch.int -> !torch.bool
torch.prim.If.yield %7 : !torch.bool
}
return %1 : !torch.list<int>
%2 = torch.prim.If %1 -> (!torch.list<int>) {
%5 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
torch.prim.Loop %5, %true, init() {
^bb0(%arg4: !torch.int):
%7 = torch.aten.append.t %6, %arg4 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
torch.prim.If.yield %6 : !torch.list<int>
} else {
%5 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>
torch.prim.If.yield %5 : !torch.list<int>
}
%3 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
%4 = call @__torch__.torch.jit._shape_functions.mean_dim(%arg0, %2, %arg2, %3) : (!torch.list<int>, !torch.list<int>, !torch.bool, !torch.any) -> !torch.list<int>
return %4 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.permute"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,8 @@ def aten〇var(self: List[int], unbiased: bool = True) -> List[int]:
return []

def aten〇var〇dim(self: List[int], dim: List[int], unbiased: bool = True, keepdim: bool = False) -> List[int]:
if len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, None)

def aten〇var〇correction(self: List[int], dim: Optional[List[int]], correction: Optional[int], keepdim: bool = False) -> List[int]:
Expand Down Expand Up @@ -533,13 +535,14 @@ def aten〇max〇dim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
return reduced_shape, reduced_shape

def aten〇mean〇dim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)

def aten〇sum〇dim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None:
return upstream_shape_functions.mean_dim(self, [], keepdim, dtype)
else:
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
if dim is None or len(dim)==0:
dim = list(range(len(self)))
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)

def aten〇permute(self: List[int], dims: List[int]) -> List[int]:
return upstream_shape_functions.permute(self, dims)
Expand Down
19 changes: 19 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ def ReduceSumDimIntListKeepDimFloatModule_basic(module, tu: TestUtils):

# ==============================================================================

class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.sum(a, dim=[])


@register_test_case(module_factory=lambda: ReduceSumDimIntListEmptyDimModule())
def ReduceSumDimIntListEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class ReduceSumUnsignedIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
28 changes: 24 additions & 4 deletions python/torch_mlir_e2e_test/test_suite/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ def forward(self, x):
def MeanDimNegativeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================

class MeanDimEmptyDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.mean(x, dim=[])


@register_test_case(module_factory=lambda: MeanDimEmptyDimModule())
def MeanDimEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

# ==============================================================================

class VarUnbiasedModule(torch.nn.Module):
Expand Down Expand Up @@ -410,7 +430,7 @@ def VarDimNegativeModule_basic(module, tu: TestUtils):
# ==============================================================================


class VarDimKeepDimFalseModule(torch.nn.Module):
class VarDimEmptyDimModule(torch.nn.Module):

def __init__(self):
super().__init__()
Expand All @@ -421,11 +441,11 @@ def __init__(self):
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=False)
return torch.ops.aten.var(x, dim=[], keepdim=False)


@register_test_case(module_factory=lambda: VarDimKeepDimFalseModule())
def VarDimKeepDimFalseModule_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: VarDimEmptyDimModule())
def VarDimEmptyDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


Expand Down

0 comments on commit c681c34

Please sign in to comment.