From 9ce118c711c744a6ad4dd5bffe3ca78c78102e3e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 09:33:48 +0300 Subject: [PATCH 001/114] add file for mpt-7b-instruct model --- mlc_llm/relax_model/mpt.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 mlc_llm/relax_model/mpt.py diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py new file mode 100644 index 0000000000..e69de29bb2 From f922562889fda6e94fead7ac29c0d296bfe171be Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 09:34:41 +0300 Subject: [PATCH 002/114] update build.py for mpt --- build.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/build.py b/build.py index 315733e561..6a4f705b2a 100644 --- a/build.py +++ b/build.py @@ -11,7 +11,7 @@ import mlc_llm from mlc_llm import utils -from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv +from mlc_llm.relax_model import gpt_bigcode, gpt_neox, llama, moss, rwkv, mpt def _parse_args(): @@ -407,6 +407,8 @@ def main(): mod, params = moss.get_model(ARGS, config) elif ARGS.model_category == "rwkv": mod, params = rwkv.get_model(ARGS, config) + elif ARGS.model_category == "mpt": + mod, params = mpt.get_model(ARGS, config) else: raise ValueError(f"Model {ARGS.model} not supported") mod = mod_transform_before_build(mod, params, ARGS) From 14eb7e17f699ac0ff2465a957440c03423d4540d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 09:38:46 +0300 Subject: [PATCH 003/114] update utils --- mlc_llm/relax_model/__init__.py | 1 + mlc_llm/utils.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/__init__.py b/mlc_llm/relax_model/__init__.py index 9ee3d0db52..1ee1adbe07 100644 --- a/mlc_llm/relax_model/__init__.py +++ b/mlc_llm/relax_model/__init__.py @@ -1 +1,2 @@ from . import llama +from . import mpt diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7ae411015e..41519aec01 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -51,7 +51,7 @@ class Quantization: ), } -supported_model_types = set(["llama", "gpt_neox", "gpt_bigcode", "moss", "rwkv"]) +supported_model_types = set(["llama", "gpt_neox", "gpt_bigcode", "moss", "rwkv", "mpt"]) def argparse_postproc_common(args: argparse.Namespace) -> None: @@ -78,6 +78,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "gorilla-": ("gorilla", "llama"), "starcoder": ("code_gpt", "gpt_bigcode"), "wizardcoder-": ("code_gpt", "gpt_bigcode"), + "mpt-": ("mpt", "mpt-7b", "mpt-7b-instruct"), } model = args.model.lower() for prefix, (conv_template, model_category) in supported_model_prefix.items(): From 255ec66bc15123eeda68a65ee1b538e9ebdc07c0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 10:08:14 +0300 Subject: [PATCH 004/114] add MPTConfig --- mlc_llm/relax_model/mpt_config.py | 164 ++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 mlc_llm/relax_model/mpt_config.py diff --git a/mlc_llm/relax_model/mpt_config.py b/mlc_llm/relax_model/mpt_config.py new file mode 100644 index 0000000000..ecd851b440 --- /dev/null +++ b/mlc_llm/relax_model/mpt_config.py @@ -0,0 +1,164 @@ +""" +It is simply copy from https://huggingface.co/mosaicml/mpt-7b-instruct/blob/main/configuration_mpt.py +A HuggingFace-style model configuration. +""" +from typing import Dict, Optional, Union +from transformers import PretrainedConfig + +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8 +} +init_config_defaults: Dict = { + 'name': 'kaiming_normal_', + 'fan_mode': 'fan_in', + 'init_nonlinearity': 'relu', + 'init_div_is_residual': True, + 'emb_init_std': None, + 'emb_init_uniform_lim': None, + 'init_std': None, + 'init_gain': 0.0 +} + + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + + def __init__( + self, + d_model: int=2048, + n_heads: int=16, + n_layers: int=24, + expansion_ratio: int=4, + max_seq_len: int=2048, + vocab_size: int=50368, + resid_pdrop: float=0.0, + emb_pdrop: float=0.0, + learned_pos_emb: bool=True, + attn_config: Dict=attn_config_defaults, + init_device: str='cpu', + logit_scale: Optional[Union[float, str]]=None, + no_bias: bool=False, + verbose: int=0, + embedding_fraction: float=1.0, + norm_type: str='low_precision_layernorm', + use_cache: bool=False, + init_config: Dict=init_config_defaults, + **kwargs + ): + """The MPT configuration class. + + Args: + d_model (int): The size of the embedding dimension of the model. + n_heads (int): The number of attention heads. + n_layers (int): The number of layers in the model. + expansion_ratio (int): The ratio of the up/down scale in the MLP. + max_seq_len (int): The maximum sequence length of the model. + vocab_size (int): The size of the vocabulary. + resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. + emb_pdrop (float): The dropout probability for the embedding layer. + learned_pos_emb (bool): Whether to use learned positional embeddings + attn_config (Dict): A dictionary used to configure the model's attention module: + attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention + attn_pdrop (float): The dropout probability for the attention layers. + attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. + qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. + clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to + this value. + softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, + use the default scale of ``1/sqrt(d_keys)``. + prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an + extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix + can attend to one another bi-directionally. Tokens outside the prefix use causal attention. + attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. + When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates + which sub-sequence each token belongs to. + Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + alibi (bool): Whether to use the alibi bias instead of position embeddings. + alibi_bias_max (int): The maximum value of the alibi bias. + init_device (str): The device to use for parameter initialization. + logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. + no_bias (bool): Whether to use bias in all layers. + verbose (int): The verbosity level. 0 is silent. + embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. + norm_type (str): choose type of norm to use + multiquery_attention (bool): Whether to use multiquery attention implementation. + use_cache (bool): Whether or not the model should return the last key/values attentions + init_config (Dict): A dictionary used to configure the model initialization: + init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', + 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or + 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. + init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. + emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. + emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution + used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. + init_std (float): The standard deviation of the normal distribution used to initialize the model, + if using the baseline_ parameter initialization scheme. + init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. + fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. + init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. + --- + See llmfoundry.models.utils.param_init_fns.py for info on other param init config options + """ + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.verbose = verbose + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults(self, config, config_defaults): + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self): + self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): + raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') + if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") + if self.init_config.get('name', None) is None: + raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") + if not self.learned_pos_emb and (not self.attn_config['alibi']): + raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') From eb09dc299b8a5a4cb04c2c6440c941219b5bfa5d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 10:48:17 +0300 Subject: [PATCH 005/114] add get_model function like implementation for t5 --- mlc_llm/relax_model/mpt.py | 63 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index e69de29bb2..aca8c11ec4 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -0,0 +1,63 @@ +from typing import Optional, Tuple +import numpy as np + +import torch + +import tvm +from tvm import relax, te + +from .mpt_config import MPTConfig + + +def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + pass + + +def get_model(args, hf_config): + from transformers import AutoModelForCausalLM # type: ignore[import] + + model_name = args.model + # TODO: download model and use model_path instead of args for from_pretrained + # model_path = args.model_path + dtype = args.quantization.model_dtype + # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct + max_seq_len = args.max_seq_len if args.max_seq_len is not None else 4096 # 4096 recommended + + config.update({"max_seq_len": max_seq_len}) + config.update({"max_new_tokens": args.seq_len}) + + if model_name.startswith("mpt-"): + config = MPTConfig(**hf_config) + + bb = relax.BlockBuilder() + create_encoding_func(bb, config) + + mod = bb.get() + + device = tvm.cpu() + # TODO: get default mpt-7b-instruct from HF. Possibly it should be downloaded earlier + # and use model_path instead + hf_model = AutoModelForCausalLM.from_pretrained( + 'mosaicml/mpt-7b-instruct', + config=config, + torch_dtype=torch.bfloat16, + trust_remote_code=True + ) + # Get a list of parameters in advance, then delete the model to save memory + # param_list = [param for _, param in hf_model.named_parameters()] + for name, param in hf_model.named_parameters(): + print(name, param.shape) + # Get a list of parameters in advance, then delete the model to save memory + param_list = [param for _, param in hf_model.named_parameters()] + + for i, param in enumerate(param_list): + # TODO: dtype? what is about mix-precision? + param_list[i] = tvm.nd.array( + param.detach().cpu().numpy().astype(dtype), device + ) + del hf_model + + print(mod) + return mod, param_list + + raise ValueError(f"Unsupported model: {model_name}") \ No newline at end of file From 5b211136398d21ad9771dd550238f672106160f8 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 15:41:36 +0300 Subject: [PATCH 006/114] add MPTMLP, get Linear from Llama for a moment --- mlc_llm/relax_model/mpt.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index aca8c11ec4..db75f8c831 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -5,10 +5,38 @@ import tvm from tvm import relax, te +from tvm.relax.testing import nn +from tvm.script import relax as R from .mpt_config import MPTConfig +# TODO: it is identical to Linear from llama.py +class Linear(nn.Module): + def __init__(self, in_features, out_features, dtype: str, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + (out_features, in_features), dtype=dtype, name="linear_weight" + ) + if bias: + self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + return nn.emit(relax.op.linear(input, self.weight, self.bias)) + + +class MPTMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype) + + def forward(self, x): + return self.down_proj(relax.op.nn.gelu(self.up_proj(x))) + + def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: pass From a7371b8a004b60168820f96bca5763fe17bef19f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 21:06:43 +0300 Subject: [PATCH 007/114] add low-precision layer norm, need to correct it further --- mlc_llm/relax_model/mpt.py | 84 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index db75f8c831..dd8a70ee6b 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -11,6 +11,37 @@ from .mpt_config import MPTConfig +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor + +class LPLayerNorm(torch.nn.LayerNorm): + def __init__(self, normalized_shape, eps=1e-05, dtype=None): + self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") + self.bias = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_bias") + # TODO: check + self.weight = relax.op.ones((normalized_shape,), dtype) + self.bias = relax.op.zeros((normalized_shape,), dtype) + self.variance_epsilon = tvm.tir.const(eps, dtype) + + def forward(self, x): + module_device = x.device + downcast_x = _cast_if_autocast_enabled(x) + downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + with torch.autocast(enabled=False, device_type=module_device.type): + return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) + +NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNorm} + + # TODO: it is identical to Linear from llama.py class Linear(nn.Module): def __init__(self, in_features, out_features, dtype: str, bias=True): @@ -37,6 +68,59 @@ def forward(self, x): return self.down_proj(relax.op.nn.gelu(self.up_proj(x))) +class MPTBlock(nn.Module): + def __init__(self, config: MPTConfig): + self.hidden_size = config.d_model + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + dtype=config.dtype, + ) + self.mlp = MPTMLP( + hidden_size=self.hidden_size, + intermediate_size=config.expansion_ratio*self.hidden_size, + dtype=config.dtype, + ) + self.input_layernorm = LlamaRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + cos_cached: relax.Expr, + sin_cached: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + cos_cached=cos_cached, + sin_cached=sin_cached, + past_key_value=past_key_value, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + ) + hidden_states = nn.emit(residual + hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.emit(residual + hidden_states) + + return hidden_states, attn_weights, present_key_value + + def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: pass From 2502c2dcb13d51287503ac4a21c27898b2d8e843 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 6 Jun 2023 21:34:45 +0300 Subject: [PATCH 008/114] MPTBlock was implemented --- mlc_llm/relax_model/mpt.py | 108 +++++++++++++++++++------------------ 1 file changed, 57 insertions(+), 51 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index dd8a70ee6b..c6da0bc3e3 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -8,7 +8,7 @@ from tvm.relax.testing import nn from tvm.script import relax as R -from .mpt_config import MPTConfig +from .mpt_config import MPTConfig, attn_config_defaults def _cast_if_autocast_enabled(tensor): @@ -69,56 +69,62 @@ def forward(self, x): class MPTBlock(nn.Module): - def __init__(self, config: MPTConfig): - self.hidden_size = config.d_model - self.self_attn = LlamaAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - dtype=config.dtype, - ) - self.mlp = MPTMLP( - hidden_size=self.hidden_size, - intermediate_size=config.expansion_ratio*self.hidden_size, - dtype=config.dtype, - ) - self.input_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps - ) - - def forward( - self, - hidden_states: relax.Expr, - cos_cached: relax.Expr, - sin_cached: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, present_key_value = self.self_attn( - hidden_states=hidden_states, - cos_cached=cos_cached, - sin_cached=sin_cached, - past_key_value=past_key_value, - attention_mask=attention_mask, - all_seq_len_shape=all_seq_len_shape, - ) - hidden_states = nn.emit(residual + hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = nn.emit(residual + hidden_states) - - return hidden_states, attn_weights, present_key_value + def __init__(self, config: MPTConfig): + # Get values from config or defaults + attn_config = config.attn_config if config.attn_config is not None else attn_config_defaults + norm_type = config.norm_type if config.norm_type is not None else 'low_precision_layernorm' + verbose = config.verbose if config.verbose is not None else 0 + # Define layer norm and attention classes + norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] + attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] + + self.hidden_size = config.d_model + # Init layers + self.self_attn = attn_class( + attn_impl=attn_config['attn_impl'], + clip_qkv=attn_config['clip_qkv'], + qk_ln=attn_config['qk_ln'], + softmax_scale=attn_config['softmax_scale'], + attn_pdrop=attn_config['attn_pdrop'], + d_model=self.hidden_size, + n_heads=config.n_heads, + verbose=verbose, + ) + self.mlp = MPTMLP( + hidden_size=self.hidden_size, + intermediate_size=config.expansion_ratio*self.hidden_size, + dtype=config.dtype, + ) + self.input_layernorm = norm_class(self.hidden_size) + self.post_attention_layernorm = norm_class(self.hidden_size) + + def forward( + self, + hidden_states: relax.Expr, + past_key_value: Tuple[relax.Expr], + attn_bias: Optional[relax.Expr] = None, + attention_mask: Optional[relax.Expr] = None, + is_causal: bool=True, + ) -> Tuple[relax.Expr, relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + (hidden_states, attn_weights, present_key_value) = self.self_attn( + hidden_states, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=is_causal + ) + residual = nn.emit(residual + hidden_states) + + # Fully Connected + hidden_states = self.post_attention_layernorm(residual) + hidden_states = self.mlp(hidden_states) + hidden_states = nn.emit(residual + hidden_states) + + return (hidden_states, attn_weights, present_key_value) def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: From cabd1a81186cd755ea9544e50c36fd713a242084 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 7 Jun 2023 10:46:35 +0300 Subject: [PATCH 009/114] update MPTConfig by dtype --- mlc_llm/relax_model/mpt_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt_config.py b/mlc_llm/relax_model/mpt_config.py index ecd851b440..18a5d7e6d5 100644 --- a/mlc_llm/relax_model/mpt_config.py +++ b/mlc_llm/relax_model/mpt_config.py @@ -1,5 +1,6 @@ """ -It is simply copy from https://huggingface.co/mosaicml/mpt-7b-instruct/blob/main/configuration_mpt.py +It is practicaly copy from https://huggingface.co/mosaicml/mpt-7b-instruct/blob/main/configuration_mpt.py +but `dtype` field is added A HuggingFace-style model configuration. """ from typing import Dict, Optional, Union @@ -52,6 +53,7 @@ def __init__( norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, + dtype=None, **kwargs ): """The MPT configuration class. @@ -126,6 +128,7 @@ def __init__( self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config + self.dtype = dtype if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: From adb4e39c602f4b57137698204b0468f59e2cfd16 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 7 Jun 2023 10:47:15 +0300 Subject: [PATCH 010/114] draft for attentions layers of MPT. some updates --- mlc_llm/relax_model/mpt.py | 281 ++++++++++++++++++++++++++++++++++--- 1 file changed, 260 insertions(+), 21 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index c6da0bc3e3..ae3445fa00 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1,3 +1,6 @@ +import math # TODO: replace +from einops import rearrange # TODO: replace +import warnings from typing import Optional, Tuple import numpy as np @@ -9,7 +12,12 @@ from tvm.script import relax as R from .mpt_config import MPTConfig, attn_config_defaults - +from .modules import ( + Embedding, + LayerNorm, + Linear, + ModuleList, +) def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): @@ -42,21 +50,256 @@ def forward(self, x): NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNorm} -# TODO: it is identical to Linear from llama.py -class Linear(nn.Module): - def __init__(self, in_features, out_features, dtype: str, bias=True): - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - (out_features, in_features), dtype=dtype, name="linear_weight" - ) - if bias: - self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") +def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') + else: + return False + return original_is_causal + + +def scaled_multihead_dot_product_attention( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + needs_weights=False, + multiquery=False +): + q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + kv_n_heads = 1 if multiquery else n_heads + k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) + v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + if past_key_value is not None: + if len(past_key_value) != 0: + k = torch.cat([past_key_value[0], k], dim=3) + v = torch.cat([past_key_value[1], v], dim=2) + past_key_value = (k, v) + (b, _, s_q, d) = q.shape + s_k = k.size(-1) + if softmax_scale is None: + softmax_scale = 1 / math.sqrt(d) + attn_weight = q.matmul(k) * softmax_scale + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - s_q) + _s_k = max(0, attn_bias.size(3) - s_k) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): + raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') + attn_weight = attn_weight + attn_bias + min_val = torch.finfo(q.dtype).min + if key_padding_mask is not None: + if attn_bias is not None: + warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) + if is_causal and (not q.size(2) == 1): + s = max(s_q, s_k) + causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) + causal_mask = causal_mask.tril() + causal_mask = causal_mask.to(torch.bool) + causal_mask = ~causal_mask + causal_mask = causal_mask[-s_q:, -s_k:] + attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + attn_weight = torch.softmax(attn_weight, dim=-1) + out = attn_weight.matmul(v) + out = rearrange(out, 'b h s d -> b s (h d)') + if needs_weights: + return (out, attn_weight, past_key_value) + return (out, None, past_key_value) + + +def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): + for tensor in tensors: + if tensor.dtype not in valid_dtypes: + raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') + if not tensor.is_cuda: + raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') + + +def flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + needs_weights=False, + multiquery=False +): + try: + from flash_attn import bert_padding, flash_attn_interface + except: + raise RuntimeError('Please install flash-attn==1.0.3.post0') + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if attn_bias is not None: + raise NotImplementedError(f'attn_bias not implemented for flash attn.') + (batch_size, seqlen) = query.shape[:2] + if key_padding_mask is None: + key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + query_padding_mask = key_padding_mask[:, -query.size(1):] + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) + query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) + key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) + value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + if multiquery: + key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) + value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + return (output, None, past_key_value) + + +def triton_flash_attn_fn( + query, + key, + value, + n_heads, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + needs_weights=False, + multiquery=False): + try: + from .flash_attn_triton import flash_attn_func + except: + _installed = False + if version.parse(torch.__version__) < version.parse('2.0.0'): + _installed = True + try: + from flash_attn.flash_attn_triton import flash_attn_func + except: + _installed = False + if not _installed: + raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = torch.cat([past_key_value[0], key], dim=1) + value = torch.cat([past_key_value[1], value], dim=1) + past_key_value = (key, value) + if attn_bias is not None: + _s_q = max(0, attn_bias.size(2) - query.size(1)) + _s_k = max(0, attn_bias.size(3) - key.size(1)) + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if needs_weights: + raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') + if key_padding_mask is not None: + warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + (b_size, s_k) = key_padding_mask.shape[:2] + if attn_bias is None: + attn_bias = query.new_zeros(b_size, 1, 1, s_k) + attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) + query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) + key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + if multiquery: + key = key.expand(*key.shape[:2], n_heads, key.size(-1)) + value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) + output = attn_output.view(*attn_output.shape[:2], -1) + return (output, None, past_key_value) + + +class MultiheadAttention(nn.Module): + """Multi-head self attention. + Using torch or triton attention implemetation enables user to also use + additive bias. + """ + + def __init__( + self, + d_model: int, + n_heads: int, + attn_impl: str='triton', + clip_qkv: Optional[float]=None, + qk_ln: bool=False, + softmax_scale: Optional[float]=None, + low_precision_layernorm: bool=False, + device: Optional[str]=None + ): + # Init fields + self.d_model = d_model + self.n_heads = n_heads + self.attn_impl = attn_impl + self.clip_qkv = clip_qkv + self.qk_ln = qk_ln + self.softmax_scale = softmax_scale + + if self.softmax_scale is None: + self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) + self.Wqkv = Linear(self.d_model, 3 * self.d_model, device=device) + fuse_splits = (d_model, 2 * d_model) + self.Wqkv._fused = (0, fuse_splits) + if self.qk_ln: + layernorm_class = LPLayerNorm if low_precision_layernorm else LayerNorm + self.q_ln = layernorm_class(self.d_model, device=device) + self.k_ln = layernorm_class(self.d_model, device=device) + if self.attn_impl == 'flash': + self.attn_fn = flash_attn_fn + elif self.attn_impl == 'triton': + # While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. + # When training larger models this can trigger alloc retries which hurts performance. + # If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`. + self.attn_fn = triton_flash_attn_fn + elif self.attn_impl == 'torch': + # Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` + # otherwise we recommend using `attn_impl: triton`. + self.attn_fn = scaled_multihead_dot_product_attention else: - self.bias = None + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + self.out_proj = Linear(self.d_model, self.d_model, device=device) + # TODO: Does field _is_residual exist? + self.out_proj._is_residual = True + + def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + qkv = self.Wqkv(x) + if self.clip_qkv: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + (query, key, value) = qkv.chunk(3, dim=2) + key_padding_mask = attention_mask + if self.qk_ln: + dtype = query.dtype + query = self.q_ln(query).to(dtype) + key = self.k_ln(key).to(dtype) + (context, attn_weights, past_key_value) = self.attn_fn( + query, + key, + value, + self.n_heads, + past_key_value=past_key_value, + softmax_scale=self.softmax_scale, + attn_bias=attn_bias, + key_padding_mask=key_padding_mask, + is_causal=is_causal, + needs_weights=needs_weights + ) + return (self.out_proj(context), attn_weights, past_key_value) - def forward(self, input: relax.Expr) -> relax.Var: - return nn.emit(relax.op.linear(input, self.weight, self.bias)) +ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} class MPTMLP(nn.Module): @@ -73,7 +316,6 @@ def __init__(self, config: MPTConfig): # Get values from config or defaults attn_config = config.attn_config if config.attn_config is not None else attn_config_defaults norm_type = config.norm_type if config.norm_type is not None else 'low_precision_layernorm' - verbose = config.verbose if config.verbose is not None else 0 # Define layer norm and attention classes norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] @@ -81,14 +323,13 @@ def __init__(self, config: MPTConfig): self.hidden_size = config.d_model # Init layers self.self_attn = attn_class( + d_model=self.hidden_size, + n_heads=config.n_heads, attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], - d_model=self.hidden_size, - n_heads=config.n_heads, - verbose=verbose, ) self.mlp = MPTMLP( hidden_size=self.hidden_size, @@ -145,7 +386,7 @@ def get_model(args, hf_config): config.update({"max_new_tokens": args.seq_len}) if model_name.startswith("mpt-"): - config = MPTConfig(**hf_config) + config = MPTConfig(**hf_config, dtype=dtype) bb = relax.BlockBuilder() create_encoding_func(bb, config) @@ -161,8 +402,6 @@ def get_model(args, hf_config): torch_dtype=torch.bfloat16, trust_remote_code=True ) - # Get a list of parameters in advance, then delete the model to save memory - # param_list = [param for _, param in hf_model.named_parameters()] for name, param in hf_model.named_parameters(): print(name, param.shape) # Get a list of parameters in advance, then delete the model to save memory From bd7edfa6342745349ca797554e5c33c36479c6de Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 7 Jun 2023 14:47:09 +0300 Subject: [PATCH 011/114] _reset_is_causal and attn_bias_shape methods were added. __init__ of MPTModel was refactored --- mlc_llm/relax_model/mpt.py | 280 +++++++++++++++++++++++++++++++++++-- 1 file changed, 267 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index ae3445fa00..7c24cf4384 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -4,8 +4,6 @@ from typing import Optional, Tuple import numpy as np -import torch - import tvm from tvm import relax, te from tvm.relax.testing import nn @@ -30,13 +28,12 @@ def _cast_if_autocast_enabled(tensor): return tensor.to(dtype=dtype) return tensor -class LPLayerNorm(torch.nn.LayerNorm): +# Low-precision layer norm for mpt-7b-instruct, where are no biases expected +class LPLayerNormWOBias(nn.Module): def __init__(self, normalized_shape, eps=1e-05, dtype=None): self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") - self.bias = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_bias") # TODO: check self.weight = relax.op.ones((normalized_shape,), dtype) - self.bias = relax.op.zeros((normalized_shape,), dtype) self.variance_epsilon = tvm.tir.const(eps, dtype) def forward(self, x): @@ -47,16 +44,16 @@ def forward(self, x): with torch.autocast(enabled=False, device_type=module_device.type): return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) -NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNorm} +NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNormWOBias} def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool): - if original_is_causal and num_query_tokens != num_key_tokens: - if num_query_tokens != 1: - raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') - else: - return False - return original_is_causal + if original_is_causal and num_query_tokens != num_key_tokens: + if num_query_tokens != 1: + raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.') + else: + return False + return original_is_causal def scaled_multihead_dot_product_attention( @@ -255,7 +252,7 @@ def __init__( fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: - layernorm_class = LPLayerNorm if low_precision_layernorm else LayerNorm + layernorm_class = LPLayerNormWOBias if low_precision_layernorm else LayerNorm self.q_ln = layernorm_class(self.d_model, device=device) self.k_ln = layernorm_class(self.d_model, device=device) if self.attn_impl == 'flash': @@ -302,6 +299,21 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + + class MPTMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype) @@ -368,6 +380,248 @@ def forward( return (hidden_states, attn_weights, present_key_value) +class MPTModel(nn.Module): + def __init__(self, config: MPTConfig): + config._validate_config() + # Init fields from config + self.attn_impl = config.attn_config['attn_impl'] + self.prefix_lm = config.attn_config['prefix_lm'] + self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id'] + self.alibi = config.attn_config['alibi'] + self.alibi_bias_max = config.attn_config['alibi_bias_max'] + self.is_causal = not self.prefix_lm + + self._attn_bias_initialized = False + self.attn_bias = None + self.attn_bias_shape = attn_bias_shape( + self.attn_impl, + config.n_heads, + config.max_seq_len, + self.alibi, + prefix_lm=self.prefix_lm, + causal=self.is_causal, + use_sequence_id=self.attn_uses_sequence_id + ) + + # Define layer norm type + if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys(): + norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys()) + raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).') + norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()] + + # Init layers + self.wte = Embedding(config.vocab_size, config.d_model, dtype=config.dtype) + if not self.alibi: + self.wpe = Embedding(config.max_seq_len, config.d_model, dtype=config.dtype) + self.blocks = ModuleList([MPTBlock(config) for _ in range(config.n_layers)]) + self.norm_f = norm_class(config.d_model, dtype=config.dtype) + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, value): + self.wte = value + + def _attn_bias(self, device, dtype, attention_mask: Optional[relax.Expr]=None, prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None): + if not self._attn_bias_initialized: + if self.attn_bias_shape: + self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) + self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) + self._attn_bias_initialized = True + if self.attn_impl == 'flash': + return (self.attn_bias, attention_mask) + if self.attn_bias is not None: + self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) + attn_bias = self.attn_bias + if self.prefix_lm: + assert isinstance(attn_bias, torch.Tensor) + assert isinstance(prefix_mask, torch.Tensor) + attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) + if self.attn_uses_sequence_id and sequence_id is not None: + assert isinstance(attn_bias, torch.Tensor) + attn_bias = self._apply_sequence_id(attn_bias, sequence_id) + if attention_mask is not None: + s_k = attention_mask.shape[-1] + if attn_bias is None: + attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) + else: + _s_k = max(0, attn_bias.size(-1) - s_k) + attn_bias = attn_bias[:, :, :, _s_k:] + if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: + raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) + return (attn_bias, None) + + def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): + (s_k, s_q) = attn_bias.shape[-2:] + if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: + raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') + seq_len = prefix_mask.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + attn_bias = attn_bias[..., :seq_len, :seq_len] + causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) + prefix = prefix_mask.view(-1, 1, 1, seq_len) + cannot_attend = ~torch.logical_or(causal, prefix.bool()) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): + seq_len = sequence_id.shape[-1] + if seq_len > self.config.max_seq_len: + raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + attn_bias = attn_bias[..., :seq_len, :seq_len] + cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + min_val = torch.finfo(attn_bias.dtype).min + attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + return attn_bias + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): + return_dict = return_dict if return_dict is not None else self.config.return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + if attention_mask is not None: + attention_mask = attention_mask.bool() + if prefix_mask is not None: + prefix_mask = prefix_mask.bool() + if not return_dict: + raise NotImplementedError('return_dict False is not implemented yet for MPT') + if output_attentions: + if self.attn_impl != 'torch': + raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') + if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: + raise NotImplementedError('MPT does not support training with left padding.') + if self.prefix_lm and prefix_mask is None: + raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') + if self.training: + if self.attn_uses_sequence_id and sequence_id is None: + raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') + elif self.attn_uses_sequence_id is False and sequence_id is not None: + warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') + S = input_ids.size(1) + assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + tok_emb = self.wte(input_ids) + if self.alibi: + x = tok_emb + else: + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') + past_position = past_key_values[0][0].size(1) + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].size(3) + if S + past_position > self.config.max_seq_len: + raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') + pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) + if attention_mask is not None: + pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) + pos_emb = self.wpe(pos) + x = tok_emb + pos_emb + (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) + if use_cache and past_key_values is None: + past_key_values = [() for _ in range(self.config.n_layers)] + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + for (b_idx, block) in enumerate(self.blocks): + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + past_key_value = past_key_values[b_idx] if past_key_values is not None else None + (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + if past_key_values is not None: + past_key_values[b_idx] = past_key_value + if output_attentions: + assert all_self_attns is not None + all_self_attns = all_self_attns + (attn_weights,) + x = self.norm_f(x) + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns) + + def fsdp_wrap_fn(self, module): + return isinstance(module, MPTBlock) + + def activation_checkpointing_fn(self, module): + return isinstance(module, MPTBlock) + + +class MPTForCausalLM(nn.Module): + def __init__(self, config: MPTConfig): + if not config.tie_word_embeddings: + raise ValueError('MPTForCausalLM only supports tied word embeddings') + self.transformer = MPTModel(config) + + def get_input_embeddings(self): + return self.transformer.wte + + def set_input_embeddings(self, value): + self.transformer.wte = value + + def get_output_embeddings(self): + return self.transformer.wte + + def set_output_embeddings(self, new_embeddings): + self.transformer.wte = new_embeddings + + def set_decoder(self, decoder): + self.transformer = decoder + + def get_decoder(self): + return self.transformer + + def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): + return_dict = return_dict if return_dict is not None else self.config.return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache + outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) + logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) + loss = None + if labels is not None: + labels = torch.roll(labels, shifts=-1) + labels[:, -1] = -100 + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) + return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + def fsdp_wrap_fn(self, module): + return isinstance(module, MPTBlock) + + def activation_checkpointing_fn(self, module): + return isinstance(module, MPTBlock) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + if inputs_embeds is not None: + raise NotImplementedError('inputs_embeds is not implemented for MPT yet') + attention_mask = kwargs['attention_mask'].bool() + if attention_mask[:, -1].sum() != attention_mask.shape[0]: + raise NotImplementedError('MPT does not support generation with right padding.') + if self.transformer.attn_uses_sequence_id and self.training: + sequence_id = torch.zeros_like(input_ids[:1]) + else: + sequence_id = None + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if self.transformer.prefix_lm: + prefix_mask = torch.ones_like(attention_mask) + if kwargs.get('use_cache') == False: + raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') + else: + prefix_mask = None + return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + """Used by HuggingFace generate when using beam search with kv-caching. + See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 + for an example in transformers. + """ + reordered_past = [] + for layer_past in past_key_values: + reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] + return reordered_past + + def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: pass From feccf591c69e596dbe2ed1c0e9e09e750f5b5c53 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 10:10:12 +0300 Subject: [PATCH 012/114] MPTForCausalLM was implemented on Relax, some TODOs were added --- mlc_llm/relax_model/mpt.py | 58 ++++++++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 7c24cf4384..742987e4f1 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1,7 +1,7 @@ -import math # TODO: replace +import math from einops import rearrange # TODO: replace import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, List import numpy as np import tvm @@ -553,6 +553,7 @@ def __init__(self, config: MPTConfig): if not config.tie_word_embeddings: raise ValueError('MPTForCausalLM only supports tied word embeddings') self.transformer = MPTModel(config) + self.dtype = config.dtype def get_input_embeddings(self): return self.transformer.wte @@ -572,17 +573,34 @@ def set_decoder(self, decoder): def get_decoder(self): return self.transformer - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): + def forward( + self, + input_ids: relax.Expr, + past_key_values: Optional[List[Tuple[relax.Expr]]]=None, + attention_mask: Optional[relax.Expr]=None, + prefix_mask: Optional[relax.Expr]=None, + sequence_id: Optional[relax.Expr]=None, + return_dict: Optional[bool]=None, + output_attentions: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + use_cache: Optional[bool]=None + ): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache - outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) - logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) - loss = None - if labels is not None: - labels = torch.roll(labels, shifts=-1) - labels[:, -1] = -100 - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) - return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + outputs = self.transformer( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + prefix_mask=prefix_mask, + sequence_id=sequence_id, + return_dict=return_dict, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache + ) + logits = nn.emit(relax.op.matmul(outputs.last_hidden_state, self.transformer.wte.weight)) + + return logits, outputs.past_key_values def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) @@ -597,18 +615,27 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: - sequence_id = torch.zeros_like(input_ids[:1]) + # TODO: [:1] in Relax? + sequence_id = nn.emit(relax.op.zeros_like(input_ids[:1])) else: sequence_id = None if past_key_values is not None: + # TODO: Relax implementation? input_ids = input_ids[:, -1].unsqueeze(-1) if self.transformer.prefix_lm: - prefix_mask = torch.ones_like(attention_mask) + prefix_mask = nn.emit(relax.op.ones_like(attention_mask, self.dtype)) if kwargs.get('use_cache') == False: raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') else: prefix_mask = None - return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)} + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'prefix_mask': prefix_mask, + 'sequence_id': sequence_id, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache', True) + } @staticmethod def _reorder_cache(past_key_values, beam_idx): @@ -618,7 +645,8 @@ def _reorder_cache(past_key_values, beam_idx): """ reordered_past = [] for layer_past in past_key_values: - reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] + # TODO: Relax implementation? + reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] return reordered_past From b9e5adbe26ade11a3466a3e5e0ea0b3c31d877e0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 12:56:50 +0300 Subject: [PATCH 013/114] rearrange from einops was replaced by relax ops --- mlc_llm/relax_model/mpt.py | 93 ++++++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 15 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 742987e4f1..b03390927f 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1,5 +1,4 @@ import math -from einops import rearrange # TODO: replace import warnings from typing import Optional, Tuple, List import numpy as np @@ -56,11 +55,36 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau return original_is_causal +def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int): + ''' + Transform shape of input: b s (h d) -> b h d s + ''' + batch_size, seqlen, _ = hidden_states.struct_info.shape + inter = nn.emit(relax.op.reshape( + hidden_states, + (batch_size, seqlen, n_heads, d_model), + )) + return nn.emit(relax.op.permute_dims(inter, [0, 2, 1, 3])) + + +def reverse_reshape_and_permute(hidden_states: relax.Expr): + ''' + Transform shape of input: b h s d -> b s (h d) + ''' + batch_size, n_heads, seqlen, d_model = hidden_states.struct_info.shape + inter = nn.emit(relax.op.permute_dims(hidden_states, [0, 2, 1, 3])) + return nn.emit(relax.op.reshape( + inter, + (batch_size, seqlen, n_heads*d_model), + )) + + def scaled_multihead_dot_product_attention( query, key, value, n_heads, + d_model, past_key_value=None, softmax_scale=None, attn_bias=None, @@ -69,10 +93,10 @@ def scaled_multihead_dot_product_attention( needs_weights=False, multiquery=False ): - q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) + q = reshape_and_permute(query, n_heads, d_model) kv_n_heads = 1 if multiquery else n_heads - k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) - v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) + k = reshape_and_permute(key, kv_n_heads, d_model) + v = reshape_and_permute(value, kv_n_heads, d_model) if past_key_value is not None: if len(past_key_value) != 0: k = torch.cat([past_key_value[0], k], dim=3) @@ -105,7 +129,7 @@ def scaled_multihead_dot_product_attention( attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) attn_weight = torch.softmax(attn_weight, dim=-1) out = attn_weight.matmul(v) - out = rearrange(out, 'b h s d -> b s (h d)') + out = reverse_reshape_and_permute(out) if needs_weights: return (out, attn_weight, past_key_value) return (out, None, past_key_value) @@ -124,6 +148,7 @@ def flash_attn_fn( key, value, n_heads, + d_model, past_key_value=None, softmax_scale=None, attn_bias=None, @@ -150,20 +175,40 @@ def flash_attn_fn( raise NotImplementedError(f'attn_bias not implemented for flash attn.') (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: - key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) + key_shape = key.struct_info.shape[:2] + key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) query_padding_mask = key_padding_mask[:, -query.size(1):] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) - query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads) + + nnz, _, _ = query_unpad.struct_info.shape + query_unpad = nn.emit(relax.op.reshape( + query_unpad, + (nnz, n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + + nnz, _, _ = key_unpad.struct_info.shape + kv_n_heads = 1 if multiquery else n_heads (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) - key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + key_unpad = nn.emit(relax.op.reshape( + key_unpad, + (nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) - value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads) + value_unpad = nn.emit(relax.op.reshape( + value_unpad, + (nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + if multiquery: key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) - output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen) + output_unpad = nn.emit(relax.op.reshape( + output_unpad, + (nnz, n_heads*d_model), + )) # (nnz, h, d)) -> (nnz, (h d)) + output = bert_padding.pad_input(output_unpad, indices_q, batch_size, seqlen) return (output, None, past_key_value) @@ -172,6 +217,7 @@ def triton_flash_attn_fn( key, value, n_heads, + d_model, past_key_value=None, softmax_scale=None, attn_bias=None, @@ -209,13 +255,27 @@ def triton_flash_attn_fn( if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) - query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads) - key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) - value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads) + + batch_size, seq_len, _ = query.struct_info.shape + query = nn.emit(relax.op.reshape( + query, + (batch_size, seq_len, n_heads, d_model), + )) # b s (h d) -> b s h d + + batch_size, seq_len, _ = key.struct_info.shape + kv_n_heads = 1 if multiquery else n_heads + key = nn.emit(relax.op.reshape( + key, + (batch_size, seq_len, kv_n_heads, d_model), + )) # b s (h d) -> b s h d + value = nn.emit(relax.op.reshape( + value, + (batch_size, seq_len, kv_n_heads, d_model), + )) # b s (h d) -> b s h d if multiquery: key = key.expand(*key.shape[:2], n_heads, key.size(-1)) value = value.expand(*value.shape[:2], n_heads, value.size(-1)) - reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) + reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) output = attn_output.view(*attn_output.shape[:2], -1) return (output, None, past_key_value) @@ -287,6 +347,7 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i key, value, self.n_heads, + self.d_model, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, @@ -656,6 +717,7 @@ def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: def get_model(args, hf_config): from transformers import AutoModelForCausalLM # type: ignore[import] + import torch # type: ignore[import] model_name = args.model # TODO: download model and use model_path instead of args for from_pretrained @@ -678,6 +740,7 @@ def get_model(args, hf_config): device = tvm.cpu() # TODO: get default mpt-7b-instruct from HF. Possibly it should be downloaded earlier # and use model_path instead + hf_model = AutoModelForCausalLM.from_pretrained( 'mosaicml/mpt-7b-instruct', config=config, From f86c101b6e84fcd6963f4d0b5a80db73315173ee Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 15:35:04 +0300 Subject: [PATCH 014/114] reimplement scaled_multihead_dot_product_attention by relax --- mlc_llm/relax_model/mpt.py | 39 +++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index b03390927f..a5ed8292a7 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -98,40 +98,45 @@ def scaled_multihead_dot_product_attention( k = reshape_and_permute(key, kv_n_heads, d_model) v = reshape_and_permute(value, kv_n_heads, d_model) if past_key_value is not None: - if len(past_key_value) != 0: - k = torch.cat([past_key_value[0], k], dim=3) - v = torch.cat([past_key_value[1], v], dim=2) - past_key_value = (k, v) - (b, _, s_q, d) = q.shape - s_k = k.size(-1) + if len(past_key_value) != 0: + k = nn.emit(relax.op.concat([past_key_value[0], k], axis=3)) + v = nn.emit(relax.op.concat([past_key_value[1], v], axis=2)) + past_key_value = (k, v) + (b, _, s_q, d) = q.struct_info.shape + s_k = k.struct_info.shape[-1] if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) - attn_weight = q.matmul(k) * softmax_scale + attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - s_q) - _s_k = max(0, attn_bias.size(3) - s_k) - attn_bias = attn_bias[:, :, _s_q:, _s_k:] - if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q): - raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.') - attn_weight = attn_weight + attn_bias + _s_q = max(0, attn_bias.struct_info.shape[2] - s_q) + _s_k = max(0, attn_bias.struct_info.shape[3] - s_k) + # TODO: use split + attn_bias = attn_bias[:, :, _s_q:, _s_k:] + if (attn_bias.struct_info.shape[-1] != 1 and + attn_bias.struct_info.shape[-1] != s_k or + (attn_bias.struct_info.shape[-2] != 1 and + attn_bias.struct_info.shape[-2] != s_q)): + raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') + attn_weight = attn_weight + attn_bias min_val = torch.finfo(q.dtype).min if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) - if is_causal and (not q.size(2) == 1): + if is_causal and (not q.struct_info.shape[2] == 1): s = max(s_q, s_k) + causal_mask = nn.emit causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) causal_mask = causal_mask.tril() causal_mask = causal_mask.to(torch.bool) causal_mask = ~causal_mask causal_mask = causal_mask[-s_q:, -s_k:] attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) - attn_weight = torch.softmax(attn_weight, dim=-1) - out = attn_weight.matmul(v) + attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) + out = nn.emit(relax.op.matmul(attn_weight, v)) out = reverse_reshape_and_permute(out) if needs_weights: - return (out, attn_weight, past_key_value) + return (out, attn_weight, past_key_value) return (out, None, past_key_value) From 81f8712c13373d576228a9c60969308271ab0b84 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 16:30:46 +0300 Subject: [PATCH 015/114] replace torch.finfo --- mlc_llm/relax_model/mpt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index a5ed8292a7..85f2bb7d69 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -118,7 +118,7 @@ def scaled_multihead_dot_product_attention( attn_bias.struct_info.shape[-2] != s_q)): raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') attn_weight = attn_weight + attn_bias - min_val = torch.finfo(q.dtype).min + min_val = tvm.tir.min_value(q.struct_info.dtype) if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') @@ -259,7 +259,10 @@ def triton_flash_attn_fn( (b_size, s_k) = key_padding_mask.shape[:2] if attn_bias is None: attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min) + attn_bias = attn_bias.masked_fill( + ~key_padding_mask.view((b_size, 1, 1, s_k)), + tvm.tir.min_value(query.struct_info.dtype) + ) batch_size, seq_len, _ = query.struct_info.shape query = nn.emit(relax.op.reshape( @@ -515,7 +518,7 @@ def _attn_bias(self, device, dtype, attention_mask: Optional[relax.Expr]=None, p attn_bias = attn_bias[:, :, :, _s_k:] if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') - min_val = torch.finfo(attn_bias.dtype).min + min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) return (attn_bias, None) @@ -530,7 +533,7 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) - min_val = torch.finfo(attn_bias.dtype).min + min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias @@ -540,7 +543,7 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTen raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) - min_val = torch.finfo(attn_bias.dtype).min + min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias From 86e849b19824759f4e54cc19eb2f7c39e3786048 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 21:40:18 +0300 Subject: [PATCH 016/114] finish scaled_multihead_dot_product_attention, some TODOs are still there --- mlc_llm/relax_model/mpt.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 85f2bb7d69..e0cf11ca49 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -120,18 +120,22 @@ def scaled_multihead_dot_product_attention( attn_weight = attn_weight + attn_bias min_val = tvm.tir.min_value(q.struct_info.dtype) if key_padding_mask is not None: - if attn_bias is not None: - warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val) + if attn_bias is not None: + warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') + key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) + # TODO: implement masked_fill by relax ops + attn_weight = attn_weight.masked_fill(key_mask, min_val) if is_causal and (not q.struct_info.shape[2] == 1): s = max(s_q, s_k) - causal_mask = nn.emit - causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) - causal_mask = causal_mask.tril() - causal_mask = causal_mask.to(torch.bool) - causal_mask = ~causal_mask + causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) + causal_mask = nn.emit(relax.op.tril(causal_mask)) + causal_mask = tvm.tir.Cast("bool", causal_mask) + causal_mask = tvm.tir.bitwise_not(causal_mask) + # TODO: use split causal_mask = causal_mask[-s_q:, -s_k:] - attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) + causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) + # TODO: implement masked_fill by relax ops + attn_weight = attn_weight.masked_fill(causal_mask, min_val) attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) out = nn.emit(relax.op.matmul(attn_weight, v)) out = reverse_reshape_and_permute(out) From 0df22545411f298f8f908614f9b9384e2b3cadb2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 8 Jun 2023 21:53:07 +0300 Subject: [PATCH 017/114] replace torch from flash_attn_fn --- mlc_llm/relax_model/mpt.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index e0cf11ca49..e2f7a6edf1 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -108,6 +108,7 @@ def scaled_multihead_dot_product_attention( softmax_scale = 1 / math.sqrt(d) attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) if attn_bias is not None: + # TODO: dynamic max _s_q = max(0, attn_bias.struct_info.shape[2] - s_q) _s_k = max(0, attn_bias.struct_info.shape[3] - s_k) # TODO: use split @@ -173,46 +174,51 @@ def flash_attn_fn( check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) + key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) + value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) past_key_value = (key, value) if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) + # TODO: dynamic max + _s_q = max(0, attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) + _s_k = max(0, attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) + # TODO: use split attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_shape = key.struct_info.shape[:2] + # TODO: dynamic shape key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) - query_padding_mask = key_padding_mask[:, -query.size(1):] + # TODO: use split + query_padding_mask = key_padding_mask[:, -query.struct_info.shape[1]:] (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) - nnz, _, _ = query_unpad.struct_info.shape + qnnz, _, _ = query_unpad.struct_info.shape query_unpad = nn.emit(relax.op.reshape( query_unpad, - (nnz, n_heads, d_model), + (qnnz, n_heads, d_model), )) # (nnz, (h d)) -> (nnz, h, d) - nnz, _, _ = key_unpad.struct_info.shape + kv_nnz, _, _ = key_unpad.struct_info.shape kv_n_heads = 1 if multiquery else n_heads (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) key_unpad = nn.emit(relax.op.reshape( key_unpad, - (nnz, kv_n_heads, d_model), + (kv_nnz, kv_n_heads, d_model), )) # (nnz, (h d)) -> (nnz, h, d) (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) value_unpad = nn.emit(relax.op.reshape( value_unpad, - (nnz, kv_n_heads, d_model), + (kv_nnz, kv_n_heads, d_model), )) # (nnz, (h d)) -> (nnz, h, d) if multiquery: - key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) - value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1)) + key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) + value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + nnz, _, _ = output_unpad.struct_info.shape output_unpad = nn.emit(relax.op.reshape( output_unpad, (nnz, n_heads*d_model), From 0c1c73705bcf646398f3ec99ec97e620f9d5312e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 08:57:07 +0300 Subject: [PATCH 018/114] replace torch from triton_flash_attn_fn --- mlc_llm/relax_model/mpt.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index e2f7a6edf1..c2ee890411 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -255,24 +255,25 @@ def triton_flash_attn_fn( check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: - key = torch.cat([past_key_value[0], key], dim=1) - value = torch.cat([past_key_value[1], value], dim=1) + key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) + value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) past_key_value = (key, value) if attn_bias is not None: - _s_q = max(0, attn_bias.size(2) - query.size(1)) - _s_k = max(0, attn_bias.size(3) - key.size(1)) + # TODO: dynamic max + _s_q = max(0, attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) + _s_k = max(0, attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) + # TODO: use split attn_bias = attn_bias[:, :, _s_q:, _s_k:] if needs_weights: raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') if key_padding_mask is not None: warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - (b_size, s_k) = key_padding_mask.shape[:2] + (b_size, s_k) = key_padding_mask.struct_info.shape[:2] if attn_bias is None: - attn_bias = query.new_zeros(b_size, 1, 1, s_k) - attn_bias = attn_bias.masked_fill( - ~key_padding_mask.view((b_size, 1, 1, s_k)), - tvm.tir.min_value(query.struct_info.dtype) - ) + attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) + key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) + # TODO: implement masked_fill by relax ops + attn_bias = attn_bias.masked_fill(key_mask, tvm.tir.min_value(query.struct_info.dtype)) batch_size, seq_len, _ = query.struct_info.shape query = nn.emit(relax.op.reshape( @@ -291,11 +292,15 @@ def triton_flash_attn_fn( (batch_size, seq_len, kv_n_heads, d_model), )) # b s (h d) -> b s h d if multiquery: - key = key.expand(*key.shape[:2], n_heads, key.size(-1)) - value = value.expand(*value.shape[:2], n_heads, value.size(-1)) + key = relax.op.broadcast_to(key, (batch_size, seq_len, n_heads, d_model)) + value = relax.op.broadcast_to(value, (batch_size, seq_len, n_heads, d_model)) reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) - output = attn_output.view(*attn_output.shape[:2], -1) + batch_size, seq_len, _, _ = attn_output.struct_info.shape + output = nn.emit(relax.op.reshape( + attn_output, + (batch_size, seq_len, n_heads*d_model), + )) # (b, s, h, d)) -> (b, s, (h d)) return (output, None, past_key_value) From 893c51dd689dff3570f4a4742fb7ac36125329d7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 09:43:06 +0300 Subject: [PATCH 019/114] update MPTModel forward, replace all torch operations --- mlc_llm/relax_model/mpt.py | 97 +++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 44 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index c2ee890411..a49431a72b 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -562,68 +562,77 @@ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTen attn_bias = attn_bias.masked_fill(cannot_attend, min_val) return attn_bias - def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None): + def forward( + self, + input_ids: relax.Expr, + past_key_values: Optional[List[Tuple[relax.Expr]]]=None, + attention_mask: Optional[relax.Expr]=None, + prefix_mask: Optional[relax.Expr]=None, + sequence_id: Optional[relax.Expr]=None, + return_dict: Optional[bool]=None, + output_attentions: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + use_cache: Optional[bool]=None + ): return_dict = return_dict if return_dict is not None else self.config.return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache if attention_mask is not None: - attention_mask = attention_mask.bool() + attention_mask = nn.emit(tvm.tir.Cast("bool", attention_mask)) if prefix_mask is not None: - prefix_mask = prefix_mask.bool() + prefix_mask = nn.emit(tvm.tir.Cast("bool", prefix_mask)) if not return_dict: - raise NotImplementedError('return_dict False is not implemented yet for MPT') + raise NotImplementedError('return_dict False is not implemented yet for MPT') if output_attentions: - if self.attn_impl != 'torch': - raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') - if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training: - raise NotImplementedError('MPT does not support training with left padding.') + if self.attn_impl != 'torch': + raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') if self.prefix_lm and prefix_mask is None: - raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') - if self.training: - if self.attn_uses_sequence_id and sequence_id is None: - raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.') - elif self.attn_uses_sequence_id is False and sequence_id is not None: - warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.') - S = input_ids.size(1) + raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') + + S = input_ids.struct_info.shape[1] assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + tok_emb = self.wte(input_ids) if self.alibi: - x = tok_emb + x = tok_emb else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') - past_position = past_key_values[0][0].size(1) - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].size(3) - if S + past_position > self.config.max_seq_len: - raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') - pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0) - if attention_mask is not None: - pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb + past_position = 0 + if past_key_values is not None: + if len(past_key_values) != self.config.n_layers: + raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') + past_position = past_key_values[0][0].struct_info.shape[1] + if self.attn_impl == 'torch': + past_position = past_key_values[0][0].struct_info.shape[3] + if S + past_position > self.config.max_seq_len: + raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') + pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) + if attention_mask is not None: + #TODO use split + pos_diff = nn.emit(relax.op.cumsum(tvm.tir.Cast("int32", tvm.tir.bitwise_not(attention_mask)), axis=1)[:, past_position:]) + pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) + pos_emb = self.wpe(pos) + x = tok_emb + pos_emb + # TODO: reimplement _attn_bias, check removed args (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] + past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for (b_idx, block) in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - past_key_value = past_key_values[b_idx] if past_key_values is not None else None - (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) - if past_key_values is not None: - past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) - if output_hidden_states: + if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns) + past_key_value = past_key_values[b_idx] if past_key_values is not None else None + (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + if past_key_values is not None: + past_key_values[b_idx] = past_key_value + if output_attentions: + assert all_self_attns is not None + all_self_attns = all_self_attns + (attn_weights,) + x = self.norm_f(x) + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states = all_hidden_states + (x,) + return x, past_key_values, all_hidden_states, all_self_attns def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) From db4aadadf4f2da752a896ae7180985a3d6a51e0c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 11:10:04 +0300 Subject: [PATCH 020/114] implement masked_fill by relax, replace torch masked_fill by it. remove corresponding TODOs. other torch replacements --- mlc_llm/relax_model/mpt.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index a49431a72b..63e6c9a59b 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -16,6 +16,13 @@ ModuleList, ) + +def masked_fill_relax(input, mask, value): + rx_value = relax.const(value) + values = nn.emit(relax.op.full_like(input, rx_value)) + return nn.emit(relax.op.where(mask, values, input)) + + def _cast_if_autocast_enabled(tensor): if torch.is_autocast_enabled(): if tensor.device.type == 'cuda': @@ -124,8 +131,7 @@ def scaled_multihead_dot_product_attention( if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) - # TODO: implement masked_fill by relax ops - attn_weight = attn_weight.masked_fill(key_mask, min_val) + attn_weight = masked_fill_relax(attn_weight, key_mask, min_val) if is_causal and (not q.struct_info.shape[2] == 1): s = max(s_q, s_k) causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) @@ -135,8 +141,7 @@ def scaled_multihead_dot_product_attention( # TODO: use split causal_mask = causal_mask[-s_q:, -s_k:] causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) - # TODO: implement masked_fill by relax ops - attn_weight = attn_weight.masked_fill(causal_mask, min_val) + attn_weight = masked_fill_relax(attn_weight, causal_mask, min_val) attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) out = nn.emit(relax.op.matmul(attn_weight, v)) out = reverse_reshape_and_permute(out) @@ -272,8 +277,7 @@ def triton_flash_attn_fn( if attn_bias is None: attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) - # TODO: implement masked_fill by relax ops - attn_bias = attn_bias.masked_fill(key_mask, tvm.tir.min_value(query.struct_info.dtype)) + attn_bias = masked_fill_relax(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype)) batch_size, seq_len, _ = query.struct_info.shape query = nn.emit(relax.op.reshape( @@ -534,11 +538,11 @@ def _attn_bias(self, device, dtype, attention_mask: Optional[relax.Expr]=None, p if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val) + attn_bias = masked_fill_relax(attn_bias, ~attention_mask.view(-1, 1, 1, s_k), min_val) return (attn_bias, None) - def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor): - (s_k, s_q) = attn_bias.shape[-2:] + def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): + (s_k, s_q) = attn_bias.struct_info.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.shape[-1] @@ -549,17 +553,18 @@ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) prefix = prefix_mask.view(-1, 1, 1, seq_len) cannot_attend = ~torch.logical_or(causal, prefix.bool()) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) return attn_bias - def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor): - seq_len = sequence_id.shape[-1] + def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): + seq_len = sequence_id.struct_info.shape[-1] if seq_len > self.config.max_seq_len: - raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + # TODO: use split attn_bias = attn_bias[..., :seq_len, :seq_len] cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = attn_bias.masked_fill(cannot_attend, min_val) + attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) return attn_bias def forward( @@ -714,7 +719,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ sequence_id = None if past_key_values is not None: # TODO: Relax implementation? - input_ids = input_ids[:, -1].unsqueeze(-1) + input_ids = nn.emit(relax.op.expand_dims(input_ids[:, -1], axis=-1)) if self.transformer.prefix_lm: prefix_mask = nn.emit(relax.op.ones_like(attention_mask, self.dtype)) if kwargs.get('use_cache') == False: From dbf143ab300383f19c69250dae2bffe08d11d48d Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 11:28:05 +0300 Subject: [PATCH 021/114] fix max on dynamic values --- mlc_llm/relax_model/mpt.py | 43 ++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 63e6c9a59b..6ecb75f0e9 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -38,9 +38,9 @@ def _cast_if_autocast_enabled(tensor): class LPLayerNormWOBias(nn.Module): def __init__(self, normalized_shape, eps=1e-05, dtype=None): self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") - # TODO: check + # TODO: check default filling of weights self.weight = relax.op.ones((normalized_shape,), dtype) - self.variance_epsilon = tvm.tir.const(eps, dtype) + self.variance_epsilon = relax.const(eps, dtype) def forward(self, x): module_device = x.device @@ -62,7 +62,7 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau return original_is_causal -def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int): +def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int, indeces: List[int] = [0, 2, 1, 3]): ''' Transform shape of input: b s (h d) -> b h d s ''' @@ -71,7 +71,7 @@ def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int): hidden_states, (batch_size, seqlen, n_heads, d_model), )) - return nn.emit(relax.op.permute_dims(inter, [0, 2, 1, 3])) + return nn.emit(relax.op.permute_dims(inter, indeces)) def reverse_reshape_and_permute(hidden_states: relax.Expr): @@ -102,7 +102,7 @@ def scaled_multihead_dot_product_attention( ): q = reshape_and_permute(query, n_heads, d_model) kv_n_heads = 1 if multiquery else n_heads - k = reshape_and_permute(key, kv_n_heads, d_model) + k = reshape_and_permute(key, kv_n_heads, d_model, [0, 2, 3, 1]) v = reshape_and_permute(value, kv_n_heads, d_model) if past_key_value is not None: if len(past_key_value) != 0: @@ -115,15 +115,14 @@ def scaled_multihead_dot_product_attention( softmax_scale = 1 / math.sqrt(d) attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) if attn_bias is not None: - # TODO: dynamic max - _s_q = max(0, attn_bias.struct_info.shape[2] - s_q) - _s_k = max(0, attn_bias.struct_info.shape[3] - s_k) + _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - s_q) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - s_k) # TODO: use split attn_bias = attn_bias[:, :, _s_q:, _s_k:] if (attn_bias.struct_info.shape[-1] != 1 and - attn_bias.struct_info.shape[-1] != s_k or + attn_bias.struct_info.shape[-1] != s_k or # dynamic condition? (attn_bias.struct_info.shape[-2] != 1 and - attn_bias.struct_info.shape[-2] != s_q)): + attn_bias.struct_info.shape[-2] != s_q)): # dynamic condition? raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') attn_weight = attn_weight + attn_bias min_val = tvm.tir.min_value(q.struct_info.dtype) @@ -133,7 +132,7 @@ def scaled_multihead_dot_product_attention( key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) attn_weight = masked_fill_relax(attn_weight, key_mask, min_val) if is_causal and (not q.struct_info.shape[2] == 1): - s = max(s_q, s_k) + s = relax.op.maximum(s_q, s_k) causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) causal_mask = nn.emit(relax.op.tril(causal_mask)) causal_mask = tvm.tir.Cast("bool", causal_mask) @@ -183,9 +182,8 @@ def flash_attn_fn( value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) past_key_value = (key, value) if attn_bias is not None: - # TODO: dynamic max - _s_q = max(0, attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) - _s_k = max(0, attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) + _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) # TODO: use split attn_bias = attn_bias[:, :, _s_q:, _s_k:] if attn_bias is not None: @@ -264,9 +262,8 @@ def triton_flash_attn_fn( value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) past_key_value = (key, value) if attn_bias is not None: - # TODO: dynamic max - _s_q = max(0, attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) - _s_k = max(0, attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) + _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) # TODO: use split attn_bias = attn_bias[:, :, _s_q:, _s_k:] if needs_weights: @@ -522,20 +519,20 @@ def _attn_bias(self, device, dtype, attention_mask: Optional[relax.Expr]=None, p self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) attn_bias = self.attn_bias if self.prefix_lm: - assert isinstance(attn_bias, torch.Tensor) - assert isinstance(prefix_mask, torch.Tensor) + assert isinstance(attn_bias, relax.Expr) + assert isinstance(prefix_mask, relax.Expr) attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, torch.Tensor) + assert isinstance(attn_bias, relax.Expr) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: - s_k = attention_mask.shape[-1] + s_k = attention_mask.struct_info.shape[-1] if attn_bias is None: attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) else: - _s_k = max(0, attn_bias.size(-1) - s_k) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[-1] - s_k) attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.shape != prefix_mask.shape: + if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = masked_fill_relax(attn_bias, ~attention_mask.view(-1, 1, 1, s_k), min_val) From 1fab0ba696d7b64ed0bb31f977f9890b709c562f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 12:21:38 +0300 Subject: [PATCH 022/114] implement build_attn_bias with dependencies --- mlc_llm/relax_model/mpt.py | 49 +++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 6ecb75f0e9..8a9bbee5c1 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -191,7 +191,6 @@ def flash_attn_fn( (batch_size, seqlen) = query.shape[:2] if key_padding_mask is None: key_shape = key.struct_info.shape[:2] - # TODO: dynamic shape key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) # TODO: use split query_padding_mask = key_padding_mask[:, -query.struct_info.shape[1]:] @@ -465,6 +464,42 @@ def forward( return (hidden_states, attn_weights, present_key_value) +def gen_slopes(n_heads, alibi_bias_max=8): + _n_heads = 2 ** math.ceil(math.log2(n_heads)) + m = nn.emit(relax.op.arange(1, _n_heads + 1, dtype="float32")) + m = nn.emit(m * (alibi_bias_max / _n_heads)) + slopes = 1.0 / math.pow(2, m) + if _n_heads != n_heads: + # TODO: relax [::] + slopes = nn.emit(relax.op.concat([slopes[1::2], slopes[::2]])[:n_heads]) + return nn.emit(relax.op.reshape(slopes, (1, n_heads, 1, 1))) + + +def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, dtype=None): + alibi_bias = nn.emit(relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, 1, seq_len))) + if full: + alibi_bias = nn.emit(alibi_bias - relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, seq_len, 1))) + alibi_bias = nn.emit(relax.op.negative(relax.op.abs(alibi_bias))) + slopes = gen_slopes(n_heads, alibi_bias_max) + alibi_bias = nn.emit(alibi_bias * slopes) + if dtype is not None: + alibi_bias = nn.emit(tvm.tir.Cast(dtype, alibi_bias)) + return alibi_bias + + +def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + attn_bias = nn.emit(relax.op.add(attn_bias, build_alibi_bias( + n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, dtype=attn_bias.struct_info.dtype + ))) + return attn_bias + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + + class MPTModel(nn.Module): def __init__(self, config: MPTConfig): config._validate_config() @@ -507,12 +542,14 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.wte = value - def _attn_bias(self, device, dtype, attention_mask: Optional[relax.Expr]=None, prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None): + def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None): if not self._attn_bias_initialized: - if self.attn_bias_shape: - self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype) - self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max) - self._attn_bias_initialized = True + if self.attn_bias_shape: + self.attn_bias = nn.emit(relax.op.zeros(self.attn_bias_shape, dtype=dtype)) + self.attn_bias = build_attn_bias( + self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max + ) + self._attn_bias_initialized = True if self.attn_impl == 'flash': return (self.attn_bias, attention_mask) if self.attn_bias is not None: From 6a40c7b049f71311086a4d2bda34b5b65aeb6c98 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 12:48:22 +0300 Subject: [PATCH 023/114] transfer of code for the sake of convenience --- mlc_llm/relax_model/mpt.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 8a9bbee5c1..ced6b9d52d 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -383,21 +383,6 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} -def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): - if attn_impl == 'flash': - return None - elif attn_impl in ['torch', 'triton']: - if alibi: - if (prefix_lm or not causal) or use_sequence_id: - return (1, n_heads, seq_len, seq_len) - return (1, n_heads, 1, seq_len) - elif prefix_lm or use_sequence_id: - return (1, 1, seq_len, seq_len) - return None - else: - raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') - - class MPTMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype) @@ -464,6 +449,21 @@ def forward( return (hidden_states, attn_weights, present_key_value) +def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): + if attn_impl == 'flash': + return None + elif attn_impl in ['torch', 'triton']: + if alibi: + if (prefix_lm or not causal) or use_sequence_id: + return (1, n_heads, seq_len, seq_len) + return (1, n_heads, 1, seq_len) + elif prefix_lm or use_sequence_id: + return (1, 1, seq_len, seq_len) + return None + else: + raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') + + def gen_slopes(n_heads, alibi_bias_max=8): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = nn.emit(relax.op.arange(1, _n_heads + 1, dtype="float32")) From 01d4b0738f6a69e0893a831dcd192fa1ff038cce Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 13:21:43 +0300 Subject: [PATCH 024/114] _attn_bias of MPTModel was implemented --- mlc_llm/relax_model/mpt.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index ced6b9d52d..3be92a385a 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -551,9 +551,9 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma ) self._attn_bias_initialized = True if self.attn_impl == 'flash': - return (self.attn_bias, attention_mask) + return (self.attn_bias, attention_mask) if self.attn_bias is not None: - self.attn_bias = self.attn_bias.to(dtype=dtype, device=device) + self.attn_bias = nn.emit(tvm.tir.Cast(dtype, self.attn_bias)) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, relax.Expr) @@ -563,16 +563,18 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma assert isinstance(attn_bias, relax.Expr) attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: - s_k = attention_mask.struct_info.shape[-1] - if attn_bias is None: - attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype) - else: - _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[-1] - s_k) - attn_bias = attn_bias[:, :, :, _s_k:] - if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: - raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') - min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = masked_fill_relax(attn_bias, ~attention_mask.view(-1, 1, 1, s_k), min_val) + s_k = attention_mask.struct_info.shape[-1] + if attn_bias is None: + attn_bias = nn.emit(relax.op.zeros((1, 1, 1, s_k), dtype=dtype)) + else: + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[-1] - s_k) + # TODO: use split + attn_bias = attn_bias[:, :, :, _s_k:] + if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: + raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') + min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) + attn_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) + attn_bias = masked_fill_relax(attn_bias, attn_mask, min_val) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): From b256f38f65404ac1499c548df38119b8ebd6ac6b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 14:27:45 +0300 Subject: [PATCH 025/114] _apply_prefix_mask of MPTModel was implemented on relax --- mlc_llm/relax_model/mpt.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 3be92a385a..de2289f994 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -581,13 +581,15 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): (s_k, s_q) = attn_bias.struct_info.shape[-2:] if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') - seq_len = prefix_mask.shape[-1] + seq_len = prefix_mask.struct_info.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + # TODO: use split attn_bias = attn_bias[..., :seq_len, :seq_len] - causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len) - prefix = prefix_mask.view(-1, 1, 1, seq_len) - cannot_attend = ~torch.logical_or(causal, prefix.bool()) + causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) + prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) + # TODO: logical_or on relax + cannot_attend = nn.emit(tvm.tir.bitwise_not(torch.logical_or(causal, tvm.tir.Cast("bool", prefix)))) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) return attn_bias From 7d5b1d28bbccffb88f461e54506df012d0e0c355 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 14:34:31 +0300 Subject: [PATCH 026/114] _apply_sequence_id of MPTModel was implemented on relax --- mlc_llm/relax_model/mpt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index de2289f994..7cd17a1d37 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -600,7 +600,10 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') # TODO: use split attn_bias = attn_bias[..., :seq_len, :seq_len] - cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1) + # TODO: logical_not on relax + seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) + seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) + cannot_attend = nn.emit(relax.op.expand_dims(torch.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) return attn_bias From f7b604f91af0a7bb474e4daa7532df5804488db2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 14:57:38 +0300 Subject: [PATCH 027/114] fix layer norm --- mlc_llm/relax_model/mpt.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 7cd17a1d37..f5ad6f82e9 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -24,15 +24,15 @@ def masked_fill_relax(input, mask, value): def _cast_if_autocast_enabled(tensor): - if torch.is_autocast_enabled(): - if tensor.device.type == 'cuda': - dtype = torch.get_autocast_gpu_dtype() - elif tensor.device.type == 'cpu': - dtype = torch.get_autocast_cpu_dtype() - else: - raise NotImplementedError() - return tensor.to(dtype=dtype) - return tensor + # # TODO: how to check device? + # if tensor.device.type == 'cuda': + # dtype = "float16" + # elif tensor.device.type == 'cpu': + # dtype = "bfloat16" + # else: + # raise NotImplementedError() + dtype = "float32" # TODO: temporal workaround + return nn.emit(tvm.tir.Cast(dtype, tensor)) # Low-precision layer norm for mpt-7b-instruct, where are no biases expected class LPLayerNormWOBias(nn.Module): @@ -40,15 +40,14 @@ def __init__(self, normalized_shape, eps=1e-05, dtype=None): self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") # TODO: check default filling of weights self.weight = relax.op.ones((normalized_shape,), dtype) - self.variance_epsilon = relax.const(eps, dtype) + self.bias = relax.op.zeros((normalized_shape,), dtype) + self.eps = relax.const(eps, dtype) def forward(self, x): - module_device = x.device downcast_x = _cast_if_autocast_enabled(x) downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias - with torch.autocast(enabled=False, device_type=module_device.type): - return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) + return nn.emit(relax.op.nn.layer_norm(downcast_x, downcast_weight, downcast_bias, axes=-1, epsilon=self.eps)) NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNormWOBias} @@ -149,12 +148,13 @@ def scaled_multihead_dot_product_attention( return (out, None, past_key_value) -def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): - for tensor in tensors: - if tensor.dtype not in valid_dtypes: - raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') - if not tensor.is_cuda: - raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') +def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): + for tensor in tensors: + if tensor.struct_info.dtype not in valid_dtypes: + raise TypeError(f'tensor.dtype={tensor.struct_info.dtype!r} must be in valid_dtypes={valid_dtypes!r}.') + # TODO: check on relax that CUDA is used + # if not tensor.is_cuda: + # raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') def flash_attn_fn( From 9a8a331eef7e46fa9a5aa528a1e93e27eaf59bfa Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 15:02:58 +0300 Subject: [PATCH 028/114] remove device --- mlc_llm/relax_model/mpt.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index f5ad6f82e9..b1a6303b29 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -318,8 +318,7 @@ def __init__( clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, - low_precision_layernorm: bool=False, - device: Optional[str]=None + low_precision_layernorm: bool=False ): # Init fields self.d_model = d_model @@ -331,13 +330,13 @@ def __init__( if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.Wqkv = Linear(self.d_model, 3 * self.d_model, device=device) + self.Wqkv = Linear(self.d_model, 3 * self.d_model) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNormWOBias if low_precision_layernorm else LayerNorm - self.q_ln = layernorm_class(self.d_model, device=device) - self.k_ln = layernorm_class(self.d_model, device=device) + self.q_ln = layernorm_class(self.d_model) + self.k_ln = layernorm_class(self.d_model) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': @@ -351,7 +350,7 @@ def __init__( self.attn_fn = scaled_multihead_dot_product_attention else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') - self.out_proj = Linear(self.d_model, self.d_model, device=device) + self.out_proj = Linear(self.d_model, self.d_model) # TODO: Does field _is_residual exist? self.out_proj._is_residual = True @@ -657,8 +656,7 @@ def forward( pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) pos_emb = self.wpe(pos) x = tok_emb + pos_emb - # TODO: reimplement _attn_bias, check removed args - (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) + (attn_bias, attention_mask) = self._attn_bias(dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers)] all_hidden_states = () if output_hidden_states else None From 589af326fc033a4dfb27c2e098cfaa16640d8090 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 9 Jun 2023 15:34:31 +0300 Subject: [PATCH 029/114] unroll flash_attn implementation using sources --- mlc_llm/relax_model/mpt.py | 156 ++++++++++++++++++++++++++++++++++--- 1 file changed, 147 insertions(+), 9 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index b1a6303b29..dc4d52d3e4 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -157,6 +157,148 @@ def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): # raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') +class IndexFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, + repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, 'b ... -> b (...)') + grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, dtype=grad_output.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + +import torch +class IndexPutFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, + dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +def bert_padding_unpad_input(hidden_states, attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, + cu_seqlens, max_seqlen_in_batch) + + +def bert_padding_pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, '(b s) ... -> b s ...', b=batch) + +import flash_attn_cuda +def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax, num_splits=0, + generator=None): + """ + num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means + it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. + Don't change it unless you know what you're doing. + """ + softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, num_splits, generator + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + S_dmask = rest[0] if return_softmax else None + return out, softmax_lse, rng_state, S_dmask + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax, deterministic): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( + q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax + ) + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None + + +def flash_attn_interface_flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, + deterministic=False +): + return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs, deterministic) + + def flash_attn_fn( query, key, @@ -171,10 +313,6 @@ def flash_attn_fn( needs_weights=False, multiquery=False ): - try: - from flash_attn import bert_padding, flash_attn_interface - except: - raise RuntimeError('Please install flash-attn==1.0.3.post0') check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: @@ -194,7 +332,7 @@ def flash_attn_fn( key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) # TODO: use split query_padding_mask = key_padding_mask[:, -query.struct_info.shape[1]:] - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask) + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) qnnz, _, _ = query_unpad.struct_info.shape query_unpad = nn.emit(relax.op.reshape( @@ -204,12 +342,12 @@ def flash_attn_fn( kv_nnz, _, _ = key_unpad.struct_info.shape kv_n_heads = 1 if multiquery else n_heads - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask) + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) key_unpad = nn.emit(relax.op.reshape( key_unpad, (kv_nnz, kv_n_heads, d_model), )) # (nnz, (h d)) -> (nnz, h, d) - (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) + (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) value_unpad = nn.emit(relax.op.reshape( value_unpad, (kv_nnz, kv_n_heads, d_model), @@ -219,13 +357,13 @@ def flash_attn_fn( key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) - output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + output_unpad = flash_attn_interface_flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) nnz, _, _ = output_unpad.struct_info.shape output_unpad = nn.emit(relax.op.reshape( output_unpad, (nnz, n_heads*d_model), )) # (nnz, h, d)) -> (nnz, (h d)) - output = bert_padding.pad_input(output_unpad, indices_q, batch_size, seqlen) + output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) return (output, None, past_key_value) From 4f5ef3fc78df4b6e162d52564b938d3ca85ca9d7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Sat, 10 Jun 2023 15:59:46 +0300 Subject: [PATCH 030/114] slicings were reimplemented from python style to relax. corresponding TODOs were removed --- mlc_llm/relax_model/mpt.py | 68 ++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index dc4d52d3e4..847255fbd6 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -116,8 +116,9 @@ def scaled_multihead_dot_product_attention( if attn_bias is not None: _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - s_q) _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - s_k) - # TODO: use split - attn_bias = attn_bias[:, :, _s_q:, _s_k:] + # slicing attn_bias[:, :, _s_q:, _s_k:] + s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) if (attn_bias.struct_info.shape[-1] != 1 and attn_bias.struct_info.shape[-1] != s_k or # dynamic condition? (attn_bias.struct_info.shape[-2] != 1 and @@ -136,8 +137,9 @@ def scaled_multihead_dot_product_attention( causal_mask = nn.emit(relax.op.tril(causal_mask)) causal_mask = tvm.tir.Cast("bool", causal_mask) causal_mask = tvm.tir.bitwise_not(causal_mask) - # TODO: use split - causal_mask = causal_mask[-s_q:, -s_k:] + # slicing causal_mask[-s_q:, -s_k:] + s_q_end, s_k_end = causal_mask.struct_info.shape + causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) attn_weight = masked_fill_relax(attn_weight, causal_mask, min_val) attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) @@ -322,16 +324,18 @@ def flash_attn_fn( if attn_bias is not None: _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) - # TODO: use split - attn_bias = attn_bias[:, :, _s_q:, _s_k:] + # slicing attn_bias[:, :, _s_q:, _s_k:] + s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) if attn_bias is not None: raise NotImplementedError(f'attn_bias not implemented for flash attn.') - (batch_size, seqlen) = query.shape[:2] + batch_size, seqlen = query.struct_info.shape[:2] if key_padding_mask is None: key_shape = key.struct_info.shape[:2] key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) - # TODO: use split - query_padding_mask = key_padding_mask[:, -query.struct_info.shape[1]:] + # slicing key_padding_mask[:, -query.struct_info.shape[1]:] + dim1_length = key_padding_mask.struct_info.shape[1] + query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) qnnz, _, _ = query_unpad.struct_info.shape @@ -401,8 +405,9 @@ def triton_flash_attn_fn( if attn_bias is not None: _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) - # TODO: use split - attn_bias = attn_bias[:, :, _s_q:, _s_k:] + # slicing attn_bias[:, :, _s_q:, _s_k:] + s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) if needs_weights: raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') if key_padding_mask is not None: @@ -607,8 +612,13 @@ def gen_slopes(n_heads, alibi_bias_max=8): m = nn.emit(m * (alibi_bias_max / _n_heads)) slopes = 1.0 / math.pow(2, m) if _n_heads != n_heads: - # TODO: relax [::] - slopes = nn.emit(relax.op.concat([slopes[1::2], slopes[::2]])[:n_heads]) + slopes_len = slopes.struct_info.shape[0] + slopes = nn.emit(relax.op.strided_slice( + relax.op.concat( + [relax.op.strided_slice(slopes, [0], [relax.const(1)], [slopes_len], [relax.const(2)]), # [1::2] + relax.op.strided_slice(slopes, [0], [relax.const(0)], [slopes_len], [relax.const(2)])] # [::2] + ), [0], [relax.const(0)], [relax.const(n_heads)]) # slicing [:n_heads] + ) return nn.emit(relax.op.reshape(slopes, (1, n_heads, 1, 1))) @@ -705,8 +715,9 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma attn_bias = nn.emit(relax.op.zeros((1, 1, 1, s_k), dtype=dtype)) else: _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[-1] - s_k) - # TODO: use split - attn_bias = attn_bias[:, :, :, _s_k:] + # slicing attn_bias[:, :, :, _s_k:] + s_k_end = attn_bias.struct_info.shape[3] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [3], [_s_k], [s_k_end])) if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) @@ -721,8 +732,9 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): seq_len = prefix_mask.struct_info.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') - # TODO: use split - attn_bias = attn_bias[..., :seq_len, :seq_len] + # slicing attn_bias[..., :seq_len, :seq_len] + dims_len = len(attn_bias.struct_info.shape) # TODO: rank? + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) # TODO: logical_or on relax @@ -735,8 +747,9 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): seq_len = sequence_id.struct_info.shape[-1] if seq_len > self.config.max_seq_len: raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') - # TODO: use split - attn_bias = attn_bias[..., :seq_len, :seq_len] + # slicing attn_bias[..., :seq_len, :seq_len] + dims_len = len(attn_bias.struct_info.shape) # TODO: rank? + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) # TODO: logical_not on relax seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) @@ -789,8 +802,10 @@ def forward( raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) if attention_mask is not None: - #TODO use split - pos_diff = nn.emit(relax.op.cumsum(tvm.tir.Cast("int32", tvm.tir.bitwise_not(attention_mask)), axis=1)[:, past_position:]) + pos_diff_to_slice = nn.emit(relax.op.cumsum(tvm.tir.Cast("int32", tvm.tir.bitwise_not(attention_mask)), axis=1)) + dim1_len = pos_diff_to_slice.struct_info.shape[1] + # slicing [:, past_position:] + pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) pos_emb = self.wpe(pos) x = tok_emb + pos_emb @@ -890,13 +905,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') if self.transformer.attn_uses_sequence_id and self.training: - # TODO: [:1] in Relax? - sequence_id = nn.emit(relax.op.zeros_like(input_ids[:1])) + # slicing input_ids[:1] + input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [0], [relax.const(0)], [relax.const(1)])) + sequence_id = nn.emit(relax.op.zeros_like(input_ids_slice)) else: sequence_id = None if past_key_values is not None: - # TODO: Relax implementation? - input_ids = nn.emit(relax.op.expand_dims(input_ids[:, -1], axis=-1)) + # slicing input_ids[:, -1] + dim1_len = input_ids.struct_info.shape[1] + input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) + input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) if self.transformer.prefix_lm: prefix_mask = nn.emit(relax.op.ones_like(attention_mask, self.dtype)) if kwargs.get('use_cache') == False: From 1b58d3ab430a85c1f47d192890782d6528d555f1 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Jun 2023 09:25:02 +0300 Subject: [PATCH 031/114] add draft for create_decoding_func. Fix two TODOs related to rank --- mlc_llm/relax_model/mpt.py | 39 ++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 847255fbd6..95501a68ef 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Dict import numpy as np import tvm @@ -14,6 +14,7 @@ LayerNorm, Linear, ModuleList, + named_parameters, ) @@ -733,7 +734,7 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): if seq_len > self.config.max_seq_len: raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') # slicing attn_bias[..., :seq_len, :seq_len] - dims_len = len(attn_bias.struct_info.shape) # TODO: rank? + dims_len = attn_bias.struct_info.ndim attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) @@ -748,7 +749,7 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): if seq_len > self.config.max_seq_len: raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') # slicing attn_bias[..., :seq_len, :seq_len] - dims_len = len(attn_bias.struct_info.shape) # TODO: rank? + dims_len = attn_bias.struct_info.ndim attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) # TODO: logical_not on relax seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) @@ -943,9 +944,35 @@ def _reorder_cache(past_key_values, beam_idx): return reordered_past -def create_encoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: - pass - +def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: + pidx2pname: Dict[int, str] = {} + with bb.function("decode"): + model = MPTForCausalLM(config) + input_ids = nn.Placeholder((1, 1), dtype="int32", name="input_ids") + # Placeholder for compatibility to LLAMA + all_seq_len_shape = relax.Var("place_holder", R.Object()) + state = relax.Var("state", R.Tuple([R.Object()] * config.n_layers * 5)) + with bb.dataflow(): + logits, states = model(input_ids, state) + params = [ + input_ids, + all_seq_len_shape, + state, + ] + model.parameters() + + named_params = named_parameters(model) + for i, (name, param) in enumerate(named_params.items()): + pidx2pname[i] = name + assert param.same_as(params[i + 3]) + + gv = bb.emit_output((logits, relax.Tuple(states))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("decode") + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + return pidx2pname def get_model(args, hf_config): from transformers import AutoModelForCausalLM # type: ignore[import] From 3bbd62c1ab5bd75c04cbe18b8c498223ed0a5976 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Jun 2023 15:42:56 +0300 Subject: [PATCH 032/114] replace handmade masked_filled by relax op implemented in mlc-relax --- mlc_llm/relax_model/mpt.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 95501a68ef..3a2c875b86 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -18,12 +18,6 @@ ) -def masked_fill_relax(input, mask, value): - rx_value = relax.const(value) - values = nn.emit(relax.op.full_like(input, rx_value)) - return nn.emit(relax.op.where(mask, values, input)) - - def _cast_if_autocast_enabled(tensor): # # TODO: how to check device? # if tensor.device.type == 'cuda': @@ -131,7 +125,7 @@ def scaled_multihead_dot_product_attention( if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) - attn_weight = masked_fill_relax(attn_weight, key_mask, min_val) + attn_weight = nn.emit(relax.op.masked_fill(attn_weight, key_mask, min_val)) if is_causal and (not q.struct_info.shape[2] == 1): s = relax.op.maximum(s_q, s_k) causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) @@ -142,7 +136,7 @@ def scaled_multihead_dot_product_attention( s_q_end, s_k_end = causal_mask.struct_info.shape causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) - attn_weight = masked_fill_relax(attn_weight, causal_mask, min_val) + attn_weight = nn.emit(relax.op.masked_fill(attn_weight, causal_mask, min_val)) attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) out = nn.emit(relax.op.matmul(attn_weight, v)) out = reverse_reshape_and_permute(out) @@ -417,7 +411,7 @@ def triton_flash_attn_fn( if attn_bias is None: attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) - attn_bias = masked_fill_relax(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype)) + attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype))) batch_size, seq_len, _ = query.struct_info.shape query = nn.emit(relax.op.reshape( @@ -723,7 +717,7 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) - attn_bias = masked_fill_relax(attn_bias, attn_mask, min_val) + attn_bias = nn.emit(relax.op.masked_fill(attn_bias, attn_mask, min_val)) return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): @@ -741,7 +735,7 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): # TODO: logical_or on relax cannot_attend = nn.emit(tvm.tir.bitwise_not(torch.logical_or(causal, tvm.tir.Cast("bool", prefix)))) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) + attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): @@ -756,7 +750,7 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) cannot_attend = nn.emit(relax.op.expand_dims(torch.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_bias = masked_fill_relax(attn_bias, cannot_attend, min_val) + attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias def forward( From cc3b128102dce11431c1602f2786ad55a17e7da4 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 12 Jun 2023 15:45:27 +0300 Subject: [PATCH 033/114] replace torch logical_or by relax op implemented in mlc-relax --- mlc_llm/relax_model/mpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 3a2c875b86..cf37ccc121 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -732,8 +732,7 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) - # TODO: logical_or on relax - cannot_attend = nn.emit(tvm.tir.bitwise_not(torch.logical_or(causal, tvm.tir.Cast("bool", prefix)))) + cannot_attend = nn.emit(tvm.tir.bitwise_not(relax.op.logical_or(causal, tvm.tir.Cast("bool", prefix)))) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias From ce460934c29c17416d21413f65a50d21c8890c6f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 14:38:09 +0300 Subject: [PATCH 034/114] replace torch logical_not by relax op implemented in mlc-relax --- mlc_llm/relax_model/mpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index cf37ccc121..5fd89379d2 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -744,10 +744,9 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): # slicing attn_bias[..., :seq_len, :seq_len] dims_len = attn_bias.struct_info.ndim attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) - # TODO: logical_not on relax seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) - cannot_attend = nn.emit(relax.op.expand_dims(torch.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) + cannot_attend = nn.emit(relax.op.expand_dims(relax.op.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias From 9daf6ac3717cbc0aec9a2f4e28579c95e42b0510 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 14:53:23 +0300 Subject: [PATCH 035/114] fix TODO with index_select --- mlc_llm/relax_model/mpt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 5fd89379d2..aec84a5105 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -931,8 +931,7 @@ def _reorder_cache(past_key_values, beam_idx): """ reordered_past = [] for layer_past in past_key_values: - # TODO: Relax implementation? - reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))] + reordered_past += [tuple((nn.emit(relax.op.take(past_state, beam_idx, 0)) for past_state in layer_past))] return reordered_past From 829442352b875354dd95f6b6064535363c428993 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 15:08:06 +0300 Subject: [PATCH 036/114] small fixes --- mlc_llm/relax_model/mpt.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index aec84a5105..bf9d351d3b 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -490,7 +490,7 @@ def __init__( raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = Linear(self.d_model, self.d_model) # TODO: Does field _is_residual exist? - self.out_proj._is_residual = True + # self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) @@ -983,7 +983,7 @@ def get_model(args, hf_config): config = MPTConfig(**hf_config, dtype=dtype) bb = relax.BlockBuilder() - create_encoding_func(bb, config) + create_decoding_func(bb, config) mod = bb.get() @@ -1003,7 +1003,6 @@ def get_model(args, hf_config): param_list = [param for _, param in hf_model.named_parameters()] for i, param in enumerate(param_list): - # TODO: dtype? what is about mix-precision? param_list[i] = tvm.nd.array( param.detach().cpu().numpy().astype(dtype), device ) From 498bdb874cb615be061c68977b76337af9c927d7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 15:58:37 +0300 Subject: [PATCH 037/114] remove backwards --- mlc_llm/relax_model/mpt.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index bf9d351d3b..9d75cb6808 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -167,19 +167,6 @@ def forward(ctx, input, indices): return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, 'b ... -> b (...)') - grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], - device=grad_output.device, dtype=grad_output.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - # grad_input[indices] = grad_output - grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - index_first_axis = IndexFirstAxis.apply @@ -198,14 +185,6 @@ def forward(ctx, values, indices, first_axis_dim): # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) return output - @staticmethod - def backward(ctx, grad_output): - indices, = ctx.saved_tensors - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - grad_values = grad_output[indices] - # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) - return grad_values, None, None - index_put_first_axis = IndexPutFirstAxis.apply @@ -276,17 +255,6 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) - @staticmethod - def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) - _flash_attn_backward( - dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, - ) - return dq, dk, dv, None, None, None, None, None, None, None, None, None - def flash_attn_interface_flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, From 0f5f64f5bbbde163a0b1fb953346c4952f83fa91 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 16:04:21 +0300 Subject: [PATCH 038/114] zone different types of flash attention implementation --- mlc_llm/relax_model/mpt.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 9d75cb6808..4cf021501e 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -80,6 +80,8 @@ def reverse_reshape_and_permute(hidden_states: relax.Expr): )) +######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (BEGIN) ########################## + def scaled_multihead_dot_product_attention( query, key, @@ -144,6 +146,8 @@ def scaled_multihead_dot_product_attention( return (out, attn_weight, past_key_value) return (out, None, past_key_value) +######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (END) ########################## + def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): for tensor in tensors: @@ -154,6 +158,8 @@ def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): # raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') +######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (BEGIN) ########################## + class IndexFirstAxis(torch.autograd.Function): @staticmethod @@ -333,6 +339,10 @@ def flash_attn_fn( output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) return (output, None, past_key_value) +######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (END) ########################## + + +######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (BEGIN) ########################## def triton_flash_attn_fn( query, @@ -409,6 +419,8 @@ def triton_flash_attn_fn( )) # (b, s, h, d)) -> (b, s, (h d)) return (output, None, past_key_value) +######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (END) ########################## + class MultiheadAttention(nn.Module): """Multi-head self attention. From 03426f645cb2ba1c3f24daa47caba3fdc17b3189 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 13 Jun 2023 16:06:22 +0300 Subject: [PATCH 039/114] commented flash attention implementations with types flash and triton --- mlc_llm/relax_model/mpt.py | 504 ++++++++++++++++++------------------- 1 file changed, 252 insertions(+), 252 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 4cf021501e..97344f5551 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -160,264 +160,264 @@ def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): ######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (BEGIN) ########################## -class IndexFirstAxis(torch.autograd.Function): - - @staticmethod - def forward(ctx, input, indices): - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. - # return input[indices] - return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, - repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) - - -index_first_axis = IndexFirstAxis.apply - -import torch -class IndexPutFirstAxis(torch.autograd.Function): - - @staticmethod - def forward(ctx, values, indices, first_axis_dim): - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, - dtype=values.dtype) - # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. - output[indices] = values - # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) - return output - - -index_put_first_axis = IndexPutFirstAxis.apply - - -def bert_padding_unpad_input(hidden_states, attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the - # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim - # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to - # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, - # so we write custom forward and backward to make it a bit faster. - return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, - cu_seqlens, max_seqlen_in_batch) - - -def bert_padding_pad_input(hidden_states, indices, batch, seqlen): - """ - Arguments: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz) - Return: - hidden_states: (batch, seqlen, ...) - """ - dim = hidden_states.shape[-1] - # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) - # output[indices] = hidden_states - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, '(b s) ... -> b s ...', b=batch) - -import flash_attn_cuda -def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_softmax, num_splits=0, - generator=None): - """ - num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means - it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. - Don't change it unless you know what you're doing. - """ - softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( - q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, False, causal, return_softmax, num_splits, generator - ) - # if out.isnan().any() or softmax_lse.isnan().any(): - # breakpoint() - S_dmask = rest[0] if return_softmax else None - return out, softmax_lse, rng_state, S_dmask - - -class FlashAttnFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax, deterministic): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( - q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax - ) - ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.deterministic = deterministic - return out if not return_softmax else (out, softmax_lse, S_dmask) - - -def flash_attn_interface_flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, - deterministic=False -): - return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal, return_attn_probs, deterministic) - - -def flash_attn_fn( - query, - key, - value, - n_heads, - d_model, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - needs_weights=False, - multiquery=False -): - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) - value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) - _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) - # slicing attn_bias[:, :, _s_q:, _s_k:] - s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) - if attn_bias is not None: - raise NotImplementedError(f'attn_bias not implemented for flash attn.') - batch_size, seqlen = query.struct_info.shape[:2] - if key_padding_mask is None: - key_shape = key.struct_info.shape[:2] - key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) - # slicing key_padding_mask[:, -query.struct_info.shape[1]:] - dim1_length = key_padding_mask.struct_info.shape[1] - query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) - (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) - - qnnz, _, _ = query_unpad.struct_info.shape - query_unpad = nn.emit(relax.op.reshape( - query_unpad, - (qnnz, n_heads, d_model), - )) # (nnz, (h d)) -> (nnz, h, d) - - kv_nnz, _, _ = key_unpad.struct_info.shape - kv_n_heads = 1 if multiquery else n_heads - (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) - key_unpad = nn.emit(relax.op.reshape( - key_unpad, - (kv_nnz, kv_n_heads, d_model), - )) # (nnz, (h d)) -> (nnz, h, d) - (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) - value_unpad = nn.emit(relax.op.reshape( - value_unpad, - (kv_nnz, kv_n_heads, d_model), - )) # (nnz, (h d)) -> (nnz, h, d) - - if multiquery: - key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) - value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) - reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) - output_unpad = flash_attn_interface_flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) - nnz, _, _ = output_unpad.struct_info.shape - output_unpad = nn.emit(relax.op.reshape( - output_unpad, - (nnz, n_heads*d_model), - )) # (nnz, h, d)) -> (nnz, (h d)) - output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) - return (output, None, past_key_value) +# class IndexFirstAxis(torch.autograd.Function): + +# @staticmethod +# def forward(ctx, input, indices): +# ctx.save_for_backward(indices) +# assert input.ndim >= 2 +# ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] +# second_dim = other_shape.numel() +# # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. +# # return input[indices] +# return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, +# repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) + + +# index_first_axis = IndexFirstAxis.apply + +# import torch +# class IndexPutFirstAxis(torch.autograd.Function): + +# @staticmethod +# def forward(ctx, values, indices, first_axis_dim): +# ctx.save_for_backward(indices) +# assert indices.ndim == 1 +# assert values.ndim >= 2 +# output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, +# dtype=values.dtype) +# # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. +# output[indices] = values +# # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) +# return output + + +# index_put_first_axis = IndexPutFirstAxis.apply + + +# def bert_padding_unpad_input(hidden_states, attention_mask): +# seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) +# indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() +# max_seqlen_in_batch = seqlens_in_batch.max().item() +# cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) +# # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the +# # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim +# # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to +# # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, +# # so we write custom forward and backward to make it a bit faster. +# return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, +# cu_seqlens, max_seqlen_in_batch) + + +# def bert_padding_pad_input(hidden_states, indices, batch, seqlen): +# """ +# Arguments: +# hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. +# indices: (total_nnz) +# Return: +# hidden_states: (batch, seqlen, ...) +# """ +# dim = hidden_states.shape[-1] +# # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) +# # output[indices] = hidden_states +# output = index_put_first_axis(hidden_states, indices, batch * seqlen) +# return rearrange(output, '(b s) ... -> b s ...', b=batch) + +# import flash_attn_cuda +# def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, +# dropout_p, softmax_scale, causal, return_softmax, num_splits=0, +# generator=None): +# """ +# num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means +# it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. +# Don't change it unless you know what you're doing. +# """ +# softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( +# q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, +# softmax_scale, False, causal, return_softmax, num_splits, generator +# ) +# # if out.isnan().any() or softmax_lse.isnan().any(): +# # breakpoint() +# S_dmask = rest[0] if return_softmax else None +# return out, softmax_lse, rng_state, S_dmask + + +# class FlashAttnFunc(torch.autograd.Function): +# @staticmethod +# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, +# softmax_scale, causal, return_softmax, deterministic): +# if softmax_scale is None: +# softmax_scale = q.shape[-1] ** (-0.5) +# out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( +# q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, +# dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax +# ) +# ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) +# ctx.dropout_p = dropout_p +# ctx.max_seqlen_q = max_seqlen_q +# ctx.max_seqlen_k = max_seqlen_k +# ctx.softmax_scale = softmax_scale +# ctx.causal = causal +# ctx.deterministic = deterministic +# return out if not return_softmax else (out, softmax_lse, S_dmask) + + +# def flash_attn_interface_flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, +# dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, +# deterministic=False +# ): +# return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, +# dropout_p, softmax_scale, causal, return_attn_probs, deterministic) + + +# def flash_attn_fn( +# query, +# key, +# value, +# n_heads, +# d_model, +# past_key_value=None, +# softmax_scale=None, +# attn_bias=None, +# key_padding_mask=None, +# is_causal=False, +# needs_weights=False, +# multiquery=False +# ): +# check_valid_inputs(query, key, value) +# if past_key_value is not None: +# if len(past_key_value) != 0: +# key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) +# value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) +# past_key_value = (key, value) +# if attn_bias is not None: +# _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) +# _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) +# # slicing attn_bias[:, :, _s_q:, _s_k:] +# s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] +# attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) +# if attn_bias is not None: +# raise NotImplementedError(f'attn_bias not implemented for flash attn.') +# batch_size, seqlen = query.struct_info.shape[:2] +# if key_padding_mask is None: +# key_shape = key.struct_info.shape[:2] +# key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) +# # slicing key_padding_mask[:, -query.struct_info.shape[1]:] +# dim1_length = key_padding_mask.struct_info.shape[1] +# query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) +# (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) + +# qnnz, _, _ = query_unpad.struct_info.shape +# query_unpad = nn.emit(relax.op.reshape( +# query_unpad, +# (qnnz, n_heads, d_model), +# )) # (nnz, (h d)) -> (nnz, h, d) + +# kv_nnz, _, _ = key_unpad.struct_info.shape +# kv_n_heads = 1 if multiquery else n_heads +# (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) +# key_unpad = nn.emit(relax.op.reshape( +# key_unpad, +# (kv_nnz, kv_n_heads, d_model), +# )) # (nnz, (h d)) -> (nnz, h, d) +# (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) +# value_unpad = nn.emit(relax.op.reshape( +# value_unpad, +# (kv_nnz, kv_n_heads, d_model), +# )) # (nnz, (h d)) -> (nnz, h, d) + +# if multiquery: +# key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) +# value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) +# reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) +# output_unpad = flash_attn_interface_flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) +# nnz, _, _ = output_unpad.struct_info.shape +# output_unpad = nn.emit(relax.op.reshape( +# output_unpad, +# (nnz, n_heads*d_model), +# )) # (nnz, h, d)) -> (nnz, (h d)) +# output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) +# return (output, None, past_key_value) ######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (END) ########################## ######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (BEGIN) ########################## -def triton_flash_attn_fn( - query, - key, - value, - n_heads, - d_model, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - needs_weights=False, - multiquery=False): - try: - from .flash_attn_triton import flash_attn_func - except: - _installed = False - if version.parse(torch.__version__) < version.parse('2.0.0'): - _installed = True - try: - from flash_attn.flash_attn_triton import flash_attn_func - except: - _installed = False - if not _installed: - raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') - check_valid_inputs(query, key, value) - if past_key_value is not None: - if len(past_key_value) != 0: - key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) - value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) - past_key_value = (key, value) - if attn_bias is not None: - _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) - _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) - # slicing attn_bias[:, :, _s_q:, _s_k:] - s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) - if needs_weights: - raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') - if key_padding_mask is not None: - warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - (b_size, s_k) = key_padding_mask.struct_info.shape[:2] - if attn_bias is None: - attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) - key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) - attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype))) - - batch_size, seq_len, _ = query.struct_info.shape - query = nn.emit(relax.op.reshape( - query, - (batch_size, seq_len, n_heads, d_model), - )) # b s (h d) -> b s h d - - batch_size, seq_len, _ = key.struct_info.shape - kv_n_heads = 1 if multiquery else n_heads - key = nn.emit(relax.op.reshape( - key, - (batch_size, seq_len, kv_n_heads, d_model), - )) # b s (h d) -> b s h d - value = nn.emit(relax.op.reshape( - value, - (batch_size, seq_len, kv_n_heads, d_model), - )) # b s (h d) -> b s h d - if multiquery: - key = relax.op.broadcast_to(key, (batch_size, seq_len, n_heads, d_model)) - value = relax.op.broadcast_to(value, (batch_size, seq_len, n_heads, d_model)) - reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) - attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) - batch_size, seq_len, _, _ = attn_output.struct_info.shape - output = nn.emit(relax.op.reshape( - attn_output, - (batch_size, seq_len, n_heads*d_model), - )) # (b, s, h, d)) -> (b, s, (h d)) - return (output, None, past_key_value) +# def triton_flash_attn_fn( +# query, +# key, +# value, +# n_heads, +# d_model, +# past_key_value=None, +# softmax_scale=None, +# attn_bias=None, +# key_padding_mask=None, +# is_causal=False, +# needs_weights=False, +# multiquery=False): +# try: +# from .flash_attn_triton import flash_attn_func +# except: +# _installed = False +# if version.parse(torch.__version__) < version.parse('2.0.0'): +# _installed = True +# try: +# from flash_attn.flash_attn_triton import flash_attn_func +# except: +# _installed = False +# if not _installed: +# raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.') +# check_valid_inputs(query, key, value) +# if past_key_value is not None: +# if len(past_key_value) != 0: +# key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) +# value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) +# past_key_value = (key, value) +# if attn_bias is not None: +# _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) +# _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) +# # slicing attn_bias[:, :, _s_q:, _s_k:] +# s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] +# attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) +# if needs_weights: +# raise NotImplementedError(f'attn_impl: triton cannot return attn weights.') +# if key_padding_mask is not None: +# warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') +# (b_size, s_k) = key_padding_mask.struct_info.shape[:2] +# if attn_bias is None: +# attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) +# key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) +# attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype))) + +# batch_size, seq_len, _ = query.struct_info.shape +# query = nn.emit(relax.op.reshape( +# query, +# (batch_size, seq_len, n_heads, d_model), +# )) # b s (h d) -> b s h d + +# batch_size, seq_len, _ = key.struct_info.shape +# kv_n_heads = 1 if multiquery else n_heads +# key = nn.emit(relax.op.reshape( +# key, +# (batch_size, seq_len, kv_n_heads, d_model), +# )) # b s (h d) -> b s h d +# value = nn.emit(relax.op.reshape( +# value, +# (batch_size, seq_len, kv_n_heads, d_model), +# )) # b s (h d) -> b s h d +# if multiquery: +# key = relax.op.broadcast_to(key, (batch_size, seq_len, n_heads, d_model)) +# value = relax.op.broadcast_to(value, (batch_size, seq_len, n_heads, d_model)) +# reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) +# attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale) +# batch_size, seq_len, _, _ = attn_output.struct_info.shape +# output = nn.emit(relax.op.reshape( +# attn_output, +# (batch_size, seq_len, n_heads*d_model), +# )) # (b, s, h, d)) -> (b, s, (h d)) +# return (output, None, past_key_value) ######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (END) ########################## From 2b8bbbfc4a69f99f27ac23552502f6aa16682c6b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 10:10:58 +0300 Subject: [PATCH 040/114] some fixes --- mlc_llm/relax_model/mpt.py | 9 ++++----- mlc_llm/utils.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 97344f5551..2e074013d6 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -527,7 +527,6 @@ def __init__(self, config: MPTConfig): clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], - attn_pdrop=attn_config['attn_pdrop'], ) self.mlp = MPTMLP( hidden_size=self.hidden_size, @@ -861,9 +860,9 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache ) - logits = nn.emit(relax.op.matmul(outputs.last_hidden_state, self.transformer.wte.weight)) + logits = nn.emit(relax.op.matmul(outputs[0], self.transformer.wte.weight)) - return logits, outputs.past_key_values + return logits, outputs[1] def fsdp_wrap_fn(self, module): return isinstance(module, MPTBlock) @@ -956,8 +955,8 @@ def get_model(args, hf_config): # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct max_seq_len = args.max_seq_len if args.max_seq_len is not None else 4096 # 4096 recommended - config.update({"max_seq_len": max_seq_len}) - config.update({"max_new_tokens": args.seq_len}) + hf_config.update({"max_seq_len": max_seq_len}) + # hf_config.update({"max_new_tokens": args.seq_len}) if model_name.startswith("mpt-"): config = MPTConfig(**hf_config, dtype=dtype) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 41519aec01..c653ee6fb0 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -78,7 +78,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "gorilla-": ("gorilla", "llama"), "starcoder": ("code_gpt", "gpt_bigcode"), "wizardcoder-": ("code_gpt", "gpt_bigcode"), - "mpt-": ("mpt", "mpt-7b", "mpt-7b-instruct"), + "mpt-": ("mpt", "mpt"), } model = args.model.lower() for prefix, (conv_template, model_category) in supported_model_prefix.items(): From b1e34bf7be67e71b2a2e578d0707cc0572c6bc94 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 10:35:05 +0300 Subject: [PATCH 041/114] fix dtype in Linear layers --- mlc_llm/relax_model/mpt.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 2e074013d6..42f3aa1788 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -432,6 +432,7 @@ def __init__( self, d_model: int, n_heads: int, + dtype: str, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, @@ -448,7 +449,7 @@ def __init__( if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.Wqkv = Linear(self.d_model, 3 * self.d_model) + self.Wqkv = Linear(self.d_model, 3 * self.d_model, dtype) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: @@ -456,19 +457,21 @@ def __init__( self.q_ln = layernorm_class(self.d_model) self.k_ln = layernorm_class(self.d_model) if self.attn_impl == 'flash': - self.attn_fn = flash_attn_fn + raise NotImplemented("Flash type of flash attention has not been implemented yet") + # self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': # While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. # When training larger models this can trigger alloc retries which hurts performance. # If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`. - self.attn_fn = triton_flash_attn_fn + raise NotImplemented("Triton type of flash attention has not been implemented yet") + # self.attn_fn = triton_flash_attn_fn elif self.attn_impl == 'torch': # Using `attn_impl: torch`. If your model does not use `alibi` or `prefix_lm` we recommend using `attn_impl: flash` # otherwise we recommend using `attn_impl: triton`. self.attn_fn = scaled_multihead_dot_product_attention else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') - self.out_proj = Linear(self.d_model, self.d_model) + self.out_proj = Linear(self.d_model, self.d_model, dtype) # TODO: Does field _is_residual exist? # self.out_proj._is_residual = True @@ -523,6 +526,7 @@ def __init__(self, config: MPTConfig): self.self_attn = attn_class( d_model=self.hidden_size, n_heads=config.n_heads, + dtype=config.dtype, attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], From 10c7cd6a73204002eb8a5b621280f6eb865e0bfe Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 10:40:49 +0300 Subject: [PATCH 042/114] fix dtype in layer norm --- mlc_llm/relax_model/mpt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 42f3aa1788..b6faf95c8f 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -31,7 +31,7 @@ def _cast_if_autocast_enabled(tensor): # Low-precision layer norm for mpt-7b-instruct, where are no biases expected class LPLayerNormWOBias(nn.Module): - def __init__(self, normalized_shape, eps=1e-05, dtype=None): + def __init__(self, normalized_shape, dtype, eps=1e-05): self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") # TODO: check default filling of weights self.weight = relax.op.ones((normalized_shape,), dtype) @@ -454,8 +454,8 @@ def __init__( self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: layernorm_class = LPLayerNormWOBias if low_precision_layernorm else LayerNorm - self.q_ln = layernorm_class(self.d_model) - self.k_ln = layernorm_class(self.d_model) + self.q_ln = layernorm_class(self.d_model, dtype) + self.k_ln = layernorm_class(self.d_model, dtype) if self.attn_impl == 'flash': raise NotImplemented("Flash type of flash attention has not been implemented yet") # self.attn_fn = flash_attn_fn @@ -537,8 +537,8 @@ def __init__(self, config: MPTConfig): intermediate_size=config.expansion_ratio*self.hidden_size, dtype=config.dtype, ) - self.input_layernorm = norm_class(self.hidden_size) - self.post_attention_layernorm = norm_class(self.hidden_size) + self.input_layernorm = norm_class(self.hidden_size, config.dtype) + self.post_attention_layernorm = norm_class(self.hidden_size, config.dtype) def forward( self, From fcd8d7d44aa7980996efd206c942732d6c7c2781 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 10:45:52 +0300 Subject: [PATCH 043/114] fix config using --- mlc_llm/relax_model/mpt.py | 47 +++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index b6faf95c8f..006b9e3472 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -636,12 +636,18 @@ def __init__(self, config: MPTConfig): self.alibi_bias_max = config.attn_config['alibi_bias_max'] self.is_causal = not self.prefix_lm + self.n_heads = config.n_heads + self.n_layers = config.n_layers + self.max_seq_len = config.max_seq_len + self.return_dict = config.return_dict + self.use_cache = config.use_cache + self._attn_bias_initialized = False self.attn_bias = None self.attn_bias_shape = attn_bias_shape( self.attn_impl, - config.n_heads, - config.max_seq_len, + self.n_heads, + self.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, @@ -672,7 +678,7 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma if self.attn_bias_shape: self.attn_bias = nn.emit(relax.op.zeros(self.attn_bias_shape, dtype=dtype)) self.attn_bias = build_attn_bias( - self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max + self.attn_impl, self.attn_bias, self.n_heads, self.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max ) self._attn_bias_initialized = True if self.attn_impl == 'flash': @@ -705,11 +711,11 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): (s_k, s_q) = attn_bias.struct_info.shape[-2:] - if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len: - raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.') + if s_k != self.max_seq_len or s_q != self.max_seq_len: + raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.max_seq_len} ' + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.struct_info.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + if seq_len > self.max_seq_len: + raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.max_seq_len}') # slicing attn_bias[..., :seq_len, :seq_len] dims_len = attn_bias.struct_info.ndim attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) @@ -722,8 +728,8 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): seq_len = sequence_id.struct_info.shape[-1] - if seq_len > self.config.max_seq_len: - raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}') + if seq_len > self.max_seq_len: + raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.max_seq_len}') # slicing attn_bias[..., :seq_len, :seq_len] dims_len = attn_bias.struct_info.ndim attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) @@ -746,8 +752,8 @@ def forward( output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.return_dict + use_cache = use_cache if use_cache is not None else self.use_cache if attention_mask is not None: attention_mask = nn.emit(tvm.tir.Cast("bool", attention_mask)) if prefix_mask is not None: @@ -761,7 +767,7 @@ def forward( raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') S = input_ids.struct_info.shape[1] - assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' + assert S <= self.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.max_seq_len}' tok_emb = self.wte(input_ids) if self.alibi: @@ -769,13 +775,13 @@ def forward( else: past_position = 0 if past_key_values is not None: - if len(past_key_values) != self.config.n_layers: - raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).') + if len(past_key_values) != self.n_layers: + raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.n_layers!r}).') past_position = past_key_values[0][0].struct_info.shape[1] if self.attn_impl == 'torch': past_position = past_key_values[0][0].struct_info.shape[3] - if S + past_position > self.config.max_seq_len: - raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.') + if S + past_position > self.max_seq_len: + raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) if attention_mask is not None: pos_diff_to_slice = nn.emit(relax.op.cumsum(tvm.tir.Cast("int32", tvm.tir.bitwise_not(attention_mask)), axis=1)) @@ -787,7 +793,7 @@ def forward( x = tok_emb + pos_emb (attn_bias, attention_mask) = self._attn_bias(dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.config.n_layers)] + past_key_values = [() for _ in range(self.n_layers)] all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for (b_idx, block) in enumerate(self.blocks): @@ -821,6 +827,9 @@ def __init__(self, config: MPTConfig): self.transformer = MPTModel(config) self.dtype = config.dtype + self.return_dict = config.return_dict + self.use_cache = config.use_cache + def get_input_embeddings(self): return self.transformer.wte @@ -851,8 +860,8 @@ def forward( output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None ): - return_dict = return_dict if return_dict is not None else self.config.return_dict - use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.return_dict + use_cache = use_cache if use_cache is not None else self.use_cache outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, From 30ce7b540fec9812d4672fc1067ab424a4e2d3de Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 13:39:44 +0300 Subject: [PATCH 044/114] small fixes --- mlc_llm/relax_model/mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 006b9e3472..4c1012a6c3 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -587,8 +587,8 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s def gen_slopes(n_heads, alibi_bias_max=8): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = nn.emit(relax.op.arange(1, _n_heads + 1, dtype="float32")) - m = nn.emit(m * (alibi_bias_max / _n_heads)) - slopes = 1.0 / math.pow(2, m) + m = nn.emit(m * relax.const(alibi_bias_max / _n_heads)) + slopes = relax.const(1.0) / relax.op.power(m, 2) if _n_heads != n_heads: slopes_len = slopes.struct_info.shape[0] slopes = nn.emit(relax.op.strided_slice( @@ -791,7 +791,7 @@ def forward( pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) pos_emb = self.wpe(pos) x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias(dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) + (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.n_layers)] all_hidden_states = () if output_hidden_states else None @@ -966,7 +966,7 @@ def get_model(args, hf_config): # model_path = args.model_path dtype = args.quantization.model_dtype # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct - max_seq_len = args.max_seq_len if args.max_seq_len is not None else 4096 # 4096 recommended + max_seq_len = args.max_seq_len if args.max_seq_len is not None and args.max_seq_len > 0 else 4096 # 4096 recommended hf_config.update({"max_seq_len": max_seq_len}) # hf_config.update({"max_new_tokens": args.seq_len}) From 983f5239aafe8e8969c03aa275ab8fc0a7deab3b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 14:20:54 +0300 Subject: [PATCH 045/114] tir.Cast was replaced by relax.op.astype --- mlc_llm/relax_model/mpt.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 4c1012a6c3..719706d1d8 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -27,7 +27,7 @@ def _cast_if_autocast_enabled(tensor): # else: # raise NotImplementedError() dtype = "float32" # TODO: temporal workaround - return nn.emit(tvm.tir.Cast(dtype, tensor)) + return nn.emit(relax.op.astype(tensor, dtype)) # Low-precision layer norm for mpt-7b-instruct, where are no biases expected class LPLayerNormWOBias(nn.Module): @@ -132,7 +132,7 @@ def scaled_multihead_dot_product_attention( s = relax.op.maximum(s_q, s_k) causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) causal_mask = nn.emit(relax.op.tril(causal_mask)) - causal_mask = tvm.tir.Cast("bool", causal_mask) + causal_mask = nn.emit(relax.op.astype(causal_mask, "bool")) causal_mask = tvm.tir.bitwise_not(causal_mask) # slicing causal_mask[-s_q:, -s_k:] s_q_end, s_k_end = causal_mask.struct_info.shape @@ -588,7 +588,8 @@ def gen_slopes(n_heads, alibi_bias_max=8): _n_heads = 2 ** math.ceil(math.log2(n_heads)) m = nn.emit(relax.op.arange(1, _n_heads + 1, dtype="float32")) m = nn.emit(m * relax.const(alibi_bias_max / _n_heads)) - slopes = relax.const(1.0) / relax.op.power(m, 2) + slopes = nn.emit(relax.op.divide(relax.const(1.0), relax.op.power(m, relax.const(2.0)))) + if _n_heads != n_heads: slopes_len = slopes.struct_info.shape[0] slopes = nn.emit(relax.op.strided_slice( @@ -606,9 +607,10 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, dtype=None) alibi_bias = nn.emit(alibi_bias - relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, seq_len, 1))) alibi_bias = nn.emit(relax.op.negative(relax.op.abs(alibi_bias))) slopes = gen_slopes(n_heads, alibi_bias_max) + alibi_bias = nn.emit(relax.op.astype(alibi_bias, slopes.struct_info.dtype.value)) alibi_bias = nn.emit(alibi_bias * slopes) if dtype is not None: - alibi_bias = nn.emit(tvm.tir.Cast(dtype, alibi_bias)) + alibi_bias = nn.emit(relax.op.astype(alibi_bias, dtype)) return alibi_bias @@ -684,7 +686,7 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma if self.attn_impl == 'flash': return (self.attn_bias, attention_mask) if self.attn_bias is not None: - self.attn_bias = nn.emit(tvm.tir.Cast(dtype, self.attn_bias)) + self.attn_bias = nn.emit(relax.op.astype(self.attn_bias, dtype)) attn_bias = self.attn_bias if self.prefix_lm: assert isinstance(attn_bias, relax.Expr) @@ -721,7 +723,7 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) - cannot_attend = nn.emit(tvm.tir.bitwise_not(relax.op.logical_or(causal, tvm.tir.Cast("bool", prefix)))) + cannot_attend = nn.emit(tvm.tir.bitwise_not(relax.op.logical_or(causal, relax.op.astype(prefix, "bool")))) min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias @@ -755,9 +757,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.return_dict use_cache = use_cache if use_cache is not None else self.use_cache if attention_mask is not None: - attention_mask = nn.emit(tvm.tir.Cast("bool", attention_mask)) + attention_mask = nn.emit(relax.op.astype(attention_mask, "bool")) if prefix_mask is not None: - prefix_mask = nn.emit(tvm.tir.Cast("bool", prefix_mask)) + prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) if not return_dict: raise NotImplementedError('return_dict False is not implemented yet for MPT') if output_attentions: @@ -784,7 +786,7 @@ def forward( raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) if attention_mask is not None: - pos_diff_to_slice = nn.emit(relax.op.cumsum(tvm.tir.Cast("int32", tvm.tir.bitwise_not(attention_mask)), axis=1)) + pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(tvm.tir.bitwise_not(attention_mask), "int32"), axis=1)) dim1_len = pos_diff_to_slice.struct_info.shape[1] # slicing [:, past_position:] pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) From 62f4a396ac86cb940749c3c63a4ac56984007e47 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 14:49:12 +0300 Subject: [PATCH 046/114] update downcast workaround for lplayernorm --- mlc_llm/relax_model/mpt.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 719706d1d8..db61bb2838 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -18,7 +18,7 @@ ) -def _cast_if_autocast_enabled(tensor): +def _cast_if_autocast_enabled(tensor: relax.Expr, dtype="float32"): # # TODO: how to check device? # if tensor.device.type == 'cuda': # dtype = "float16" @@ -26,7 +26,6 @@ def _cast_if_autocast_enabled(tensor): # dtype = "bfloat16" # else: # raise NotImplementedError() - dtype = "float32" # TODO: temporal workaround return nn.emit(relax.op.astype(tensor, dtype)) # Low-precision layer norm for mpt-7b-instruct, where are no biases expected @@ -36,12 +35,15 @@ def __init__(self, normalized_shape, dtype, eps=1e-05): # TODO: check default filling of weights self.weight = relax.op.ones((normalized_shape,), dtype) self.bias = relax.op.zeros((normalized_shape,), dtype) - self.eps = relax.const(eps, dtype) + self.eps = eps + + self.dtype = dtype def forward(self, x): - downcast_x = _cast_if_autocast_enabled(x) - downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight - downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias + dtype = self.dtype # TODO: temporal workaround + downcast_x = _cast_if_autocast_enabled(x, dtype) + downcast_weight = _cast_if_autocast_enabled(self.weight, dtype) if self.weight is not None else self.weight + downcast_bias = _cast_if_autocast_enabled(self.bias, dtype) if self.bias is not None else self.bias return nn.emit(relax.op.nn.layer_norm(downcast_x, downcast_weight, downcast_bias, axes=-1, epsilon=self.eps)) NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNormWOBias} From 7ad3f4f8b3f32231e7a79282c3549ee4c7e5659b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 14 Jun 2023 15:36:44 +0300 Subject: [PATCH 047/114] more torch group were replaced by relax ops --- mlc_llm/relax_model/mpt.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index db61bb2838..96f6cd099d 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -480,14 +480,17 @@ def __init__( def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): qkv = self.Wqkv(x) if self.clip_qkv: - qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - (query, key, value) = qkv.chunk(3, dim=2) + qkv = nn.emit(relax.op.clip(qkv, min=relax.const(-self.clip_qkv), max=relax.const(self.clip_qkv))) + qkv_out = relax.op.split(qkv, 3, axis=2) + query = nn.emit(qkv_out[0]) + key = nn.emit(qkv_out[1]) + value = nn.emit(qkv_out[2]) key_padding_mask = attention_mask if self.qk_ln: - dtype = query.dtype - query = self.q_ln(query).to(dtype) - key = self.k_ln(key).to(dtype) - (context, attn_weights, past_key_value) = self.attn_fn( + dtype = query.struct_info.dtype + query = nn.emit(relax.op.astype(self.q_ln(query), dtype)) + key = nn.emit(relax.op.astype(self.k_ln(key), dtype)) + attn_out = self.attn_fn( query, key, value, @@ -500,7 +503,7 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i is_causal=is_causal, needs_weights=needs_weights ) - return (self.out_proj(context), attn_weights, past_key_value) + return (self.out_proj(attn_out[0]), attn_out[1], attn_out[2]) ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} @@ -609,7 +612,7 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, dtype=None) alibi_bias = nn.emit(alibi_bias - relax.op.reshape(relax.op.arange(1 - seq_len, 1, dtype="int32"), (1, 1, seq_len, 1))) alibi_bias = nn.emit(relax.op.negative(relax.op.abs(alibi_bias))) slopes = gen_slopes(n_heads, alibi_bias_max) - alibi_bias = nn.emit(relax.op.astype(alibi_bias, slopes.struct_info.dtype.value)) + alibi_bias = nn.emit(relax.op.astype(alibi_bias, slopes.struct_info.dtype)) alibi_bias = nn.emit(alibi_bias * slopes) if dtype is not None: alibi_bias = nn.emit(relax.op.astype(alibi_bias, dtype)) From aa187bbae74d4b2e1172918851041dca1d878351 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 08:31:00 +0300 Subject: [PATCH 048/114] correct rearrange --- mlc_llm/relax_model/mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 96f6cd099d..08dff6865c 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -60,12 +60,12 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int, indeces: List[int] = [0, 2, 1, 3]): ''' - Transform shape of input: b s (h d) -> b h d s + Transform shape of input: b s (h d) -> b s h d -> b h s d or b h d s ''' batch_size, seqlen, _ = hidden_states.struct_info.shape inter = nn.emit(relax.op.reshape( hidden_states, - (batch_size, seqlen, n_heads, d_model), + (batch_size, seqlen, n_heads, int(d_model / n_heads)), )) return nn.emit(relax.op.permute_dims(inter, indeces)) @@ -74,11 +74,11 @@ def reverse_reshape_and_permute(hidden_states: relax.Expr): ''' Transform shape of input: b h s d -> b s (h d) ''' - batch_size, n_heads, seqlen, d_model = hidden_states.struct_info.shape + batch_size, n_heads, seqlen, head_len = hidden_states.struct_info.shape inter = nn.emit(relax.op.permute_dims(hidden_states, [0, 2, 1, 3])) return nn.emit(relax.op.reshape( inter, - (batch_size, seqlen, n_heads*d_model), + (batch_size, seqlen, n_heads*head_len), )) From fb1138b5f5ddc1a0ae49e7455f615fcb55c27e0a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 08:35:56 +0300 Subject: [PATCH 049/114] switch on model_path in get_model method. need to redo due to update in other models code --- mlc_llm/relax_model/mpt.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 08dff6865c..8890edd344 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -966,11 +966,9 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, def get_model(args, hf_config): from transformers import AutoModelForCausalLM # type: ignore[import] - import torch # type: ignore[import] model_name = args.model - # TODO: download model and use model_path instead of args for from_pretrained - # model_path = args.model_path + model_path = args.model_path dtype = args.quantization.model_dtype # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct max_seq_len = args.max_seq_len if args.max_seq_len is not None and args.max_seq_len > 0 else 4096 # 4096 recommended @@ -987,14 +985,9 @@ def get_model(args, hf_config): mod = bb.get() device = tvm.cpu() - # TODO: get default mpt-7b-instruct from HF. Possibly it should be downloaded earlier - # and use model_path instead hf_model = AutoModelForCausalLM.from_pretrained( - 'mosaicml/mpt-7b-instruct', - config=config, - torch_dtype=torch.bfloat16, - trust_remote_code=True + model_path, ) for name, param in hf_model.named_parameters(): print(name, param.shape) From 4860be49ef4e127d6e684b7038122986d6e8f01b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 09:06:19 +0300 Subject: [PATCH 050/114] small fixes --- mlc_llm/relax_model/mpt.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 8890edd344..d5a460fe0f 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -110,13 +110,14 @@ def scaled_multihead_dot_product_attention( (b, _, s_q, d) = q.struct_info.shape s_k = k.struct_info.shape[-1] if softmax_scale is None: - softmax_scale = 1 / math.sqrt(d) + softmax_scale = 1 / math.sqrt(d) + softmax_scale = relax.op.astype(relax.const(softmax_scale), q.struct_info.dtype) attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) + _, _, s_q_end, s_k_end = attn_bias.struct_info.shape if attn_bias is not None: - _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - s_q) - _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - s_k) + _s_q = np.maximum(0, s_q_end - s_q) + _s_k = np.maximum(0, s_k_end - s_k) # slicing attn_bias[:, :, _s_q:, _s_k:] - s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) if (attn_bias.struct_info.shape[-1] != 1 and attn_bias.struct_info.shape[-1] != s_k or # dynamic condition? @@ -717,7 +718,8 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma return (attn_bias, None) def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): - (s_k, s_q) = attn_bias.struct_info.shape[-2:] + s_k = attn_bias.struct_info.shape[-2] + s_q = attn_bias.struct_info.shape[-1] if s_k != self.max_seq_len or s_q != self.max_seq_len: raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.max_seq_len} ' + f'but are {s_k} and {s_q}.') seq_len = prefix_mask.struct_info.shape[-1] @@ -939,15 +941,11 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, with bb.function("decode"): model = MPTForCausalLM(config) input_ids = nn.Placeholder((1, 1), dtype="int32", name="input_ids") - # Placeholder for compatibility to LLAMA - all_seq_len_shape = relax.Var("place_holder", R.Object()) - state = relax.Var("state", R.Tuple([R.Object()] * config.n_layers * 5)) + with bb.dataflow(): - logits, states = model(input_ids, state) + logits, states = model(input_ids) params = [ input_ids, - all_seq_len_shape, - state, ] + model.parameters() named_params = named_parameters(model) @@ -960,7 +958,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, mod = bb.get() gv = mod.get_global_var("decode") - bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) return pidx2pname From b888dd41fcdfe0f66d52dbc83a7760d320398810 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 10:02:01 +0300 Subject: [PATCH 051/114] replace matmul by linear for weight transposition from the box --- mlc_llm/relax_model/mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index d5a460fe0f..d0a3d4ed5c 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -882,7 +882,7 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache ) - logits = nn.emit(relax.op.matmul(outputs[0], self.transformer.wte.weight)) + logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) return logits, outputs[1] From be73cd9b2c2d2e30306d8708a0b7af8c528679b2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 11:39:26 +0300 Subject: [PATCH 052/114] fixes --- mlc_llm/relax_model/mpt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index d0a3d4ed5c..94d5d19d20 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -951,8 +951,9 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, named_params = named_parameters(model) for i, (name, param) in enumerate(named_params.items()): pidx2pname[i] = name - assert param.same_as(params[i + 3]) - + assert param.same_as(params[i + 1]) + if states is None: + states = () gv = bb.emit_output((logits, relax.Tuple(states))) bb.emit_func_output(gv, params) @@ -979,13 +980,13 @@ def get_model(args, hf_config): bb = relax.BlockBuilder() create_decoding_func(bb, config) - mod = bb.get() device = tvm.cpu() hf_model = AutoModelForCausalLM.from_pretrained( model_path, + trust_remote_code=True, ) for name, param in hf_model.named_parameters(): print(name, param.shape) From f173728a042e286195151bf2eff9381a46f0a3a7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 15 Jun 2023 16:45:40 +0300 Subject: [PATCH 053/114] check decode only for mpt models --- build.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/build.py b/build.py index 6a4f705b2a..eaabf794b8 100644 --- a/build.py +++ b/build.py @@ -274,6 +274,10 @@ def mod_transform_before_build( "get_metadata", "reset_kv_cache", ] + elif ARGS.model.startswith("mpt-"): + model_names = [ + "decode", + ] else: model_names = [ "prefill", From 9885b5473508d54f18b23da6c3f199227949f435 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 10:40:34 +0300 Subject: [PATCH 054/114] add desc for mpt-7b-instruct --- docs/models/mpt/README.md | 62 ++++++++++ docs/models/mpt/mpt_topology.txt | 198 +++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 docs/models/mpt/README.md create mode 100644 docs/models/mpt/mpt_topology.txt diff --git a/docs/models/mpt/README.md b/docs/models/mpt/README.md new file mode 100644 index 0000000000..ce9b912495 --- /dev/null +++ b/docs/models/mpt/README.md @@ -0,0 +1,62 @@ +# MPT-7b-instruct + +There is brief description of mpt-7b-instruct model. It is needed for correct Relax implementation of the model and weights mapping. +MPT-7b-instruct is decoder-like kv_cache free model using flash attention. +The list of Tensor name - tensor size for the original (pytorch) model can be found in mpt_topology.txt file. +The original config for the model: +{ + "architectures": [ + "MPTForCausalLM" + ], + "attn_config": { + "alibi": true, + "alibi_bias_max": 8, + "attn_impl": "torch", + "attn_pdrop": 0, + "attn_type": "multihead_attention", + "attn_uses_sequence_id": false, + "clip_qkv": null, + "prefix_lm": false, + "qk_ln": false, + "softmax_scale": null + }, + "auto_map": { + "AutoConfig": "configuration_mpt.MPTConfig", + "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM" + }, + "d_model": 4096, + "emb_pdrop": 0, + "embedding_fraction": 1.0, + "expansion_ratio": 4, + "init_config": { + "emb_init_std": null, + "emb_init_uniform_lim": null, + "fan_mode": "fan_in", + "init_div_is_residual": true, + "init_gain": 0, + "init_nonlinearity": "relu", + "init_std": 0.02, + "name": "kaiming_normal_", + "verbose": 0 + }, + "init_device": "cpu", + "learned_pos_emb": true, + "logit_scale": null, + "max_seq_len": 2048, + "model_type": "mpt", + "n_heads": 32, + "n_layers": 32, + "no_bias": true, + "norm_type": "low_precision_layernorm", + "resid_pdrop": 0, + "tokenizer_name": "EleutherAI/gpt-neox-20b", + "torch_dtype": "bfloat16", + "transformers_version": "4.28.1", + "use_cache": false, + "verbose": 0, + "vocab_size": 50432 +} + +This config wraps default one. It should highlight two defaults parameters: +"is_encoder_decoder": false, +"use_cache": false, \ No newline at end of file diff --git a/docs/models/mpt/mpt_topology.txt b/docs/models/mpt/mpt_topology.txt new file mode 100644 index 0000000000..e2d911d9c0 --- /dev/null +++ b/docs/models/mpt/mpt_topology.txt @@ -0,0 +1,198 @@ +transformer.wte.weight torch.Size([50432, 4096]) + +transformer.blocks.0.norm_1.weight torch.Size([4096]) +transformer.blocks.0.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.0.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.0.norm_2.weight torch.Size([4096]) +transformer.blocks.0.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.0.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.blocks.1.norm_1.weight torch.Size([4096]) +transformer.blocks.1.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.1.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.1.norm_2.weight torch.Size([4096]) +transformer.blocks.1.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.1.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.blocks.2.norm_1.weight torch.Size([4096]) +transformer.blocks.2.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.2.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.2.norm_2.weight torch.Size([4096]) +transformer.blocks.2.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.2.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.3.norm_1.weight torch.Size([4096]) +transformer.blocks.3.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.3.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.3.norm_2.weight torch.Size([4096]) +transformer.blocks.3.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.3.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.4.norm_1.weight torch.Size([4096]) +transformer.blocks.4.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.4.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.4.norm_2.weight torch.Size([4096]) +transformer.blocks.4.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.4.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.5.norm_1.weight torch.Size([4096]) +transformer.blocks.5.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.5.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.5.norm_2.weight torch.Size([4096]) +transformer.blocks.5.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.5.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.6.norm_1.weight torch.Size([4096]) +transformer.blocks.6.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.6.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.6.norm_2.weight torch.Size([4096]) +transformer.blocks.6.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.6.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.7.norm_1.weight torch.Size([4096]) +transformer.blocks.7.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.7.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.7.norm_2.weight torch.Size([4096]) +transformer.blocks.7.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.7.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.8.norm_1.weight torch.Size([4096]) +transformer.blocks.8.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.8.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.8.norm_2.weight torch.Size([4096]) +transformer.blocks.8.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.8.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.9.norm_1.weight torch.Size([4096]) +transformer.blocks.9.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.9.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.9.norm_2.weight torch.Size([4096]) +transformer.blocks.9.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.9.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.10.norm_1.weight torch.Size([4096]) +transformer.blocks.10.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.10.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.10.norm_2.weight torch.Size([4096]) +transformer.blocks.10.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.10.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.11.norm_1.weight torch.Size([4096]) +transformer.blocks.11.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.11.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.11.norm_2.weight torch.Size([4096]) +transformer.blocks.11.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.11.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.12.norm_1.weight torch.Size([4096]) +transformer.blocks.12.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.12.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.12.norm_2.weight torch.Size([4096]) +transformer.blocks.12.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.12.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.13.norm_1.weight torch.Size([4096]) +transformer.blocks.13.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.13.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.13.norm_2.weight torch.Size([4096]) +transformer.blocks.13.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.13.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.14.norm_1.weight torch.Size([4096]) +transformer.blocks.14.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.14.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.14.norm_2.weight torch.Size([4096]) +transformer.blocks.14.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.14.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.15.norm_1.weight torch.Size([4096]) +transformer.blocks.15.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.15.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.15.norm_2.weight torch.Size([4096]) +transformer.blocks.15.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.15.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.16.norm_1.weight torch.Size([4096]) +transformer.blocks.16.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.16.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.16.norm_2.weight torch.Size([4096]) +transformer.blocks.16.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.16.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.17.norm_1.weight torch.Size([4096]) +transformer.blocks.17.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.17.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.17.norm_2.weight torch.Size([4096]) +transformer.blocks.17.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.17.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.18.norm_1.weight torch.Size([4096]) +transformer.blocks.18.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.18.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.18.norm_2.weight torch.Size([4096]) +transformer.blocks.18.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.18.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.19.norm_1.weight torch.Size([4096]) +transformer.blocks.19.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.19.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.19.norm_2.weight torch.Size([4096]) +transformer.blocks.19.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.19.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.20.norm_1.weight torch.Size([4096]) +transformer.blocks.20.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.20.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.20.norm_2.weight torch.Size([4096]) +transformer.blocks.20.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.20.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.21.norm_1.weight torch.Size([4096]) +transformer.blocks.21.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.21.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.21.norm_2.weight torch.Size([4096]) +transformer.blocks.21.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.21.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.22.norm_1.weight torch.Size([4096]) +transformer.blocks.22.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.22.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.22.norm_2.weight torch.Size([4096]) +transformer.blocks.22.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.22.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.23.norm_1.weight torch.Size([4096]) +transformer.blocks.23.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.23.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.23.norm_2.weight torch.Size([4096]) +transformer.blocks.23.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.23.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.24.norm_1.weight torch.Size([4096]) +transformer.blocks.24.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.24.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.24.norm_2.weight torch.Size([4096]) +transformer.blocks.24.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.24.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.25.norm_1.weight torch.Size([4096]) +transformer.blocks.25.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.25.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.25.norm_2.weight torch.Size([4096]) +transformer.blocks.25.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.25.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.26.norm_1.weight torch.Size([4096]) +transformer.blocks.26.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.26.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.26.norm_2.weight torch.Size([4096]) +transformer.blocks.26.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.26.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.27.norm_1.weight torch.Size([4096]) +transformer.blocks.27.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.27.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.27.norm_2.weight torch.Size([4096]) +transformer.blocks.27.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.27.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.28.norm_1.weight torch.Size([4096]) +transformer.blocks.28.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.28.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.28.norm_2.weight torch.Size([4096]) +transformer.blocks.28.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.28.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.29.norm_1.weight torch.Size([4096]) +transformer.blocks.29.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.29.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.29.norm_2.weight torch.Size([4096]) +transformer.blocks.29.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.29.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.30.norm_1.weight torch.Size([4096]) +transformer.blocks.30.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.30.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.30.norm_2.weight torch.Size([4096]) +transformer.blocks.30.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.30.ffn.down_proj.weight torch.Size([4096, 16384]) +transformer.blocks.31.norm_1.weight torch.Size([4096]) +transformer.blocks.31.attn.Wqkv.weight torch.Size([12288, 4096]) +transformer.blocks.31.attn.out_proj.weight torch.Size([4096, 4096]) +transformer.blocks.31.norm_2.weight torch.Size([4096]) +transformer.blocks.31.ffn.up_proj.weight torch.Size([16384, 4096]) +transformer.blocks.31.ffn.down_proj.weight torch.Size([4096, 16384]) + +transformer.norm_f.weight torch.Size([4096]) \ No newline at end of file From 29800ea697b0796f90ebfd192c436207a416bf54 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 14:11:27 +0300 Subject: [PATCH 055/114] upstream weights mapping --- build.py | 1 + mlc_llm/relax_model/mpt.py | 48 +++++++++++++++++++------------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/build.py b/build.py index eaabf794b8..2abbf00426 100644 --- a/build.py +++ b/build.py @@ -277,6 +277,7 @@ def mod_transform_before_build( elif ARGS.model.startswith("mpt-"): model_names = [ "decode", + "get_metadata", ] else: model_names = [ diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 94d5d19d20..22d5919c3a 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -9,6 +9,8 @@ from tvm.script import relax as R from .mpt_config import MPTConfig, attn_config_defaults +from ..utils import load_torch_pname2binname_map +from .commons import create_metadata_func from .modules import ( Embedding, LayerNorm, @@ -964,7 +966,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, return pidx2pname def get_model(args, hf_config): - from transformers import AutoModelForCausalLM # type: ignore[import] + assert model_name.startswith("mpt-"), f"Unsupported model name: {args.model_name}" model_name = args.model model_path = args.model_path @@ -975,31 +977,29 @@ def get_model(args, hf_config): hf_config.update({"max_seq_len": max_seq_len}) # hf_config.update({"max_new_tokens": args.seq_len}) - if model_name.startswith("mpt-"): - config = MPTConfig(**hf_config, dtype=dtype) + config = MPTConfig(**hf_config, dtype=dtype) - bb = relax.BlockBuilder() - create_decoding_func(bb, config) - mod = bb.get() + bb = relax.BlockBuilder() + pidx2pname = create_decoding_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=-1, # TODO: check + stop_tokens=[0], # TODO: check for mpt embeddings + add_prefix_space=False, # TODO: what is it? + ) - device = tvm.cpu() + mod = bb.get() - hf_model = AutoModelForCausalLM.from_pretrained( - model_path, - trust_remote_code=True, - ) - for name, param in hf_model.named_parameters(): - print(name, param.shape) - # Get a list of parameters in advance, then delete the model to save memory - param_list = [param for _, param in hf_model.named_parameters()] - - for i, param in enumerate(param_list): - param_list[i] = tvm.nd.array( - param.detach().cpu().numpy().astype(dtype), device - ) - del hf_model + pname2binname = load_torch_pname2binname_map( + model_path, set(pidx2pname.values()) + ) + + # device = tvm.cpu() - print(mod) - return mod, param_list + args.pidx2pname = pidx2pname + args.pname2binname = pname2binname + # args.f_convert_pname_fwd = f_convert_pname_fwd + # args.f_convert_param_bkwd = f_convert_param_bkwd - raise ValueError(f"Unsupported model: {model_name}") \ No newline at end of file + return mod, [None] * len(pidx2pname) \ No newline at end of file From 569310871ae512e9be4305f0c5287f6acc12e7b1 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 14:14:53 +0300 Subject: [PATCH 056/114] fix assert check --- mlc_llm/relax_model/mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 22d5919c3a..8263ea1c85 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -966,9 +966,9 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, return pidx2pname def get_model(args, hf_config): - assert model_name.startswith("mpt-"), f"Unsupported model name: {args.model_name}" - model_name = args.model + assert model_name.startswith("mpt-") , f"Unsupported model name: {model_name}" + model_path = args.model_path dtype = args.quantization.model_dtype # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct From a16753b85625457017c6c4fa1ba219e70bc8c132 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 14:41:41 +0300 Subject: [PATCH 057/114] add custom f_convert_pname_fwd --- mlc_llm/relax_model/mpt.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 8263ea1c85..64f4b86704 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -991,15 +991,23 @@ def get_model(args, hf_config): mod = bb.get() + def f_convert_pname_fwd(pname: str) -> str: + if ( + "self_attn" in pname + ): + return pname.replace("self_attn", "attn") + else: + return pname + pname2binname = load_torch_pname2binname_map( - model_path, set(pidx2pname.values()) + model_path, set(pidx2pname.values()), f_convert_pname_fwd ) # device = tvm.cpu() args.pidx2pname = pidx2pname args.pname2binname = pname2binname - # args.f_convert_pname_fwd = f_convert_pname_fwd + args.f_convert_pname_fwd = f_convert_pname_fwd # args.f_convert_param_bkwd = f_convert_param_bkwd return mod, [None] * len(pidx2pname) \ No newline at end of file From 2c091f6f661c1e03fd101e5b2f901ee7c86e135a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 14:44:59 +0300 Subject: [PATCH 058/114] once more update of f_convert_pname_fwd --- mlc_llm/relax_model/mpt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 64f4b86704..48afdeaef1 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -992,10 +992,10 @@ def get_model(args, hf_config): mod = bb.get() def f_convert_pname_fwd(pname: str) -> str: - if ( - "self_attn" in pname - ): + if "self_attn" in pname: return pname.replace("self_attn", "attn") + elif "mlp" in pname: + return pname.replace("mlp", "ffn") else: return pname From fb0fc911c78e95b5c184e267a5055fdab79ee2c9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 14:57:10 +0300 Subject: [PATCH 059/114] skip bias from weights --- mlc_llm/relax_model/mpt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 48afdeaef1..492b141605 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -454,7 +454,7 @@ def __init__( if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) - self.Wqkv = Linear(self.d_model, 3 * self.d_model, dtype) + self.Wqkv = Linear(self.d_model, 3 * self.d_model, dtype, bias=False) fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: @@ -476,7 +476,7 @@ def __init__( self.attn_fn = scaled_multihead_dot_product_attention else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') - self.out_proj = Linear(self.d_model, self.d_model, dtype) + self.out_proj = Linear(self.d_model, self.d_model, dtype, bias=False) # TODO: Does field _is_residual exist? # self.out_proj._is_residual = True @@ -513,11 +513,11 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i class MPTMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, dtype: str): - self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype) - self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) def forward(self, x): - return self.down_proj(relax.op.nn.gelu(self.up_proj(x))) + return self.down_proj(relax.op.nn.gelu(self.up_proj(x))) class MPTBlock(nn.Module): From 0b46f6f07605f49786ed8471fd16dd4d0b8accd0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 15:08:33 +0300 Subject: [PATCH 060/114] add f_convert_param_bkwd --- mlc_llm/relax_model/mpt.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 492b141605..095aabfdac 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1003,11 +1003,18 @@ def f_convert_pname_fwd(pname: str) -> str: model_path, set(pidx2pname.values()), f_convert_pname_fwd ) - # device = tvm.cpu() + def f_convert_param_bkwd(torch_pname: str, raw_param): + if "attn" in pname: + pname = torch_pname.replace("attn", "self_attn") + elif "ffn" in pname: + pname = torch_pname.replace("ffn", "mlp") + else: + pname = torch_pname + return [(pname, raw_param)] args.pidx2pname = pidx2pname args.pname2binname = pname2binname args.f_convert_pname_fwd = f_convert_pname_fwd - # args.f_convert_param_bkwd = f_convert_param_bkwd + args.f_convert_param_bkwd = f_convert_param_bkwd return mod, [None] * len(pidx2pname) \ No newline at end of file From ffed1d51f24bc7fedb172207a5da56f253ba054f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 16 Jun 2023 15:20:25 +0300 Subject: [PATCH 061/114] try to fix bfloat16 --- mlc_llm/relax_model/mpt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt.py index 095aabfdac..fb7264fe2a 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt.py @@ -1010,6 +1010,10 @@ def f_convert_param_bkwd(torch_pname: str, raw_param): pname = torch_pname.replace("ffn", "mlp") else: pname = torch_pname + + # TVM does not support bfloat16 + if raw_param.dtype == "bfloat16": + raw_param = raw_param.astype("float16") return [(pname, raw_param)] args.pidx2pname = pidx2pname From 7b6138717f19cf8ba0e0982dccce207e09178f99 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 19 Jun 2023 09:08:03 +0300 Subject: [PATCH 062/114] file structure for mpt model was refactored --- mlc_llm/relax_model/__init__.py | 2 +- {docs/models => mlc_llm/relax_model}/mpt/README.md | 0 mlc_llm/relax_model/{ => mpt}/mpt.py | 6 +++--- mlc_llm/relax_model/{ => mpt}/mpt_config.py | 0 {docs/models => mlc_llm/relax_model}/mpt/mpt_topology.txt | 0 5 files changed, 4 insertions(+), 4 deletions(-) rename {docs/models => mlc_llm/relax_model}/mpt/README.md (100%) rename mlc_llm/relax_model/{ => mpt}/mpt.py (99%) rename mlc_llm/relax_model/{ => mpt}/mpt_config.py (100%) rename {docs/models => mlc_llm/relax_model}/mpt/mpt_topology.txt (100%) diff --git a/mlc_llm/relax_model/__init__.py b/mlc_llm/relax_model/__init__.py index 1ee1adbe07..d50967ec9c 100644 --- a/mlc_llm/relax_model/__init__.py +++ b/mlc_llm/relax_model/__init__.py @@ -1,2 +1,2 @@ from . import llama -from . import mpt +from .mpt import mpt diff --git a/docs/models/mpt/README.md b/mlc_llm/relax_model/mpt/README.md similarity index 100% rename from docs/models/mpt/README.md rename to mlc_llm/relax_model/mpt/README.md diff --git a/mlc_llm/relax_model/mpt.py b/mlc_llm/relax_model/mpt/mpt.py similarity index 99% rename from mlc_llm/relax_model/mpt.py rename to mlc_llm/relax_model/mpt/mpt.py index fb7264fe2a..a5c16e14f7 100644 --- a/mlc_llm/relax_model/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -9,9 +9,9 @@ from tvm.script import relax as R from .mpt_config import MPTConfig, attn_config_defaults -from ..utils import load_torch_pname2binname_map -from .commons import create_metadata_func -from .modules import ( +from ...utils import load_torch_pname2binname_map +from ..commons import create_metadata_func +from ..modules import ( Embedding, LayerNorm, Linear, diff --git a/mlc_llm/relax_model/mpt_config.py b/mlc_llm/relax_model/mpt/mpt_config.py similarity index 100% rename from mlc_llm/relax_model/mpt_config.py rename to mlc_llm/relax_model/mpt/mpt_config.py diff --git a/docs/models/mpt/mpt_topology.txt b/mlc_llm/relax_model/mpt/mpt_topology.txt similarity index 100% rename from docs/models/mpt/mpt_topology.txt rename to mlc_llm/relax_model/mpt/mpt_topology.txt From fcdc9158e8ddf13f463bead97493809c7386ec23 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 19 Jun 2023 10:27:59 +0300 Subject: [PATCH 063/114] add script to convert model from bfloat16 to float16 --- .../relax_model/mpt/bfloat16_to_float16.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 mlc_llm/relax_model/mpt/bfloat16_to_float16.py diff --git a/mlc_llm/relax_model/mpt/bfloat16_to_float16.py b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py new file mode 100644 index 0000000000..f2c54f205f --- /dev/null +++ b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py @@ -0,0 +1,50 @@ +# The procedure is based on https://huggingface.co/transformers/v1.2.0/serialization.html#serialization-best-practices +from pathlib import Path +import argparse + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import WEIGHTS_NAME, CONFIG_NAME + + +def load_bf16_model(dir_path, tokenizer_name): + model = AutoModelForCausalLM.from_pretrained( + dir_path, + trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + return model, tokenizer + + +def save_fp16_model(dir_path, model, tokenizer): + model_to_save = model.module if hasattr(model, 'module') else model + + output_model_file = Path.joinpath(dir_path, WEIGHTS_NAME) + output_config_file = Path.joinpath(dir_path, CONFIG_NAME) + + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(dir_path) + + +def main(args): + model_root_dir = Path(args.model_path) + new_name = model_root_dir.name + "-float16" + out_path = model_root_dir.parent.joinpath(new_name) + + model, tokenizer = load_bf16_model(model_root_dir, args.tokenizer) + model.to(dtype=torch.float16) + model.save_pretrained(out_path, from_pt=True) + # save_fp16_model(out_path, model, tokenizer) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('-m', '--model_path', type=str, default="../../../dist/models/mpt-7b-instruct", + help="The path to directory with bfloat16 model") + parser.add_argument('-t', '--tokenizer', type=str, default="EleutherAI/gpt-neox-20b", + help="Tag for transformers to upload correct tokenizer") + + args = parser.parse_args() + main(args) From 4126a8e4e9b843882465514ea20af8f103977757 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 19 Jun 2023 10:44:45 +0300 Subject: [PATCH 064/114] fix f_convert_param_bkwd --- mlc_llm/relax_model/mpt/mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index a5c16e14f7..79afe6cdb3 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -1004,9 +1004,9 @@ def f_convert_pname_fwd(pname: str) -> str: ) def f_convert_param_bkwd(torch_pname: str, raw_param): - if "attn" in pname: + if "attn" in torch_pname: pname = torch_pname.replace("attn", "self_attn") - elif "ffn" in pname: + elif "ffn" in torch_pname: pname = torch_pname.replace("ffn", "mlp") else: pname = torch_pname From 8467179c902abb07713f0544ce90a2a86576f1fa Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 19 Jun 2023 11:51:57 +0300 Subject: [PATCH 065/114] clean code for conversion script, add desc --- mlc_llm/relax_model/mpt/README.md | 9 +++++- .../relax_model/mpt/bfloat16_to_float16.py | 30 +++++++++++-------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/mpt/README.md b/mlc_llm/relax_model/mpt/README.md index ce9b912495..6e55766d8e 100644 --- a/mlc_llm/relax_model/mpt/README.md +++ b/mlc_llm/relax_model/mpt/README.md @@ -2,6 +2,13 @@ There is brief description of mpt-7b-instruct model. It is needed for correct Relax implementation of the model and weights mapping. MPT-7b-instruct is decoder-like kv_cache free model using flash attention. +Data type is brain float16 by default. But numpy used inside scripts and TVM do not support this type. Due to this to compile MPT-like model use following script: +```bash +python3 bfloat16_to_float16.py +``` +It is saved converted model in `dist/models/-float16` directory. +**Note:** After conversion to float16, only weights and config will be saved. Transfer other files (like tokenizer vocab) from the original directory. + The list of Tensor name - tensor size for the original (pytorch) model can be found in mpt_topology.txt file. The original config for the model: { @@ -59,4 +66,4 @@ The original config for the model: This config wraps default one. It should highlight two defaults parameters: "is_encoder_decoder": false, -"use_cache": false, \ No newline at end of file +"use_cache": false, diff --git a/mlc_llm/relax_model/mpt/bfloat16_to_float16.py b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py index f2c54f205f..380e7887f0 100644 --- a/mlc_llm/relax_model/mpt/bfloat16_to_float16.py +++ b/mlc_llm/relax_model/mpt/bfloat16_to_float16.py @@ -1,4 +1,3 @@ -# The procedure is based on https://huggingface.co/transformers/v1.2.0/serialization.html#serialization-best-practices from pathlib import Path import argparse @@ -17,26 +16,33 @@ def load_bf16_model(dir_path, tokenizer_name): return model, tokenizer -def save_fp16_model(dir_path, model, tokenizer): - model_to_save = model.module if hasattr(model, 'module') else model +def save_fp16_model(dir_path, model, tokenizer, manually=False): + new_name = dir_path.name + "-float16" + out_path = dir_path.parent.joinpath(new_name) - output_model_file = Path.joinpath(dir_path, WEIGHTS_NAME) - output_config_file = Path.joinpath(dir_path, CONFIG_NAME) + if manually: + # Manual saving + output_model_file = Path.joinpath(out_path, WEIGHTS_NAME) + output_config_file = Path.joinpath(out_path, CONFIG_NAME) - torch.save(model_to_save.state_dict(), output_model_file) - model_to_save.config.to_json_file(output_config_file) - tokenizer.save_vocabulary(dir_path) + model_to_save = model.module if hasattr(model, 'module') else model + torch.save(model_to_save.state_dict(), output_model_file) + model_to_save.config.to_json_file(output_config_file) + tokenizer.save_vocabulary(out_path) + else: + # Use transformer API + model.save_pretrained(out_path, from_pt=True) def main(args): model_root_dir = Path(args.model_path) - new_name = model_root_dir.name + "-float16" - out_path = model_root_dir.parent.joinpath(new_name) + # Load original model (bfloat16) model, tokenizer = load_bf16_model(model_root_dir, args.tokenizer) + # Convert data type to float 16 model.to(dtype=torch.float16) - model.save_pretrained(out_path, from_pt=True) - # save_fp16_model(out_path, model, tokenizer) + # Save converted model + save_fp16_model(model_root_dir, model, tokenizer) if __name__ == "__main__": From 896cdc92a4bf7950d19520ff829e4f877fe9687f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 19 Jun 2023 16:35:37 +0300 Subject: [PATCH 066/114] workaround for lookup func --- mlc_llm/dispatch/dispatch_tir_operator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlc_llm/dispatch/dispatch_tir_operator.py b/mlc_llm/dispatch/dispatch_tir_operator.py index 93b72256c2..5aafa6d8fd 100644 --- a/mlc_llm/dispatch/dispatch_tir_operator.py +++ b/mlc_llm/dispatch/dispatch_tir_operator.py @@ -19,6 +19,9 @@ def __init__(self, model: str): elif model == "rwkv": lookup = None + elif model == "mpt": + lookup = None + else: raise ValueError(f"Model {model} not supported") self.lookup = lookup From eb8ea82fca67a97d617a04829f198a91e643c472 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 10:04:41 +0300 Subject: [PATCH 067/114] add dummy create_kv_cache_func method to support mlc_chat_cli --- mlc_llm/relax_model/mpt/mpt.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 79afe6cdb3..851b6a1db7 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -965,6 +965,29 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, return pidx2pname + +# Dummy func to support mlc_chat_cli +def create_kv_cache_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + init_shape = relax.ShapeExpr((1,)) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.n_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + def get_model(args, hf_config): model_name = args.model assert model_name.startswith("mpt-") , f"Unsupported model name: {model_name}" @@ -981,6 +1004,7 @@ def get_model(args, hf_config): bb = relax.BlockBuilder() pidx2pname = create_decoding_func(bb, config) + create_kv_cache_func(bb, config) create_metadata_func( bb, model_name=model_name, From 4b8f4e8450a2670fe8a5233c6085cfa9ea4a85f4 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 16:59:50 +0300 Subject: [PATCH 068/114] debug log --- build.py | 1 + 1 file changed, 1 insertion(+) diff --git a/build.py b/build.py index 2abbf00426..8330bdf6e1 100644 --- a/build.py +++ b/build.py @@ -430,6 +430,7 @@ def main(): mod = pickle.load(pkl) dump_split_tir(mod, ARGS) if not ARGS.reuse_lib: + print("MOD before BUILD:", mod) build(mod, ARGS) else: print("Reuse existing prebuilt lib {ARGS.reuse_lib}...") From fa6db2d4917369d70e5491e5871d0e5332554c15 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 17:04:22 +0300 Subject: [PATCH 069/114] add create_softmax_func for MPT --- mlc_llm/relax_model/mpt/mpt.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 851b6a1db7..084beb36e4 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -988,6 +988,19 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: bb.emit_func_output(gv) +def create_softmax_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, config.vocab_size), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + def get_model(args, hf_config): model_name = args.model assert model_name.startswith("mpt-") , f"Unsupported model name: {model_name}" @@ -1005,6 +1018,7 @@ def get_model(args, hf_config): bb = relax.BlockBuilder() pidx2pname = create_decoding_func(bb, config) create_kv_cache_func(bb, config) + create_softmax_func(bb, config) create_metadata_func( bb, model_name=model_name, From 8a72a9edb71e9ce08e6cf8266831ac4b285f0738 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 19:02:15 +0300 Subject: [PATCH 070/114] update transform --- build.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/build.py b/build.py index 8330bdf6e1..c2141316c0 100644 --- a/build.py +++ b/build.py @@ -277,6 +277,8 @@ def mod_transform_before_build( elif ARGS.model.startswith("mpt-"): model_names = [ "decode", + "create_kv_cache", + "softmax_with_temperature", "get_metadata", ] else: From 20ab52dfeade09edf2a753ae7fe3263c1911820e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 19:13:01 +0300 Subject: [PATCH 071/114] remove debug log --- build.py | 1 - 1 file changed, 1 deletion(-) diff --git a/build.py b/build.py index c2141316c0..7ff61e1373 100644 --- a/build.py +++ b/build.py @@ -432,7 +432,6 @@ def main(): mod = pickle.load(pkl) dump_split_tir(mod, ARGS) if not ARGS.reuse_lib: - print("MOD before BUILD:", mod) build(mod, ARGS) else: print("Reuse existing prebuilt lib {ARGS.reuse_lib}...") From 7499a90d71ac4789341d05fb5335a9e131ca8d34 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 20 Jun 2023 19:43:04 +0300 Subject: [PATCH 072/114] add conversation for MPT --- cpp/conv_templates.cc | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 91b6893b46..98ef7424c7 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -295,6 +295,25 @@ Conversation CodeGPT() { return conv; } +Conversation MPT() { + Conversation conv; + conv.name = "mpt"; + conv.system = ""; + conv.roles = {"client", "instructor"}; + conv.messages = {}; + conv.separator_style = SeparatorStyle::kSepRoleMsg; + conv.offset = 0; + conv.seps = {"\n"}; + conv.role_msg_sep = ": "; + conv.role_empty_sep = "?"; + conv.stop_str = "stop"; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + conv.stop_tokens = {0}; + conv.add_bos = false; + return conv; +} + } // namespace using ConvFactory = Conversation (*)(); @@ -312,6 +331,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"moss", MOSS}, {"LM", VanillaLM}, {"code_gpt", CodeGPT}, + {"mpt", MPT}, }; auto it = factory.find(name); if (it == factory.end()) { From be4977f6b2a7a3a97d5f0935fbc8a3ebfb9b4500 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 09:51:28 +0300 Subject: [PATCH 073/114] remove kv_cache_func --- mlc_llm/relax_model/mpt/mpt.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 084beb36e4..e6bdf17e0a 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -966,28 +966,6 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, return pidx2pname -# Dummy func to support mlc_chat_cli -def create_kv_cache_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: - init_shape = relax.ShapeExpr((1,)) - with bb.function("create_kv_cache", []): - with bb.dataflow(): - zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) - caches = [] - f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") - for _ in range(config.n_layers * 2): - caches.append( - bb.emit( - relax.Call( - f_kv_cache_create, - args=[zeros, init_shape, relax.PrimValue(0)], - sinfo_args=[relax.ObjectStructInfo()], - ) - ) - ) - gv = bb.emit_output(caches) - bb.emit_func_output(gv) - - def create_softmax_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: with bb.function("softmax_with_temperature"): logits = nn.Placeholder( @@ -1017,7 +995,6 @@ def get_model(args, hf_config): bb = relax.BlockBuilder() pidx2pname = create_decoding_func(bb, config) - create_kv_cache_func(bb, config) create_softmax_func(bb, config) create_metadata_func( bb, From 9e5857777dc29686183e8af6165f0758f0d371f8 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 10:13:52 +0300 Subject: [PATCH 074/114] remove kv_cache from the list of funcs for mpt --- build.py | 1 - 1 file changed, 1 deletion(-) diff --git a/build.py b/build.py index 7ff61e1373..16602042c4 100644 --- a/build.py +++ b/build.py @@ -277,7 +277,6 @@ def mod_transform_before_build( elif ARGS.model.startswith("mpt-"): model_names = [ "decode", - "create_kv_cache", "softmax_with_temperature", "get_metadata", ] From 89b216c7a50c7005fc715fae814b2d5f5cb4d457 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 10:31:22 +0300 Subject: [PATCH 075/114] cast logits to float32 before softmax with temperature --- mlc_llm/relax_model/mpt/mpt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index e6bdf17e0a..b3364626c1 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -886,6 +886,9 @@ def forward( ) logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + return logits, outputs[1] def fsdp_wrap_fn(self, module): From e7d1ef87767a73a41eb1a9c5846e54d5c9ebbecb Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 12:51:35 +0300 Subject: [PATCH 076/114] debug log: check contiguous weight --- mlc_llm/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index c653ee6fb0..7ae6c09c0e 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -229,6 +229,8 @@ def get_item(i): raw_param = torch_params[torch_param_name].detach().cpu().numpy() del torch_params[torch_param_name] + if not raw_param.flags['C_CONTIGUOUS']: + print("NON_CONTIGUOUS TENSOR WAS FOUND:", torch_param_name) for param_name, param in f_convert_param_bkwd( torch_param_name, raw_param ): From bee5ffb783f1a9907d32743c1267550106b47daa Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 15:42:39 +0300 Subject: [PATCH 077/114] debug log for logits --- cpp/llm_chat.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index b56fd88c68..19e3358f33 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -783,6 +783,18 @@ class LLMChat { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; + + for (int i = 0; i < logits_on_cpu_->ndim; ++i) { + std::cout << "LOGITS SHAPE[" << i << "] = " << logits_on_cpu_->shape[i] << " "; + } + std::cout << std::endl; + int64_t ndata = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; + const float* p_prob = static_cast(logits_on_cpu_->data); + std::cout << "Logits data: "; + for (int i = 0; i < ndata; ++i) { + std::cout << p_prob[i] << " "; + } + std::cout << std::endl; return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber()); } From c96d074a69f012fc06500df8c6546cb6beef0c46 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 22 Jun 2023 16:23:32 +0300 Subject: [PATCH 078/114] update readme and prepare_inputs_for_generation based on mpt model specific --- mlc_llm/relax_model/mpt/README.md | 122 +++++++++++++++++++++++++++++- mlc_llm/relax_model/mpt/mpt.py | 16 +--- 2 files changed, 122 insertions(+), 16 deletions(-) diff --git a/mlc_llm/relax_model/mpt/README.md b/mlc_llm/relax_model/mpt/README.md index 6e55766d8e..0169b3a68f 100644 --- a/mlc_llm/relax_model/mpt/README.md +++ b/mlc_llm/relax_model/mpt/README.md @@ -59,11 +59,129 @@ The original config for the model: "tokenizer_name": "EleutherAI/gpt-neox-20b", "torch_dtype": "bfloat16", "transformers_version": "4.28.1", - "use_cache": false, + **"use_cache": false,** "verbose": 0, "vocab_size": 50432 } -This config wraps default one. It should highlight two defaults parameters: +This config wraps default one (see below). It should highlight two defaults parameters: "is_encoder_decoder": false, "use_cache": false, + +Default config parameters (PretrainedConfig): +"return_dict": True +"output_hidden_states": False +"output_attentions": False +"torchscript": False +"torch_dtype": None +"use_bfloat16": False +"tf_legacy_loss": False +"pruned_heads": {} +"tie_word_embeddings": True + +**"is_encoder_decoder": False** +"is_decoder": False +"cross_attention_hidden_size": None +"add_cross_attention": False +"tie_encoder_decoder": False + +"max_length": 20 +"min_length": 0 +"do_sample": False +"early_stopping": False +"num_beams": 1 +"num_beam_groups": 1 +"diversity_penalty": 0.0 +"temperature": 1.0 +"top_k": 50 +"top_p": 1.0 +"typical_p": 1.0 +"repetition_penalty": 1.0 +"length_penalty": 1.0 +"no_repeat_ngram_size": 0 +"encoder_no_repeat_ngram_size": 0 +"bad_words_ids": None +"num_return_sequences": 1 +"chunk_size_feed_forward": 0 +"output_scores": False +"return_dict_in_generate": False +"forced_bos_token_id": None +"forced_eos_token_id": None +"remove_invalid_values": False +"exponential_decay_length_penalty": None +"suppress_tokens": None +"begin_suppress_tokens": None + +"architectures": None +"finetuning_task": None +"id2label": None +"label2id": None +if self.id2label is not None: + "num_labels": None + id2label = dict((int(key), value) for key, value in id2label.items()) +else: + "num_labels": 2 + +"tokenizer_class": None +"prefix": None +"bos_token_id": None +"pad_token_id": None +"eos_token_id": None +"sep_token_id": None + +"decoder_start_token_id": None + +"task_specific_params": None + + +Refactored greedy_search method for MPT-7b-instruct: +```python +def greedy_search(...): + # init values + logits_processor = LogitsProcessorList() + stopping_criteria = stopping_criteria # max_length and max_time criteria + pad_token_id = None + eos_token_id = None + output_scores = False + output_attentions = False + output_hidden_states = False + return_dict_in_generate = False + + # init attention / hidden states / scores tuples + scores = None + decoder_attentions = None + decoder_hidden_states = None + + # keep track of which sequences are already finished + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + + while True: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution. Due to logits_processor is empty next_tokens_scores = next_token_logits + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ) + + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break +``` \ No newline at end of file diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index b3364626c1..3f3c4cb101 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -898,28 +898,16 @@ def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - if inputs_embeds is not None: - raise NotImplementedError('inputs_embeds is not implemented for MPT yet') attention_mask = kwargs['attention_mask'].bool() if attention_mask[:, -1].sum() != attention_mask.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') - if self.transformer.attn_uses_sequence_id and self.training: - # slicing input_ids[:1] - input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [0], [relax.const(0)], [relax.const(1)])) - sequence_id = nn.emit(relax.op.zeros_like(input_ids_slice)) - else: - sequence_id = None + sequence_id = None if past_key_values is not None: # slicing input_ids[:, -1] dim1_len = input_ids.struct_info.shape[1] input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) - if self.transformer.prefix_lm: - prefix_mask = nn.emit(relax.op.ones_like(attention_mask, self.dtype)) - if kwargs.get('use_cache') == False: - raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') - else: - prefix_mask = None + prefix_mask = None return { 'input_ids': input_ids, 'attention_mask': attention_mask, From 6b6fdf171d29de37b410e033a0e2523bf94fd530 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 27 Jun 2023 21:56:33 +0300 Subject: [PATCH 079/114] flash attn implementation was transferred to outside mpt model implementation --- mlc_llm/relax_model/mpt/mpt.py | 257 ++++++++++----------------------- 1 file changed, 73 insertions(+), 184 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 3f3c4cb101..e22fcf1605 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -163,188 +163,78 @@ def check_valid_inputs(*tensors, valid_dtypes=["float16", "bfloat16"]): # raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).') -######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (BEGIN) ########################## - -# class IndexFirstAxis(torch.autograd.Function): - -# @staticmethod -# def forward(ctx, input, indices): -# ctx.save_for_backward(indices) -# assert input.ndim >= 2 -# ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] -# second_dim = other_shape.numel() -# # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. -# # return input[indices] -# return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, -# repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) - - -# index_first_axis = IndexFirstAxis.apply - -# import torch -# class IndexPutFirstAxis(torch.autograd.Function): - -# @staticmethod -# def forward(ctx, values, indices, first_axis_dim): -# ctx.save_for_backward(indices) -# assert indices.ndim == 1 -# assert values.ndim >= 2 -# output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, -# dtype=values.dtype) -# # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. -# output[indices] = values -# # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) -# return output - - -# index_put_first_axis = IndexPutFirstAxis.apply - - -# def bert_padding_unpad_input(hidden_states, attention_mask): -# seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) -# indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() -# max_seqlen_in_batch = seqlens_in_batch.max().item() -# cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) -# # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the -# # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim -# # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to -# # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, -# # so we write custom forward and backward to make it a bit faster. -# return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, -# cu_seqlens, max_seqlen_in_batch) - - -# def bert_padding_pad_input(hidden_states, indices, batch, seqlen): -# """ -# Arguments: -# hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. -# indices: (total_nnz) -# Return: -# hidden_states: (batch, seqlen, ...) -# """ -# dim = hidden_states.shape[-1] -# # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) -# # output[indices] = hidden_states -# output = index_put_first_axis(hidden_states, indices, batch * seqlen) -# return rearrange(output, '(b s) ... -> b s ...', b=batch) - -# import flash_attn_cuda -# def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, -# dropout_p, softmax_scale, causal, return_softmax, num_splits=0, -# generator=None): -# """ -# num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means -# it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. -# Don't change it unless you know what you're doing. -# """ -# softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( -# q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, -# softmax_scale, False, causal, return_softmax, num_splits, generator -# ) -# # if out.isnan().any() or softmax_lse.isnan().any(): -# # breakpoint() -# S_dmask = rest[0] if return_softmax else None -# return out, softmax_lse, rng_state, S_dmask - - -# class FlashAttnFunc(torch.autograd.Function): -# @staticmethod -# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, -# softmax_scale, causal, return_softmax, deterministic): -# if softmax_scale is None: -# softmax_scale = q.shape[-1] ** (-0.5) -# out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( -# q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, -# dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax -# ) -# ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) -# ctx.dropout_p = dropout_p -# ctx.max_seqlen_q = max_seqlen_q -# ctx.max_seqlen_k = max_seqlen_k -# ctx.softmax_scale = softmax_scale -# ctx.causal = causal -# ctx.deterministic = deterministic -# return out if not return_softmax else (out, softmax_lse, S_dmask) - - -# def flash_attn_interface_flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, -# dropout_p, softmax_scale=None, causal=False, return_attn_probs=False, -# deterministic=False -# ): -# return FlashAttnFunc.apply(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, -# dropout_p, softmax_scale, causal, return_attn_probs, deterministic) - - -# def flash_attn_fn( -# query, -# key, -# value, -# n_heads, -# d_model, -# past_key_value=None, -# softmax_scale=None, -# attn_bias=None, -# key_padding_mask=None, -# is_causal=False, -# needs_weights=False, -# multiquery=False -# ): -# check_valid_inputs(query, key, value) -# if past_key_value is not None: -# if len(past_key_value) != 0: -# key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) -# value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) -# past_key_value = (key, value) -# if attn_bias is not None: -# _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - query.struct_info.shape[1]) -# _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key.struct_info.shape[1]) -# # slicing attn_bias[:, :, _s_q:, _s_k:] -# s_q_end, s_k_end = attn_bias.struct_info.shape[-2:] -# attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) -# if attn_bias is not None: -# raise NotImplementedError(f'attn_bias not implemented for flash attn.') -# batch_size, seqlen = query.struct_info.shape[:2] -# if key_padding_mask is None: -# key_shape = key.struct_info.shape[:2] -# key_padding_mask = nn.emit(relax.op.ones(Tuple(*key_shape, 1), dtype="bool")) -# # slicing key_padding_mask[:, -query.struct_info.shape[1]:] -# dim1_length = key_padding_mask.struct_info.shape[1] -# query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) -# (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) - -# qnnz, _, _ = query_unpad.struct_info.shape -# query_unpad = nn.emit(relax.op.reshape( -# query_unpad, -# (qnnz, n_heads, d_model), -# )) # (nnz, (h d)) -> (nnz, h, d) - -# kv_nnz, _, _ = key_unpad.struct_info.shape -# kv_n_heads = 1 if multiquery else n_heads -# (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) -# key_unpad = nn.emit(relax.op.reshape( -# key_unpad, -# (kv_nnz, kv_n_heads, d_model), -# )) # (nnz, (h d)) -> (nnz, h, d) -# (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) -# value_unpad = nn.emit(relax.op.reshape( -# value_unpad, -# (kv_nnz, kv_n_heads, d_model), -# )) # (nnz, (h d)) -> (nnz, h, d) - -# if multiquery: -# key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) -# value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) -# reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key.struct_info.shape[1], is_causal) -# output_unpad = flash_attn_interface_flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) -# nnz, _, _ = output_unpad.struct_info.shape -# output_unpad = nn.emit(relax.op.reshape( -# output_unpad, -# (nnz, n_heads*d_model), -# )) # (nnz, h, d)) -> (nnz, (h d)) -# output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) -# return (output, None, past_key_value) - -######################### FLASH ATTENTION IMPLEMENTATION TYPE FLASH (END) ########################## +def flash_attn_fn( + query, + key, + value, + n_heads, + d_model, + past_key_value=None, + softmax_scale=None, + attn_bias=None, + key_padding_mask=None, + is_causal=False, + needs_weights=False, + multiquery=False +): + from ..flash_attn import bert_padding_unpad_input, bert_padding_pad_input, flash_attn_unpadded_func + check_valid_inputs(query, key, value) + if past_key_value is not None: + if len(past_key_value) != 0: + key = nn.emit(relax.op.concat([past_key_value[0], key], axis=1)) + value = nn.emit(relax.op.concat([past_key_value[1], value], axis=1)) + past_key_value = (key, value) + batch_size = query.struct_info.shape[0] + seqlen = query.struct_info.shape[1] + key_shape_d1 = key.struct_info.shape[1] + if attn_bias is not None: + _s_q = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[2] - seqlen) + _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[3] - key_shape_d1) + # slicing attn_bias[:, :, _s_q:, _s_k:] + s_q_end = attn_bias.struct_info.shape[2] + s_k_end = attn_bias.struct_info.shape[3] + attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) + if attn_bias is not None: + raise NotImplementedError(f'attn_bias not implemented for flash attn.') + if key_padding_mask is None: + key_shape_d0 = key.struct_info.shape[0] + key_padding_mask = nn.emit(relax.op.ones(Tuple(key_shape_d0, key_shape_d1, 1), dtype="bool")) + # slicing key_padding_mask[:, -query.struct_info.shape[1]:] + dim1_length = key_padding_mask.struct_info.shape[1] + query_padding_mask = nn.emit(relax.op.strided_slice(key_padding_mask, [1], [dim1_length - seqlen], [dim1_length])) + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding_unpad_input(query, query_padding_mask) + + qnnz = query_unpad.struct_info.shape[0] + query_unpad = nn.emit(relax.op.reshape( + query_unpad, + (qnnz, n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + + kv_nnz = key_unpad.struct_info.shape[0] + kv_n_heads = 1 if multiquery else n_heads + (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding_unpad_input(key, key_padding_mask) + key_unpad = nn.emit(relax.op.reshape( + key_unpad, + (kv_nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + (value_unpad, _, _, _) = bert_padding_unpad_input(value, key_padding_mask) + value_unpad = nn.emit(relax.op.reshape( + value_unpad, + (kv_nnz, kv_n_heads, d_model), + )) # (nnz, (h d)) -> (nnz, h, d) + + if multiquery: + key_unpad = relax.op.broadcast_to(key_unpad, (kv_nnz, n_heads, d_model)) + value_unpad = relax.op.broadcast_to(value_unpad, (kv_nnz, n_heads, d_model)) + reset_is_causal = _reset_is_causal(query.struct_info.shape[1], key_shape_d1, is_causal) + output_unpad = flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights) + nnz = output_unpad.struct_info.shape[0] + output_unpad = nn.emit(relax.op.reshape( + output_unpad, + (nnz, n_heads*d_model), + )) # (nnz, h, d)) -> (nnz, (h d)) + output = bert_padding_pad_input(output_unpad, indices_q, batch_size, seqlen) + return (output, None, past_key_value) ######################### FLASH ATTENTION IMPLEMENTATION TYPE TRITON (BEGIN) ########################## @@ -462,8 +352,7 @@ def __init__( self.q_ln = layernorm_class(self.d_model, dtype) self.k_ln = layernorm_class(self.d_model, dtype) if self.attn_impl == 'flash': - raise NotImplemented("Flash type of flash attention has not been implemented yet") - # self.attn_fn = flash_attn_fn + self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': # While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. # When training larger models this can trigger alloc retries which hurts performance. From 0be3602b51be9819d094de89bd076e0eb2ab3175 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 28 Jun 2023 08:53:51 +0300 Subject: [PATCH 080/114] correct import after rename file --- mlc_llm/relax_model/mpt/mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index e22fcf1605..f8f229c516 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -177,7 +177,7 @@ def flash_attn_fn( needs_weights=False, multiquery=False ): - from ..flash_attn import bert_padding_unpad_input, bert_padding_pad_input, flash_attn_unpadded_func + from ..mha_flash_attn import bert_padding_unpad_input, bert_padding_pad_input, flash_attn_unpadded_func check_valid_inputs(query, key, value) if past_key_value is not None: if len(past_key_value) != 0: From 51ce1adcb89ef55ede7efe7dfd3fb231a4e41ebe Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 3 Jul 2023 09:02:00 +0300 Subject: [PATCH 081/114] set temperature to zero by default for mpt model --- build.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/build.py b/build.py index 16602042c4..48ba7fc4d3 100644 --- a/build.py +++ b/build.py @@ -343,6 +343,10 @@ def dump_default_mlc_chat_config(args): config["shift_fill_factor"] = 0.3 config["tokenizer_files"] = utils.get_tokenizer_files(params_path) + # TODO(vchernov): create mechanism which gets default config prepared for specific model and covers this one + if args.model_category == "mpt": + config["temperature"] = 0.0 + dump_path = os.path.join(params_path, "mlc-chat-config.json") with open(dump_path, "w", encoding="utf-8") as outfile: json.dump(config, outfile, indent=4) From a870ba111dfe248c4d0630f737c64a683f898aa0 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 13:55:02 +0300 Subject: [PATCH 082/114] transfer attention_mask preprocessing to decoder forward. update prepare_inputs_for_generation --- mlc_llm/relax_model/mpt/mpt.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index f8f229c516..04806467fe 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -656,6 +656,13 @@ def forward( use_cache = use_cache if use_cache is not None else self.use_cache if attention_mask is not None: attention_mask = nn.emit(relax.op.astype(attention_mask, "bool")) + # TODO(vchernov): I'm not sure we should calculate it and can compare in Relax + # it is part from prepare_inputs_for_generation + dim1_len = attention_mask.struct_info.shape[1] + if relax.op.sum( + relax.op.strided_slice(attention_mask, [1], [dim1_len - 1], [dim1_len]) + ) != attention_mask.struct_info.shape[0]: + raise NotImplementedError('MPT does not support generation with right padding.') if prefix_mask is not None: prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) if not return_dict: @@ -787,21 +794,13 @@ def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - attention_mask = kwargs['attention_mask'].bool() - if attention_mask[:, -1].sum() != attention_mask.shape[0]: - raise NotImplementedError('MPT does not support generation with right padding.') - sequence_id = None if past_key_values is not None: # slicing input_ids[:, -1] dim1_len = input_ids.struct_info.shape[1] input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) - prefix_mask = None return { 'input_ids': input_ids, - 'attention_mask': attention_mask, - 'prefix_mask': prefix_mask, - 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True) } From f44c4db39e33942faf4cf022c14ed74b604b8c73 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 14:16:40 +0300 Subject: [PATCH 083/114] prepare_inputs_for_generation was fully transferred to forward pass --- mlc_llm/relax_model/mpt/mpt.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 04806467fe..b78a799512 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -657,7 +657,7 @@ def forward( if attention_mask is not None: attention_mask = nn.emit(relax.op.astype(attention_mask, "bool")) # TODO(vchernov): I'm not sure we should calculate it and can compare in Relax - # it is part from prepare_inputs_for_generation + # It is part from prepare_inputs_for_generation dim1_len = attention_mask.struct_info.shape[1] if relax.op.sum( relax.op.strided_slice(attention_mask, [1], [dim1_len - 1], [dim1_len]) @@ -769,6 +769,13 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.return_dict use_cache = use_cache if use_cache is not None else self.use_cache + + # It is part from prepare_inputs_for_generation + if past_key_values is not None: + # slicing input_ids[:, -1] + dim1_len = input_ids.struct_info.shape[1] + input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) + input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) outputs = self.transformer( input_ids=input_ids, past_key_values=past_key_values, @@ -793,18 +800,6 @@ def fsdp_wrap_fn(self, module): def activation_checkpointing_fn(self, module): return isinstance(module, MPTBlock) - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): - if past_key_values is not None: - # slicing input_ids[:, -1] - dim1_len = input_ids.struct_info.shape[1] - input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) - input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) - return { - 'input_ids': input_ids, - 'past_key_values': past_key_values, - 'use_cache': kwargs.get('use_cache', True) - } - @staticmethod def _reorder_cache(past_key_values, beam_idx): """Used by HuggingFace generate when using beam search with kv-caching. From 2188310eacfecee2a7817387646e0b282a7df140 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 14:24:09 +0300 Subject: [PATCH 084/114] remove excess methods --- mlc_llm/relax_model/mpt/mpt.py | 36 ---------------------------------- 1 file changed, 36 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index b78a799512..6cfaa2ce52 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -737,24 +737,6 @@ def __init__(self, config: MPTConfig): self.return_dict = config.return_dict self.use_cache = config.use_cache - def get_input_embeddings(self): - return self.transformer.wte - - def set_input_embeddings(self, value): - self.transformer.wte = value - - def get_output_embeddings(self): - return self.transformer.wte - - def set_output_embeddings(self, new_embeddings): - self.transformer.wte = new_embeddings - - def set_decoder(self, decoder): - self.transformer = decoder - - def get_decoder(self): - return self.transformer - def forward( self, input_ids: relax.Expr, @@ -794,24 +776,6 @@ def forward( return logits, outputs[1] - def fsdp_wrap_fn(self, module): - return isinstance(module, MPTBlock) - - def activation_checkpointing_fn(self, module): - return isinstance(module, MPTBlock) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - """Used by HuggingFace generate when using beam search with kv-caching. - See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133 - for an example in transformers. - """ - reordered_past = [] - for layer_past in past_key_values: - reordered_past += [tuple((nn.emit(relax.op.take(past_state, beam_idx, 0)) for past_state in layer_past))] - return reordered_past - - def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: pidx2pname: Dict[int, str] = {} with bb.function("decode"): From d56c11dfa2eb7fa6138fd89165f7f67c1070ade3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 14:27:12 +0300 Subject: [PATCH 085/114] remove excess methods once more --- mlc_llm/relax_model/mpt/mpt.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 6cfaa2ce52..da3d1650fa 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -566,12 +566,6 @@ def __init__(self, config: MPTConfig): self.blocks = ModuleList([MPTBlock(config) for _ in range(config.n_layers)]) self.norm_f = norm_class(config.d_model, dtype=config.dtype) - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, value): - self.wte = value - def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None): if not self._attn_bias_initialized: if self.attn_bias_shape: @@ -720,12 +714,6 @@ def forward( all_hidden_states = all_hidden_states + (x,) return x, past_key_values, all_hidden_states, all_self_attns - def fsdp_wrap_fn(self, module): - return isinstance(module, MPTBlock) - - def activation_checkpointing_fn(self, module): - return isinstance(module, MPTBlock) - class MPTForCausalLM(nn.Module): def __init__(self, config: MPTConfig): From e811c34fec41e202823a3a6b0198df8baef4368a Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 08:34:24 +0300 Subject: [PATCH 086/114] init kv_cache only if need --- cpp/llm_chat.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 19e3358f33..35140a65c1 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -289,8 +289,11 @@ class LLMChat { << "Cannot find env function vm.builtin.attention_kv_cache_array_popn"; fkvcache_array_popn_ = *fkvcache_array_popn; - // Step 4. KV cache creation. - kv_cache_ = vm_->GetFunction("create_kv_cache")(); + // Step 4. KV cache creation if need. + auto kv_cache_func = vm_->GetFunction("create_kv_cache"); + if (kv_cache_func.defined()) { + kv_cache_ = kv_cache_func(); + } // Step 5. KV cache reset. reset_kv_cache_func_ = vm_->GetFunction("reset_kv_cache"); From e7c9350405f8ec9c8294f764ac5bf331ad51a41c Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 21 Jun 2023 08:53:52 +0300 Subject: [PATCH 087/114] do not use kv_cache during Forward if it is empty --- cpp/llm_chat.cc | 6 +++++- mlc_llm/relax_model/mpt/mpt.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 35140a65c1..88523a60b6 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -705,7 +705,11 @@ class LLMChat { for (int i = 0; i < input_tokens.size(); ++i) { NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); int64_t pos = cur_pos + i + 1 - input_tokens.size(); - ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); + if (kv_cache_.empty()){ + ret = decode_func_(input_data, params_); + } else { + ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); + } } } return Downcast(ret[0]); diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index da3d1650fa..5ed8629406 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -862,4 +862,4 @@ def f_convert_param_bkwd(torch_pname: str, raw_param): args.f_convert_pname_fwd = f_convert_pname_fwd args.f_convert_param_bkwd = f_convert_param_bkwd - return mod, [None] * len(pidx2pname) \ No newline at end of file + return mod, [None] * len(pidx2pname) From 399bce4f9fa8bde64ecea07b7afd9061dd0766c3 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 14:35:55 +0300 Subject: [PATCH 088/114] remove input_data from Decode_step due to it is not used and recalculated in Forward --- cpp/llm_chat.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 88523a60b6..f147c91889 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -532,7 +532,6 @@ class LLMChat { void DecodeStep() { ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); - tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token}); auto tstart = std::chrono::high_resolution_clock::now(); From 991f9652a1903e2afbe3f7f8e6197af579b784ba Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 14:56:07 +0300 Subject: [PATCH 089/114] test fixes --- cpp/llm_chat.cc | 3 +-- mlc_llm/relax_model/mpt/mpt.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f147c91889..98509dbec6 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -535,8 +535,7 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); - NDArray logits_on_device = this->Forward({last_token}, total_seq_len_ + 1); - total_seq_len_ += 1; + NDArray logits_on_device = this->Forward({last_token}, ++total_seq_len_); int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 5ed8629406..8c8bdb1aaf 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -825,7 +825,7 @@ def get_model(args, hf_config): create_metadata_func( bb, model_name=model_name, - max_window_size=-1, # TODO: check + max_window_size=128, # TODO: temporal limit for max output length, change to -1 after tests stop_tokens=[0], # TODO: check for mpt embeddings add_prefix_space=False, # TODO: what is it? ) From f940301b5f318b2417d99b229761565fb867a29f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 4 Jul 2023 15:25:17 +0300 Subject: [PATCH 090/114] unroll method from generate in README --- mlc_llm/relax_model/mpt/README.md | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/mlc_llm/relax_model/mpt/README.md b/mlc_llm/relax_model/mpt/README.md index 0169b3a68f..7fc0629cfd 100644 --- a/mlc_llm/relax_model/mpt/README.md +++ b/mlc_llm/relax_model/mpt/README.md @@ -177,9 +177,25 @@ def greedy_search(...): # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=False - ) + # START model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) + # update past + if "past_key_values" in outputs: + model_kwargs["past"] = outputs.past_key_values + else: + model_kwargs["past"] = None + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + # FINISH self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) # stop when each sentence is finished, or if we exceed the maximum length if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): From 6a54257a2b4fd6c1e10ca8aa0d3bda23a18f1744 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 5 Jul 2023 16:19:37 +0300 Subject: [PATCH 091/114] remove test logs. add PrintLogits method --- cpp/llm_chat.cc | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 98509dbec6..76e52a51aa 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -589,12 +589,7 @@ class LLMChat { auto decoding_end = std::chrono::high_resolution_clock::now(); // print first few logits for eyeballs - std::ostringstream os; - for (int i = 0; i < 10; ++i) { - if (i != 0) os << ", "; - os << static_cast(logits_on_cpu_->data)[i]; - } - LOG(INFO) << "logits[:10] =[" << os.str() << "]"; + PrintLogits(10); double encoding_ms = static_cast((decoding_start - encoding_start).count()) / 1e6; double decoding_ms = static_cast((decoding_end - decoding_start).count()) / 1e6; @@ -603,6 +598,23 @@ class LLMChat { << "decoding-time=" << decoding_ms << "ms."; } +void PrintLogits(int logits_num = -1) { + std::string logits_num_tag = std::to_string(logits_num); + if (logits_num == -1) { + logits_num = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; + logits_num_tag = ""; + } + std::ostringstream os; + const float* p_data = static_cast(logits_on_cpu_->data); + for (int i = 0; i < logits_num; ++i) { + if (i != 0) os << ", "; + os << p_data[i]; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << "LOGITS[:" << logits_num_tag << "] = [" << os.str() << "]" << std::endl; + // LOG(INFO) << "logits[:" << logits_num_tag << "] = [" << os.str() << "]"; +} + private: picojson::value SerializeConfigToJSONValue() const { picojson::object config; @@ -789,17 +801,6 @@ class LLMChat { ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; - for (int i = 0; i < logits_on_cpu_->ndim; ++i) { - std::cout << "LOGITS SHAPE[" << i << "] = " << logits_on_cpu_->shape[i] << " "; - } - std::cout << std::endl; - int64_t ndata = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; - const float* p_prob = static_cast(logits_on_cpu_->data); - std::cout << "Logits data: "; - for (int i = 0; i < ndata; ++i) { - std::cout << p_prob[i] << " "; - } - std::cout << std::endl; return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber()); } From 6755d86cb57a6c380c9cbb4caabe3c6a461988f9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 5 Jul 2023 16:26:59 +0300 Subject: [PATCH 092/114] print logits after copy to cpu --- cpp/llm_chat.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 76e52a51aa..928b6eb0b6 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -775,6 +775,7 @@ void PrintLogits(int logits_num = -1) { logits_on_cpu_.CopyFrom(logits_or_prob); } TVMSynchronize(device_.device_type, device_.device_id, nullptr); + this->PrintLogits(); } // Clear kv cache From aea1d83b9f6857e6d16b0774eafe33e619d0b7df Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 08:53:56 +0300 Subject: [PATCH 093/114] print shape together with logits --- cpp/llm_chat.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 928b6eb0b6..299a4d189d 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -599,11 +599,23 @@ class LLMChat { } void PrintLogits(int logits_num = -1) { + size_t ndim = logits_on_cpu_->ndim; std::string logits_num_tag = std::to_string(logits_num); if (logits_num == -1) { - logits_num = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; + logits_num = logits_on_cpu_->shape[ndim - 1]; logits_num_tag = ""; } + // Print shape + std::ostringstream os_shape; + for (size_t i = 0; i < ndim; ++i) { + if (i != 0) os_shape << ", "; + os_shape << logits_on_cpu_->shape[i]; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << "LOGITS SHAPE = [" << os_shape.str() << "]" << std::endl; + // LOG(INFO) << "logits shape = [" << os_shape.str() << "]"; + + // Print specified number of values from logits std::ostringstream os; const float* p_data = static_cast(logits_on_cpu_->data); for (int i = 0; i < logits_num; ++i) { From 9cc2494a23329d2085dad6799cba9990d98fd262 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 10:25:29 +0300 Subject: [PATCH 094/114] test log --- mlc_llm/relax_model/mpt/mpt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 8c8bdb1aaf..14cac27468 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -757,7 +757,9 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache ) - logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) + # TODO: test workaround + logits = outputs[0] + # logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) From 8880ed5894021a669a84021e4538e22161269a30 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 12:08:00 +0300 Subject: [PATCH 095/114] print decode input --- cpp/llm_chat.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 299a4d189d..b4435e5cd3 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -533,6 +533,8 @@ class LLMChat { ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); + std::cout << "LAST TOKEN TO DECODE: " << last_token << std::endl; + auto tstart = std::chrono::high_resolution_clock::now(); NDArray logits_on_device = this->Forward({last_token}, ++total_seq_len_); From 8abc8c24f0f95b9be304bea4af7a47fa087eebfd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 12:28:20 +0300 Subject: [PATCH 096/114] test log in prefill step --- cpp/llm_chat.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index b4435e5cd3..f840018539 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -516,6 +516,10 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); + for (int64_t i = 0; i < token_len; ++i) { + std::cout << "PROMPT TOKEN [" << i << "]: " << prompt_tokens[i] << std::endl; + } + int32_t new_seq_len = total_seq_len_ + token_len; NDArray logits_on_device = this->Forward(prompt_tokens, new_seq_len); total_seq_len_ = new_seq_len; From 1b1458ff7a0d70515a8651bcedc67008321236de Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 12:36:08 +0300 Subject: [PATCH 097/114] print intermediate tensors in topology to catch nan generation --- mlc_llm/relax_model/mpt/mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 14cac27468..d2ae126baf 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -712,7 +712,7 @@ def forward( if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - return x, past_key_values, all_hidden_states, all_self_attns + return tok_emb, past_key_values # x, past_key_values, all_hidden_states, all_self_attns class MPTForCausalLM(nn.Module): From bff07c322811a39e4f2344a6968076723b4a9559 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 12:49:02 +0300 Subject: [PATCH 098/114] print intermediate tensors in topology to catch nan generation: attn_bias --- mlc_llm/relax_model/mpt/mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index d2ae126baf..0332dbdc5c 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -712,7 +712,7 @@ def forward( if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - return tok_emb, past_key_values # x, past_key_values, all_hidden_states, all_self_attns + return x, past_key_values, attn_bias # x, past_key_values, all_hidden_states, all_self_attns class MPTForCausalLM(nn.Module): @@ -764,7 +764,7 @@ def forward( if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) - return logits, outputs[1] + return logits, outputs[1], outputs[2] def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: pidx2pname: Dict[int, str] = {} @@ -773,7 +773,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, input_ids = nn.Placeholder((1, 1), dtype="int32", name="input_ids") with bb.dataflow(): - logits, states = model(input_ids) + logits, states, debug_tensor = model(input_ids) params = [ input_ids, ] + model.parameters() @@ -784,7 +784,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, assert param.same_as(params[i + 1]) if states is None: states = () - gv = bb.emit_output((logits, relax.Tuple(states))) + gv = bb.emit_output((debug_tensor, relax.Tuple(states))) bb.emit_func_output(gv, params) mod = bb.get() From 69c344c1692dcde5e9a50587961e0669f45b49f7 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 14:41:44 +0300 Subject: [PATCH 099/114] revert some debug logs --- mlc_llm/relax_model/mpt/mpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 0332dbdc5c..d2ae126baf 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -712,7 +712,7 @@ def forward( if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - return x, past_key_values, attn_bias # x, past_key_values, all_hidden_states, all_self_attns + return tok_emb, past_key_values # x, past_key_values, all_hidden_states, all_self_attns class MPTForCausalLM(nn.Module): @@ -764,7 +764,7 @@ def forward( if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) - return logits, outputs[1], outputs[2] + return logits, outputs[1] def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: pidx2pname: Dict[int, str] = {} @@ -773,7 +773,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, input_ids = nn.Placeholder((1, 1), dtype="int32", name="input_ids") with bb.dataflow(): - logits, states, debug_tensor = model(input_ids) + logits, states = model(input_ids) params = [ input_ids, ] + model.parameters() @@ -784,7 +784,7 @@ def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, assert param.same_as(params[i + 1]) if states is None: states = () - gv = bb.emit_output((debug_tensor, relax.Tuple(states))) + gv = bb.emit_output((logits, relax.Tuple(states))) bb.emit_func_output(gv, params) mod = bb.get() From e116601e3b4c214a397404c093325d63c8db07c2 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 14:51:07 +0300 Subject: [PATCH 100/114] artificial error --- mlc_llm/relax_model/mpt/mpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index d2ae126baf..f734d42a37 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -670,7 +670,8 @@ def forward( S = input_ids.struct_info.shape[1] assert S <= self.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.max_seq_len}' - tok_emb = self.wte(input_ids) + tok_emb = nn.emit(self.wte(input_ids)) + tok_emb = tvm.tir.bitwise_not(tok_emb) if self.alibi: x = tok_emb else: From c66a5a2e84ea3a3cb00a57aaf5dd457366b0a1cd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:20:18 +0300 Subject: [PATCH 101/114] reimplement all remaining tir funcs to relax --- mlc_llm/relax_model/mpt/mpt.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index f734d42a37..e0132e4efc 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -127,18 +127,18 @@ def scaled_multihead_dot_product_attention( attn_bias.struct_info.shape[-2] != s_q)): # dynamic condition? raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') attn_weight = attn_weight + attn_bias - min_val = tvm.tir.min_value(q.struct_info.dtype) + min_val = get_type_min_val(q) if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) + key_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) attn_weight = nn.emit(relax.op.masked_fill(attn_weight, key_mask, min_val)) if is_causal and (not q.struct_info.shape[2] == 1): s = relax.op.maximum(s_q, s_k) causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) causal_mask = nn.emit(relax.op.tril(causal_mask)) causal_mask = nn.emit(relax.op.astype(causal_mask, "bool")) - causal_mask = tvm.tir.bitwise_not(causal_mask) + causal_mask = nn.emit(relax.op.bitwise_not(causal_mask)) # slicing causal_mask[-s_q:, -s_k:] s_q_end, s_k_end = causal_mask.struct_info.shape causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) @@ -283,8 +283,8 @@ def flash_attn_fn( # (b_size, s_k) = key_padding_mask.struct_info.shape[:2] # if attn_bias is None: # attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) -# key_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) -# attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, tvm.tir.min_value(query.struct_info.dtype))) +# key_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) +# attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, get_type_min_val(query))) # batch_size, seq_len, _ = query.struct_info.shape # query = nn.emit(relax.op.reshape( @@ -524,6 +524,13 @@ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi= raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') +def get_type_min_val(tensor): + return relax.const( + tvm.tir.min_value(tensor.struct_info.dtype).value, + tensor.struct_info.dtype, + ) + + class MPTModel(nn.Module): def __init__(self, config: MPTConfig): config._validate_config() @@ -597,8 +604,8 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [3], [_s_k], [s_k_end])) if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') - min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) - attn_mask = nn.emit(tvm.tir.bitwise_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) + min_val = get_type_min_val(attn_bias) + attn_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, attn_mask, min_val)) return (attn_bias, None) @@ -615,8 +622,8 @@ def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) - cannot_attend = nn.emit(tvm.tir.bitwise_not(relax.op.logical_or(causal, relax.op.astype(prefix, "bool")))) - min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) + cannot_attend = nn.emit(relax.op.bitwise_not(relax.op.logical_or(causal, relax.op.astype(prefix, "bool")))) + min_val = get_type_min_val(attn_bias) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias @@ -630,7 +637,7 @@ def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) cannot_attend = nn.emit(relax.op.expand_dims(relax.op.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) - min_val = tvm.tir.min_value(attn_bias.struct_info.dtype) + min_val = get_type_min_val(attn_bias) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) return attn_bias @@ -670,8 +677,7 @@ def forward( S = input_ids.struct_info.shape[1] assert S <= self.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.max_seq_len}' - tok_emb = nn.emit(self.wte(input_ids)) - tok_emb = tvm.tir.bitwise_not(tok_emb) + tok_emb = self.wte(input_ids) if self.alibi: x = tok_emb else: @@ -686,7 +692,7 @@ def forward( raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) if attention_mask is not None: - pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(tvm.tir.bitwise_not(attention_mask), "int32"), axis=1)) + pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(relax.op.bitwise_not(attention_mask), "int32"), axis=1)) dim1_len = pos_diff_to_slice.struct_info.shape[1] # slicing [:, past_position:] pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) From 63ae7c9544dca56e92588097f5c1ad2f74588f09 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:26:20 +0300 Subject: [PATCH 102/114] print only 10 values from logits --- cpp/llm_chat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f840018539..97f1dc8ea2 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -793,7 +793,7 @@ void PrintLogits(int logits_num = -1) { logits_on_cpu_.CopyFrom(logits_or_prob); } TVMSynchronize(device_.device_type, device_.device_id, nullptr); - this->PrintLogits(); + this->PrintLogits(10); } // Clear kv cache From e68f4fbe30a220974ac3147d7aba34e55d528e48 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:30:41 +0300 Subject: [PATCH 103/114] revert test logits transform --- mlc_llm/relax_model/mpt/mpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index e0132e4efc..741bf57e7a 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -765,8 +765,8 @@ def forward( use_cache=use_cache ) # TODO: test workaround - logits = outputs[0] - # logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) + #logits = outputs[0] + logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) From 0ed398b36e2294373d4c6b4fb1d198e74149df18 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:41:57 +0300 Subject: [PATCH 104/114] remove debug logs and workaround --- cpp/llm_chat.cc | 1 - mlc_llm/relax_model/mpt/mpt.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 97f1dc8ea2..bd0a423956 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -793,7 +793,6 @@ void PrintLogits(int logits_num = -1) { logits_on_cpu_.CopyFrom(logits_or_prob); } TVMSynchronize(device_.device_type, device_.device_id, nullptr); - this->PrintLogits(10); } // Clear kv cache diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 741bf57e7a..b2c63dd189 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -764,8 +764,7 @@ def forward( output_hidden_states=output_hidden_states, use_cache=use_cache ) - # TODO: test workaround - #logits = outputs[0] + logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) if logits.struct_info.dtype != "float32": From 24b9a9cfa9e66ed2d0f2a488d04ec6acedfe385e Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:46:34 +0300 Subject: [PATCH 105/114] return correct output --- mlc_llm/relax_model/mpt/mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index b2c63dd189..465cd38a17 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -719,7 +719,7 @@ def forward( if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) - return tok_emb, past_key_values # x, past_key_values, all_hidden_states, all_self_attns + return x, past_key_values, all_hidden_states, all_self_attns class MPTForCausalLM(nn.Module): From fb7870ee21bd21340598d19624b2806eef65367b Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 16:57:24 +0300 Subject: [PATCH 106/114] continue debug --- mlc_llm/relax_model/mpt/mpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 465cd38a17..105e8b3e37 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -715,7 +715,7 @@ def forward( if output_attentions: assert all_self_attns is not None all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) + # x = self.norm_f(x) if output_hidden_states: assert all_hidden_states is not None all_hidden_states = all_hidden_states + (x,) From 418aadba2f57e63f250398c1c03891e9133fac37 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 6 Jul 2023 18:36:03 +0300 Subject: [PATCH 107/114] remove unneccessary parts from mpt topology. calculate query-key matmul in float32 to avoid inf generation --- cpp/conv_templates.cc | 8 ++-- cpp/llm_chat.cc | 77 ++++++++++++++++++++-------------- mlc_llm/relax_model/mpt/mpt.py | 67 +++++++++++++---------------- mlc_llm/utils.py | 18 +++++--- 4 files changed, 89 insertions(+), 81 deletions(-) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 98ef7424c7..924573fa1d 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -299,17 +299,17 @@ Conversation MPT() { Conversation conv; conv.name = "mpt"; conv.system = ""; - conv.roles = {"client", "instructor"}; + conv.roles = {"", ""}; conv.messages = {}; conv.separator_style = SeparatorStyle::kSepRoleMsg; conv.offset = 0; conv.seps = {"\n"}; - conv.role_msg_sep = ": "; - conv.role_empty_sep = "?"; - conv.stop_str = "stop"; + conv.role_msg_sep = ""; + conv.role_empty_sep = ""; // TODO(mlc-team): add eos to mlc-chat-config // and remove eos from stop token setting. conv.stop_tokens = {0}; + conv.stop_str = "<|endoftext|>"; conv.add_bos = false; return conv; } diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index bd0a423956..cf69827ecb 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -516,10 +516,6 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); - for (int64_t i = 0; i < token_len; ++i) { - std::cout << "PROMPT TOKEN [" << i << "]: " << prompt_tokens[i] << std::endl; - } - int32_t new_seq_len = total_seq_len_ + token_len; NDArray logits_on_device = this->Forward(prompt_tokens, new_seq_len); total_seq_len_ = new_seq_len; @@ -537,8 +533,6 @@ class LLMChat { ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); - std::cout << "LAST TOKEN TO DECODE: " << last_token << std::endl; - auto tstart = std::chrono::high_resolution_clock::now(); NDArray logits_on_device = this->Forward({last_token}, ++total_seq_len_); @@ -595,7 +589,7 @@ class LLMChat { auto decoding_end = std::chrono::high_resolution_clock::now(); // print first few logits for eyeballs - PrintLogits(10); + PrintNDArray(logits_on_cpu_, 10, "Logits"); double encoding_ms = static_cast((decoding_start - encoding_start).count()) / 1e6; double decoding_ms = static_cast((decoding_end - decoding_start).count()) / 1e6; @@ -604,34 +598,52 @@ class LLMChat { << "decoding-time=" << decoding_ms << "ms."; } -void PrintLogits(int logits_num = -1) { - size_t ndim = logits_on_cpu_->ndim; - std::string logits_num_tag = std::to_string(logits_num); - if (logits_num == -1) { - logits_num = logits_on_cpu_->shape[ndim - 1]; - logits_num_tag = ""; - } - // Print shape - std::ostringstream os_shape; - for (size_t i = 0; i < ndim; ++i) { - if (i != 0) os_shape << ", "; - os_shape << logits_on_cpu_->shape[i]; + NDArray getArrayToPrint(NDArray array) const { + ICHECK(array->data != nullptr) << "Array data is nullptr"; + // Check that the data on CPU and copy if need + if (array->device.device_type != kDLCPU) { + NDArray array_cpu; + array_cpu = array.CopyTo(DLDevice{kDLCPU, 0}); + TVMSynchronize(device_.device_type, device_.device_id, nullptr); + return array_cpu; + } else { + return array; + } } - // TODO(vchernov): after test return LOG(INFO) - std::cout << "LOGITS SHAPE = [" << os_shape.str() << "]" << std::endl; - // LOG(INFO) << "logits shape = [" << os_shape.str() << "]"; - // Print specified number of values from logits - std::ostringstream os; - const float* p_data = static_cast(logits_on_cpu_->data); - for (int i = 0; i < logits_num; ++i) { - if (i != 0) os << ", "; - os << p_data[i]; + void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor") const { + NDArray array_cpu = getArrayToPrint(array); + + size_t ndim = array_cpu->ndim; + int64_t numel = 1; + // Print shape and calculate numel + std::ostringstream os_shape; + for (size_t i = 0; i < ndim; ++i) { + if (i != 0) os_shape << ", "; + numel *= array_cpu->shape[i]; + os_shape << array_cpu->shape[i]; + } + + std::string num_tag = std::to_string(num); + if (num == -1 || num >= numel) { + num = numel; + num_tag = ""; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << tensor_tag << " shape = [" << os_shape.str() << "]" << std::endl; + // LOG(INFO) << tensor_tag << " shape = [" << os_shape.str() << "]"; + + // Print specified number of values from tensor + std::ostringstream os; + const float* p_data = static_cast(array_cpu->data); + for (int64_t i = 0; i < num; ++i) { + if (i != 0) os << ", "; + os << p_data[i]; + } + // TODO(vchernov): after test return LOG(INFO) + std::cout << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]" << std::endl; + // LOG(INFO) << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]"; } - // TODO(vchernov): after test return LOG(INFO) - std::cout << "LOGITS[:" << logits_num_tag << "] = [" << os.str() << "]" << std::endl; - // LOG(INFO) << "logits[:" << logits_num_tag << "] = [" << os.str() << "]"; -} private: picojson::value SerializeConfigToJSONValue() const { @@ -793,6 +805,7 @@ void PrintLogits(int logits_num = -1) { logits_on_cpu_.CopyFrom(logits_or_prob); } TVMSynchronize(device_.device_type, device_.device_id, nullptr); + // PrintNDArray(logits_on_cpu_, 100, "Logits"); } // Clear kv cache diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 105e8b3e37..9fbff247a1 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -98,8 +98,9 @@ def scaled_multihead_dot_product_attention( key_padding_mask=None, is_causal=False, needs_weights=False, - multiquery=False + multiquery=False, ): + dtype = query.struct_info.dtype q = reshape_and_permute(query, n_heads, d_model) kv_n_heads = 1 if multiquery else n_heads k = reshape_and_permute(key, kv_n_heads, d_model, [0, 2, 3, 1]) @@ -113,6 +114,10 @@ def scaled_multihead_dot_product_attention( s_k = k.struct_info.shape[-1] if softmax_scale is None: softmax_scale = 1 / math.sqrt(d) + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. There is workaround + if dtype != "float32": + q = nn.emit(relax.op.astype(q, "float32")) + k = nn.emit(relax.op.astype(k, "float32")) softmax_scale = relax.op.astype(relax.const(softmax_scale), q.struct_info.dtype) attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) _, _, s_q_end, s_k_end = attn_bias.struct_info.shape @@ -126,6 +131,9 @@ def scaled_multihead_dot_product_attention( (attn_bias.struct_info.shape[-2] != 1 and attn_bias.struct_info.shape[-2] != s_q)): # dynamic condition? raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. + if dtype != "float32": + attn_bias = nn.emit(relax.op.astype(attn_bias, "float32")) attn_weight = attn_weight + attn_bias min_val = get_type_min_val(q) if key_padding_mask is not None: @@ -144,12 +152,15 @@ def scaled_multihead_dot_product_attention( causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) attn_weight = nn.emit(relax.op.masked_fill(attn_weight, causal_mask, min_val)) + # TODO(vchernov): matmul(q, k) generates inf when float16 is used. + # There is uncast after workaround with float calculation due to softmax range = [0, 1] attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) + if dtype != "float32": + attn_weight = nn.emit(relax.op.astype(attn_weight, dtype)) out = nn.emit(relax.op.matmul(attn_weight, v)) out = reverse_reshape_and_permute(out) - if needs_weights: - return (out, attn_weight, past_key_value) - return (out, None, past_key_value) + + return (out, past_key_value) ######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (END) ########################## @@ -331,8 +342,7 @@ def __init__( attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, - softmax_scale: Optional[float]=None, - low_precision_layernorm: bool=False + softmax_scale: Optional[float]=None ): # Init fields self.d_model = d_model @@ -348,9 +358,8 @@ def __init__( fuse_splits = (d_model, 2 * d_model) self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: - layernorm_class = LPLayerNormWOBias if low_precision_layernorm else LayerNorm - self.q_ln = layernorm_class(self.d_model, dtype) - self.k_ln = layernorm_class(self.d_model, dtype) + self.q_ln = LayerNorm(self.d_model, dtype) + self.k_ln = LayerNorm(self.d_model, dtype) if self.attn_impl == 'flash': self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': @@ -369,7 +378,7 @@ def __init__( # TODO: Does field _is_residual exist? # self.out_proj._is_residual = True - def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False): + def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True): qkv = self.Wqkv(x) if self.clip_qkv: qkv = nn.emit(relax.op.clip(qkv, min=relax.const(-self.clip_qkv), max=relax.const(self.clip_qkv))) @@ -393,9 +402,9 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, - needs_weights=needs_weights + needs_weights=False, ) - return (self.out_proj(attn_out[0]), attn_out[1], attn_out[2]) + return (self.out_proj(attn_out[0]), attn_out[1]) ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} @@ -446,10 +455,10 @@ def forward( is_causal: bool=True, ) -> Tuple[relax.Expr, relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states) # TODO: debug comment: not nan # Self Attention - (hidden_states, attn_weights, present_key_value) = self.self_attn( + (hidden_states, present_key_value) = self.self_attn( hidden_states, past_key_value=past_key_value, attn_bias=attn_bias, @@ -463,7 +472,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = nn.emit(residual + hidden_states) - return (hidden_states, attn_weights, present_key_value) + return (hidden_states, present_key_value) def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): @@ -649,8 +658,6 @@ def forward( prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None, return_dict: Optional[bool]=None, - output_attentions: Optional[bool]=None, - output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None ): return_dict = return_dict if return_dict is not None else self.return_dict @@ -668,9 +675,6 @@ def forward( prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) if not return_dict: raise NotImplementedError('return_dict False is not implemented yet for MPT') - if output_attentions: - if self.attn_impl != 'torch': - raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.') if self.prefix_lm and prefix_mask is None: raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') @@ -702,24 +706,13 @@ def forward( (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.n_layers)] - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None for (b_idx, block) in enumerate(self.blocks): - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) past_key_value = past_key_values[b_idx] if past_key_values is not None else None - (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) if past_key_values is not None: past_key_values[b_idx] = past_key_value - if output_attentions: - assert all_self_attns is not None - all_self_attns = all_self_attns + (attn_weights,) - # x = self.norm_f(x) - if output_hidden_states: - assert all_hidden_states is not None - all_hidden_states = all_hidden_states + (x,) - return x, past_key_values, all_hidden_states, all_self_attns + x = self.norm_f(x) + return x, past_key_values class MPTForCausalLM(nn.Module): @@ -740,8 +733,6 @@ def forward( prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None, return_dict: Optional[bool]=None, - output_attentions: Optional[bool]=None, - output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None ): return_dict = return_dict if return_dict is not None else self.return_dict @@ -760,8 +751,6 @@ def forward( prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, use_cache=use_cache ) @@ -834,7 +823,7 @@ def get_model(args, hf_config): bb, model_name=model_name, max_window_size=128, # TODO: temporal limit for max output length, change to -1 after tests - stop_tokens=[0], # TODO: check for mpt embeddings + stop_tokens=[0], add_prefix_space=False, # TODO: what is it? ) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 7ae6c09c0e..6cebcd2724 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -221,16 +221,19 @@ def get_item(i): torch_param_names = list(torch_params.keys()) for torch_param_name in torch_param_names: if str(torch_params[torch_param_name].dtype) == "torch.bfloat16": - # Convert to float32 first. - raw_param = ( - torch_params[torch_param_name].detach().cpu().float().numpy() - ) + if args.quantization.mode == "no" and args.quantization.model_dtype == "float16": + raw_param = ( + torch_params[torch_param_name].detach().cpu().to(dtype=torch.float16).numpy() + ) + else: + # Convert to float32 first. + raw_param = ( + torch_params[torch_param_name].detach().cpu().float().numpy() + ) else: raw_param = torch_params[torch_param_name].detach().cpu().numpy() del torch_params[torch_param_name] - if not raw_param.flags['C_CONTIGUOUS']: - print("NON_CONTIGUOUS TENSOR WAS FOUND:", torch_param_name) for param_name, param in f_convert_param_bkwd( torch_param_name, raw_param ): @@ -292,7 +295,10 @@ def load_params(artifact_path: str, device) -> List[tvm.nd.NDArray]: params, meta = tvmjs.load_ndarray_cache(f"{artifact_path}/params", device) plist = [] size = meta["ParamSize"] + print("META:", meta) for i in range(size): + if i == 2: + print("PARAM FROM BIN:", params[f"param_{i}"]) plist.append(params[f"param_{i}"]) return plist From 6ee82c74db365b8c1e31a52a1ac10d251419b2fd Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 24 Jul 2023 14:13:13 +0300 Subject: [PATCH 108/114] create comparator --- mlc_llm/relax_model/mpt/compare.py | 79 ++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 mlc_llm/relax_model/mpt/compare.py diff --git a/mlc_llm/relax_model/mpt/compare.py b/mlc_llm/relax_model/mpt/compare.py new file mode 100644 index 0000000000..21bdbda2b9 --- /dev/null +++ b/mlc_llm/relax_model/mpt/compare.py @@ -0,0 +1,79 @@ +from pathlib import Path +import argparse + +import torch +import numpy as np + +# std::ofstream fs("tensor.bin", std::ios::out | std::ios::binary | std::ios::app); +# fs.write(reinterpret_cast(&tensor), sizeof tensor); +# fs.close(); + +def save_torch_tensor(t: torch.tensor, path=Path("./orig_input.pt")): + torch.save(t, path) + +def load_torch_tensor(path=Path("./orig_input.pt")): + return torch.load(path) + +def advanced_compare(lft, rht, atol=1e-5, rtol=1e-5): + if len(lft.shape) > 1: + lft = lft.flatten() + if len(rht.shape) > 1: + lft = rht.flatten() + numel = lft.shape[0] + assert numel == rht.shape[0] + counter = 0 + rtols=[rtol] + for i in range(numel): + diff = np.abs(lft[i]-rht[i]) + exp_diff = atol + rtol*np.abs(rht[i]) + if diff > exp_diff: + new_rtol = (diff - atol)/np.abs(rht[i]) + rtols.append(new_rtol) + print("Elements with index", i, " are not the same left:", lft[i], " right:", rht[i]) + counter = counter + 1 + print("Number of diverged values:", counter, " Percent is", 100*float(counter)/numel,"%") + max_rtol = np.max(rtols) + print("Current rtol:", rtol, "Maximum rtol:", max_rtol) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-r', '--rtol', type=float, default=5e-3, + help="Relative tolerance") + parser.add_argument('-a', '--atol', type=float, default=1e-6, + help="Absolute tolerance") + parser.add_argument('-w', '--check_weight', default=False, action="store_true", + help="Compare weights. Corresponding files are required") + + args = parser.parse_args() + + check_num = 10 + # Load data from Relax model + np_input = np.fromfile(Path("./relax_input.bin"), dtype="float32") + np_weight = np.fromfile(Path("./relax_weight.bin"), dtype="float32") + print("RELAX INPUT TYPE:", np_input.dtype, "SHAPE:", np_input.shape) + print("RELAX WEIGHT TYPE:", np_weight.dtype, "SHAPE:", np_weight.shape) + + # Load data from original model + orig_input = load_torch_tensor() + orig_weight = load_torch_tensor(Path("./orig_weight.pt")) + + orig_np_input = orig_input.numpy() + orig_np_weight = orig_weight.numpy() + print("ORIG INPUT TYPE:", orig_np_input.dtype, "SHAPE:", orig_np_input.shape) + print("ORIG WEIGHT TYPE:", orig_np_weight.dtype, "SHAPE:", orig_np_weight.shape) + + print("Compare inputs") + print("ORIG INPUT:", orig_np_input[:check_num]) + print("RELAX INPUT:", np_input[:check_num]) + # np.testing.assert_allclose(orig_np_input, np_input, rtol=rtol, atol=atol, verbose=True) + advanced_compare(orig_np_input, np_input, rtol=args.rtol, atol=args.atol) + + if args.check_weight: + print("Compare weights") + orig_np_line = orig_np_weight[0,:] + print("ORIG WEIGHT:", orig_np_line[:check_num]) + print("RELAX WEIGHT:", np_weight[:check_num]) + np.testing.assert_allclose(orig_np_line, np_weight, rtol=args.rtol, atol=args.atol, verbose=True) + +if __name__ == "__main__": + main() From 16392c9638d4ab5896c0fb07780979f2f90c8080 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 24 Jul 2023 14:16:56 +0300 Subject: [PATCH 109/114] update README --- mlc_llm/relax_model/mpt/README.md | 58 +++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/mlc_llm/relax_model/mpt/README.md b/mlc_llm/relax_model/mpt/README.md index 7fc0629cfd..b0ccf83ed1 100644 --- a/mlc_llm/relax_model/mpt/README.md +++ b/mlc_llm/relax_model/mpt/README.md @@ -133,6 +133,43 @@ else: "task_specific_params": None +Some parameters from generate() function from transformers: +```python +is_greedy_gen_mode = True +``` + +Start greedy_search method in generate() from transformers: +```python +self.greedy_search( + input_ids, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, +) +``` +Where parameters for MPT-7b-instruct: +```python +logits_processor? +stopping_criteria? +pad_token_id = None +eos_token_id None +output_scores = False +return_dict_in_generate = False +synced_gpus = False +streamer = None +model_kwargs = { + 'output_attentions': False, + 'output_hidden_states': False, + 'use_cache': False, + 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0') +} +``` Refactored greedy_search method for MPT-7b-instruct: ```python @@ -142,6 +179,7 @@ def greedy_search(...): stopping_criteria = stopping_criteria # max_length and max_time criteria pad_token_id = None eos_token_id = None + eos_token_id_tensor = None output_scores = False output_attentions = False output_hidden_states = False @@ -150,6 +188,7 @@ def greedy_search(...): # init attention / hidden states / scores tuples scores = None decoder_attentions = None + cross_attentions = None decoder_hidden_states = None # keep track of which sequences are already finished @@ -158,6 +197,13 @@ def greedy_search(...): while True: # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # model_inputs = { + # 'input_ids': tensor([[...]], device='cuda:0'), + # 'attention_mask': tensor([[True, ..., True]], device='cuda:0'), + # 'prefix_mask': None, + # 'sequence_id': None, + # 'past_key_values': None, + # 'use_cache': False} # forward pass to get next token outputs = self( @@ -178,16 +224,8 @@ def greedy_search(...): # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) # START model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) - # update past - if "past_key_values" in outputs: - model_kwargs["past"] = outputs.past_key_values - else: - model_kwargs["past"] = None - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + + model_kwargs["past"] = None # update attention mask if "attention_mask" in model_kwargs: From ddcea189866afaf38724c92eb5c40082a63fb0af Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 24 Jul 2023 14:33:43 +0300 Subject: [PATCH 110/114] update mpt model file: fix layernorm, remove some TODOs, remove excess code, comment unneccessary code parts, upstream layer names for correct mapping --- mlc_llm/relax_model/mpt/mpt.py | 108 +++++++++++++-------------------- 1 file changed, 43 insertions(+), 65 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index 9fbff247a1..f099a43bea 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -4,9 +4,8 @@ import numpy as np import tvm -from tvm import relax, te +from tvm import relax from tvm.relax.testing import nn -from tvm.script import relax as R from .mpt_config import MPTConfig, attn_config_defaults from ...utils import load_torch_pname2binname_map @@ -34,9 +33,8 @@ def _cast_if_autocast_enabled(tensor: relax.Expr, dtype="float32"): class LPLayerNormWOBias(nn.Module): def __init__(self, normalized_shape, dtype, eps=1e-05): self.weight = nn.Parameter((normalized_shape,), dtype=dtype, name="low_precision_layernorm_weight") - # TODO: check default filling of weights - self.weight = relax.op.ones((normalized_shape,), dtype) - self.bias = relax.op.zeros((normalized_shape,), dtype) + # TODO(vchernov): need to set something to layer_norm, but not use + self.dummy_bias = relax.op.zeros((normalized_shape,), dtype) self.eps = eps self.dtype = dtype @@ -44,9 +42,8 @@ def __init__(self, normalized_shape, dtype, eps=1e-05): def forward(self, x): dtype = self.dtype # TODO: temporal workaround downcast_x = _cast_if_autocast_enabled(x, dtype) - downcast_weight = _cast_if_autocast_enabled(self.weight, dtype) if self.weight is not None else self.weight - downcast_bias = _cast_if_autocast_enabled(self.bias, dtype) if self.bias is not None else self.bias - return nn.emit(relax.op.nn.layer_norm(downcast_x, downcast_weight, downcast_bias, axes=-1, epsilon=self.eps)) + downcast_weight = _cast_if_autocast_enabled(self.weight, dtype) + return nn.emit(relax.op.nn.layer_norm(downcast_x, downcast_weight, self.dummy_bias, axes=-1, epsilon=self.eps, center=False)) NORM_CLASS_REGISTRY = {'low_precision_layernorm': LPLayerNormWOBias} @@ -355,8 +352,6 @@ def __init__( if self.softmax_scale is None: self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) self.Wqkv = Linear(self.d_model, 3 * self.d_model, dtype, bias=False) - fuse_splits = (d_model, 2 * d_model) - self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: self.q_ln = LayerNorm(self.d_model, dtype) self.k_ln = LayerNorm(self.d_model, dtype) @@ -375,8 +370,6 @@ def __init__( else: raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = Linear(self.d_model, self.d_model, dtype, bias=False) - # TODO: Does field _is_residual exist? - # self.out_proj._is_residual = True def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True): qkv = self.Wqkv(x) @@ -429,7 +422,7 @@ def __init__(self, config: MPTConfig): self.hidden_size = config.d_model # Init layers - self.self_attn = attn_class( + self.attn = attn_class( d_model=self.hidden_size, n_heads=config.n_heads, dtype=config.dtype, @@ -438,13 +431,13 @@ def __init__(self, config: MPTConfig): qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], ) - self.mlp = MPTMLP( + self.ffn = MPTMLP( hidden_size=self.hidden_size, intermediate_size=config.expansion_ratio*self.hidden_size, dtype=config.dtype, ) - self.input_layernorm = norm_class(self.hidden_size, config.dtype) - self.post_attention_layernorm = norm_class(self.hidden_size, config.dtype) + self.norm_1 = norm_class(self.hidden_size, config.dtype) + self.norm_2 = norm_class(self.hidden_size, config.dtype) def forward( self, @@ -455,10 +448,10 @@ def forward( is_causal: bool=True, ) -> Tuple[relax.Expr, relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) # TODO: debug comment: not nan + hidden_states = self.norm_1(hidden_states) # Self Attention - (hidden_states, present_key_value) = self.self_attn( + (hidden_states, present_key_value) = self.attn( hidden_states, past_key_value=past_key_value, attn_bias=attn_bias, @@ -468,8 +461,8 @@ def forward( residual = nn.emit(residual + hidden_states) # Fully Connected - hidden_states = self.post_attention_layernorm(residual) - hidden_states = self.mlp(hidden_states) + hidden_states = self.norm_2(residual) + hidden_states = self.ffn(hidden_states) hidden_states = nn.emit(residual + hidden_states) return (hidden_states, present_key_value) @@ -660,7 +653,7 @@ def forward( return_dict: Optional[bool]=None, use_cache: Optional[bool]=None ): - return_dict = return_dict if return_dict is not None else self.return_dict + # return_dict = return_dict if return_dict is not None else self.return_dict use_cache = use_cache if use_cache is not None else self.use_cache if attention_mask is not None: attention_mask = nn.emit(relax.op.astype(attention_mask, "bool")) @@ -671,12 +664,12 @@ def forward( relax.op.strided_slice(attention_mask, [1], [dim1_len - 1], [dim1_len]) ) != attention_mask.struct_info.shape[0]: raise NotImplementedError('MPT does not support generation with right padding.') - if prefix_mask is not None: - prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) - if not return_dict: - raise NotImplementedError('return_dict False is not implemented yet for MPT') - if self.prefix_lm and prefix_mask is None: - raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') + # if prefix_mask is not None: + # prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) + # if not return_dict: + # raise NotImplementedError('return_dict False is not implemented yet for MPT') + # if self.prefix_lm and prefix_mask is None: + # raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') S = input_ids.struct_info.shape[1] assert S <= self.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.max_seq_len}' @@ -684,25 +677,25 @@ def forward( tok_emb = self.wte(input_ids) if self.alibi: x = tok_emb - else: - past_position = 0 - if past_key_values is not None: - if len(past_key_values) != self.n_layers: - raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.n_layers!r}).') - past_position = past_key_values[0][0].struct_info.shape[1] - if self.attn_impl == 'torch': - past_position = past_key_values[0][0].struct_info.shape[3] - if S + past_position > self.max_seq_len: - raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') - pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) - if attention_mask is not None: - pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(relax.op.bitwise_not(attention_mask), "int32"), axis=1)) - dim1_len = pos_diff_to_slice.struct_info.shape[1] - # slicing [:, past_position:] - pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) - pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) - pos_emb = self.wpe(pos) - x = tok_emb + pos_emb + # else: + # past_position = 0 + # if past_key_values is not None: + # if len(past_key_values) != self.n_layers: + # raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.n_layers!r}).') + # past_position = past_key_values[0][0].struct_info.shape[1] + # if self.attn_impl == 'torch': + # past_position = past_key_values[0][0].struct_info.shape[3] + # if S + past_position > self.max_seq_len: + # raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.max_seq_len}.') + # pos = nn.emit(relax.op.expand_dims(relax.op.arange(past_position, S + past_position, dtype="long"), axis=0)) + # if attention_mask is not None: + # pos_diff_to_slice = nn.emit(relax.op.cumsum(relax.op.astype(relax.op.bitwise_not(attention_mask), "int32"), axis=1)) + # dim1_len = pos_diff_to_slice.struct_info.shape[1] + # # slicing [:, past_position:] + # pos_diff = nn.emit(relax.op.strided_slice(pos_diff_to_slice, [1], [past_position], [dim1_len])) + # pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) + # pos_emb = self.wpe(pos) + # x = tok_emb + pos_emb (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) if use_cache and past_key_values is None: past_key_values = [() for _ in range(self.n_layers)] @@ -735,7 +728,7 @@ def forward( return_dict: Optional[bool]=None, use_cache: Optional[bool]=None ): - return_dict = return_dict if return_dict is not None else self.return_dict + # return_dict = return_dict if return_dict is not None else self.return_dict use_cache = use_cache if use_cache is not None else self.use_cache # It is part from prepare_inputs_for_generation @@ -750,7 +743,7 @@ def forward( attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, - return_dict=return_dict, + return_dict=self.return_dict, use_cache=use_cache ) @@ -830,29 +823,14 @@ def get_model(args, hf_config): mod = bb.get() def f_convert_pname_fwd(pname: str) -> str: - if "self_attn" in pname: - return pname.replace("self_attn", "attn") - elif "mlp" in pname: - return pname.replace("mlp", "ffn") - else: - return pname + return pname pname2binname = load_torch_pname2binname_map( model_path, set(pidx2pname.values()), f_convert_pname_fwd ) def f_convert_param_bkwd(torch_pname: str, raw_param): - if "attn" in torch_pname: - pname = torch_pname.replace("attn", "self_attn") - elif "ffn" in torch_pname: - pname = torch_pname.replace("ffn", "mlp") - else: - pname = torch_pname - - # TVM does not support bfloat16 - if raw_param.dtype == "bfloat16": - raw_param = raw_param.astype("float16") - return [(pname, raw_param)] + return [(torch_pname, raw_param)] args.pidx2pname = pidx2pname args.pname2binname = pname2binname From 7b52b9934b9007c4d21a88bd7e0a18bc99da8940 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 24 Jul 2023 14:34:26 +0300 Subject: [PATCH 111/114] remove debug prints --- mlc_llm/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 6cebcd2724..2508de68c9 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -295,10 +295,7 @@ def load_params(artifact_path: str, device) -> List[tvm.nd.NDArray]: params, meta = tvmjs.load_ndarray_cache(f"{artifact_path}/params", device) plist = [] size = meta["ParamSize"] - print("META:", meta) for i in range(size): - if i == 2: - print("PARAM FROM BIN:", params[f"param_{i}"]) plist.append(params[f"param_{i}"]) return plist From a8575837cbe7cf6420a58e055d5ecc191741c672 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Mon, 24 Jul 2023 14:39:29 +0300 Subject: [PATCH 112/114] update PrintNDArray method --- cpp/llm_chat.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index cf69827ecb..6ba9362fb3 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -128,7 +128,7 @@ class LLMChat { friend class LLMChatModule; public: - explicit LLMChat(DLDevice device) : device_(device) {} + explicit LLMChat(DLDevice device) : device_(device), pass_index_(0) {} /*! * \return Text describing runtime stats. @@ -611,7 +611,7 @@ class LLMChat { } } - void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor") const { + void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor", bool save = false) { NDArray array_cpu = getArrayToPrint(array); size_t ndim = array_cpu->ndim; @@ -643,6 +643,15 @@ class LLMChat { // TODO(vchernov): after test return LOG(INFO) std::cout << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]" << std::endl; // LOG(INFO) << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]"; + + if (save) { + // Save to binary file + std::string file_name = "tensor_" + std::to_string(pass_index_++) + ".bin"; + std::cout << tensor_tag << " is saved in " << file_name << std::endl; + std::ofstream fs(file_name, std::ios::out | std::ios::binary | std::ios::app); + fs.write(reinterpret_cast(p_data), 4 * numel); + fs.close(); + } } private: @@ -915,6 +924,8 @@ class LLMChat { Array kv_cache_; // Temp logits on cpu NDArray logits_on_cpu_{nullptr}; + // Counter of prefill-decode passes + int32_t pass_index_; }; /*! From 81f92a5c42635eb13aef0165dd55f1de6748ed49 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 25 Aug 2023 11:52:34 +0400 Subject: [PATCH 113/114] support mlc-llm chat using with or without kv cache --- build.py | 24 +++++++++++++++++----- cpp/llm_chat.cc | 54 ++++++++++++++++++++++++++++++++----------------- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/build.py b/build.py index 48ba7fc4d3..4f75a99b5d 100644 --- a/build.py +++ b/build.py @@ -57,6 +57,12 @@ def _parse_args(): default=1, help="Whether to use previously pickled IRModule and skip trace.", ) + args.add_argument( + "--use-kv-cache", + action="store_false", + default=True, + help="Forcely replace use_cache hyperparameter in model config", + ) args.add_argument("--debug-dump", action="store_true", default=False) args.add_argument("--debug-load-script", action="store_true", default=False) args.add_argument( @@ -275,11 +281,19 @@ def mod_transform_before_build( "reset_kv_cache", ] elif ARGS.model.startswith("mpt-"): - model_names = [ - "decode", - "softmax_with_temperature", - "get_metadata", - ] + if ARGS.use_kv_cache: + model_names = [ + "decode", + "create_kv_cache", + "softmax_with_temperature", + "get_metadata", + ] + else: + model_names = [ + "decode", + "softmax_with_temperature", + "get_metadata", + ] else: model_names = [ "prefill", diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 6ba9362fb3..7e397fd407 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -128,7 +128,7 @@ class LLMChat { friend class LLMChatModule; public: - explicit LLMChat(DLDevice device) : device_(device), pass_index_(0) {} + explicit LLMChat(DLDevice device) : device_(device), debug_index_(0) {} /*! * \return Text describing runtime stats. @@ -511,6 +511,9 @@ class LLMChat { } std::vector prompt_tokens = this->GetInputTokens(); + if (kv_cache_.empty()) { + full_output_ids_.insert(full_output_ids_.end(), prompt_tokens.begin(), prompt_tokens.end()); + } int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) return; @@ -530,12 +533,18 @@ class LLMChat { } void DecodeStep() { - ICHECK(!output_ids_.empty()); - int32_t last_token = output_ids_.back(); + std::vector input_tokens; + if (kv_cache_.empty()) { + ICHECK(!full_output_ids_.empty()); + input_tokens = full_output_ids_; + } else { + ICHECK(!output_ids_.empty()); + input_tokens = {output_ids_.back()}; + } auto tstart = std::chrono::high_resolution_clock::now(); - NDArray logits_on_device = this->Forward({last_token}, ++total_seq_len_); + NDArray logits_on_device = this->Forward(input_tokens, ++total_seq_len_); int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); @@ -611,7 +620,7 @@ class LLMChat { } } - void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor", bool save = false) { + void PrintNDArray(NDArray array, int64_t num = -1, std::string tensor_tag = "Tensor", bool to_save = false) { NDArray array_cpu = getArrayToPrint(array); size_t ndim = array_cpu->ndim; @@ -644,9 +653,9 @@ class LLMChat { std::cout << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]" << std::endl; // LOG(INFO) << tensor_tag << "[:" << num_tag << "] = [" << os.str() << "]"; - if (save) { - // Save to binary file - std::string file_name = "tensor_" + std::to_string(pass_index_++) + ".bin"; + // Save to binary file + if (to_save) { + std::string file_name = "tensor_" + std::to_string(debug_index_++) + ".bin"; std::cout << tensor_tag << " is saved in " << file_name << std::endl; std::ofstream fs(file_name, std::ios::out | std::ios::binary | std::ios::app); fs.write(reinterpret_cast(p_data), 4 * numel); @@ -708,6 +717,9 @@ class LLMChat { if (!stop_triggered_) { output_ids_.push_back(next_token); + if (kv_cache_.empty()) { + full_output_ids_.push_back(next_token); + } appeared_token_ids_.insert(next_token); } @@ -751,12 +763,14 @@ class LLMChat { ret = prefill_func_(input_data, ShapeTuple({cur_pos}), kv_cache_, params_); } else { // running decode function when prefill is not available - for (int i = 0; i < input_tokens.size(); ++i) { - NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); - int64_t pos = cur_pos + i + 1 - input_tokens.size(); - if (kv_cache_.empty()){ - ret = decode_func_(input_data, params_); - } else { + if (kv_cache_.empty()){ + // Without kv_cache full sequence of tokens is used + NDArray input_data = this->GetInputTokenNDArray(input_tokens); + ret = decode_func_(input_data, params_); + } else { + for (int i = 0; i < input_tokens.size(); ++i) { + NDArray input_data = this->GetInputTokenNDArray({input_tokens[i]}); + int64_t pos = cur_pos + i + 1 - input_tokens.size(); ret = decode_func_(input_data, ShapeTuple({pos}), kv_cache_, params_); } } @@ -814,13 +828,15 @@ class LLMChat { logits_on_cpu_.CopyFrom(logits_or_prob); } TVMSynchronize(device_.device_type, device_.device_id, nullptr); - // PrintNDArray(logits_on_cpu_, 100, "Logits"); } // Clear kv cache void ResetKVCache() { reset_kv_cache_func_(kv_cache_); } - void ProcessSystemPrompts() { this->PrefillStep(/*inp=*/"", /*append_conversation=*/false); } + void ProcessSystemPrompts() { + full_output_ids_.clear(); + this->PrefillStep(/*inp=*/"", /*append_conversation=*/false); + } // Utils static double GetRandomNumber() { @@ -874,6 +890,8 @@ class LLMChat { double top_p_{0.95}; // output ids till now (refresh after encoding step) std::vector output_ids_; + // output ids till now (sys and client prompt + generated by decoder) + std::vector full_output_ids_; // appeared token ids till now (refresh after encoding step) std::unordered_set appeared_token_ids_; // output message till now (refresh after encoding step) @@ -924,8 +942,8 @@ class LLMChat { Array kv_cache_; // Temp logits on cpu NDArray logits_on_cpu_{nullptr}; - // Counter of prefill-decode passes - int32_t pass_index_; + // Debug index + int32_t debug_index_; }; /*! From bb00d1ddcf0423b6fd179367064e61a3168a567f Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Fri, 25 Aug 2023 14:53:37 +0400 Subject: [PATCH 114/114] strong refactor based on vc/dev of mpt-like relax model to support using with/without kv cache --- mlc_llm/relax_model/mpt/mpt.py | 501 +++++++++++++++++++-------------- 1 file changed, 294 insertions(+), 207 deletions(-) diff --git a/mlc_llm/relax_model/mpt/mpt.py b/mlc_llm/relax_model/mpt/mpt.py index f099a43bea..08c27c508e 100644 --- a/mlc_llm/relax_model/mpt/mpt.py +++ b/mlc_llm/relax_model/mpt/mpt.py @@ -1,11 +1,11 @@ import math import warnings from typing import Optional, Tuple, List, Dict -import numpy as np import tvm -from tvm import relax +from tvm import relax, tir, te from tvm.relax.testing import nn +from tvm.script import relax as R from .mpt_config import MPTConfig, attn_config_defaults from ...utils import load_torch_pname2binname_map @@ -57,107 +57,141 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau return original_is_causal -def reshape_and_permute(hidden_states: relax.Expr, n_heads: int, d_model: int, indeces: List[int] = [0, 2, 1, 3]): - ''' - Transform shape of input: b s (h d) -> b s h d -> b h s d or b h d s - ''' - batch_size, seqlen, _ = hidden_states.struct_info.shape - inter = nn.emit(relax.op.reshape( - hidden_states, - (batch_size, seqlen, n_heads, int(d_model / n_heads)), - )) - return nn.emit(relax.op.permute_dims(inter, indeces)) - - -def reverse_reshape_and_permute(hidden_states: relax.Expr): - ''' - Transform shape of input: b h s d -> b s (h d) - ''' - batch_size, n_heads, seqlen, head_len = hidden_states.struct_info.shape - inter = nn.emit(relax.op.permute_dims(hidden_states, [0, 2, 1, 3])) - return nn.emit(relax.op.reshape( - inter, - (batch_size, seqlen, n_heads*head_len), - )) - - ######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (BEGIN) ########################## def scaled_multihead_dot_product_attention( - query, - key, - value, - n_heads, - d_model, - past_key_value=None, - softmax_scale=None, - attn_bias=None, - key_padding_mask=None, - is_causal=False, - needs_weights=False, - multiquery=False, + query: relax.Expr, + key: relax.Expr, + value: relax.Expr, + n_heads: int, + d_model: int, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, + softmax_scale: Optional[float]=None, + attn_bias: Optional[relax.Expr]=None, + key_padding_mask: Optional[relax.Expr]=None, + is_causal: bool=False, + needs_weights: bool=False, ): + head_dim = d_model // n_heads dtype = query.struct_info.dtype - q = reshape_and_permute(query, n_heads, d_model) - kv_n_heads = 1 if multiquery else n_heads - k = reshape_and_permute(key, kv_n_heads, d_model, [0, 2, 3, 1]) - v = reshape_and_permute(value, kv_n_heads, d_model) + + b, s_q, _ = query.struct_info.shape + assert b == 1, "Only support batch size 1 at this moment." + + q = nn.emit(relax.op.reshape(query, (b, s_q, n_heads, head_dim))) + k = nn.emit(relax.op.reshape(key, (b, -1, n_heads, head_dim))) + v = nn.emit(relax.op.reshape(value, (b, -1, n_heads, head_dim))) + if past_key_value is not None: - if len(past_key_value) != 0: - k = nn.emit(relax.op.concat([past_key_value[0], k], axis=3)) - v = nn.emit(relax.op.concat([past_key_value[1], v], axis=2)) - past_key_value = (k, v) - (b, _, s_q, d) = q.struct_info.shape - s_k = k.struct_info.shape[-1] + kv_seq_len = all_seq_len_shape.struct_info.values[0] + + kv_shape = k.struct_info.shape + kv_dtype = k.struct_info.dtype + assert kv_shape[0] == 1 # batch size + kv_shape = R.shape( + [kv_shape[0], kv_seq_len, kv_shape[2], kv_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_shape[2], kv_shape[3]]) + + # There is requirement b == 1 used + squeezed_key = nn.emit(relax.op.squeeze(k, axis=0)) + squeezed_value = nn.emit(relax.op.squeeze(v, axis=0)) + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[k_cache, squeezed_key], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[v_cache, squeezed_value], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_dtype)], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_dtype)], + ) + ) + k = nn.emit(relax.op.reshape(k_cache, kv_shape)) + v = nn.emit(relax.op.reshape(v_cache, kv_shape)) + s_k = k.struct_info.shape[1] if softmax_scale is None: - softmax_scale = 1 / math.sqrt(d) + softmax_scale = 1 / math.sqrt(head_dim) # TODO(vchernov): matmul(q, k) generates inf when float16 is used. There is workaround if dtype != "float32": q = nn.emit(relax.op.astype(q, "float32")) k = nn.emit(relax.op.astype(k, "float32")) softmax_scale = relax.op.astype(relax.const(softmax_scale), q.struct_info.dtype) - attn_weight = nn.emit(relax.op.matmul(q, k) * softmax_scale) - _, _, s_q_end, s_k_end = attn_bias.struct_info.shape + + q = nn.emit(relax.op.permute_dims(q, [0, 2, 1, 3])) + k = nn.emit(relax.op.permute_dims(k, [0, 2, 1, 3])) + v = nn.emit(relax.op.permute_dims(v, [0, 2, 1, 3])) + + attn_weight = nn.emit(relax.op.matmul(q, relax.op.permute_dims(k, [0, 1, 3, 2])) * softmax_scale) + # TODO(vchernov): attn_bias.shape is None due to it is not calculated in strided_slice with dynamic input + # _, _, s_q_end, s_k_end = attn_bias.struct_info.shape # shape = [1, 32, 1, seq_len] if attn_bias is not None: - _s_q = np.maximum(0, s_q_end - s_q) - _s_k = np.maximum(0, s_k_end - s_k) + # s_q = 1 for use_cache = True and = seq_len otherwise + # s_k = seq_len always + # TODO(vchernov): _s_q, _s_k can not be calculated due to reason above, but + # Trivial symbolic arithmetic shows that: + # _s_q = 0 always (s_q_end - s_q <= 0) + # _s_k = 0 + # _s_q = relax.op.maximum(0, s_q_end - s_q) + # _s_k = relax.op.maximum(0, s_k_end - s_k) + # TODO(vchernov): due to _s_q = 0 and _s_k = 0 the below slicing can be skipped # slicing attn_bias[:, :, _s_q:, _s_k:] - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) - if (attn_bias.struct_info.shape[-1] != 1 and - attn_bias.struct_info.shape[-1] != s_k or # dynamic condition? - (attn_bias.struct_info.shape[-2] != 1 and - attn_bias.struct_info.shape[-2] != s_q)): # dynamic condition? - raise RuntimeError(f'attn_bias (shape: {attn_bias.struct_info.shape}) is expected to broadcast to shape: {attn_weight.struct_info.shape}.') + # attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [2, 3], [_s_q, _s_k], [s_q_end, s_k_end])) # TODO(vchernov): matmul(q, k) generates inf when float16 is used. if dtype != "float32": attn_bias = nn.emit(relax.op.astype(attn_bias, "float32")) - attn_weight = attn_weight + attn_bias + attn_weight = nn.emit(attn_weight + attn_bias) min_val = get_type_min_val(q) if key_padding_mask is not None: if attn_bias is not None: warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.') - key_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) + key_mask = nn.emit(relax.op.logical_not(relax.op.reshape(key_padding_mask, (b, 1, 1, s_k)))) attn_weight = nn.emit(relax.op.masked_fill(attn_weight, key_mask, min_val)) - if is_causal and (not q.struct_info.shape[2] == 1): - s = relax.op.maximum(s_q, s_k) - causal_mask = nn.emit(relax.op.ones((s, s,), dtype="float16")) - causal_mask = nn.emit(relax.op.tril(causal_mask)) - causal_mask = nn.emit(relax.op.astype(causal_mask, "bool")) - causal_mask = nn.emit(relax.op.bitwise_not(causal_mask)) - # slicing causal_mask[-s_q:, -s_k:] - s_q_end, s_k_end = causal_mask.struct_info.shape - causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) - causal_mask = nn.emit(relax.op.reshape(causal_mask, (1, 1, s_q, s_k))) - attn_weight = nn.emit(relax.op.masked_fill(attn_weight, causal_mask, min_val)) + if is_causal and (not s_q == 1): + # It is the case where is no kv cache, thus s_q == s_k + # s = relax.op.maximum(s_q, s_k) + s = s_q + causal_mask = nn.emit(relax.op.ones((s, s,), dtype="bool")) + causal_mask = nn.emit(relax.op.triu(causal_mask, 1)) + # Due to the case the slicing below can be skipped + # slicing causal_mask[-s_q:, -s_k:] + # s_q_end, s_k_end = causal_mask.struct_info.shape + # causal_mask = nn.emit(relax.op.strided_slice(causal_mask, [0, 1], [s_q_end - s_q, s_k_end - s_k], [s_q_end, s_k_end])) + causal_mask = nn.emit(relax.op.broadcast_to(causal_mask, (b, n_heads, s, s))) + attn_weight = nn.emit(relax.op.masked_fill(attn_weight, causal_mask, min_val)) # TODO(vchernov): matmul(q, k) generates inf when float16 is used. # There is uncast after workaround with float calculation due to softmax range = [0, 1] attn_weight = nn.emit(relax.op.nn.softmax(attn_weight)) if dtype != "float32": attn_weight = nn.emit(relax.op.astype(attn_weight, dtype)) out = nn.emit(relax.op.matmul(attn_weight, v)) - out = reverse_reshape_and_permute(out) - return (out, past_key_value) + out = nn.emit(relax.op.permute_dims(out, [0, 2, 1, 3])) + out = nn.emit(relax.op.reshape(out, (b, tir.const(1, dtype="int64"), tir.const(d_model, dtype="int64")))) + + return out, past_key_value ######################### FLASH ATTENTION IMPLEMENTATION TYPE TORCH (END) ########################## @@ -291,7 +325,7 @@ def flash_attn_fn( # (b_size, s_k) = key_padding_mask.struct_info.shape[:2] # if attn_bias is None: # attn_bias = nn.emit(relax.op.zeros((b_size, 1, 1, s_k), dtype=query.struct_info.dtype)) -# key_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) +# key_mask = nn.emit(relax.op.logical_not(relax.op.reshape(key_padding_mask, (b_size, 1, 1, s_k)))) # attn_bias = nn.emit(relax.op.masked_fill(attn_bias, key_mask, get_type_min_val(query))) # batch_size, seq_len, _ = query.struct_info.shape @@ -371,7 +405,15 @@ def __init__( raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.') self.out_proj = Linear(self.d_model, self.d_model, dtype, bias=False) - def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True): + def forward( + self, + x: relax.Expr, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, + attn_bias: Optional[relax.Expr]=None, + attention_mask: Optional[relax.Expr] = None, + is_causal: bool=True, + ): qkv = self.Wqkv(x) if self.clip_qkv: qkv = nn.emit(relax.op.clip(qkv, min=relax.const(-self.clip_qkv), max=relax.const(self.clip_qkv))) @@ -384,12 +426,13 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i dtype = query.struct_info.dtype query = nn.emit(relax.op.astype(self.q_ln(query), dtype)) key = nn.emit(relax.op.astype(self.k_ln(key), dtype)) - attn_out = self.attn_fn( + attn_out, past_key_value = self.attn_fn( query, key, value, self.n_heads, self.d_model, + all_seq_len_shape=all_seq_len_shape, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, @@ -397,7 +440,7 @@ def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, i is_causal=is_causal, needs_weights=False, ) - return (self.out_proj(attn_out[0]), attn_out[1]) + return self.out_proj(attn_out), past_key_value ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention} @@ -442,7 +485,8 @@ def __init__(self, config: MPTConfig): def forward( self, hidden_states: relax.Expr, - past_key_value: Tuple[relax.Expr], + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_value: Optional[Tuple[relax.Expr]]=None, attn_bias: Optional[relax.Expr] = None, attention_mask: Optional[relax.Expr] = None, is_causal: bool=True, @@ -451,8 +495,9 @@ def forward( hidden_states = self.norm_1(hidden_states) # Self Attention - (hidden_states, present_key_value) = self.attn( + hidden_states, present_key_value = self.attn( hidden_states, + all_seq_len_shape=all_seq_len_shape, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, @@ -465,7 +510,7 @@ def forward( hidden_states = self.ffn(hidden_states) hidden_states = nn.emit(residual + hidden_states) - return (hidden_states, present_key_value) + return hidden_states, present_key_value def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id): @@ -493,9 +538,9 @@ def gen_slopes(n_heads, alibi_bias_max=8): slopes_len = slopes.struct_info.shape[0] slopes = nn.emit(relax.op.strided_slice( relax.op.concat( - [relax.op.strided_slice(slopes, [0], [relax.const(1)], [slopes_len], [relax.const(2)]), # [1::2] - relax.op.strided_slice(slopes, [0], [relax.const(0)], [slopes_len], [relax.const(2)])] # [::2] - ), [0], [relax.const(0)], [relax.const(n_heads)]) # slicing [:n_heads] + [relax.op.strided_slice(slopes, [0], [relax.const(1, dtype="int64")], [slopes_len], [relax.const(2)]), # [1::2] + relax.op.strided_slice(slopes, [0], [relax.const(0, dtype="int64")], [slopes_len], [relax.const(2)])] # [::2] + ), [0], [relax.const(0, dtype="int64")], [relax.const(n_heads, dtype="int64")]) # slicing [:n_heads] ) return nn.emit(relax.op.reshape(slopes, (1, n_heads, 1, 1))) @@ -528,7 +573,7 @@ def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi= def get_type_min_val(tensor): return relax.const( - tvm.tir.min_value(tensor.struct_info.dtype).value, + tir.min_value(tensor.struct_info.dtype).value, tensor.struct_info.dtype, ) @@ -547,7 +592,6 @@ def __init__(self, config: MPTConfig): self.n_heads = config.n_heads self.n_layers = config.n_layers self.max_seq_len = config.max_seq_len - self.return_dict = config.return_dict self.use_cache = config.use_cache self._attn_bias_initialized = False @@ -575,7 +619,7 @@ def __init__(self, config: MPTConfig): self.blocks = ModuleList([MPTBlock(config) for _ in range(config.n_layers)]) self.norm_f = norm_class(config.d_model, dtype=config.dtype) - def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_mask: Optional[relax.Expr]=None, sequence_id: Optional[relax.Expr]=None): + def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None): if not self._attn_bias_initialized: if self.attn_bias_shape: self.attn_bias = nn.emit(relax.op.zeros(self.attn_bias_shape, dtype=dtype)) @@ -588,92 +632,38 @@ def _attn_bias(self, dtype, attention_mask: Optional[relax.Expr]=None, prefix_ma if self.attn_bias is not None: self.attn_bias = nn.emit(relax.op.astype(self.attn_bias, dtype)) attn_bias = self.attn_bias - if self.prefix_lm: - assert isinstance(attn_bias, relax.Expr) - assert isinstance(prefix_mask, relax.Expr) - attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask) - if self.attn_uses_sequence_id and sequence_id is not None: - assert isinstance(attn_bias, relax.Expr) - attn_bias = self._apply_sequence_id(attn_bias, sequence_id) if attention_mask is not None: - s_k = attention_mask.struct_info.shape[-1] + s_k = attention_mask.struct_info.shape[1] # seq_len if attn_bias is None: attn_bias = nn.emit(relax.op.zeros((1, 1, 1, s_k), dtype=dtype)) else: - _s_k = relax.op.maximum(relax.const(0), attn_bias.struct_info.shape[-1] - s_k) + def attn_bias_te_slicing(x: te.Tensor, seq_len: tvm.tir.Var): + return te.compute( + shape=(x.shape[0], x.shape[1], x.shape[2], seq_len), + fcompute=lambda i, j, k, m: x[i, j, k, x.shape[3] - seq_len + m], + name="attn_bias_slice", + ) + + s_k_end = attn_bias.struct_info.shape[3] # config.max_seq_len = 2048 + # TODO(vchernov): it can not be calculated in relax + # _s_k = relax.op.maximum(relax.const(0), s_k_end - s_k) # slicing attn_bias[:, :, :, _s_k:] - s_k_end = attn_bias.struct_info.shape[3] - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [3], [_s_k], [s_k_end])) - if prefix_mask is not None and attention_mask.struct_info.shape != prefix_mask.struct_info.shape: - raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.') + # Need to use _s_k instead of s_k_end - s_k (attn_bias.shape = [1, 32, 1, seq_len]) + # attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [3], [s_k_end - s_k], [s_k_end])) + attn_bias = nn.emit_te(attn_bias_te_slicing, attn_bias, s_k, primfunc_name_hint="attn_bias_slice") min_val = get_type_min_val(attn_bias) - attn_mask = nn.emit(relax.op.bitwise_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) + attn_mask = nn.emit(relax.op.logical_not(relax.op.reshape(attention_mask, (-1, 1, 1, s_k)))) attn_bias = nn.emit(relax.op.masked_fill(attn_bias, attn_mask, min_val)) return (attn_bias, None) - def _apply_prefix_mask(self, attn_bias: relax.Expr, prefix_mask: relax.Expr): - s_k = attn_bias.struct_info.shape[-2] - s_q = attn_bias.struct_info.shape[-1] - if s_k != self.max_seq_len or s_q != self.max_seq_len: - raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.max_seq_len} ' + f'but are {s_k} and {s_q}.') - seq_len = prefix_mask.struct_info.shape[-1] - if seq_len > self.max_seq_len: - raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.max_seq_len}') - # slicing attn_bias[..., :seq_len, :seq_len] - dims_len = attn_bias.struct_info.ndim - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) - causal = nn.emit(relax.op.reshape(relax.op.tril(relax.op.ones((seq_len, seq_len), dtype="bool")), (1, 1, seq_len, seq_len))) - prefix = nn.emit(relax.op.reshape(prefix_mask, (-1, 1, 1, seq_len))) - cannot_attend = nn.emit(relax.op.bitwise_not(relax.op.logical_or(causal, relax.op.astype(prefix, "bool")))) - min_val = get_type_min_val(attn_bias) - attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) - return attn_bias - - def _apply_sequence_id(self, attn_bias: relax.Expr, sequence_id: relax.Expr): - seq_len = sequence_id.struct_info.shape[-1] - if seq_len > self.max_seq_len: - raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.max_seq_len}') - # slicing attn_bias[..., :seq_len, :seq_len] - dims_len = attn_bias.struct_info.ndim - attn_bias = nn.emit(relax.op.strided_slice(attn_bias, [dims_len - 2, dims_len - 1], [relax.const(0), relax.const(0)], [seq_len, seq_len])) - seq_id_l = nn.emit(relax.op.reshape(sequence_id, (-1, seq_len, 1))) - seq_id_r = nn.emit(relax.op.reshape(sequence_id, (-1, 1, seq_len))) - cannot_attend = nn.emit(relax.op.expand_dims(relax.op.logical_not(relax.op.equal(seq_id_l, seq_id_r)), axis=1)) - min_val = get_type_min_val(attn_bias) - attn_bias = nn.emit(relax.op.masked_fill(attn_bias, cannot_attend, min_val)) - return attn_bias - def forward( self, input_ids: relax.Expr, - past_key_values: Optional[List[Tuple[relax.Expr]]]=None, + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_values: Optional[relax.Expr]=None, attention_mask: Optional[relax.Expr]=None, - prefix_mask: Optional[relax.Expr]=None, - sequence_id: Optional[relax.Expr]=None, - return_dict: Optional[bool]=None, use_cache: Optional[bool]=None ): - # return_dict = return_dict if return_dict is not None else self.return_dict - use_cache = use_cache if use_cache is not None else self.use_cache - if attention_mask is not None: - attention_mask = nn.emit(relax.op.astype(attention_mask, "bool")) - # TODO(vchernov): I'm not sure we should calculate it and can compare in Relax - # It is part from prepare_inputs_for_generation - dim1_len = attention_mask.struct_info.shape[1] - if relax.op.sum( - relax.op.strided_slice(attention_mask, [1], [dim1_len - 1], [dim1_len]) - ) != attention_mask.struct_info.shape[0]: - raise NotImplementedError('MPT does not support generation with right padding.') - # if prefix_mask is not None: - # prefix_mask = nn.emit(relax.op.astype(prefix_mask, "bool")) - # if not return_dict: - # raise NotImplementedError('return_dict False is not implemented yet for MPT') - # if self.prefix_lm and prefix_mask is None: - # raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.') - - S = input_ids.struct_info.shape[1] - assert S <= self.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.max_seq_len}' - tok_emb = self.wte(input_ids) if self.alibi: x = tok_emb @@ -696,16 +686,30 @@ def forward( # pos = nn.emit(relax.op.clip(pos - pos_diff, min=0)) # pos_emb = self.wpe(pos) # x = tok_emb + pos_emb - (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id) - if use_cache and past_key_values is None: - past_key_values = [() for _ in range(self.n_layers)] + (attn_bias, attention_mask) = self._attn_bias(dtype=x.struct_info.dtype, attention_mask=attention_mask) + + # decoder layers + if past_key_values is not None: + next_decoder_cache = () + else: + next_decoder_cache = None + for (b_idx, block) in enumerate(self.blocks): - past_key_value = past_key_values[b_idx] if past_key_values is not None else None - (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal) + past_key_value = (past_key_values[b_idx * 2], past_key_values[b_idx * 2 + 1]) if past_key_values is not None else None + x, key_value_cache = block( + x, + all_seq_len_shape=all_seq_len_shape, + past_key_value=past_key_value, + attn_bias=attn_bias, + attention_mask=attention_mask, + is_causal=self.is_causal + ) if past_key_values is not None: - past_key_values[b_idx] = past_key_value + next_decoder_cache += key_value_cache x = self.norm_f(x) - return x, past_key_values + if past_key_values is not None: + assert len(next_decoder_cache) == len(self.blocks) * 2 + return x, next_decoder_cache class MPTForCausalLM(nn.Module): @@ -715,50 +719,126 @@ def __init__(self, config: MPTConfig): self.transformer = MPTModel(config) self.dtype = config.dtype - self.return_dict = config.return_dict self.use_cache = config.use_cache + def prepare_attention_mask_for_generation(self, input_ids=None, src_len=None): + if src_len is not None: + seq_len = src_len.struct_info.values[0] + shape = R.shape([1, seq_len]) + return nn.emit(relax.op.ones(shape, dtype="bool")) + else: + return nn.emit(relax.op.astype(relax.op.ones_like(input_ids), dtype="bool")) + def forward( self, input_ids: relax.Expr, - past_key_values: Optional[List[Tuple[relax.Expr]]]=None, - attention_mask: Optional[relax.Expr]=None, - prefix_mask: Optional[relax.Expr]=None, - sequence_id: Optional[relax.Expr]=None, - return_dict: Optional[bool]=None, - use_cache: Optional[bool]=None + all_seq_len_shape: Optional[relax.Expr]=None, + past_key_values: Optional[relax.Expr]=None, ): - # return_dict = return_dict if return_dict is not None else self.return_dict - use_cache = use_cache if use_cache is not None else self.use_cache + attention_mask = self.prepare_attention_mask_for_generation(input_ids, all_seq_len_shape) - # It is part from prepare_inputs_for_generation - if past_key_values is not None: - # slicing input_ids[:, -1] - dim1_len = input_ids.struct_info.shape[1] - input_ids_slice = nn.emit(relax.op.strided_slice(input_ids, [1], [dim1_len - 1], [dim1_len])) - input_ids = nn.emit(relax.op.expand_dims(input_ids_slice, axis=-1)) - outputs = self.transformer( - input_ids=input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - prefix_mask=prefix_mask, - sequence_id=sequence_id, - return_dict=self.return_dict, - use_cache=use_cache + logits, key_value_cache = self.transformer( + input_ids=input_ids, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache = self.use_cache, ) - logits = nn.emit(relax.op.linear(outputs[0], self.transformer.wte.weight)) + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = nn.emit_te(te_slicing, logits, primfunc_name_hint="slice") + + logits = nn.emit(relax.op.linear(logits, self.transformer.wte.weight)) if logits.struct_info.dtype != "float32": logits = nn.emit(relax.op.astype(logits, "float32")) - return logits, outputs[1] + return logits, key_value_cache -def create_decoding_func(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: + +def create_kv_cache_func(bb: relax.BlockBuilder, config: MPTConfig) -> None: + init_shape = relax.ShapeExpr( + ( + config.max_seq_len, + config.n_heads, + config.d_model // config.n_heads, + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.n_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_decoding_func_with_kv_cache(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: pidx2pname: Dict[int, str] = {} + bsz = 1 + all_seq_len = tvm.tir.Var("n", "int64") + with bb.function("decode"): model = MPTForCausalLM(config) - input_ids = nn.Placeholder((1, 1), dtype="int32", name="input_ids") + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var( + "all_seq_len", relax.ShapeStructInfo((all_seq_len,)) + ) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.n_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + + named_params = named_parameters(model) + for i, (name, param) in enumerate(named_params.items()): + pidx2pname[i] = name + assert param.same_as(params[i + 3]) + + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var("decode") + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + return pidx2pname + + +def create_decoding_func_wo_kv_cache(bb: relax.BlockBuilder, config: MPTConfig) -> Dict[int, str]: + pidx2pname: Dict[int, str] = {} + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + + with bb.function("decode"): + model = MPTForCausalLM(config) + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") with bb.dataflow(): logits, states = model(input_ids) @@ -801,40 +881,47 @@ def get_model(args, hf_config): model_path = args.model_path dtype = args.quantization.model_dtype - # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct - max_seq_len = args.max_seq_len if args.max_seq_len is not None and args.max_seq_len > 0 else 4096 # 4096 recommended + + if args.max_seq_len is not None and args.max_seq_len > 0: + max_seq_len = args.max_seq_len + elif hf_config["max_seq_len"] > 0: + max_seq_len = hf_config["max_seq_len"] + else: + # Recommendation from https://huggingface.co/mosaicml/mpt-7b-instruct + max_seq_len = 4096 hf_config.update({"max_seq_len": max_seq_len}) - # hf_config.update({"max_new_tokens": args.seq_len}) + hf_config.update({"use_cache": args.use_kv_cache}) config = MPTConfig(**hf_config, dtype=dtype) bb = relax.BlockBuilder() - pidx2pname = create_decoding_func(bb, config) + pidx2pname = None + if config.use_cache: + create_kv_cache_func(bb, config) + pidx2pname = create_decoding_func_with_kv_cache(bb, config) + else: + pidx2pname = create_decoding_func_wo_kv_cache(bb, config) create_softmax_func(bb, config) create_metadata_func( bb, model_name=model_name, - max_window_size=128, # TODO: temporal limit for max output length, change to -1 after tests + max_window_size=-1, stop_tokens=[0], - add_prefix_space=False, # TODO: what is it? + add_prefix_space=False, ) mod = bb.get() - def f_convert_pname_fwd(pname: str) -> str: - return pname - pname2binname = load_torch_pname2binname_map( - model_path, set(pidx2pname.values()), f_convert_pname_fwd + model_path, set(pidx2pname.values()) ) - def f_convert_param_bkwd(torch_pname: str, raw_param): - return [(torch_pname, raw_param)] - args.pidx2pname = pidx2pname args.pname2binname = pname2binname - args.f_convert_pname_fwd = f_convert_pname_fwd - args.f_convert_param_bkwd = f_convert_param_bkwd + args.f_convert_pname_fwd = lambda pname: pname + args.f_convert_param_bkwd = lambda torch_pname, raw_param: [ + (torch_pname, raw_param.astype(dtype)) + ] return mod, [None] * len(pidx2pname)