Skip to content

Commit

Permalink
[MHLO] refactor pass configurations (llvm#1315)
Browse files Browse the repository at this point in the history
Related to llvm#1227

1. Reduce MHLO #ifdefs
2. Dismiss compilation warnings
  • Loading branch information
Tanyo Kwok authored and AmosLewis committed Sep 2, 2022
1 parent ae75b03 commit c74f581
Show file tree
Hide file tree
Showing 13 changed files with 374 additions and 381 deletions.
8 changes: 0 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ endmacro()
option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
# The i64 calculation is much slower than i32 on some devices, such as Nvidia GPU.
# One can truncate from i64 to i32 since dimension sizes are unlikely to exceed
# the range of i32(4GiB)
option(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32
"Enable truncate dimension size from i64 to i32(unsafely)" OFF)
if(TORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO_TRUNC_DIMSIZE_TO_I32)
endif()
endif()

option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
Expand Down
11 changes: 11 additions & 0 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
Convert Torch ops to mhlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";

// Specify any options.
let options = [
Option<"enableStaticShape", "enable-static-shape", "bool", /*default=*/"false",
"Enable static shape conversion">,
// The i64 calculation is much slower than i32 on some devices, such as
// Nvidia GPU. One can truncate from i64 to i32 since dimension sizes
// are unlikely to exceed the range of i32(4GiB)
Option<"enableI32Index", "enable-i32-index", "bool", /*default=*/"false",
"Enable truncate index from i64 to i32(unsafely)">,
];
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir

Expand Down
Loading

0 comments on commit c74f581

Please sign in to comment.