diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index db3d2abcd76d..ba6cccbf6e8d 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -27,6 +27,7 @@ from data import get_convert_example from utils import ( CausalLMTrainer, + InTokensIterDatasetCallback, compute_metrics, get_lora_target_modules, get_prefix_tuning_params, @@ -162,16 +163,22 @@ def main(): train_ds, dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["train", "dev"]) # TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later. if training_args.resume_from_checkpoint is not None and data_args.lazy: - logger.warning( - f"Loading from '{training_args.resume_from_checkpoint}', manually skipping dataset and setting `ignore_data_skip` to True." + logger.info( + f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True." ) training_args.ignore_data_skip = True state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json")) - consumed_samples = ( - state.global_step - * training_args.per_device_train_batch_size - * training_args.gradient_accumulation_steps - * training_args.dataset_world_size + if state.trial_params is not None and "intokens_global_step" in state.trial_params: + consumed_samples = state.trial_params["intokens_global_step"] + else: + consumed_samples = ( + state.global_step + * training_args.per_device_train_batch_size + * training_args.gradient_accumulation_steps + * training_args.dataset_world_size + ) + logger.info( + f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'." ) train_ds = train_ds.skip(consumed_samples) @@ -299,6 +306,7 @@ def compute_metrics_do_generation(eval_preds): return_tensors="np", ), do_generation=data_args.eval_with_do_generation, + callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, gen_args=gen_args, data_args=data_args, ) diff --git a/llm/utils.py b/llm/utils.py index 55c96d644f12..263046859828 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -25,7 +25,8 @@ from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler from sklearn.metrics import accuracy_score -from paddlenlp.trainer import Trainer +from paddlenlp.datasets import InTokensIterableDataset +from paddlenlp.trainer import Trainer, TrainerCallback from paddlenlp.trainer.trainer_utils import has_length from paddlenlp.utils.log import logger @@ -144,6 +145,21 @@ def get_lora_target_modules(model): return target_modules +class InTokensIterDatasetCallback(TrainerCallback): + """ + A [`TrainerCallback`] that handles early stopping. + + """ + + def on_step_end(self, args, state, control, **kwargs): + train_dataloader = kwargs["train_dataloader"] + if not isinstance(train_dataloader.dataset, InTokensIterableDataset): + raise ValueError("InTokensIterDatasetCallback expectes `paddlenlp.datasets.InTokensIterableDataset`") + if state.trial_params is None: + state.trial_params = {} + state.trial_params["intokens_global_step"] = train_dataloader.dataset.intokens_global_step + + class CausalLMTrainer(Trainer): def __init__(self, do_generation: bool, gen_args, data_args, **kwargs): super().__init__(**kwargs) diff --git a/paddlenlp/datasets/intokens_dataset.py b/paddlenlp/datasets/intokens_dataset.py index 96faf538cbf0..795d82d93e8f 100644 --- a/paddlenlp/datasets/intokens_dataset.py +++ b/paddlenlp/datasets/intokens_dataset.py @@ -105,6 +105,7 @@ def __init__(self, data, tokenizer, max_length): self.data = data self.tokenizer = tokenizer self.max_length = max_length + self.intokens_global_step = 0 def __iter__(self): batch_records, max_len = [], 0 @@ -114,6 +115,7 @@ def __iter__(self): to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length if to_append: batch_records.append(record) + self.intokens_global_step += 1 cur_len_so_far += len(record["input_ids"]) else: # exceed max length @@ -124,6 +126,7 @@ def __iter__(self): cur_len_so_far = 0 # append current data batch_records.append(record) + self.intokens_global_step += 1 cur_len_so_far += len(record["input_ids"]) if batch_records: padded_list = self._pad_batch_records(batch_records)