Skip to content

Commit

Permalink
support qqq(w4a8) for lmdeploy
Browse files Browse the repository at this point in the history
fix llama_kernels compile issue on windows

fix conflicts

format
  • Loading branch information
HandH1998 committed Aug 12, 2024
1 parent c685f77 commit 779caa8
Show file tree
Hide file tree
Showing 44 changed files with 2,891 additions and 292 deletions.
5 changes: 3 additions & 2 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def model_format(parser, default: str = None):
'--model-format',
type=str,
default=default,
choices=['hf', 'llama', 'awq'],
choices=['hf', 'llama', 'awq', 'qqq'],
help='The format of input model. `hf` meaning `hf_llama`, `llama` '
'meaning `meta_llama`, `awq` meaning the quantized model by awq')
'meaning `meta_llama`, `awq` meaning the quantized model by awq, '
'`qqq` meaning the quantized model by qqq')

@staticmethod
def revision(parser, default: str = None):
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ class TurbomindEngineConfig:
"""TurboMind Engine config.
Args:
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq],
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), awq` meaning the quantized model by AWQ.
model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq, qqq],
`hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), `awq` meaning the quantized model by AWQ,
`qqq` meaning the quantized model by QQQ.
tp (int): the number of GPU cards used in tensor parallelism, default to 1
session_len (int): the max session length of a sequence, default to None
max_batch_size (int): the max batch size during inference, default to 128
Expand Down
36 changes: 25 additions & 11 deletions lmdeploy/turbomind/deploy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .source_model.base import INPUT_MODELS
from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig

SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', None]
SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', 'qqq', None]
logger = get_logger('lmdeploy')


Expand All @@ -26,12 +26,14 @@ def get_input_model_registered_name(model_path: str, model_format: str):
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['meta_llama', 'hf', 'awq']
['meta_llama', 'hf', 'awq', 'qqq']
"""
arch = get_model_arch(model_path)[0]
register_name = SUPPORTED_ARCHS[arch]
if model_format == 'awq':
register_name = register_name + '-awq'
elif model_format == 'qqq':
register_name = register_name + '-qqq'
return register_name


Expand Down Expand Up @@ -92,8 +94,9 @@ def get_output_model_registered_name_and_config(model_path: str,
Args:
model_path (str): the path of the input model
model_format (str): the format of the model, which can be one of
['meta_llama', 'hf', 'awq']
group_size (int): the size of group used by awq model
['meta_llama', 'hf', 'awq', 'qqq']
group_size (int): the size of group used by quantization methods,
including `awq` and `qqq`
"""
register_name = 'fp16'
turbomind_model_arch = 'llama'
Expand All @@ -113,6 +116,15 @@ def get_output_model_registered_name_and_config(model_path: str,
register_name = 'plora-w4' \
if turbomind_model_arch == 'xcomposer2' else 'w4'
group_size = 128 if group_size == 0 else group_size
config.quantization = 'awq'
elif model_format == 'qqq':
weight_type = 'int4'
register_name = 'qqq-w4'
from transformers import AutoConfig
quant_config = AutoConfig.from_pretrained(
model_path).quantization_config
group_size = quant_config['group_size']
config.quantization = 'qqq'
else:
torch_dtype = getattr(model_config, 'torch_dtype', 'float16')
TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16', torch.float16: 'fp16'}
Expand Down Expand Up @@ -212,17 +224,19 @@ def main(model_name: str,
model_name (str): unused any longer
model_path (str): the directory path of the model
model_format (str): the format of the model, should choose from
['meta_llama', 'hf', 'awq', None]. 'meta_llama' stands for META's
llama format, 'hf' means huggingface llama format, and 'awq' means
llama(hf) model quantized by lmdeploy/lite/quantization/awq.py.
The default value is None
chat_template (str): the name of the built-in chat template.
['meta_llama', 'hf', 'awq', 'qqq', None]. 'meta_llama' stands for
META's llama format, 'hf' means huggingface llama format,
'awq' means llama(hf) model quantized by
lmdeploy/lite/quantization/awq.py,
and 'qqq' means llama(hf) model quantized by the repo
https://github.com/HandH1998/QQQ,
the default value is None
tokenizer_path (str): the path of tokenizer model
dst_path (str): the destination path that saves outputs
tp (int): the number of GPUs used for tensor parallelism, should be 2^n
quant_path (str): Path of the quantized model, which can be None.
group_size (int): a parameter used in AWQ to quantize fp16 weights
to 4 bits
group_size (int): a parameter used in AWQ or QQQ to quantize fp16
weights to 4 bits
revision (str): The specific model version to use. It can be a branch
name, a tag name, or a commit id. If unspecified, will use
the default version.
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/source_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .internvl import InternVLModel # noqa: F401
from .llama import LlamaModel # noqa: F401
from .llama_awq import LlamaAwqModel # noqa: F401
from .llama_qqq import LlamaQQQModel # noqa: F401
from .meta_llama import MetaLlamaModel # noqa: F401
from .minicpmv import MiniCPMVModel # noqa: F401
from .minicpmv_awq import MiniCPMVAwqModel # noqa: F401
Expand Down
69 changes: 69 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/llama_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader


def ensure_dtype(tensors: torch.Tensor, dtype: torch.dtype):
"""Ensure tensors in the specified dytpe."""
result = []
for tensor in tensors:
if tensor is not None and tensor.numel() > 0:
if tensor.dtype in [torch.float16, torch.float32, torch.bfloat16]:
result.append(tensor.to(dtype))
else:
assert tensor.dtype == torch.int32
result.append(tensor)
else:
result.append(None)
return (*result, )


class LlamaQQQReader(LlamaReader):
"""LlamaQQQReader."""

def __init__(self, new_params: dict, unused_params: dict, last_bin: bool,
model_cfg: dict):
super().__init__(new_params, unused_params, last_bin, model_cfg)

def attn(self, i: int):
"""Get q, k, v, o qweight for layer i."""
return ensure_dtype(self._attn(i, 'B'), torch.int32)

def attn_scale_group(self, i: int):
"""Get q, k, v, o per-group scales for layer i."""
return ensure_dtype(self._attn(i, 's_group'), torch.float16)

def attn_scale_channel(self, i: int):
"""Get q, k, v, o per-channel scales for layer i."""
return ensure_dtype(self._attn(i, 's_channel'), torch.float32)

def ffn(self, i: int):
"""Get ffn qweight for layer i."""
return ensure_dtype(self._ffn(i, 'B'), torch.int32)

def ffn_scale_group(self, i: int):
"""Get ffn per-group scales for layer i."""
return ensure_dtype(self._ffn(i, 's_group'), torch.float16)

def ffn_scale_channel(self, i: int):
"""Get ffn per-channel scales for layer i."""
return ensure_dtype(self._ffn(i, 's_channel'), torch.float32)


@INPUT_MODELS.register_module(name='llama-qqq')
class LlamaQQQModel(LlamaModel):
"""Llama QQQ model in hf format."""

Reader = LlamaQQQReader

def __init__(self,
model_path: str,
tokenizer_path: str,
ckpt_path: str = None,
**kwargs):
super().__init__(model_path,
tokenizer_path,
ckpt_path=ckpt_path,
**kwargs)
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .fp import TurbomindModel # noqa: F401
from .plora import TurbomindPloraModel # noqa: F401
from .plora_w4 import TurbomindPloraW4Model # noqa: F401
from .qqq_w4 import TurbomindQQQW4Model # noqa: F401
from .w4 import TurbomindW4Model # noqa: F401
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TurbomindModelConfig:
max_prefill_iters: int = 1
use_context_fmha: int = 1
quant_policy: int = 0
quantization: str = ''
max_position_embeddings: int = 0
original_max_position_embeddings: int = 0
rope_scaling_type: str = ''
Expand Down
Loading

0 comments on commit 779caa8

Please sign in to comment.