Skip to content

Commit

Permalink
[MHLO] Add AtenRSubScalarOp conversion pattern
Browse files Browse the repository at this point in the history
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
  • Loading branch information
Vremold committed Aug 16, 2022
1 parent 3b3cb99 commit 61dfe97
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern<AtenOpT> {

if (!rhsType) {
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.other(), outElemTy);
if (isa<AtenRsubScalarOp>(op)) {
std::swap(lhs, rhs);
}
}

lhs = mhlo::promoteType(rewriter, lhs, outType);
Expand Down Expand Up @@ -1120,6 +1123,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, chlo::BroadcastAddOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, chlo::BroadcastSubOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, chlo::BroadcastSubOp);
INSERT_BINARY_ADDSUB_PATTERN(AtenRsubScalarOp, chlo::BroadcastSubOp);
#undef INSERT_BINARY_ADDSUB_PATTERN

#define INSERT_BINARY_MULDIV_PATTERN(AtenOp, ChloOp) \
Expand Down
21 changes: 21 additions & 0 deletions test/Conversion/TorchToMhlo/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,27 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc

// -----

// CHECK-LABEL: func.func @torch.aten.rsubscalar$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[INT9:.*]] = torch.constant.int 9
// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]]
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[T2:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64>
// CHECK: %[[T3:.*]] = mhlo.convert(%[[T2]]) : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[T4:.*]] = "mhlo.reshape"(%[[T3]]) : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[T5:.*]] = chlo.broadcast_subtract %[[T4]], %[[T0]] : (tensor<f32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[T6]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%int9 = torch.constant.int 9
%int1 = torch.constant.int 1
%0 = torch.aten.rsub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.subscalar$alpha(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
Expand Down

0 comments on commit 61dfe97

Please sign in to comment.