From a4aa8e878cd54767cb12a06b47619f1474ba51c8 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 14 Apr 2022 12:40:10 -0400 Subject: [PATCH] Added JIT to MLIR lowering (#724) * Added JIT to MLIR lowering Lowering to JIT is performed in a way similar to how it's done in the TS LTC backend. After a jit::Graph is constructed, it gets converted to a jit::Function, which is fed into the existing utility to generate an MlirModule in torch-mlir. * Renamed `csrc/backend` to `csrc/base_lazy_backend` --- .gitignore | 10 +- build_tools/autogen_ltc_backend.py | 2 +- python/torch_mlir/csrc/CMakeLists.txt | 21 +- .../csrc/backend/mlir_lowering_context.cpp | 107 ----- .../csrc/backend/mlir_lowering_context.h | 72 --- .../LazyShapeInference.cpp | 0 .../LazyShapeInference.h | 0 .../backend_impl.cpp | 3 +- .../backend_impl.h | 3 +- .../mlir_lowering_context.cpp | 255 ++++++++++ .../base_lazy_backend/mlir_lowering_context.h | 136 ++++++ .../mlir_native_functions.cpp | 0 .../mlir_node.cpp | 4 +- .../mlir_node.h | 3 - .../base_lazy_backend/mlir_node_lowering.cpp | 452 ++++++++++++++++++ .../base_lazy_backend/mlir_node_lowering.h | 30 ++ python/torch_mlir/csrc/utils/debug.h | 4 + .../torch/importer/jit_ir/csrc/CMakeLists.txt | 2 +- .../importer/jit_ir/csrc/function_importer.h | 2 +- 19 files changed, 904 insertions(+), 202 deletions(-) delete mode 100644 python/torch_mlir/csrc/backend/mlir_lowering_context.cpp delete mode 100644 python/torch_mlir/csrc/backend/mlir_lowering_context.h rename python/torch_mlir/csrc/{backend => base_lazy_backend}/LazyShapeInference.cpp (100%) rename python/torch_mlir/csrc/{backend => base_lazy_backend}/LazyShapeInference.h (100%) rename python/torch_mlir/csrc/{backend => base_lazy_backend}/backend_impl.cpp (98%) rename python/torch_mlir/csrc/{backend => base_lazy_backend}/backend_impl.h (98%) create mode 100644 python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h rename python/torch_mlir/csrc/{backend => base_lazy_backend}/mlir_native_functions.cpp (100%) rename python/torch_mlir/csrc/{backend => base_lazy_backend}/mlir_node.cpp (88%) rename python/torch_mlir/csrc/{backend => base_lazy_backend}/mlir_node.h (90%) create mode 100644 python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp create mode 100644 python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h diff --git a/.gitignore b/.gitignore index f0fedfcc87a..200d6ee0382 100644 --- a/.gitignore +++ b/.gitignore @@ -22,8 +22,8 @@ __pycache__ # Autogenerated files /generated_native_functions.yaml /generated_backend.hash -/python/torch_mlir/csrc/backend/LazyIr.h -/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp -/python/torch_mlir/csrc/backend/LazyNativeFunctions.h -/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp -/python/torch_mlir/csrc/backend/RegisterLazy.cpp +/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h +/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp +/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h +/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp +/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 815e1e7080d..5b59d32e08a 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -300,7 +300,7 @@ def main(args): "generated_native_functions.yaml" ) backend_path = TORCH_MLIR_DIR.joinpath( - "python", "torch_mlir", "csrc", "backend" + "python", "torch_mlir", "csrc", "base_lazy_backend" ) assert backend_path.is_dir() diff --git a/python/torch_mlir/csrc/CMakeLists.txt b/python/torch_mlir/csrc/CMakeLists.txt index 538362d63f4..f42248cbd96 100644 --- a/python/torch_mlir/csrc/CMakeLists.txt +++ b/python/torch_mlir/csrc/CMakeLists.txt @@ -20,18 +20,23 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(torch_mlir_ltc_backend SHARED - backend/backend_impl.cpp - backend/LazyNativeFunctions.cpp - backend/LazyShapeInference.cpp - backend/GenLazyShapeInference.cpp - backend/mlir_lowering_context.cpp - backend/mlir_native_functions.cpp - backend/mlir_node.cpp - backend/RegisterLazy.cpp + base_lazy_backend/backend_impl.cpp + base_lazy_backend/LazyNativeFunctions.cpp + base_lazy_backend/LazyShapeInference.cpp + base_lazy_backend/GenLazyShapeInference.cpp + base_lazy_backend/mlir_lowering_context.cpp + base_lazy_backend/mlir_native_functions.cpp + base_lazy_backend/mlir_node.cpp + base_lazy_backend/mlir_node_lowering.cpp + base_lazy_backend/RegisterLazy.cpp ) +add_dependencies(torch_mlir_ltc_backend + TorchMLIRJITIRImporter +) target_link_libraries(torch_mlir_ltc_backend TorchMLIRAggregateCAPI + TorchMLIRJITIRImporter ${TORCH_LIBRARIES} ${Python3_LIBRARIES} torch_python diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp deleted file mode 100644 index d366080e256..00000000000 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp +++ /dev/null @@ -1,107 +0,0 @@ -//===- mlir_lowering_context.cpp ------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp -//===----------------------------------------------------------------------===// - -#include - -#include "../utils/debug.h" -#include "../utils/exception.h" -#include "mlir_lowering_context.h" - -namespace torch { -namespace lazy { - -TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device) - : LoweringContext(name, std::forward(device)) {} - -TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) - : LoweringContext( - name, std::forward(device), - std::forward>(post_order), - std::forward(emit_status)) {} - -int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); } - -const std::vector& -TorchMlirComputation::parameter_shapes() const { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -const std::vector& TorchMlirComputation::parameter_names() const { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -const torch::lazy::Shape& TorchMlirComputation::result_shape() const { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -std::string TorchMlirComputation::to_string() const { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -// Get the shape of the result tuple component, given by index. -torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -// Adds the given output as a component of the result tuple and returns its -// assigned position within the tuple. -size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) { - PRINT_FUNCTION(); - const torch::lazy::Node* node; - auto it = emitted_outputs_.find(output); - if (it == emitted_outputs_.end()) { - node = output.node; - - auto post_order = Util::ComputePostOrder(node, &emit_status_); - for (auto po_node : post_order) { - // TODO: uncomment after lowering is implemented - // bool ok = lowering_->Lower(node); - // TORCH_CHECK(ok, "Failed to lower: ", node->ToString()); - } - emitted_outputs_[output] = node; - } else { - node = it->second; - } - result_tuple_.emplace_back(node); - return result_tuple_.size() - 1; -} - -// Associates the given output with the input parameter of the given index and -// shape. Only used for the operator-by-operator execution, mostly for -// debugging purposes. -void TorchMlirLoweringContext::AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -// Build the computation capturing all the operations created with the -// embedded builder (returned by the builder() API). -ComputationPtr TorchMlirLoweringContext::Build() { - PRINT_FUNCTION() - for (const torch::lazy::Node* output : result_tuple_) { - } - return std::make_shared(); -} - -// Retrieves the lowered operation for an output. If the requested output is -// not available yet, the graph behind the output's Node is lowered, and the -// corresponding MLIR operation returned. -torch::jit::Value* GetOutputOp(const Output& output) { - UNIMPLEMENTED_FUNCTION_ERROR(); -} - -} // namespace lazy -} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.h b/python/torch_mlir/csrc/backend/mlir_lowering_context.h deleted file mode 100644 index c8e526b131f..00000000000 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.h +++ /dev/null @@ -1,72 +0,0 @@ -//===- mlir_lowering_context.h --------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h -//===----------------------------------------------------------------------===// - -#pragma once - -#include - -#include - -namespace torch { -namespace lazy { - -class TORCH_API TorchMlirComputation : public torch::lazy::Computation { -public: - int parameters_size() const override; - - virtual const std::vector& - parameter_shapes() const override; - - virtual const std::vector& parameter_names() const override; - - virtual const torch::lazy::Shape& result_shape() const override; -}; - -class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { -public: - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device); - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device, - c10::ArrayRef post_order, - torch::lazy::Util::EmissionMap emit_status); - - // Get the shape of the result tuple component, given by index. - virtual torch::lazy::Shape GetResultShape(size_t index) const override; - - // Adds the given output as a component of the result tuple and returns its - // assigned position within the tuple. - virtual size_t AddResult(const torch::lazy::Output& output) override; - - // Associates the given output with the input parameter of the given index and - // shape. Only used for the operator-by-operator execution, mostly for - // debugging purposes. - virtual void AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) override; - - // Build the computation capturing all the operations created with the - // embedded builder (returned by the builder() API). - virtual torch::lazy::ComputationPtr Build() override; - - // Retrieves the lowered operation for an output. If the requested output is - // not available yet, the graph behind the output's Node is lowered, and the - // corresponding MLIR operation returned. - torch::jit::Value* GetOutputOp(const Output& output); - -private: - std::vector result_tuple_; - torch::lazy::OutputMap emitted_outputs_; -}; - -} // namespace lazy -} // namespace torch diff --git a/python/torch_mlir/csrc/backend/LazyShapeInference.cpp b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp similarity index 100% rename from python/torch_mlir/csrc/backend/LazyShapeInference.cpp rename to python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.cpp diff --git a/python/torch_mlir/csrc/backend/LazyShapeInference.h b/python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h similarity index 100% rename from python/torch_mlir/csrc/backend/LazyShapeInference.h rename to python/torch_mlir/csrc/base_lazy_backend/LazyShapeInference.h diff --git a/python/torch_mlir/csrc/backend/backend_impl.cpp b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp similarity index 98% rename from python/torch_mlir/csrc/backend/backend_impl.cpp rename to python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp index dfd5f4c2401..d097c827f46 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp @@ -95,7 +95,8 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( TorchMlirBackendData::Info* info = dynamic_cast(data->info()); TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + info, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); return info->tensor; } diff --git a/python/torch_mlir/csrc/backend/backend_impl.h b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h similarity index 98% rename from python/torch_mlir/csrc/backend/backend_impl.h rename to python/torch_mlir/csrc/base_lazy_backend/backend_impl.h index 33e509f6fb4..685c6a7996d 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.h +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.h @@ -41,7 +41,8 @@ class TORCH_API TorchMlirBackendData : public BackendData { TorchMlirBackendData(BackendDevice device, Shape shape); TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); - TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape); + TorchMlirBackendData( + const at::Tensor& tensor, BackendDevice device, Shape shape); virtual BackendData::Handle GetHandle() override; diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp new file mode 100644 index 00000000000..3ec87e678a1 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -0,0 +1,255 @@ +//===- mlir_lowering_context.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +//===----------------------------------------------------------------------===// + +#include + +#include +#include + +#include "../../dialects/torch/importer/jit_ir/csrc/function_importer.h" +#include "../utils/debug.h" +#include "../utils/exception.h" +#include "backend_impl.h" +#include "mlir-c/Registration.h" +#include "mlir_lowering_context.h" +#include "mlir_node.h" +#include "torch-mlir-c/Registration.h" + +namespace torch { +namespace lazy { + +/////////////////////////////////////////////////////////////////////////////// +// TorchMlir Computation +/////////////////////////////////////////////////////////////////////////////// + +TorchMlirComputation::TorchMlirComputation( + MlirOperation func_op, MlirContext mlir_context, + const std::shared_ptr& graph) + : func_op_(std::move(func_op)), mlir_context_(std::move(mlir_context)), + graph_(graph), num_results_(graph_->outputs().size()) { + + // TODO(henrytu): Save parameter shape information. + + for (torch::jit::Value* input : graph_->inputs()) { + parameter_names_.push_back(input->debugName()); + } +} + +int TorchMlirComputation::parameters_size() const { + return parameter_names_.size(); +} + +const std::vector& +TorchMlirComputation::parameter_shapes() const { + throw std::runtime_error( + "todo(whc) implement ts computation shapes or change interface"); + return parameter_shapes_; +} + +const std::vector& TorchMlirComputation::parameter_names() const { + return parameter_names_; +} + +const torch::lazy::Shape& TorchMlirComputation::result_shape() const { + throw std::runtime_error( + "todo(whc) implement ts computation shapes or change interface"); + return result_shape_; +} + +unsigned TorchMlirComputation::num_results() const { return num_results_; } + +MlirOperation TorchMlirComputation::func_op() const { return func_op_; } + +std::string TorchMlirComputation::to_string() const { + // Since we use the C-MLIR API, we need to use a callback to print. + MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + // user_data is a void ptr to some data structure of our choice -- in this + // case, the string stream where we'll be accumulating the strings. + std::stringstream* ss_ptr = static_cast(user_data); + *ss_ptr << std::string(part.data, part.length); + }; + + std::stringstream ss; + ss << "JIT Graph: \n" + << graph_->toString() << "\n\n" + << "MLIR: \n"; + mlirOperationPrint(func_op_, print_callback, &ss); + + return ss.str(); +} + +/////////////////////////////////////////////////////////////////////////////// +// TorchMlir Lowering Context +/////////////////////////////////////////////////////////////////////////////// + +TorchMlirLoweringContext::TorchMlirLoweringContext( + const std::string& name, BackendDevice device) + : LoweringContext(name, std::forward(device)), + graph_(std::make_shared()), + mlir_context_(mlirContextCreate()) { + lowering_ = TorchMlirNodeLoweringInterface::Create(this); + RegisterMlirDialects(); +} + +TorchMlirLoweringContext::TorchMlirLoweringContext( + const std::string& name, BackendDevice device, + c10::ArrayRef post_order, Util::EmissionMap emit_status) + : LoweringContext( + name, std::forward(device), + std::forward>(post_order), + std::forward(emit_status)), + graph_(std::make_shared()), + mlir_context_(mlirContextCreate()) { + lowering_ = TorchMlirNodeLoweringInterface::Create(this); + for (auto node : post_order) { + bool ok = lowering_->Lower(node); + CHECK(ok) << "Failed to lower: " << *node; + } + + RegisterMlirDialects(); +} + +// Get the shape of the result tuple component, given by index. +torch::lazy::Shape +TorchMlirLoweringContext::GetResultShape(size_t index) const { + UNIMPLEMENTED_FUNCTION_ERROR(); +} + +size_t TorchMlirLoweringContext::AddResult(const Output& output) { + PRINT_FUNCTION(); + + return AddResult(GetOutputOp(output)); +} + +// Associates the given output with the input parameter of the given index and +// shape. Only used for the operator-by-operator execution, mostly for +// debugging purposes. +void TorchMlirLoweringContext::AddParameter( + const torch::lazy::Output& output, size_t index, + const torch::lazy::Shape& shape, const std::string& name) { + UNIMPLEMENTED_FUNCTION_ERROR(); +} + +// Build the computation capturing all the operations created with the +// embedded builder (returned by the builder() API). +ComputationPtr TorchMlirLoweringContext::Build() { + PRINT_FUNCTION(); + + for (torch::jit::Value* output : root_tuple_) { + graph_->block()->registerOutput(output); + } + + // Create jit::Function from jit::Graph. + c10::QualifiedName name("graph"); + auto cu = std::make_shared(); + // IMPORTANT: We pass in a COPY of the graph into create_function, since it + // may get mutated in the process. + auto jit_fn = cu->create_function(std::move(name), std::move(graph_->copy())); + + // Generate MLIR. + MlirOperation func_op = + torch_mlir::importJitFunctionAsFuncOp(mlir_context_, jit_fn); + + // TODO(henrytu): Inject tensor shapes into func_op + return std::make_shared(func_op, mlir_context_, graph_); +} + +torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { + PRINT_FUNCTION(); + + auto it = emitted_outputs_.find(output); + if (it == emitted_outputs_.end()) { + auto post_order = Util::ComputePostOrder(output.node, &emit_status_); + for (auto node : post_order) { + bool ok = lowering_->Lower(node); + TORCH_CHECK(ok, "Failed to lower: ", node->ToString()); + } + // At this point the output better be present, otherwise there is an issue + // with the lowering code. + it = emitted_outputs_.find(output); + TORCH_CHECK( + it != emitted_outputs_.end(), + "No MLIR operation emitted for output: ", output.ToString()); + } + return it->second; +} + +void TorchMlirLoweringContext::AssignOutputOp( + const Output& output, torch::jit::Value* op) { + PRINT_FUNCTION(); + + auto torch_mlir_node = + NodeCast(output.node, output.node->op()); + if (!torch_mlir_node->getPythonStacktrace().empty()) { + op->node()->s_( + c10::Symbol::attr("source"), torch_mlir_node->getPythonStacktrace()); + } + emitted_outputs_[output] = std::move(op); +} + +torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { + PRINT_FUNCTION(); + + if (!dynamic_cast(data.get())) { + TORCH_CHECK( + false, + "Expected TorchMlirBackendData. Got some other BackendData type"); + } + const auto mlir_data = std::static_pointer_cast(data); + + BackendData::Handle handle = mlir_data->GetHandle(); + auto it = parameters_map_.find(handle); + + if (it == parameters_map_.end()) { + torch::jit::Value* param = + graph_->addInput(c10::str("p", parameters_.size())); + + auto info = mlir_data->mlir_info(); + if (info->scalar.has_value()) { + auto& scalar = info->scalar.value(); + if (scalar.isFloatingPoint()) { + param->setType(c10::FloatType::get()); + } else if (scalar.isIntegral(true)) { + param->setType(c10::IntType::get()); + } else { + TORCH_CHECK( + false, "Unhandled scalar type: ", c10::toString(scalar.type())); + } + } + + it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) + .first; + parameters_.push_back(mlir_data); + } + + parameter_sequence_.push_back(it->second.index); + return it->second.param; +} + +std::shared_ptr TorchMlirLoweringContext::graph() const { + return graph_; +} + +size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { + PRINT_FUNCTION(); + root_tuple_.push_back(std::move(op)); + return root_tuple_.size() - 1; +} + +void TorchMlirLoweringContext::RegisterMlirDialects() { + // https://reviews.llvm.org/D88162 + mlirRegisterAllDialects(mlir_context_); + torchMlirRegisterAllDialects(mlir_context_); +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h new file mode 100644 index 00000000000..76c9123cebb --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h @@ -0,0 +1,136 @@ +//===- mlir_lowering_context.h --------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include + +#include "mlir-c/IR.h" +#include "mlir_node_lowering.h" + +namespace torch { +namespace lazy { + +class TORCH_API TorchMlirNodeLoweringInterface { + /** + * This interface is only needed for legacy ops, and can be removed once all + * ops implement LtcMlirNode->lower(). + * */ +public: + TorchMlirNodeLoweringInterface() = default; + + virtual ~TorchMlirNodeLoweringInterface() = default; + + virtual bool Lower(const Node* node) = 0; + + static std::unique_ptr + Create(LoweringContext* loctx); +}; + +class TORCH_API TorchMlirComputation : public torch::lazy::Computation { +public: + TorchMlirComputation( + MlirOperation func_op, MlirContext mlir_context, + const std::shared_ptr& graph); + + int parameters_size() const override; + + const std::vector& parameter_shapes() const override; + + const std::vector& parameter_names() const override; + + const torch::lazy::Shape& result_shape() const override; + + unsigned num_results() const; + + MlirOperation func_op() const; + + std::string to_string() const; + +private: + std::vector parameter_names_; + std::vector parameter_shapes_; + Shape result_shape_; + + MlirOperation func_op_; + MlirContext mlir_context_; + std::shared_ptr graph_; + unsigned num_results_; +}; + +class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { +public: + TorchMlirLoweringContext( + const std::string& name, torch::lazy::BackendDevice device); + TorchMlirLoweringContext( + const std::string& name, torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); + + // Get the shape of the result tuple component, given by index. + torch::lazy::Shape GetResultShape(size_t index) const override; + + // Adds the given output as a component of the result tuple and returns its + // assigned position within the tuple. + size_t AddResult(const torch::lazy::Output& output) override; + + // Associates the given output with the input parameter of the given index and + // shape. Only used for the operator-by-operator execution, mostly for + // debugging purposes. + void AddParameter( + const torch::lazy::Output& output, size_t index, + const torch::lazy::Shape& shape, const std::string& name) override; + + // Build the computation capturing all the operations created with the + // embedded builder (returned by the builder() API). + torch::lazy::ComputationPtr Build() override; + + // Retrieves the lowered operation for an output. If the requested output is + // not available yet, the graph behind the output's Node is lowered, and the + // corresponding TS operation returned. + torch::jit::Value* GetOutputOp(const Output& output); + + // Assigns the given TS operation to the specified output. As outputs are + // lowered in a post-order fashion, later nodes should always find their + // operands among the emitted outputs. + void AssignOutputOp(const Output& output, torch::jit::Value* op); + + // If a parameter associated with data has already been declared, it will be + // returned. Otherwise a new one will be created, associated with the tensor + // held in data. + torch::jit::Value* GetParameter(BackendDataPtr data); + + std::shared_ptr graph() const; + +private: + struct Parameter { + torch::jit::Value* param; + size_t index = 0; + }; + + size_t AddResult(torch::jit::Value* op); + + void RegisterMlirDialects(); + + std::shared_ptr graph_; + MlirContext mlir_context_; + std::unordered_map parameters_map_; + std::vector root_tuple_; + OutputMap emitted_outputs_; + std::unique_ptr lowering_; +}; + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_native_functions.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp similarity index 100% rename from python/torch_mlir/csrc/backend/mlir_native_functions.cpp rename to python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp diff --git a/python/torch_mlir/csrc/backend/mlir_node.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp similarity index 88% rename from python/torch_mlir/csrc/backend/mlir_node.cpp rename to python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp index f0ba45b5ddb..cee4ae3f3cb 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.cpp @@ -18,8 +18,8 @@ namespace torch { namespace lazy { -TorchMlirOpVector -TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector TorchMlirNode::Lower( + TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { return {}; } diff --git a/python/torch_mlir/csrc/backend/mlir_node.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h similarity index 90% rename from python/torch_mlir/csrc/backend/mlir_node.h rename to python/torch_mlir/csrc/base_lazy_backend/mlir_node.h index 71068ac81e1..00a44f406e9 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.h +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node.h @@ -25,9 +25,6 @@ namespace torch { namespace lazy { -typedef std::vector TorchMlirOpVector; -typedef std::shared_ptr TorchMlirFunction; - class TORCH_API TorchMlirNode : public torch::lazy::Node { public: using torch::lazy::Node::Node; diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp new file mode 100644 index 00000000000..12489c1b053 --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -0,0 +1,452 @@ +//===- mlir_node_lowering.cpp ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.cpp +//===----------------------------------------------------------------------===// + +#include "mlir_node_lowering.h" +#include "mlir_lowering_context.h" +#include "mlir_node.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace lazy { + +class TorchMlirNodeLowering : public TorchMlirNodeLoweringInterface { +public: + TorchMlirNodeLowering( + const std::string& name, torch::lazy::TorchMlirLoweringContext* loctx) + : loctx_(loctx), function_( + loctx ? std::make_shared( + name, loctx->graph(), nullptr) + : nullptr) {} + + torch::lazy::TorchMlirLoweringContext* loctx() { return loctx_; } + + bool Lower(const torch::lazy::Node* node) override { + if (auto* torch_mlir_node = + dynamic_cast(node)) { + // First, we call the node lowering function, which exists for newly + // codegenned or refactored nodes + TorchMlirOpVector ops = torch_mlir_node->Lower(function_, loctx()); + if (ops.empty()) { + // Then fall back to legacy lowering code, which should be gradually + // removed + ops = LowerNonCodegenOps(node); + } + if (ops.empty()) { + return false; + } + CHECK_EQ(node->num_outputs(), ops.size()); + for (size_t i = 0; i < ops.size(); ++i) { + loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]); + } + return true; + } else { + TorchMlirOpVector ops = LowerNonCodegenOps(node); + if (!ops.empty()) { + CHECK_EQ(node->num_outputs(), ops.size()); + for (size_t i = 0; i < ops.size(); ++i) { + loctx()->AssignOutputOp(torch::lazy::Output(node, i), ops[i]); + } + return true; + } + } + throw std::runtime_error( + "Expected torch::lazy::TorchMlirNode but could not dynamic cast"); + } + + // TODO(whc) this is for legacy/non-codegen Ops, and after moving most ops + // to codegen we should delete this and put all the lowering logic into Node + // classes + TorchMlirOpVector LowerNonCodegenOps(const torch::lazy::Node* node) { + + if (node->op().op == at::aten::as_strided) { + return LowerAsStrided(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::as_strided))); + } + if (node->op() == *torch::lazy::ltc_as_strided_view_update) { + return LowerAsStridedViewUpdate( + torch::lazy::NodeCast( + node, *torch::lazy::ltc_as_strided_view_update)); + } + if (node->op() == *torch::lazy::ltc_cast) { + return LowerCast(torch::lazy::NodeCast( + node, *torch::lazy::ltc_cast)); + } + if (node->op() == *torch::lazy::ltc_select_view_update) { + return LowerSelectViewUpdate( + torch::lazy::NodeCast( + node, *torch::lazy::ltc_select_view_update)); + } + if (node->op() == *torch::lazy::ltc_narrow_view_update) { + return LowerNarrowViewUpdate( + torch::lazy::NodeCast( + node, *torch::lazy::ltc_narrow_view_update)); + } + if (node->op().op == at::prim::Constant) { + return LowerScalar(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::prim::Constant))); + } + if (node->op().op == at::aten::bernoulli) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + return LowerBuiltin(node, arguments); + } + if (node->op().op == at::aten::native_batch_norm) { + return LowerBatchNorm( + torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::native_batch_norm))); + } + if (node->op().op == at::aten::native_batch_norm_backward) { + return LowerBatchNormBackward( + torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::native_batch_norm_backward))); + } + if (node->op().op == at::aten::expand) { + return LowerExpand(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::expand))); + } + if (node->op().op == at::aten::narrow) { + return LowerNarrow(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::narrow))); + } + if (node->op().op == at::aten::permute) { + return LowerPermute(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::permute))); + } + if (node->op().op == at::aten::select) { + return LowerSelect(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::select))); + } + if (node->op().op == at::aten::squeeze) { + return LowerSqueeze(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::squeeze))); + } + if (node->op().op == at::aten::unsqueeze) { + return LowerUnsqueeze(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::unsqueeze))); + } + if (node->op().op == at::aten::view) { + return LowerView(torch::lazy::NodeCast( + node, torch::lazy::OpKind(at::aten::view))); + } + if (node->op() == *torch::lazy::ltc_device_data) { + const torch::lazy::DeviceData* device_data_node = + torch::lazy::NodeCast( + node, *torch::lazy::ltc_device_data); + auto infoptr = device_data_node->data()->info(); + auto deviceDataInfoPtr = + (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + if (GRAPH_DUMP_ENABLED) { + LOG(ERROR) << "Lowering device data node, tensor id " + << deviceDataInfoPtr->tensor_id << std::endl; + } + return {loctx()->GetParameter(device_data_node->data())}; + } + + std::vector arguments; + for (const torch::lazy::Output& output : node->operands()) { + arguments.emplace_back(loctx()->GetOutputOp(output)); + } + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector LowerBuiltin( + const torch::lazy::Node* node, + const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin( + function_, node->op().op, arguments, kwarguments); + } + TorchMlirOpVector LowerBuiltin( + c10::Symbol sym, const std::vector& arguments, + const std::vector& kwarguments = {}) { + return LowerTorchMlirBuiltin(function_, sym, arguments, kwarguments); + } + + TorchMlirOpVector LowerAsStrided(const torch::lazy::AsStrided* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.emplace_back(node->size()); + arguments.emplace_back(node->stride()); + arguments.emplace_back(node->storage_offset()); + TorchMlirOpVector as_strided_out = LowerBuiltin(node, arguments); + CHECK_EQ(as_strided_out.size(), 1); + return {GenerateClone(as_strided_out.front())}; + } + + TorchMlirOpVector + LowerAsStridedViewUpdate(const torch::lazy::AsStridedViewUpdate* node) { + torch::jit::Value* destination = + GenerateClone(loctx()->GetOutputOp(node->operand(0))); + const torch::lazy::Output& input_op = node->operand(1); + const torch::lazy::Shape& input_shape = input_op.shape(); + const auto input_dimensions = input_shape.sizes(); + std::vector dest_arguments; + dest_arguments.emplace_back(destination); + dest_arguments.emplace_back( + std::vector(input_dimensions.begin(), input_dimensions.end())); + dest_arguments.emplace_back(node->stride()); + dest_arguments.emplace_back(node->storage_offset()); + TorchMlirOpVector as_strided_out = + LowerBuiltin(at::aten::as_strided, dest_arguments); + CHECK_EQ(as_strided_out.size(), 1); + torch::jit::Value* as_strided = as_strided_out.front(); + GenerateCopy(as_strided, loctx()->GetOutputOp(input_op)); + return {destination}; + } + + TorchMlirOpVector + LowerBatchNorm(const torch::lazy::NativeBatchNormForward* node) { + std::vector arguments; + for (size_t i = 0; i < 5; ++i) { + arguments.emplace_back(loctx()->GetOutputOp(node->operand(i))); + } + arguments.emplace_back(node->training()); + arguments.emplace_back(node->momentum()); + arguments.emplace_back(node->eps()); + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector + LowerBatchNormBackward(const torch::lazy::NativeBatchNormBackward* node) { + std::vector arguments; + for (size_t i = 0; i < 3; ++i) { + arguments.emplace_back(loctx()->GetOutputOp(node->operand(i))); + } + const auto& operands = node->operands(); + c10::optional null_arg; + if (operands.size() == 5) { + arguments.emplace_back(null_arg); + arguments.emplace_back(null_arg); + } + for (size_t i = 3; i < operands.size(); ++i) { + arguments.emplace_back(loctx()->GetOutputOp(node->operand(i))); + } + arguments.emplace_back(node->training()); + arguments.emplace_back(node->eps()); + arguments.emplace_back(node->output_mask()); + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector LowerCast(const torch::lazy::Cast* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.emplace_back(node->dtype()); + return LowerBuiltin(at::aten::to, arguments); + } + + TorchMlirOpVector LowerExpand(const torch::lazy::Expand* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.emplace_back(node->size()); + auto expand_out = LowerBuiltin(node, arguments); + if (node->is_scalar_expand()) { + // The aten::expand operations sets all strides to 0 when the original + // of rank 0. This leads to false positives when checking for internal + // memory overlap, because at::has_internal_overlap returns + // MemOverlap::YES when a stride is set to 0. + CHECK_EQ(expand_out.size(), 1); + return {GenerateClone(expand_out.front())}; + } + return expand_out; + } + + TorchMlirOpVector LowerNarrow(const torch::lazy::Narrow* node) { + const torch::lazy::Output& input = node->operand(0); + torch::jit::Value* base = loctx()->GetOutputOp(input); + const auto& base_indices = node->base_indices(); + const auto& sizes = node->sizes(); + const torch::lazy::Shape& input_shape = input.shape(); + CHECK_EQ(sizes.size(), base_indices.size()); + CHECK_EQ(input_shape.dim(), base_indices.size()); + for (size_t dim = 0; dim < base_indices.size(); ++dim) { + int64_t start = base_indices[dim]; + base = GenerateSlice( + /*base=*/base, /*dim=*/dim, /*start=*/start, + /*end=*/start + sizes[dim], /*step=*/1); + } + return {base}; + } + + TorchMlirOpVector LowerPermute(const torch::lazy::Permute* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.push_back(node->dims()); + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector LowerScalar(const torch::lazy::Scalar* node) { + const at::Scalar& value = node->value(); + const torch::lazy::Shape& shape = node->shape(); + auto options = + at::TensorOptions() + .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) + .dtype(shape.scalar_type()); + return { + loctx()->graph()->insertConstant(at::scalar_tensor(value, options))}; + } + + TorchMlirOpVector LowerSelect(const torch::lazy::Select* node) { + int64_t step = torch::lazy::Select::GetStride( + node->start(), node->end(), node->stride()); + torch::jit::Value* base = loctx()->GetOutputOp(node->operand(0)); + return {GenerateSlice( + /*base=*/base, /*dim=*/node->dim(), + /*start=*/node->start(), /*end=*/node->end(), + /*step=*/step)}; + } + + TorchMlirOpVector LowerSqueeze(const torch::lazy::Squeeze* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + if (node->dim() != -1) { + arguments.push_back(node->dim()); + } + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector + LowerSelectViewUpdate(const torch::lazy::SelectViewUpdate* node) { + torch::jit::Value* dest = + GenerateClone(loctx()->GetOutputOp(node->operand(0))); + int64_t step = torch::lazy::Select::GetStride( + node->start(), node->end(), node->stride()); + torch::jit::Value* selected = GenerateSlice( + /*base=*/dest, /*dim=*/node->dim(), /*start=*/node->start(), + /*end=*/node->end(), /*step=*/step); + GenerateCopy(selected, loctx()->GetOutputOp(node->operand(1))); + return {dest}; + } + + TorchMlirOpVector + LowerNarrowViewUpdate(const torch::lazy::NarrowViewUpdate* node) { + torch::jit::Value* dest = + GenerateClone(loctx()->GetOutputOp(node->operand(0))); + const auto& base_indices = node->base_indices(); + const torch::lazy::Output& source_argument = node->operand(1); + const torch::lazy::Shape& source_shape = source_argument.shape(); + CHECK_EQ(source_shape.dim(), base_indices.size()); + torch::jit::Value* base = dest; + for (size_t dim = 0; dim < base_indices.size(); ++dim) { + int64_t start = base_indices[dim]; + base = GenerateSlice( + /*base=*/base, /*dim=*/dim, /*start=*/start, + /*end=*/start + source_shape.size(dim), + /*step=*/1); + } + GenerateCopy(base, loctx()->GetOutputOp(source_argument)); + return {dest}; + } + + TorchMlirOpVector LowerUnsqueeze(const torch::lazy::Unsqueeze* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.push_back(node->dim()); + return LowerBuiltin(node, arguments); + } + + TorchMlirOpVector LowerView(const torch::lazy::View* node) { + std::vector arguments; + arguments.emplace_back(loctx()->GetOutputOp(node->operand(0))); + arguments.push_back(node->output_size()); + return LowerBuiltin(at::aten::reshape, arguments); + } + + torch::jit::Value* GenerateClone(torch::jit::Value* val) { + std::vector clone_arguments; + clone_arguments.emplace_back(val); + TorchMlirOpVector cloned = LowerBuiltin(at::aten::clone, clone_arguments); + CHECK_EQ(cloned.size(), 1); + return cloned.front(); + } + + void GenerateCopy(torch::jit::Value* destination, torch::jit::Value* source) { + std::vector arguments; + arguments.emplace_back(destination); + arguments.emplace_back(source); + LowerBuiltin(at::aten::copy_, arguments); + } + + torch::jit::Value* GenerateSlice( + torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, + int64_t step) { + std::vector arguments; + arguments.emplace_back(base); + arguments.emplace_back(dim); + arguments.emplace_back(start); + arguments.emplace_back(end); + arguments.emplace_back(step); + TorchMlirOpVector selected = LowerBuiltin(at::aten::slice, arguments); + CHECK_EQ(selected.size(), 1); + return selected.front(); + } + torch::lazy::TorchMlirLoweringContext* loctx_; + std::shared_ptr function_; +}; + +std::unique_ptr +TorchMlirNodeLoweringInterface::Create(torch::lazy::LoweringContext* loctx) { + return std::make_unique( + "TorchMlirNodeLowering", + static_cast(loctx)); +} + +TorchMlirOpVector LowerTorchMlirBuiltin( + std::shared_ptr function, c10::Symbol sym, + const std::vector& arguments, + const std::vector& kwarguments) { + auto builtin = + std::make_shared(sym, at::nullopt); + auto magic_method = std::make_shared("", builtin); + auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); + auto sv = dynamic_cast(ret.get()); + CHECK(sv); + if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { + const auto tuple_call_result = sv->asTuple({}, *function); + TorchMlirOpVector tuple_result; + for (const auto& tuple_component : tuple_call_result) { + auto tuple_component_sv = + dynamic_cast(tuple_component.get()); + tuple_result.push_back(tuple_component_sv->getValue()); + } + return tuple_result; + } + return {sv->getValue()}; +} + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h new file mode 100644 index 00000000000..2e7774b251f --- /dev/null +++ b/python/torch_mlir/csrc/base_lazy_backend/mlir_node_lowering.h @@ -0,0 +1,30 @@ +//===- mlir_node_lowering.h -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +// This file is adapted from pytorch/pytorch +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node_lowering.h +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace torch { +namespace lazy { + +typedef std::vector TorchMlirOpVector; +typedef std::shared_ptr TorchMlirFunction; + +TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( + TorchMlirFunction function, c10::Symbol sym, + const std::vector& arguments, + const std::vector& kwarguments = {}); + +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/utils/debug.h b/python/torch_mlir/csrc/utils/debug.h index 59f98d99ac7..df61ca1aaac 100644 --- a/python/torch_mlir/csrc/utils/debug.h +++ b/python/torch_mlir/csrc/utils/debug.h @@ -21,3 +21,7 @@ static const bool verbose_print_function = std::cout << __PRETTY_FUNCTION__ << " (" << __FILE__ << ":" << __LINE__ \ << ")" << std::endl; \ } + +#define PRINT_DEBUG(msg) \ + std::cout << msg << " (" << __FILE__ << ":" << __LINE__ << ")" \ + << std::endl; \ No newline at end of file diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt index 735a40f7486..48f1d83df5f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/CMakeLists.txt @@ -10,7 +10,7 @@ include_directories(BEFORE ) link_directories("${TORCH_INSTALL_PREFIX}/lib") -add_library(TorchMLIRJITIRImporter MODULE +add_library(TorchMLIRJITIRImporter SHARED class_annotator.cpp get_registered_ops.cpp function_importer.cpp diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h index c749c26bc56..6a5652eb5d2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/function_importer.h @@ -39,7 +39,7 @@ namespace torch_mlir { /// will be attached as an argument attribute to the func op's argument. If a /// null MlirAttribute is returned, no attribute will be attached to that /// argument. -MlirOperation importJitFunctionAsFuncOp( +TORCH_API MlirOperation importJitFunctionAsFuncOp( MlirContext context, torch::jit::Function *function, std::function getArgAttribute = [](int) -> MlirAttribute { return {nullptr}; });