Skip to content

Commit

Permalink
edits
Browse files Browse the repository at this point in the history
  • Loading branch information
vidsinghal committed Jul 28, 2022
1 parent c47f0f9 commit 575b0d9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ enum MemoryFormat {
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };

//===----------------------------------------------------------------------===//
// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops.
// Source:
// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops.
// Source:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
Expand Down
10 changes: 4 additions & 6 deletions lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
indicesLength = getDimOp(rewriter, loc, indices, 0);
} else {
return rewriter.notifyMatchFailure(
op,"Unimplemented: include last offset is not yet "
"supported for EmbeddingBag.");
op, "Unimplemented: include last offset is not yet "
"supported for EmbeddingBag.");
}

Value embeddingBagResult =
Expand Down Expand Up @@ -423,12 +423,10 @@ class ConvertAtenEmbeddingBagPaddingIdxOp
"Unimplemented: Mean mode is not supported yet for EmbeddingBag.");
} else if (modeInt == torch_upstream::EmbeddingBagMode::MODE_MAX) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Max mode is not supported yet for EmbeddingBag.");
op, "Unimplemented: Max mode is not supported yet for EmbeddingBag.");
} else {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Unknown mode encountered for EmbeddingBag.");
op, "Unimplemented: Unknown mode encountered for EmbeddingBag.");
}
}
};
Expand Down
9 changes: 5 additions & 4 deletions lib/Conversion/TorchToMhlo/GatherOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
ConversionPatternRewriter &rewriter) const override;
};

// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// Ref:
// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html
// padding_idx (int, optional)
// – If specified, the entries at padding_idx do not contribute to the gradient;
// therefore, the embedding vector at padding_idx is not updated during training,
// i.e. it remains as a fixed “pad”.
// – If specified, the entries at padding_idx do not contribute to the
// gradient; therefore, the embedding vector at padding_idx is not updated
// during training, i.e. it remains as a fixed “pad”.
// scale_grad_by_freq (boolean, optional)
// – If given, this will scale gradients by the inverse of frequency of the
// words in the mini-batch. Default False.
Expand Down

0 comments on commit 575b0d9

Please sign in to comment.