From 61dfe97ed09d95d1665e1722c594e5c9954d172d Mon Sep 17 00:00:00 2001 From: Vremold Date: Tue, 16 Aug 2022 15:49:10 +0800 Subject: [PATCH] [MHLO] Add AtenRSubScalarOp conversion pattern Co-authored-by: Bairen Yi Co-authored-by: Jiawei Wu Co-authored-by: Tianyou Guo Co-authored-by: Xu Yan Co-authored-by: Ziheng Jiang --- lib/Conversion/TorchToMhlo/Basic.cpp | 4 ++++ test/Conversion/TorchToMhlo/elementwise.mlir | 21 ++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index ef084d2dc083..4a6a3e43cd57 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -178,6 +178,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!rhsType) { rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy); + if (isa(op)) { + std::swap(lhs, rhs); + } } lhs = mhlo::promoteType(rewriter, lhs, outType); @@ -1120,6 +1123,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp); INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp); + INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp); #undef INSERT_BINARY_ADDSUB_PATTERN #define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \ diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 77aaea093ad5..fb22507d43fe 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -197,6 +197,27 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // ----- +// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[INT9:.*]] = torch.constant.int 9 +// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor, tensor) -> tensor +// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor