diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 84cec511a350..f1bf75245607 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -155,6 +155,20 @@ jobs: run: | cmake --build build --target check-torch-mlir-all + - name: Ensure generated files are up to date + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + ./build_tools/update_torch_ods.sh + ./build_tools/update_shape_lib.sh + if ! git diff --quiet; then + echo "#######################################################" + echo "Generated files are not up to date (see diff below)" + echo ">>> Please run ./build_tools/update_torch_ods.sh and ./build_tools/update_shape_lib.sh <<<" + echo "#######################################################" + git diff --color=always + exit 1 + fi + - name: Run refbackend e2e integration tests if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} run: | diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 4868a711b47e..f7a8f69ca355 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -3497,60 +3497,65 @@ module { func.func @__torch__.torch.jit._shape_functions.conv_forwards(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list { %true = torch.constant.bool true %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 %0 = torch.aten.len.t %arg5 : !torch.list -> !torch.int %1 = torch.aten.gt.int %0, %int0 : !torch.int, !torch.int -> !torch.bool %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int %3 = torch.prim.ListConstruct : () -> !torch.list - %4 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int - %5 = torch.aten.append.t %3, %4 : !torch.list, !torch.int -> !torch.list - %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int - %7 = torch.aten.append.t %3, %6 : !torch.list, !torch.int -> !torch.list - %8 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - torch.prim.Loop %8, %true, init() { + %4 = torch.prim.If %arg6 -> (!torch.int) { + torch.prim.If.yield %int1 : !torch.int + } else { + torch.prim.If.yield %int0 : !torch.int + } + %5 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int + %6 = torch.aten.append.t %3, %5 : !torch.list, !torch.int -> !torch.list + %7 = torch.aten.__getitem__.t %arg1, %4 : !torch.list, !torch.int -> !torch.int + %8 = torch.aten.append.t %3, %7 : !torch.list, !torch.int -> !torch.list + %9 = torch.aten.__range_length %int2, %2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + torch.prim.Loop %9, %true, init() { ^bb0(%arg9: !torch.int): - %9 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int - %10 = torch.prim.If %1 -> (!torch.int) { - %11 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %12 = torch.aten.__getitem__.t %arg5, %11 : !torch.list, !torch.int -> !torch.int - torch.prim.If.yield %12 : !torch.int + %10 = torch.aten.__derive_index %arg9, %int2, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int + %11 = torch.prim.If %1 -> (!torch.int) { + %12 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %13 = torch.aten.__getitem__.t %arg5, %12 : !torch.list, !torch.int -> !torch.int + torch.prim.If.yield %13 : !torch.int } else { torch.prim.If.yield %int1 : !torch.int } torch.prim.If %arg6 -> () { - %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int - %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int - %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg3, %16 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.mul.int %15, %17 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.__getitem__.t %arg4, %19 : !torch.list, !torch.int -> !torch.int - %21 = torch.aten.mul.int %20, %int2 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.sub.int %18, %21 : !torch.int, !torch.int -> !torch.int - %23 = torch.aten.add.int %22, %13 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int + %16 = torch.aten.sub.int %15, %int1 : !torch.int, !torch.int -> !torch.int + %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %18 = torch.aten.__getitem__.t %arg3, %17 : !torch.list, !torch.int -> !torch.int + %19 = torch.aten.mul.int %16, %18 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.__getitem__.t %arg4, %20 : !torch.list, !torch.int -> !torch.int + %22 = torch.aten.mul.int %21, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.sub.int %19, %22 : !torch.int, !torch.int -> !torch.int + %24 = torch.aten.add.int %23, %14 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } else { - %11 = torch.aten.__getitem__.t %arg1, %9 : !torch.list, !torch.int -> !torch.int - %12 = torch.aten.sub.int %11, %int1 : !torch.int, !torch.int -> !torch.int - %13 = torch.aten.mul.int %10, %12 : !torch.int, !torch.int -> !torch.int - %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int - %15 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int - %16 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %17 = torch.aten.__getitem__.t %arg4, %16 : !torch.list, !torch.int -> !torch.int - %18 = torch.aten.mul.int %17, %int2 : !torch.int, !torch.int -> !torch.int - %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int - %20 = torch.aten.sub.int %19, %14 : !torch.int, !torch.int -> !torch.int - %21 = torch.aten.sub.int %9, %int2 : !torch.int, !torch.int -> !torch.int - %22 = torch.aten.__getitem__.t %arg3, %21 : !torch.list, !torch.int -> !torch.int - %23 = torch.aten.floordiv.int %20, %22 : !torch.int, !torch.int -> !torch.int - %24 = torch.aten.add.int %23, %int1 : !torch.int, !torch.int -> !torch.int - %25 = torch.aten.append.t %3, %24 : !torch.list, !torch.int -> !torch.list + %12 = torch.aten.__getitem__.t %arg1, %10 : !torch.list, !torch.int -> !torch.int + %13 = torch.aten.sub.int %12, %int1 : !torch.int, !torch.int -> !torch.int + %14 = torch.aten.mul.int %11, %13 : !torch.int, !torch.int -> !torch.int + %15 = torch.aten.add.int %14, %int1 : !torch.int, !torch.int -> !torch.int + %16 = torch.aten.__getitem__.t %arg0, %10 : !torch.list, !torch.int -> !torch.int + %17 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %18 = torch.aten.__getitem__.t %arg4, %17 : !torch.list, !torch.int -> !torch.int + %19 = torch.aten.mul.int %18, %int2 : !torch.int, !torch.int -> !torch.int + %20 = torch.aten.add.int %16, %19 : !torch.int, !torch.int -> !torch.int + %21 = torch.aten.sub.int %20, %15 : !torch.int, !torch.int -> !torch.int + %22 = torch.aten.sub.int %10, %int2 : !torch.int, !torch.int -> !torch.int + %23 = torch.aten.__getitem__.t %arg3, %22 : !torch.list, !torch.int -> !torch.int + %24 = torch.aten.floordiv.int %21, %23 : !torch.int, !torch.int -> !torch.int + %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int + %26 = torch.aten.append.t %3, %25 : !torch.list, !torch.int -> !torch.list torch.prim.If.yield } torch.prim.Loop.condition %true, iter() @@ -3593,7 +3598,7 @@ module { %9 = torch.prim.ListConstruct : () -> !torch.list %10 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int %11 = torch.aten.append.t %9, %10 : !torch.list, !torch.int -> !torch.list - %12 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int + %12 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int %13 = torch.aten.append.t %9, %12 : !torch.list, !torch.int -> !torch.list %14 = torch.aten.__range_length %int2, %8, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int torch.prim.Loop %14, %true, init() { @@ -4344,10 +4349,6 @@ module { } return %7 : !torch.list } - func.func @__torch_mlir_shape_fn.aten.roll(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { - %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list - return %0 : !torch.list - } func.func @__torch__.torch.jit._shape_functions.expand(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %int-1 = torch.constant.int -1 %true = torch.constant.bool true @@ -5998,6 +5999,10 @@ module { } return %6 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.roll"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.expand"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list { %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list