Skip to content

Commit

Permalink
Support LayerWise for RTN/GPTQ (#1883)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
Co-authored-by: chensuyue <suyue.chen@intel.com>
  • Loading branch information
Kaihui-intel and chensuyue authored Jul 16, 2024
1 parent de43d85 commit 649e6b1
Show file tree
Hide file tree
Showing 13 changed files with 440 additions and 38 deletions.
1 change: 1 addition & 0 deletions .azure-pipelines/scripts/codeScan/pylint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ apt-get install -y --no-install-recommends --fix-missing \
build-essential

pip install -r /neural-compressor/requirements.txt
pip install -r /neural-compressor/requirements_pt.txt
pip install cmake

pip install torch \
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/torch/algorithms/layer_wise/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
_open_zipfile_reader,
)

from neural_compressor.adaptor.torch_utils.layer_wise_quant import modified_pickle as pickle
from neural_compressor.torch.algorithms.layer_wise import modified_pickle as pickle

from .utils import torch

Expand Down
12 changes: 11 additions & 1 deletion neural_compressor/torch/algorithms/layer_wise/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from transformers.models.auto.auto_factory import _BaseAutoModelClass

from neural_compressor.common import options
from neural_compressor.torch.algorithms.weight_only.modules import WeightOnlyLinear

from .load import load

LWQ_WORKSPACE = os.path.join(options.workspace, "layer_wise_tmp")
LWQ_WORKSPACE = os.path.join(options.workspace, "lwq_tmpdir")


class QDQLayer(torch.nn.Module):
Expand Down Expand Up @@ -215,6 +216,9 @@ def _get_path(pretrained_model_name_or_path):
return path


get_path = _get_path


def load_value(model, param_name, path):
if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True):
input_embeddings = model.get_input_embeddings()
Expand Down Expand Up @@ -281,6 +285,12 @@ def clean_module_weight(module):
else:
submodule = module

if isinstance(module, WeightOnlyLinear):
for n, m in submodule._buffers.items():
old_value = getattr(submodule, n)
with torch.no_grad():
submodule._buffers[n] = torch.zeros(old_value.shape, device="meta")

for n, m in submodule.named_parameters():
is_buffer = n in submodule._buffers
old_value = getattr(submodule, n)
Expand Down
46 changes: 36 additions & 10 deletions neural_compressor/torch/algorithms/weight_only/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,13 @@ def __init__(

# device
self.device = get_accelerator(kwargs.pop("device", "auto")).current_device_name()
self.model.to(self.device)
if not use_layer_wise:
self.model.to(self.device)
self.is_ready = False

self.use_layer_wise = use_layer_wise
self.model_path = model_path
if use_layer_wise:
self.prepare_layer_wise(model_path)

# dataloader
self.use_max_length = use_max_length
Expand All @@ -243,6 +245,20 @@ def __init__(
self.dataloader = []
self.nsamples = nsamples

def prepare_layer_wise(self, model_path):
import os

from neural_compressor.torch.algorithms.layer_wise import LWQ_WORKSPACE, get_path, register_weight_hooks

os.makedirs(LWQ_WORKSPACE, exist_ok=True)
if model_path == "":
model_path = self.model.path
assert model_path, "model_path should not be None."
self.model_path = get_path(model_path)
register_weight_hooks(
self.model, self.model_path, device=self.device, clean_weight=True, saved_path=LWQ_WORKSPACE
)

def get_full_layer_name(self, sub_layer_name, block_idx):
transformer_name = self.gptq_related_blocks["transformers_name"]
return ".".join([transformer_name, str(block_idx), sub_layer_name])
Expand Down Expand Up @@ -413,7 +429,6 @@ def execute_quantization(self, means=None, stds=None):
# Step1: prepare quantization (calibration datasets)

logger.info("Begin ====>")
model_path = self.model_path

# Step2: run gptq quantization in a transformer block-wise manner.
gptq_config = {}
Expand Down Expand Up @@ -450,7 +465,7 @@ def execute_quantization(self, means=None, stds=None):
if self.use_layer_wise: # pragma: no cover
from neural_compressor.torch.algorithms.layer_wise import load_value

W = load_value(self.model, full_layer_name + ".weight", model_path)
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
else:
W = sub_layers[layer_name].weight.data.clone()

Expand Down Expand Up @@ -489,7 +504,7 @@ def tmp(_, inp, out):
from neural_compressor.torch.algorithms.layer_wise import load_value

full_layer_name = self.get_full_layer_name(layer_name, block_idx)
W = load_value(self.model, full_layer_name + ".weight", model_path)
W = load_value(self.model, full_layer_name + ".weight", self.model_path)
else:
W = sub_layers[layer_name].weight.data.clone()
accelerator.mark_step()
Expand Down Expand Up @@ -518,7 +533,7 @@ def tmp(_, inp, out):
if n == "weight":
set_module_tensor_to_device(self.model, param_name, self.device, Q)
else:
value = load_value(self.model, param_name, model_path)
value = load_value(self.model, param_name, self.model_path)
set_module_tensor_to_device(self.model, param_name, self.device, value)
# sub_layer.weight.data = Q
torch.save(sub_layer.state_dict(), LWQ_WORKSPACE + f"/{full_layer_name}.pt")
Expand Down Expand Up @@ -562,7 +577,13 @@ def tmp(_, inp, out):
gptq_perm = gptq_config[self.get_full_layer_name(layer_name, block_idx)]["perm"]
else:
gptq_perm = None
Q = sub_layers[layer_name].weight.data
if self.use_layer_wise:
state_dict = torch.load(LWQ_WORKSPACE + f"/{self.get_full_layer_name(layer_name, block_idx)}.pt")
Q = state_dict["weight"].data
bias = state_dict["bias"] if "bias" in state_dict.keys() else None

else:
Q = sub_layers[layer_name].weight.data
if weight_config_this_layer["act_order"]:
Q.copy_(Q[:, gptq_perm])
if is_transformers_imported() and isinstance(sub_layers[layer_name], transformers.Conv1D):
Expand Down Expand Up @@ -591,18 +612,21 @@ def tmp(_, inp, out):
scale = scale.t_().contiguous()
zp = zp.t_().contiguous() if zp is not None else zp

if not self.use_layer_wise:
bias = sub_layers[layer_name].bias

new_module = WeightOnlyLinear(
in_features,
out_features,
dtype=weight_config_this_layer["dtype"],
bits=weight_config_this_layer["bits"],
group_size=weight_config_this_layer["group_size"],
zp=gptq_zp is not None,
bias=sub_layers[layer_name].bias is not None,
bias=bias is not None,
g_idx=gptq_perm is not None,
device=self.device,
)
new_module.pack(int_weight, gptq_scale, gptq_zp, sub_layers[layer_name].bias, gptq_perm)
new_module.pack(int_weight, gptq_scale, gptq_zp, bias, gptq_perm)
set_module(transformer_block, layer_name, new_module)
del gptq_for_this_block
torch.cuda.empty_cache()
Expand Down Expand Up @@ -1019,8 +1043,10 @@ def prepare(
def convert(self, model, *args, **kwargs):
self.gptq_quantizer.model = model
self.gptq_quantizer.remove_prepare_for_calibration()

q_model, gptq_config = self.gptq_quantizer.execute_quantization()
q_model = q_model.to(self.model_device)
if not self.gptq_quantizer.use_layer_wise:
q_model = q_model.to(self.model_device)
q_model.gptq_config = gptq_config
logger.info("GPTQ quantizing done.")
return q_model
Loading

0 comments on commit 649e6b1

Please sign in to comment.