diff --git a/BUILD b/BUILD index c5593dce..d22ec247 100644 --- a/BUILD +++ b/BUILD @@ -203,6 +203,7 @@ cc_library( "@torch-mlir//:TorchMLIRConversionUtils", "@torch-mlir//:TorchMLIRTorchBackendTypeConversion", "@torch-mlir//:TorchMLIRTorchConversionDialect", + "@torch-mlir//:TorchMLIRTorchPasses", ], ) diff --git a/include/mlir-tcp/Conversion/Passes.td b/include/mlir-tcp/Conversion/Passes.td index 73450d26..9e86e047 100644 --- a/include/mlir-tcp/Conversion/Passes.td +++ b/include/mlir-tcp/Conversion/Passes.td @@ -21,7 +21,12 @@ def ConvertTorchToTcp : Pass<"convert-torch-to-tcp", "func::FuncOp"> { let description = [{ Convert Torch ops to Tcp ops. }]; - let constructor = "mlir::tcp::createConvertTorchToTcpPass()"; + let constructor = "mlir::tcp::createConvertTorchToTcpPass(/*convertTorchOps=*/{})"; + let options = [ + ListOption<"convertTorchOps", "convert-torch-ops", "std::string", + "List of Torch operation names that should be converted to Tcp", + "llvm::cl::ZeroOrMore"> + ]; } //===----------------------------------------------------------------------===// diff --git a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h index 856d0f7b..9e1ee0d6 100644 --- a/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h +++ b/include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h @@ -19,7 +19,8 @@ namespace mlir { namespace tcp { -std::unique_ptr> createConvertTorchToTcpPass(); +std::unique_ptr> +createConvertTorchToTcpPass(llvm::ArrayRef convertTorchOps); } // namespace tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index d6868dd8..0fc3403f 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -18,6 +18,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/StringSet.h" + using namespace mlir; using namespace mlir::tcp; using namespace mlir::torch; @@ -60,8 +62,7 @@ class ConvertAtenCatOp : public OpConversionPattern { void torch_to_tcp::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - target.addIllegalOp(); - patterns.add(typeConverter, context); + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( + typeConverter, patterns, target, convertTorchOpsSet); } diff --git a/lib/Conversion/TorchToTcp/Elementwise.cpp b/lib/Conversion/TorchToTcp/Elementwise.cpp index 91391630..f1d98c1d 100644 --- a/lib/Conversion/TorchToTcp/Elementwise.cpp +++ b/lib/Conversion/TorchToTcp/Elementwise.cpp @@ -650,78 +650,58 @@ class ConvertAtenToDtypeOp : public OpConversionPattern { void torch_to_tcp::populateElementwisePatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - patterns.add>( - typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenToDtypeOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenClampOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenReluOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenBatchNormOp); + INSERT_ATEN_ELEMENTWISE_OP_PATTERN(AtenAtan2Op); +#undef INSERT_ATEN_ELEMENTWISE_OP_PATTERN + +#define INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenAddSubOp, AtenOp>(typeConverter, patterns, \ + target, convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenAddTensorOp, tcp::AddOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenSubTensorOp, tcp::SubOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenAddScalarOp, tcp::AddOp); + INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN(AtenSubScalarOp, tcp::SubOp); +#undef INSERT_ATEN_ELEMENTWISE_ADD_SUB_PATTERN + +#define INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenOpPattern, AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenOpPattern, AtenOp>(typeConverter, patterns, target, \ + convertTorchOpsSet) + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenMulOp, AtenMulTensorOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenMulOp, AtenMulScalarOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenDivOp, AtenDivTensorOp); + INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN(ConvertAtenDivOp, AtenDivScalarOp); +#undef INSERT_ATEN_ELEMENTWISE_MUL_DIV_PATTERN + +#define INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenUnaryFpOnlyOp, AtenOp>( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenCeilOp, tcp::CeilOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenFloorOp, tcp::FloorOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenSigmoidOp, tcp::SigmoidOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenTanhOp, tcp::TanhOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenSinOp, tcp::SinOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenCosOp, tcp::CosOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenLogOp, tcp::LogOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenNegOp, tcp::NegOp); + INSERT_ATEN_UNARY_FP_ONLY_PATTERN(AtenAtanOp, tcp::AtanOp); +#undef INSERT_ATEN_UNARY_FP_ONLY_PATTERN + +#define INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenOp, TcpOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenUnaryIntOrFpOp, AtenOp>( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenAbsOp, tcp::AbsOp); + INSERT_ATEN_UNARY_INT_OR_FP_PATTERN(AtenSqrtOp, tcp::SqrtOp); +#undef INSERT_ATEN_UNARY_INT_OR_FP_PATTERN } diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index 9ad168c6..5e5c5f64 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -134,7 +134,7 @@ class ConvertValueTensorLiteralOp }; template -class ConvertAtenZerosOnesPatternOp : public OpConversionPattern { +class ConvertAtenZerosOnesOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -180,7 +180,7 @@ class ConvertAtenZerosOnesPatternOp : public OpConversionPattern { }; template -class ConvertAtenZerosOnesLikePatternOp : public OpConversionPattern { +class ConvertAtenZerosOnesLikeOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -220,28 +220,25 @@ class ConvertAtenZerosOnesLikePatternOp : public OpConversionPattern { } // namespace -void torch_to_tcp::populateMiscPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target) { - MLIRContext *context = patterns.getContext(); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add(typeConverter, context); - - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - target.addIllegalOp(); - patterns.add>(typeConverter, - context); - - target.addIllegalOp(); - patterns.add>( - typeConverter, context); - target.addIllegalOp(); - patterns.add>( - typeConverter, context); +void torch_to_tcp::populateMiscPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + +#define INSERT_ATEN_MISC_OP_PATTERN(AtenOp) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet( \ + typeConverter, patterns, target, convertTorchOpsSet) + INSERT_ATEN_MISC_OP_PATTERN(AtenBroadcastToOp); + INSERT_ATEN_MISC_OP_PATTERN(ValueTensorLiteralOp); +#undef INSERT_ATEN_MISC_OP_PATTERN + +#define INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenOpPattern, AtenOp, Val) \ + torch_to_tcp::addPatternIfOpInConvertTorchOpsSet< \ + ConvertAtenOpPattern, AtenOp>(typeConverter, patterns, \ + target, convertTorchOpsSet) + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesOp, AtenZerosOp, 0); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesOp, AtenOnesOp, 1); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesLikeOp, AtenZerosLikeOp, + 0); + INSERT_ATEN_ZEROS_ONES_PATTERN(ConvertAtenZerosOnesLikeOp, AtenOnesLikeOp, 1); +#undef INSERT_ATEN_ZEROS_ONES_PATTERN } diff --git a/lib/Conversion/TorchToTcp/PopulatePatterns.h b/lib/Conversion/TorchToTcp/PopulatePatterns.h index 64fb2b37..d62d0a02 100644 --- a/lib/Conversion/TorchToTcp/PopulatePatterns.h +++ b/lib/Conversion/TorchToTcp/PopulatePatterns.h @@ -9,19 +9,22 @@ #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/StringSet.h" + namespace mlir { namespace torch_to_tcp { -void populateElementwisePatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); -void populateMiscPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); +void populateElementwisePatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); + +void populateMiscPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); -void populateDataMovementPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target); +void populateDataMovementPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet); } // namespace torch_to_tcp } // namespace mlir diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 122ee98d..988f40ac 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -26,6 +26,9 @@ #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringSet.h" + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -40,7 +43,15 @@ namespace tcp { namespace { class ConvertTorchToTcp : public ConvertTorchToTcpBase { +private: + llvm::StringSet<> convertTorchOpsSet; + public: + ConvertTorchToTcp() = default; + ConvertTorchToTcp(ArrayRef convertTorchOps) { + this->convertTorchOps = convertTorchOps; + } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -49,6 +60,14 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { void runOnOperation() override { MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + // Usually the default constructor is called which means `convertTorchOps` + // is usually unset. Doing this here allows the initialization of + // `convertTorchOpsSet` to be be delayed to when `runOnOperation` is called. + convertTorchOpsSet.clear(); + convertTorchOpsSet.insert(convertTorchOps.begin(), convertTorchOps.end()); + ConversionTarget target(*context); target.addLegalDialect(); @@ -57,14 +76,14 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - RewritePatternSet patterns(context); + torch_to_tcp::populateElementwisePatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); - torch_to_tcp::populateElementwisePatternsAndLegality(typeConverter, - patterns, target); torch_to_tcp::populateMiscPatternsAndLegality(typeConverter, patterns, - target); - torch_to_tcp::populateDataMovementPatternsAndLegality(typeConverter, - patterns, target); + target, convertTorchOpsSet); + + torch_to_tcp::populateDataMovementPatternsAndLegality( + typeConverter, patterns, target, convertTorchOpsSet); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -75,8 +94,9 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { } // namespace -std::unique_ptr> createConvertTorchToTcpPass() { - return std::make_unique(); +std::unique_ptr> +createConvertTorchToTcpPass(llvm::ArrayRef convertTorchOps) { + return std::make_unique(convertTorchOps); } } // namespace tcp diff --git a/lib/Conversion/TorchToTcp/Utils.h b/lib/Conversion/TorchToTcp/Utils.h index f93f8350..68bff615 100644 --- a/lib/Conversion/TorchToTcp/Utils.h +++ b/lib/Conversion/TorchToTcp/Utils.h @@ -9,6 +9,10 @@ #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +#include "llvm/ADT/StringSet.h" + namespace mlir { namespace torch_to_tcp { @@ -71,6 +75,26 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, bool getConstTensorWithType(ConversionPatternRewriter &rewriter, Operation *op, Value &constOp, Type resultType, int fillVal); +// Utility function to selectively add a torch->tcp pattern if whitelist op is +// provided +template +inline void addPatternIfOpInConvertTorchOpsSet( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const llvm::StringSet<> &convertTorchOpsSet) { + MLIRContext *context = patterns.getContext(); + std::optional opName = + TorchToTcpPattern(context).getRootKind(); + assert(opName && "All TorchToTcp patterns must target a single op"); + // When no ops are specified, convert all. + // When ops are specified, convert those ops only. + if (convertTorchOpsSet.empty() || + convertTorchOpsSet.contains( + opName->getStringRef().ltrim(torch::Torch::kTorchOpPrefix))) { + target.addIllegalOp(); + patterns.add(typeConverter, context); + } +} + namespace impl { template std::optional diff --git a/lib/Pipeline/Pipeline.cpp b/lib/Pipeline/Pipeline.cpp index b174a72f..8f7fb45b 100644 --- a/lib/Pipeline/Pipeline.cpp +++ b/lib/Pipeline/Pipeline.cpp @@ -35,7 +35,7 @@ using namespace mlir; static void createTorchBackendToTcpBackendPipeline(OpPassManager &pm) { - pm.addNestedPass(tcp::createConvertTorchToTcpPass()); + pm.addNestedPass(tcp::createConvertTorchToTcpPass({})); // Clean up any non-canonical code introduced above. pm.addNestedPass(createCanonicalizerPass()); diff --git a/test/Conversion/TorchToTcp/partial_conversion.mlir b/test/Conversion/TorchToTcp/partial_conversion.mlir new file mode 100644 index 00000000..f282b476 --- /dev/null +++ b/test/Conversion/TorchToTcp/partial_conversion.mlir @@ -0,0 +1,40 @@ +// RUN: tcp-opt %s -convert-torch-to-tcp="convert-torch-ops=aten.atan" -verify-diagnostics | FileCheck %s -check-prefix=CHECK1 +// RUN: tcp-opt %s -convert-torch-to-tcp="convert-torch-ops=aten.log" -verify-diagnostics | FileCheck %s -check-prefix=CHECK2 +// RUN: tcp-opt %s -convert-torch-to-tcp="convert-torch-ops=aten.sigmoid" -verify-diagnostics | FileCheck %s -check-prefix=CHECK3 +// RUN: tcp-opt %s -convert-torch-to-tcp -verify-diagnostics | FileCheck %s -check-prefix=CHECK4 + +// CHECK1-LABEL: func.func @torch.aten.atan.log( +// CHECK1-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK1: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK1: %[[T1:.*]] = tcp.atan %[[T0]] : tensor -> tensor +// CHECK1: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK1: %[[T3:.*]] = torch.aten.log %[[T2]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK1: return %[[T3]] : !torch.vtensor<[?,?],f32> + +// CHECK2-LABEL: func.func @torch.aten.atan.log( +// CHECK2-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK2: %[[T0:.*]] = torch.aten.atan %[[ARG0]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK2: %[[T1:.*]] = torch_c.to_builtin_tensor %[[T0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK2: %[[T2:.*]] = tcp.log %[[T1]] : tensor -> tensor +// CHECK2: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK2: return %[[T3]] : !torch.vtensor<[?,?],f32> + +// CHECK3-LABEL: func.func @torch.aten.atan.log( +// CHECK3-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK3: %[[T0:.*]] = torch.aten.atan %[[ARG0]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK3: %[[T1:.*]] = torch.aten.log %[[T0]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK3: return %[[T1]] : !torch.vtensor<[?,?],f32> + +// CHECK4-LABEL: func.func @torch.aten.atan.log( +// CHECK4-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK4: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK4: %[[T1:.*]] = tcp.atan %[[T0]] : tensor -> tensor +// CHECK4: %[[T2:.*]] = tcp.log %[[T1]] : tensor -> tensor +// CHECK4: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK4: return %[[T3]] : !torch.vtensor<[?,?],f32> + +func.func @torch.aten.atan.log(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.atan %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + %1 = torch.aten.log %0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +}