Skip to content

Commit

Permalink
[MLIR][TORCH] Fix dynamic cases for aten.index.Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Aug 19, 2022
1 parent 1e1759c commit 65d811e
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 12 deletions.
4 changes: 4 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"IndexTensorMultiInputContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousOneDimDynamic_basic",
"IndexTensorMultiInputNonContiguousDynamic_basic",
"IndexTensorMultiInputNonContiguousMultipleStaticDims_basic",
"Matmul_dot",
"Matmul_matvec",
"MulIntModule_basic",
Expand Down
46 changes: 37 additions & 9 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,10 +583,11 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
int replacedIndexCount = indexTensorDims.size();
int64_t startIndex = contiguous ? firstIndexDim : 0;

// Currently we only support statically sized index tensors
// when there is more than one index tensor.
// TODO: Add support for dynamic size index tensors. This will probably
// require broadcasting the index tensors to a common shape.
// Currently we only support statically sized index tensors or dynamic size
// index tensors without overlapping dynamic dims when there is more than
// one index tensor.
// TODO: Add support for dynamic size index tensors with overlapping
// dynamic dims.
SmallVector<Value> broadcastedIndexShape;
if (indexTensors.size() > 1) {
int maxRank = -1;
Expand All @@ -602,12 +603,39 @@ class ConvertAtenIndexTensorOp : public OpConversionPattern<AtenIndexTensorOp> {
for (auto i : llvm::seq(startIndex, startIndex + maxRank)) {
auto resultDimSize = refinedResultShape[i];
if (ShapedType::isDynamic(resultDimSize)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: index tensors must have static shape if "
"there is more than one index tensor");
SmallVector<Value> dynamicDims;
int64_t staticDimSize = -1;
for (auto indexTensor : indexTensors) {
RankedTensorType indexTensorType =
indexTensor.getType().cast<RankedTensorType>();
int64_t indexTensorRank = indexTensorType.getRank();
if ((maxRank - indexTensorRank) > (i - startIndex))
continue;
int64_t dim = i - startIndex - maxRank + indexTensorRank;
if (ShapedType::isDynamic(indexTensorType.getShape()[dim]))
dynamicDims.push_back(getDimOp(rewriter, loc, indexTensor, dim));
else
staticDimSize =
std::max(staticDimSize, indexTensorType.getShape()[dim]);
}
if (dynamicDims.size() >= 2)
return rewriter.notifyMatchFailure(
op,
"unimplemented: index tensors with overlapping dynamic dims");
if (staticDimSize > 1) {
Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize,
rewriter.getIndexType());
auto equalToRunning = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, cstStaticDimSize,
dynamicDims[0]);
rewriter.create<cf::AssertOp>(loc, equalToRunning,
"mismatched size for broadcast");
}
broadcastedIndexShape.push_back(dynamicDims[0]);
} else {
broadcastedIndexShape.push_back(getConstant(
rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
broadcastedIndexShape.push_back(
getConstant(rewriter, loc, resultDimSize, rewriter.getIndexType()));
}
} else {
// For a single indexing tensor we can simply use its (dynamic) sizes
Expand Down
125 changes: 122 additions & 3 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,125 @@ def IndexTensorMultiInputOneDim_basic(module, tu: TestUtils):
# ==============================================================================


class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
None,
index1,
index2,
))


@register_test_case(
module_factory=lambda: IndexTensorMultiInputContiguousOneDimDynamic())
def IndexTensorMultiInputContiguousOneDimDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
torch.randint(3, (3, )))


# ==============================================================================


class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
index1,
None,
index2,
))


@register_test_case(
module_factory=lambda: IndexTensorMultiInputNonContiguousOneDimDynamic())
def IndexTensorMultiInputNonContiguousOneDimDynamic_basic(
module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(4, (6, 1)),
torch.randint(3, (3, )))


# ==============================================================================


class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, 2], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, x, index1, index2):
return torch.ops.aten.index(x, (
index2,
None,
index1,
))


@register_test_case(
module_factory=lambda: IndexTensorMultiInputNonContiguousDynamic())
def IndexTensorMultiInputNonContiguousDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3), torch.randint(2, (6, 2)),
torch.randint(3, (2, )))


# ==============================================================================


class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([4, 1], torch.int64, True),
([1, 3], torch.int64, True),
([-1, 3], torch.int64, True),
])
def forward(self, x, index1, index2, index3):
return torch.ops.aten.index(x, (index1, index2, index3))


@register_test_case(module_factory=lambda:
IndexTensorMultiInputNonContiguousMultipleStaticDims())
def IndexTensorMultiInputNonContiguousMultipleStaticDims_basic(
module, tu: TestUtils):
module.forward(tu.rand(5, 4, 3, 2), torch.randint(3, (4, 1)),
torch.randint(1, (1, 3)), torch.randint(1, (4, 3)))


# ==============================================================================


class IndexTensorMultiInputNonContiguous(torch.nn.Module):

def __init__(self):
Expand Down Expand Up @@ -2591,7 +2710,7 @@ def __init__(self):

@export
@annotate_args([
None,
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
Expand All @@ -2613,7 +2732,7 @@ def __init__(self):

@export
@annotate_args([
None,
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
Expand Down Expand Up @@ -2645,4 +2764,4 @@ def forward(self, val):

@register_test_case(module_factory=lambda: AtenToDeviceModule())
def AtenToDeviceModule_basic(module, tu: TestUtils):
module.forward(torch.randn(2, 4))
module.forward(torch.randn(2, 4))
35 changes: 35 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,38 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index.Tensor
// CHECK-SAME: (%[[INPUT:.*]]: !torch.vtensor<[?,?,?],f32>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>, %[[ARG2:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[?,?,?],f32> -> tensor<?x?x?xf32>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[INDICES:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[NONE]], %[[ARG2]] : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
// CHECK: %[[INDEX1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64>
// CHECK: %[[INDEX2:.*]] = torch_c.to_builtin_tensor %[[ARG2]] : !torch.vtensor<[?],si64> -> tensor<?xi64>
// CHECK: %[[CST0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INDEX1]], %[[CST0]] : tensor<?x1xi64>
// CHECK: %[[CST0_0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INDEX2]], %[[CST0_0]] : tensor<?xi64>
// CHECK: %[[CST1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM2:.*]] = tensor.dim %[[T]], %[[CST1]] : tensor<?x?x?xf32>
// CHECK: %[[OUT_T:.*]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]] : tensor<?x?x?xf32>
// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[INDEX1]], %[[INDEX2]] : tensor<?x1xi64>, tensor<?xi64>) outs(%[[OUT_T]] : tensor<?x?x?xf32>) {
// CHECK: ^bb0(%[[IN1:.*]]: i64, %[[IN2:.*]]: i64, %[[IN3:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IN1]] : i64 to index
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
// CHECK: %[[INDEX_3:.*]] = arith.index_cast %[[IN2]] : i64 to index
// CHECK: %[[RESULT:.*]] = tensor.extract %[[T]][%[[INDEX_1]], %[[INDEX_2]], %[[INDEX_3]]] : tensor<?x?x?xf32>
// CHECK: linalg.yield %[[RESULT]] : f32
// CHECK: } -> tensor<?x?x?xf32>
// CHECK: %[[OUT_CAST:.*]] = tensor.cast %[[OUT]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
// CHECK: %[[VALUE_OUT_CAST:.*]] = torch_c.from_builtin_tensor %[[OUT_CAST]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[VALUE_OUT_CAST]] : !torch.vtensor<[?,?,?],f32>
func.func @torch.aten.index.Tensor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,1],si64>, %arg2: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?,?],f32> {
%none = torch.constant.none
%1 = torch.prim.ListConstruct %arg1, %none, %arg2 : (!torch.vtensor<[?,1],si64>, !torch.none, !torch.vtensor<[?],si64>) -> !torch.list<optional<vtensor>>
%2 = torch.aten.index.Tensor %arg0, %1 : !torch.vtensor<[?,?,?],f32>, !torch.list<optional<vtensor>> -> !torch.vtensor<[?,?,?],f32>
return %2 : !torch.vtensor<[?,?,?],f32>
}

0 comments on commit 65d811e

Please sign in to comment.