From 7a6d51d8d2d49f4f4a7af3e5f8fdf438feaa8530 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 1 Nov 2023 08:22:54 -0700 Subject: [PATCH] add whitelist pass option to TorchToTcp add private method move addPatternIfOpInConvertTorchOpsSet to utility function add back typeConverter add inline function in Utils.h update Misc patterns as well update elementwise --- BUILD | 1 + include/mlir-tcp/Conversion/Passes.td | 7 +- .../Conversion/TorchToTcp/TorchToTcp.h | 3 +- lib/Conversion/TorchToTcp/DataMovement.cpp | 9 +- lib/Conversion/TorchToTcp/Elementwise.cpp | 129 ++++++++---------- lib/Conversion/TorchToTcp/Misc.cpp | 51 ++++--- lib/Conversion/TorchToTcp/PopulatePatterns.h | 21 +-- lib/Conversion/TorchToTcp/TorchToTcp.cpp | 37 +++-- lib/Conversion/TorchToTcp/Utils.h | 24 ++++ lib/Pipeline/Pipeline.cpp | 5 +- 10 files changed, 163 insertions(+), 124 deletions(-) diff --git a/BUILD b/BUILD index c5593dce..d22ec247 100644 --- a/BUILD +++ b/BUILD @@ -203,6 +203,7 @@ cc_library( "@torch-mlir//:TorchMLIRConversionUtils", "@torch-mlir//:TorchMLIRTorchBackendTypeConversion", "@torch-mlir//:TorchMLIRTorchConversionDialect", + "@torch-mlir//:TorchMLIRTorchPasses", ], ) diff --git a/include/mlir-tcp/Conversion/Passes.td b/include/mlir-tcp/Conversion/Passes.td index 73450d26..9e86e047 100644 --- a/include/mlir-tcp/Conversion/Passes.td +++ b/include/mlir-tcp/Conversion/Passes.td @@ -21,7 +21,12 @@ def ConvertTorchToTcp : Pass<"convert-torch-to-tcp", "func::FuncOp"> { let description = [{ Convert Torch ops to Tcp ops. }]; - let constructor = "mlir::tcp::createConvertTorchToTcpPass()"; + let constructor = "mlir::tcp::createConvertTorchToTcpPass(/*convertTorchOps=*/{})"; + let options = [ + ListOption<"convertTorchOps", "convert-torch-ops", "std::string", + "List of Torch operation names that should be converted to Tcp", + "llvm::cl::ZeroOrMore"> + ]; } //===----------------------------------------------------------------------===// diff --git a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h index 856d0f7b..9e1ee0d6 100644 --- a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h +++ b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h @@ -19,7 +19,8 @@ namespace mlir { namespace tcp { -std::unique_ptr> createConvertTorchToTcpPass(); +std::unique_ptr> +createConvertTorchToTcpPass(llvm::ArrayRef convertTorchOps); } // namespace tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index d6868dd8..0fc3403f 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -18,6 +18,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/StringSet.h" + using namespace mlir; using namespace mlir::tcp; using namespace mlir::torch; @@ -60,8 +62,7 @@ class ConvertAtenCatOp : public OpConversionPattern { void torch_to_tcp::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add(typeConverter, context); + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet); } diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index 91391630..a98dc545 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -650,78 +650,59 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { void torch_to_tcp::populateElementwisePatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>( - typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenOpPattern, AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenToDtypeOp, AtenToDtypeOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenClampOp, AtenClampOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenReluOp, AtenReluOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenBatchNormOp, AtenBatchNormOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(ConvertAtenAtan2Op, AtenAtan2Op); +#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN + +#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenAddSubOp, AtenOp>(typeConverter, patterns, \ + target, convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenAddTensorOp, tcp::AddOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenSubTensorOp, tcp::SubOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenAddScalarOp, tcp::AddOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenSubScalarOp, tcp::SubOp); +#undef INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN + +#define INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenOpPattern, AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenOpPattern, AtenOp>(typeConverter, patterns, target, \ + convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenMulOp, AtenMulTensorOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenMulOp, AtenMulScalarOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenDivOp, AtenDivTensorOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenDivOp, AtenDivScalarOp); +#undef INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN + +#define INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenUnaryFpOnlyOp, AtenOp>( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenCeilOp, tcp::CeilOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenFloorOp, tcp::FloorOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenSigmoidOp, tcp::SigmoidOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenTanhOp, tcp::TanhOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenSinOp, tcp::SinOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenCosOp, tcp::CosOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenLogOp, tcp::LogOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenNegOp, tcp::NegOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenAtanOp, tcp::AtanOp); +#undef INSERT_ATEN_UNARY_FP_ONLY_PATTERN + +#define INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenUnaryIntOrFpOp, AtenOp>( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenAbsOp, tcp::AbsOp); + INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenSqrtOp, tcp::SqrtOp); +#undef INSERT_ATEN_UNARY_INT_OR_FP_PATTERN } diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index 9ad168c6..2e3ed199 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -134,7 +134,7 @@ class ConvertValueTensorLiteralOp }; template -class ConvertAtenZerosOnesPatternOp : public OpConversionPattern { +class ConvertAtenZerosOnesOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -180,7 +180,7 @@ class ConvertAtenZerosOnesPatternOp : public OpConversionPattern { }; template -class ConvertAtenZerosOnesLikePatternOp : public OpConversionPattern { +class ConvertAtenZerosOnesLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -220,28 +220,27 @@ class ConvertAtenZerosOnesLikePatternOp : public OpConversionPattern { } // namespace -void torch_to_tcp::populateMiscPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - patterns.add>( - typeConverter, context); - target.addIllegalOp(); - patterns.add>( - typeConverter, context); +void torch_to_tcp::populateMiscPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_MISC_OP_PATTERN(ConvertAtenOpPattern, AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_MISC_OP_PATTERN(ConvertAtenBroadcastToOp, AtenBroadcastToOp); + INSERT_ATEN_MISC_OP_PATTERN(ConvertValueTensorLiteralOp, + ValueTensorLiteralOp); +#undef INSERT_ATEN_MISC_OP_PATTERN + +#define INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenOpPattern, AtenOp, Val) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenOpPattern, AtenOp>(typeConverter, patterns, \ + target, convertTorchOpsSet) + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesOp, AtenZerosOp, 0); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesOp, AtenOnesOp, 1); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesLikeOp, AtenZerosLikeOp, + 0); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesLikeOp, AtenOnesLikeOp, 1); +#undef INSERT_ATEN_ZEROS_ONES_PATTERN } diff --git a/lib/Conversion/TorchToTcp/PopulatePatterns.h b/lib/Conversion/TorchToTcp/PopulatePatterns.h index 64fb2b37..d62d0a02 100644 --- a/lib/Conversion/TorchToTcp/PopulatePatterns.h +++ b/lib/Conversion/TorchToTcp/PopulatePatterns.h @@ -9,19 +9,22 @@ #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringSet.h" + namespace mlir { namespace torch_to_tcp { -void populateElementwisePatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); -void populateMiscPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); +void populateElementwisePatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); + +void populateMiscPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); -void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); +void populateDataMovementPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); } // namespace torch_to_tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 122ee98d..f8bfc8aa 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -26,6 +26,9 @@ #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringSet.h" + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -40,7 +43,15 @@ namespace tcp { namespace { class ConvertTorchToTcp : public ConvertTorchToTcpBase { +private: + llvm::StringSet<> convertTorchOpsSet; + public: + ConvertTorchToTcp() = default; + ConvertTorchToTcp(ArrayRef convertTorchOps) { + this->convertTorchOps = convertTorchOps; + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -49,6 +60,15 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { void runOnOperation() override { MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // The strings in the `convertTorchOps` ArrayRef don't exist during the call + // to the constructor `ConvertTorchToTcp`, so the creation of the + // `convertTorchOpsSet` must be delayed to when `runOnOperation` gets + // called. + convertTorchOpsSet.clear(); + convertTorchOpsSet.insert(convertTorchOps.begin(), convertTorchOps.end()); + ConversionTarget target(*context); target.addLegalDialect(); @@ -57,14 +77,14 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - RewritePatternSet patterns(context); + torch_to_tcp::populateElementwisePatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); - torch_to_tcp::populateElementwisePatternsAndLegality(typeConverter, - patterns, target); torch_to_tcp::populateMiscPatternsAndLegality(typeConverter, patterns, - target); - torch_to_tcp::populateDataMovementPatternsAndLegality(typeConverter, - patterns, target); + target, convertTorchOpsSet); + + torch_to_tcp::populateDataMovementPatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -75,8 +95,9 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { } // namespace -std::unique_ptr> createConvertTorchToTcpPass() { - return std::make_unique(); +std::unique_ptr> +createConvertTorchToTcpPass(llvm::ArrayRef convertTorchOps) { + return std::make_unique(convertTorchOps); } } // namespace tcp diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index f93f8350..68bff615 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -9,6 +9,10 @@ #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +#include "llvm/ADT/StringSet.h" + namespace mlir { namespace torch_to_tcp { @@ -71,6 +75,26 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, bool getConstTensorWithType(ConversionPatternRewriter &rewriter, Operation *op, Value &constOp, Type resultType, int fillVal); +// Utility function to selectively add a torch->tcp pattern if whitelist op is +// provided +template +inline void addPatternIfOpInConvertTorchOpsSet( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + MLIRContext *context = patterns.getContext(); + std::optional opName = + TorchToTcpPattern(context).getRootKind(); + assert(opName && "All TorchToTcp patterns must target a single op"); + // When no ops are specified, convert all. + // When ops are specified, convert those ops only. + if (convertTorchOpsSet.empty() || + convertTorchOpsSet.contains( + opName->getStringRef().ltrim(torch::Torch::kTorchOpPrefix))) { + target.addIllegalOp(); + patterns.add(typeConverter, context); + } +} + namespace impl { template std::optional diff --git a/lib/Pipeline/Pipeline.cpp b/lib/Pipeline/Pipeline.cpp index b174a72f..62cccb78 100644 --- a/lib/Pipeline/Pipeline.cpp +++ b/lib/Pipeline/Pipeline.cpp @@ -35,7 +35,10 @@ using namespace mlir; static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) { - pm.addNestedPass(tcp::createConvertTorchToTcpPass()); + ArrayRef emptyArrayRef; + + pm.addNestedPass( + tcp::createConvertTorchToTcpPass(emptyArrayRef)); // Clean up any non-canonical code introduced above. pm.addNestedPass(createCanonicalizerPass());