From 1f1f61fa211a5b432d17b3e755f1dcb59e25a9a0 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 22 Feb 2024 18:49:50 -0500 Subject: [PATCH 1/2] Add additional matching logic to MMGroupQuantRewriterPass --- .../transforms/quantization/mm_group_quant.py | 48 +++++++++++++++-- core/shark_turbine/transforms/rewriter.py | 51 +++++++++++++++++++ .../custom_models/stateless_llama.py | 8 ++- 3 files changed, 103 insertions(+), 4 deletions(-) diff --git a/core/shark_turbine/transforms/quantization/mm_group_quant.py b/core/shark_turbine/transforms/quantization/mm_group_quant.py index 6db522820..5c0e833b2 100644 --- a/core/shark_turbine/transforms/quantization/mm_group_quant.py +++ b/core/shark_turbine/transforms/quantization/mm_group_quant.py @@ -48,6 +48,8 @@ def __init__(self, globals: GlobalsDict, builder: Builder): def match(self, op: Operation): weight_transpose = Transpose2DMatcher()(op.operands[1]) + if not weight_transpose: + weight_transpose = PermuteMatcher([1, 0])(op.operands[1]) if not weight_transpose: return None weight_load = GlobalLoadMatcher(self.globals)(weight_transpose.input) @@ -67,6 +69,38 @@ def match(self, op: Operation): ) +class ViewTransposedMMMatcher(NamedOpMatcher): + def __init__(self, globals: GlobalsDict, builder: Builder): + super().__init__("torch.aten.mm") + self.globals = globals + self.builder = builder + + def match(self, op: Operation): + weight_transpose = Transpose2DMatcher()(op.operands[1]) + if not weight_transpose: + weight_transpose = PermuteMatcher([1, 0])(op.operands[1]) + if not weight_transpose: + return None + weight_view = Transposed2DViewMatcher(self.builder)(weight_transpose.input) + if not weight_view: + return None + weight_load = GlobalLoadMatcher(self.globals)(weight_view.input) + if not weight_load or not weight_load.resolved_global: + return None + + m, n = self.builder.get_tensor_dims(op.operands[0].type) + _, k = self.builder.get_tensor_dims(op.operands[1].type) + return TransposedMMResult( + op, + weight_global=weight_load.resolved_global, + param_name=weight_load.global_ref, + m=m, + n=n, + k=k, + element_type=self.builder.get_tensor_element_type(op.operands[0].type), + ) + + # TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization GROUP_MATMUL_TEMPLATE = r""" module {{ @@ -125,14 +159,16 @@ def match(self, op: Operation): class MMGroupQuantRewriterPass(Pass): - def __init__(self, root_op: Operation, *, group_size: int = 128): + def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Optional[set] = None): super().__init__(root_op) self.group_size = group_size self.context = root_op.context + self.param_names = param_names def run(self): globals = self.globals mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder)) + view_mms = match_children(self.funcs, ViewTransposedMMMatcher(globals, self.builder)) for mr in mms: if mr.k is None or mr.n is None: @@ -140,14 +176,20 @@ def run(self): if (mr.k % self.group_size) != 0: continue self.rewrite(mr) + + for mr in view_mms: + if mr.k is None or mr.n is None: + continue + if (mr.k % self.group_size) != 0 or (mr.n % self.group_size): + continue + self.rewrite(mr) self.inline() self.cleanup() def rewrite(self, mr: TransposedMMResult): none_to_q = lambda x: "?" if x is None else x - # TODO (ian): make generalizable and not specific for brevitas - if "lm_head.weight" not in mr.param_name: + if self.param_names is None or mr.param_name[8:] in self.param_names: inline_module_asm = GROUP_MATMUL_TEMPLATE.format( # TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass param_name=mr.param_name[8:], diff --git a/core/shark_turbine/transforms/rewriter.py b/core/shark_turbine/transforms/rewriter.py index 4bcb19ce5..3fce5c2b3 100644 --- a/core/shark_turbine/transforms/rewriter.py +++ b/core/shark_turbine/transforms/rewriter.py @@ -38,6 +38,8 @@ "OpMatchResult", "Pass", "Transpose2DMatcher", + "PermuteMatcher", + "Transposed2DViewMatcher", "match_children", "pass_main", ] @@ -194,6 +196,55 @@ def match(self, op: Operation) -> Optional[Transpose2DResult]: return result +class PermuteResult(OpMatchResult): + @property + def input(self) -> Value: + return self.op.operands[0] + + +class PermuteMatcher(NamedOpMatcher): + def __init__(self, permutation: list[int]): + super().__init__("torch.aten.permute") + self.permutation = permutation + + def match(self, op: Operation) -> Optional[PermuteResult]: + list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1]) + if not list_construct: + return None + list_construct = list_construct.op + if len(self.permutation) != len(list_construct.operands): + return None + for i, list_item in enumerate(list_construct.operands): + if not ConstantIntMatcher(self.permutation[i])(list_item): + return None + return PermuteResult(op) + + +class Transposed2DViewResult(OpMatchResult): + @property + def input(self) -> Value: + return self.op.operands[0] + + +class Transposed2DViewMatcher(NamedOpMatcher): + def __init__(self, builder: Builder): + super().__init__("torch.aten.view") + self.builder = builder + + def match(self, op: Operation) -> Optional[Transposed2DViewResult]: + list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1]) + if not list_construct: + return None + list_construct = list_construct.op + if len(list_construct.operands) != 2: + return None + tensor_dims = self.builder.get_tensor_dims(op.operands[0].type) + if not ConstantIntMatcher(tensor_dims[0])(list_construct.operands[1]) or \ + not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]): + return None + return Transposed2DViewResult(op) + + class ConstantIntMatcher(NamedOpMatcher): def __init__(self, value: int): super().__init__("torch.constant.int") diff --git a/models/turbine_models/custom_models/stateless_llama.py b/models/turbine_models/custom_models/stateless_llama.py index 762690603..ac735cda3 100644 --- a/models/turbine_models/custom_models/stateless_llama.py +++ b/models/turbine_models/custom_models/stateless_llama.py @@ -140,11 +140,13 @@ def export_transformer_model( ) mapper = {} + param_set = set() if external_weights is not None: if external_weights == "safetensors": mod_params = dict(mod.named_parameters()) for name in mod_params: mapper["params." + name] = name + param_set.add(name) if external_weight_file: safetensors.torch.save_file(mod_params, external_weight_file) @@ -313,8 +315,12 @@ def evict_kvcache_space(self): if quantization == "int4" and not compile_to == "linalg": from shark_turbine.transforms.quantization import mm_group_quant + print(param_set) + if "lm_head.weight" in param_set: + param_set.remove("lm_head.weight") mm_group_quant.MMGroupQuantRewriterPass( - CompiledModule.get_mlir_module(inst).operation + CompiledModule.get_mlir_module(inst).operation, + param_names=param_set, ).run() module_str = str(CompiledModule.get_mlir_module(inst)) safe_name = hf_model_name.split("/")[-1].strip() From 373c068fd64a9753761a7c572b4bfd8201190a70 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Thu, 22 Feb 2024 18:57:33 -0500 Subject: [PATCH 2/2] black --- .../transforms/quantization/mm_group_quant.py | 14 +++++++++++--- core/shark_turbine/transforms/rewriter.py | 5 +++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/core/shark_turbine/transforms/quantization/mm_group_quant.py b/core/shark_turbine/transforms/quantization/mm_group_quant.py index 5c0e833b2..8106f823f 100644 --- a/core/shark_turbine/transforms/quantization/mm_group_quant.py +++ b/core/shark_turbine/transforms/quantization/mm_group_quant.py @@ -159,7 +159,13 @@ def match(self, op: Operation): class MMGroupQuantRewriterPass(Pass): - def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Optional[set] = None): + def __init__( + self, + root_op: Operation, + *, + group_size: int = 128, + param_names: Optional[set] = None, + ): super().__init__(root_op) self.group_size = group_size self.context = root_op.context @@ -168,7 +174,9 @@ def __init__(self, root_op: Operation, *, group_size: int = 128, param_names: Op def run(self): globals = self.globals mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder)) - view_mms = match_children(self.funcs, ViewTransposedMMMatcher(globals, self.builder)) + view_mms = match_children( + self.funcs, ViewTransposedMMMatcher(globals, self.builder) + ) for mr in mms: if mr.k is None or mr.n is None: @@ -176,7 +184,7 @@ def run(self): if (mr.k % self.group_size) != 0: continue self.rewrite(mr) - + for mr in view_mms: if mr.k is None or mr.n is None: continue diff --git a/core/shark_turbine/transforms/rewriter.py b/core/shark_turbine/transforms/rewriter.py index 3fce5c2b3..2a8fea752 100644 --- a/core/shark_turbine/transforms/rewriter.py +++ b/core/shark_turbine/transforms/rewriter.py @@ -239,8 +239,9 @@ def match(self, op: Operation) -> Optional[Transposed2DViewResult]: if len(list_construct.operands) != 2: return None tensor_dims = self.builder.get_tensor_dims(op.operands[0].type) - if not ConstantIntMatcher(tensor_dims[0])(list_construct.operands[1]) or \ - not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]): + if not ConstantIntMatcher(tensor_dims[0])( + list_construct.operands[1] + ) or not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]): return None return Transposed2DViewResult(op)