Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RoPE Interpolation #3564

Merged
merged 12 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions model/model_training/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -779,3 +779,32 @@ debug:
verbose: true
num_train_epochs: 0.2
dtype: fp32

rope_scaling_test:
dtype: bf16
log_dir: "llama_log_7b"
learning_rate: 1e-5
model_name: "huggyllama/llama-7b"
deepspeed_config: configs/zero_config_falcon.json
output_dir: llama
weight_decay: 0.0
max_length: 4048
warmup_steps: 100
gradient_checkpointing: true
gradient_accumulation_steps: 2
per_device_train_batch_size: 1
per_device_eval_batch_size: 1
eval_steps: 100
save_steps: 500
num_train_epochs: 8
save_total_limit: 4
use_flash_attention: false
residual_dropout: 0.3
residual_dropout_lima: true
log_wandb: true
peft_model: true
peft_type: "lora"
superhot: true
superhot_config:
type: linear
scale: 2
54 changes: 53 additions & 1 deletion model/model_training/models/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import torch.nn as nn
import transformers
from transformers import GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

from .patching_llama import llama_forward_with_flash_attn
from .patching_neox import neox_forward_with_flash_attn
from .reward_model import GPTNeoXRewardModel
from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope

SUPPORTED_MODELS = [
GPTNeoXModel,
Expand Down Expand Up @@ -176,3 +177,54 @@ def patch_model(
if resid_pdrop is not None and resid_pdrop > 0:
add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop)
add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop)


class RopePatch:
def __init__(self, model_name, **kwargs):
self.args = kwargs
rope_type = self.args.pop("type")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
architecture = config.architectures
if architecture:
self.model_name = architecture[0]
if "RWForCausalLM" in architecture:
self.architecture = "RWForCausalLM"
if rope_type == "ntk":
self.patch_fun = RWNTKScaledRope
else:
raise NotImplementedError()
elif "LlamaForCausalLM" in architecture:
self.architecture = "LlamaForCausalLM"
if rope_type == "linear":
self.patch_fun = LlamaLinearScaledRope
elif rope_type == "ntk":
self.patch_fun = LlamaNTKScaledRope
elif rope_type == "dynamic-ntk":
self.patch_fun = LlamaDynamicScaledRotaryEmbedding
else:
raise NotImplementedError()
else:
raise NotImplementedError()

@classmethod
def from_config(cls, config):
model_name = config.model_name
args = config.superhot_config
return cls(model_name, **args)

def patch(self, model):
if self.architecture == "RWForCausalLM":
self.patch_rw_model(model, **self.args)
elif self.architecture == "LlamaForCausalLM":
self.patch_llama_model(model, **self.args)
else:
raise NotImplementedError()

def patch_rw_model(self, model, **kwargs):
for each in model.transformer.h:
each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs)

def patch_llama_model(self, model, **kwargs):
kwargs.update({"device": model.device})
for each in model.model.layers:
each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs)
187 changes: 187 additions & 0 deletions model/model_training/models/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import torch


# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0


class RWNTKScaledRope(torch.nn.Module):
shahules786 marked this conversation as resolved.
Show resolved Hide resolved

"""
NTK-Scaled RoPE for RefinedWebModel
"""

def __init__(
self,
head_dim: int,
base=10000,
alpha: int = 2,
):
super().__init__()
self.alpha = alpha
base = base * self.alpha ** (head_dim / (head_dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.head_dim = head_dim
self.seq_len_cached = None
self.batch_size_cached = None
self.cos_cached: torch.Tensor | None = None
self.sin_cached: torch.Tensor | None = None

def cos_sin(
self,
seq_len: int,
device="cuda",
dtype=torch.bfloat16,
) -> torch.Tensor:
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(device)

if dtype in [torch.float16, torch.bfloat16]:
emb = emb.float()

self.cos_cached = emb.cos()[None, :, :]
self.sin_cached = emb.sin()[None, :, :]

self.cos_cached = self.cos_cached.type(dtype)
self.sin_cached = self.sin_cached.type(dtype)

return self.cos_cached, self.sin_cached

def forward(self, q, k):
batch, seq_len, head_dim = q.shape
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


class LlamaLinearScaledRope(torch.nn.Module):
"""
reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None):
super().__init__()
self.scale = 1 / scale
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
t *= self.scale
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaNTKScaledRope(torch.nn.Module):

"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None):
super().__init__()
base = base * alpha ** (dim / (dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)


class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
"""
reference: https://github.com/jquesnelle/scaled-rope
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None):
super().__init__()
self.ntk = ntk
self.base = base
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
if self.ntk:
base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** (
self.dim / (self.dim - 2)
)
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
if not self.ntk:
t *= self.max_position_embeddings / seq_len
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
7 changes: 5 additions & 2 deletions model/model_training/trainer_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# from model_training.custom_datasets.formatting import DatasetEntry
from model_training.custom_datasets.dialogue_collator import DialogueDataCollator
from model_training.efficiency_utils import fuse_gelu
from model_training.models.patching import RopePatch
from model_training.models.peft_modeling import peft_model
from model_training.utils.utils import (
PerDatasetSampler,
Expand Down Expand Up @@ -362,7 +363,6 @@ def main():
)

train, evals = get_dataset(training_conf)

show_dataset_stats = (training_conf.verbose or training_conf.show_dataset_stats) and (
not training_conf.deepspeed or training_conf.local_rank == 0
)
Expand Down Expand Up @@ -416,9 +416,12 @@ def main():
sampler = None

metrics, preprocess_fns = get_metrics(training_conf, tokenizer)

model = get_model(training_conf, tokenizer)

superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None
if superhot:
superhot.patch(model)

if training_conf.peft_model:
print("Using PEFT model")
model = peft_model(
Expand Down