From 4c8fed186ac2ea37c0a5302f14c2f5e28b8d7b06 Mon Sep 17 00:00:00 2001 From: Henry Tu Date: Thu, 14 Apr 2022 12:53:00 -0400 Subject: [PATCH] Add example Torch MLIR LTC Backend (#725) --- CMakeLists.txt | 1 + examples/CMakeLists.txt | 1 + examples/ltc_backend/CMakeLists.txt | 63 ++++++++ examples/ltc_backend/ltc_backend/__init__.py | 0 .../ltc_backend/csrc/backend/backend_impl.cpp | 140 ++++++++++++++++++ .../ltc_backend/csrc/backend/backend_impl.h | 27 ++++ .../csrc/example_mlir_backend_pybind.cpp | 73 +++++++++ .../ltc_backend/csrc/utils/sys_utils.h | 26 ++++ examples/ltc_backend_mnist.py | 86 +++++++++++ .../csrc/base_lazy_backend/backend_impl.cpp | 19 ++- 10 files changed, 434 insertions(+), 2 deletions(-) create mode 100644 examples/CMakeLists.txt create mode 100644 examples/ltc_backend/CMakeLists.txt create mode 100644 examples/ltc_backend/ltc_backend/__init__.py create mode 100644 examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp create mode 100644 examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h create mode 100644 examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp create mode 100644 examples/ltc_backend/ltc_backend/csrc/utils/sys_utils.h create mode 100644 examples/ltc_backend_mnist.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b0f48c87805..685296420f10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -165,3 +165,4 @@ else() endif() add_subdirectory(test) +add_subdirectory(examples) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 000000000000..d390ea366998 --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ltc_backend) diff --git a/examples/ltc_backend/CMakeLists.txt b/examples/ltc_backend/CMakeLists.txt new file mode 100644 index 000000000000..e08d32d6de08 --- /dev/null +++ b/examples/ltc_backend/CMakeLists.txt @@ -0,0 +1,63 @@ +########################################################################### +# Setup PyTorch +########################################################################### + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") +include(TorchMLIRPyTorch) +TorchMLIRProbeForPyTorchInstall() +find_package(Torch 1.11 REQUIRED) + +TorchMLIRConfigurePyTorch() + +########################################################################### +# Setup Python development +########################################################################### + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules") +include(MLIRDetectPythonEnv) +mlir_configure_python_dev_packages() + +########################################################################### +# Library definition +########################################################################### + +include_directories(BEFORE + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR} + ${Python3_INCLUDE_DIRS} + ${PYTHON_H_DIR} + ${PROJECT_SOURCE_DIR}/python + ) +link_directories("${TORCH_INSTALL_PREFIX}/lib") +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib) +add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib) + +file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS + "ltc_backend/csrc/*.h" + "ltc_backend/csrc/*.cc" + "ltc_backend/csrc/*.cpp" + "ltc_backend/csrc/*/*.h" + "ltc_backend/csrc/*/*.cc" + "ltc_backend/csrc/*/*.cpp" + ) +add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC}) +add_dependencies(example_mlir_ltc_backend + torch_mlir_ltc_backend + ) +target_link_libraries(example_mlir_ltc_backend + ${TORCH_LIBRARIES} + ${Python3_LIBRARIES} + torch_python + torch_mlir_ltc_backend + ) + +message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") +set_target_properties(example_mlir_ltc_backend PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/" + OUTPUT_NAME _EXAMPLE_MLIR_BACKEND + PREFIX "${PYTHON_MODULE_PREFIX}" + SUFFIX "${PYTHON_MODULE_EXTENSION}" + CXX_VISIBILITY_PRESET "hidden" + COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" + ) diff --git a/examples/ltc_backend/ltc_backend/__init__.py b/examples/ltc_backend/ltc_backend/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp new file mode 100644 index 000000000000..c05d932073bb --- /dev/null +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp @@ -0,0 +1,140 @@ +//===- backend_impl.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "backend_impl.h" + +using namespace torch::lazy; + +namespace torch { +namespace lazy { + +struct ExampleMlirBackendDeviceType : public BackendDeviceType { + ExampleMlirBackendDeviceType(std::string device_type) + : device_type_(device_type) {} + + std::string toString() const override { return device_type_; } + + std::string device_type_; +}; + +class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { +public: + ExampleMlirBackendImpl() : default_device_type_("Magic") {} + + /** + * Configuration + * */ + void SetRngSeed(size_t seed) const override { + std::cout << "RNG Seed Set to: " << seed << std::endl; + } + + /** + * Lowering, Compilation, Execution + * */ + std::vector + GetCompilationDevices(const std::string &device, + c10::ArrayRef devices) const override { + return std::vector(devices.begin(), devices.end()); + }; + + std::vector + Compile(std::vector instances) const override { + PRINT_FUNCTION(); + + // Vendor backend specific lowering can be exec here before returning. + for (const auto &instance : instances) { + std::cout << "Instance received at Compile: \n" + << GetComputationBackendText(instance) << std::endl; + } + + return instances; + } + + std::vector + ExecuteComputation(Computation &computation, + c10::ArrayRef arguments, + const BackendDevice &device) const override { + PRINT_FUNCTION(); + + // `arguments` maps 1:1 with the parameters in the generated MLIR. In this + // function, we will generate a list of BackendData that corresponds to the + // return values in the MLIR. + std::vector results; + + // "Borrow" some tensor data from arguments to reuse in return. This ensures + // that the tensor device is correctly configured. + TORCH_CHECK(arguments.size() > 0, + "Need at least one argument for example execution."); + const TorchMlirBackendData *torch_mlir_data = + dynamic_cast(arguments[0].get()); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + + // For this demo we aren't performing a legitimate execution, so we generate + // some dummy data to return based on the expected number of return values. + auto mlir_computation = static_cast(&computation); + for (unsigned i = 0; i < mlir_computation->num_results(); i++) { + results.push_back(std::make_shared( + torch_mlir_data->mlir_info()->tensor, device, + torch_mlir_data->shape())); + } + + return results; + } + + /** + * Device Configuration + * */ + std::shared_ptr GetDefaultDeviceType() const { + return std::make_shared(default_device_type_); + } + + void SetDefaultDeviceType(std::string device_type) { + default_device_type_ = ExampleMlirBackendDeviceType(device_type); + } + + /** + * Debug/Metrics + * */ + std::string + GetComputationBackendText(const ComputationPtr computation) const override { + auto mlir_computation = + static_cast(computation.get()); + return mlir_computation->to_string(); + } + +private: + ExampleMlirBackendDeviceType default_device_type_; +}; + +BackendImplInterface *GetExampleMlirBackendImpl() { + static ExampleMlirBackendImpl *example_mlir_backend_impl = + new ExampleMlirBackendImpl(); + return example_mlir_backend_impl; +} + +void InitExampleMlirBackend() { + at::RegisterTorchMlirLazyNativeFunctions(); + static std::unique_ptr g_registrar; + g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl())); +} + +} // namespace lazy +} // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h new file mode 100644 index 000000000000..377ae4d219f9 --- /dev/null +++ b/examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h @@ -0,0 +1,27 @@ +//===- backend_impl.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace at { +// This function is defined in the codegenerated RegisterLazy.cpp file. +TORCH_API void RegisterTorchMlirLazyNativeFunctions(); +} // namespace at + +namespace torch { +namespace lazy { + +torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl(); + +void InitExampleMlirBackend(); + +} // namespace lazy +} // namespace torch diff --git a/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp new file mode 100644 index 000000000000..1474b4dc9077 --- /dev/null +++ b/examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp @@ -0,0 +1,73 @@ +//===- example_mlir_backend_pybind.cpp ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch/csrc/jit/python/pybind.h" +#include "torch/csrc/lazy/backend/backend_interface.h" + +#include +#include +#include + +#include "backend/backend_impl.h" +#include "utils/sys_utils.h" + +namespace py = pybind11; + +namespace { +bool verbose = sys_util::GetEnv("VERBOSE", false); + +struct NoGilSection { + NoGilSection() : state(PyEval_SaveThread()) {} + ~NoGilSection() { PyEval_RestoreThread(state); } + PyThreadState *state = nullptr; +}; + +/** + * @brief Install the plugin + */ +void Initialize() { + // Initialize the Example MLIR LTC Backend + torch::lazy::InitExampleMlirBackend(); + + // sanity check + const torch::lazy::BackendImplInterface *mlir_backend = + torch::lazy::GetExampleMlirBackendImpl(); + const torch::lazy::BackendImplInterface *lazy_backend = + torch::lazy::getBackend(); + if (lazy_backend != mlir_backend) { + std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl; + throw std::runtime_error("Failed to initialize MLIR Lazy Backend"); + } + + if (verbose) { + std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; + } +} + +/** + * @brief Uninstall the plugin + */ +void Shutdown() { + if (verbose) { + std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl; + } +} +} // anonymous namespace + +PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) { + m.doc() = ("pybind11 for example MLIR LTC backend."); + m.def("_initialize", []() { + NoGilSection gil; + Initialize(); + }); + m.def("_shutdown", []() { + NoGilSection gil; + Shutdown(); + }); +} diff --git a/examples/ltc_backend/ltc_backend/csrc/utils/sys_utils.h b/examples/ltc_backend/ltc_backend/csrc/utils/sys_utils.h new file mode 100644 index 000000000000..640d872cf5ad --- /dev/null +++ b/examples/ltc_backend/ltc_backend/csrc/utils/sys_utils.h @@ -0,0 +1,26 @@ +//===- sys_utils.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +namespace sys_util { + +template +T GetEnv(const std::string &name, const T &default_value = T(0)) { + const char *env = std::getenv(name.c_str()); + if (!env) { + return default_value; + } + return T(std::atoi(env)); +} + +} // namespace sys_util diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py new file mode 100644 index 000000000000..8e7af23b4fba --- /dev/null +++ b/examples/ltc_backend_mnist.py @@ -0,0 +1,86 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. +""" +Example use of the example Torch MLIR LTC backend. +""" +import argparse + +import torch.nn.functional as F + + +def main(device): + import torch + + if device in ("TS", "MLIR_EXAMPLE"): + import torch._lazy + + if device == "TS": + import torch._lazy.ts_backend + + torch._lazy.ts_backend.init() + + elif device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend + + ltc_backend._initialize() + + device = "lazy" + print("Initialized backend") + else: + device = device.lower() + + inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) + assert inputs.device.type == device + + targets = torch.tensor([3], dtype=torch.int64, device=device) + assert targets.device.type == device + + print("Initialized data") + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(5, 5) + + def forward(self, x): + out = self.fc1(x) + out = F.relu(out) + return out + + model = Model().to(device) + model.train() + assert all(p.device.type == device for p in model.parameters()) + + print("Initialized model") + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + optimizer.zero_grad() + + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + if device == "lazy": + print("Calling Mark Step") + torch._lazy.mark_step() + + print() + print(loss) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-d", + "--device", + type=str.upper, + choices=["CPU", "TS", "MLIR_EXAMPLE"], + default="MLIR_EXAMPLE", + help="The device type", + ) + args = parser.parse_args() + main(args.device) diff --git a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp index d097c827f46b..8a8b121e2b23 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/backend_impl.cpp @@ -46,11 +46,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() { } void TorchMlirBackendData::Assign(const BackendData& data) { + const TorchMlirBackendData* torch_mlir_data = + dynamic_cast(&data); + TORCH_CHECK( + torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + TorchMlirBackendData::Info* info = - dynamic_cast(data.info()); + dynamic_cast(torch_mlir_data->mlir_info()); TORCH_CHECK( info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + info_ = std::make_unique(*info); } @@ -92,11 +99,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( const BackendDataPtr data, c10::optional logical_scalar_type) const { PRINT_FUNCTION(); + + TorchMlirBackendData* torch_mlir_data = + dynamic_cast(data.get()); + TORCH_CHECK( + torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + TorchMlirBackendData::Info* info = - dynamic_cast(data->info()); + dynamic_cast(torch_mlir_data->mlir_info()); TORCH_CHECK( info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + return info->tensor; }