Skip to content

Commit

Permalink
Various fixes to enable my experimental llama.turbine port of llamacp…
Browse files Browse the repository at this point in the history
…p to compile. (nod-ai#269)

(https://github.com/stellaraccident/llama.turbine)

* Dedup captured tensors on fx import.
* Allow import of complex64/128 now that using
DenseResourceElementsAttr.
* Add index casting to IREE.splat_tensor.
* Support getting a tensor dimension as a specific type.
* Preserve FxImporter's py_attr_tracker across an entire module build.
  • Loading branch information
stellaraccident authored Dec 18, 2023
1 parent 7110975 commit cfec6ae
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 41 deletions.
1 change: 1 addition & 0 deletions python/shark_turbine/aot/builtins/jittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def flat_wrapped_f(*args):
context=proc_trace.context,
config_check=False,
literal_resolver_callback=_make_literal_resolver(proc_trace.module_builder),
py_attr_tracker=proc_trace.module_builder.fx_py_attr_tracker,
)
fx_importer.import_stateless_graph(gm.graph, func_name=self.function_name)

Expand Down
22 changes: 12 additions & 10 deletions python/shark_turbine/aot/compiled_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,17 +482,19 @@ def __new__(
try:
context = Context.current
except ValueError:
raise ValueError(
"Neither an implicit context context handler not "
"context= or module= arguments specified"
pass

if not context:
context = Context()

if not module_op:
with context:
loc = Location.unknown(context=context)
module = Module.create(loc)
module_op = module.operation
module_op.attributes["sym_name"] = StringAttr.get(
class_info.ir_module_name, context=context
)
if context:
loc = Location.unknown(context=context)
module = Module.create(loc)
module_op = module.operation
module_op.attributes["sym_name"] = StringAttr.get(
class_info.ir_module_name, context=context
)
module_builder = ModuleBuilder(module_op)
info = CompiledModuleInstanceInfo(class_info, module_builder=module_builder)
_all_compiled_module_instance_infos[self] = info
Expand Down
9 changes: 9 additions & 0 deletions python/shark_turbine/aot/support/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
TORCH_DTYPE_TO_MLIR_TYPE_ASM,
)

from ...importers.utils import (
RefTracker as FxRefTracker,
)

from ...dynamo.type_conversion import (
NativeTypeConverter,
)
Expand Down Expand Up @@ -169,6 +173,7 @@ class ModuleBuilder:
"body",
"cache",
"context",
"fx_py_attr_tracker",
"global_ip",
"ip",
"module_op",
Expand All @@ -187,6 +192,10 @@ def __init__(self, module_op: Operation):
self.cache = ContextCache(self.context)
# Tracks global references to a MaterializedGlobal.
self.global_ref_tracker = RefTracker()
# Usually the FxImporter makes a new ref tracker for each invocation,
# but we want to preserve it across individual JIT evaluations so
# as to better intern tensors to attributes.
self.fx_py_attr_tracker = FxRefTracker()
self.native_type_converter = NativeTypeConverter(self.context)

def handle_mlir_error(self, op: Operation, e: MLIRError, message: str):
Expand Down
10 changes: 8 additions & 2 deletions python/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractTensor):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand All @@ -160,7 +163,10 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractScalar):
global_type = value.get_ir_type(module_builder)
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand Down
34 changes: 29 additions & 5 deletions python/shark_turbine/aot/support/procedural/iree_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from ..ir_imports import (
IndexType,
IntegerType,
IrType,
RankedTensorType,
StringAttr,
Value,
arith_d,
flow_d,
)

Expand Down Expand Up @@ -103,6 +106,16 @@ def cast_tensor_dim_decl(
return dim_decls, dynamic_dim_values


def cast_scalar_to_element_type(scalar: Value, element_type: IrType) -> Value:
scalar_type = scalar.type
# Support cast from Index -> Integer.
if scalar_type == IndexType.get() and IntegerType.isinstance(element_type):
return arith_d.IndexCastUIOp(element_type, scalar).result
raise ValueError(
f"Provided splat value ({type(value)}) does not match dtype {dtype} (and cannot be cast)"
)


def assert_value_is_index(x: Value):
t = x.type
if not IndexType.isinstance(t):
Expand Down Expand Up @@ -131,11 +144,24 @@ def wrapper(*args, **kwargs):

class IREEEmitter:
@emitter
def tensor_dim(self, source: BuildableTensorType, index: int) -> "IrScalar":
def tensor_dim(
self,
source: BuildableTensorType,
index: int,
*,
dtype: Optional[torch.dtype] = None,
) -> "IrScalar":
"""Gets the dimension size of a tensor at a static position."""
source = cast_tensor_value(source)
index = cast_static_bounded_index(index, 0, source.rank - 1)
return IrImmediateScalar(source.get_dim_value(index))
dim_value = source.get_dim_value(index)
if dtype is not None:
try:
cast_type = TORCH_DTYPE_TO_IREE_TYPE[dtype]()
except KeyError:
raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type")
dim_value = arith_d.IndexCastUIOp(cast_type, dim_value).result
return IrImmediateScalar(dim_value)

@emitter
def tensor_empty(
Expand Down Expand Up @@ -304,9 +330,7 @@ def tensor_splat(
raise ValueError(f"Could not map Torch dtype {dtype} to an IREE type")
value = cast_scalar_value(value)
if value.type != element_type:
raise ValueError(
f"Provided splat value ({type(value)}) does not match dtype {dtype}"
)
value = cast_scalar_to_element_type(value, element_type)
tensor_type = RankedTensorType.get(dim_decls, element_type)
raw_tensor = flow_d.TensorSplatOp(tensor_type, value, dyn_dim_values).result
result = IrImmediateTensor(raw_tensor, dtype=dtype)
Expand Down
67 changes: 44 additions & 23 deletions python/shark_turbine/importers/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
from types import NoneType, BuiltinMethodType, BuiltinFunctionType
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union

import numpy as np

import torch
Expand Down Expand Up @@ -63,6 +64,7 @@
)

from .utils import (
RefTracker,
TypeSubclassMap,
)

Expand Down Expand Up @@ -131,8 +133,8 @@
torch.float64: np.float64,
torch.bool: np.bool_,
# torch.complex32: None, # no equivalent precision for numpy
# torch.complex64: np.complex64, # complex dtypes can't be parsed by DenseElementsAttr in the numpy buffer format
# torch.complex128: np.complex128,
torch.complex64: np.complex64,
torch.complex128: np.complex128,
}

# https://github.com/llvm/torch-mlir/blob/4c24472dea1c9102b898768b0b11e31487e50207/python/torch_mlir/_dynamo_fx_importer.py#L189
Expand Down Expand Up @@ -220,6 +222,7 @@ class FxImporter:
"_literal_resolver_callback",
"_m",
"_m_ip",
"_py_attr_tracker",
"symbol_table",
]

Expand All @@ -230,6 +233,7 @@ def __init__(
context: Optional[Context] = None,
config_check: bool = True,
literal_resolver_callback: Optional[LiteralResolverCallback] = None,
py_attr_tracker: Optional[RefTracker] = None,
):
if module is not None:
assert context is None, "If configuring with a Module, context must be None"
Expand All @@ -241,7 +245,8 @@ def __init__(
if config_check:
# Production code can disable this for a bit of a boost.
self._config_check()
self._cc = ContextCache(self._c)
self._py_attr_tracker = py_attr_tracker or RefTracker()
self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker)
self._m_ip = InsertionPoint(self._m.body)
self._literal_resolver_callback = literal_resolver_callback
self.symbol_table = SymbolTable(self._m.operation)
Expand Down Expand Up @@ -329,6 +334,7 @@ class ContextCache:
"_c",
"_dtype_to_type",
"_tensor_metadata_cache",
"_py_attr_tracker",
# Types.
"torch_bool_type",
"torch_float_type",
Expand All @@ -338,10 +344,13 @@ class ContextCache:
"torch_device_type",
]

def __init__(self, context: Context):
def __init__(
self, context: Context, *, py_attr_tracker: Optional[RefTracker] = None
):
self._c = context
self._dtype_to_type: Dict[TorchDtype, IrType] = {}
self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], IrType] = {}
self._py_attr_tracker = py_attr_tracker or RefTracker()

# Common types.
with context:
Expand Down Expand Up @@ -714,9 +723,9 @@ def _import_torch_op_overload(
if not self._c.is_registered_operation(mlir_op_name):
operation = Operation.create(
"torch.operator",
attributes={"name": StringAttr.get(mlir_op_name)},
results=result_types,
operands=operands,
attributes={"name": StringAttr.get(mlir_op_name)},
loc=loc,
)
else:
Expand Down Expand Up @@ -899,23 +908,35 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType:
raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type")


def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: IrType) -> Operation:
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
assert (
npy_dtype is not None
), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}"
# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,
# but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get
# a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes, "from_py", tensor_type
)
def _make_vtensor_literal_op(
tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: RefTracker
) -> Operation:
mapping = py_attr_tracker.track(tensor)
if mapping.is_empty:
# Resolve the attribute.
npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype)
assert (
npy_dtype is not None
), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}"
# We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal,
# but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get
# a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to
# detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw
# buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as
# desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above)
np_tensor = np.array(tensor.tolist()).astype(npy_dtype)
bytes_view = memoryview(np_tensor)
tensor_type = create_mlir_tensor_type(tensor)
shape_desc = "_".join([str(d) for d in tensor.shape])
blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}"
elements_attr = DenseResourceElementsAttr.get_from_buffer(
bytes_view,
blob_name,
tensor_type,
)
mapping.value = elements_attr
else:
elements_attr = mapping.value
return Operation.create(
name="torch.vtensor.literal",
results=[vtensor_type],
Expand Down Expand Up @@ -957,7 +978,7 @@ def _make_vtensor_literal_op(tensor: torch.Tensor, vtensor_type: IrType) -> Oper
LITERAL_CONVERTER_MAP.map(
torch.Tensor,
lambda arg, gni, cc: _make_vtensor_literal_op(
arg, cc.tensor_to_vtensor_type(arg)
arg, cc.tensor_to_vtensor_type(arg), cc._py_attr_tracker
).result,
)
LITERAL_CONVERTER_MAP.map(
Expand Down
59 changes: 59 additions & 0 deletions python/shark_turbine/importers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from typing import Any, Dict, List, Tuple

import weakref


class TypeSubclassMap:
"""Mapping of super-types to values.
Expand Down Expand Up @@ -41,3 +43,60 @@ def lookup(self, t: type) -> Any:
else:
self._cache[t] = None
return None


###############################################################################
# Reference mapping
###############################################################################


# Opaque value to indicate something is empty. Used in cases where 'None'
# may have a different meaning.
class EmptyType:
...


Empty = EmptyType()


class RefMapping:
__slots__ = [
"_referrent",
"value",
]

def __init__(self, referrent: Any):
if referrent is not Empty:
self._referrent = weakref.ref(referrent)
self.value = Empty

@property
def is_empty(self):
return self.value is Empty

def __repr__(self):
return (
f"<RefMapping {id(self._referrent) if self._referrent is not Empty else 'empty'} -> "
f"{self.value if self.value is not Empty else 'empty'}>"
)


class RefTracker:
"""Tracks live references from Python values to symbolic associations."""

def __init__(self):
self._refs: Dict[int, RefMapping] = {}

def track(self, referrent: Any) -> RefMapping:
ref_id = id(referrent)
existing = self._refs.get(ref_id)
if existing:
return existing
info = RefMapping(referrent)
if referrent is not Empty:
weakref.finalize(referrent, self._ref_finalizer, ref_id)
self._refs[ref_id] = info
return info

def _ref_finalizer(self, ref_id: int):
del self._refs[ref_id]
Loading

0 comments on commit cfec6ae

Please sign in to comment.