Skip to content

Commit

Permalink
Add decomposition and E2E support for Aten_EmbeddingBag (#1137)
Browse files Browse the repository at this point in the history
* Add decomposition and E2E support for Aten_EmbeddingBag
  • Loading branch information
vidsinghal authored Aug 8, 2022
1 parent f85ae9c commit b70548e
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 46 deletions.
1 change: 1 addition & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,4 +336,5 @@
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"AtenEmbeddingBagSumExample_basic",
"Aten_EmbeddingBagExample_basic",
}
34 changes: 34 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5117,6 +5117,40 @@ def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_id
}];
}

def Torch_Aten_EmbeddingBagOp : Torch_Op<"aten._embedding_bag", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$weight,
AnyTorchTensorType:$indices,
AnyTorchTensorType:$offsets,
Torch_BoolType:$scale_grad_by_freq,
Torch_IntType:$mode,
Torch_BoolType:$sparse,
AnyTorchOptionalTensorType:$per_sample_weights,
Torch_BoolType:$include_last_offset,
Torch_IntType:$padding_idx
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2,
AnyTorchTensorType:$result3
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_EmbeddingBagOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 4);
}
void Aten_EmbeddingBagOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 4);
}
}];
}

def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,42 @@ class DecomposeAtenSelectScatterOp
};
} // namespace

namespace {
class DecomposeAten_EmbeddingBagOp
: public OpRewritePattern<Aten_EmbeddingBagOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_EmbeddingBagOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value weight = op.weight();
Value indices = op.indices();
Value offsets = op.offsets();
Value scaleGradByFreq = op.scale_grad_by_freq();
Value mode = op.mode();
Value sparse = op.sparse();
Value perSampleWeights = op.per_sample_weights();
Value includeLastOffset = op.include_last_offset();
Value paddingIdx = op.padding_idx();

auto resultType0 = op->getResult(0).getType();
auto resultType1 = op->getResult(1).getType();
auto resultType2 = op->getResult(2).getType();
auto resultType3 = op->getResult(3).getType();

mlir::TypeRange returnTypes{resultType0, resultType1, resultType2,
resultType3};

rewriter.replaceOpWithNewOp<AtenEmbeddingBagPaddingIdxOp>(
op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode,
sparse, perSampleWeights, includeLastOffset, paddingIdx);

return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -2572,6 +2608,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenStdDimOp>();
patterns.add<DecomposeAtenNarrowOp>(context);
target.addIllegalOp<AtenNarrowOp>();
patterns.add<DecomposeAten_EmbeddingBagOp>(context);
target.addIllegalOp<Aten_EmbeddingBagOp>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
Expand Down
41 changes: 21 additions & 20 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ class TypeAnalysis : public dataflow::SparseDataFlowAnalysis<
ArrayRef<const ValueState *> operands);
void visitAtenScalarImplicitOp(AtenScalarImplicitOp op,
ArrayRef<const ValueState *> operands);
void visitAtenEmbeddingBagOp(Operation *op);
};
} // namespace

Expand Down Expand Up @@ -1052,26 +1053,9 @@ void TypeAnalysis::visitOperation(Operation *op,
incorporateKnowledge(embedding.getResult(), knowledge);
return;
}

// case for Embedding bag padding idx.
if (auto embedding_bag_padding_idx =
dyn_cast<AtenEmbeddingBagPaddingIdxOp>(op)) {

auto resultFloatKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultFloatKnowledge.dtype = Float32Type::get(op->getContext());

incorporateKnowledge(embedding_bag_padding_idx.getResult(0),
resultFloatKnowledge);
auto resultIntKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultIntKnowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);

for (int64_t i = 1; i < 4; i++) {
incorporateKnowledge(embedding_bag_padding_idx.getResult(i),
resultIntKnowledge);
}

if (isa<Aten_EmbeddingBagOp, AtenEmbeddingBagPaddingIdxOp>(op)) {
visitAtenEmbeddingBagOp(op);
return;
}

Expand Down Expand Up @@ -1151,6 +1135,23 @@ void TypeAnalysis::visitAtenLinearOp(AtenLinearOp op,
incorporateKnowledge(op->getResult(0), knowledge);
}

void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) {
auto resultFloatKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultFloatKnowledge.dtype = Float32Type::get(op->getContext());

incorporateKnowledge(op->getResult(0), resultFloatKnowledge);
auto resultIntKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultIntKnowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);

for (int64_t i = 1; i < 4; i++) {
incorporateKnowledge(op->getResult(i), resultIntKnowledge);
}
return;
}

// Arange like ops returns a 1-D tensor of size ceil(end - start).
void TypeAnalysis::visitAtenArangeLikeOpHelper(Operation *op,
llvm::Optional<Value> start,
Expand Down
10 changes: 9 additions & 1 deletion lib/Dialect/Torch/Transforms/ShapeLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6470,6 +6470,10 @@ module {
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional<list<int>>, %arg7: !torch.bool, %arg8: !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>> {
%0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.int) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>
return %0 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>
}
func.func @__torch__._embedding_bag_helper(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int2 = torch.constant.int 2
Expand Down Expand Up @@ -6501,7 +6505,7 @@ module {
}
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
%7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int
%8 = torch.prim.If %arg7 -> (!torch.int) {
%8 = torch.prim.If %arg3 -> (!torch.int) {
%19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
Expand Down Expand Up @@ -6531,6 +6535,10 @@ module {
%18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>
return %18 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten._embedding_bag"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional<list<int>>, %arg7: !torch.bool, %arg8: !torch.int) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>> {
%0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.int) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>
return %0 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {
%int-1 = torch.constant.int -1
%true = torch.constant.bool true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,34 @@ def __repr__(self):
else:
return f"TensorOfShape({args_str}, dtype={self.dtype})"

def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[int], include_last_offset: bool, mode: int):
assert len(weight) == 2
assert len(indices) == 1
assert len(offsets) == 1
output_bag_shape: List[int] = []
out_dim0 = offsets[0]
if (include_last_offset):
out_dim0 = out_dim0 - 1
out_dim1 = weight[1]
output_bag_shape.append(out_dim0)
output_bag_shape.append(out_dim1)

offset2bag_shape: List[int] = []
if mode == 1:
offset2bag_shape.append(0)
else:
offset2bag_shape = upstream_shape_functions._copy(indices)

bag_size_shape = upstream_shape_functions._copy(offsets)

max_indices_shape: List[int] = []
if mode == 2:
max_indices_shape = upstream_shape_functions._copy(output_bag_shape)
else:
max_indices_shape = upstream_shape_functions._copy(offsets)

return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape

def LongTensorOfShape(*args, **kwargs):
"""Helper for indicating a TensorOfShape with integer type."""
return TensorOfShape(*args, **kwargs, dtype=torch.long)
Expand Down Expand Up @@ -954,32 +982,10 @@ def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)

def aten〇embedding_bag〇padding_idx(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
assert len(weight) == 2
assert len(indices) == 1
assert len(offsets) == 1
output_bag_shape: List[int] = []
out_dim0 = offsets[0]
if (include_last_offset):
out_dim0 = out_dim0 - 1
out_dim1 = weight[1]
output_bag_shape.append(out_dim0)
output_bag_shape.append(out_dim1)

offset2bag_shape: List[int] = []
if mode == 1:
offset2bag_shape.append(0)
else:
offset2bag_shape = upstream_shape_functions._copy(indices)
return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode)

bag_size_shape = upstream_shape_functions._copy(offsets)

max_indices_shape: List[int] = []
if mode == 2:
max_indices_shape = upstream_shape_functions._copy(output_bag_shape)
else:
max_indices_shape = upstream_shape_functions._copy(offsets)

return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape
def aten〇_embedding_bag(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool = False, mode: int = 0, sparse: bool = False, per_sample_weights: Optional[List[int]] = None, include_last_offset: bool = False, padding_idx: int = -1) -> Tuple[List[int], List[int], List[int], List[int]]:
return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode)

@check_shape_function([
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
emit("aten::_embedding_bag : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int) -> (Tensor, Tensor, Tensor, Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,3 +2605,25 @@ def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
module.forward(weight, indices, offsets)

class Aten_EmbeddingBagExample(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, weight, indices, offsets):
return torch.ops.aten._embedding_bag(weight, indices, offsets)

@register_test_case(module_factory=lambda: Aten_EmbeddingBagExample())
def Aten_EmbeddingBagExample_basic(module, tu: TestUtils):
weight = torch.rand(100, 10)
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
module.forward(weight, indices, offsets)

0 comments on commit b70548e

Please sign in to comment.