Skip to content

Commit

Permalink
Fix transformers rtn layer-wise quant (#2008)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Kaihui-intel and pre-commit-ci[bot] authored Sep 30, 2024
1 parent 802a5af commit a0066d4
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,18 @@ Pytorch and Intel-extension-for-pytorch version for intel GPU > 2.1 are required
```bash
pip install -r requirements_GPU.txt
pip install transformers==4.38.1 # llama use 4.38.1
source /opt/intel/oneapi/setvars.sh
git clone https://github.com/intel/intel-extension-for-pytorch.git ipex-gpu
cd ipex-gpu
git submodule update --init --recursive
export USE_AOT_DEVLIST='pvc,ats-m150'
export BUILD_WITH_CPU=OFF

export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib/:$LD_LIBRARY_PATH
export OCL_ICD_VENDORS=/etc/OpenCL/vendors
export CCL_ROOT=${CONDA_PREFIX}
source /opt/intel/oneapi/setvars.sh --force
export LLM_ACC_TEST=1

python setup.py install
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
tokenizer.save_pretrained(args.output_dir)

enable_optimize_transformers = False
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen"]
opt_gpu_model_type_list = ["llama", "gptj", "mistral", "qwen", "phi3"]

if config.model_type in opt_gpu_model_type_list:
enable_optimize_transformers = True
Expand Down
10 changes: 3 additions & 7 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,16 @@ def convert(

if use_layer_wise:
from neural_compressor.common.utils import DEFAULT_WORKSPACE
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module, register_weight_hooks
from neural_compressor.torch.algorithms.layer_wise.utils import get_path, load_module

if model_path == "":
model_path = model.path
assert model_path, "model_path should not be None."
model_path = get_path(model_path)

register_weight_hooks(model, model_path, device=device, clean_weight=True)

for name, m in model.named_modules():

if use_layer_wise and len(list(m.named_children())) == 0:
load_module(model, name, model_path, device=device)
if not isinstance(m, supported_layers):
continue
if name in weight_config: # pragma: no cover
Expand Down Expand Up @@ -192,9 +191,6 @@ def convert(
logger.debug(f"RTN quantized module:{name, m}")
logger.debug(log_msg)

if use_layer_wise:
load_module(model, name, model_path, device=device)

# for only group_dim is 0 or only `transformers.Conv1D`, we need transpose weight.
if is_transformers_imported():
transpose = (group_dim == 0) ^ (isinstance(m, transformers.Conv1D))
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/torch/utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def load_empty_model(pretrained_model_name_or_path, cls=None, **kwargs):
if cls.__base__ == _BaseAutoModelClass:
config = AutoConfig.from_pretrained(path, **kwargs)
with init_empty_weights():
model = cls.from_config(config)
model = cls.from_config(config, **kwargs)
else: # pragma: no cover
config = cls.config_class.from_pretrained(path, **kwargs)
with init_empty_weights():
model = cls(config)
model = cls(config, **kwargs)
model.tie_weights()
model.eval()
model.path = pretrained_model_name_or_path
Expand Down
28 changes: 27 additions & 1 deletion neural_compressor/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,33 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
(RtnConfig, AwqConfig, TeqConfig, GPTQConfig, AutoRoundConfig),
):
logger.info("Applying Weight Only Quantization.")
if use_xpu:
# set use_layer_wise on client
if hasattr(quantization_config, "use_layer_wise"):
import neural_compressor.torch.utils as torch_utils

process_type = torch_utils.get_processor_type_from_user_config()
if process_type == torch_utils.ProcessorType.Client:
quantization_config.use_layer_wise = True

if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise:
from transformers.dynamic_module_utils import resolve_trust_remote_code

from neural_compressor.torch import load_empty_model

trust_remote_code = kwargs.get("trust_remote_code", None)
has_remote_code = hasattr(config, "auto_map") and cls.ORIG_MODEL.__name__ in config.auto_map
has_local_code = type(config) in cls.ORIG_MODEL._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code,
has_remote_code,
)

model = load_empty_model(pretrained_model_name_or_path, trust_remote_code=trust_remote_code)
if use_cpu:
quantization_config.post_init_cpu()
elif use_xpu:
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
kwargs["low_cpu_mem_usage"] = True
kwargs["device_map"] = "cpu"
Expand Down
12 changes: 6 additions & 6 deletions neural_compressor/transformers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def _replace_linear(
"fp16": ipex.quantization.WoqLowpMode.FP16,
"int8": ipex.quantization.WoqLowpMode.INT8,
}

ipex_qconfig_mapping = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype[quantization_config.bits],
lowp_mode=compute_dtype[quantization_config.compute_dtype],
Expand Down Expand Up @@ -366,11 +365,6 @@ def convert_to_quantized_model(model, config, device="cpu"):

# mapping to INC config
dtype = "int4" if config.weight_dtype == "int4_fullrange" else config.weight_dtype
import neural_compressor.torch.utils as torch_utils

process_type = torch_utils.get_processor_type_from_user_config()
if process_type == torch_utils.ProcessorType.Client:
config.use_layer_wise = True
if config.quant_method.value == "rtn":
quant_config = RTNConfig(
dtype=dtype,
Expand Down Expand Up @@ -529,6 +523,12 @@ def convert_to_quantized_model(model, config, device="cpu"):
if orig_dtype != torch.float32:
q_model.to(dtype=orig_dtype)

if config.use_layer_wise and not (q_model.device == device or q_model.device.type == device):
logger.warning(
"Do not convert device to avoid out of memory. Recommend using saved quantized model to inference."
)
return q_model

return q_model.to(device)


Expand Down
33 changes: 33 additions & 0 deletions test/3x/torch/quantization/weight_only/test_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,39 @@ def test_save_load(self):
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."

def test_use_layer_wise(self):
model_name_or_path = self.model_name_or_path

fp32_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
dummy_input = fp32_model.dummy_inputs["input_ids"]

# RTN
# use_layer_wise=True
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=True)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
)
woq_output = woq_model(dummy_input)[0]

# save
output_dir = "./transformers_tmp"
woq_model.save_pretrained(output_dir)

# load
loaded_model = AutoModelForCausalLM.from_pretrained(output_dir)
loaded_output = loaded_model(dummy_input)[0]
assert torch.equal(woq_output, loaded_output), "loaded output should be same. Please double check."

# use_layer_wise=False
woq_config = RtnConfig(bits=4, group_size=16, use_layer_wise=False)
woq_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
)
woq_output2 = woq_model(dummy_input)[0]
assert torch.equal(woq_output, woq_output2), "use_layer_wise output should be same. Please double check."

def test_loading_autoawq_model(self):
user_model = AutoModelForCausalLM.from_pretrained(self.autoawq_model)
tokenizer = AutoTokenizer.from_pretrained(self.autoawq_model)
Expand Down

0 comments on commit a0066d4

Please sign in to comment.