Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Add configurable op decomposition rules #1172

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,20 @@ struct TorchLoweringPipelineOptions
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(true)};

// If this option is false, decompose complex operations.
// If this option is true, skip decomposition of complex operations.
Option<bool> decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};
Option<bool> decompose{*this, "decompose-complex-ops",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that we have something a bit more principled here. Like one option which is an allowlist and another which is a denylist so users can control it more precisely.

llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};

// Disable decomposition rules for torch ops. Useless when
// decompose-complex-ops set to false.
Option<std::string> torchDecompositionSkipOps{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a list<string> I think. And it should not have torch in the name.

*this, "torch-decomposition-skip-ops",
llvm::cl::desc("Disable decomposition rules for these torch ops. Useless "
"when decompose-complex-ops set to false."),
llvm::cl::init(",")};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the default value to "" simply?

};

/// Creates a pipeline that lowers the object graph IR that is produced by
Expand Down Expand Up @@ -68,6 +77,12 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();

std::unique_ptr<OperationPass<func::FuncOp>> createDecomposeComplexOpsPass();

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(llvm::ArrayRef<std::string> skipOps);

std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(std::string skipOps);

std::unique_ptr<OperationPass<ModuleOp>> createPreprocessShapeLibraryPass();

std::unique_ptr<OperationPass<ModuleOp>> createReifyShapeCalculationsPass();
Expand Down
8 changes: 7 additions & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,13 @@ def RefinePublicReturn : Pass<"torch-refine-public-return", "ModuleOp"> {

def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
let summary = "Decompose complicated torch operations";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let constructor = "mlir::torch::Torch::createDecomposeComplexOpsPass()";
let options = [
ListOption<"skipOps", "skipOps", "std::string",
"Disable decompositions rules for these operations",
"llvm::cl::ZeroOrMore">
];

let description = [{
Decompose torch operation that are losslessly represented as combinations of
other operations, modulo appropropriate compiler fusion. Note that this pass
Expand Down
Loading