Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Rope scaling #3598

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ rope_scaling_test:
dtype: bf16
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
model_name: "meta-llama/Llama-2-13b-hf"
deepspeed_config: configs/zero_config.json
output_dir: llama
weight_decay: 0.0
Expand All @@ -811,7 +811,7 @@ rope_scaling_test:
superhot: true
superhot_config:
type: linear
scale: 2
scaling_factor: 2
datasets:
- dolly15k

Expand Down
29 changes: 20 additions & 9 deletions model/model_training/models/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
LlamaForCausalLM,
LlamaModel,
)
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
)
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_falcon import falcon_forward_with_flash_attn
from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope
from .rope import RWNTKScaledRope

SUPPORTED_MODELS = [
GPTNeoXModel,
Expand Down Expand Up @@ -200,25 +204,27 @@ def patch_model(
class RopePatch:
def __init__(self, model_name, **kwargs):
self.args = kwargs
rope_type = self.args.pop("type")
self.rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
if hasattr(config, "max_position_embeddings"):
self.args["max_position_embeddings"] = config.max_position_embeddings
if hasattr(config, "base"):
self.args["base"] = config.base
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "FalconForCausalLM" in architecture or "RWForCausalLM" in architecture:
self.architecture = "FalconForCausalLM"
if rope_type == "ntk":
if self.rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
raise NotImplementedError()
elif "LlamaForCausalLM" in architecture:
self.architecture = "LlamaForCausalLM"
if rope_type == "linear":
self.patch_fun = LlamaLinearScaledRope
elif rope_type == "ntk":
self.patch_fun = LlamaNTKScaledRope
elif rope_type == "dynamic-ntk":
self.patch_fun = LlamaDynamicScaledRotaryEmbedding
if self.rope_type == "linear":
self.patch_fun = LlamaLinearScalingRotaryEmbedding
elif self.rope_type == "dynamic":
self.patch_fun = LlamaDynamicNTKScalingRotaryEmbedding
else:
raise NotImplementedError()
else:
Expand All @@ -230,6 +236,9 @@ def from_config(cls, config):
args = config.superhot_config
return cls(model_name, **args)

def update_config(self, model, scaling_factor):
model.config["rope_scaling"] = {"type": self.rope_type, "factor": scaling_factor}

def patch(self, model):
if self.architecture == "FalconForCausalLM":
self.patch_falcon_model(model, **self.args)
Expand All @@ -238,6 +247,8 @@ def patch(self, model):
else:
raise NotImplementedError()

self.update_config(model, self.args.get("scaling_factor"))

def patch_falcon_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)
Expand Down
128 changes: 0 additions & 128 deletions model/model_training/models/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,131 +62,3 @@ def forward(self, q, k, past_key_values_length=0):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, past_key_values_length, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


class LlamaLinearScaledRope(torch.nn.Module):
"""
reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = 1 / scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaNTKScaledRope(torch.nn.Module):

"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
super().__init__()
base = base * alpha ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (
self.dim / (self.dim - 2)
)
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)