Skip to content

Commit

Permalink
Fix bugs in BackendImpl and codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Mar 31, 2022
1 parent 2d0c95a commit 397b604
Show file tree
Hide file tree
Showing 10 changed files with 469 additions and 30 deletions.
51 changes: 41 additions & 10 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from codegen.model import NativeFunctionsGroup


def isOptionalCType(arg):
return str(type(arg)) == "<class 'tools.codegen.api.types.OptionalCType'>"


def generate_native_functions(
config_path: Path, torch_ops_file: Path, out_file: Path
):
Expand Down Expand Up @@ -120,21 +124,47 @@ def get_native_function_name(f):

@dataclass(frozen=True)
class MlirLazyIr(codegen.gen_lazy_tensor.dest.LazyIR):
lowering_function_type: str = "torch::lazy::MlirFunction"
lowering_context_type: str = "torch::lazy::MlirLoweringContext*"
lowering_return_type: str = "torch::lazy::MlirOpVector"
lowering_function_type: str = "torch::lazy::TorchMlirFunction"
lowering_context_type: str = "torch::lazy::TorchMlirLoweringContext*"
lowering_return_type: str = "torch::lazy::TorchMlirOpVector"

def lowering_body(self, f):
func = (
f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
)
schema = LazyIrSchema(func)

return f"""
UNIMPLEMENTED_ERROR(
"'{func}' lowering not yet implemented"
);
""".rstrip()
emplace_arguments = []
for arg in schema.positional_args:
if arg.is_lazy_value:
if isOptionalCType(arg.lazy_type):
emplace_arguments.append(f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr")
continue
emplace_arguments.append('loctx->GetOutputOp(operand(i++))')
continue
emplace_arguments.append(f'"{arg.name}", {arg.name}')

emplace_arguments_str = "\n ".join(
[f"arguments.emplace_back({a});" for a in emplace_arguments])
emplace_kwarg_values = [f'"{t.name}", loctx->GetOutputOp(operand(i++))' for t in schema.keyword_values]
emplace_kwarg_scalars = [f'"{t.name}", {t.name}' for t in schema.keyword_scalars]
emplace_kwarguments = "\n ".join(
[f"kwarguments.emplace_back({a});" for a in emplace_kwarg_values + emplace_kwarg_scalars])

return f"""\
PRINT_FUNCTION();
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve({len(emplace_arguments)});
kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
size_t i = 0;
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(func.returns)});
return {schema.aten_name}_out;
"""


def generate_backend(
Expand All @@ -159,7 +189,7 @@ def gen_fallback_code(*args, **kwargs):
dry_run=False,
impl_path=str(backend_path.joinpath("aten_ltc_mlir_type.cpp")),
gen_ts_lowerings=False,
node_base="torch::lazy::MlirNode",
node_base="torch::lazy::TorchMlirNode",
node_base_hdr=str(backend_path.joinpath("mlir_node.h")),
tensor_class="torch::lazy::LazyTensor",
tensor_class_hdr="torch/csrc/lazy/core/tensor.h",
Expand Down Expand Up @@ -299,7 +329,6 @@ def main(args):
new_hash = m.hexdigest().strip()

if args.force or new_hash != prev_hash:
hash_file.write_text(new_hash)
parsed_yaml, grouped_native_functions = generate_native_functions(
config_path, torch_ops_file, native_functions
)
Expand All @@ -311,6 +340,8 @@ def main(args):
grouped_native_functions,
)

hash_file.write_text(new_hash)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
Loading

0 comments on commit 397b604

Please sign in to comment.