From 23c11c642ffc1650a9b4257d84b25df09f409b03 Mon Sep 17 00:00:00 2001 From: Vidush SInghal Date: Fri, 15 Jul 2022 16:43:25 -0400 Subject: [PATCH] E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 34 +++ .../Dialect/Torch/Utils/TorchUpstream.h | 7 + .../TorchToLinalg/IndirectDataMovement.cpp | 255 ++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 22 ++ lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 62 +++++ .../jit_ir/build_tools/shape_lib_gen.py | 28 ++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 22 ++ 8 files changed, 431 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b40eaaafd88c..b2bf98cbebcd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4984,6 +4984,40 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [ }]; } +def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$weight, + AnyTorchTensorType:$indices, + AnyTorchTensorType:$offsets, + Torch_BoolType:$scale_grad_by_freq, + Torch_IntType:$mode, + Torch_BoolType:$sparse, + AnyTorchOptionalTensorType:$per_sample_weights, + Torch_BoolType:$include_last_offset, + AnyTorchOptionalIntType:$padding_idx + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2, + AnyTorchTensorType:$result3 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEmbeddingBagPaddingIdxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 4); + } + void AtenEmbeddingBagPaddingIdxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 4); + } + }]; +} + def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 0d2d75b7818f..d6bc0a699172 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -153,6 +153,13 @@ enum MemoryFormat { //===----------------------------------------------------------------------===// enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions }; +//===----------------------------------------------------------------------===// +// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops. +// Source: +// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +//===-----------------------------------------------------------------------===// +enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX }; + } // namespace torch_upstream } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 13ae325f5e86..115feef429c1 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -168,6 +168,259 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern { }; } // namespace +namespace { +// AtenEmbeddingPaddingIdxOp +// SUM mode == integer 0 +// Sums bags of embeddings together from a weight tensor based on an index and +// offset Vector. Example arguments weight = [[1, 3, 5, 3], +// [3, 4, 2, 1], +// [2, 2, 3, 2], +// [0, 4, 2, 1]] +// +// indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1] +// offsets = [0, 3, 5] +// +// output_tensor = initZeroTensor(offsets_length, embedding_size) +// +// for i in range(offsets_length): <- dim0 +// for j in range(indices_length): <- dim1 +// for k in range(embedding_size): <- dim2 +// if(offsets[i] <= j and j < offsets[i+1]): +// output_tensor[i][k] = output_tensor[i][k] + +// weight[indices[j]][k] +// else: +// break +// +// Indexing maps for linalg::Generic ops +// +// +// indices_indexing_map = (d0, d1, d2) -> (d1) +// offset_indexing_map = (d0, d1, d2) -> (d0) +// output_indexing_map = (d0, d1, d2) -> (d0, d2) +// + +class ConvertAtenEmbeddingBagPaddingIdxOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + auto context = op->getContext(); + Value weight = adaptor.weight(); + Value indices = adaptor.indices(); + Value offsets = adaptor.offsets(); + Value scaleGradByFreq = adaptor.scale_grad_by_freq(); + Value mode = op.mode(); + Value sparse = op.sparse(); + Value includeLastOffset = op.include_last_offset(); + + int64_t modeInt; + if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { + return rewriter.notifyMatchFailure( + op, "mode is expected to be a constant integer value."); + } + + if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { + return rewriter.notifyMatchFailure( + op, + "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); + } + + bool isSparse; + if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) { + return rewriter.notifyMatchFailure( + op, "sparse is expected to be a constant boolean value."); + } + + if (isSparse) { + return rewriter.notifyMatchFailure( + op, + "Unimplemented: Sparse mode is not supported yet for EmbeddingBag."); + } + + bool discardLastOffset; + if (!matchPattern(includeLastOffset, + m_TorchConstantBool(&discardLastOffset))) { + return rewriter.notifyMatchFailure( + op, + "include_last_offset is expected to be a constant boolean value."); + } + + auto weightTy = weight.getType().cast(); + if (weightTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "weight must be rank 2"); + + auto indicesTy = indices.getType().cast(); + if (indicesTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "indices must be a vector"); + + auto offsetsTy = offsets.getType().cast(); + if (offsetsTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "offsets much be a vector"); + + Type weightElemTy = weightTy.getElementType(); + + int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank(); + SmallVector indicesExpr; + indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); + auto indicesIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + indicesExpr, context); + + SmallVector offsetsExpr; + offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); + + auto offsetIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + offsetsExpr, context); + + SmallVector outputExpr; + outputExpr.push_back(mlir::getAffineDimExpr(0, context)); + outputExpr.push_back(mlir::getAffineDimExpr(2, context)); + + auto outputIndexingMap = + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + outputExpr, context); + + SmallVector indexingMaps = { + indicesIndexingMap, + offsetIndexingMap, + outputIndexingMap, + }; + + SmallVector iteratorTypes(iterationMapDimension, + getParallelIteratorTypeName()); + + Value embeddingDim = getDimOp(rewriter, loc, weight, 1); + Value initTensor; + Value offsetsLength; + Value indicesLength; + if (!discardLastOffset) { + SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), + embeddingDim}; + + initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); + offsetsLength = getDimOp(rewriter, loc, offsets, 0); + indicesLength = getDimOp(rewriter, loc, indices, 0); + } else { + return rewriter.notifyMatchFailure( + op, "Unimplemented: include last offset is not yet " + "supported for EmbeddingBag."); + } + + Value embeddingBagResult = + rewriter + .create( + loc, initTensor.getType(), ValueRange{indices, offsets}, + initTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value indexInIndices = args[0]; + Value offsetsI = args[1]; + Value initTensorElem = args[2]; + + Value indexI = b.create(loc, /*value=*/0); + Value indexIToInt = castIndexToInt64(b, loc, indexI); + Value one = getConstant( + b, loc, 1, + mlir::IntegerType::get(getContext(), 64, + IntegerType::Signless)); + Value offsetIndexPlusOneInt = + b.create(loc, indexIToInt, one); + + Value offsetIndexPlusOne = + castIntToIndex(b, loc, offsetIndexPlusOneInt); + Value checkLast = b.create( + loc, arith::CmpIPredicate::eq, + castIndexToInt64(b, loc, offsetsLength), + offsetIndexPlusOneInt); + Value nextOffset = b.create( + loc, checkLast, castIndexToInt64(b, loc, indicesLength), + b.create(loc, offsets, + offsetIndexPlusOne)); + + Value indicesIndex = castIndexToInt64( + b, loc, b.create(loc, /*value=*/1)); + + Value offsetLessThanIndicesIndex = b.create( + loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex); + Value offsetEqualToIndicesIndex = b.create( + loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); + Value offsetLessThanOrEqualToIndicesIndex = + b.create(loc, offsetLessThanIndicesIndex, + offsetEqualToIndicesIndex); + + Value indicesIndexLessThanNextOffset = + b.create(loc, arith::CmpIPredicate::slt, + indicesIndex, nextOffset); + + Value indicesIndexWithinBounds = b.create( + loc, offsetLessThanOrEqualToIndicesIndex, + indicesIndexLessThanNextOffset); + + SmallVector indexIntoWeight; + indexIntoWeight.push_back( + castIntToIndex(b, loc, indexInIndices)); + indexIntoWeight.push_back( + b.create(loc, /*value=*/2)); + Value weightElem = b.create( + loc, weight, indexIntoWeight); + + Value addResult = b.create(loc, weightElem, + initTensorElem); + Value select = + b.create(loc, indicesIndexWithinBounds, + addResult, initTensorElem); + b.create(loc, select); + }) + .getResult(0); + + // cast outputType. + auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); + Value castedEmbeddingBagResult = + rewriter.create(loc, restulType0, embeddingBagResult); + + // offset2 tensor, this should be an empty tensor for the sum mode + SmallVector offsetResultSize; + Type offsetElemTy = offsetsTy.getElementType(); + Value zeroDim = rewriter.create(loc, /*value=*/0); + offsetResultSize.push_back(zeroDim); + Value offsetResult = rewriter.create( + loc, offsetResultSize, offsetElemTy); + auto resultType1 = typeConverter->convertType(op->getResult(1).getType()); + Value castedOffsetResult = + rewriter.create(loc, resultType1, offsetResult); + + SmallVector offsetSize = getTensorSizes(rewriter, loc, offsets); + // bagsize, vector of size offset with zeros, I think this is always just + // a vector of zeros in the sum mode + Value bagSize = + createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); + auto resultType2 = typeConverter->convertType(op->getResult(2).getType()); + Value castedBagSizeResult = + rewriter.create(loc, resultType2, bagSize); + + // max indices, vector of size offset with zeros, this is also always a + // vector of zeros in the sum mode. Its mainly used in the max mode. + Value indicesOut = + createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); + auto resultType3 = typeConverter->convertType(op->getResult(3).getType()); + Value castedMaxIndices = + rewriter.create(loc, resultType3, indicesOut); + + rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, + castedBagSizeResult, castedMaxIndices}); + + return success(); + } +}; +} // namespace + namespace { // Let's say we have an input tensor: initialized with some random values of // size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an @@ -365,4 +618,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1eff5b7cab3a..9057be4ea6a4 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1036,6 +1036,28 @@ void TypeAnalysis::visitOperation(Operation *op, return; } + // case for Embedding bag padding idx. + if (auto embedding_bag_padding_idx = + dyn_cast(op)) { + + auto resultFloatKnowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + resultFloatKnowledge.dtype = Float32Type::get(op->getContext()); + + incorporateKnowledge(embedding_bag_padding_idx.getResult(0), + resultFloatKnowledge); + auto resultIntKnowledge = + ValueKnowledge::getTensorPessimisticValueState(op->getContext()); + resultIntKnowledge.dtype = + IntegerType::get(op->getContext(), 64, IntegerType::Signed); + + for (int64_t i = 1; i < 4; i++) { + incorporateKnowledge(embedding_bag_padding_idx.getResult(i), + resultIntKnowledge); + } + return; + } + if (auto softmaxIntOp = dyn_cast(op)) { visitAtenSoftmaxLikeOp(softmaxIntOp, operands); return; diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index ad52744936c2..df7adcfd6866 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6334,6 +6334,68 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> { + %none = torch.constant.none + %str = torch.constant.str "AssertionError: " + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int + %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %1 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int + %3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %3 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int + %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool + torch.prim.If %5 -> () { + torch.prim.If.yield + } else { + torch.prim.RaiseException %str, %none : !torch.str, !torch.none + torch.prim.If.yield + } + %6 = torch.prim.ListConstruct : () -> !torch.list + %7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int + %8 = torch.prim.If %arg7 -> (!torch.int) { + %19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int + torch.prim.If.yield %19 : !torch.int + } else { + torch.prim.If.yield %7 : !torch.int + } + %9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int + %10 = torch.aten.append.t %6, %8 : !torch.list, !torch.int -> !torch.list + %11 = torch.aten.append.t %6, %9 : !torch.list, !torch.int -> !torch.list + %12 = torch.prim.ListConstruct : () -> !torch.list + %13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool + %14 = torch.prim.If %13 -> (!torch.list) { + %19 = torch.aten.append.t %12, %int0 : !torch.list, !torch.int -> !torch.list + torch.prim.If.yield %12 : !torch.list + } else { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } + %15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list + %16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool + %17 = torch.prim.If %16 -> (!torch.list) { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } else { + %19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list + torch.prim.If.yield %19 : !torch.list + } + %18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list, !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list, list> + return %18 : !torch.tuple, list, list, list> + } func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> { %int-1 = torch.constant.int -1 %true = torch.constant.bool true diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1aa13ef5e45d..d37c3331ab12 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -934,6 +934,34 @@ def aten〇index_put〇hacked_twin(self: List[int], indices: List[List[int]], va def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) +def aten〇embedding_bag〇padding_idx(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: + assert len(weight) == 2 + assert len(indices) == 1 + assert len(offsets) == 1 + output_bag_shape: List[int] = [] + out_dim0 = offsets[0] + if (include_last_offset): + out_dim0 = out_dim0 - 1 + out_dim1 = weight[1] + output_bag_shape.append(out_dim0) + output_bag_shape.append(out_dim1) + + offset2bag_shape: List[int] = [] + if mode == 1: + offset2bag_shape.append(0) + else: + offset2bag_shape = upstream_shape_functions._copy(indices) + + bag_size_shape = upstream_shape_functions._copy(offsets) + + max_indices_shape: List[int] = [] + if mode == 2: + max_indices_shape = upstream_shape_functions._copy(output_bag_shape) + else: + max_indices_shape = upstream_shape_functions._copy(offsets) + + return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 48ee9722f1af..5e4f7a2763bb 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -422,6 +422,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") + emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)") emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1100f2532ead..b3e44ac4177b 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2459,3 +2459,25 @@ def forward(self, lhs): @register_test_case(module_factory=lambda: NumpyTRank0Module()) def NumpyTRank0Module_basic(module, tu: TestUtils): module.forward(torch.tensor(7, dtype=torch.float32)) + +class AtenEmbeddingBagSumExample(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int64, True), + ([-1], torch.int64, True), + ]) + def forward(self, weight, indices, offsets): + return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None) + +@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample()) +def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils): + weight = torch.rand(100, 10) + indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54]) + offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15]) + module.forward(weight, indices, offsets)