-
Notifications
You must be signed in to change notification settings - Fork 496
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
411 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,3 +166,4 @@ else() | |
endif() | ||
|
||
add_subdirectory(test) | ||
add_subdirectory(examples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
add_subdirectory(ltc_backend) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
########################################################################### | ||
# Setup PyTorch | ||
########################################################################### | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules") | ||
include(TorchMLIRPyTorch) | ||
TorchMLIRProbeForPyTorchInstall() | ||
find_package(Torch 1.11 REQUIRED) | ||
|
||
TorchMLIRConfigurePyTorch() | ||
|
||
########################################################################### | ||
# Setup Python development | ||
########################################################################### | ||
|
||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/external/llvm-project/mlir/cmake/modules") | ||
include(MLIRDetectPythonEnv) | ||
mlir_configure_python_dev_packages() | ||
|
||
########################################################################### | ||
# Library definition | ||
########################################################################### | ||
|
||
include_directories(BEFORE | ||
${TORCH_INCLUDE_DIRS} | ||
${CMAKE_CURRENT_SOURCE_DIR} | ||
${CMAKE_CURRENT_BINARY_DIR} | ||
${Python3_INCLUDE_DIRS} | ||
${PYTHON_H_DIR} | ||
${PROJECT_SOURCE_DIR}/python | ||
) | ||
link_directories("${TORCH_INSTALL_PREFIX}/lib") | ||
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/lib) | ||
add_link_options(-Wl,-rpath,$ORIGIN/ltc_backend/lib) | ||
|
||
file(GLOB LTC_BACKEND_CSRC CONFIGURE_DEPENDS | ||
"ltc_backend/csrc/*.h" | ||
"ltc_backend/csrc/*.cc" | ||
"ltc_backend/csrc/*.cpp" | ||
"ltc_backend/csrc/*/*.h" | ||
"ltc_backend/csrc/*/*.cc" | ||
"ltc_backend/csrc/*/*.cpp" | ||
) | ||
add_library(example_mlir_ltc_backend SHARED ${LTC_BACKEND_CSRC}) | ||
add_dependencies(example_mlir_ltc_backend | ||
torch_mlir_ltc_backend | ||
) | ||
target_link_libraries(example_mlir_ltc_backend | ||
${TORCH_LIBRARIES} | ||
${Python3_LIBRARIES} | ||
torch_python | ||
torch_mlir_ltc_backend | ||
) | ||
|
||
message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic") | ||
set_target_properties(example_mlir_ltc_backend PROPERTIES | ||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/ltc_backend/" | ||
OUTPUT_NAME _EXAMPLE_MLIR_BACKEND | ||
PREFIX "${PYTHON_MODULE_PREFIX}" | ||
SUFFIX "${PYTHON_MODULE_EXTENSION}" | ||
CXX_VISIBILITY_PRESET "hidden" | ||
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic" | ||
) |
Empty file.
134 changes: 134 additions & 0 deletions
134
examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
//===- backend_impl.cpp ---------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include <torch/csrc/lazy/backend/backend_data.h> | ||
#include <torch/csrc/lazy/backend/backend_device.h> | ||
#include <torch/csrc/lazy/backend/lowering_context.h> | ||
#include <torch/csrc/lazy/core/shape.h> | ||
|
||
#include <torch_mlir/csrc/backend/LazyNativeFunctions.h> | ||
#include <torch_mlir/csrc/backend/aten_eager_fallback.h> | ||
#include <torch_mlir/csrc/backend/backend_impl.h> | ||
#include <torch_mlir/csrc/backend/mlir_lowering_context.h> | ||
#include <torch_mlir/csrc/utils/debug.h> | ||
#include <torch_mlir/csrc/utils/exception.h> | ||
|
||
#include "backend_impl.h" | ||
|
||
using namespace torch::lazy; | ||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
struct ExampleMlirBackendDeviceType : public BackendDeviceType { | ||
ExampleMlirBackendDeviceType(std::string device_type) | ||
: device_type_(device_type) {} | ||
|
||
std::string toString() const override { return device_type_; } | ||
|
||
std::string device_type_; | ||
}; | ||
|
||
class ExampleMlirBackendImpl : public torch::lazy::TorchMlirBackendImpl { | ||
public: | ||
ExampleMlirBackendImpl() : default_device_type_("Magic") {} | ||
|
||
/** | ||
* Configuration | ||
* */ | ||
void SetRngSeed(size_t seed) const override { | ||
std::cout << "RNG Seed Set to: " << seed << std::endl; | ||
} | ||
|
||
/** | ||
* Lowering, Compilation, Execution | ||
* */ | ||
std::vector<std::string> | ||
GetCompilationDevices(const std::string &device, | ||
c10::ArrayRef<std::string> devices) const override { | ||
return std::vector<std::string>(devices.begin(), devices.end()); | ||
}; | ||
|
||
std::vector<ComputationPtr> | ||
Compile(std::vector<ComputationPtr> instances) const override { | ||
PRINT_FUNCTION(); | ||
|
||
// Vendor backend specific lowering can be exec here before returning. | ||
for (const auto &instance : instances) { | ||
std::cout << "Instance received at Compile: \n" | ||
<< GetComputationBackendText(instance) << std::endl; | ||
} | ||
|
||
return instances; | ||
} | ||
|
||
std::vector<BackendDataPtr> | ||
ExecuteComputation(Computation &computation, | ||
c10::ArrayRef<BackendDataPtr> arguments, | ||
const BackendDevice &device) const override { | ||
PRINT_FUNCTION(); | ||
|
||
std::vector<torch::lazy::BackendDataPtr> results; | ||
|
||
// For this demo we aren't performing a full execution, so we generate | ||
// some dummy data to return based on the number of return vals in MLIR. | ||
auto mlir_computation = static_cast<TorchMlirComputation *>(&computation); | ||
for (unsigned i = 0; i < mlir_computation->num_results(); i++) { | ||
// TODO(henrytu): This is a hack to generate dummy data. It | ||
// currently results in a "tensor does not have a device" error. | ||
// Also noticed that device does not correspond to our example | ||
// device, so that will need to be investigated to see if it has | ||
// anything to do with this problem. | ||
results.push_back(std::make_shared<TorchMlirBackendData>( | ||
device, arguments[0]->shape())); | ||
} | ||
|
||
return results; | ||
} | ||
|
||
/** | ||
* Device Configuration | ||
* */ | ||
std::shared_ptr<torch::lazy::BackendDeviceType> GetDefaultDeviceType() const { | ||
return std::make_shared<BackendDeviceType>(default_device_type_); | ||
} | ||
|
||
void SetDefaultDeviceType(std::string device_type) { | ||
default_device_type_ = ExampleMlirBackendDeviceType(device_type); | ||
} | ||
|
||
/** | ||
* Debug/Metrics | ||
* */ | ||
std::string | ||
GetComputationBackendText(const ComputationPtr computation) const override { | ||
auto mlir_computation = | ||
static_cast<TorchMlirComputation *>(computation.get()); | ||
return mlir_computation->toString(); | ||
} | ||
|
||
private: | ||
ExampleMlirBackendDeviceType default_device_type_; | ||
}; | ||
|
||
BackendImplInterface *GetExampleMlirBackendImpl() { | ||
static ExampleMlirBackendImpl *example_mlir_backend_impl = | ||
new ExampleMlirBackendImpl(); | ||
return example_mlir_backend_impl; | ||
} | ||
|
||
void InitExampleMlirBackend() { | ||
at::RegisterTorchMlirLazyNativeFunctions(); | ||
register_mlir_ltc_eager_fallback(); | ||
static std::unique_ptr<BackendRegistrar> g_registrar; | ||
g_registrar.reset(new BackendRegistrar(GetExampleMlirBackendImpl())); | ||
} | ||
|
||
} // namespace lazy | ||
} // namespace torch |
27 changes: 27 additions & 0 deletions
27
examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
//===- backend_impl.h -----------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <torch/csrc/lazy/backend/backend_interface.h> | ||
|
||
namespace at { | ||
// This function is defined in the codegenerated RegisterLazy.cpp file. | ||
extern TORCH_API void RegisterTorchMlirLazyNativeFunctions(); | ||
} // namespace at | ||
|
||
namespace torch { | ||
namespace lazy { | ||
|
||
torch::lazy::BackendImplInterface *GetExampleMlirBackendImpl(); | ||
|
||
void InitExampleMlirBackend(); | ||
|
||
} // namespace lazy | ||
} // namespace torch |
73 changes: 73 additions & 0 deletions
73
examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
//===- example_mlir_backend_pybind.cpp ------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "torch/csrc/jit/python/pybind.h" | ||
#include "torch/csrc/lazy/backend/backend_interface.h" | ||
|
||
#include <exception> | ||
#include <iostream> | ||
#include <string> | ||
|
||
#include "backend/backend_impl.h" | ||
#include "utils/sys_utils.h" | ||
|
||
namespace py = pybind11; | ||
|
||
namespace { | ||
bool verbose = sys_util::GetEnv("VERBOSE", false); | ||
|
||
struct NoGilSection { | ||
NoGilSection() : state(PyEval_SaveThread()) {} | ||
~NoGilSection() { PyEval_RestoreThread(state); } | ||
PyThreadState *state = nullptr; | ||
}; | ||
|
||
/** | ||
* @brief Install the plugin | ||
*/ | ||
void Initialize() { | ||
// Initialize the Example MLIR LTC Backend | ||
torch::lazy::InitExampleMlirBackend(); | ||
|
||
// sanity check | ||
const torch::lazy::BackendImplInterface *mlir_backend = | ||
torch::lazy::GetExampleMlirBackendImpl(); | ||
const torch::lazy::BackendImplInterface *lazy_backend = | ||
torch::lazy::getBackend(); | ||
if (lazy_backend != mlir_backend) { | ||
std::cout << "Failed to initialize MLIR Lazy Backend" << std::endl; | ||
throw std::runtime_error("Failed to initialize MLIR Lazy Backend"); | ||
} | ||
|
||
if (verbose) { | ||
std::cout << "MLIR LTC PyTorch Plugin Initialized." << std::endl; | ||
} | ||
} | ||
|
||
/** | ||
* @brief Uninstall the plugin | ||
*/ | ||
void Shutdown() { | ||
if (verbose) { | ||
std::cout << "MLIR LTC PyTorch Plugin Shut down." << std::endl; | ||
} | ||
} | ||
} // anonymous namespace | ||
|
||
PYBIND11_MODULE(_EXAMPLE_MLIR_BACKEND, m) { | ||
m.doc() = ("pybind11 for example MLIR LTC backend."); | ||
m.def("_initialize", []() { | ||
NoGilSection gil; | ||
Initialize(); | ||
}); | ||
m.def("_shutdown", []() { | ||
NoGilSection gil; | ||
Shutdown(); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//===- sys_utils.h --------------------------------------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include <cstdlib> | ||
#include <string> | ||
|
||
namespace sys_util { | ||
|
||
template <typename T> | ||
T GetEnv(const std::string &name, const T &default_value = T(0)) { | ||
const char *env = std::getenv(name.c_str()); | ||
if (!env) { | ||
return default_value; | ||
} | ||
return T(std::atoi(env)); | ||
} | ||
|
||
} // namespace sys_util |
Oops, something went wrong.