Skip to content

Commit

Permalink
Add support for returning more than one copy of the same tensor
Browse files Browse the repository at this point in the history
One of the simplifications made by the pass `RefinePublicReturn`
currently only happens if the tensor in question only has one
user. However, the current method of checking this does not correctly
handle the case of a user having multiple uses of the same
tensor. This commit makes sure only unique users are considered.
  • Loading branch information
ramiro050 committed Aug 15, 2022
1 parent c935795 commit 52402e4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ class RefinePublicReturnPass
// If the return (or transitively other ops) are not the only users,
// then we can't be sure that the tensor hasn't been mutated, so stop
// here.
if (!llvm::hasSingleElement(copy->getUsers()))
SmallVector<Operation *> usersVector(copy->getUsers());
size_t numOfUniqueUsers = std::distance(
usersVector.begin(), llvm::unique(usersVector, std::equal_to()));
if (numOfUniqueUsers != 1)
break;
newOperand = copy.getOperand();
} else {
Expand Down
13 changes: 13 additions & 0 deletions test/Dialect/Torch/refine-public-return.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,16 @@ func.func @called(%arg0: tensor<*xf32>) -> tensor<*xf32> {
^bb2:
return %arg0 : tensor<*xf32>
}

// -----

// CHECK-LABEL: func.func @return_multiple_copies_of_tensor(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>) {
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[ARG]] : !torch.vtensor<[],f32> to !torch.vtensor
// CHECK: %[[TO_TENSOR:.*]] = torch.copy.to_tensor %[[CAST]] : !torch.tensor
// CHECK: return %[[ARG]], %[[ARG]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>
func.func @return_multiple_copies_of_tensor(%arg0: !torch.vtensor<[],f32>) -> (!torch.tensor, !torch.tensor) {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
return %1, %1 : !torch.tensor, !torch.tensor
}

0 comments on commit 52402e4

Please sign in to comment.