Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
fix divide_floor & export promoteTypes api (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Jun 24, 2022
1 parent 4d8382a commit f081e3e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ enum MemoryFormat {
//===----------------------------------------------------------------------===//
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };

ScalarType promoteTypes(ScalarType a, ScalarType b);
} // namespace torch_upstream
} // namespace torch
} // namespace mlir
Expand Down
5 changes: 4 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2068,8 +2068,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
PatternRewriter &rewriter) const override {
// https://pytorch.org/docs/stable/generated/torch.floor_divide.html
// PyTorch aten.floor_divide is a misnomer because it actually rounds
// the quotient towards zero instead of taking its floor.
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.self(), op.other(),
/*rounding_mode=*/cstStrFloor);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/TorchUpstream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static inline bool isQIntType(ScalarType t) {
// Type promotion related code are copied from
// aten/src/ATen/native/TypeProperties.*.
//===----------------------------------------------------------------------===//
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
ScalarType promoteTypes(ScalarType a, ScalarType b) {
// This is generated according to NumPy's promote_types
constexpr auto u1 = ScalarType::Byte;
constexpr auto i1 = ScalarType::Char;
Expand Down

0 comments on commit f081e3e

Please sign in to comment.