From 6b3d5f5360c0f5c18059f725f00f044f7f5e8fe2 Mon Sep 17 00:00:00 2001 From: Antonio Kim Date: Wed, 30 Mar 2022 12:29:11 -0700 Subject: [PATCH] Fix bugs in BackendImpl and codegen --- build_tools/autogen_ltc_backend.py | 51 +++++++++++++++---- .../csrc/backend/aten_eager_fallback.cpp | 2 +- .../csrc/backend/aten_eager_fallback.h | 2 +- .../csrc/backend/aten_ltc_mlir_type.cpp | 2 +- .../torch_mlir/csrc/backend/backend_impl.cpp | 32 ++++++------ python/torch_mlir/csrc/backend/backend_impl.h | 7 ++- .../csrc/backend/mlir_lowering_context.cpp | 2 +- python/torch_mlir/csrc/backend/mlir_node.h | 1 + 8 files changed, 69 insertions(+), 30 deletions(-) diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 0e924829dd1a..a279bdfafb94 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -24,6 +24,10 @@ from codegen.model import NativeFunctionsGroup +def isOptionalCType(arg): + return str(type(arg)) == "" + + def generate_native_functions( config_path: Path, torch_ops_file: Path, out_file: Path ): @@ -120,9 +124,9 @@ def get_native_function_name(f): @dataclass(frozen=True) class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR): - lowering_function_type: str = "torch::lazy::MlirFunction" - lowering_context_type: str = "torch::lazy::MlirLoweringContext*" - lowering_return_type: str = "torch::lazy::MlirOpVector" + lowering_function_type: str = "torch::lazy::TorchMlirFunction" + lowering_context_type: str = "torch::lazy::TorchMlirLoweringContext*" + lowering_return_type: str = "torch::lazy::TorchMlirOpVector" def lowering_body(self, f): func = ( @@ -130,11 +134,37 @@ def lowering_body(self, f): ) schema = LazyIrSchema(func) - return f""" - UNIMPLEMENTED_ERROR( - "'{func}' lowering not yet implemented" - ); - """.rstrip() + emplace_arguments = [] + for arg in schema.positional_args: + if arg.is_lazy_value: + if isOptionalCType(arg.lazy_type): + emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr") + continue + emplace_arguments.append('loctx->GetOutputOp(operand(i++))') + continue + emplace_arguments.append(f'"{arg.name}", {arg.name}') + + emplace_arguments_str = "\n ".join( + [f"arguments.emplace_back({a});" for a in emplace_arguments]) + emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values] + emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars] + emplace_kwarguments = "\n ".join( + [f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars]) + + return f"""\ + PRINT_FUNCTION(); + std::vector arguments; + std::vector kwarguments; + arguments.reserve({len(emplace_arguments)}); + kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)}); + size_t i = 0; + {emplace_arguments_str} + {emplace_kwarguments} + torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments); + CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)}); + + return {schema.aten_name}_out; + """ def generate_backend( @@ -159,7 +189,7 @@ def gen_fallback_code(*args, **kwargs): dry_run=False, impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")), gen_ts_lowerings=False, - node_base="torch::lazy::MlirNode", + node_base="torch::lazy::TorchMlirNode", node_base_hdr=str(backend_path.joinpath("mlir_node.h")), tensor_class="torch::lazy::LazyTensor", tensor_class_hdr="torch/csrc/lazy/core/tensor.h", @@ -299,7 +329,6 @@ def main(args): new_hash = m.hexdigest().strip() if args.force or new_hash != prev_hash: - hash_file.write_text(new_hash) parsed_yaml, grouped_native_functions = generate_native_functions( config_path, torch_ops_file, native_functions ) @@ -311,6 +340,8 @@ def main(args): grouped_native_functions, ) + hash_file.write_text(new_hash) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp b/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp index 25d4de7936b9..db4e210637ca 100644 --- a/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp +++ b/python/torch_mlir/csrc/backend/aten_eager_fallback.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/torch/csrc/csrc/ts_backend/ts_eager_fallback.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/csrc/ts_backend/ts_eager_fallback.cpp //===----------------------------------------------------------------------===// #include diff --git a/python/torch_mlir/csrc/backend/aten_eager_fallback.h b/python/torch_mlir/csrc/backend/aten_eager_fallback.h index 4e0f422260bd..7d9963b06261 100644 --- a/python/torch_mlir/csrc/backend/aten_eager_fallback.h +++ b/python/torch_mlir/csrc/backend/aten_eager_fallback.h @@ -9,7 +9,7 @@ // Facilitates eager fallback behaviour // // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/torch/csrc/csrc/ts_backend/ts_eager_fallback.h +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/csrc/ts_backend/ts_eager_fallback.h //===----------------------------------------------------------------------===// #pragma once diff --git a/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp b/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp index 8b8b3606ee1b..df31ff92a9fa 100644 --- a/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp +++ b/python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// #include diff --git a/python/torch_mlir/csrc/backend/backend_impl.cpp b/python/torch_mlir/csrc/backend/backend_impl.cpp index e0cce17d7813..dfd5f4c24019 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.cpp +++ b/python/torch_mlir/csrc/backend/backend_impl.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/backend_impl.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp //===----------------------------------------------------------------------===// #include @@ -24,23 +24,21 @@ namespace torch { namespace lazy { TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) - : BackendData(device, shape) { + : BackendData(device, shape), + info_(std::make_unique()) { PRINT_FUNCTION(); - auto info = std::make_shared(); - SetInfo(info); } -TorchMlirBackendData::TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device) - : BackendData(device, Shape(scalar.type(), {})) { +TorchMlirBackendData::TorchMlirBackendData( + const at::Scalar& scalar, BackendDevice device) + : BackendData(device, Shape(scalar.type(), {})), + info_(std::make_unique(scalar)) { PRINT_FUNCTION(); - auto info = std::make_shared(scalar); - SetInfo(info); } TorchMlirBackendData::TorchMlirBackendData( const at::Tensor& tensor, BackendDevice device, Shape shape) - : BackendData(device, shape) { + : BackendData(device, shape), + info_(std::make_unique(tensor)) { PRINT_FUNCTION(); - auto info = std::make_shared(tensor); - SetInfo(info); } BackendData::Handle TorchMlirBackendData::GetHandle() { @@ -51,12 +49,16 @@ void TorchMlirBackendData::Assign(const BackendData& data) { TorchMlirBackendData::Info* info = dynamic_cast(data.info()); TORCH_CHECK( - info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); - auto new_info = std::make_shared(*info); - SetInfo(new_info); + info, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); + info_ = std::make_unique(*info); } -bool TorchMlirBackendData::HasValue() const { return bool(info()); } +bool TorchMlirBackendData::HasValue() const { return bool(info_); } + +TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const { + return info_.get(); +} /** * Initialization/Teardown diff --git a/python/torch_mlir/csrc/backend/backend_impl.h b/python/torch_mlir/csrc/backend/backend_impl.h index 90fce516582e..33e509f6fb46 100644 --- a/python/torch_mlir/csrc/backend/backend_impl.h +++ b/python/torch_mlir/csrc/backend/backend_impl.h @@ -10,7 +10,7 @@ // using the Torch-MLIR ATen dialect // // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h //===----------------------------------------------------------------------===// #pragma once @@ -48,6 +48,11 @@ class TORCH_API TorchMlirBackendData : public BackendData { virtual void Assign(const BackendData& data) override; virtual bool HasValue() const override; + + TorchMlirBackendData::Info* mlir_info() const; + +private: + std::unique_ptr info_; }; class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { diff --git a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp b/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp index 7a99246ae298..0d5ec9b3dcd4 100644 --- a/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp +++ b/python/torch_mlir/csrc/backend/mlir_lowering_context.cpp @@ -7,7 +7,7 @@ // //===----------------------------------------------------------------------===// // This file is adapted from pytorch/pytorch -// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp //===----------------------------------------------------------------------===// #include diff --git a/python/torch_mlir/csrc/backend/mlir_node.h b/python/torch_mlir/csrc/backend/mlir_node.h index 94c797840a25..f03b6a807118 100644 --- a/python/torch_mlir/csrc/backend/mlir_node.h +++ b/python/torch_mlir/csrc/backend/mlir_node.h @@ -18,6 +18,7 @@ #include #include +#include "../utils/debug.h" #include "../utils/exception.h" #include "aten_eager_fallback.h" #include "mlir_lowering_context.h"