From 52402e47db87c1790ea87cc9a32253012f84227e Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 15 Aug 2022 22:23:57 +0000 Subject: [PATCH] Add support for returning more than one copy of the same tensor One of the simplifications made by the pass `RefinePublicReturn` currently only happens if the tensor in question only has one user. However, the current method of checking this does not correctly handle the case of a user having multiple uses of the same tensor. This commit makes sure only unique users are considered. --- lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp | 5 ++++- test/Dialect/Torch/refine-public-return.mlir | 13 +++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 4adf613466e0..3c681cc737a6 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -71,7 +71,10 @@ class RefinePublicReturnPass // If the return (or transitively other ops) are not the only users, // then we can't be sure that the tensor hasn't been mutated, so stop // here. - if (!llvm::hasSingleElement(copy->getUsers())) + SmallVector usersVector(copy->getUsers()); + size_t numOfUniqueUsers = std::distance( + usersVector.begin(), llvm::unique(usersVector, std::equal_to())); + if (numOfUniqueUsers != 1) break; newOperand = copy.getOperand(); } else { diff --git a/test/Dialect/Torch/refine-public-return.mlir b/test/Dialect/Torch/refine-public-return.mlir index 0cb97d1bd6d1..ad810ec97ccb 100644 --- a/test/Dialect/Torch/refine-public-return.mlir +++ b/test/Dialect/Torch/refine-public-return.mlir @@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> { ^bb2: return %arg0 : tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @return_multiple_copies_of_tensor( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) { +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor +// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor +// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32> +func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) { + %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + return %1, %1 : !torch.tensor, !torch.tensor +}