Skip to content

Commit

Permalink
Fix bugs in BackendImpl and codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Mar 31, 2022
1 parent 2d0c95a commit 6b3d5f5
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 30 deletions.
51 changes: 41 additions & 10 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from codegen.model import NativeFunctionsGroup


def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"


def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
Expand Down Expand Up @@ -120,21 +124,47 @@ def get_native_function_name(f):

@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
lowering_function_type: str = "torch::lazy::MlirFunction"
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
lowering_return_type: str = "torch::lazy::MlirOpVector"
lowering_function_type: str = "torch::lazy::TorchMlirFunction"
lowering_context_type: str = "torch::lazy::TorchMlirLoweringContext*"
lowering_return_type: str = "torch::lazy::TorchMlirOpVector"

def lowering_body(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)

return f"""
UNIMPLEMENTED_ERROR(
"'{func}' lowering not yet implemented"
);
""".rstrip()
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
if isOptionalCType(arg.lazy_type):
emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
continue
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')

emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments])
emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values]
emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars]
emplace_kwarguments = "\n ".join(
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])

return f"""\
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
return {schema.aten_name}_out;
"""


def generate_backend(
Expand All @@ -159,7 +189,7 @@ def gen_fallback_code(*args, **kwargs):
dry_run=False,
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
gen_ts_lowerings=False,
node_base="torch::lazy::MlirNode",
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
Expand Down Expand Up @@ -299,7 +329,6 @@ def main(args):
new_hash = m.hexdigest().strip()

if args.force or new_hash != prev_hash:
hash_file.write_text(new_hash)
parsed_yaml, grouped_native_functions = generate_native_functions(
config_path, torch_ops_file, native_functions
)
Expand All @@ -311,6 +340,8 @@ def main(args):
grouped_native_functions,
)

hash_file.write_text(new_hash)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/backend/aten_eager_fallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/torch/csrc/csrc/ts_backend/ts_eager_fallback.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/csrc/ts_backend/ts_eager_fallback.cpp
//===----------------------------------------------------------------------===//

#include <iostream>
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/backend/aten_eager_fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// Facilitates eager fallback behaviour
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/torch/csrc/csrc/ts_backend/ts_eager_fallback.h
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/csrc/ts_backend/ts_eager_fallback.h
//===----------------------------------------------------------------------===//

#pragma once
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/backend/aten_ltc_mlir_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/aten_ltc_ts_type.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
//===----------------------------------------------------------------------===//

#include <ATen/Operators.h>
Expand Down
32 changes: 17 additions & 15 deletions python/torch_mlir/csrc/backend/backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/backend_impl.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp
//===----------------------------------------------------------------------===//

#include <torch/csrc/lazy/backend/backend_data.h>
Expand All @@ -24,23 +24,21 @@ namespace torch {
namespace lazy {

TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
: BackendData(device, shape) {
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>()) {
PRINT_FUNCTION();
auto info = std::make_shared<TorchMlirBackendData::Info>();
SetInfo(info);
}
TorchMlirBackendData::TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})) {
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_unique<TorchMlirBackendData::Info>(scalar)) {
PRINT_FUNCTION();
auto info = std::make_shared<TorchMlirBackendData::Info>(scalar);
SetInfo(info);
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
: BackendData(device, shape) {
: BackendData(device, shape),
info_(std::make_unique<TorchMlirBackendData::Info>(tensor)) {
PRINT_FUNCTION();
auto info = std::make_shared<TorchMlirBackendData::Info>(tensor);
SetInfo(info);
}

BackendData::Handle TorchMlirBackendData::GetHandle() {
Expand All @@ -51,12 +49,16 @@ void TorchMlirBackendData::Assign(const BackendData& data) {
TorchMlirBackendData::Info* info =
dynamic_cast<TorchMlirBackendData::Info*>(data.info());
TORCH_CHECK(
info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
auto new_info = std::make_shared<TorchMlirBackendData::Info>(*info);
SetInfo(new_info);
info,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
info_ = std::make_unique<TorchMlirBackendData::Info>(*info);
}

bool TorchMlirBackendData::HasValue() const { return bool(info()); }
bool TorchMlirBackendData::HasValue() const { return bool(info_); }

TorchMlirBackendData::Info* TorchMlirBackendData::mlir_info() const {
return info_.get();
}

/**
* Initialization/Teardown
Expand Down
7 changes: 6 additions & 1 deletion python/torch_mlir/csrc/backend/backend_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// using the Torch-MLIR ATen dialect
//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/lazy_tensor_staging/lazy_tensor_core/lazy_tensor_core/csrc/ts_backend/backend_impl.h
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_backend_impl.h
//===----------------------------------------------------------------------===//

#pragma once
Expand Down Expand Up @@ -48,6 +48,11 @@ class TORCH_API TorchMlirBackendData : public BackendData {
virtual void Assign(const BackendData& data) override;

virtual bool HasValue() const override;

TorchMlirBackendData::Info* mlir_info() const;

private:
std::unique_ptr<TorchMlirBackendData::Info> info_;
};

class TORCH_API TorchMlirBackendImpl : public BackendImplInterface {
Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/backend/mlir_lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//
// This file is adapted from pytorch/pytorch
// https://github.com/pytorch/pytorch/blob/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp
//===----------------------------------------------------------------------===//

#include <iostream>
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/csrc/backend/mlir_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <torch/csrc/lazy/core/ir.h>
#include <torch/csrc/lazy/core/shape.h>

#include "../utils/debug.h"
#include "../utils/exception.h"
#include "aten_eager_fallback.h"
#include "mlir_lowering_context.h"
Expand Down

0 comments on commit 6b3d5f5

Please sign in to comment.