Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support yarn in turbomind backend #2519

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions lmdeploy/pytorch/backends/default/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,12 @@ def __init__(self,
self.register_buffer('inv_freq', inv_freq, persistent=False)

# get mscale
self.mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale) /
yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
if yarn_params.attention_factor is not None:
self.mscale = yarn_params.attention_factor
else:
self.mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale) /
yarn_get_mscale(self.scaling_factor, self.mscale_all_dim))
if self.mscale == 1.0:
self.mscale = None

Expand Down Expand Up @@ -334,10 +337,10 @@ def build(
return LlamaDynamicNTKScalingRotaryEmbedding(
dim, base, scaling_factor, max_position_embeddings)
elif emb_type == RopeType.Llama3:
return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor,
llama3_params.low_freq_factor,
llama3_params.high_freq_factor,
max_position_embeddings)
return Llama3RotaryEmbeddingImpl(
dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor,
llama3_params.original_max_position_embeddings)
elif emb_type == RopeType.Yarn:
return YarnRotaryEmbeddingImpl(dim,
base,
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class YarnParameters:
beta_slow: float = 1
mscale: int = 1
mscale_all_dim: int = 0
attention_factor: int = None


@dataclass
Expand All @@ -39,6 +40,7 @@ class Llama3Parameters:
"""llama3 rope parameters."""
low_freq_factor: float = 1.0
high_freq_factor: float = 4.0
original_max_position_embeddings: int = 8192


class RotaryEmbeddingImpl(ABC):
Expand Down
10 changes: 6 additions & 4 deletions lmdeploy/pytorch/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm,
SiluAndMul, build_rotary_embedding,
build_rotary_params)
from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
Expand Down Expand Up @@ -239,15 +240,16 @@ def __init__(self,
device=device)

# build rotary embedding
emb_type = RopeType.LinearScaling
# emb_type = RopeType.LinearScaling
rope_params = build_rotary_params(config)
rope_dim = config.hidden_size // config.num_attention_heads
rope_max_pos_emb = config.max_position_embeddings
rope_base = config.rope_theta
self.rotary_emb = build_rotary_embedding(
rope_dim,
rope_max_pos_emb,
rope_base,
emb_type=emb_type,
**rope_params,
)

def forward(
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .rotary_embedding import RopeType # noqa: F401
from .rotary_embedding import YarnParameters # noqa: F401
from .rotary_embedding import build_rotary_embedding # noqa: F401
from .rotary_embedding import build_rotary_params # noqa: F401
79 changes: 79 additions & 0 deletions lmdeploy/pytorch/nn/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch import Tensor, nn
from transformers import PretrainedConfig

from ..backends import OpType, get_backend
from ..backends.rotary_embedding import (Llama3Parameters,
LongRoPEScalingParameters, RopeType,
YarnParameters)


def _get_default_rope_parameters(config: PretrainedConfig):
"""get default rope parameters."""
return dict(emb_type=RopeType.Default, scaling_factor=1.0)


def _get_linear_scaling_rope_parameters(config: PretrainedConfig):
"""get linear rope parameters."""
rope_scaling = config.rope_scaling
scaling_factor = rope_scaling['factor']
return dict(emb_type=RopeType.LinearScaling, scaling_factor=scaling_factor)


def _get_dynamic_ntk_parameters(config: PretrainedConfig):
"""get dynamic ntk parameters."""
rope_scaling = config.rope_scaling
scaling_factor = rope_scaling['factor']
return dict(emb_type=RopeType.DynamicNTKScaling,
scaling_factor=scaling_factor)


def _get_yarn_parameters(config: PretrainedConfig):
"""get yarn parameters."""
rope_scaling = config.rope_scaling
scaling_factor = rope_scaling['factor']
params = YarnParameters()
params.attention_factor = rope_scaling.get('attention_factor',
params.attention_factor)
params.beta_fast = rope_scaling.get('beta_fast', params.beta_fast)
params.beta_slow = rope_scaling.get('beta_slow', params.beta_slow)
return dict(emb_type=RopeType.Yarn,
scaling_factor=scaling_factor,
yarn_params=params)


def _get_longrope_parameters(config: PretrainedConfig):
"""get longrope parameters."""
rope_scaling = config.rope_scaling
params = LongRoPEScalingParameters()
scaling_factor = rope_scaling['factor']
params.long_factor = rope_scaling.long_factor
params.short_factor = rope_scaling.long_factor
params.original_max_position_embeddings = rope_scaling.get(
'original_max_position_embeddings', config.max_position_embeddings)
return dict(emb_type=RopeType.Yarn,
scaling_factor=scaling_factor,
longrope_params=params)


def _get_llama3_parameters(config: PretrainedConfig):
"""get llama rope parameters."""
rope_scaling = config.rope_scaling
params = Llama3Parameters()
scaling_factor = rope_scaling['factor']
params.low_freq_factor = rope_scaling['low_freq_factor']
params.high_freq_factor = rope_scaling['high_freq_factor']
params.original_max_position_embeddings = rope_scaling.get(
'original_max_position_embeddings',
params.original_max_position_embeddings)
return dict(emb_type=RopeType.Llama3,
scaling_factor=scaling_factor,
llama3_params=params)


def build_rotary_params(config: PretrainedConfig):
"""get scaling_factor rotary params, and emb_type."""
params = dict(emb_type=RopeType.Default)
if config.rope_scaling is not None:
rope_type_str = config.rope_scaling.get('rope_type', 'default')
build_funcs = dict(default=_get_default_rope_parameters,
linear=_get_linear_scaling_rope_parameters,
dynamic=_get_dynamic_ntk_parameters,
yarn=_get_yarn_parameters,
longrope=_get_longrope_parameters,
llama3=_get_llama3_parameters)
params.update(build_funcs[rope_type_str](config))
return params


def build_rotary_embedding(dim: int,
max_position_embeddings: int = 2048,
base: int = 10000,
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ class ModelConfig:
class AttentionConfig:
rotary_embedding: int = 128
rope_theta: float = 10000.0
attention_factor: float = None
max_position_embeddings: int = 0
original_max_position_embeddings: int = 0
rope_scaling_type: str = ''
rope_scaling_factor: float = 0.0
use_dynamic_ntk: int = 0
low_freq_factor: float = 1.0
high_freq_factor: float = 1.0
beta_fast: float = 32.0
beta_slow: float = 1.0
use_logn_attn: int = 0
cache_block_seq_len: int = 64

Expand Down
12 changes: 11 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ def model_info(self):
scaling_type = ''
low_freq_factor = 1.0
high_freq_factor = 1.0
attention_factor = -1.0
beta_fast = 32.0
beta_slow = 1.0
original_max_position_embeddings = 0
if isinstance(rope_scaling, dict):
llama2_scaling_type = model_arg['rope_scaling'].get('type', '')
Expand All @@ -236,6 +239,10 @@ def model_info(self):
else llama3_scaling_type
if scaling_type == 'dynamic':
use_dynamic_ntk = 1
attention_factor = model_arg['rope_scaling'].get(
'attention_factor', None)
beta_fast = model_arg['rope_scaling'].get('beta_fast', 32.0)
beta_slow = model_arg['rope_scaling'].get('beta_slow', 1.0)

return dict(
num_layer=num_layer,
Expand All @@ -250,4 +257,7 @@ def model_info(self):
rope_scaling_type=scaling_type,
rope_scaling_factor=scaling_factor,
low_freq_factor=low_freq_factor,
high_freq_factor=high_freq_factor)
high_freq_factor=high_freq_factor,
attention_factor=attention_factor,
beta_fast=beta_fast,
beta_slow=beta_slow)
5 changes: 5 additions & 0 deletions src/turbomind/kernels/attention/attention_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,17 @@ struct AttentionParams {
// rotary embedding
int rotary_embedding_dim;
float rotary_embedding_base;
float rope_scaling_factor;
float attention_scaling;
int max_position_embeddings;
float rope_ti_scale; // used for linear RoPE scaling
// the following 3 parameters are used by llama3
float llama3_inv_scaling_factor;
float llama3_alpha;
float llama3_beta;
// the following are use by yarn
float yarn_ramp_min;
float yarn_ramp_max;

// log(n) attention
bool use_logn_attn;
Expand Down
4 changes: 4 additions & 0 deletions src/turbomind/kernels/attention/attention_universal.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,13 @@ struct AttentionUniversal {
params.rotary_embedding_dim,
rope_base,
params.rope_ti_scale,
params.rope_scaling_factor,
params.llama3_inv_scaling_factor,
params.llama3_alpha,
params.llama3_beta,
params.yarn_ramp_min,
params.yarn_ramp_max,
params.attention_scaling,
std::integral_constant<int, kVecSize>{});
PRAGMA_UNROLL
for (int s = 0; s < ITER_S; ++s) {
Expand Down
Loading
Loading