Skip to content

Commit

Permalink
fix mrope
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Dec 12, 2024
1 parent bcb4fb3 commit 2811814
Show file tree
Hide file tree
Showing 13 changed files with 32 additions and 9 deletions.
Binary file modified assets/wechat.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified assets/wechat_npu.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
#### Supervised Fine-Tuning on Multiple Nodes

```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```

#### Multimodal Supervised Fine-Tuning
Expand Down
4 changes: 2 additions & 2 deletions examples/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
#### 在多机上进行指令监督微调

```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_full/llama3_full_sft.yaml
```

#### 多模态指令监督微调
Expand Down
9 changes: 9 additions & 0 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
feature["token_type_ids"] = token_type_ids[i]

features: Dict[str, "torch.Tensor"] = super().__call__(features)

if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], _ = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)

if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def run_dpo(

data_collator = PairwiseDataCollatorWithPadding(
template=template,
model=model,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/train/kto/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def run_kto(

data_collator = KTODataCollatorWithPadding(
template=template,
model=model,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
**tokenizer_module,
Expand Down
2 changes: 1 addition & 1 deletion src/llamafactory/train/ppo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run_ppo(
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)

tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, **tokenizer_module)
data_collator = MultiModalDataCollatorForSeq2Seq(template=template, model=model, **tokenizer_module)

# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/pt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import Trainer
from typing_extensions import override

from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler

Expand All @@ -38,6 +38,9 @@ class CustomTrainer(Trainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")

super().__init__(**kwargs)
self.finetuning_args = finetuning_args

Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/rm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing_extensions import override

from ...extras import logging
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler

Expand All @@ -48,6 +48,9 @@ class PairwiseTrainer(Trainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")

super().__init__(**kwargs)
self.finetuning_args = finetuning_args
self.can_return_loss = True # override property to return eval_loss
Expand Down
4 changes: 3 additions & 1 deletion src/llamafactory/train/rm/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def run_rm(
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module)
data_collator = PairwiseDataCollatorWithPadding(
template=template, model=model, pad_to_multiple_of=8, **tokenizer_module
)

# Update arguments
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
Expand Down
5 changes: 4 additions & 1 deletion src/llamafactory/train/sft/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler

Expand All @@ -51,6 +51,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
) -> None:
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")

super().__init__(**kwargs)
self.finetuning_args = finetuning_args

Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def run_sft(

data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model,
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
Expand Down

0 comments on commit 2811814

Please sign in to comment.