Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional matching logic to MMGroupQuantRewriterPass #473

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions core/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(self, globals: GlobalsDict, builder: Builder):

def match(self, op: Operation):
weight_transpose = Transpose2DMatcher()(op.operands[1])
if not weight_transpose:
weight_transpose = PermuteMatcher([1, 0])(op.operands[1])
if not weight_transpose:
return None
weight_load = GlobalLoadMatcher(self.globals)(weight_transpose.input)
Expand All @@ -67,6 +69,38 @@ def match(self, op: Operation):
)


class ViewTransposedMMMatcher(NamedOpMatcher):
def __init__(self, globals: GlobalsDict, builder: Builder):
super().__init__("torch.aten.mm")
self.globals = globals
self.builder = builder

def match(self, op: Operation):
weight_transpose = Transpose2DMatcher()(op.operands[1])
if not weight_transpose:
weight_transpose = PermuteMatcher([1, 0])(op.operands[1])
if not weight_transpose:
return None
weight_view = Transposed2DViewMatcher(self.builder)(weight_transpose.input)
if not weight_view:
return None
weight_load = GlobalLoadMatcher(self.globals)(weight_view.input)
if not weight_load or not weight_load.resolved_global:
return None

m, n = self.builder.get_tensor_dims(op.operands[0].type)
_, k = self.builder.get_tensor_dims(op.operands[1].type)
return TransposedMMResult(
op,
weight_global=weight_load.resolved_global,
param_name=weight_load.global_ref,
m=m,
n=n,
k=k,
element_type=self.builder.get_tensor_element_type(op.operands[0].type),
)


# TODO (ian): Make more generalizable using RenameParametersPass. Currently hardcoded for brevitas quantization
GROUP_MATMUL_TEMPLATE = r"""
module {{
Expand Down Expand Up @@ -125,14 +159,24 @@ def match(self, op: Operation):


class MMGroupQuantRewriterPass(Pass):
def __init__(self, root_op: Operation, *, group_size: int = 128):
def __init__(
self,
root_op: Operation,
*,
group_size: int = 128,
param_names: Optional[set] = None,
):
super().__init__(root_op)
self.group_size = group_size
self.context = root_op.context
self.param_names = param_names

def run(self):
globals = self.globals
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))
view_mms = match_children(
self.funcs, ViewTransposedMMMatcher(globals, self.builder)
)

for mr in mms:
if mr.k is None or mr.n is None:
Expand All @@ -141,13 +185,19 @@ def run(self):
continue
self.rewrite(mr)

for mr in view_mms:
if mr.k is None or mr.n is None:
continue
if (mr.k % self.group_size) != 0 or (mr.n % self.group_size):
continue
self.rewrite(mr)

self.inline()
self.cleanup()

def rewrite(self, mr: TransposedMMResult):
none_to_q = lambda x: "?" if x is None else x
# TODO (ian): make generalizable and not specific for brevitas
if "lm_head.weight" not in mr.param_name:
if self.param_names is None or mr.param_name[8:] in self.param_names:
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
# TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass
param_name=mr.param_name[8:],
Expand Down
52 changes: 52 additions & 0 deletions core/shark_turbine/transforms/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"OpMatchResult",
"Pass",
"Transpose2DMatcher",
"PermuteMatcher",
"Transposed2DViewMatcher",
"match_children",
"pass_main",
]
Expand Down Expand Up @@ -194,6 +196,56 @@ def match(self, op: Operation) -> Optional[Transpose2DResult]:
return result


class PermuteResult(OpMatchResult):
@property
def input(self) -> Value:
return self.op.operands[0]


class PermuteMatcher(NamedOpMatcher):
def __init__(self, permutation: list[int]):
super().__init__("torch.aten.permute")
self.permutation = permutation

def match(self, op: Operation) -> Optional[PermuteResult]:
list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1])
if not list_construct:
return None
list_construct = list_construct.op
if len(self.permutation) != len(list_construct.operands):
return None
for i, list_item in enumerate(list_construct.operands):
if not ConstantIntMatcher(self.permutation[i])(list_item):
return None
return PermuteResult(op)


class Transposed2DViewResult(OpMatchResult):
@property
def input(self) -> Value:
return self.op.operands[0]


class Transposed2DViewMatcher(NamedOpMatcher):
def __init__(self, builder: Builder):
super().__init__("torch.aten.view")
self.builder = builder

def match(self, op: Operation) -> Optional[Transposed2DViewResult]:
list_construct = NamedOpMatcher("torch.prim.ListConstruct")(op.operands[1])
if not list_construct:
return None
list_construct = list_construct.op
if len(list_construct.operands) != 2:
return None
tensor_dims = self.builder.get_tensor_dims(op.operands[0].type)
if not ConstantIntMatcher(tensor_dims[0])(
list_construct.operands[1]
) or not ConstantIntMatcher(tensor_dims[1])(list_construct.operands[0]):
return None
return Transposed2DViewResult(op)


class ConstantIntMatcher(NamedOpMatcher):
def __init__(self, value: int):
super().__init__("torch.constant.int")
Expand Down
8 changes: 7 additions & 1 deletion models/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ def export_transformer_model(
)

mapper = {}
param_set = set()
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(mod.named_parameters())
for name in mod_params:
mapper["params." + name] = name
param_set.add(name)
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)

Expand Down Expand Up @@ -313,8 +315,12 @@ def evict_kvcache_space(self):
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

print(param_set)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

if "lm_head.weight" in param_set:
param_set.remove("lm_head.weight")
mm_group_quant.MMGroupQuantRewriterPass(
CompiledModule.get_mlir_module(inst).operation
CompiledModule.get_mlir_module(inst).operation,
param_names=param_set,
).run()
module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = hf_model_name.split("/")[-1].strip()
Expand Down
Loading