Skip to content

Commit

Permalink
Add example Torch MLIR LTC Backend
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytwo committed Apr 1, 2022
1 parent 2de2f11 commit eb1737d
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,4 @@ else()
endif()

add_subdirectory(test)
add_subdirectory(examples)
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(ltc_backend)
63 changes: 63 additions & 0 deletions examples/ltc_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
###########################################################################
# Setup PyTorch
###########################################################################

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)
TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.11 REQUIRED)

TorchMLIRConfigurePyTorch()

###########################################################################
# Setup Python development
###########################################################################

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules")
include(MLIRDetectPythonEnv)
mlir_configure_python_dev_packages()

###########################################################################
# Library definition
###########################################################################

include_directories(BEFORE
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
${PYTHON_H_DIR}
${PROJECT_SOURCE_DIR}/python
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib)
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib)

file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS
"ltc_backend/csrc/*.h"
"ltc_backend/csrc/*.cc"
"ltc_backend/csrc/*.cpp"
"ltc_backend/csrc/*/*.h"
"ltc_backend/csrc/*/*.cc"
"ltc_backend/csrc/*/*.cpp"
)
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC})
add_dependencies(example_mlir_ltc_backend
torch_mlir_ltc_backend
)
target_link_libraries(example_mlir_ltc_backend
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
torch_mlir_ltc_backend
)

message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(example_mlir_ltc_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/"
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
)
Empty file.
134 changes: 134 additions & 0 deletions examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===- backend_impl.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/backend/backend_device.h>
#include <torch/csrc/lazy/backend/lowering_context.h>
#include <torch/csrc/lazy/core/shape.h>

#include <torch_mlir/csrc/backend/LazyNativeFunctions.h>
#include <torch_mlir/csrc/backend/aten_eager_fallback.h>
#include <torch_mlir/csrc/backend/backend_impl.h>
#include <torch_mlir/csrc/backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/utils/debug.h>
#include <torch_mlir/csrc/utils/exception.h>

#include "backend_impl.h"

using namespace torch::lazy;

namespace torch {
namespace lazy {

struct ExampleMlirBackendDeviceType : public BackendDeviceType {
ExampleMlirBackendDeviceType(std::string device_type)
: device_type_(device_type) {}

std::string toString() const override { return device_type_; }

std::string device_type_;
};

class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl {
public:
ExampleMlirBackendImpl() : default_device_type_("Magic") {}

/**
* Configuration
* */
void SetRngSeed(size_t seed) const override {
std::cout << "RNG Seed Set to: " << seed << std::endl;
}

/**
* Lowering, Compilation, Execution
* */
std::vector<std::string>
GetCompilationDevices(const std::string &device,
c10::ArrayRef<std::string> devices) const override {
return std::vector<std::string>(devices.begin(), devices.end());
};

std::vector<ComputationPtr>
Compile(std::vector<ComputationPtr> instances) const override {
PRINT_FUNCTION();

// Vendor backend specific lowering can be exec here before returning.
for (const auto &instance : instances) {
std::cout << "Instance received at Compile: \n"
<< GetComputationBackendText(instance) << std::endl;
}

return instances;
}

std::vector<BackendDataPtr>
ExecuteComputation(Computation &computation,
c10::ArrayRef<BackendDataPtr> arguments,
const BackendDevice &device) const override {
PRINT_FUNCTION();

std::vector<torch::lazy::BackendDataPtr> results;

// For this demo we aren't performing a full execution, so we generate
// some dummy data to return based on the number of return vals in MLIR.
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation);
for (unsigned i = 0; i < mlir_computation->num_results(); i++) {
// TODO(henrytu): This is a hack to generate dummy data. It
// currently results in a "tensor does not have a device" error.
// Also noticed that device does not correspond to our example
// device, so that will need to be investigated to see if it has
// anything to do with this problem.
results.push_back(std::make_shared<TorchMlirBackendData>(
device, arguments[0]->shape()));
}

return results;
}

/**
* Device Configuration
* */
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const {
return std::make_shared<BackendDeviceType>(default_device_type_);
}

void SetDefaultDeviceType(std::string device_type) {
default_device_type_ = ExampleMlirBackendDeviceType(device_type);
}

/**
* Debug/Metrics
* */
std::string
GetComputationBackendText(const ComputationPtr computation) const override {
auto mlir_computation =
static_cast<TorchMlirComputation *>(computation.get());
return mlir_computation->toString();
}

private:
ExampleMlirBackendDeviceType default_device_type_;
};

BackendImplInterface *GetExampleMlirBackendImpl() {
static ExampleMlirBackendImpl *example_mlir_backend_impl =
new ExampleMlirBackendImpl();
return example_mlir_backend_impl;
}

void InitExampleMlirBackend() {
at::RegisterTorchMlirLazyNativeFunctions();
register_mlir_ltc_eager_fallback();
static std::unique_ptr<BackendRegistrar> g_registrar;
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl()));
}

} // namespace lazy
} // namespace torch
27 changes: 27 additions & 0 deletions examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- backend_impl.h -----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <torch/csrc/lazy/backend/backend_interface.h>

namespace at {
// This function is defined in the codegenerated RegisterLazy.cpp file.
extern TORCH_API void RegisterTorchMlirLazyNativeFunctions();
} // namespace at

namespace torch {
namespace lazy {

torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl();

void InitExampleMlirBackend();

} // namespace lazy
} // namespace torch
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//===- example_mlir_backend_pybind.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/lazy/backend/backend_interface.h"

#include <exception>
#include <iostream>
#include <string>

#include "backend/backend_impl.h"
#include "utils/sys_utils.h"

namespace py = pybind11;

namespace {
bool verbose = sys_util::GetEnv("VERBOSE", false);

struct NoGilSection {
NoGilSection() : state(PyEval_SaveThread()) {}
~NoGilSection() { PyEval_RestoreThread(state); }
PyThreadState *state = nullptr;
};

/**
* @brief Install the plugin
*/
void Initialize() {
// Initialize the Example MLIR LTC Backend
torch::lazy::InitExampleMlirBackend();

// sanity check
const torch::lazy::BackendImplInterface *mlir_backend =
torch::lazy::GetExampleMlirBackendImpl();
const torch::lazy::BackendImplInterface *lazy_backend =
torch::lazy::getBackend();
if (lazy_backend != mlir_backend) {
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl;
throw std::runtime_error("Failed to initialize MLIR Lazy Backend");
}

if (verbose) {
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl;
}
}

/**
* @brief Uninstall the plugin
*/
void Shutdown() {
if (verbose) {
std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl;
}
}
} // anonymous namespace

PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) {
m.doc() = ("pybind11 for example MLIR LTC backend.");
m.def("_initialize", []() {
NoGilSection gil;
Initialize();
});
m.def("_shutdown", []() {
NoGilSection gil;
Shutdown();
});
}
26 changes: 26 additions & 0 deletions examples/ltc_backend/ltc_backend/csrc/utils/sys_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- sys_utils.h --------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cstdlib>
#include <string>

namespace sys_util {

template <typename T>
T GetEnv(const std::string &name, const T &default_value = T(0)) {
const char *env = std::getenv(name.c_str());
if (!env) {
return default_value;
}
return T(std::atoi(env));
}

} // namespace sys_util
Loading

0 comments on commit eb1737d

Please sign in to comment.