diff --git a/README.md b/README.md index 31dd3ec2cc7..5ccf94a32fa 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,8 @@ We have few paths to lower down to the Torch MLIR Dialect. - TorchScript This is the most tested path down to Torch MLIR Dialect, and the PyTorch ecosystem is converging on using TorchScript IR as a lingua franca. - - LazyTensorCore (Based on the PyTorch [`lazy_tensor_staging` branch](https://github.com/pytorch/pytorch/tree/lazy_tensor_staging/lazy_tensor_core)) - This path provides the upcoming LTC path of capture. It is based of an unstable devel branch but is the closest way for you to adapt any existing `torch/xla` derivatives. - + - LazyTensorCore + Read more details [here](docs/ltc_backend.md). ## Project Communication - `#torch-mlir` channel on the LLVM [Discord](https://discord.gg/xS7Z362) - this is the most active communication channel @@ -71,11 +70,9 @@ torch-mlir prediction [('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)] ``` -### LazyTensorCore - -Lazy Tensor Core support is provided through an abstract [`TorchMlirBackendImpl`](python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class. An example implementation is available [here](examples/ltc_backend/ltc_backend). +### Lazy Tensor Core -There are also examples of a [HuggingFace BERT](examples/ltc_backend_bert.py) and [MNIST model](examples/ltc_backend_mnist.py) running on the example/reference LTC backend. +View examples [here](docs/ltc_examples.md). ### Eager Mode diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md new file mode 100644 index 00000000000..1e60fffe40d --- /dev/null +++ b/docs/ltc_backend.md @@ -0,0 +1,132 @@ +# Torch-MLIR Lazy Tensor Core Backend + +## Table of Contents +- [Introduction](#introduction) +- [Examples](#examples) +- [Code Structure](#code-structure) +- [Architecture](#architecture) +- [Implementing a custom backend](#implementing-a-custom-backend) +- [Future Expansion](#future-expansion) + +## Introduction +[Lazy Tensor Core](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/tutorial.md) is a tracing system in PyTorch which is supported as an entry point to Torch-MLIR. +After registering an LTC backend, all operations performed on lazy tensors are recorded and handed off to the backend implementation. + +LTC support is provided through an abstract [`TorchMlirBackendImpl`](../python/torch_mlir/csrc/base_lazy_backend/backend_impl.h) class, which handles the conversion to MLIR. +Implementations based on this abstract class will be able to specify their own compile and execution workflows. +Additional details about how to implement a custom backend is available [below](#Implementing-a-custom-backend). + +## Examples +View examples [here](ltc_examples.md). + +## Code Structure + +### Autogen Build Tools ([`build_tools`](../build_tools)) + +- `autogen_ltc_backend.{py,yaml}` + - The [autogen files](#autogen-files) are generated by this script based on the list of supported ops, which includes all ops from [`GeneratedTorchOps.td`](https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td), + excluding those explicitly blacklisted in the YAML file + +### Autogen Files ([`python/torch_mlir/csrc/base_lazy_backend/generated`](../python/torch_mlir/csrc/base_lazy_backend/generated)) +Generated files are created in this directory, which is ignored by version control. + +- `LazyIr.h` + - Definitions of `torch::lazy:TorchMlirNode` subclasses for each supported autogen op +- `LazyNativeFunctions.{cpp,h}` + - Native function definitions for each supported op (handles `at::Tensor -> at::Tensor` data flow and creation of `torch::lazy:TorchMlirNode`) +- `LazyNonNativeIr.h` + - Non-native `torch::lazy:TorchMlirNode` subclasses +- `RegisterLazy.cpp` + - Registers PyTorch kernels under the `lazy` dispatch key for all supported ops, which map to our native functions +- `shape_inference.{cpp,h}` + - Shape inference headers for supported ops and autogen'd placeholders for unimplemented functions + +### Base Backend ([`python/torch_mlir/csrc/base_lazy_backend`](../python/torch_mlir/csrc/base_lazy_backend)) + +- `backend_impl.{cpp,h}` + - Base LTC backend to setup Torch-MLIR lowering context +- `dynamic_ir.{cpp,h}` + - Manually implemented "dynamic" nodes +- `ir_builder.h` + - Torch-MLIR implementation of `torch::lazy::IrBuilder` +- `mlir_lowering_context.h` + - Handles conversion from `torch::lazy::Node` to MLIR via JIT and Torch-MLIR infrastructure +- `mlir_native_functions.cpp` + - Manually implemented native functions +- `mlir_node.{cpp,h}` + - Torch-MLIR implementation of `torch::lazy::Node` +- `mlir_node_lowering.{cpp,h}` + - Lower a `torch::lazy::Node` to JIT graph in preparation for MLIR generation +- `shape_inference.cpp` + - Implementation of select shape inference functions (most functions are [implemented upstream](https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/core/shape_inference.cpp)) + +### Examples ([`examples`](../examples)) + +- `examples/ltc_backend/ltc_backend/csrc/backend/backend_impl.{cpp,h}` + - Example Torch-MLIR LTC backend implementation, which simply stores the MLIR as a string and executes computation on CPU +- `examples/ltc_backend/ltc_backend/csrc/example_mlir_backend_pybind.cpp` + - pybind for example Torch-MLIR LTC backend +- `ltc_backend_bert.py` + - Example HuggingFace BERT model traced by LTC to MLIR +- `ltc_backend_mnist.py` + - Example MNIST model traced by LTC to MLIR + +## Architecture + +### Tracing LTC graph + +The journey begins with a tensor in PyTorch on the `lazy` device, which may undergo a number of operations during its lifetime. +```python +>>> ltc_backend._initialize() +>>> x = torch.tensor(..., device='lazy') +>>> y = torch.tanh(x) +... +``` +The call to `torch.tanh` triggers a chain of events. PyTorch checks the dispatch table under the `lazy` key and finds the kernel for `tanh` +previously registered in `RegisterLazy.cpp`. + +Next, `LazyNativeFunctions::tanh` from `LazyNativeFunctions.cpp` is called, which triggers the creation of a `Tanh` node, which is a subclass of `TorchMlirNode` and `torch::lazy::Node`, defined in `LazyIr.h`. +These nodes are then tracked internally by LTC as the computation graph is traced out. + +![Tracing Tensors](ltc_images/tracing_tensors.jpg) + +### Syncing Tensors + +At some point, the tensors will be synced in order to execute the computation -- either explicitly via `mark_step`, or implicitly through some operation that requires the contents of the tensors (e.g. printing to console). + +```python +>>> torch._lazy.mark_step() +``` + +This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and +creates an instance of `TorchMlirLoweringContext`. Here, the `TorchMlirNode`s are lowered to JIT via `mlir_node_lowering.cpp` and inserted into a `jit::Graph`. + +Next, `TorchMlirLoweringContext::Build` is executed and the final `jit::Graph` is sent to `torch_mlir::importJitFunctionAsFuncOp` to generate MLIR using the existing infrastructure from Torch-MLIR. +At this point, a `TorchMlirComputation` is created containing the final `mlir::FuncOp`. + +![Syncing Tensors](ltc_images/syncing_tensors.jpg) + +### Final Compilation and Execution + +The `TorchMlirComputation` is sent to the vendor specific implementation of `TorchMlirBackendImpl::Compile` to be handed off to the vendor's compilation stack (if applicable). + +Finally, the compiled computation is sent to `TorchMlirBackendImpl::ExecuteComputation` to be executed on the vendor device, which produces some results to be send back to PyTorch. + +![Vendor Execution](ltc_images/vendor_execution.jpg) + +## Implementing a custom backend + +An example implementation of a custom backend is available [here](../examples/ltc_backend/ltc_backend). +All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device. + +A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself. +Most of the code in the example implementation should be reusable, excluding some debug related function (e.g. `get_latest_computation`). + +## Future Expansion + +There are a number of areas for future improvement: +- Generate source information in `jit::Graph` so it can be embedded in the MLIR +- Currently the example backend implementation executes via the `jit::Graph` instead of the MLIR since we currently lack lowerings for many ops, which would make it difficult to run models such as HF BERT + - In the future, we should change the implementation to lower the MLIR to linalg and execute on a reference backend +- As new models get tested, we will inevitably run into errors related to unimplemented shape inference functions. +This problem is simply solved by implementing the missing function, or adding a structured kernel to PyTorch. diff --git a/docs/ltc_examples.md b/docs/ltc_examples.md new file mode 100644 index 00000000000..306dabb8a72 --- /dev/null +++ b/docs/ltc_examples.md @@ -0,0 +1,54 @@ +# Torch-MLIR Lazy Tensor Core Backend Examples + +Refer to the main documentation [here](ltc_backend.md). + +## Example Usage +```python +import torch +import torch._lazy +import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend + +# Register the example LTC backend. +ltc_backend._initialize() + +device = 'lazy' + +# Create some tensors and perform operations. +inputs = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.float32, device=device) +outputs = torch.tanh(inputs) + +# Mark end of training/evaluation iteration and lower traced graph. +torch._lazy.mark_step() +print('Results:', outputs) + +# Optionally dump MLIR graph generated from LTC trace. +computation = ltc_backend.get_latest_computation() +if computation: + print(computation.debug_string()) +``` + +``` +Received 1 computation instances at Compile! +Received 1 arguments, and returned 2 results during ExecuteCompile! + +Results: tensor([[0.7616, 0.9640, 0.9951, 0.9993, 0.9999]], device='lazy:0') + +JIT Graph: +graph(%p0 : Float(1, 5)): + %1 : Float(1, 5) = aten::tanh(%p0) + return (%p0, %1) + +MLIR: +func.func @graph(%arg0: !torch.vtensor<[1,5],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32>) { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[1,5],f32> -> !torch.vtensor<[1,5],f32> + return %arg0, %0 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32> +} + +Input/Output Alias Mapping: +Output: 0 -> Input param: 0 + +In Mark Step: true +``` + +## Example Models +There are also examples of a [HuggingFace BERT](../examples/ltc_backend_bert.py) and [MNIST](../examples/ltc_backend_mnist.py) model running on the example LTC backend. diff --git a/docs/ltc_images/syncing_tensors.jpg b/docs/ltc_images/syncing_tensors.jpg new file mode 100644 index 00000000000..8965f75f428 Binary files /dev/null and b/docs/ltc_images/syncing_tensors.jpg differ diff --git a/docs/ltc_images/tracing_tensors.jpg b/docs/ltc_images/tracing_tensors.jpg new file mode 100644 index 00000000000..d59d58e8598 Binary files /dev/null and b/docs/ltc_images/tracing_tensors.jpg differ diff --git a/docs/ltc_images/vendor_execution.jpg b/docs/ltc_images/vendor_execution.jpg new file mode 100644 index 00000000000..9a96dd29d06 Binary files /dev/null and b/docs/ltc_images/vendor_execution.jpg differ