Skip to content

Commit

Permalink
E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode
Browse files Browse the repository at this point in the history
  • Loading branch information
vidsinghal committed Aug 1, 2022
1 parent 3c9addf commit a54e3ea
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 0 deletions.
34 changes: 34 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4984,6 +4984,40 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
}];
}

def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$weight,
AnyTorchTensorType:$indices,
AnyTorchTensorType:$offsets,
Torch_BoolType:$scale_grad_by_freq,
Torch_IntType:$mode,
Torch_BoolType:$sparse,
AnyTorchOptionalTensorType:$per_sample_weights,
Torch_BoolType:$include_last_offset,
AnyTorchOptionalIntType:$padding_idx
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2,
AnyTorchTensorType:$result3
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEmbeddingBagPaddingIdxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 4);
}
void AtenEmbeddingBagPaddingIdxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 4);
}
}];
}

def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
7 changes: 7 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ enum MemoryFormat {
//===----------------------------------------------------------------------===//
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };

//===----------------------------------------------------------------------===//
// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops.
// Source:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };

} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
260 changes: 260 additions & 0 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,264 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern<AtenEmbeddingOp> {
};
} // namespace

namespace {
// AtenEmbeddingPaddingIdxOp
// SUM mode == integer 0
// Sums bags of embeddings together from a weight tensor based on an index and
// offset Vector. Example arguments weight = [[1, 3, 5, 3],
// [3, 4, 2, 1],
// [2, 2, 3, 2],
// [0, 4, 2, 1]]
//
// indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1]
// offsets = [0, 3, 5]
//
// output_tensor = initZeroTensor(offsets_length, embedding_size)
//
// for i in range(offsets_length): <- dim0
// for j in range(indices_length): <- dim1
// for k in range(embedding_size): <- dim2
// if(offsets[i] <= j and j < offsets[i+1]):
// output_tensor[i][k] = output_tensor[i][k] +
// weight[indices[j]][k]
// else:
// break
//
// Indexing maps for linalg::Generic ops
//
//
// indices_indexing_map = (d0, d1, d2) -> (d1)
// offset_indexing_map = (d0, d1, d2) -> (d0)
// output_indexing_map = (d0, d1, d2) -> (d0, d2)
//
// TODO: Find an optimal lowering.
// current lowering is not optimal for bags of large embeddings.
// Since it traverses the output tensor multiple times.
//
//

class ConvertAtenEmbeddingBagPaddingIdxOp
: public OpConversionPattern<AtenEmbeddingBagPaddingIdxOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
auto context = op->getContext();
Value weight = adaptor.weight();
Value indices = adaptor.indices();
Value offsets = adaptor.offsets();
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
Value mode = op.mode();
Value sparse = op.sparse();
Value includeLastOffset = op.include_last_offset();

int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(
op, "mode is expected to be a constant integer value.");
}

if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag.");
}

bool isSparse;
if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) {
return rewriter.notifyMatchFailure(
op, "sparse is expected to be a constant boolean value.");
}

if (isSparse) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Sparse mode is not supported yet for EmbeddingBag.");
}

bool discardLastOffset;
if (!matchPattern(includeLastOffset,
m_TorchConstantBool(&discardLastOffset))) {
return rewriter.notifyMatchFailure(
op,
"include_last_offset is expected to be a constant boolean value.");
}

auto weightTy = weight.getType().cast<RankedTensorType>();
if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2");

auto indicesTy = indices.getType().cast<RankedTensorType>();
if (indicesTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "indices must be a vector");

auto offsetsTy = offsets.getType().cast<RankedTensorType>();
if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "offsets much be a vector");

Type weightElemTy = weightTy.getElementType();

int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank();
SmallVector<AffineExpr> indicesExpr;
indicesExpr.push_back(mlir::getAffineDimExpr(1, context));
auto indicesIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
indicesExpr, context);

SmallVector<AffineExpr> offsetsExpr;
offsetsExpr.push_back(mlir::getAffineDimExpr(0, context));

auto offsetIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
offsetsExpr, context);

SmallVector<AffineExpr> outputExpr;
outputExpr.push_back(mlir::getAffineDimExpr(0, context));
outputExpr.push_back(mlir::getAffineDimExpr(2, context));

auto outputIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
outputExpr, context);

SmallVector<AffineMap, 3> indexingMaps = {
indicesIndexingMap,
offsetIndexingMap,
outputIndexingMap,
};

SmallVector<StringRef> iteratorTypes(iterationMapDimension,
getParallelIteratorTypeName());

Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
Value initTensor;
Value offsetsLength;
Value indicesLength;
if (!discardLastOffset) {
SmallVector<Value> sizes{getDimOp(rewriter, loc, offsets, 0),
embeddingDim};

initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy);
offsetsLength = getDimOp(rewriter, loc, offsets, 0);
indicesLength = getDimOp(rewriter, loc, indices, 0);
} else {
return rewriter.notifyMatchFailure(
op, "Unimplemented: include last offset is not yet "
"supported for EmbeddingBag.");
}

Value embeddingBagResult =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), ValueRange{indices, offsets},
initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indexInIndices = args[0];
Value offsetsI = args[1];
Value initTensorElem = args[2];

Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
Value indexIToInt = castIndexToInt64(b, loc, indexI);
Value one = getConstant(
b, loc, 1,
mlir::IntegerType::get(getContext(), 64,
IntegerType::Signless));
Value offsetIndexPlusOneInt =
b.create<arith::AddIOp>(loc, indexIToInt, one);

Value offsetIndexPlusOne =
castIntToIndex(b, loc, offsetIndexPlusOneInt);
Value checkLast = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq,
castIndexToInt64(b, loc, offsetsLength),
offsetIndexPlusOneInt);
Value nextOffset = b.create<arith::SelectOp>(
loc, checkLast, castIndexToInt64(b, loc, indicesLength),
b.create<tensor::ExtractOp>(loc, offsets,
offsetIndexPlusOne));

Value indicesIndex = castIndexToInt64(
b, loc, b.create<linalg::IndexOp>(loc, /*value=*/1));

Value offsetLessThanIndicesIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex);
Value offsetEqualToIndicesIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex);
Value offsetLessThanOrEqualToIndicesIndex =
b.create<arith::OrIOp>(loc, offsetLessThanIndicesIndex,
offsetEqualToIndicesIndex);

Value indicesIndexLessThanNextOffset =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indicesIndex, nextOffset);

Value indicesIndexWithinBounds = b.create<arith::AndIOp>(
loc, offsetLessThanOrEqualToIndicesIndex,
indicesIndexLessThanNextOffset);

SmallVector<Value> indexIntoWeight;
indexIntoWeight.push_back(
castIntToIndex(b, loc, indexInIndices));
indexIntoWeight.push_back(
b.create<linalg::IndexOp>(loc, /*value=*/2));
Value weightElem = b.create<tensor::ExtractOp>(
loc, weight, indexIntoWeight);

Value addResult = b.create<arith::AddFOp>(loc, weightElem,
initTensorElem);
Value select =
b.create<arith::SelectOp>(loc, indicesIndexWithinBounds,
addResult, initTensorElem);
b.create<linalg::YieldOp>(loc, select);
})
.getResult(0);

// cast outputType.
auto restulType0 = typeConverter->convertType(op->getResult(0).getType());
Value castedEmbeddingBagResult =
rewriter.create<tensor::CastOp>(loc, restulType0, embeddingBagResult);

// offset2 tensor, this should be an empty tensor for the sum mode
SmallVector<Value> offsetResultSize;
Type offsetElemTy = offsetsTy.getElementType();
Value zeroDim = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/0);
offsetResultSize.push_back(zeroDim);
Value offsetResult = rewriter.create<linalg::InitTensorOp>(
loc, offsetResultSize, offsetElemTy);
auto resultType1 = typeConverter->convertType(op->getResult(1).getType());
Value castedOffsetResult =
rewriter.create<tensor::CastOp>(loc, resultType1, offsetResult);

SmallVector<Value> offsetSize = getTensorSizes(rewriter, loc, offsets);
// bagsize, vector of size offset with zeros, I think this is always just
// a vector of zeros in the sum mode
Value bagSize =
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
auto resultType2 = typeConverter->convertType(op->getResult(2).getType());
Value castedBagSizeResult =
rewriter.create<tensor::CastOp>(loc, resultType2, bagSize);

// max indices, vector of size offset with zeros, this is also always a
// vector of zeros in the sum mode. Its mainly used in the max mode.
Value indicesOut =
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
auto resultType3 = typeConverter->convertType(op->getResult(3).getType());
Value castedMaxIndices =
rewriter.create<tensor::CastOp>(loc, resultType3, indicesOut);

rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult,
castedBagSizeResult, castedMaxIndices});

return success();
}
};
} // namespace

namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
Expand Down Expand Up @@ -365,4 +623,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
target.addIllegalOp<AtenIndexTensorOp>();
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
}
22 changes: 22 additions & 0 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,28 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}

// case for Embedding bag padding idx.
if (auto embedding_bag_padding_idx =
dyn_cast<AtenEmbeddingBagPaddingIdxOp>(op)) {

auto resultFloatKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultFloatKnowledge.dtype = Float32Type::get(op->getContext());

incorporateKnowledge(embedding_bag_padding_idx.getResult(0),
resultFloatKnowledge);
auto resultIntKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultIntKnowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);

for (int64_t i = 1; i < 4; i++) {
incorporateKnowledge(embedding_bag_padding_idx.getResult(i),
resultIntKnowledge);
}
return;
}

if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
return;
Expand Down
Loading

0 comments on commit a54e3ea

Please sign in to comment.