Skip to content

Commit

Permalink
Added JIT to MLIR lowering (#724)
Browse files Browse the repository at this point in the history
* Added JIT to MLIR lowering

Lowering to JIT is performed in a way similar to how it's done in the TS LTC backend. After a jit::Graph is constructed, it gets converted to a jit::Function, which is fed into the existing utility to generate an MlirModule in torch-mlir.

* Renamed `csrc/backend` to `csrc/base_lazy_backend`
  • Loading branch information
henrytwo authored and antoniojkim committed Jul 19, 2022
1 parent d847721 commit c4ce4a9
Show file tree
Hide file tree
Showing 19 changed files with 908 additions and 206 deletions.
18 changes: 9 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ __pycache__
# Bazel
bazel-*

# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/backend/LazyIr.h
/python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/backend/RegisterLazy.cpp

# Libraries
*.so
*.a

# Autogenerated files
/generated_native_functions.yaml
/generated_backend.hash
/python/torch_mlir/csrc/base_lazy_backend/LazyIr.h
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.cpp
/python/torch_mlir/csrc/base_lazy_backend/LazyNativeFunctions.h
/python/torch_mlir/csrc/base_lazy_backend/GenLazyShapeInference.cpp
/python/torch_mlir/csrc/base_lazy_backend/RegisterLazy.cpp
2 changes: 1 addition & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def main(args):
"generated_native_functions.yaml"
)
backend_path = TORCH_MLIR_DIR.joinpath(
"python", "torch_mlir", "csrc", "backend"
"python", "torch_mlir", "csrc", "base_lazy_backend"
)
assert backend_path.is_dir()

Expand Down
21 changes: 13 additions & 8 deletions python/torch_mlir/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib")


add_library(torch_mlir_ltc_backend SHARED
backend/backend_impl.cpp
backend/LazyNativeFunctions.cpp
backend/LazyShapeInference.cpp
backend/GenLazyShapeInference.cpp
backend/mlir_lowering_context.cpp
backend/mlir_native_functions.cpp
backend/mlir_node.cpp
backend/RegisterLazy.cpp
base_lazy_backend/backend_impl.cpp
base_lazy_backend/LazyNativeFunctions.cpp
base_lazy_backend/LazyShapeInference.cpp
base_lazy_backend/GenLazyShapeInference.cpp
base_lazy_backend/mlir_lowering_context.cpp
base_lazy_backend/mlir_native_functions.cpp
base_lazy_backend/mlir_node.cpp
base_lazy_backend/mlir_node_lowering.cpp
base_lazy_backend/RegisterLazy.cpp
)

add_dependencies(torch_mlir_ltc_backend
TorchMLIRJITIRImporter
)
target_link_libraries(torch_mlir_ltc_backend
TorchMLIRAggregateCAPI
TorchMLIRJITIRImporter
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
Expand Down
107 changes: 0 additions & 107 deletions python/torch_mlir/csrc/backend/mlir_lowering_context.cpp

This file was deleted.

72 changes: 0 additions & 72 deletions python/torch_mlir/csrc/backend/mlir_lowering_context.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data->info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
return info->tensor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class TORCH_API TorchMlirBackendData : public BackendData {

TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape);
TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape);

virtual BackendData::Handle GetHandle() override;

Expand Down
Loading

0 comments on commit c4ce4a9

Please sign in to comment.