From f951a63a3d939f4a58b5f84ee184e9b8520bfd19 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Tue, 9 Aug 2022 00:20:43 +0800 Subject: [PATCH] add e2e unittest --- .../Torch/Transforms/DecomposeComplexOps.cpp | 35 +++++++++++++------ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 +++ .../jit_ir/build_tools/shape_lib_gen.py | 3 ++ .../torch_mlir_e2e_test/test_suite/basic.py | 21 +++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 26 ++++++++------ 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index e9bff69f115..77c0ec6d261 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -709,7 +709,7 @@ class DecomposeAtenTOp : public OpRewritePattern { }; } // namespace -// Decompose aten.roll into aten.expand and aten.slice and aten.cat ops. +// Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { class DecomposeAtenRollOp : public OpRewritePattern { @@ -736,28 +736,43 @@ class DecomposeAtenRollOp : public OpRewritePattern { Value constOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); auto self = op.self(); - Type listType = Torch::ListType::get(self.getType()); + auto selfTy = self.getType().cast(); // roll(input, shift, dim) = cat({ // slice(input, dim, -shift, none), // slice(input, dim, 0, -shift)}, dim) - auto ImitateRoll = [&](Value input, Value shift, Value dim) { + auto imitateRoll = [&](Value input, Value shift, Value dim, + int64_t cstDim) { Value negShift = rewriter.create(loc, shift); - Type sliceType = computeReductionType( - rewriter, op, self.getType().cast(), dim, - /*keepDim=*/true); + ArrayRef inputShape = selfTy.getSizes(); + SmallVector sizes; + sizes.append(inputShape.begin(), inputShape.end()); + sizes[cstDim] = ShapedType::kDynamicSize; + Type sliceTy = selfTy.getWithSizesAndDtype(llvm::makeArrayRef(sizes), + selfTy.getDtype()); Value slice0 = rewriter.create( - loc, sliceType, input, dim, negShift, constNone, constOne); + loc, sliceTy, input, dim, negShift, constNone, constOne); Value slice1 = rewriter.create( - loc, sliceType, input, dim, constZero, negShift, constOne); + loc, sliceTy, input, dim, constZero, negShift, constOne); + Type listType = Torch::ListType::get(sliceTy); Value slices = rewriter.create( loc, listType, llvm::ArrayRef{slice0, slice1}); return rewriter.create(loc, self.getType(), slices, dim); }; - auto output = self; + int rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); + Value output = self; auto nShifts = shifts.size(); for (size_t k = 0; k < nShifts; ++k) { - output = ImitateRoll(output, shifts[k], dims[k]); + auto dim = dims[k]; + int64_t cstDim = -1; + if (!matchPattern(dim, m_TorchConstantInt(&cstDim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: dim must be constant"); + + cstDim = toPositiveDim(cstDim, rank); + output = imitateRoll(output, shifts[k], dim, cstDim); } rewriter.replaceOp(op, output); return success(); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 7f435531579..2afd5640c1d 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -4213,6 +4213,10 @@ module { } return %7 : !torch.list } + func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %int-1 = torch.constant.int -1 %true = torch.constant.bool true diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 2d8ba3b6e2c..41379eb66b5 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -635,6 +635,9 @@ def aten〇repeat(self: List[int], repeats: List[int]) -> List[int]: out.append(self[i] * repeats[i + leading_rank]) return out +def aten〇roll(self: List[int], shifts: List[int], dims: List[int] = ()) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expand(self: List[int], size: List[int], implicit: bool = False) -> List[int]: return upstream_shape_functions.expand(self, size) diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index b2401f141b0..6c155fa1ee4 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1047,6 +1047,27 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class RollModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, -1, 2], torch.float32, True), + ]) + def forward(self, x): + return x.roll([2, -1], [0, 2]) + + +@register_test_case(module_factory=lambda: RollModule()) +def RollModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 2)) + +# ============================================================================== + + class RepeatModule(torch.nn.Module): def __init__(self): diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index f5fc7cafa9f..9cd2d18538b 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1355,26 +1355,30 @@ func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // ----- // CHECK-LABEL: func.func @torch.aten.roll( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int, %[[ARG4:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT:.*]]-2 = torch.constant.int -2 +// CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT]]-2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[T2:.*]] = torch.aten.neg.int %[[ARG1]] : !torch.int -> !torch.int -// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[T2]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[ARG2]], %[[INT0]], %[[T2]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T3:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[T2]], %[[NONE]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T4:.*]] = torch.aten.slice.Tensor %[[ARG0]], %[[INT1]], %[[INT0]], %[[T2]], %[[INT1]]_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[T5:.*]] = torch.prim.ListConstruct %[[T3]], %[[T4]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> -// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[ARG2]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T6:.*]] = torch.aten.cat %[[T5]], %[[INT1]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[T7:.*]] = torch.aten.neg.int %[[ARG2]] : !torch.int -> !torch.int -// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[T7]], %[[NONE]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> -// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[ARG3]], %[[INT0]], %[[T7]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T8:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[T7]], %[[NONE]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.none, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T9:.*]] = torch.aten.slice.Tensor %[[T6]], %[[INT]]-2, %[[INT]]0, %[[T7]], %[[INT]]1_0 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: %[[T10:.*]] = torch.prim.ListConstruct %[[T8]], %[[T9]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list> -// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[ARG3]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: %[[T11:.*]] = torch.aten.cat %[[T10]], %[[INT]]-2 : !torch.list>, !torch.int -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T11]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.vtensor<[?,?],f32> { +func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.vtensor<[?,?],f32> { %0 = torch.prim.ListConstruct %arg1, %arg2: (!torch.int, !torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %arg2, %arg3: (!torch.int, !torch.int) -> !torch.list + %int1 = torch.constant.int 1 + %int-2 = torch.constant.int -2 + %1 = torch.prim.ListConstruct %int1, %int-2: (!torch.int, !torch.int) -> !torch.list %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> return %2 : !torch.vtensor<[?,?],f32> }