Skip to content

Commit

Permalink
styles (#6892)
Browse files Browse the repository at this point in the history
  • Loading branch information
sijunhe authored Sep 1, 2023
1 parent 96de177 commit a488002
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
22 changes: 15 additions & 7 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from data import get_convert_example
from utils import (
CausalLMTrainer,
InTokensIterDatasetCallback,
compute_metrics,
get_lora_target_modules,
get_prefix_tuning_params,
Expand Down Expand Up @@ -162,16 +163,22 @@ def main():
train_ds, dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["train", "dev"])
# TODO(ZHUI & sijunhe): Temporary implementation. Generalize this logic and move to Trainer later.
if training_args.resume_from_checkpoint is not None and data_args.lazy:
logger.warning(
f"Loading from '{training_args.resume_from_checkpoint}', manually skipping dataset and setting `ignore_data_skip` to True."
logger.info(
f"Loading from '{training_args.resume_from_checkpoint}' with `lazy=True`, manually skipping dataset and setting `ignore_data_skip` to True."
)
training_args.ignore_data_skip = True
state = TrainerState.load_from_json(os.path.join(training_args.resume_from_checkpoint, "trainer_state.json"))
consumed_samples = (
state.global_step
* training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.dataset_world_size
if state.trial_params is not None and "intokens_global_step" in state.trial_params:
consumed_samples = state.trial_params["intokens_global_step"]
else:
consumed_samples = (
state.global_step
* training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.dataset_world_size
)
logger.info(
f"Skipping the first {consumed_samples} samples to warmup the dataset from checkpoint '{training_args.resume_from_checkpoint}'."
)
train_ds = train_ds.skip(consumed_samples)

Expand Down Expand Up @@ -299,6 +306,7 @@ def compute_metrics_do_generation(eval_preds):
return_tensors="np",
),
do_generation=data_args.eval_with_do_generation,
callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None,
gen_args=gen_args,
data_args=data_args,
)
Expand Down
18 changes: 17 additions & 1 deletion llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
from sklearn.metrics import accuracy_score

from paddlenlp.trainer import Trainer
from paddlenlp.datasets import InTokensIterableDataset
from paddlenlp.trainer import Trainer, TrainerCallback
from paddlenlp.trainer.trainer_utils import has_length
from paddlenlp.utils.log import logger

Expand Down Expand Up @@ -144,6 +145,21 @@ def get_lora_target_modules(model):
return target_modules


class InTokensIterDatasetCallback(TrainerCallback):
"""
A [`TrainerCallback`] that handles early stopping.
"""

def on_step_end(self, args, state, control, **kwargs):
train_dataloader = kwargs["train_dataloader"]
if not isinstance(train_dataloader.dataset, InTokensIterableDataset):
raise ValueError("InTokensIterDatasetCallback expectes `paddlenlp.datasets.InTokensIterableDataset`")
if state.trial_params is None:
state.trial_params = {}
state.trial_params["intokens_global_step"] = train_dataloader.dataset.intokens_global_step


class CausalLMTrainer(Trainer):
def __init__(self, do_generation: bool, gen_args, data_args, **kwargs):
super().__init__(**kwargs)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/datasets/intokens_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self, data, tokenizer, max_length):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.intokens_global_step = 0

def __iter__(self):
batch_records, max_len = [], 0
Expand All @@ -114,6 +115,7 @@ def __iter__(self):
to_append = (cur_len_so_far + len(record["input_ids"])) <= self.max_length
if to_append:
batch_records.append(record)
self.intokens_global_step += 1
cur_len_so_far += len(record["input_ids"])
else:
# exceed max length
Expand All @@ -124,6 +126,7 @@ def __iter__(self):
cur_len_so_far = 0
# append current data
batch_records.append(record)
self.intokens_global_step += 1
cur_len_so_far += len(record["input_ids"])
if batch_records:
padded_list = self._pad_batch_records(batch_records)
Expand Down

0 comments on commit a488002

Please sign in to comment.