From d847721b0e764696d8be352259f8523b0ac4dffc Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:42:02 -0400 Subject: [PATCH] Fix Torch-MLIR LTC Backend based off latest PyTorch master (#723) * Changes as a result of the LTC TS backend decoupling * Fix bugs in BackendImpl and codegen * Fix based on latest PyTorch master --- .gitignore | 2 +- build_tools/autogen_ltc_backend.py | 58 ++++- build_tools/autogen_ltc_backend.yaml | 4 +- python/torch_mlir/csrc/CMakeLists.txt | 4 +- .../csrc/backend/LazyShapeInference.h | 1 + .../csrc/backend/aten_eager_fallback.cpp | 58 ----- .../csrc/backend/aten_eager_fallback.h | 27 -- .../torch_mlir/csrc/backend/backend_impl.cpp | 80 +++--- python/torch_mlir/csrc/backend/backend_impl.h | 19 +- .../csrc/backend/mlir_lowering_context.cpp | 38 ++- .../csrc/backend/mlir_lowering_context.h | 15 +- ...lir_type.cpp => mlir_native_functions.cpp} | 158 +++++------- python/torch_mlir/csrc/backend/mlir_node.cpp | 109 +------- python/torch_mlir/csrc/backend/mlir_node.h | 61 +---- python/torch_mlir/csrc/tensor_aten_ops.cpp | 242 ------------------ python/torch_mlir/csrc/tensor_aten_ops.h | 79 ------ 16 files changed, 218 insertions(+), 737 deletions(-) delete mode 100644 python/torch_mlir/csrc/backend/aten_eager_fallback.cpp delete mode 100644 python/torch_mlir/csrc/backend/aten_eager_fallback.h rename python/torch_mlir/csrc/backend/{aten_ltc_mlir_type.cpp => mlir_native_functions.cpp} (71%) delete mode 100644 python/torch_mlir/csrc/tensor_aten_ops.cpp delete mode 100644 python/torch_mlir/csrc/tensor_aten_ops.h diff --git a/.gitignore b/.gitignore index 9b0f788538d..846e39906dc 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,7 @@ bazel-* # Autogenerated files /generated_native_functions.yaml /generated_backend.hash -/python/torch_mlir/csrc/backend/LazyLazyIr.h +/python/torch_mlir/csrc/backend/LazyIr.h /python/torch_mlir/csrc/backend/LazyNativeFunctions.cpp /python/torch_mlir/csrc/backend/LazyNativeFunctions.h /python/torch_mlir/csrc/backend/GenLazyShapeInference.cpp diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index fe2b5b31dba..815e1e7080d 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -24,6 +24,10 @@ from codegen.model import NativeFunctionsGroup +def isOptionalCType(arg): + return str(type(arg)) == "" + + def generate_native_functions( config_path: Path, torch_ops_file: Path, out_file: Path ): @@ -98,7 +102,7 @@ def get_native_function_name(f): yaml.dump( { "backend": "Lazy", - "cpp_namespace": "torch_lazy_tensors", + "cpp_namespace": "torch::lazy", "full_codegen": opnames, "supported": sorted(supported_ops), }, @@ -120,21 +124,46 @@ def get_native_function_name(f): @dataclass(frozen=True) class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR): - lowering_function_type: str = "torch::lazy::MlirFunction" - lowering_context_type: str = "torch::lazy::MlirLoweringContext*" - lowering_return_type: str = "torch::lazy::MlirOpVector" - def lowering_body(self, f): + def lowering_function(self, f): func = ( f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func ) schema = LazyIrSchema(func) + emplace_arguments = [] + for arg in schema.positional_args: + if arg.is_lazy_value: + if isOptionalCType(arg.lazy_type): + emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr") + continue + emplace_arguments.append('loctx->GetOutputOp(operand(i++))') + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments]) + emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values] + emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars] + emplace_kwarguments = "\n ".join( + [f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars]) + return f""" - UNIMPLEMENTED_ERROR( - "'{func}' lowering not yet implemented" - ); - """.rstrip() + TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override {{ + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments); + CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)}); + + return {schema.aten_name}_out; + }} + """.strip() def generate_backend( @@ -151,14 +180,14 @@ def gen_fallback_code(*args, **kwargs): codegen.dest.lazy_ir.gen_fallback_code = gen_fallback_code - codegen.gen_lazy_tensor.run( + codegen.gen_lazy_tensor.run_gen_lazy_tensor( backend_name="TorchMlir", + aten_path=str(TORCH_DIR.joinpath("aten", "src", "ATen")), source_yaml=str(source_yaml), output_dir=str(backend_path), dry_run=False, - impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")), - gen_ts_lowerings=False, - node_base="torch::lazy::MlirNode", + impl_path=str(backend_path.joinpath("mlir_native_functions.cpp")), + node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(backend_path.joinpath("mlir_node.h")), tensor_class="torch::lazy::LazyTensor", tensor_class_hdr="torch/csrc/lazy/core/tensor.h", @@ -298,7 +327,6 @@ def main(args): new_hash = m.hexdigest().strip() if args.force or new_hash != prev_hash: - hash_file.write_text(new_hash) parsed_yaml, grouped_native_functions = generate_native_functions( config_path, torch_ops_file, native_functions ) @@ -310,6 +338,8 @@ def main(args): grouped_native_functions, ) + hash_file.write_text(new_hash) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/build_tools/autogen_ltc_backend.yaml b/build_tools/autogen_ltc_backend.yaml index 573eee944b2..f0dee793625 100644 --- a/build_tools/autogen_ltc_backend.yaml +++ b/build_tools/autogen_ltc_backend.yaml @@ -37,10 +37,9 @@ supported: - empty - expand - fill_ -# - native_batch_norm_backward - native_batch_norm +# - native_batch_norm_backward - permute -- repeat - squeeze - t - unsqueeze @@ -50,3 +49,4 @@ additional_ops: # Additional ops to support that are not supported by Torch-MLIR explicitly - _copy_from - _copy_from_and_resize +- native_batch_norm_backward diff --git a/python/torch_mlir/csrc/CMakeLists.txt b/python/torch_mlir/csrc/CMakeLists.txt index 9f52a062773..538362d63f4 100644 --- a/python/torch_mlir/csrc/CMakeLists.txt +++ b/python/torch_mlir/csrc/CMakeLists.txt @@ -20,16 +20,14 @@ link_directories("${TORCH_INSTALL_PREFIX}/lib") add_library(torch_mlir_ltc_backend SHARED - backend/aten_eager_fallback.cpp - backend/aten_ltc_mlir_type.cpp backend/backend_impl.cpp backend/LazyNativeFunctions.cpp backend/LazyShapeInference.cpp backend/GenLazyShapeInference.cpp backend/mlir_lowering_context.cpp + backend/mlir_native_functions.cpp backend/mlir_node.cpp backend/RegisterLazy.cpp - tensor_aten_ops.cpp ) target_link_libraries(torch_mlir_ltc_backend diff --git a/python/torch_mlir/csrc/backend/LazyShapeInference.h b/python/torch_mlir/csrc/backend/LazyShapeInference.h index 5aae29c8399..c614a535540 100644 --- a/python/torch_mlir/csrc/backend/LazyShapeInference.h +++ b/python/torch_mlir/csrc/backend/LazyShapeInference.h @@ -69,6 +69,7 @@ TORCH_API std::vector compute_shape_new_zeros(const at::Tensor & self, at TORCH_API std::vector compute_shape_rand_like(const at::Tensor & self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, c10::optional memory_format); TORCH_API std::vector compute_shape_relu(const at::Tensor & self); TORCH_API std::vector compute_shape_relu_(at::Tensor & self); +TORCH_API std::vector compute_shape_repeat(const at::Tensor & self, at::IntArrayRef repeats); TORCH_API std::vector compute_shape_reshape(const at::Tensor & self, at::IntArrayRef shape); TORCH_API std::vector compute_shape_rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); TORCH_API std::vector compute_shape_select(const at::Tensor & self, int64_t dim, int64_t index); diff --git a/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp b/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp deleted file mode 100644 index fd7319c18e6..00000000000 --- a/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp +++ /dev/null @@ -1,58 +0,0 @@ -//===- aten_eager_fallback.cpp --------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.cpp -//===----------------------------------------------------------------------===// - -#include -#include - -#include -#include - -#include "../utils/exception.h" -#include "aten_eager_fallback.h" - -namespace torch_lazy_tensors { - -static std::unordered_map - _eager_fallback_counters; - -bool force_eager_fallback(c10::Symbol op) { - static char* force_str = std::getenv("LTC_FORCE_FALLBACK"); - if (force_str != nullptr) { - static auto force_sym = c10::Symbol::fromQualString(std::string(force_str)); - if (op == force_sym) { - std::cout << "MLIR force_eager_fallback(" << force_str << "): true" - << std::endl; - return true; - } - } - std::cout << "MLIR force_eager_fallback(" << op.toQualString() << "): false" - << std::endl; - return false; -} - -void ltc_eager_fallback( - const c10::OperatorHandle& op, torch::jit::Stack* stack) { - const auto name = c10::toString(op.operator_name()); - UNSUPPORTED_ERROR( - "MLIR ltc_eager_fallback is not supported for op: " << name); -} - -std::function register_mlir_ltc_eager_fallback; - -TORCH_LIBRARY_IMPL(_, Lazy, m) { - register_mlir_ltc_eager_fallback = [&]() { - m.fallback( - torch::CppFunction::makeFromBoxedFunction<<c_eager_fallback>()); - }; -} - -} // namespace torch_lazy_tensors diff --git a/python/torch_mlir/csrc/backend/aten_eager_fallback.h b/python/torch_mlir/csrc/backend/aten_eager_fallback.h deleted file mode 100644 index 7c6218a3c89..00000000000 --- a/python/torch_mlir/csrc/backend/aten_eager_fallback.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- aten_eager_fallback.h ----------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// Facilitates eager fallback behaviour -// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_eager_fallback.h -//===----------------------------------------------------------------------===// - -#pragma once - -#include - -namespace torch_lazy_tensors { - -bool force_eager_fallback(c10::Symbol op); -void ltc_eager_fallback( - const c10::OperatorHandle& op, torch::jit::Stack* stack); - -extern TORCH_API std::function register_mlir_ltc_eager_fallback; - -} // namespace torch_lazy_tensors diff --git a/python/torch_mlir/csrc/backend/backend_impl.cpp b/python/torch_mlir/csrc/backend/backend_impl.cpp index 75896416531..dfd5f4c2401 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/backend/backend_impl.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp //===----------------------------------------------------------------------===// #include @@ -23,77 +23,79 @@ namespace torch { namespace lazy { -MlirBackendData::MlirBackendData(BackendDevice device, Shape shape) - : BackendData(device, shape) { +TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) + : BackendData(device, shape), + info_(std::make_unique()) { PRINT_FUNCTION(); - auto info = std::make_shared(); - SetInfo(info); } -MlirBackendData::MlirBackendData(const at::Scalar& scalar, BackendDevice device) - : BackendData(device, Shape(scalar.type(), {})) { +TorchMlirBackendData::TorchMlirBackendData( + const at::Scalar& scalar, BackendDevice device) + : BackendData(device, Shape(scalar.type(), {})), + info_(std::make_unique(scalar)) { PRINT_FUNCTION(); - auto info = std::make_shared(scalar); - SetInfo(info); } -MlirBackendData::MlirBackendData( +TorchMlirBackendData::TorchMlirBackendData( const at::Tensor& tensor, BackendDevice device, Shape shape) - : BackendData(device, shape) { + : BackendData(device, shape), + info_(std::make_unique(tensor)) { PRINT_FUNCTION(); - auto info = std::make_shared(tensor); - SetInfo(info); } -BackendData::Handle MlirBackendData::GetHandle() { +BackendData::Handle TorchMlirBackendData::GetHandle() { return reinterpret_cast(this); } -void MlirBackendData::Assign(const BackendData& data) { - MlirBackendData::Info* info = - dynamic_cast(data.info()); +void TorchMlirBackendData::Assign(const BackendData& data) { + TorchMlirBackendData::Info* info = + dynamic_cast(data.info()); TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."); - auto new_info = std::make_shared(*info); - SetInfo(new_info); + info, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + info_ = std::make_unique(*info); } -bool MlirBackendData::HasValue() const { return bool(info()); } +bool TorchMlirBackendData::HasValue() const { return bool(info_); } + +TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const { + return info_.get(); +} /** * Initialization/Teardown * */ -void MlirBackendImpl::PrepareToExit() const {} +void TorchMlirBackendImpl::PrepareToExit() const {} /** * Data Transfer * */ -BackendDataPtr MlirBackendImpl::MakeComputationDataFromTensor( +BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor( const at::Tensor& tensor, const Shape& shape, const BackendDevice& device) const { PRINT_FUNCTION(); - return std::make_shared(tensor, device, shape); + return std::make_shared(tensor, device, shape); } -BackendDataPtr MlirBackendImpl::MakeComputationDataFromScalar( +BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar( const at::Scalar& scalar, const BackendDevice& device) const { PRINT_FUNCTION(); - return std::make_shared(scalar, device); + return std::make_shared(scalar, device); } -BackendDataPtr MlirBackendImpl::CreateDataPlaceholder( +BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder( const BackendDevice& device, const Shape& shape) const { PRINT_FUNCTION(); - return std::make_shared(device, shape); + return std::make_shared(device, shape); } -at::Tensor MlirBackendImpl::MakeTensorFromComputationData( +at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( const BackendDataPtr data, c10::optional logical_scalar_type) const { PRINT_FUNCTION(); - MlirBackendData::Info* info = - dynamic_cast(data->info()); + TorchMlirBackendData::Info* info = + dynamic_cast(data->info()); TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected MlirBackendData::Info."); + info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); return info->tensor; } @@ -101,20 +103,20 @@ at::Tensor MlirBackendImpl::MakeTensorFromComputationData( * Lowering, Compilation, Execution * */ -std::unique_ptr MlirBackendImpl::CreateLoweringContext( +std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( const std::string& name, BackendDevice device, c10::ArrayRef post_order, Util::EmissionMap emit_status) const { PRINT_FUNCTION(); - return std::make_unique( + return std::make_unique( name, std::forward(device), std::forward>(post_order), std::forward(emit_status)); } -std::unique_ptr MlirBackendImpl::CreateLoweringContext( +std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( const std::string& name, BackendDevice device) const { PRINT_FUNCTION(); - return std::make_unique( + return std::make_unique( name, std::forward(device)); } @@ -129,13 +131,13 @@ std::unique_ptr MlirBackendImpl::CreateLoweringContext( // Specify which aten device should be used for eager fallback // may change depending on current 'Default' DeviceType -at::DeviceType MlirBackendImpl::EagerFallbackDeviceType() const { +at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const { PRINT_FUNCTION(); return at::DeviceType::CPU; } // Query all available backend devices -std::vector MlirBackendImpl::GetBackendDevices() const { +std::vector TorchMlirBackendImpl::GetBackendDevices() const { PRINT_FUNCTION(); return { GetBackendDevice(c10::Device(c10::kLazy, 0)), @@ -148,7 +150,7 @@ std::vector MlirBackendImpl::GetBackendDevices() const { // scenes. In the future, non-virtual c10:: devices may also use lazy tensors // through a mode, in which case these APIs should still work, but should be // identity mappings. -BackendDevice MlirBackendImpl::GetBackendDevice(c10::Device device) const { +BackendDevice TorchMlirBackendImpl::GetBackendDevice(c10::Device device) const { PRINT_FUNCTION(); return BackendDevice(GetDefaultDeviceType(), device.index()); } diff --git a/python/torch_mlir/csrc/backend/backend_impl.h b/python/torch_mlir/csrc/backend/backend_impl.h index 5b0e0d6331d..33e509f6fb4 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.h +++ b/python/torch_mlir/csrc/backend/backend_impl.h @@ -10,7 +10,7 @@ // using the Torch-MLIR ATen dialect // // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h //===----------------------------------------------------------------------===// #pragma once @@ -23,7 +23,7 @@ namespace torch { namespace lazy { -class TORCH_API MlirBackendData : public BackendData { +class TORCH_API TorchMlirBackendData : public BackendData { public: struct Info : public BackendData::Info { at::Tensor tensor; @@ -39,20 +39,25 @@ class TORCH_API MlirBackendData : public BackendData { Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {} }; - MlirBackendData(BackendDevice device, Shape shape); - MlirBackendData(const at::Scalar& scalar, BackendDevice device); - MlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape); + TorchMlirBackendData(BackendDevice device, Shape shape); + TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); + TorchMlirBackendData(const at::Tensor& tensor, BackendDevice device, Shape shape); virtual BackendData::Handle GetHandle() override; virtual void Assign(const BackendData& data) override; virtual bool HasValue() const override; + + TorchMlirBackendData::Info* mlir_info() const; + +private: + std::unique_ptr info_; }; -class TORCH_API MlirBackendImpl : public BackendImplInterface { +class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { public: - virtual ~MlirBackendImpl() = default; + virtual ~TorchMlirBackendImpl() = default; /** * Initialization/Teardown diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp index e04374dbab0..d366080e256 100644 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp @@ -7,22 +7,23 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/ts_lowering_context.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp //===----------------------------------------------------------------------===// #include +#include "../utils/debug.h" #include "../utils/exception.h" #include "mlir_lowering_context.h" namespace torch { namespace lazy { -MlirLoweringContext::MlirLoweringContext( +TorchMlirLoweringContext::TorchMlirLoweringContext( const std::string& name, BackendDevice device) : LoweringContext(name, std::forward(device)) {} -MlirLoweringContext::MlirLoweringContext( +TorchMlirLoweringContext::TorchMlirLoweringContext( const std::string& name, BackendDevice device, c10::ArrayRef post_order, Util::EmissionMap emit_status) : LoweringContext( @@ -30,29 +31,34 @@ MlirLoweringContext::MlirLoweringContext( std::forward>(post_order), std::forward(emit_status)) {} -int MlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); } +int TorchMlirComputation::parameters_size() const { UNIMPLEMENTED_FUNCTION_ERROR(); } const std::vector& -MlirComputation::parameter_shapes() const { +TorchMlirComputation::parameter_shapes() const { UNIMPLEMENTED_FUNCTION_ERROR(); } -const std::vector& MlirComputation::parameter_names() const { +const std::vector& TorchMlirComputation::parameter_names() const { UNIMPLEMENTED_FUNCTION_ERROR(); } -const torch::lazy::Shape& MlirComputation::result_shape() const { +const torch::lazy::Shape& TorchMlirComputation::result_shape() const { + UNIMPLEMENTED_FUNCTION_ERROR(); +} + +std::string TorchMlirComputation::to_string() const { UNIMPLEMENTED_FUNCTION_ERROR(); } // Get the shape of the result tuple component, given by index. -torch::lazy::Shape MlirLoweringContext::GetResultShape(size_t index) const { +torch::lazy::Shape TorchMlirLoweringContext::GetResultShape(size_t index) const { UNIMPLEMENTED_FUNCTION_ERROR(); } // Adds the given output as a component of the result tuple and returns its // assigned position within the tuple. -size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) { +size_t TorchMlirLoweringContext::AddResult(const torch::lazy::Output& output) { + PRINT_FUNCTION(); const torch::lazy::Node* node; auto it = emitted_outputs_.find(output); if (it == emitted_outputs_.end()) { @@ -75,7 +81,7 @@ size_t MlirLoweringContext::AddResult(const torch::lazy::Output& output) { // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. -void MlirLoweringContext::AddParameter( +void TorchMlirLoweringContext::AddParameter( const torch::lazy::Output& output, size_t index, const torch::lazy::Shape& shape, const std::string& name) { UNIMPLEMENTED_FUNCTION_ERROR(); @@ -83,10 +89,18 @@ void MlirLoweringContext::AddParameter( // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). -ComputationPtr MlirLoweringContext::Build() { +ComputationPtr TorchMlirLoweringContext::Build() { + PRINT_FUNCTION() for (const torch::lazy::Node* output : result_tuple_) { } - return std::make_shared(); + return std::make_shared(); +} + +// Retrieves the lowered operation for an output. If the requested output is +// not available yet, the graph behind the output's Node is lowered, and the +// corresponding MLIR operation returned. +torch::jit::Value* GetOutputOp(const Output& output) { + UNIMPLEMENTED_FUNCTION_ERROR(); } } // namespace lazy diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.h b/python/torch_mlir/csrc/backend/mlir_lowering_context.h index 963d2759795..c8e526b131f 100644 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.h +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.h @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_lowering_context.h +// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.h //===----------------------------------------------------------------------===// #pragma once @@ -19,7 +19,7 @@ namespace torch { namespace lazy { -class TORCH_API MlirComputation : public torch::lazy::Computation { +class TORCH_API TorchMlirComputation : public torch::lazy::Computation { public: int parameters_size() const override; @@ -31,11 +31,11 @@ class TORCH_API MlirComputation : public torch::lazy::Computation { virtual const torch::lazy::Shape& result_shape() const override; }; -class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext { +class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { public: - MlirLoweringContext( + TorchMlirLoweringContext( const std::string& name, torch::lazy::BackendDevice device); - MlirLoweringContext( + TorchMlirLoweringContext( const std::string& name, torch::lazy::BackendDevice device, c10::ArrayRef post_order, torch::lazy::Util::EmissionMap emit_status); @@ -58,6 +58,11 @@ class TORCH_API MlirLoweringContext : public torch::lazy::LoweringContext { // embedded builder (returned by the builder() API). virtual torch::lazy::ComputationPtr Build() override; + // Retrieves the lowered operation for an output. If the requested output is + // not available yet, the graph behind the output's Node is lowered, and the + // corresponding MLIR operation returned. + torch::jit::Value* GetOutputOp(const Output& output); + private: std::vector result_tuple_; torch::lazy::OutputMap emitted_outputs_; diff --git a/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp b/python/torch_mlir/csrc/backend/mlir_native_functions.cpp similarity index 71% rename from python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp rename to python/torch_mlir/csrc/backend/mlir_native_functions.cpp index 91d352414ef..b3726b1d6ba 100644 --- a/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp +++ b/python/torch_mlir/csrc/backend/mlir_native_functions.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// #include @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -24,13 +25,13 @@ #include "ATen/MetaFunctions.h" #include -#include "../tensor_aten_ops.h" #include "../utils/exception.h" #include "../utils/sys_utils.h" #include "LazyNativeFunctions.h" #include "LazyShapeInference.h" -namespace torch_lazy_tensors { +namespace torch { +namespace lazy { namespace { @@ -121,7 +122,7 @@ GetLtcDevice(const c10::optional& device) { // UNIMPLEMENTED_FUNCTION_ERROR(); // // return torch::lazy::CreateAtenFromLtcTensor( -// // lazy_tensor_aten_ops::bernoulli(self_tensor)); +// // torch::lazy::bernoulli(self_tensor)); // } // at::Tensor& LazyNativeFunctions::bernoulli_( @@ -133,7 +134,7 @@ GetLtcDevice(const c10::optional& device) { // auto self_tensor = torch::lazy::TryGetLtcTensor(self); // UNIMPLEMENTED_FUNCTION_ERROR(); -// // lazy_tensor_aten_ops::bernoulli_(self_tensor, p); +// // torch::lazy::bernoulli_(self_tensor, p); // // return self; // } @@ -208,7 +209,7 @@ at::Tensor LazyNativeFunctions::_copy_from( dst_tensor_data->copy_(self_tensor->ToTensor(/*detached=*/false)); } } else { - lazy_tensor_aten_ops::copy_(dst_tensor, self_tensor); + torch::lazy::copy_(dst_tensor, self_tensor); auto* impl = dynamic_cast(dst.unsafeGetTensorImpl()); impl->set_tensor(dst_tensor); @@ -260,15 +261,15 @@ at::Tensor LazyNativeFunctions::expand( const at::Tensor& self, at::IntArrayRef size, bool implicit) { TORCH_LAZY_FN_COUNTER("lazy::"); UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::expand( - // torch::lazy::TryGetLtcTensor(self), size.vec())); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::expand(torch::lazy::TryGetLtcTensor(self), size.vec())); } at::Tensor& LazyNativeFunctions::fill_(at::Tensor& self, const at::Scalar& value) { TORCH_LAZY_FN_COUNTER("lazy::"); auto self_tensor = torch::lazy::TryGetLtcTensor(self); - lazy_tensor_aten_ops::fill_(self_tensor, value); + torch::lazy::fill_(self_tensor, value); return self; } @@ -280,110 +281,86 @@ LazyNativeFunctions::native_batch_norm( const c10::optional& running_var, bool training, double momentum, double eps) { TORCH_LAZY_FN_COUNTER("lazy::"); - torch::lazy::LazyTensorPtr input_tensor = torch::lazy::TryGetLtcTensor(input); + auto input_tensor = torch::lazy::TryGetLtcTensor(input); const torch::lazy::BackendDevice& device = input_tensor->GetDevice(); - torch::lazy::LazyTensorPtr running_mean_tensor = - GetOrCreateLtcTensor(running_mean, device); - torch::lazy::LazyTensorPtr running_var_tensor = - GetOrCreateLtcTensor(running_var, device); - UNIMPLEMENTED_FUNCTION_ERROR(); - // auto outputs = lazy_tensor_aten_ops::ts_native_batch_norm( - // torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, - // device), - // GetOrCreateLtcTensor(bias, device), running_mean_tensor, - // running_var_tensor, training, momentum, eps); - // return - // std::make_tuple(torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)), - // torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)), - // torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs))); + auto running_mean_tensor = GetOrCreateLtcTensor(running_mean, device); + auto running_var_tensor = GetOrCreateLtcTensor(running_var, device); + auto outputs = torch::lazy::native_batch_norm( + torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device), + GetOrCreateLtcTensor(bias, device), running_mean_tensor, + running_var_tensor, training, momentum, eps); + return std::make_tuple( + torch::lazy::CreateAtenFromLtcTensor(std::get<0>(outputs)), + torch::lazy::CreateAtenFromLtcTensor(std::get<1>(outputs)), + torch::lazy::CreateAtenFromLtcTensor(std::get<2>(outputs))); } -// std::tuple -// LazyNativeFunctions::native_batch_norm_backward( -// const at::Tensor& grad_out, const at::Tensor& input, -// const c10::optional& weight, -// const c10::optional& running_mean, -// const c10::optional& running_var, -// const c10::optional& save_mean, -// const c10::optional& save_invstd, bool train, double eps, -// std::array output_mask) { -// TORCH_LAZY_FN_COUNTER("lazy::"); -// torch::lazy::LazyTensor grad_out_tensor = -// torch::lazy::TryGetLtcTensor(grad_out); -// const torch::lazy::BackendDevice& device = grad_out_tensor.GetDevice(); -// torch::lazy::LazyTensor null_tensor; -// bool running_stats = running_mean && running_mean->defined(); -// CHECK_EQ(running_var && running_var->defined(), running_stats); -// UNIMPLEMENTED_FUNCTION_ERROR(); -// // auto gradients = lazy_tensor_aten_ops::ts_native_batch_norm_backward( -// // torch::lazy::TryGetLtcTensor(grad_out), -// torch::lazy::TryGetLtcTensor(input), -// // GetOrCreateLtcTensor(weight, device), -// // running_stats ? GetOrCreateLtcTensor(running_mean, device) -// // : null_tensor, -// // running_stats ? GetOrCreateLtcTensor(running_var, device) -// // : null_tensor, -// // GetOrCreateLtcTensor(save_mean, device), -// // GetOrCreateLtcTensor(save_invstd, device), train, eps, -// // output_mask); -// // at::Tensor undefined; -// // return std::make_tuple( -// // output_mask[0] ? -// torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients)) -// // : undefined, -// // output_mask[1] ? -// torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients)) -// // : undefined, -// // output_mask[2] ? -// torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients)) -// // : undefined); -// } - -at::Tensor -LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) { +std::tuple +LazyNativeFunctions::native_batch_norm_backward( + const at::Tensor& grad_out, const at::Tensor& input, + const c10::optional& weight, + const c10::optional& running_mean, + const c10::optional& running_var, + const c10::optional& save_mean, + const c10::optional& save_invstd, bool train, double eps, + std::array output_mask) { TORCH_LAZY_FN_COUNTER("lazy::"); - torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); - UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::permute( - // self_tensor, torch::lazy::ToI64Vector(dims))); + auto grad_out_tensor = torch::lazy::TryGetLtcTensor(grad_out); + const torch::lazy::BackendDevice& device = grad_out_tensor->GetDevice(); + torch::lazy::LazyTensorPtr null_tensor; + bool running_stats = running_mean && running_mean->defined(); + CHECK_EQ(running_var && running_var->defined(), running_stats); + auto gradients = torch::lazy::native_batch_norm_backward( + torch::lazy::TryGetLtcTensor(grad_out), + torch::lazy::TryGetLtcTensor(input), GetOrCreateLtcTensor(weight, device), + running_stats ? GetOrCreateLtcTensor(running_mean, device) : null_tensor, + running_stats ? GetOrCreateLtcTensor(running_var, device) : null_tensor, + GetOrCreateLtcTensor(save_mean, device), + GetOrCreateLtcTensor(save_invstd, device), train, eps, output_mask); + at::Tensor undefined; + return std::make_tuple( + output_mask[0] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<0>(gradients)) + : undefined, + output_mask[1] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<1>(gradients)) + : undefined, + output_mask[2] + ? torch::lazy::CreateAtenFromLtcTensor(std::get<2>(gradients)) + : undefined); } at::Tensor -LazyNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef repeats) { +LazyNativeFunctions::permute(const at::Tensor& self, at::IntArrayRef dims) { TORCH_LAZY_FN_COUNTER("lazy::"); + torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::repeat( - // torch::lazy::TryGetLtcTensor(self), - // torch::lazy::ToI64Vector(repeats))); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::permute(self_tensor, torch::lazy::ToI64Vector(dims))); } at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER("lazy::"); - UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor( - // lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self))); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self))); } at::Tensor LazyNativeFunctions::squeeze(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); - UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor( - // lazy_tensor_aten_ops::squeeze(torch::lazy::TryGetLtcTensor(self), - // dim)); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::squeeze(torch::lazy::TryGetLtcTensor(self), dim)); } at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER("lazy::"); - return torch::lazy::CreateAtenFromLtcTensor(lazy_tensor_aten_ops::transpose( - torch::lazy::TryGetLtcTensor(self), 0, 1)); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::transpose(torch::lazy::TryGetLtcTensor(self), 0, 1)); } at::Tensor LazyNativeFunctions::unsqueeze(const at::Tensor& self, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); - UNIMPLEMENTED_FUNCTION_ERROR(); - // return torch::lazy::CreateAtenFromLtcTensor( - // lazy_tensor_aten_ops::unsqueeze(torch::lazy::TryGetLtcTensor(self), - // dim)); + return torch::lazy::CreateAtenFromLtcTensor( + torch::lazy::unsqueeze(torch::lazy::TryGetLtcTensor(self), dim)); } at::Tensor @@ -391,9 +368,10 @@ LazyNativeFunctions::view(const at::Tensor& self, at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); torch::lazy::LazyTensorPtr self_tensor = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( - lazy_tensor_aten_ops::view(self_tensor, torch::lazy::ToI64Vector(size))); + torch::lazy::view(self_tensor, torch::lazy::ToI64Vector(size))); } void InitializeAtenBindings() {} -} // namespace torch_lazy_tensors +} // namespace lazy +} // namespace torch diff --git a/python/torch_mlir/csrc/backend/mlir_node.cpp b/python/torch_mlir/csrc/backend/mlir_node.cpp index e1b910bf955..f0ba45b5ddb 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.cpp +++ b/python/torch_mlir/csrc/backend/mlir_node.cpp @@ -18,112 +18,9 @@ namespace torch { namespace lazy { -namespace { - -hash_t OperandHashes( - const OpList& operands, const hash_t& seed, const bool bakeInSizes) { - hash_t hash = seed; - for (auto& operand : operands) { - if (!operand) { - hash = HashCombine(hash, static_cast(kNullOpt)); - continue; - } - auto operand_hash = - bakeInSizes ? operand.hash_with_sizes() : operand.hash_without_sizes(); - hash = HashCombine(hash, operand_hash); - } - return hash; -} - -hash_t GetOpHash( - OpKind op, const Shape& shape, hash_t hash_seed, const bool bakeInSizes) { - hash_t h = HashCombine(op.hash(), shape.hash(bakeInSizes)); - return HashCombine(h, hash_seed); -} - -} // namespace - -MlirNode::MlirNode( - OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, - hash_t hash_seed) - : Node( - op, num_outputs, - /* node_hash */ HashCombine(op.hash(), hash_seed), - /* dag_hash */ - [&](bool bakeInSizes) -> hash_t { - return OperandHashes( - operands, HashCombine(op.hash(), hash_seed), bakeInSizes); - }), - shapes_(std::move(shapes)) { - for (auto& operand : operands) { - // Ideally, optional operands should be filtered by the leaf node classes, - // but it's just much easier to do it here. - if (!operand) { - continue; - } - - AddOperand(operand.node, operand.index); - } -} - -MlirNode::MlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed) - : MlirNode(op, operands, std::vector{}, num_outputs, hash_seed) { - shapes_.push_back(GetOpShape(shape_fn)); -} - -MlirNode::MlirNode( - OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) - : MlirNode(op, operands, std::vector{}, num_outputs, hash_seed) {} - -void MlirNode::SetShapeDeferred(const std::function& shape_fn) { - shapes_.push_back(GetOpShape(shape_fn)); -} - -MlirNode::MlirNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) - : Node(op, num_outputs, [&](bool bakeInSizes) -> hash_t { - return GetOpHash(op, shape, hash_seed, bakeInSizes); - }) { - shapes_.push_back(std::move(shape)); -} - -using ShapeCache = Cache; - -constexpr const int torch_lazy_shape_cache_size = 4096; - -ShapeCache* GetShapeCache() { - static ShapeCache* cache = new ShapeCache(torch_lazy_shape_cache_size); - return cache; -} - -Shape MlirNode::GetOpShape(const std::function& shape_fn) const { - ShapeCache* shape_cache = GetShapeCache(); - auto shape = shape_cache->Get(hash()); - if (shape == nullptr) { - shape = shape_cache->Add(hash(), std::make_shared(shape_fn())); - } - return *shape; -} - -c10::ArrayRef MlirNode::shapes() const { return shapes_; } - -const Shape& MlirNode::shape(size_t output_index) const { - return shapes_.at(output_index); -} - -const std::vector& MlirNode::operands() const { - return operands_as_outputs_; -} - -const Output& MlirNode::operand(size_t i) const { - return operands_as_outputs_.at(i); -} - -void MlirNode::AddOperand(NodePtr node, size_t index) { - CHECK_LT(index, node->num_outputs()); - operands_.push_back(std::move(node)); - operands_as_outputs_.emplace_back(operands_.back().get(), index); +TorchMlirOpVector +TorchMlirNode::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { + return {}; } } // namespace lazy diff --git a/python/torch_mlir/csrc/backend/mlir_node.h b/python/torch_mlir/csrc/backend/mlir_node.h index e024ba3ba09..71068ac81e1 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.h +++ b/python/torch_mlir/csrc/backend/mlir_node.h @@ -7,76 +7,33 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/torch/csrc/lazy/ts_backend/ts_node.h +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_node.h //===----------------------------------------------------------------------===// #pragma once #include +#include #include #include #include +#include "../utils/debug.h" #include "../utils/exception.h" -#include "aten_eager_fallback.h" #include "mlir_lowering_context.h" namespace torch { namespace lazy { -typedef std::vector MlirOpVector; -typedef NodePtr MlirFunction; - -class TORCH_API MlirNode : public torch::lazy::Node { +typedef std::vector TorchMlirOpVector; +typedef std::shared_ptr TorchMlirFunction; +class TORCH_API TorchMlirNode : public torch::lazy::Node { public: - MlirNode( - OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs = 1, hash_t hash_seed = kHashSeed); - - // Same as the constructor above, but the shape is generated by a function, - // only if needed (shape cache miss). - MlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs = 1, hash_t hash_seed = kHashSeed); - - // The shape is set later. - MlirNode( - OpKind op, OpList operands, size_t num_outputs = 1, - hash_t hash_seed = kHashSeed); - - void SetShapeDeferred(const std::function& shape_fn); - - // Contructor used to create leaf nodes. - MlirNode( - OpKind op, Shape shape, size_t num_outputs = 1, - hash_t hash_seed = kHashSeed); - - Shape GetOpShape(const std::function& shape_fn) const; - - // Retrieves the full shape of the IR Node. - c10::ArrayRef shapes() const override; - - // Retrieves the shape of the output at a given index. - const Shape& shape(size_t output_index = 0) const override; - - const std::vector& operands() const override; - - const Output& operand(size_t i) const override; - - virtual MlirOpVector - Lower(MlirFunction function, MlirLoweringContext* loctx) const = 0; - -private: - // Adds node's index output number as operand. - void AddOperand(NodePtr node, size_t index = 0); + using torch::lazy::Node::Node; - std::vector shapes_; - // A node holds a real reference to its operands. - std::vector operands_; - // Outputs do not hold references on the nodes, and neither do the uses, since - // otherwise we get into circular reference counting. - std::vector operands_as_outputs_; + virtual TorchMlirOpVector + Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; }; } // namespace lazy diff --git a/python/torch_mlir/csrc/tensor_aten_ops.cpp b/python/torch_mlir/csrc/tensor_aten_ops.cpp deleted file mode 100644 index e70d3066d19..00000000000 --- a/python/torch_mlir/csrc/tensor_aten_ops.cpp +++ /dev/null @@ -1,242 +0,0 @@ -//===- tensor_aten_ops.cpp ------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.cpp -//===----------------------------------------------------------------------===// - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensor_aten_ops.h" - -namespace torch_lazy_tensors { -namespace lazy_tensor_aten_ops { -namespace { - -// to enable operator+-*/ for Value -using namespace torch::lazy; - -torch::lazy::Value MaybeExpand( - const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) { - if (input.shape().sizes() == target_shape.sizes()) { - return input; - } - return torch::lazy::MakeNode( - input, target_shape.sizes().vec(), - /*is_scalar_expand=*/false); -} - -std::vector GetExpandDimensions( - const torch::lazy::Shape& shape, std::vector dimensions) { - CHECK_GE(dimensions.size(), shape.dim()) << shape; - int64_t base = dimensions.size() - shape.dim(); - for (size_t i = 0; i < shape.dim(); ++i) { - if (dimensions[base + i] == -1) { - dimensions[base + i] = shape.size(i); - } - } - return dimensions; -} - -// Returns a 1-D shape for batch norm weight or bias based on the input shape. -torch::lazy::Shape -BatchNormFeaturesShape(const torch::lazy::LazyTensorPtr& input) { - CHECK(input); - auto input_shape = input->shape().Get(); - return torch::lazy::Shape(input_shape.scalar_type(), input_shape.sizes()[1]); -} - -// Returns the IR for the given input or the provided default value broadcasted -// to the default shape, if the input is undefined. -torch::lazy::Value GetIrValueOrDefault( - const torch::lazy::LazyTensorPtr& input, const at::Scalar& default_value, - const torch::lazy::Shape& default_shape, - const torch::lazy::BackendDevice& device) { - return input ? input->GetIrValue() - : torch::lazy::LazyGraphExecutor::Get() - ->GetIrValueForExpandedScalar( - default_value, default_shape, device); -} - -torch::lazy::ViewInfo CreateAsStridedViewInfo( - const torch::lazy::Shape& input_shape, std::vector size, - std::vector stride, c10::optional storage_offset) { - torch::lazy::Shape result_shape = - torch::lazy::Shape(input_shape.scalar_type(), size); - torch::lazy::AsStridedInfo as_strided_info; - as_strided_info.stride = std::move(stride); - if (storage_offset) { - as_strided_info.offset = *storage_offset; - } - return torch::lazy::ViewInfo( - torch::lazy::ViewInfo::Type::kAsStrided, std::move(result_shape), - input_shape, std::move(as_strided_info)); -} - -} // namespace - -////////////////////////////////////////////////////////////////////////////// -// ATEN operators follows here, listed in alphabetical order. -////////////////////////////////////////////////////////////////////////////// -torch::lazy::LazyTensorPtr as_strided( - const torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, c10::optional storage_offset) { - auto input_shape = input->shape(); - return input->CreateViewTensor(CreateAsStridedViewInfo( - input_shape, std::move(size), std::move(stride), storage_offset)); -} - -void as_strided_( - torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, c10::optional storage_offset) { - if (input->data()->view == nullptr) { - input->SetIrValue(torch::lazy::MakeNode( - input->GetIrValue(), std::move(size), std::move(stride), - storage_offset.value_or(0))); - } else { - auto input_shape = input->shape(); - input->SetSubView(CreateAsStridedViewInfo( - input_shape, std::move(size), std::move(stride), storage_offset)); - } -} - -torch::lazy::LazyTensorPtr -expand(const torch::lazy::LazyTensorPtr& input, std::vector size) { - auto input_shape = input->shape(); - return torch::lazy::LazyTensor::Create( - torch::lazy::MakeNode( - input->GetIrValue(), - GetExpandDimensions(input_shape.Get(), std::move(size)), - /*is_scalar_expand=*/false), - input->GetDevice()); -} - -void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value) { - torch::lazy::Value constant = - torch::lazy::LazyGraphExecutor::Get()->GetIrValueForExpandedScalar( - value, input->shape(), input->GetDevice()); - input->SetInPlaceIrValue(std::move(constant)); -} - -torch::lazy::LazyTensorPtr narrow( - const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t length) { - auto input_shape = input->shape(); - dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); - torch::lazy::Shape narrow_shape = input_shape; - narrow_shape.set_size(dim, length); - - torch::lazy::ViewInfo::Type view_type = - (input_shape.Get().numel() == narrow_shape.numel()) - ? torch::lazy::ViewInfo::Type::kReshape - : torch::lazy::ViewInfo::Type::kNarrow; - torch::lazy::ViewInfo view_info( - view_type, std::move(narrow_shape), input_shape); - view_info.indices[dim] = - torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start); - return input->CreateViewTensor(std::move(view_info)); -} - -torch::lazy::LazyTensorPtr -permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef dims) { - auto input_shape = input->shape(); - torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kPermute, input_shape, - torch::lazy::GetCanonicalDimensionIndices(dims, input_shape.Get().dim())); - return input->CreateViewTensor(std::move(view_info)); -} - -void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { - if (input->GetDevice() == src->GetDevice()) { - torch::lazy::Value copy_value; - if (input->dtype() == src->dtype()) { - copy_value = src->GetIrValue(); - } else { - copy_value = torch::lazy::MakeNode( - src->GetIrValue(), input->dtype(), src->dtype()); - } - input->SetIrValue(MaybeExpand(copy_value, input->shape())); - } else { - auto input_shape = input->shape(); - at::Tensor src_tensor = src->ToTensor(/*detached=*/true); - if (src_tensor.sizes() != input_shape.Get().sizes()) { - src_tensor = src_tensor.expand(input_shape.Get().sizes().vec()); - } - input->UpdateFromTensor(std::move(src_tensor), /*sync=*/false); - } -} - -torch::lazy::LazyTensorPtr slice( - const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t end, int64_t step) { - auto input_shape = input->shape(); - dim = torch::lazy::GetCanonicalDimensionIndex(dim, input_shape.Get().dim()); - start = - torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, start); - end = torch::lazy::GetCanonicalPosition(input_shape.Get().sizes(), dim, end); - // PyTorch allows tensor[-1:0] to return a 0-dim tensor. - if (start > end) { - end = start; - } - step = std::min(step, end - start); - - torch::lazy::SelectInfo select = {dim, start, end, step}; - torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kSelect, input_shape, std::move(select)); - return input->CreateViewTensor(std::move(view_info)); -} - -torch::lazy::LazyTensorPtr -transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) { - auto input_shape = input->shape(); - auto permute_dims = torch::lazy::MakeTransposePermutation( - /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim()); - torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims); - return input->CreateViewTensor(std::move(view_info)); -} - -void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1) { - auto input_shape = input->shape(); - auto permute_dims = torch::lazy::MakeTransposePermutation( - /*dim0=*/dim0, /*dim1=*/dim1, /*rank=*/input_shape.Get().dim()); - torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kPermute, input_shape, permute_dims); - return input->ModifyCurrentView(std::move(view_info)); -} - -torch::lazy::LazyTensorPtr view( - const torch::lazy::LazyTensorPtr& input, - c10::ArrayRef output_size) { - auto input_shape = input->shape().Get(); - torch::lazy::Shape shape = torch::lazy::Shape( - input_shape.scalar_type(), - at::infer_size(output_size, input_shape.numel())); - torch::lazy::ViewInfo view_info( - torch::lazy::ViewInfo::Type::kReshape, std::move(shape), input_shape); - return input->CreateViewTensor(std::move(view_info)); -} - -} // namespace lazy_tensor_aten_ops -} // namespace torch_lazy_tensors diff --git a/python/torch_mlir/csrc/tensor_aten_ops.h b/python/torch_mlir/csrc/tensor_aten_ops.h deleted file mode 100644 index 3342e1ec9a4..00000000000 --- a/python/torch_mlir/csrc/tensor_aten_ops.h +++ /dev/null @@ -1,79 +0,0 @@ -//===- tensor_aten_ops.h --------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// -// This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/tensor_aten_ops.h -//===----------------------------------------------------------------------===// - -#pragma once - -#include - -namespace torch_lazy_tensors { -namespace lazy_tensor_aten_ops { - -////////////////////////////////////////////////////////////////////////////// -// ATEN operators follows here, listed in alphabetical order. -////////////////////////////////////////////////////////////////////////////// -// Takes a slice from the input as R1 at the specified offset and reshapes it -// into the provided size. -torch::lazy::LazyTensorPtr as_strided( - const torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, c10::optional storage_offset); - -// In-place version of the method above. -void as_strided_( - torch::lazy::LazyTensorPtr& input, std::vector size, - std::vector stride, c10::optional storage_offset); - -torch::lazy::LazyTensorPtr -expand(const torch::lazy::LazyTensorPtr& input, std::vector size); - -// Fills the input with the given value. -void fill_(torch::lazy::LazyTensorPtr& input, const at::Scalar& value); - -// Returns a new tensor that is a narrowed view of the input in the given -// dimension. -torch::lazy::LazyTensorPtr narrow( - const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t length); - -// Permute the dimensions of this tensor according to the given permutation. -torch::lazy::LazyTensorPtr -permute(const torch::lazy::LazyTensorPtr& input, c10::ArrayRef dims); - -// Repeats the input tensor along each dimension by the given number of -// repeats. -torch::lazy::LazyTensorPtr -repeat(const torch::lazy::LazyTensorPtr& input, std::vector repeats); - -void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src); - -torch::lazy::LazyTensorPtr slice( - const torch::lazy::LazyTensorPtr& input, int64_t dim, int64_t start, - int64_t end, int64_t step); - -std::tuple< - torch::lazy::LazyTensorPtr, torch::lazy::LazyTensorPtr, - torch::lazy::LazyTensorPtr> -svd(const torch::lazy::LazyTensorPtr& input, bool some, bool compute_uv); - -// Swap given dimensions of the input. -torch::lazy::LazyTensorPtr -transpose(const torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1); - -// In-place version of the method above. -void transpose_(torch::lazy::LazyTensorPtr& input, int64_t dim0, int64_t dim1); - -// Like reshape, but it returns a view into the original tensor. -torch::lazy::LazyTensorPtr view( - const torch::lazy::LazyTensorPtr& input, - c10::ArrayRef output_size); - -} // namespace lazy_tensor_aten_ops -} // namespace torch_lazy_tensors