Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] Support skip data intervals #8989

Merged
merged 12 commits into from
Sep 23, 2024
20 changes: 18 additions & 2 deletions paddlenlp/trainer/argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
from typing import (
Any,
Dict,
Iterable,
NewType,
Optional,
Tuple,
Union,
get_args,
get_type_hints,
)

DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -129,7 +139,13 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif isclass(origin_type) and issubclass(origin_type, list):
kwargs["type"] = field.type.__args__[0]
# supprt one dimension list and two dimension list
if hasattr(get_args(field.type)[0], "__args__"):
kwargs["type"] = field.type.__args__[0].__args__[0]
kwargs["action"] = "append"
else:
kwargs["type"] = field.type.__args__[0]

kwargs["nargs"] = "+"
if field.default_factory is not dataclasses.MISSING:
kwargs["default"] = field.default_factory()
Expand Down
93 changes: 77 additions & 16 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
PREFIX_CHECKPOINT_DIR,
EvalLoopOutput,
EvalPrediction,
IntervalStrategy,
IterableDatasetShard,
OptimizerNames,
PredictionOutput,
Expand All @@ -139,6 +140,7 @@
get_scheduler,
has_length,
set_seed,
should_skip_data,
speed_metrics,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -287,9 +289,16 @@

# Seed must be set before instantiating the model when using model
set_seed(seed=self.args.seed)

self._skip_global_steps = 0 # total skip global steps
self._skip_steps_since_last_logged = 0 # skip steps since last logged
if model is None:
raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
logger.warning("Model is None.")
self.model = None
self.train_dataset = train_dataset
self.tokenizer = tokenizer
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
self.data_collator = data_collator if data_collator is not None else default_collator
return

Check warning on line 301 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L295-L301

Added lines #L295 - L301 were not covered by tests

if self.args.to_static:
model = paddle.jit.to_static(model)
Expand Down Expand Up @@ -945,6 +954,7 @@
step_control = 0 # used in loop control, reset to 0 after every step
self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

step = -1
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
Expand Down Expand Up @@ -981,6 +991,44 @@
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if should_skip_data(self.state.global_step, self.args.skip_data_intervals):
# skip this step

if (step_control + 1) % self.args.gradient_accumulation_steps == 0 or (

Check warning on line 997 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L997

Added line #L997 was not covered by tests
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
and (step + 1) == steps_in_epoch
):
# update current global step and skip step
self.state.global_step += 1
self._skip_global_steps += 1
self._skip_steps_since_last_logged += 1

Check warning on line 1005 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1003-L1005

Added lines #L1003 - L1005 were not covered by tests

self.state.epoch = epoch + (step + 1) / steps_in_epoch

Check warning on line 1007 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1007

Added line #L1007 was not covered by tests

if self.state.global_step == 1 and self.args.logging_first_step:
self.control.should_log = True
if (

Check warning on line 1011 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1009-L1011

Added lines #L1009 - L1011 were not covered by tests
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.args.logging_steps == 0
):
self.control.should_log = True

Check warning on line 1015 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1015

Added line #L1015 was not covered by tests

self.control.should_evaluate = False
self.control.should_save = False

Check warning on line 1018 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1017-L1018

Added lines #L1017 - L1018 were not covered by tests

# log loss and memeory usage
self._maybe_log_save_evaluate(tr_loss, model, epoch, ignore_keys_for_eval, inputs=inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也不需要了吧?

Copy link
Contributor Author

@greycooker greycooker Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_maybe_log_save_evaluate这里是为了去走:
1.tr_loss的重置:

tr_loss.subtract_(tr_loss)

2._globalstep_last_logged的更新:
self._globalstep_last_logged = self.state.global_step

3.正常的eval流程。不然最后eval计算consumed_samples的时候会有问题https://github.com/PaddlePaddle/PaddleNLP/blob/48820cbc1fe986004f817c0517886735675732d2/paddlenlp/trainer/trainer.py#L2792C6-L2797C18

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我主要的担心的是,skip数据的时候,碰到了eval 或 者 save 等各种各样的call back 是否有问题。
还是说,我们这里可以只处理数据,其他一律不触发。当然 step之类的更新加上。

self._print_timer()
step_control = 0

Check warning on line 1023 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1021-L1023

Added lines #L1021 - L1023 were not covered by tests
else:
step_control += 1
if self.state.global_step >= self.state.max_steps:
break

Check warning on line 1027 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1025-L1027

Added lines #L1025 - L1027 were not covered by tests

self.timers and self.timers("read-data").start()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉很多东西你可能不需要啊,没有计算的话,一些call_back 触发不知道有没有问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为了进行一些判断,比如是否应该进行eval、save和停止训练。没有经过前反向计算直接执行callback我测试的时候没有报错,不过可能确实会有一些没测试到的潜在风险。
https://github.com/PaddlePaddle/PaddleNLP/blob/48820cbc1fe986004f817c0517886735675732d2/paddlenlp/trainer/trainer_callback.py#L432C1-L460C23

continue

Check warning on line 1030 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1029-L1030

Added lines #L1029 - L1030 were not covered by tests

if step_control % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()
Expand Down Expand Up @@ -1202,7 +1250,13 @@
)

self._total_loss_scalar += tr_loss.item()
train_loss = self._total_loss_scalar / self.state.global_step

# In case all steps were skipped, the total loss is set to 0.
if self.state.global_step == self._skip_global_steps:
logger.info("All steps were skipped, the total loss is set to 0.")
train_loss = 0.0

Check warning on line 1257 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1256-L1257

Added lines #L1256 - L1257 were not covered by tests
else:
train_loss = self._total_loss_scalar / (self.state.global_step - self._skip_global_steps)

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)

Expand Down Expand Up @@ -1321,15 +1375,20 @@
if self.control.should_log:

logs: Dict[str, float] = {}

num_steps = self.state.global_step - self._globalstep_last_logged - self._skip_steps_since_last_logged
self._skip_steps_since_last_logged = 0
# all_gather + mean() to get average loss over all processes
avg_loss = self._nested_gather(tr_loss).mean()
tr_loss_scalar = self._get_item_from_loss(avg_loss)

# reset tr_loss to zero
tr_loss.subtract_(tr_loss)
# set loss to zero if all steps are skipped since last log
if num_steps == 0:
logs["loss"] = 0.0

Check warning on line 1388 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1388

Added line #L1388 was not covered by tests
else:
logs["loss"] = round(tr_loss_scalar / num_steps, 8)

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)
if in_auto_parallel_align_mode():
Expand All @@ -1352,7 +1411,7 @@
total_train_batch_size = (
self.args.train_batch_size * self.args.gradient_accumulation_steps * self.args.dataset_world_size
)
num_steps = self.state.global_step - self._globalstep_last_logged

seq_length = None
model_flops = None
if getattr(self, "is_pretraining", False) and hasattr(self.model, "config"):
Expand All @@ -1362,16 +1421,18 @@
except NotImplementedError:
model_flops = None

logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
model_flops=model_flops,
# Do not log speed metrics if all steps are skipped since last log.
if num_steps > 0:
logs.update(
speed_metrics(
"interval",
self._globalstep_last_start_time,
num_samples=total_train_batch_size * num_steps,
num_steps=num_steps,
seq_length=seq_length,
model_flops=model_flops,
)
)
)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
Expand Down Expand Up @@ -3255,7 +3316,7 @@
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
if not self.args.remove_unused_columns or self.model is None:

Check warning on line 3319 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L3319

Added line #L3319 was not covered by tests
return dataset
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
Expand Down
17 changes: 17 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,20 @@
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_ and local_seed not in tracker.seeds_:
tracker.add("local_seed", local_seed)


def should_skip_data(global_step, skip_data_intervals):
"""Whether to skip current step data"""

if skip_data_intervals is None:
return False
skip_flag = False
for interval in skip_data_intervals:
if len(interval) != 2 or interval[0] > interval[1] or interval[0] <= 0:
raise ValueError(f"Please check your skip interval {interval}")
start_global_step, end_global_step = interval[0], interval[1]

Check warning on line 1119 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1115-L1119

Added lines #L1115 - L1119 were not covered by tests
# start_global_step and end_global_step start from 1, while global_step start from 0
if start_global_step <= global_step + 1 <= end_global_step:
skip_flag = True
break
return skip_flag

Check warning on line 1124 in paddlenlp/trainer/trainer_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer_utils.py#L1121-L1124

Added lines #L1121 - L1124 were not covered by tests
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,10 @@ class TrainingArguments:
release_grads: Optional[bool] = field(
default=False, metadata={"help": "Whether to release gradients during training. Default is `False`."}
)
skip_data_intervals: Optional[List[List[int]]] = field(
default=None,
metadata={"help": "The intervals to skip, pass start global step and end global step at each interval"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down
Loading