Skip to content

Commit

Permalink
E2E HuggingFace Bert using LTC Backend (#912)
Browse files Browse the repository at this point in the history
* Update native function definitions

* Add ops to support bert lowering

- Add empty_strided and as_strided

- Restore zeros_like to op blacklist (Without this, tensors will be unintentionally created with a CPU device rather than lazy)

- Check for composite implicit ops and add device data IR

- Also fix codegen for functionalization

* Add autogen to CMakeList

* Remove PyTorch submodule

* Reduced BERT model size

* Print Mark Step status in Torch MLIR LTC debug string

* Apply fixes to work with latest upstream/main

- Pass importOptions into getMlirTypeFromTorchType during NodeImporter::importNode

  Without this, the tensor type created may have a mismatched type as ImportOptions may cause vtensor to be used instead of tensor

* Update shape inference functions

- Fixed compute_shape_native_batch_norm when mean and var are uninitialized

  Previously, the number of shapes returned would be <3 if either mean or val was didn't exist. Instead, we now initialize them with a vector matching the number of channels.

- Implemented compute_shape_mul

- Fixed bug in reshape shape inference error message

* Get MLIR backend more consistent with TS backend

- Remove LazyNativeFunctions::_unsafe_view from autogen

- Blacklist ops to make JIT graph more like output of TS backend

- Print graph when SSA value has mismatch of types and results

- Remove normalize_index from LazyShapeInference

- Fix seeds for LTC example models

* Update and clean up shape inference functions

- Prune shape inference functions

- Add shape inference function for GenerateSlice

- Add shape inference function for GenerateCopy

Co-authored-by: Henry Tu <henry.tu@cerebras.net>
  • Loading branch information
antoniojkim and henrytwo committed Jul 8, 2022
1 parent 189d7b5 commit 2da6913
Show file tree
Hide file tree
Showing 21 changed files with 824 additions and 415 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ libtorch*

/build/
__pycache__
*.pyc

.pytype

Expand Down
4 changes: 0 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
[submodule "external/llvm-project"]
path = externals/llvm-project
url = https://github.com/llvm/llvm-project.git
[submodule "externals/pytorch"]
path = externals/pytorch
url = https://github.com/pytorch/pytorch.git
shallow = true
120 changes: 73 additions & 47 deletions build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import argparse
import hashlib
import importlib
import os
import subprocess
import sys
import warnings
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from shutil import which
from textwrap import dedent

import yaml

TORCH_MLIR_DIR = Path(__file__).parent.parent.resolve()
TORCH_DIR = TORCH_MLIR_DIR.joinpath("externals", "pytorch")

sys.path.append(str(TORCH_DIR))

# PyTorch's LTC backend autogen script
import torchgen
import torchgen.dest.lazy_ir
import torchgen.gen_lazy_tensor
from torchgen.api.lazy import LazyIrSchema
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
from torchgen.model import NativeFunctionsGroup

TORCH_DIR = Path(importlib.util.find_spec('torch').origin).resolve().parent.parent
TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent

def isOptionalCType(arg):
return str(type(arg)) == "<class 'torchgen.api.types.OptionalCType'>"
Expand All @@ -42,20 +42,29 @@ def generate_native_functions(
grouped_native_functions = get_grouped_native_functions(native_functions)

def get_native_function_name(f):
func = f.func if hasattr(f, "func") else f.functional.func
return str(func.name)
func = f if hasattr(f, "func") else f.functional
return str(func.func.name), func

def get_opnames(ops):
opnames = defaultdict(set)
for op in ops:
opname = op.split(".")[0]
opnames[opname].add(op)
return opnames

aten_funcs = set(map(get_native_function_name, grouped_native_functions))
native_functions = dict(map(get_native_function_name, native_functions))
grouped_native_functions = dict(map(get_native_function_name, grouped_native_functions))
aten_funcs = get_opnames(set(grouped_native_functions.keys()))

with config_path.open() as f:
config = yaml.load(f, yaml.CLoader)

# List of unsupported ops in LTC autogen because of some error
blacklist = config.get("blacklist", [])
blacklist = set(config.get("blacklist", []))

# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported = config.get("supported", [])
supported = set(config.get("supported", []))

# List of non-native ops to do IR codegen for
non_native = config.get("non_native", [])
Expand All @@ -65,49 +74,54 @@ def get_native_function_name(f):
else:
cmd = ["grep", "-o", r"aten::[0-9a-zA-Z_\.]\+"]

output = (
subprocess.check_output(
torch_ops = set(
op[6:]
for op in subprocess.check_output(
cmd + [str(torch_ops_file)],
encoding="utf-8",
)
.strip()
.split(os.linesep)
)
torch_opnames = get_opnames(torch_ops)

# process ops list
ops = []
supported_ops = []
skipped = []
ops = set()
composite_implicit = set()

for op in output:
op = op[6:]
opname = op.split(".")[0]

if opname in blacklist or op in blacklist:
for op in torch_ops:
if op not in native_functions:
continue

if opname in supported:
supported_ops.append(op)
continue
func = native_functions[op]
base = func.func.name.name.base

if op not in aten_funcs:
skipped.append(op)
if base in blacklist or op in blacklist:
continue
if base in supported or op in supported:
continue

ops.append(op)
if func.has_composite_implicit_autograd_kernel and f"{op}_backward" not in torch_ops:
composite_implicit.add(op)
elif func.func.name.name.inplace:
for autogen in func.autogen:
if "functional" in autogen.overload_name:
ops.add(str(autogen))
else:
ops.add(op)

opnames = sorted(set(ops))
skipped = set(torch_ops) - ops - supported - composite_implicit

# Additional ops to support that are not supported by Torch-MLIR explicitly
supported_ops.extend(config.get("additional_ops", []))
supported |= set(config.get("additional_ops", []))

with out_file.open("w") as f:
yaml.dump(
{
"backend": "Lazy",
"cpp_namespace": "torch::lazy",
"full_codegen": opnames,
"supported": sorted(supported_ops),
"full_codegen": sorted(ops),
"supported": sorted(supported),
"non_native": non_native,
},
f,
Expand All @@ -117,10 +131,15 @@ def get_native_function_name(f):
dedent(
"""
# Composite implicit ops (supported by Torch-MLIR but not differentiable)
{composite_implicit}
# Skipped ops (supported by Torch-MLIR but no equivalent native function)
{skipped}
"""
).format(
composite_implicit=os.linesep.join(f"# - {op}" for op in sorted(composite_implicit)),
skipped=os.linesep.join(f"# - {op}" for op in sorted(skipped)),
)
+ os.linesep.join(f"# - {op}" for op in sorted(skipped))
)

return parsed_yaml, grouped_native_functions
Expand All @@ -129,11 +148,13 @@ def get_native_function_name(f):
@dataclass(frozen=True)
class GenMlirLazyIr(torchgen.dest.GenLazyIR):

def lowering_function(self, schema, declaration_only=True):
def lowering_function(self, schema):
signature = "TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override"

if declaration_only:
if schema.properties.LowerDeclOnly:
return f"{signature};"
elif not schema.properties.Lower:
return ""

emplace_arguments = []
for arg in schema.positional_args:
Expand Down Expand Up @@ -213,7 +234,7 @@ def gen_fallback_code(*args, **kwargs):
import re

sig_re = re.compile(
r"std::vector<Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
r"std::vector<torch::lazy::Shape>\s+(?P<name>\w+)\((?P<signature>[^\)]+)\)"
)
global_signatures = {}

Expand Down Expand Up @@ -307,25 +328,30 @@ def main(args):
)
assert backend_path.is_dir()

torchgen_path = Path(torchgen.__path__[0]).resolve()
assert torchgen_path.is_dir()

prev_hash = None
hash_file = TORCH_MLIR_DIR.joinpath("generated_backend.hash")
if hash_file.exists():
prev_hash = hash_file.read_text().strip()

m = hashlib.sha256()
m.update(script_path.read_bytes())
m.update(config_path.read_bytes())
m.update(torch_ops_file.read_bytes())
if native_functions.exists():
m.update(native_functions.read_bytes())

shape_inference_headers = backend_path.joinpath("LazyShapeInference.h")
if shape_inference_headers.exists():
m.update(shape_inference_headers.read_bytes())

shape_inference_defs = backend_path.joinpath("LazyShapeInference.cpp")
if shape_inference_defs.exists():
m.update(shape_inference_defs.read_bytes())

# Add file contents to hash
for path in (
script_path,
config_path,
torch_ops_file,
native_functions,
backend_path.joinpath("LazyShapeInference.h"),
backend_path.joinpath("LazyShapeInference.cpp"),
torchgen_path.joinpath("dest", "lazy_ir.py"),
torchgen_path.joinpath("api", "lazy.py"),
torchgen_path.joinpath("model.py"),
):
if path.exists():
m.update(path.read_bytes())

new_hash = m.hexdigest().strip()

Expand Down
78 changes: 39 additions & 39 deletions build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
@@ -1,49 +1,46 @@
blacklist:
# List of unsupported ops in LTC autogen because of some error
- arange # Error: Code below assumes there is at least one tensor arg
- contiguous # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- empty_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- full # Error: Code below assumes there is at least one tensor arg
- index.Tensor # Error: TODO not sure if there are other valid types to handle here
- index_put # Error: TODO not sure if there are other valid types to handle here
- index_put_ # Error: TODO not sure if there are other valid types to handle here
- _index_put_impl_ # Error: TODO not sure if there are other valid types to handle here
- ones # Error: Code below assumes there is at least one tensor arg
- ones_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- resize_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- stack # Error: TODO not sure if there are other valid types to handle here
- to.dtype # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- to.other # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- uniform_ # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)
- zeros # Error: Code below assumes there is at least one tensor arg
- zeros_like # Error: TODO add support for type BaseType(name=<BaseTy.MemoryFormat: 12>)

# Additional ops which autogen is supported for but don't compile yet
- detach
- item
- size
- where
- copy_
- _to_copy
- log_softmax # Not inherently differentiable. Needs to be decomposed.
- linear # Not inherently differentiable. Needs to be decomposed.

# Disabled for consistency with TS backend
- rsub

# List of supported ops that we don't want to do the full codegen for
# primarily view ops
supported:
# - bernoulli
# - bernoulli_
- as_strided
- as_strided_
- _to_copy
- cat
- clone
- empty
- empty.memory_format
- empty_strided
- expand
- fill_
- native_batch_norm_backward
- fill_.Scalar
- permute
- select.int
- slice.Tensor
- squeeze
- squeeze.dim
- t
- transpose.int
- unsqueeze
- view
- _unsafe_view

additional_ops:
# Additional ops to support that are not supported by Torch-MLIR explicitly
Expand All @@ -53,35 +50,38 @@ additional_ops:

# List of non native ops that we only want to do IR node class generation for
non_native:
- func: device_data(std::shared_ptr<BackendData> data) -> Tensor
opkind: ltc_device_data
cache_shape: false
- func: scalar(at::Scalar value, at::ScalarType type) -> Tensor
- func: scalar(Scalar value, ScalarType type) -> Tensor
opkind: at::prim::Constant
cache_shape: false
- func: expand(Tensor input, std::vector<int64_t> size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, std::vector<int64_t> output_size) -> Tensor
cache_shape: false
- func: cast(Tensor input, at::ScalarType dtype, optional<at::ScalarType> stype) -> Tensor
properties:
- ShapeCompute
- TreatScalarsAsConstants
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
- func: view(Tensor input, int[] output_size) -> Tensor
properties:
- ShapeCompute
- func: cast(Tensor input, ScalarType dtype, ScalarType? stype) -> Tensor
opkind: ltc_cast
cache_shape: false
properties:
- ShapeCompute

# View ops only required until proper functionalization pass is introduced into LTC
- func: as_strided_view_update(Tensor target, Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: as_strided_view_update(Tensor target, Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
opkind: ltc_as_strided_view_update
- func: as_strided(Tensor input, std::vector<int64_t> size, std::vector<int64_t> stride, int64_t storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: as_strided(Tensor input, int[] size, int[] stride, int storage_offset) -> Tensor
- func: diagonal_view_update(Tensor target, Tensor input, int offset, int dim1, int dim2) -> Tensor
opkind: ltc_diagonal_view_update
cache_shape: false
- func: diagonal(Tensor input, int64_t offset, int64_t dim1, int64_t dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, std::vector<int64_t> base_indices) -> Tensor
properties:
- ShapeCompute
- func: diagonal(Tensor input, int offset, int dim1, int dim2) -> Tensor
- func: narrow_view_update(Tensor input, Tensor source, int[] base_indices) -> Tensor
opkind: ltc_narrow_view_update
- func: narrow(Tensor input, std::vector<int64_t> base_indices, std::vector<int64_t> sizes) -> Tensor
- func: permute(Tensor input, std::vector<int64_t> dims) -> Tensor
- func: resize(Tensor input, std::vector<int64_t> size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
- func: narrow(Tensor input, int[] base_indices, int[] sizes) -> Tensor
- func: permute(Tensor input, int[] dims) -> Tensor
- func: resize(Tensor input, int[] size) -> Tensor
- func: select_view_update(Tensor target, Tensor source, int dim, int start, int end, int stride) -> Tensor
opkind: ltc_select_view_update
cache_shape: false
- func: select(Tensor input, int64_t dim, int64_t start, int64_t end, int64_t stride) -> Tensor
properties:
- ShapeCompute
- func: select(Tensor input, int dim, int start, int end, int stride) -> Tensor
- func: squeeze(Tensor input, int dim) -> Tensor
- func: unsqueeze(Tensor input, int dim) -> Tensor
Loading

0 comments on commit 2da6913

Please sign in to comment.