(If you would like to use flash attention and not afraid to go in to the transformers source code look below before starting the training)
If not using multi-gpu you should be able to download the requirements from pip but if you want to use multi-gpu you need to install the requirements from their github repos.
- transformers
pip install transformers
or install from https://github.com/huggingface/transformerspip install git+https://github.com/huggingface/transformers.git
(if using flash attention you need to git clone the repo, edit themodeling_llama.py
file, and then install from the local repo) - peft
pip install peft
or install from https://github.com/huggingface/peftpip install git+https://github.com/huggingface/peft.git
- bitsandbytes
pip install bitsandbytes
if linux you're done, if windows follow the steps from this repo:
You can either edit the training script directly to load your desired data and process it in to the correct format or if it is a dataset available in datasets you can use the --dataset_name
and edit the training script to use the correct dataset keys. The section to edit for custom data loading are around line 210 in finetune_peft_8bit.py
.
Recommended starting parameters:
- epochs: 2-4
- batch_size: Larger effective batch sizes like 128 or 256 are recommended (effective batch size = batch_size * gradient_accumulation_steps * num_gpus)
- block_size: 512 if not using flash attention, 600 if using flash attention (assuming 24gb of vram and batch of 2 with 64 gradient accumulation, can push it up to 2048 if you have more vram)
- learning_rate: 2e-4 (From LoRA paper)
- warmup_ratio: 0.06 (From LoRA paper)
- lr_scheduler_type: linear (From LoRA paper)
- weight_decay: 0.1 # worked well in experiments but may want to test as it's not based on any paper
- optim: adamw or adamw_torch_fused (need pytorch 2.0)
Launch the training script with your desired parameters.
(For LLaMA 2 if you get a 'private repo' issue add the flag --use_auth_token=<huggingface_auth_token>
or login with the hugginface cli)
Example launch command:
python finetune_peft_8bit.py --num_train_epochs=2 --model_name_or_path=meta-llama/Llama-2-7b-hf --model_output_dir=LLaMA/LoRA/7B --output_dir=LLaMA/LoRA/train/7B --block_size=600 --per_device_train_batch_size=2 --gradient_accumulation_steps=64 --fp16=true --logging_steps=1 --log_level=info --learning_rate=2.0e-04 --lr_scheduler_type=linear --warmup_ratio=0.06 --weight_decay=0.1 --optim=adamw_torch_fused --evaluation_strategy=steps --save_strategy=steps --eval_steps=400 --save_steps=400 --output_dir="LoRA" --save_total_limit=3 --load_best_model_at_end=True --dataset_name=Dahoas/full-hh-rlhf --r=64 --lora_alpha=32 --lora_dropout=0.05
4-bit training example:
python finetune_peft_8bit.py --num_train_epochs=1 --model_name_or_path=meta-llama/Llama-2-7b-hf --model_output_dir=LLaMA/LoRA/7B --output_dir=LLaMA/LoRA/train/7B --bits=4 --bf16 --quant_type=nf4 --double_quant=True --gradient_checkpointing=True --block_size=2048 --per_device_train_batch_size=4 --gradient_accumulation_steps=32 --logging_steps=1 --log_level=info --learning_rate=2.0e-04 --lr_scheduler_type=linear --warmup_ratio=0.06 --weight_decay=0.1 --optim=paged_adamw_32bit --evaluation_strategy=steps --save_strategy=steps --eval_steps=400 --save_steps=400 --output_dir="LoRA" --save_total_limit=3 --load_best_model_at_end=True --dataset_name=Dahoas/full-hh-rlhf --r=64 --lora_alpha=32 --lora_dropout=0.05 --max_grad_norm=0.3
Or if using multiple gpus (make sure you have the correct number of gpus in the accelerate_config.yaml
file)
accelerate launch --config_file=accelerate_config.yaml finetune_peft_8bit.py --multi_gpu=True --tensor_parallel=False --num_train_epochs=2 --model_name_or_path=meta-llama/Llama-2-13b-hf --model_output_dir=LLaMA/LoRA/13B --output_dir=LLaMA/LoRA/train/13B --block_size=600 --per_device_train_batch_size=2 --gradient_accumulation_steps=8 --fp16=true --logging_steps=1 --log_level=info --learning_rate=2.0e-04 --lr_scheduler_type=linear --warmup_ratio=0.06 --weight_decay=0.1 --optim=adamw_torch_fused --evaluation_strategy=steps --save_strategy=steps --eval_steps=400 --save_steps=400 --output_dir="LoRA" --save_total_limit=3 --load_best_model_at_end=True --remove_unused_columns=False --dataset_name=Dahoas/full-hh-rlhf --r=64, --lora_alpha=32, --lora_dropout=0.05
Or if using multiple gpus and tensor parrallelism (go in to the main training file and edit the get_device_map
function with your number of devices and desired memory allocation)
python finetune_peft_8bit.py --num_train_epochs=2 --multi_gpu=True --tensor_parallel=True --model_name_or_path=meta-llama/Llama-2-70b-hf --model_output_dir=LLaMA/LoRA/70B --output_dir=LLaMA/LoRA/train/70B --block_size=600 --per_device_train_batch_size=2 --gradient_accumulation_steps=64 --fp16=true --logging_steps=1 --log_level=info --learning_rate=2.0e-04 --lr_scheduler_type=linear --warmup_ratio=0.06 --weight_decay=0.1 --optim=adamw_torch_fused --evaluation_strategy=steps --save_strategy=steps --eval_steps=400 --save_steps=400 --output_dir="LoRA" --save_total_limit=3 --load_best_model_at_end=True --dataset_name=Dahoas/full-hh-rlhf --r=64 --lora_alpha=32 --lora_dropout=0.05
If you want to use flash attention, you need to change the source code of the transformers library. You can find the source code here. You need to change the modeling_llama.py
file. located at transformers/src/transformers/models/llama/modeling_llama.py
. You need to change the LlamaAttention
class to the following:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.q_proj = nn.Linear(
hidden_size,
num_heads * self.head_dim,
bias=False,
)
self.k_proj = nn.Linear(
hidden_size,
num_heads * self.head_dim,
bias=False,
)
self.v_proj = nn.Linear(
hidden_size,
num_heads * self.head_dim,
bias=False,
)
self.o_proj = nn.Linear(
num_heads * self.head_dim,
hidden_size,
bias=False,
)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim)
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
self.register_buffer("bias", torch.tril(torch.ones(hidden_size, hidden_size))
.view(1, 1, hidden_size, hidden_size))
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
offset = 0
if past_key_value is not None:
offset = past_key_value[0].shape[-2]
kv_seq_len += offset
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, offset=offset)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if self.flash:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
else:
past_key_value = (key_states, value_states)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
The changes are
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
self.register_buffer("bias", torch.tril(torch.ones(hidden_size, hidden_size))
.view(1, 1, hidden_size, hidden_size))
and
if self.flash:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
else:
...
After those changes, pip install the local repo by cd'ing in to the transformers repo (root directory) and then running pip install .
. Then, add/uncomment the following to the imports of the training script:
import torch.backends.cuda
torch.backends.cuda.enable_flash_sdp(enabled=True)