diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 468f10115..27e609691 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -106,9 +106,10 @@ def model_format(parser, default: str = None): '--model-format', type=str, default=default, - choices=['hf', 'llama', 'awq'], + choices=['hf', 'llama', 'awq', 'qqq'], help='The format of input model. `hf` meaning `hf_llama`, `llama` ' - 'meaning `meta_llama`, `awq` meaning the quantized model by awq') + 'meaning `meta_llama`, `awq` meaning the quantized model by awq, ' + '`qqq` meaning the quantized model by qqq') @staticmethod def revision(parser, default: str = None): diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index a8d22bad1..4ce259f9f 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -115,8 +115,9 @@ class TurbomindEngineConfig: """TurboMind Engine config. Args: - model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq], - `hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), awq` meaning the quantized model by AWQ. + model_format (str): the layout of the deployed model. It can be one of the following values [hf, meta_llama, awq, qqq], + `hf` meaning huggingface model(.bin, .safetensors), `meta_llama` being meta llama's format(.pth), `awq` meaning the quantized model by AWQ, + `qqq` meaning the quantized model by QQQ. tp (int): the number of GPU cards used in tensor parallelism, default to 1 session_len (int): the max session length of a sequence, default to None max_batch_size (int): the max batch size during inference, default to 128 diff --git a/lmdeploy/turbomind/deploy/converter.py b/lmdeploy/turbomind/deploy/converter.py index 59a038da9..4a25053b1 100644 --- a/lmdeploy/turbomind/deploy/converter.py +++ b/lmdeploy/turbomind/deploy/converter.py @@ -15,7 +15,7 @@ from .source_model.base import INPUT_MODELS from .target_model.base import OUTPUT_MODELS, TurbomindModelConfig -SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', None] +SUPPORTED_FORMATS = ['meta_llama', 'hf', 'awq', 'qqq', None] logger = get_logger('lmdeploy') @@ -26,12 +26,14 @@ def get_input_model_registered_name(model_path: str, model_format: str): Args: model_path (str): the path of the input model model_format (str): the format of the model, which can be one of - ['meta_llama', 'hf', 'awq'] + ['meta_llama', 'hf', 'awq', 'qqq'] """ arch = get_model_arch(model_path)[0] register_name = SUPPORTED_ARCHS[arch] if model_format == 'awq': register_name = register_name + '-awq' + elif model_format == 'qqq': + register_name = register_name + '-qqq' return register_name @@ -92,8 +94,9 @@ def get_output_model_registered_name_and_config(model_path: str, Args: model_path (str): the path of the input model model_format (str): the format of the model, which can be one of - ['meta_llama', 'hf', 'awq'] - group_size (int): the size of group used by awq model + ['meta_llama', 'hf', 'awq', 'qqq'] + group_size (int): the size of group used by quantization methods, + including `awq` and `qqq` """ register_name = 'fp16' turbomind_model_arch = 'llama' @@ -113,6 +116,15 @@ def get_output_model_registered_name_and_config(model_path: str, register_name = 'plora-w4' \ if turbomind_model_arch == 'xcomposer2' else 'w4' group_size = 128 if group_size == 0 else group_size + config.quantization = 'awq' + elif model_format == 'qqq': + weight_type = 'int4' + register_name = 'qqq-w4' + from transformers import AutoConfig + quant_config = AutoConfig.from_pretrained( + model_path).quantization_config + group_size = quant_config['group_size'] + config.quantization = 'qqq' else: torch_dtype = getattr(model_config, 'torch_dtype', 'float16') TORCH_DTYPE_MAP = {torch.bfloat16: 'bf16', torch.float16: 'fp16'} @@ -212,17 +224,19 @@ def main(model_name: str, model_name (str): unused any longer model_path (str): the directory path of the model model_format (str): the format of the model, should choose from - ['meta_llama', 'hf', 'awq', None]. 'meta_llama' stands for META's - llama format, 'hf' means huggingface llama format, and 'awq' means - llama(hf) model quantized by lmdeploy/lite/quantization/awq.py. - The default value is None - chat_template (str): the name of the built-in chat template. + ['meta_llama', 'hf', 'awq', 'qqq', None]. 'meta_llama' stands for + META's llama format, 'hf' means huggingface llama format, + 'awq' means llama(hf) model quantized by + lmdeploy/lite/quantization/awq.py, + and 'qqq' means llama(hf) model quantized by the repo + https://github.com/HandH1998/QQQ, + the default value is None tokenizer_path (str): the path of tokenizer model dst_path (str): the destination path that saves outputs tp (int): the number of GPUs used for tensor parallelism, should be 2^n quant_path (str): Path of the quantized model, which can be None. - group_size (int): a parameter used in AWQ to quantize fp16 weights - to 4 bits + group_size (int): a parameter used in AWQ or QQQ to quantize fp16 + weights to 4 bits revision (str): The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. diff --git a/lmdeploy/turbomind/deploy/source_model/__init__.py b/lmdeploy/turbomind/deploy/source_model/__init__.py index 5ba4badb5..aa86e3a1a 100644 --- a/lmdeploy/turbomind/deploy/source_model/__init__.py +++ b/lmdeploy/turbomind/deploy/source_model/__init__.py @@ -9,6 +9,7 @@ from .internvl import InternVLModel # noqa: F401 from .llama import LlamaModel # noqa: F401 from .llama_awq import LlamaAwqModel # noqa: F401 +from .llama_qqq import LlamaQQQModel # noqa: F401 from .meta_llama import MetaLlamaModel # noqa: F401 from .minicpmv import MiniCPMVModel # noqa: F401 from .minicpmv_awq import MiniCPMVAwqModel # noqa: F401 diff --git a/lmdeploy/turbomind/deploy/source_model/llama_qqq.py b/lmdeploy/turbomind/deploy/source_model/llama_qqq.py new file mode 100644 index 000000000..efcdf7d07 --- /dev/null +++ b/lmdeploy/turbomind/deploy/source_model/llama_qqq.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from .base import INPUT_MODELS +from .llama import LlamaModel, LlamaReader + + +def ensure_dtype(tensors: torch.Tensor, dtype: torch.dtype): + """Ensure tensors in the specified dytpe.""" + result = [] + for tensor in tensors: + if tensor is not None and tensor.numel() > 0: + if tensor.dtype in [torch.float16, torch.float32, torch.bfloat16]: + result.append(tensor.to(dtype)) + else: + assert tensor.dtype == torch.int32 + result.append(tensor) + else: + result.append(None) + return (*result, ) + + +class LlamaQQQReader(LlamaReader): + """LlamaQQQReader.""" + + def __init__(self, new_params: dict, unused_params: dict, last_bin: bool, + model_cfg: dict): + super().__init__(new_params, unused_params, last_bin, model_cfg) + + def attn(self, i: int): + """Get q, k, v, o qweight for layer i.""" + return ensure_dtype(self._attn(i, 'B'), torch.int32) + + def attn_scale_group(self, i: int): + """Get q, k, v, o per-group scales for layer i.""" + return ensure_dtype(self._attn(i, 's_group'), torch.float16) + + def attn_scale_channel(self, i: int): + """Get q, k, v, o per-channel scales for layer i.""" + return ensure_dtype(self._attn(i, 's_channel'), torch.float32) + + def ffn(self, i: int): + """Get ffn qweight for layer i.""" + return ensure_dtype(self._ffn(i, 'B'), torch.int32) + + def ffn_scale_group(self, i: int): + """Get ffn per-group scales for layer i.""" + return ensure_dtype(self._ffn(i, 's_group'), torch.float16) + + def ffn_scale_channel(self, i: int): + """Get ffn per-channel scales for layer i.""" + return ensure_dtype(self._ffn(i, 's_channel'), torch.float32) + + +@INPUT_MODELS.register_module(name='llama-qqq') +class LlamaQQQModel(LlamaModel): + """Llama QQQ model in hf format.""" + + Reader = LlamaQQQReader + + def __init__(self, + model_path: str, + tokenizer_path: str, + ckpt_path: str = None, + **kwargs): + super().__init__(model_path, + tokenizer_path, + ckpt_path=ckpt_path, + **kwargs) diff --git a/lmdeploy/turbomind/deploy/target_model/__init__.py b/lmdeploy/turbomind/deploy/target_model/__init__.py index c40f9224f..fecefc643 100644 --- a/lmdeploy/turbomind/deploy/target_model/__init__.py +++ b/lmdeploy/turbomind/deploy/target_model/__init__.py @@ -2,4 +2,5 @@ from .fp import TurbomindModel # noqa: F401 from .plora import TurbomindPloraModel # noqa: F401 from .plora_w4 import TurbomindPloraW4Model # noqa: F401 +from .qqq_w4 import TurbomindQQQW4Model # noqa: F401 from .w4 import TurbomindW4Model # noqa: F401 diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index ef1473bbe..05eda0948 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -66,6 +66,7 @@ class TurbomindModelConfig: max_prefill_iters: int = 1 use_context_fmha: int = 1 quant_policy: int = 0 + quantization: str = '' max_position_embeddings: int = 0 original_max_position_embeddings: int = 0 rope_scaling_type: str = '' diff --git a/lmdeploy/turbomind/deploy/target_model/qqq_w4.py b/lmdeploy/turbomind/deploy/target_model/qqq_w4.py new file mode 100644 index 000000000..75b699cf9 --- /dev/null +++ b/lmdeploy/turbomind/deploy/target_model/qqq_w4.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import List + +import numpy as np +import torch + +from ..source_model.base import BaseInputModel, BaseReader +from .base import (OUTPUT_MODELS, BaseOutputModel, TurbomindModelConfig, + merge_qkv, permute, tprint) + + +def get_cuda_tensor(tensors): + """Get cuda tensor.""" + result = map(lambda x: x.cuda() if x is not None else x, tensors) + return (*result, ) + + +def get_qqq_perms(group_size: int): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), 4 * (i % 4) + 1, 4 * (i % 4) + 2, + 4 * (i % 4) + 3 + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm) + if group_size == -1: + interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm = torch.from_numpy(np.array(scale_perm)) + scale_perm_single = torch.from_numpy(np.array(scale_perm_single)) + return perm, scale_perm, scale_perm_single + + +def pack(w: torch.Tensor, + s_channel: torch.Tensor, + s_group: torch.Tensor, + group_size: int, + tile: int = 16): + assert w.dim() == 2 + infeatures, outfeatures = w.shape[0], w.shape[1] + _perm, _scale_perm, _scale_perm_single = get_qqq_perms(group_size) + org_device = w.device + # permute scales + if group_size != -1 and group_size < infeatures: + s_group = s_group.reshape((-1, len(_scale_perm)))[:, _scale_perm] + s_group = s_group.reshape((-1, outfeatures)).contiguous() + s_channel = s_channel.reshape( + (-1, len(_scale_perm_single)))[:, _scale_perm_single] + s_channel = s_channel.reshape((-1, outfeatures)).contiguous() + # permute and pack weight + w = w.reshape(( + infeatures // tile, + tile, + outfeatures // tile, + tile, + )) + w = w.permute((0, 2, 1, 3)) + w = w.reshape((infeatures // tile, outfeatures * tile)) + res = w + res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) + q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) + res = res.cpu().numpy().astype(np.uint32) + if group_size != -1 and group_size < infeatures: + for i in range(8): + q |= res[:, i::8] << 4 * i + else: + for i in range(8): + q |= (res[:, i::8] & 0xF) << 4 * i + q = torch.from_numpy(q.astype(np.int32)).to(org_device) + return q, s_channel, s_group + + +def unpack(w: torch.Tensor, + s_channel: torch.Tensor, + s_group: torch.Tensor, + group_size: int, + tile: int = 16, + wbits: int = 4): + assert w.dim() == 2 + pack_factor = 32 // wbits + infeatures = w.shape[0] * tile + outfeatures = w.shape[1] * pack_factor // tile + org_device = w.device + _perm, _scale_perm, _scale_perm_single = get_qqq_perms(group_size) + wf = torch.tensor(list(range(0, 32, 4)), + dtype=torch.int32).unsqueeze(0).to(org_device) + # unpack weight + weight = torch.bitwise_right_shift( + torch.unsqueeze(w, 2).expand(-1, -1, 32 // wbits), + wf.unsqueeze(0), + ) + weight = torch.bitwise_and(weight, (2**wbits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + + # reshape weight and scale + _perm_inv = torch.argsort(_perm) + _scale_perm_inv = torch.argsort(_scale_perm) + _scale_perm_single_inv = torch.argsort(_scale_perm_single) + + weight = weight.reshape(-1, _perm.numel())[:, _perm_inv] + weight = weight.reshape(( + infeatures // tile, + outfeatures // tile, + tile, + tile, + )) + weight = weight.permute((0, 2, 1, 3)) + weight = weight.reshape((infeatures, outfeatures)) + s_channel = s_channel.reshape( + -1, len(_scale_perm_single))[:, _scale_perm_single_inv].reshape( + -1, outfeatures) + if group_size != -1 and group_size < infeatures: + s_group = s_group.reshape( + -1, len(_scale_perm))[:, _scale_perm_inv].reshape(-1, outfeatures) + + return weight, s_channel, s_group + + +def permute_qk(w: torch.Tensor, + s_channel: torch.Tensor, + s_group: torch.Tensor, + group_size: int, + size_per_head: int = 128): + unp_w, unp_s_channel, unp_s_group = unpack(w, s_channel, s_group, + group_size) + dim = unp_w.shape[-1] + n_heads = dim // size_per_head + perm_w = unp_w.view(-1, n_heads, 2, + dim // n_heads // 2).transpose(2, 3).reshape(-1, dim) + perm_s_channel = unp_s_channel.view(-1, n_heads, + 2, dim // n_heads // 2).transpose( + 2, 3).reshape(-1, dim) + perm_s_group = unp_s_group + if group_size != -1 and group_size < unp_w.shape[0]: + perm_s_group = unp_s_group.view(-1, n_heads, + 2, dim // n_heads // 2).transpose( + 2, 3).reshape(-1, dim) + p_w, p_s_channel, p_s_group = pack(perm_w, perm_s_channel, perm_s_group, + group_size) + return p_w, p_s_channel, p_s_group + + +@OUTPUT_MODELS.register_module(name='qqq-w4') +class TurbomindQQQW4Model(BaseOutputModel): + """Export to turbomind QQQ w4a8 format.""" + + def __init__(self, + input_model: BaseInputModel, + cfg: TurbomindModelConfig, + out_dir: str = ''): + self.weight_bits = 4 + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + self.tile_size = 16 + # supported group size + self.supported_group_size = [-1, 128] + # Min out_features dim + self.min_n_threads = 64 + # Min in_features dim + self.min_k_threads = 128 + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + super().__init__(input_model, cfg, out_dir) + + def get_config(self, cfg: TurbomindModelConfig): + """Get turbomind config.""" + final_cfg = super().get_config(cfg).__dict__ + + # attn_bias, inter_size + visit = False + attn_bias = 0 + for bin in self.input_model.bins(): + for i in range(bin.start_layer_id, bin.end_layer_id): + visit = True + w1s, _, _ = bin.ffn_scale_channel(i) + inter_size = w1s.shape[-1] + qb, _, _, _ = bin.attn_bias(i) + if qb is not None: + attn_bias = 1 + break + if visit: + break + final_cfg.update(dict(attn_bias=attn_bias, inter_size=inter_size)) + final_cfg = TurbomindModelConfig.from_dict(final_cfg) + + if final_cfg.group_size not in self.supported_group_size: + raise ValueError(f'The group_size of QQQ should be in' + f'{self.supported_group_size}') + # check weight size + hidden_size = final_cfg.head_num * final_cfg.size_per_head + merge_qkv_size = (final_cfg.head_num + + 2 * final_cfg.kv_head_num) * final_cfg.size_per_head + tp = final_cfg.tensor_para_size + weight_info = { + 'mgrge_qkv': { + 'weight_size': [hidden_size, merge_qkv_size], + 'split_dim': -1 + }, + 'o': { + 'weight_size': [hidden_size, hidden_size], + 'split_dim': 0 + }, + 'w1': { + 'weight_size': [hidden_size, inter_size], + 'split_dim': -1 + }, + 'w2': { + 'weight_size': [hidden_size, inter_size], + 'split_dim': -1 + }, + 'w3': { + 'weight_size': [inter_size, hidden_size], + 'split_dim': 0 + }, + } + for weight_name, split_info in weight_info.items(): + self.check_weight_size(weight_name, split_info['weight_size'], tp, + final_cfg.group_size, + split_info['split_dim']) + return final_cfg + + def check_weight_size(self, weight_name: str, weight_size: List[int], + tp: int, group_size: int, split_dim: int): + assert weight_size[ + split_dim] % tp == 0, 'The split size must be divisible by tp size' + input_size_per_partition = weight_size[ + 0] // tp if split_dim == 0 else weight_size[0] + output_size_per_partition = weight_size[ + -1] // tp if split_dim == -1 else weight_size[-1] + # Validate output_size_per_partition + if output_size_per_partition % self.min_n_threads != 0: + raise ValueError( + f'{weight_name} weight output_size_per_partition = ' + f'{output_size_per_partition} is not divisible by ' + f'min_n_threads = {self.min_n_threads}.') + if output_size_per_partition % self.pack_factor != 0: + raise ValueError( + f'{weight_name} weight output_size_per_partition = ' + f'{output_size_per_partition} is not divisible by ' + f'pack_factor = {self.pack_factor}.') + + # Validate input_size_per_partition + if input_size_per_partition % self.min_k_threads != 0: + raise ValueError( + f'{weight_name} weight input_size_per_partition = ' + f'{input_size_per_partition} is not divisible by ' + f'min_k_threads = {self.min_k_threads}.') + if (group_size != -1 and input_size_per_partition % group_size != 0): + raise ValueError( + f'{weight_name} weight input_size_per_partition = ' + f'{input_size_per_partition} is not divisible by ' + f'group_size = {group_size}.') + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.perm_len // (self.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + 'Each permutation group must reside on the same gpu') + + def export_transformer_block(self, bin: BaseReader, i: int): + """Export transformer layer i.""" + group_size = self.cfg.group_size + tp = self.cfg.tensor_para_size + size_per_head = self.cfg.size_per_head + # attn + q_qw, k_qw, v_qw, o_qw = get_cuda_tensor(bin.attn(i)) + q_sc, k_sc, v_sc, o_sc = get_cuda_tensor(bin.attn_scale_channel(i)) + q_sg, k_sg = None, None + if group_size != -1: + q_sg, k_sg, v_sg, o_sg = get_cuda_tensor(bin.attn_scale_group(i)) + + # TODO(HandH1998): verify correctness + q_qw, q_sc, q_sg = permute_qk(q_qw, q_sc, q_sg, group_size, + size_per_head) + k_qw, k_sc, k_sg = permute_qk(k_qw, k_sc, k_sg, group_size, + size_per_head) + + qkv_qw = merge_qkv(q_qw, k_qw, v_qw, tp, dim=2) + qkv_sc = merge_qkv(q_sc, k_sc, v_sc, tp, dim=2) + + self.save_split(qkv_qw, f'layers.{i}.attention.w_qkv.qweight', -1) + self.save_split(qkv_sc, f'layers.{i}.attention.w_qkv.scales_channel', + -1) + + self.save_split(o_qw, f'layers.{i}.attention.wo.qweight', 0) + # TODO(HandH1998): verify tp > 1 + self.save_split(o_sc, + f'layers.{i}.attention.wo.scales_channel', + copy=True) + + if group_size != -1: + qkv_sg = merge_qkv(q_sg, k_sg, v_sg, tp, dim=2) + self.save_split(qkv_sg, f'layers.{i}.attention.w_qkv.scales_zeros', + -1) + self.save_split(o_sg, f'layers.{i}.attention.wo.scales_zeros', 0) + + q_b, k_b, v_b, o_b = get_cuda_tensor(bin.attn_bias(i)) + if q_b is not None: + q_b = permute(q_b, size_per_head) + k_b = permute(k_b, size_per_head) + qkv_b = merge_qkv(q_b, k_b, v_b, tp, dim=1) + self.save_split(qkv_b, f'layers.{i}.attention.w_qkv.bias', -1) + self.save_split(o_b, f'layers.{i}.attention.wo.bias', copy=True) + + # ffn weights + w1_qw, w2_qw, w3_qw = get_cuda_tensor(bin.ffn(i)) + w1_sc, w2_sc, w3_sc = get_cuda_tensor(bin.ffn_scale_channel(i)) + + self.save_split(w1_qw, f'layers.{i}.feed_forward.w1.qweight', -1) + self.save_split(w1_sc, f'layers.{i}.feed_forward.w1.scales_channel', + -1) + self.save_split(w3_qw, f'layers.{i}.feed_forward.w3.qweight', -1) + self.save_split(w3_sc, f'layers.{i}.feed_forward.w3.scales_channel', + -1) + self.save_split(w2_qw, f'layers.{i}.feed_forward.w2.qweight', 0) + # TODO(HandH1998): verify tp > 1 + self.save_split(w2_sc, + f'layers.{i}.feed_forward.w2.scales_channel', + copy=True) + + if group_size != -1: + w1_sg, w2_sg, w3_sg = get_cuda_tensor(bin.ffn_scale_group(i)) + self.save_split(w1_sg, f'layers.{i}.feed_forward.w1.scales_zeros', + -1) + self.save_split(w3_sg, f'layers.{i}.feed_forward.w3.scales_zeros', + -1) + self.save_split(w2_sg, f'layers.{i}.feed_forward.w2.scales_zeros', + 0) + + # norm + attn_norm = bin.attn_norm(i) + ffn_norm = bin.ffn_norm(i) + self.save_split(attn_norm, f'layers.{i}.attention_norm.weight') + self.save_split(ffn_norm, f'layers.{i}.ffn_norm.weight') + + # override `export_weight` + def export_weight(self, param: torch.Tensor, name: str) -> None: + """export turbomind weight.""" + + def _tofile(tensor, path): + """to file.""" + if tensor.dtype == torch.bfloat16: + tensor = tensor.view(torch.half) + tensor.contiguous().cpu().numpy().tofile(path) + + if self.to_file: + tprint(name, param.shape) + _tofile(param, osp.join(self.out_dir, name)) + elif len(self.tm_params) > 0: + tm_params = self.tm_params + # currently, the tensor type should in + # [torch.float, torch.half, torch.bfloat16, torch.int32] + torch_tensor = param.cuda().contiguous() + assert torch_tensor.dtype in [ + torch.int32, torch.float, torch.half, torch.bfloat16 + ] + for tm_tensor in tm_params[name]: + tm_tensor.copy_from(torch_tensor) + tm_params.pop(name) + else: + tprint('skip export', name, param.shape) diff --git a/src/turbomind/kernels/CMakeLists.txt b/src/turbomind/kernels/CMakeLists.txt index b7ef1c725..01a7914b0 100644 --- a/src/turbomind/kernels/CMakeLists.txt +++ b/src/turbomind/kernels/CMakeLists.txt @@ -58,7 +58,12 @@ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +add_library(quant_kernels STATIC quant_kernels.cu) +set_property(TARGET quant_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET quant_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + add_subdirectory(gemm_s_f16) +add_subdirectory(marlin_qqq_gemm) if (BUILD_TEST) add_subdirectory(flash_attention) endif () diff --git a/src/turbomind/kernels/activation_kernels.cu b/src/turbomind/kernels/activation_kernels.cu index 0bd76c36a..1ea86f1c9 100644 --- a/src/turbomind/kernels/activation_kernels.cu +++ b/src/turbomind/kernels/activation_kernels.cu @@ -15,6 +15,7 @@ */ #include "src/turbomind/kernels/activation_kernels.h" +#include "src/turbomind/kernels/reduce_kernel_utils.cuh" #include "src/turbomind/macro.h" #include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_utils.h" @@ -169,13 +170,15 @@ struct IdentityActivation { }; // clang-format off -template class Activation, typename T, typename BT> +template class Activation, typename T, typename BT, typename QT, bool enable_quant> __global__ void generic_activation(T* out, const BT* __restrict bias, const T* __restrict gated_weights, const BT* __restrict gated_bias, const int* __restrict ia3_tasks, const T* __restrict ia3_weights, + QT* __restrict quant_out, + float* __restrict quant_scale, const int int8_mode, const float* __restrict activation_in, const float* __restrict activation_out, @@ -188,57 +191,49 @@ __global__ void generic_activation(T* out, const bool with_bias = bias != nullptr; const bool with_gate = gated_weights != nullptr; - // const bool with_ia3 = ia3_tasks != nullptr; using Act_T = typename Activation::return_type; - using Float_T = typename packed_as::type; - using Packed_Int8_t = typename packed_as::type; - - for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) { - T val; - if (int8_mode == 2) { - // val = cuda_cast(cuda_cast(reinterpret_cast(out)[id]) * activation_in[0]); + using Single_T = typename packed_as::type; + using Packed_Float = typename packed_as::type; + + if constexpr (enable_quant) { + __shared__ float s_amax; + float amax_val = 0.0f; + for (int64_t id = threadIdx.x; id < n; id += blockDim.x) { + T val = out[blockIdx.x * n + id]; + T gated_val; + if (with_gate) { + gated_val = gated_weights[blockIdx.x * n + id]; + val = cuda_cast(Activation::apply(val) * cuda_cast(gated_val)); + } + if (int8_mode != 2) { + out[blockIdx.x * n + id] = val; + } + amax_val = cuda_max(amax_val, cuda_cast(cuda_max(cuda_abs(val)))); } - else { - val = out[id]; + amax_val = blockReduceMax(amax_val); + if (threadIdx.x == 0) { + s_amax = amax_val; + quant_scale[blockIdx.x] = amax_val / 127.0f; } - - T gated_val; - if (with_gate) { - gated_val = gated_weights[id]; + __syncthreads(); + const float tmp_scale = 127.0f / s_amax; + for (int64_t id = threadIdx.x; id < n; id += blockDim.x) { + T val = out[blockIdx.x * n + id]; + Packed_Float tmp = cuda_cast(val) * tmp_scale; + quant_out[blockIdx.x * n + id] = cuda_cast(tmp); } - - // if (with_bias) { - // const T reg_bias = static_cast(bias[id % n]); - // val = val + reg_bias; - - // if (with_gate) { - // const T reg_gated_bias = static_cast(gated_bias[id % n]); - // gated_val = gated_val + reg_gated_bias; - // } - // } - - if (with_gate) { - val = cuda_cast(Activation::apply(val) * cuda_cast(gated_val)); - } - else { - // val = cuda_cast(Activation::apply(val)); - } - - // if (with_ia3) { - // const int word_id = id / n; - // const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; - // const int batch_id = (word_id + offset) / seq_len; - // const int task = ia3_tasks[batch_id]; - // val = val * ia3_weights[task * n + (id % n)]; - // } - - if (int8_mode != 2) { - out[id] = val; - } - else { - // reinterpret_cast(out)[id] = - // cuda_cast(cuda_cast(val) * activation_out[0]); + } else { + for (int64_t id = threadIdx.x; id < n; id += blockDim.x) { + T val = out[blockIdx.x * n + id]; + T gated_val; + if (with_gate) { + gated_val = gated_weights[blockIdx.x * n + id]; + val = cuda_cast(Activation::apply(val) * cuda_cast(gated_val)); + } + if (int8_mode != 2) { + out[blockIdx.x * n + id] = val; + } } } } @@ -251,6 +246,8 @@ void invokeGenericActivation(T* out, const BT* gated_bias, const int* ia3_tasks, const T* ia3_weights, + int8_t* quant_out, + float* quant_scale, const int m, const int n, const int int8_mode, @@ -265,33 +262,53 @@ void invokeGenericActivation(T* out, using PT = typename packed_type::type; constexpr int packed_elems = num_elems::value; using PBT = typename packed_as::type; + using PQT = typename packed_as::type; + + if (packed_elems > 1) { + FT_CHECK(n % packed_elems == 0); + } - const int n_threads = 512; + dim3 grid(m); + dim3 block(std::min(n, 1024)); - dim3 block, grid; - if (n / 4 / packed_elems <= n_threads) { - block.x = n / 4 / packed_elems; - grid.x = m; + TM_LOG_DEBUG("%d %d", grid.x, block.x); + sync_check_cuda_error(); + if (quant_out == nullptr) { + generic_activation + <<>>(reinterpret_cast(out), + reinterpret_cast(bias), + reinterpret_cast(gated_weights), + reinterpret_cast(gated_bias), + ia3_tasks, + reinterpret_cast(ia3_weights), + reinterpret_cast(quant_out), + quant_scale, + int8_mode, + activation_in, + activation_out, + padding_offset, + seq_len, + m, + n / packed_elems); } else { - block.x = n_threads; - grid.x = ceil(1LL * m * n / double(n_threads)); + generic_activation + <<>>(reinterpret_cast(out), + reinterpret_cast(bias), + reinterpret_cast(gated_weights), + reinterpret_cast(gated_bias), + ia3_tasks, + reinterpret_cast(ia3_weights), + reinterpret_cast(quant_out), + quant_scale, + int8_mode, + activation_in, + activation_out, + padding_offset, + seq_len, + m, + n / packed_elems); } - TM_LOG_DEBUG("%d %d", grid.x, block.x); - sync_check_cuda_error(); - generic_activation<<>>(reinterpret_cast(out), - reinterpret_cast(bias), - reinterpret_cast(gated_weights), - reinterpret_cast(gated_bias), - ia3_tasks, - reinterpret_cast(ia3_weights), - int8_mode, - activation_in, - activation_out, - padding_offset, - seq_len, - m, - n / packed_elems); sync_check_cuda_error(); } @@ -302,6 +319,8 @@ void invokeGenericActivation(T* out, const BT* gated_bias, \ const int* ia3_tasks, \ const T* ia3_weights, \ + int8_t* quant_out, \ + float* quant_scale, \ const int m, \ const int n, \ const int int8_mode, \ diff --git a/src/turbomind/kernels/activation_kernels.h b/src/turbomind/kernels/activation_kernels.h index 776b614c9..e754dbc80 100644 --- a/src/turbomind/kernels/activation_kernels.h +++ b/src/turbomind/kernels/activation_kernels.h @@ -37,6 +37,8 @@ void invokeGenericActivation(T* out, const BT* gated_bias, const int* ia3_tasks, const T* ia3_weights, + int8_t* quant_out, + float* quant_scale, const int m, const int n, const int int8_mode, @@ -53,6 +55,8 @@ void invokeGenericActivation(T* out, const BT* gated_bias, const int* ia3_tasks, const T* ia3_weights, + int8_t* quant_out, + float* quant_scale, const int m, const int n, const int int8_mode, @@ -66,6 +70,8 @@ void invokeGenericActivation(T* out, gated_bias, ia3_tasks, ia3_weights, + nullptr, + nullptr, m, n, int8_mode, diff --git a/src/turbomind/kernels/marlin_qqq_gemm/CMakeLists.txt b/src/turbomind/kernels/marlin_qqq_gemm/CMakeLists.txt new file mode 100644 index 000000000..b7bcc0367 --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +add_library(marlin_qqq_gemm STATIC marlin_qqq_gemm.cc marlin_qqq_gemm_kernel.cu) +set_property(TARGET marlin_qqq_gemm PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET marlin_qqq_gemm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/turbomind/kernels/marlin_qqq_gemm/common.h b/src/turbomind/kernels/marlin_qqq_gemm/common.h new file mode 100644 index 000000000..326cce57b --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/common.h @@ -0,0 +1,255 @@ +/* + * Modified by HandH1998 + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace turbomind { +namespace marlin_qqq { + +constexpr int ceildiv(int a, int b) +{ + return (a + b - 1) / b; +} + +template +inline std::string str(T x) +{ + return std::to_string(x); +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) + { + return elems[i]; + } +}; + +using I4 = Vec; +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS_GROUP = Vec; // weight per-group quantization scales +using FragS_CHANNEL = Vec; // weight per-channel quantization scales or activaton + // per-token quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) +{ + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) +{ + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() +{ + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) +{ + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) +{ + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, +// cp.async.ca can support BYTES = 4, 8, 16; +// as s1's shape is equal to prob_m, we need set s1 to float type, +// and cp_size = 1 float, i.e., 4 BYTES +// Asynchronous global->shared copy for activation quantizaton scales s1 +__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) +{ + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +// m16n8k16 tensor core mma instruction with int8 inputs and int32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) +{ + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + int* c = reinterpret_cast(&frag_c); + asm volatile("mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in int8 tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) +{ + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); +} + +inline __device__ half2 float2_to_half2(float2 f) +{ + uint32_t res; + // NOTE(HandH1998): h0,h1 should be uint16_t, not half + uint16_t h0, h1; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); + return reinterpret_cast(res); +} + +inline __device__ float int32_to_float(int h) +{ + float res; + asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); + return res; +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) +{ + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per channel dequant. +__device__ inline FragB dequant_per_channel(int q) +{ + static constexpr int MASK = 0xf0f0f0f0; + FragB frag_b; + frag_b[0] = (q & MASK); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per group dequant. +__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) +{ + static constexpr uint32_t LO = 0x000f000f; + static constexpr uint32_t HI = 0x00f000f0; + static constexpr uint32_t EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + static constexpr uint32_t SUB = 0x64086408; + static constexpr uint32_t MUL = 0x2c002c00; + static constexpr uint32_t ADD = 0xd480d480; + *reinterpret_cast(&t0) = __hsub2(*reinterpret_cast(&t0), *reinterpret_cast(&SUB)); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + + uint16_t s = reinterpret_cast(&frag_s)[i]; + uint32_t double_s; + // pack 2xfp16 to half2 + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); + // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 + // half, respectively) + static constexpr uint32_t MAGIC_NUM = 0x64806480; + *reinterpret_cast(&t0) = __hfma2(*reinterpret_cast(&t0), + *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + *reinterpret_cast(&t1) = __hfma2(*reinterpret_cast(&t1), + *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 + // int8 into 1 uint32 + FragB frag_b; + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(uint8s) : "r"(t0), "r"(t1), "n"(MASK_0246)); + frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); + return frag_b; +} + +} // namespace marlin_qqq +} // namespace turbomind diff --git a/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.cc b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.cc new file mode 100644 index 000000000..5c80669dc --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.cc @@ -0,0 +1,199 @@ +/* + * Adapted from + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp + * Modified by HandH1998 + * Copyright (C) 2024 HandH1998 + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "marlin_qqq_gemm.h" +#include "marlin_qqq_gemm_kernel.h" +#include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/string_utils.h" + +namespace turbomind { +namespace marlin_qqq { + +void MarlinQQQGemm::allocateBuffer(size_t workspace_size, size_t reduce_buf_size) +{ + workspace_buf_ = (int*)allocator_->reMalloc(workspace_buf_, sizeof(int) * workspace_size, true); + reduce_buf_ = (int*)allocator_->reMalloc(reduce_buf_, sizeof(int) * reduce_buf_size, false); + is_allocate_buffer_ = true; +} + +void MarlinQQQGemm::freeBuffer() +{ + if (is_allocate_buffer_) { + allocator_->free((void**)&workspace_buf_); + allocator_->free((void**)&reduce_buf_); + is_allocate_buffer_ = false; + } +} + +bool MarlinQQQGemm::is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k) +{ + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +MarlinQQQGemm::thread_config_t MarlinQQQGemm::determine_thread_config(int prob_m, int prob_n, int prob_k) +{ + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS \ + && thread_k_blocks == THREAD_K_BLOCKS && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) \ + { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, D_ptr, s1_ptr, s2_ptr, s3_ptr, prob_m, prob_n, prob_k, locks); \ + } + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void MarlinQQQGemm::Run(half* D, + const int8_t* A, + const uint* B, + const float* s1, + const float* s2, + const half* s3, + int prob_m, + int prob_n, + int prob_k, + int groupsize, + cudaStream_t stream) +{ + + // Set thread config + thread_config_t th_config = determine_thread_config(prob_m, prob_n, prob_k); + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + fmtstr("Invalid thread config: thread_k = %d, thread_n = %d, num_threads = %d for MKN = [%d, %d, %d]", + th_config.thread_k, + th_config.thread_n, + th_config.num_threads, + prob_m, + prob_k, + prob_n)); + } + + int num_threads = th_config.num_threads; + int thread_k = th_config.thread_k; + int thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + // QQQ only supports groupsize = -1 or 128 + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + + if (group_blocks != -1) { + FT_CHECK_WITH_INFO(prob_k % group_blocks == 0, + fmtstr("prob_k = %d is not divisible by group_blocks = %d", prob_k, group_blocks)); + } + FT_CHECK_WITH_INFO(prob_k % tile_size == 0, + fmtstr("prob_k = %d is not divisible by tile_size = %d", prob_k, tile_size)); + FT_CHECK_WITH_INFO(prob_n % min_thread_n == 0, + fmtstr("prob_n = %d is not divisible by min_thread_n = %d", prob_n, min_thread_n)); + + if (reduce_buf_ == nullptr || workspace_buf_ == nullptr) { + size_t workspace_size = (prob_n / min_thread_n) * max_par; + size_t reduce_buf_size = max_par * 64 * prob_n; + allocateBuffer(workspace_size, reduce_buf_size); + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)reduce_buf_; + int4* D_ptr = (int4*)D; + const float* s1_ptr = (const float*)s1; + const int4* s2_ptr = (const int4*)s2; + const int4* s3_ptr = (const int4*)s3; + int* locks = workspace_buf_; + invokeMarlinQQQGemm(A_ptr, + B_ptr, + C_ptr, + D_ptr, + s1_ptr, + s2_ptr, + s3_ptr, + prob_m, + prob_n, + prob_k, + locks, + thread_n_blocks, + thread_k_blocks, + group_blocks, + num_threads, + stream); + sync_check_cuda_error(); +} +} // namespace marlin_qqq +} // namespace turbomind diff --git a/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.h b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.h new file mode 100644 index 000000000..f2a80fd46 --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.h @@ -0,0 +1,98 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#pragma once + +#include "src/turbomind/utils/allocator.h" +#include +#include +#include + +namespace turbomind { +namespace marlin_qqq { + +class MarlinQQQGemm { +public: + // The parameterless constructor is only called for test by Python through pybind + MarlinQQQGemm() + { + auto allocator = std::make_unique>(0); + allocator_ = allocator.get(); + // The unique_ptr must be saved, or it will be released after constructing + allocator_holder_ = std::move(allocator); + } + + MarlinQQQGemm(IAllocator* allocator): allocator_(allocator) {} + + ~MarlinQQQGemm() + { + freeBuffer(); + } + + void Run(half* D, + const int8_t* A, + const uint* B, + const float* s1, + const float* s2, + const half* s3, + int prob_m, + int prob_n, + int prob_k, + int groupsize, + cudaStream_t stream); + + void setBuffer(int* reduce_buf, int* workspace_buf) + { + reduce_buf_ = reduce_buf; + workspace_buf_ = workspace_buf; + } + + std::pair getBuffer() + { + return std::make_pair(reduce_buf_, workspace_buf_); + } + +private: + // normally the allocation is performed by self-attn or ffn + // allocateBuffer is only called when testing marlin qqq gemm + void allocateBuffer(size_t workspace_size, size_t reduce_buf_size); + + void freeBuffer(); + + int* workspace_buf_{}; + int* reduce_buf_{}; + bool is_allocate_buffer_{}; + // allocator_holder_ is only for test + std::unique_ptr> allocator_holder_; + IAllocator* allocator_; + typedef struct { + int thread_k; + int thread_n; + int num_threads; + } thread_config_t; + + bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k); + + thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k); + + thread_config_t small_batch_thread_configs[4] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N + }; + + thread_config_t large_batch_thread_configs[4] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X + }; +}; +} // namespace marlin_qqq +} // namespace turbomind diff --git a/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.cu b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.cu new file mode 100644 index 000000000..f56358414 --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.cu @@ -0,0 +1,808 @@ +/* + * Adapted from + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp + * Modified by HandH1998 + * Copyright (C) 2024 HandH1998 + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.h" +#include "marlin_qqq_gemm_kernel.h" +#include +#include +#include +#include +#include + +namespace turbomind { +namespace marlin_qqq { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +templateshared + // fetch pipeline + const int group_blocks // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin(const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s1, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s2, // fp32 weight per-channel quantization scales + // of shape 1xn + const int4* __restrict__ s3, // fp16 weight per-group quantization scales + // of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) +{ + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if constexpr (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; + D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + s1 += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 16; + C += 16 * thread_m_blocks * prob_n / 4; + D += 16 * thread_m_blocks * prob_n / 8; + s1 += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 16; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 1 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + constexpr int s1_sh_stride = 16 * thread_m_blocks; + + constexpr int s2_sh_stride = 16 * thread_n_blocks / 4; + + int s3_gl_stride = prob_n / 8; + constexpr int s3_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s3_sh_stage = s3_sh_stride; + int s3_gl_rd_delta = s3_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); + a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s1_gl_rd = threadIdx.x; + // NOTE(HandH1998): activation scale s1 need shuffle to [0, 8, 1, 9, 2, 10, 3, + // 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for + // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as s1's + // size is not fixed, we can not shuffle before inference we shuffle it when + // fetching s1 from global memory to shared memory, that's why s1_sh_wr is + // like this + int s1_sh_wr = (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; + int s1_sh_rd = (threadIdx.x % 32) / 4; + bool s1_sh_wr_pred = threadIdx.x < prob_m; + + int s2_gl_rd = s2_sh_stride * slice_col + threadIdx.x; + int s2_sh_wr = threadIdx.x; + int s2_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + 2 * ((threadIdx.x % 32) % 4); + bool s2_sh_wr_pred = threadIdx.x < s2_sh_stride; + + int s3_gl_rd, s3_sh_wr, s3_sh_rd; + bool s3_sh_wr_pred; + if constexpr (group_blocks != -1) { + s3_gl_rd = + s3_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s3_sh_stride * slice_col + threadIdx.x; + s3_sh_wr = threadIdx.x; + // NOTE(HandH1998): s3_sh_rd is related to mma output C + s3_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s3_sh_wr_pred = threadIdx.x < s3_sh_stride; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + // NOTE(HandH1998): stages need >= 4, otherwise, sh_s1 = sh + max(stages * + // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s1 = sh_b + (stages * b_sh_stage); + int4* sh_s2 = sh_s1 + s1_sh_stride; + int4* sh_s3 = sh_s2 + s2_sh_stride; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS_GROUP frag_s3[2][4]; + FragS_CHANNEL frag_s1[thread_m_blocks]; + FragS_CHANNEL frag_s2[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if constexpr (group_blocks != -1) { + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s3_stage = sh_s3 + s3_sh_stage * pipe; + if (s3_sh_wr_pred) + cp_async4(&sh_s3_stage[s3_sh_wr], &s3[s3_gl_rd]); + s3_gl_rd += s3_gl_rd_delta; + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if constexpr (group_blocks != -1) { + int4* sh_s3_stage = + sh_s3 + s3_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s3[k % 2])[0] = sh_s3_stage[s3_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + // int b_quant_shift = b_quant << 4; + FragB frag_b0, frag_b1; + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + int b_quant_shift = b_quant >> 8; + frag_b0 = dequant_per_group(b_quant, frag_s3[k % 2][j], 0); + frag_b1 = dequant_per_group(b_quant_shift, frag_s3[k % 2][j], 1); + } + else { + int b_quant_shift = b_quant << 4; + frag_b0 = dequant_per_channel(b_quant); + frag_b1 = dequant_per_channel(b_quant_shift); + } +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + int* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + int* c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + int* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + // global_reduce works on INT32 elements, which are the results of INT8 GEMM. + // This is why we need another INT32 maxtrix `C` to reduce instead of the + // original half matrix `D`. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 4; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 8 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; + c_gl_wr += (4 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads * 2; + int c_sh_wr = 2 * threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i + 1], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + 1], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; +#pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + reinterpret_cast(&d_red1)[j]; + } +#pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += + reinterpret_cast(&d_red2)[j]; + } + } + if (!last) { + int4 d1, d2; +#pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d1)[j] = + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; + } +#pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d2)[j] = + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = d1; + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + 1] = d2; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int d_gl_stride = prob_n / 8; + constexpr int d_sh_stride = 2 * thread_n_blocks + 1; + int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int d_sh_rd_delta = d_sh_stride * (threads / (2 * thread_n_blocks)); + + int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + d_gl_wr += (2 * thread_n_blocks) * slice_col; + int d_sh_wr = (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + d_sh_wr += 32 * (threadIdx.x / 32); + int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int d_gl_wr_end = d_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { + float2 deq_res; + deq_res.x = int32_to_float(c0) * w_s[0] * a_s; + deq_res.y = int32_to_float(c1) * w_s[1] * a_s; + ((half2*)sh)[idx] = float2_to_half2(deq_res); + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + int wr = d_sh_wr + 8 * j; + write(wr + (4 * d_sh_stride) * 0 + 0, + frag_c[i][j][0][0], + frag_c[i][j][0][1], + frag_s1[i][0], + frag_s2[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 8 + 0, + frag_c[i][j][0][2], + frag_c[i][j][0][3], + frag_s1[i][1], + frag_s2[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 0 + 4, + frag_c[i][j][1][0], + frag_c[i][j][1][1], + frag_s1[i][0], + frag_s2[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * d_sh_stride) * 8 + 4, + frag_c[i][j][1][2], + frag_c[i][j][1][3], + frag_s1[i][1], + frag_s2[j / 2][2 * (j % 2) + 1]); + } + d_sh_wr += 16 * (4 * d_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (d_gl_wr < d_gl_wr_end) { + D[d_gl_wr] = sh[d_sh_rd]; + d_gl_wr += d_gl_wr_delta; + d_sh_rd += d_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { +#pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { +// We unroll over both the global fetch and the register load pipeline to +// ensure all shared memory accesses are static. Note that both pipelines have +// even length meaning that the next iteration will always start at index 0. +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (last) { + if (s1_sh_wr_pred) { + cp_async1(&sh_s1[s1_sh_wr], &s1[s1_gl_rd]); + } + if (s2_sh_wr_pred) { + cp_async4(&sh_s2[s2_sh_wr], &s2[s2_gl_rd]); + } + cp_async_fence(); + } + thread_block_reduce(); + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + frag_s1[i][0] = *reinterpret_cast(&sh_s1[16 * i + 2 * s1_sh_rd]); + frag_s1[i][1] = *reinterpret_cast(&sh_s1[16 * i + 2 * s1_sh_rd + 1]); + } + reinterpret_cast(&frag_s2)[0] = sh_s2[s2_sh_rd + 0]; + reinterpret_cast(&frag_s2)[1] = sh_s2[s2_sh_rd + 1]; + reinterpret_cast(&frag_s2)[2] = sh_s2[s2_sh_rd + 8]; + reinterpret_cast(&frag_s2)[3] = sh_s2[s2_sh_rd + 9]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + s3_gl_rd = s3_sh_stride * slice_col + threadIdx.x; + s2_gl_rd = s2_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#else + +templateshared + // fetch pipeline + const int group_blocks // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin(const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s1, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s2, // fp32 weight per-channel quantization scales + // of shape 1xn + const int4* __restrict__ s3, // fp16 weight per-group quantization scales + // of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) +{ + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS \ + && thread_k_blocks == THREAD_K_BLOCKS && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) \ + { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + Marlin \ + <<>>(A, B, C, D, s1, s2, s3, prob_m, prob_n, prob_k, locks); \ + } + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void invokeMarlinQQQGemm(const int4* A, + const int4* B, + int4* C, + int4* D, + const float* s1, + const int4* s2, + const int4* s3, + int prob_m, + int prob_n, + int prob_k, + int* locks, + int thread_n_blocks, + int thread_k_blocks, + int group_blocks, + int num_threads, + cudaStream_t stream) +{ + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + int dev = 0; + cudaGetDevice(&dev); + + int sms = 0; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int blocks = sms; + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) {} + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else + { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]" + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A += 16 * thread_m_blocks * (prob_k / 16) * par; + D += 16 * thread_m_blocks * (prob_n / 8) * par; + s1 += 16 * thread_m_blocks * par; + } +} + +} // namespace marlin_qqq +} // namespace turbomind diff --git a/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h new file mode 100644 index 000000000..b457425f1 --- /dev/null +++ b/src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h @@ -0,0 +1,39 @@ +#pragma once + +namespace turbomind { +namespace marlin_qqq { + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +// TODO(HandH1998): 64 or 128? +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; +static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit + +void invokeMarlinQQQGemm(const int4* A, + const int4* B, + int4* C, + int4* D, + const float* s1, + const int4* s2, + const int4* s3, + int prob_m, + int prob_n, + int prob_k, + int* locks, + int thread_n_blocks, + int thread_k_blocks, + int group_blocks, + int num_threads, + cudaStream_t stream); + +} // namespace marlin_qqq +} // namespace turbomind diff --git a/src/turbomind/kernels/quant_kernels.cu b/src/turbomind/kernels/quant_kernels.cu new file mode 100644 index 000000000..b4b9f5384 --- /dev/null +++ b/src/turbomind/kernels/quant_kernels.cu @@ -0,0 +1,94 @@ +#include "quant_kernels.h" +#include "reduce_kernel_utils.cuh" +#include "src/turbomind/utils/cuda_type_utils.cuh" +#include "src/turbomind/utils/cuda_utils.h" + +namespace turbomind { + +template +__global__ void +int8_quant_kernel(const T* __restrict__ input, int8_t* __restrict__ out, float* scale, const int hidden_size) +{ + using T2 = typename TypeConverter::Type; + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + float absmax_val = 0.0f; + + int16_t* out_ptr = (int16_t*)out; + const T2* input_ptr = (const T2*)input; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + T2 val = cuda_abs(input_ptr[token_idx * hidden_size + i]); + T val_max = cuda_max(val); + absmax_val = cuda_max(absmax_val, cuda_cast(val_max)); + } + + const float block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; + } + __syncthreads(); + + const float tmp_scale = 127.0f / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + float2 val = cuda_cast(input_ptr[token_idx * hidden_size + i]); + float2 tmp = val * tmp_scale; + out_ptr[token_idx * hidden_size + i] = cuda_cast(tmp); + } +} + +template<> +__global__ void +int8_quant_kernel(const float* __restrict__ input, int8_t* __restrict__ out, float* scale, const int hidden_size) +{ + const int tid = threadIdx.x; + const int token_idx = blockIdx.x; + float absmax_val = 0.0f; + float const zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + absmax_val = cuda_max(absmax_val, cuda_abs(input[token_idx * hidden_size + i])); + } + + const float block_absmax_val_maybe = blockReduceMax(absmax_val); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; + } + __syncthreads(); + + const float tmp_scale = 127.0f / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = cuda_cast(input[token_idx * hidden_size + i] * tmp_scale); + } +} + +template +void invokeI8Quant(const T* input, int8_t* out, float* scale, const int token_num, int hidden_size, cudaStream_t stream) +{ + if (sizeof(T) == 2) { + FT_CHECK(hidden_size % 2 == 0); + hidden_size /= 2; + } + dim3 grid(token_num); + dim3 block(std::min(hidden_size, 1024)); + int8_quant_kernel<<>>(input, out, scale, hidden_size); + sync_check_cuda_error(); +} + +#define INSTANTIATE_I8QUANT(T) \ + template void invokeI8Quant( \ + const T* input, int8_t* out, float* scale, const int token_num, const int hidden_size, cudaStream_t stream) + +INSTANTIATE_I8QUANT(half); +#ifdef ENABLE_FP32 +INSTANTIATE_I8QUANT(float); +#endif +#ifdef ENABLE_BF16 +INSTANTIATE_I8QUANT(__nv_bfloat16); +#endif + +} // namespace turbomind diff --git a/src/turbomind/kernels/quant_kernels.h b/src/turbomind/kernels/quant_kernels.h new file mode 100644 index 000000000..939d7e16d --- /dev/null +++ b/src/turbomind/kernels/quant_kernels.h @@ -0,0 +1,13 @@ + +#pragma once + +#include +#include + +namespace turbomind { + +template +void invokeI8Quant( + const T* input, int8_t* out, float* scale, const int token_num, const int hidden_size, cudaStream_t stream); + +} // namespace turbomind diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt index 898a74252..e72e3cc5f 100644 --- a/src/turbomind/models/llama/CMakeLists.txt +++ b/src/turbomind/models/llama/CMakeLists.txt @@ -23,6 +23,8 @@ set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(Llama PUBLIC CUDA::cudart gemm_s4_f16 + marlin_qqq_gemm + quant_kernels cublasMMWrapper DynamicDecodeLayer activation_kernels diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc index a522c6c1a..d1b9ff343 100644 --- a/src/turbomind/models/llama/LlamaBatch.cc +++ b/src/turbomind/models/llama/LlamaBatch.cc @@ -719,6 +719,13 @@ void LlamaBatch::AllocateBuffer(size_t batch_size, size_t session_len, int ca context_decoder_output_buf_, sizeof(T) * max_forward_token_num_ * hidden_units, false); } + if (model_->quantization_ == QuantMethod::QQQ) { + context_decoder_quant_output_buf_ = (int8_t*)allocator_->reMalloc( + context_decoder_quant_output_buf_, sizeof(int8_t) * max_context_token_num_ * hidden_units, false); + context_decoder_quant_scale_buf_ = (float*)allocator_->reMalloc( + context_decoder_quant_scale_buf_, sizeof(float) * max_context_token_num_, false); + } + context_decoder_input_buf_ = (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_forward_token_num_ * hidden_units, false); context_decoder_ids_buf_ = @@ -857,6 +864,8 @@ void LlamaBatch::FreeBuffer() allocator_->free((void**)&decoder_input_buf_); allocator_->free((void**)&decoder_output_buf_); + allocator_->free((void**)&context_decoder_quant_output_buf_); + allocator_->free((void**)&context_decoder_quant_scale_buf_); allocator_->free((void**)&input_ids_buf_); allocator_->free((void**)&input_length_buf_); @@ -1644,8 +1653,10 @@ bool LlamaBatch::Forward(GenerationState& g) } model_->forwardUnified(decoder_output_buf_ + first * model_->hidden_units_, - context_decoder_output_buf_, // temp - context_decoder_input_buf_, // temp + context_decoder_output_buf_, // temp + context_decoder_input_buf_, // temp + context_decoder_quant_output_buf_, // temp + context_decoder_quant_scale_buf_, // temp (void**)block_ptrs_, cu_block_counts_ + first, context_decoder_ids_buf_, // temp diff --git a/src/turbomind/models/llama/LlamaBatch.h b/src/turbomind/models/llama/LlamaBatch.h index f0345af6d..80f1ec735 100644 --- a/src/turbomind/models/llama/LlamaBatch.h +++ b/src/turbomind/models/llama/LlamaBatch.h @@ -205,10 +205,12 @@ class LlamaBatch { //////////////////////////////////////////////////////////////////// // context decoding temp buffers - T* context_decoder_input_buf_{}; - T* context_decoder_output_buf_{}; - int* context_decoder_ids_buf_{}; - int* input_ids_buf_{}; + T* context_decoder_input_buf_{}; + T* context_decoder_output_buf_{}; + int8_t* context_decoder_quant_output_buf_{}; + float* context_decoder_quant_scale_buf_{}; + int* context_decoder_ids_buf_{}; + int* input_ids_buf_{}; // lengths int* input_length_buf_{}; // input + cache missed length int* context_length_buf_{}; // history length + input_length diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index d055248d5..aca36e6c2 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -19,6 +19,7 @@ // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoderLayerWeight.cc #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" +#include "src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/memory_utils.h" @@ -27,17 +28,18 @@ namespace turbomind { template -LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, - size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t inter_size, - WeightType weight_type, - int group_size, - LoraParams lora_params, - bool attn_bias, - size_t tensor_para_size, - size_t tensor_para_rank): +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + WeightType weight_type, + int group_size, + QuantMethod quantization, + LoraParams lora_params, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank): head_num_(head_num), kv_head_num_(kv_head_num), size_per_head_(size_per_head), @@ -84,37 +86,45 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, } } } - fused_up_and_gate_ = weight_type_ == WeightType::kINT4 && ffn_weights.gating.lora.policy != LoraPolicy::kPlora; - - self_attn_weights.qkv.input_dims = hidden_units_; - self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; - self_attn_weights.qkv.type = weight_type; - self_attn_weights.qkv.group_size = group_size; - - self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; - self_attn_weights.output.output_dims = hidden_units_; - self_attn_weights.output.type = weight_type; - self_attn_weights.output.group_size = group_size; - - ffn_weights.gating.input_dims = hidden_units_; - ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; - ffn_weights.gating.type = weight_type; - ffn_weights.gating.group_size = group_size; - - ffn_weights.intermediate.input_dims = hidden_units_; - ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; - ffn_weights.intermediate.type = weight_type; - ffn_weights.intermediate.group_size = group_size; - - ffn_weights.fused_gating_intermediate.input_dims = hidden_units_; - ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2; - ffn_weights.fused_gating_intermediate.type = weight_type; - ffn_weights.fused_gating_intermediate.group_size = group_size; - - ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; - ffn_weights.output.output_dims = hidden_units_; - ffn_weights.output.type = weight_type; - ffn_weights.output.group_size = group_size; + // only fuse up and gate for AWQ + // fused_up_and_gate_ = weight_type_ == WeightType::kINT4 && ffn_weights.gating.lora.policy != LoraPolicy::kPlora; + fused_up_and_gate_ = quantization == QuantMethod::AWQ && ffn_weights.gating.lora.policy != LoraPolicy::kPlora; + + self_attn_weights.qkv.input_dims = hidden_units_; + self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; + self_attn_weights.qkv.type = weight_type; + self_attn_weights.qkv.group_size = group_size; + self_attn_weights.qkv.quantization = quantization; + + self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; + self_attn_weights.output.output_dims = hidden_units_; + self_attn_weights.output.type = weight_type; + self_attn_weights.output.group_size = group_size; + self_attn_weights.output.quantization = quantization; + + ffn_weights.gating.input_dims = hidden_units_; + ffn_weights.gating.output_dims = inter_size_ / tensor_para_size_; + ffn_weights.gating.type = weight_type; + ffn_weights.gating.group_size = group_size; + ffn_weights.gating.quantization = quantization; + + ffn_weights.intermediate.input_dims = hidden_units_; + ffn_weights.intermediate.output_dims = inter_size_ / tensor_para_size_; + ffn_weights.intermediate.type = weight_type; + ffn_weights.intermediate.group_size = group_size; + ffn_weights.intermediate.quantization = quantization; + + ffn_weights.fused_gating_intermediate.input_dims = hidden_units_; + ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2; + ffn_weights.fused_gating_intermediate.type = weight_type; + ffn_weights.fused_gating_intermediate.group_size = group_size; + ffn_weights.fused_gating_intermediate.quantization = quantization; + + ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; + ffn_weights.output.output_dims = hidden_units_; + ffn_weights.output.type = weight_type; + ffn_weights.output.group_size = group_size; + ffn_weights.output.quantization = quantization; mallocWeights(); } @@ -149,11 +159,34 @@ void mallocWeights(LlamaDenseWeight& weights, bool bias) } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; - FT_CHECK(weights.input_dims % factor == 0); + // TODO(HandH1998): check if it should use `output_dims % factor` + // FT_CHECK(weights.input_dims % factor == 0); + FT_CHECK(weights.output_dims % factor == 0); + if (weights.group_size != -1) + FT_CHECK(weights.input_dims % weights.group_size == 0); + if (weights.quantization == QuantMethod::AWQ) { + // interleaved scales/zeros + deviceMalloc((T**)&weights.scales_and_zeros, + weights.input_dims / weights.group_size * weights.output_dims * 2); + } + else if (weights.quantization == QuantMethod::QQQ) { + FT_CHECK(weights.input_dims % marlin_qqq::min_thread_k == 0); + FT_CHECK(weights.output_dims % marlin_qqq::min_thread_n == 0); + // QQQ employs sym quantization + if (weights.group_size == -1) { + deviceMalloc((float**)&weights.scales_channel, weights.output_dims); + } + else { + deviceMalloc((T**)&weights.scales_and_zeros, + weights.input_dims / weights.group_size * weights.output_dims); + deviceMalloc((float**)&weights.scales_channel, weights.output_dims); + } + } + else { + FT_CHECK(0); + } deviceMalloc((int**)&weights.kernel, weights.input_dims * weights.output_dims / factor); deviceMemSetZero((int*)weights.kernel, weights.input_dims * weights.output_dims / factor); - // interleaved scales/zeros - deviceMalloc((T**)&weights.scales_and_zeros, weights.input_dims / weights.group_size * weights.output_dims * 2); } if (weights.lora.r > 0) { @@ -196,11 +229,33 @@ void getWeightTensor(LlamaDenseWeight& weights, bool bias, const std::string& TYPE_INT32, {weights.input_dims * weights.output_dims * sizeof(int) / factor}, weights.kernel}); - output.insert(get_name("scales_zeros"), - Tensor{MEMORY_GPU, - getTensorType(), - {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)}, - weights.scales_and_zeros}); + if (weights.quantization == QuantMethod::AWQ) { + output.insert(get_name("scales_zeros"), + Tensor{MEMORY_GPU, + getTensorType(), + {weights.input_dims / weights.group_size * weights.output_dims * 2 * sizeof(T)}, + weights.scales_and_zeros}); + } + else if (weights.quantization == QuantMethod::QQQ) { + if (weights.group_size == -1) { + output.insert( + get_name("scales_channel"), + Tensor{MEMORY_GPU, TYPE_FP32, {weights.output_dims * sizeof(float)}, weights.scales_channel}); + } + else { + output.insert(get_name("scales_zeros"), + Tensor{MEMORY_GPU, + getTensorType(), + {weights.input_dims / weights.group_size * weights.output_dims * sizeof(T)}, + weights.scales_and_zeros}); + output.insert( + get_name("scales_channel"), + Tensor{MEMORY_GPU, TYPE_FP32, {weights.output_dims * sizeof(float)}, weights.scales_channel}); + } + } + else { + FT_CHECK(0); + } } if (weights.lora.r) { @@ -318,15 +373,31 @@ void loadWeights(LlamaDenseWeight& w, } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; - FT_CHECK(dim1 % factor == 0); - + if (w.group_size != -1) + FT_CHECK(dim0 % w.group_size == 0); + const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1; + if (w.quantization == QuantMethod::AWQ) { + loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {}); + } + else if (w.quantization == QuantMethod::QQQ) { + FT_CHECK(dim0 % marlin_qqq::min_thread_k == 0); + FT_CHECK(dim1 % marlin_qqq::min_thread_n == 0); + if (w.group_size == -1) { + loadWeightFromBin( + (float*)w.scales_channel, {1, dim1}, prefix + ".scales_channel", FtCudaDataType::FP32, {}); + } + else { + loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1}, prefix + ".scales_zeros", type, {}); + loadWeightFromBin( + (float*)w.scales_channel, {1, dim1}, prefix + ".scales_channel", FtCudaDataType::FP32, {}); + } + } + else { + FT_CHECK(0); + } std::vector w_shape{dim0, dim1 / factor * sizeof(uint32_t)}; loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {}); - - const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1; - - loadWeightFromBin((half*)w.scales_and_zeros, {group_count, dim1 * 2}, prefix + ".scales_zeros", type, {}); } } diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 5086adf8e..ac743a637 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -22,6 +22,7 @@ #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" namespace turbomind { @@ -30,17 +31,18 @@ template struct LlamaDecoderLayerWeight { public: LlamaDecoderLayerWeight() = delete; - LlamaDecoderLayerWeight(int layer_idx, - size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t inter_size, - WeightType weight_type, - int group_size, - LoraParams lora_params, - bool attn_bias, - size_t tensor_para_size, - size_t tensor_para_rank); + LlamaDecoderLayerWeight(int layer_idx, + size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + WeightType weight_type, + int group_size, + QuantMethod quantization, + LoraParams lora_params, + bool attn_bias, + size_t tensor_para_size, + size_t tensor_para_rank); ~LlamaDecoderLayerWeight(); LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 220185234..a17d9d27f 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -19,6 +19,7 @@ #pragma once +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/cuda_utils.h" namespace turbomind { @@ -76,14 +77,16 @@ struct LoraWeight { template struct LlamaDenseWeight { - size_t input_dims; - size_t output_dims; - void* kernel; - LoraWeight lora; - WeightType type; - T* bias; - T* scales_and_zeros; - int group_size; + size_t input_dims; + size_t output_dims; + void* kernel; + LoraWeight lora; + WeightType type; + T* bias; + T* scales_and_zeros; + float* scales_channel; + int group_size; + QuantMethod quantization; }; template diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index 464e9a2ce..5e6058b25 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -19,6 +19,8 @@ #include "src/turbomind/models/llama/LlamaFfnLayer.h" #include "src/turbomind/kernels/activation_kernels.h" +#include "src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h" +#include "src/turbomind/kernels/quant_kernels.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/anomaly_handler.h" @@ -31,24 +33,59 @@ void LlamaFfnLayer::allocateBuffer(size_t token_num, const LlamaDenseWeight* gating, const LlamaDenseWeight* inter) { - size_t sz = sizeof(T) * token_num * inter_size_; - size_t sz_gate = (gating->lora.r > 0) ? sz + sz / inter_size_ * gating->lora.r : sz; - size_t sz_inter = (inter->lora.r > 0) ? sz + sz / inter_size_ * inter->lora.r : sz; - inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sz_inter, false); - gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sz_gate, false); + size_t sz = sizeof(T) * token_num * inter_size_; + size_t sz_gate = (gating->lora.r > 0) ? sz + sz / inter_size_ * gating->lora.r : sz; + size_t sz_inter = (inter->lora.r > 0) ? sz + sz / inter_size_ * inter->lora.r : sz; + inter_buf_ = (T*)allocator_->reMalloc(inter_buf_, sz_inter, false); + gating_buf_ = (T*)allocator_->reMalloc(gating_buf_, sz_gate, false); + if (quantization_ == QuantMethod::QQQ) { + quant_buf_ = (int8_t*)allocator_->reMalloc(quant_buf_, sizeof(int8_t) * token_num * inter_size_, false); + act_scale_buf_ = (float*)allocator_->reMalloc(act_scale_buf_, sizeof(float) * token_num, false); + } is_allocate_buffer_ = true; } +template +void LlamaFfnLayer::allocateWorkspace() +{ + if (quantization_ == QuantMethod::QQQ) { + size_t max_dims = std::max(inter_size_, hidden_units_); + size_t sz_reduction = sizeof(int) * marlin_qqq::max_par * 64 * max_dims; + size_t sz_workspace = sizeof(int) * marlin_qqq::max_par * (max_dims / marlin_qqq::min_thread_n); + + auto [reduce_buf, workspace_buf] = linear_.getQQQBuffer(); + reduce_buf = (int*)allocator_->malloc(sz_reduction, false); + workspace_buf = (int*)allocator_->malloc(sz_workspace, true); + linear_.setQQQBuffer(reduce_buf, workspace_buf); + } + is_allocate_workspace_ = true; +} + template void LlamaFfnLayer::freeBuffer() { if (is_allocate_buffer_) { allocator_->free((void**)&inter_buf_); allocator_->free((void**)&gating_buf_); + allocator_->free((void**)&quant_buf_); + allocator_->free((void**)&act_scale_buf_); is_allocate_buffer_ = false; } } +template +void LlamaFfnLayer::freeWorkspace() +{ + if (is_allocate_workspace_) { + // free qqq workspace + auto [reduce_buf, workspace_buf] = linear_.getQQQBuffer(); + allocator_->free((void**)&reduce_buf); + allocator_->free((void**)&workspace_buf); + + is_allocate_workspace_ = false; + } +} + template void LlamaFfnLayer::activation(int num_token) { @@ -59,6 +96,8 @@ void LlamaFfnLayer::activation(int num_token) (const T*)nullptr, // gated_bias nullptr, // ia3_tasks (const T*)nullptr, // ia3_weights + quant_buf_, // quant_out + act_scale_buf_, // quant_scale num_token, // m inter_size_, // n 0, // int8_mode @@ -78,7 +117,8 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, /** * input_tensors: * \param ffn_input [token_num, hidden_dimension] - * + * \param ffn_quant_input [token_num, hidden_dimension] + * \param ffn_quant_scale [token_num, hidden_dimension] * output_tensors: * \param ffn_output [token_num, hidden_dimension] */ @@ -91,29 +131,50 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, allocateBuffer(num_token, &weights->gating, &weights->intermediate); - const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); - T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); + const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); + int8_t* ffn_quant_input_data = input_tensors->at("ffn_quant_input").getPtr(); + float* ffn_quant_scale_data = input_tensors->at("ffn_quant_scale").getPtr(); + T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); int* lora_mask = input_tensors->at("lora_mask", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); if (weights->fused_gating_intermediate.kernel) { NvtxScope scope("fused_silu_ffn"); - linear_.forward( - gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear::kFusedSiluFfn); + linear_.forward(gating_buf_, + ffn_input_data, + ffn_quant_input_data, + ffn_quant_scale_data, + num_token, + weights->fused_gating_intermediate, + LlamaLinear::kFusedSiluFfn); count_and_fix(gating_buf_, num_token * weights->output.input_dims, Concat("w1_w3_silu", layer_id), 3); } else { { // w1(x) NvtxScope scope("w1"); - linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear::kGemm, lora_mask); + linear_.forward(gating_buf_, + ffn_input_data, + ffn_quant_input_data, + ffn_quant_scale_data, + num_token, + weights->gating, + LlamaLinear::kGemm, + lora_mask); } - count_and_fix(gating_buf_, num_token * weights->gating.output_dims, Concat("w1", layer_id), 3); { // w3(x) NvtxScope scope("w3"); - linear_.forward( - inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear::kGemm, lora_mask); + linear_.forward(inter_buf_, + ffn_input_data, + ffn_quant_input_data, + ffn_quant_scale_data, + num_token, + weights->intermediate, + LlamaLinear::kGemm, + lora_mask); } + + count_and_fix(gating_buf_, num_token * weights->gating.output_dims, Concat("w1", layer_id), 3); count_and_fix(inter_buf_, num_token * weights->intermediate.output_dims, Concat("w3", layer_id), 3); // silu(w1(x)) * w3(x) @@ -124,7 +185,14 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, { // w2(x) NvtxScope scope("w2"); - linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output, LlamaLinear::kGemm, lora_mask); + linear_.forward(ffn_output_data, + gating_buf_, + quant_buf_, + act_scale_buf_, + num_token, + weights->output, + LlamaLinear::kGemm, + lora_mask); } count_and_fix(ffn_output_data, num_token * weights->output.output_dims, Concat("w2", layer_id), 3); diff --git a/src/turbomind/models/llama/LlamaFfnLayer.h b/src/turbomind/models/llama/LlamaFfnLayer.h index 6a414305c..9d25091ea 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.h +++ b/src/turbomind/models/llama/LlamaFfnLayer.h @@ -38,22 +38,26 @@ class LlamaFfnLayer { cudaStream_t stream, cublasMMWrapper* cublas_wrapper, IAllocator* allocator, - bool is_free_buffer_after_forward): + bool is_free_buffer_after_forward, + QuantMethod quantization): head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size / tensor_para.world_size_), hidden_units_(head_num * size_per_head), stream_(stream), - linear_(cublas_wrapper, stream), + linear_(cublas_wrapper, stream, allocator), allocator_(allocator), tensor_para_(tensor_para), - is_free_buffer_after_forward_(is_free_buffer_after_forward) + is_free_buffer_after_forward_(is_free_buffer_after_forward), + quantization_(quantization) { + allocateWorkspace(); } ~LlamaFfnLayer() { freeBuffer(); + freeWorkspace(); } void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight* weights); @@ -63,6 +67,10 @@ class LlamaFfnLayer { void freeBuffer(); + void allocateWorkspace(); + + void freeWorkspace(); + void activation(int num_token); size_t head_num_; @@ -73,13 +81,17 @@ class LlamaFfnLayer { LlamaLinear linear_; IAllocator* allocator_; bool is_free_buffer_after_forward_; + QuantMethod quantization_; - T* gating_buf_{}; - T* inter_buf_{}; + T* gating_buf_{}; + T* inter_buf_{}; + int8_t* quant_buf_{}; + float* act_scale_buf_{}; NcclParam tensor_para_; bool is_allocate_buffer_{}; + bool is_allocate_workspace_{}; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaLinear.h b/src/turbomind/models/llama/LlamaLinear.h index b00fb58c5..181b75e95 100644 --- a/src/turbomind/models/llama/LlamaLinear.h +++ b/src/turbomind/models/llama/LlamaLinear.h @@ -3,10 +3,13 @@ #pragma once #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h" +#include "src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/llama_decoder_kernels.h" #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/logger.h" @@ -24,12 +27,15 @@ class LlamaLinear { kFusedAdd }; - LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream) + LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream, IAllocator* allocator): + cublas_wrapper_(cublas_wrapper), stream_(stream), gemm_s4_s8_(allocator) { } void forward(T* output_data, const T* input_data, + int8_t* quant_input_data, + float* quant_scale, int batch_size, const LlamaDenseWeight& weight, Type type = kGemm, @@ -69,21 +75,31 @@ class LlamaLinear { invokeMask(output_data, lora_mask, batch_size, weight.output_dims, stream_); type = kFusedAdd; } - switch (weight.type) { - case WeightType::kFP16: - case WeightType::kFP32: - case WeightType::kBF16: + switch (weight.quantization) { + case QuantMethod::QNone: forwardFp(output_data, input_data, batch_size, weight, type); break; - case WeightType::kINT4: + case QuantMethod::AWQ: forwardInt4(output_data, input_data, batch_size, weight, type); break; + case QuantMethod::QQQ: + forwardQQQ(output_data, quant_input_data, quant_scale, batch_size, weight, type); break; default: FT_CHECK(0); } } + void setQQQBuffer(int* reduce_buf, int* workspace_buf) + { + gemm_s4_s8_.setBuffer(reduce_buf, workspace_buf); + } + + std::pair getQQQBuffer() + { + return gemm_s4_s8_.getBuffer(); + } + private: void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight& weight, Type type) { @@ -129,10 +145,40 @@ class LlamaLinear { } } + // w4a8 + void forwardQQQ(T* output_data, + const int8_t* input_data, + const float* act_scale, + int batch_size, + const LlamaDenseWeight& weight, + Type type) + { + // qqq only supports kGemm + FT_CHECK(type == kGemm); + if constexpr (std::is_same_v) { + gemm_s4_s8_.Run(output_data, + input_data, + (const uint*)weight.kernel, + act_scale, + (const float*)weight.scales_channel, + (const half*)weight.scales_and_zeros, + batch_size, + weight.output_dims, + weight.input_dims, + weight.group_size, + stream_); + sync_check_cuda_error(); + } + else { + FT_CHECK_WITH_INFO(0, "Not implemented"); + } + } + private: - cublasMMWrapper* cublas_wrapper_; - cudaStream_t stream_{}; - GemmS4F16 gemm_s4_f16_; + cublasMMWrapper* cublas_wrapper_; + cudaStream_t stream_{}; + GemmS4F16 gemm_s4_f16_; + marlin_qqq::MarlinQQQGemm gemm_s4_s8_; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 0387338c1..ab1c30ee1 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -55,6 +55,7 @@ LlamaV2::LlamaV2(size_t head_num, int start_id, int end_id, int cache_block_seq_len, + QuantMethod quantization, int quant_policy, bool use_context_fmha, const EngineParams& engine_params, @@ -83,6 +84,7 @@ LlamaV2::LlamaV2(size_t head_num, local_kv_head_num_(kv_head_num / tensor_para.world_size_), weights_(weights), tensor_para_(tensor_para), + quantization_(quantization), stream_(stream), cublas_wrapper_(cublas_wrapper), allocator_(allocator), @@ -141,7 +143,8 @@ void LlamaV2::initialize(const LlamaAttentionParams& attn_params, is_free_buffer_after_forward_, use_context_fmha, cache_block_seq_len, - quant_policy)); + quant_policy, + quantization_)); dynamic_decode_layer_ = new DynamicDecodeLayer(vocab_size_, vocab_size_padded_, @@ -237,6 +240,8 @@ template void LlamaV2::forwardUnified(T* out, T* decoder_output, T* decoder_input, + int8_t* decoder_quant_output, + float* decoder_quant_scale, void** block_ptrs, const int* cu_block_cnts, const int* input_ids, @@ -298,7 +303,10 @@ void LlamaV2::forwardUnified(T* out, if (lora_mask != nullptr && have_embeddings) { inputs.insert({"lora_mask", {MEMORY_GPU, TYPE_INT32, {token_num}, lora_mask}}); } - + // decoder_quant_output and decoder_quant_scale may be non Tensor, so we use insert pair to avoid tensor validity + // check + outputs.insert({"decoder_quant_output", {MEMORY_GPU, TYPE_INT8, {token_num, hidden_units_}, decoder_quant_output}}); + outputs.insert({"decoder_quant_scale", {MEMORY_GPU, TYPE_FP32, {token_num}, decoder_quant_scale}}); unified_decoder_->forward(&outputs, &inputs, &weights_->decoder_layer_weights); } diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index 61d83b90e..c0366f500 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -28,6 +28,7 @@ #include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/SequenceManager.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/unified_decoder.h" #include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/cublasMMWrapper.h" @@ -65,6 +66,7 @@ class LlamaV2 { int start_id, int end_id, int cache_block_seq_len, + QuantMethod quantization, int quant_policy, bool use_context_fmha, const EngineParams& engine_params, @@ -122,6 +124,8 @@ class LlamaV2 { void forwardUnified(T* out, T* decoder_output, T* decoder_input, + int8_t* decoder_quant_output, + float* decoder_quant_scale, void** block_ptrs, const int* cu_block_cnts, const int* input_ids, @@ -195,6 +199,7 @@ class LlamaV2 { ffi_api_lock_ctrl_t ffi_lock_; std::unique_ptr> batch_; LoraParams lora_params_; + QuantMethod quantization_; }; } // namespace turbomind diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index c87bc4089..3115b5639 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -23,18 +23,19 @@ namespace turbomind { template -LlamaWeight::LlamaWeight(size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t inter_size, - size_t vocab_size, - size_t num_layer, - bool attn_bias, - WeightType weight_type, - int group_size, - LoraParams lora_params, - size_t tensor_para_size, - size_t tensor_para_rank): +LlamaWeight::LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t vocab_size, + size_t num_layer, + bool attn_bias, + WeightType weight_type, + int group_size, + QuantMethod quantization, + LoraParams lora_params, + size_t tensor_para_size, + size_t tensor_para_rank): hidden_units_(head_num * size_per_head), inter_size_(inter_size), vocab_size_(vocab_size), @@ -57,6 +58,7 @@ LlamaWeight::LlamaWeight(size_t head_num, inter_size_, weight_type_, group_size, + quantization, lora_params, attn_bias, tensor_para_size_, diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index 65eb986d8..401ca0143 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -22,6 +22,7 @@ #include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/memory_utils.h" namespace turbomind { @@ -29,18 +30,19 @@ namespace turbomind { template struct LlamaWeight { LlamaWeight() = default; - LlamaWeight(size_t head_num, - size_t kv_head_num, - size_t size_per_head, - size_t inter_size, - size_t vocab_size, - size_t num_layer, - bool attn_bias, - WeightType weight_type, - int group_size, - LoraParams lora_params, - size_t tensor_para_size, - size_t tensor_para_rank); + LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, + size_t inter_size, + size_t vocab_size, + size_t num_layer, + bool attn_bias, + WeightType weight_type, + int group_size, + QuantMethod quantization, + LoraParams lora_params, + size_t tensor_para_size, + size_t tensor_para_rank); ~LlamaWeight(); diff --git a/src/turbomind/models/llama/llama_decoder_kernels.cu b/src/turbomind/models/llama/llama_decoder_kernels.cu index f0ed63ca7..a592548c2 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.cu +++ b/src/turbomind/models/llama/llama_decoder_kernels.cu @@ -12,11 +12,26 @@ namespace cg = cooperative_groups; namespace turbomind { +template +struct QuantTypeConverter { + using Type = short4; +}; + +template<> +struct QuantTypeConverter<8> { + using Type = short4; +}; + +template<> +struct QuantTypeConverter<4> { + using Type = short2; +}; + template struct res_norm_ops_t { }; -template +template struct res_norm_t { res_norm_ops_t f; __device__ uint4 addvec(const uint4& a, const uint4& b, const uint4& bias, float& accum) const @@ -37,6 +52,17 @@ struct res_norm_t { v.w = f.cast(f.norm(f.cast(u.w), f.cast(s.w), factor)); return v; } + __device__ T absmax(const uint4& x) const + { + auto v1 = f.max(f.max(f.abs(x.x)), f.max(f.abs(x.y))); + auto v2 = f.max(f.max(f.abs(x.z)), f.max(f.abs(x.w))); + auto v = f.max(v1, v2); + return v; + } + __device__ T_Q quant(const uint4& x, float scale) const + { + return f.quant(x, scale); + } }; template<> @@ -50,6 +76,19 @@ struct res_norm_ops_t { auto y = __float22half2_rn(x); return reinterpret_cast(y); } + __device__ half2 abs(const uint& x) const + { + uint t = const_cast(x); + return cuda_abs(reinterpret_cast(t)); + } + __device__ half max(const half2& x) const + { + return (x.x > x.y) ? x.x : x.y; + } + __device__ half max(const half& x, const half& y) const + { + return (x > y) ? x : y; + } __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const { float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y}; @@ -60,6 +99,22 @@ struct res_norm_ops_t { { return {a.x * s.x * factor, a.y * s.y * factor}; } + __device__ float2 mul(const uint& x, float scale) const + { + float2 res = cast(x); + res.x *= scale; + res.y *= scale; + return res; + } + __device__ short4 quant(const uint4& x, float scale) const + { + short4 res; + res.x = cuda_cast(mul(x.x, scale)); + res.y = cuda_cast(mul(x.y, scale)); + res.z = cuda_cast(mul(x.z, scale)); + res.w = cuda_cast(mul(x.w, scale)); + return res; + } }; template<> @@ -72,6 +127,20 @@ struct res_norm_ops_t { { return reinterpret_cast(x); } + __device__ float abs(const uint& x) const + { + uint t = const_cast(x); + return cuda_abs(reinterpret_cast(t)); + } + // for generality + __device__ float max(const float& x) const + { + return x; + } + __device__ float max(const float& x, const float& y) const + { + return (x > y) ? x : y; + } __device__ float add(const float& a, const float& b, const float& bias, float& accum) const { float c = a + b + bias; @@ -82,6 +151,20 @@ struct res_norm_ops_t { { return a * s * factor; } + __device__ float2 mul(const uint& x, const uint& y, float scale) const + { + float2 res; + res.x = cast(x) * scale; + res.y = cast(y) * scale; + return res; + } + __device__ short2 quant(const uint4& x, float scale) const + { + short2 res; + res.x = cuda_cast(mul(x.x, x.y, scale)); + res.y = cuda_cast(mul(x.z, x.w, scale)); + return res; + } }; #ifdef ENABLE_BF16 @@ -96,6 +179,19 @@ struct res_norm_ops_t<__nv_bfloat16> { auto y = cuda_cast<__nv_bfloat162, float2>(x); return reinterpret_cast(y); } + __device__ __nv_bfloat162 abs(const uint& x) const + { + uint t = const_cast(x); + return cuda_abs(reinterpret_cast<__nv_bfloat162&>(t)); + } + __device__ __nv_bfloat16 max(const __nv_bfloat162& x) const + { + return (x.x > x.y) ? x.x : x.y; + } + __device__ __nv_bfloat16 max(const __nv_bfloat16& x, const __nv_bfloat16& y) const + { + return (x > y) ? x : y; + } __device__ float2 add(const float2& a, const float2& b, const float2& bias, float& accum) const { float2 c{a.x + b.x + bias.x, a.y + b.y + bias.y}; @@ -106,6 +202,22 @@ struct res_norm_ops_t<__nv_bfloat16> { { return {a.x * s.x * factor, a.y * s.y * factor}; } + __device__ float2 mul(const uint& x, float scale) const + { + float2 res = cast(x); + res.x *= scale; + res.y *= scale; + return res; + } + __device__ short4 quant(const uint4& x, float scale) const + { + short4 res; + res.x = cuda_cast(mul(x.x, scale)); + res.y = cuda_cast(mul(x.y, scale)); + res.z = cuda_cast(mul(x.z, scale)); + res.w = cuda_cast(mul(x.w, scale)); + return res; + } }; #endif @@ -128,13 +240,33 @@ __device__ T blockReduceSum(const cg::thread_block& block, T value) return cg::reduce(tile, value, cg::plus{}); } +template +__device__ T blockReduceMax(const cg::thread_block& block, T value) +{ + __shared__ float partial[32]; + + auto tile = cg::tiled_partition<32>(block); + value = cg::reduce(tile, value, cg::greater{}); + + if (tile.thread_rank() == 0) { + partial[tile.meta_group_rank()] = value; + } + + block.sync(); + + value = tile.thread_rank() < tile.meta_group_size() ? partial[tile.thread_rank()] : T(-1e20f); + return cg::reduce(tile, value, cg::greater{}); +} + // r' = r + x // x' = norm(r') * scales -template +template __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, T* __restrict__ x_data, const T* __restrict__ bias, const T* __restrict__ scale, + int8_t* __restrict__ quant_out, + float* __restrict__ quant_scale, float eps, int batch_size, int n_dims) @@ -148,8 +280,10 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, uint4* __restrict__ r_ptr = reinterpret_cast(r_data + batch_idx * n_dims); uint4* __restrict__ x_ptr = reinterpret_cast(x_data + batch_idx * n_dims); const uint4* __restrict__ b_ptr = reinterpret_cast(bias); + using T_Q = typename QuantTypeConverter::Type; + T_Q* __restrict__ q_ptr = reinterpret_cast(quant_out + batch_idx * n_dims); - res_norm_t ops; + res_norm_t ops; float thread_sum{}; for (auto i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) { @@ -165,17 +299,47 @@ __global__ void fusedAddBiasResidualNorm(T* __restrict__ r_data, float s_inv_mean = rsqrt(total_sum / n_dims + eps); const uint4* __restrict__ s_ptr = reinterpret_cast(scale); - for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) { - auto r = r_ptr[i]; - auto s = s_ptr[i]; - auto o = ops.normvec(r, s, s_inv_mean); - x_ptr[i] = o; + if constexpr (enable_quant) { + float thread_max{}; + for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) { + auto r = r_ptr[i]; + auto s = s_ptr[i]; + auto o = ops.normvec(r, s, s_inv_mean); + x_ptr[i] = o; + thread_max = cuda_max(thread_max, cuda_cast(ops.absmax(o))); + } + auto total_max = blockReduceMax(block, thread_max); + if (block.thread_rank() == 0) { + quant_scale[batch_idx] = total_max / 127.0f; + } + const float tmp_scale = 127.0f / total_max; + for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) { + auto x = x_ptr[i]; + auto q = ops.quant(x, tmp_scale); + q_ptr[i] = q; + } + } + else { + for (uint i = block.thread_rank(); i < n_dims / PACK_DIM; i += block.size()) { + auto r = r_ptr[i]; + auto s = s_ptr[i]; + auto o = ops.normvec(r, s, s_inv_mean); + x_ptr[i] = o; + } } } template -void invokeFusedAddBiasResidualRMSNorm( - T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream) +void invokeFusedAddBiasResidualRMSNorm(T* residual, + T* in_out, + const T* bias, + const T* scale, + int8_t* quant_out, + float* quant_scale, + float eps, + int batch_size, + int n_dims, + cudaStream_t stream) { constexpr int PACK_DIM = sizeof(uint4) / sizeof(T); FT_CHECK(n_dims % PACK_DIM == 0); @@ -184,8 +348,15 @@ void invokeFusedAddBiasResidualRMSNorm( int n_threads = (n_pack + n_iter - 1) / n_iter; // adjust block size to avoid tail effect n_threads = (n_threads + 31) / 32 * 32; // round up to the nearest multiple of warp size - fusedAddBiasResidualNorm<<>>( - residual, in_out, bias, scale, eps, batch_size, n_dims); + if (quant_out == nullptr) { + fusedAddBiasResidualNorm<<>>( + residual, in_out, bias, scale, quant_out, quant_scale, eps, batch_size, n_dims); + } + else { + fusedAddBiasResidualNorm<<>>( + residual, in_out, bias, scale, quant_out, quant_scale, eps, batch_size, n_dims); + } + sync_check_cuda_error(); } template @@ -206,15 +377,24 @@ void invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_ } #ifdef ENABLE_FP32 -template void -invokeFusedAddBiasResidualRMSNorm(float*, float*, const float*, const float*, float, int, int, cudaStream_t); +template void invokeFusedAddBiasResidualRMSNorm( + float*, float*, const float*, const float*, int8_t*, float*, float, int, int, cudaStream_t); template void invokeMask(float* output, const int* mask, int batch_size, int dim, cudaStream_t stream); #endif -template void invokeFusedAddBiasResidualRMSNorm(half*, half*, const half*, const half*, float, int, int, cudaStream_t); +template void invokeFusedAddBiasResidualRMSNorm( + half*, half*, const half*, const half*, int8_t*, float*, float, int, int, cudaStream_t); template void invokeMask(half* output, const int* mask, int batch_size, int dim, cudaStream_t stream); #ifdef ENABLE_BF16 -template void invokeFusedAddBiasResidualRMSNorm( - __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +template void invokeFusedAddBiasResidualRMSNorm(__nv_bfloat16*, + __nv_bfloat16*, + const __nv_bfloat16*, + const __nv_bfloat16*, + int8_t*, + float*, + float, + int, + int, + cudaStream_t); template void invokeMask(__nv_bfloat16* output, const int* mask, int batch_size, int dim, cudaStream_t stream); #endif } // namespace turbomind diff --git a/src/turbomind/models/llama/llama_decoder_kernels.h b/src/turbomind/models/llama/llama_decoder_kernels.h index 9d4dc51fe..f2cd636e5 100644 --- a/src/turbomind/models/llama/llama_decoder_kernels.h +++ b/src/turbomind/models/llama/llama_decoder_kernels.h @@ -5,8 +5,16 @@ namespace turbomind { template -void invokeFusedAddBiasResidualRMSNorm( - T* residual, T* in_out, const T* bias, const T* scale, float eps, int batch_size, int n_dims, cudaStream_t stream); +void invokeFusedAddBiasResidualRMSNorm(T* residual, + T* in_out, + const T* bias, + const T* scale, + int8_t*, + float*, + float eps, + int batch_size, + int n_dims, + cudaStream_t stream); template void invokeMask(T* output, const int* mask, int batch_size, int dim, cudaStream_t stream); diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index 11be59d0c..fd7fa7eaa 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -20,16 +20,18 @@ namespace turbomind { // fp16, bf16 // n is divided by 2 for this impl -template -__global__ void rootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n) +template +__global__ typename std::enable_if::value>::type rootMeanSquareNorm( + T* out, const T* input, const T* scale, int8_t* quant_out, float* quant_scale, float eps, int m, int n) { using T2 = typename TypeConverter::Type; __shared__ float s_inv_mean; float mean = 0.f; - T2* out_ptr = (T2*)out; - const T2* input_ptr = (const T2*)input; - const T2* scale_ptr = (const T2*)scale; + T2* out_ptr = (T2*)out; + const T2* input_ptr = (const T2*)input; + const T2* scale_ptr = (const T2*)scale; + uint16_t* quant_out_ptr = (uint16_t*)quant_out; for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); @@ -43,17 +45,42 @@ __global__ void rootMeanSquareNorm(T* out, const T* input, const T* scale, float } __syncthreads(); - for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { - float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); - float2 sca2 = cuda_cast(scale_ptr[idx]); - tmp2.x = tmp2.x * s_inv_mean * sca2.x; - tmp2.y = tmp2.y * s_inv_mean * sca2.y; - out_ptr[blockIdx.x * n + idx] = cuda_cast(tmp2); + if constexpr (enable_quant) { + __shared__ float s_amax; + float amax_val = 0.0f; + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); + float2 sca2 = cuda_cast(scale_ptr[idx]); + tmp2 = tmp2 * s_inv_mean * sca2; + out_ptr[blockIdx.x * n + idx] = cuda_cast(tmp2); + amax_val = cuda_max(amax_val, cuda_max(cuda_abs(tmp2.x), cuda_abs(tmp2.y))); + } + amax_val = blockReduceMax(amax_val); + if (threadIdx.x == 0) { + s_amax = amax_val; + quant_scale[blockIdx.x] = amax_val / 127.0f; + } + __syncthreads(); + const float tmp_scale = 127.0f / s_amax; + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float2 tmp2 = cuda_cast(out_ptr[blockIdx.x * n + idx]); + tmp2 = tmp2 * tmp_scale; + quant_out_ptr[blockIdx.x * n + idx] = cuda_cast(tmp2); + } + } + else { + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float2 tmp2 = cuda_cast(input_ptr[blockIdx.x * n + idx]); + float2 sca2 = cuda_cast(scale_ptr[idx]); + tmp2 = tmp2 * s_inv_mean * sca2; + out_ptr[blockIdx.x * n + idx] = cuda_cast(tmp2); + } } } -template<> -__global__ void rootMeanSquareNorm(float* out, const float* input, const float* scale, float eps, int m, int n) +template +__global__ typename std::enable_if::value>::type rootMeanSquareNorm( + T* out, const T* input, const T* scale, int8_t* quant_out, float* quant_scale, float eps, int m, int n) { __shared__ float s_inv_mean; float mean = 0.f; @@ -69,14 +96,44 @@ __global__ void rootMeanSquareNorm(float* out, const float* input, const float* } __syncthreads(); - for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { - float tmp = input[blockIdx.x * n + idx]; - out[blockIdx.x * n + idx] = tmp * s_inv_mean * scale[idx]; + if constexpr (enable_quant) { + __shared__ float s_amax; + float amax_val = 0.0f; + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float tmp = input[blockIdx.x * n + idx] * s_inv_mean * scale[idx]; + out[blockIdx.x * n + idx] = tmp; + amax_val = cuda_max(amax_val, cuda_abs(tmp)); + } + amax_val = blockReduceMax(amax_val); + if (threadIdx.x == 0) { + s_amax = amax_val; + quant_scale[blockIdx.x] = amax_val / 127.0f; + } + __syncthreads(); + const float tmp_scale = 127.0f / s_amax; + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float tmp = out[blockIdx.x * n + idx]; + quant_out[blockIdx.x * n + idx] = cuda_cast(tmp * tmp_scale); + } + } + else { + for (uint idx = threadIdx.x; idx < n; idx += blockDim.x) { + float tmp = input[blockIdx.x * n + idx]; + out[blockIdx.x * n + idx] = tmp * s_inv_mean * scale[idx]; + } } } template -void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream) +void invokeRootMeanSquareNorm(T* out, + const T* input, + const T* scale, + int8_t* quant_out, + float* quant_scale, + float eps, + int m, + int n, + cudaStream_t stream) { if (sizeof(T) == 2) { FT_CHECK(n % 2 == 0); @@ -84,22 +141,23 @@ void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, } dim3 grid(m); dim3 block(std::min(n, 1024)); - rootMeanSquareNorm<<>>(out, input, scale, eps, m, n); + if (quant_out == nullptr) { + rootMeanSquareNorm<<>>(out, input, scale, quant_out, quant_scale, eps, m, n); + } + else { + rootMeanSquareNorm<<>>(out, input, scale, quant_out, quant_scale, eps, m, n); + } + sync_check_cuda_error(); } -template void invokeRootMeanSquareNorm(float*, const float*, const float*, float, int, int, cudaStream_t); -template void invokeRootMeanSquareNorm(half*, const half*, const half*, float, int, int, cudaStream_t); -#ifdef ENABLE_BF16 template void -invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); +invokeRootMeanSquareNorm(float*, const float*, const float*, int8_t*, float*, float, int, int, cudaStream_t); +template void invokeRootMeanSquareNorm(half*, const half*, const half*, int8_t*, float*, float, int, int, cudaStream_t); +#ifdef ENABLE_BF16 +template void invokeRootMeanSquareNorm( + __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, int8_t*, float*, float, int, int, cudaStream_t); #endif -// #ifdef ENABLE_BF16 - -// template void invokeRootMeanSquareNorm(__nv_bfloat16*, const __nv_bfloat16*, float, int, int, cudaStream_t); - -// #endif - template __device__ T saturate_cast(T0 x) { diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index 3b01dee60..d9cc55b81 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -13,7 +13,15 @@ namespace turbomind { template -void invokeRootMeanSquareNorm(T* out, const T* input, const T* scale, float eps, int m, int n, cudaStream_t stream); +void invokeRootMeanSquareNorm(T* out, + const T* input, + const T* scale, + int8_t* quant_out, + float* quant_scale, + float eps, + int m, + int n, + cudaStream_t stream); template void invokeAddResidual(T* out, const T* in, int m, int n, cudaStream_t stream); diff --git a/src/turbomind/models/llama/llama_utils.h b/src/turbomind/models/llama/llama_utils.h index a97b94c37..6d353c676 100644 --- a/src/turbomind/models/llama/llama_utils.h +++ b/src/turbomind/models/llama/llama_utils.h @@ -21,6 +21,13 @@ enum QuantPolicy kCacheKVInt4 = 0x04, }; +enum QuantMethod +{ + QNone, + AWQ, + QQQ, +}; + enum CmpMode { kCmpNone, diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index 89c9cf4d7..84c3dd581 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -23,6 +23,8 @@ #include "src/turbomind/kernels/attention/attention.h" #include "src/turbomind/kernels/attention/decoding.h" #include "src/turbomind/kernels/attention/kv_cache_utils_v2.h" +#include "src/turbomind/kernels/marlin_qqq_gemm/marlin_qqq_gemm_kernel.h" +#include "src/turbomind/kernels/quant_kernels.h" #include "src/turbomind/macro.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/llama_kernels.h" @@ -57,6 +59,12 @@ void UnifiedAttentionLayer::allocateBuffer(size_t q_count, (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * q_count * local_q_kv_head_num * size_per_head_, false); } + if (quantization_ == QuantMethod::QQQ) { + quant_buf_ = (int8_t*)allocator_->reMalloc( + quant_buf_, sizeof(int8_t) * q_count * local_head_num_ * size_per_head_, false); + act_scale_buf_ = (float*)allocator_->reMalloc(act_scale_buf_, sizeof(float) * q_count, false); + } + qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * q_count * local_head_num_ * size_per_head_, false); // Pad the tmp buffer for linear KV cache by `MAX_CTA_S` to avoid illegal accesses @@ -76,6 +84,16 @@ void UnifiedAttentionLayer::allocateWorkspace() partial_O_ = (float*)allocator_->malloc(sizeof(float) * kMaxWorkspaceTokens * local_head_num_ * size_per_head_); split_cnt_ = (int*)allocator_->malloc(sizeof(int) * kMaxWorkspaceTokens); barriers_ = (int*)allocator_->malloc(sizeof(int) * kMaxWorkspaceTokens * local_head_num_, true, false); + if (quantization_ == QuantMethod::QQQ) { + const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + size_t max_dims = std::max(local_q_kv_head_num * size_per_head_, hidden_units_); + size_t sz_reduction = sizeof(int) * marlin_qqq::max_par * 64 * max_dims; + size_t sz_workspace = sizeof(int) * marlin_qqq::max_par * (max_dims / marlin_qqq::min_thread_n); + auto [reduce_buf, workspace_buf] = linear_.getQQQBuffer(); + reduce_buf = (int*)allocator_->malloc(sz_reduction, false); + workspace_buf = (int*)allocator_->malloc(sz_workspace, true); + linear_.setQQQBuffer(reduce_buf, workspace_buf); + } is_allocate_workspace_ = true; } @@ -90,6 +108,10 @@ void UnifiedAttentionLayer::freeWorkspace() allocator_->free((void**)&partial_O_); allocator_->free((void**)&split_cnt_); allocator_->free((void**)&barriers_); + // free qqq workspace + auto [reduce_buf, workspace_buf] = linear_.getQQQBuffer(); + allocator_->free((void**)&reduce_buf); + allocator_->free((void**)&workspace_buf); is_allocate_workspace_ = false; } @@ -103,6 +125,8 @@ void UnifiedAttentionLayer::freeBuffer() allocator_->free((void**)&qkv_buf_); allocator_->free((void**)&qkv_buf_3_); + allocator_->free((void**)&quant_buf_); + allocator_->free((void**)&act_scale_buf_); allocator_->free((void**)&tmp_kv_buf_); is_allocate_buffer_ = false; @@ -117,6 +141,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa /** * input_tensors: * \param input_query [token_num, hidden_dim] + * \param input_quant_query [token_num, hidden_dim] or empty, int8 + * \param input_quant_scale [token_num] or empty, float * \param cu_q_len [batch_size+1], int * \param cu_k_len [batch_size+1], int * \param cu_block_counts [batch_size+1], int @@ -157,8 +183,10 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa void** block_ptrs = outputs->getPtr("block_ptrs"); int* cu_block_count = inputs->getPtr("cu_block_counts"); - T* attention_input = inputs->getPtr("input_query"); - T* attention_out = outputs->getPtr("hidden_features"); + T* attention_input = inputs->getPtr("input_query"); + int8_t* attention_quant_input = inputs->getPtr("input_quant_query"); + float* attention_quant_scale = inputs->getPtr("input_quant_scale"); + T* attention_out = outputs->getPtr("hidden_features"); ///////////////////////////////////////////// /// allocate buffers @@ -177,10 +205,14 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa // } int* lora_mask = inputs->at("lora_mask", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); - ////////////////////////////////////////////// - /// qkv gemm - // [token_num, hidden_dim] -> [token_num, 3, local_hidden_dim] - linear_.forward(qkv_buf_, attention_input, token_num, weights->qkv, LlamaLinear::kGemm, lora_mask); + linear_.forward(qkv_buf_, + attention_input, + attention_quant_input, + attention_quant_scale, + token_num, + weights->qkv, + LlamaLinear::kGemm, + lora_mask); count_and_fix(qkv_buf_, token_num * weights->qkv.output_dims, Concat("qkv", layer_id), 3); @@ -317,12 +349,19 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa // } count_and_fix(qkv_buf_3_, token_num * weights->output.input_dims, Concat("attn", layer_id), 3); - + if (quantization_ == QuantMethod::QQQ) { + invokeI8Quant(qkv_buf_3_, quant_buf_, act_scale_buf_, token_num, local_head_num_ * size_per_head_, stream_); + } ////////////////////////////////////////////// /// output gemm -> - linear_.forward(attention_out, qkv_buf_3_, token_num, weights->output, LlamaLinear::kGemm, lora_mask); - - // ++count; + linear_.forward(attention_out, + qkv_buf_3_, + quant_buf_, + act_scale_buf_, + token_num, + weights->output, + LlamaLinear::kGemm, + lora_mask); count_and_fix(attention_out, token_num * weights->output.output_dims, Concat("wo", layer_id), 3); diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index 6b6bbba56..f4c9ba8e1 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -61,7 +61,8 @@ class UnifiedAttentionLayer { IAllocator* allocator, bool is_free_buffer_after_forward, int cache_block_seq_len, - int quant_policy): + int quant_policy, + QuantMethod quantization): head_num_(head_num), size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), @@ -73,11 +74,12 @@ class UnifiedAttentionLayer { lora_params_(lora_params), stream_(stream), cublas_wrapper_(cublas_wrapper), - linear_(cublas_wrapper, stream), + linear_(cublas_wrapper, stream, allocator), allocator_(allocator), kv_cache_block_len_(cache_block_seq_len), is_free_buffer_after_forward_(is_free_buffer_after_forward), - quant_policy_(quant_policy) + quant_policy_(quant_policy), + quantization_(quantization) { FT_CHECK(head_num % kv_head_num == 0); arch_ = getSMVersion(); @@ -141,7 +143,8 @@ class UnifiedAttentionLayer { const LlamaAttentionParams params_; - const int quant_policy_; + const int quant_policy_; + QuantMethod quantization_; NcclParam tensor_para_; @@ -178,6 +181,9 @@ class UnifiedAttentionLayer { int* barriers_{}; // always zero T* tmp_kv_buf_{}; + // online act quant + int8_t* quant_buf_{}; + float* act_scale_buf_{}; bool is_allocate_buffer_ = false; bool is_allocate_workspace_ = false; diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index 914436b34..6f8cc4fe2 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -34,7 +34,8 @@ template void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int cache_block_seq_len, - int quant_policy) + int quant_policy, + QuantMethod quantization) { attn_layer_ = new UnifiedAttentionLayer(head_num_, kv_head_num, @@ -47,7 +48,8 @@ void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, allocator_, is_free_buffer_after_forward_, cache_block_seq_len, - quant_policy); + quant_policy, + quantization); ffn_layer_ = new LlamaFfnLayer(head_num_, size_per_head_, @@ -56,13 +58,16 @@ void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, stream_, cublas_wrapper_, allocator_, - is_free_buffer_after_forward_); + is_free_buffer_after_forward_, + quantization); check_cuda_error(cudaEventCreateWithFlags(&ev_h_cu_x_, cudaEventDisableTiming)); } template void UnifiedDecoder::forwardSelfAttn(T* attn_io, + int8_t* attn_qi, + float* attn_qs, TensorMap* _outputs, const TensorMap* _inputs, size_t token_num, @@ -72,6 +77,8 @@ void UnifiedDecoder::forwardSelfAttn(T* attn_io, { TensorMap inputs(*_inputs); inputs.insert("input_query", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, attn_io}); + inputs.insert({"input_quant_query", {MEMORY_GPU, TYPE_INT8, {token_num, hidden_units_}, attn_qi}}); + inputs.insert({"input_quant_scale", {MEMORY_GPU, TYPE_FP32, {token_num}, attn_qs}}); inputs.insert("layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}); inputs.insert("cu_q_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_q_len_}); inputs.insert("cu_k_len", {MEMORY_GPU, TYPE_INT32, {batch_size + 1}, cu_k_len_}); @@ -110,6 +117,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con * * output tensors: * \param decoder_output [num_token, hidden_units], + * \param decoder_quant_output [num_token, hidden_units] or empty, int8 + * \param decoder_quant_scale [num_token] or empty, float * \param last_token_hidden_units [batch_size, hidden_units] * \param block_ptrs [total_block_counts], void* */ @@ -123,9 +132,11 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con const int* h_q_len = inputs->getPtr("h_q_len"); const int* h_k_len = inputs->getPtr("h_k_len"); - T* decoder_input_output = inputs->getPtr("decoder_input"); - T* decoder_output = outputs->getPtr("decoder_output"); - T* last_token_hidden_units = outputs->getPtr("last_token_hidden_units"); + T* decoder_input_output = inputs->getPtr("decoder_input"); + T* decoder_output = outputs->getPtr("decoder_output"); + int8_t* decoder_quant_output = outputs->getPtr("decoder_quant_output"); + float* decoder_quant_scale = outputs->getPtr("decoder_quant_scale"); + T* last_token_hidden_units = outputs->getPtr("last_token_hidden_units"); { // compute cumulative lengths @@ -156,6 +167,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con invokeRootMeanSquareNorm(decoder_output, decoder_input_output, weights->at(0)->self_attn_norm_weights, + decoder_quant_output, + decoder_quant_scale, rmsnorm_eps_, token_num, hidden_units_, @@ -171,6 +184,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con ///////////////////////////////////////////// /// self-attention forwardSelfAttn(decoder_output, // + decoder_quant_output, + decoder_quant_scale, outputs, inputs, token_num, @@ -184,6 +199,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con decoder_output, weights->at(layer)->self_attn_weights.output.bias, weights->at(layer)->ffn_norm_weights, + decoder_quant_output, + decoder_quant_scale, rmsnorm_eps_, token_num, hidden_units_, @@ -198,6 +215,9 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con int layer_id = layer; // int is needed TensorMap ffn_inputs{{"ffn_input", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}, {"layer_id", {MEMORY_CPU, TYPE_INT32, {1}, &layer_id}}}; + ffn_inputs.insert( + {"ffn_quant_input", {MEMORY_GPU, TYPE_INT8, {token_num, hidden_units_}, decoder_quant_output}}); + ffn_inputs.insert({"ffn_quant_scale", {MEMORY_GPU, TYPE_INT8, {token_num}, decoder_quant_scale}}); TensorMap ffn_outputs{{"ffn_output", {MEMORY_GPU, dtype_, {token_num, hidden_units_}, decoder_output}}}; if (inputs->isExist("lora_mask")) { ffn_inputs.insert({"lora_mask", inputs->at("lora_mask")}); @@ -215,6 +235,8 @@ void UnifiedDecoder::forward(TensorMap* outputs, const TensorMap* inputs, con decoder_output, weights->at(layer)->ffn_weights.output.bias, scale_weight, + decoder_quant_output, + decoder_quant_scale, rmsnorm_eps_, token_num, hidden_units_, diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 7dde36cb9..2401803c5 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -15,8 +15,11 @@ class UnifiedDecoder { protected: void freeBuffer(); - void - initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, int cache_block_seq_len, int quant_policy); + void initialize(const LlamaAttentionParams& attn_params, + size_t kv_head_num, + int cache_block_seq_len, + int quant_policy, + QuantMethod quantization); cudaStream_t stream_; cublasMMWrapper* cublas_wrapper_; @@ -52,6 +55,8 @@ class UnifiedDecoder { using WeightType = LlamaDecoderLayerWeight; void forwardSelfAttn(T* attn_io, + int8_t* attn_qi, + float* attn_qs, TensorMap* _outputs, const TensorMap* _inputs, size_t token_num, @@ -75,7 +80,8 @@ class UnifiedDecoder { bool is_free_buffer_after_forward, bool use_fmha, int cache_block_seq_len, - int quant_policy): + int quant_policy, + QuantMethod quantization): stream_(stream), cublas_wrapper_(cublas_wrapper), allocator_(allocator), @@ -90,7 +96,7 @@ class UnifiedDecoder { tensor_para_(tensor_para), dtype_(getTensorType()) { - initialize(attn_params, kv_head_num, cache_block_seq_len, quant_policy); + initialize(attn_params, kv_head_num, cache_block_seq_len, quant_policy, quantization); } void allocateBuffer(size_t max_batch_size); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 87fd2cdf5..424b1b4cf 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -274,6 +274,16 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, std::cout << "[ERROR] Unsupported weight type: '" << weight_type_str << "'\n"; ft::FT_CHECK(0); } + const std::string quantization_str = reader.Get("llama", "quantization"); + if (quantization_str == "awq") { + quantization_ = turbomind::QuantMethod::AWQ; + } + else if (quantization_str == "qqq") { + quantization_ = turbomind::QuantMethod::QQQ; + } + else { + quantization_ = turbomind::QuantMethod::QNone; + } TM_LOG_INFO("%s", toString().c_str()); } @@ -344,6 +354,7 @@ std::unique_ptr> LlamaTritonModel::createSh start_id_, end_id_, cache_block_seq_len_, + quantization_, quant_policy_, use_context_fmha_, engine_params_, @@ -415,6 +426,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) attn_bias_, weight_type_, group_size_, + quantization_, lora_params_, tensor_para_size_, tensor_para_rank); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index c0a0ebf3a..47f1fb101 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -22,6 +22,7 @@ #include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/utils/cuda_utils.h" @@ -100,6 +101,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { ft::WeightType weight_type_; bool attn_bias_; int quant_policy_; + turbomind::QuantMethod quantization_; int group_size_; turbomind::LoraParams lora_params_; diff --git a/src/turbomind/utils/cuda_type_utils.cuh b/src/turbomind/utils/cuda_type_utils.cuh index f7f7b9527..69c873f52 100644 --- a/src/turbomind/utils/cuda_type_utils.cuh +++ b/src/turbomind/utils/cuda_type_utils.cuh @@ -309,6 +309,16 @@ __device__ inline half2 cuda_cast(half val) { return __half2half2(val); } +template<> +__device__ inline half cuda_cast(float val) +{ + return __float2half(val); +} +template<> +__device__ inline float cuda_cast(half val) +{ + return __half2float(val); +} template<> __device__ inline int8_t cuda_cast(half val)