Skip to content

Commit

Permalink
[Trainer] Support MoE
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed May 30, 2024
1 parent efd29c0 commit d854330
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 34 deletions.
113 changes: 89 additions & 24 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
GroupShardedOptimizerStage2,
)
from paddle.utils import map_structure

try:
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
Expand Down Expand Up @@ -143,6 +144,7 @@
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -565,7 +567,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
)
self.model.set_state_dict(state_dict)
else:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_moe):
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

Check warning on line 570 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L570

Added line #L570 was not covered by tests

weights_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
Expand Down Expand Up @@ -940,12 +942,17 @@ def _inner_training_loop(
forbidden_no_sync = True

availiable_no_sync = dp_enabled and not forbidden_no_sync
has_no_sync = hasattr(model, "no_sync")

Check warning on line 945 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L945

Added line #L945 was not covered by tests

is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
)
or (args.recompute and availiable_no_sync)
or args.use_expert_parallel
)
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -956,14 +963,25 @@ def _inner_training_loop(
if dp_master_grad:
is_no_sync = True

if is_no_sync:
if is_no_sync and has_no_sync:

Check warning on line 966 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L966

Added line #L966 was not covered by tests
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)

tr_loss += tr_loss_step
def fused_allreduce_gradients_no_sync(param_list, hcg):
param_list = list(param_list)
nonmoe_list = [p for p in param_list if not getattr(p, "no_sync", False)]
moe_list = [p for p in param_list if getattr(p, "no_sync", False)]
if moe_list and not self.args.use_expert_parallel:
logger.warning("found `no_sync` param when `use_expert_parallel=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

Check warning on line 979 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L973-L979

Added lines #L973 - L979 were not covered by tests

if tr_loss_step is not None:
if tr_loss is None:
tr_loss = map_structure(lambda x: paddle.zeros_like(x), tr_loss_step)
map_structure(lambda x, y: x.add_(y), tr_loss, tr_loss_step)

Check warning on line 984 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L981-L984

Added lines #L981 - L984 were not covered by tests

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
Expand All @@ -983,12 +1001,13 @@ def _inner_training_loop(

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
# Case 1.1: pure dp + moe should manually collect gradient here.
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 1006 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1005-L1006

Added lines #L1005 - L1006 were not covered by tests

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

Check warning on line 1010 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1010

Added line #L1010 was not covered by tests

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,7 +1026,9 @@ def _inner_training_loop(
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)
fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)

Check warning on line 1029 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1029

Added line #L1029 was not covered by tests
else:
assert not self.args.use_expert_parallel, "moe should not use `enable_dp_comm_overlap`"

Check warning on line 1031 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1031

Added line #L1031 was not covered by tests

self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()
Expand Down Expand Up @@ -1132,7 +1153,7 @@ def _inner_training_loop(
"on multiple nodes, you should activate `--save_on_each_node`."
)

self._total_loss_scalar += tr_loss.item()
self._total_loss_scalar += tr_loss.pop("loss").item() if isinstance(tr_loss, dict) else tr_loss.item()

Check warning on line 1156 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1156

Added line #L1156 was not covered by tests
train_loss = self._total_loss_scalar / self.state.global_step

metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
Expand Down Expand Up @@ -1250,12 +1271,22 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
logs: Dict[str, float] = {}

# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._get_item_from_loss(self._nested_gather(tr_loss).mean())
tr_loss_scalar = map_structure(lambda x: self._get_item_from_loss(self._nested_gather(x).mean()), tr_loss)

Check warning on line 1274 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1274

Added line #L1274 was not covered by tests

# reset tr_loss to zero
tr_loss.subtract_(tr_loss)

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
map_structure(lambda x: x.zero_(), tr_loss)

Check warning on line 1277 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1277

Added line #L1277 was not covered by tests

if isinstance(tr_loss_scalar, dict):
for k, v in tr_loss_scalar.items():
logs[k] = round(v / (self.state.global_step - self._globalstep_last_logged), 8)
elif isinstance(tr_loss_scalar, (list, tuple)):
for i, v in enumerate(tr_loss_scalar):
logs[f"loss_{i}"] = round(v / (self.state.global_step - self._globalstep_last_logged), 8)

Check warning on line 1284 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1279-L1284

Added lines #L1279 - L1284 were not covered by tests
else:
logs["loss"] = round(

Check warning on line 1286 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1286

Added line #L1286 was not covered by tests
tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged),
8,
)
logs["learning_rate"] = float("{0:.3e}".format(self._get_learning_rate()))
logs["global_step"] = int(self.state.global_step)

Expand Down Expand Up @@ -1290,7 +1321,9 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
)
)

self._total_loss_scalar += tr_loss_scalar
self._total_loss_scalar += (

Check warning on line 1324 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L1324

Added line #L1324 was not covered by tests
tr_loss_scalar.pop("loss") if isinstance(tr_loss_scalar, dict) else tr_loss_scalar
)
self._globalstep_last_logged = self.state.global_step
self._globalstep_last_start_time = time.time()

Expand Down Expand Up @@ -2047,14 +2080,19 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
loss = self.compute_loss(model, inputs)

if self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():
loss = loss / self.args.gradient_accumulation_steps
loss = map_structure(lambda x: x / self.args.gradient_accumulation_steps, loss)

Check warning on line 2083 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2083

Added line #L2083 was not covered by tests

if isinstance(loss, dict):
total_loss = loss["loss"]

Check warning on line 2086 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2085-L2086

Added lines #L2085 - L2086 were not covered by tests
else:
total_loss = loss

Check warning on line 2088 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2088

Added line #L2088 was not covered by tests

if self.do_grad_scaling:
self.scaler.scale(loss).backward()
self.scaler.scale(total_loss).backward()

Check warning on line 2091 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2091

Added line #L2091 was not covered by tests
else:
loss.backward()
total_loss.backward()

Check warning on line 2093 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2093

Added line #L2093 was not covered by tests

return loss.detach()
return map_structure(lambda v: v.detach(), loss)

Check warning on line 2095 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2095

Added line #L2095 was not covered by tests

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
"""
Expand Down Expand Up @@ -2113,6 +2151,18 @@ def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle

return loss.detach()

def _save_moe_weights(
self,
output_dir,
merge_tensor_parallel: Optional[bool] = False,
):
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
if not self.args.ignore_save_lr_and_optim:
self._save_ckpt_func(

Check warning on line 2161 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2159-L2161

Added lines #L2159 - L2161 were not covered by tests
self.optimizer.state_dict(),
os.path.join(output_dir, _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)),
)

def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Optional[bool] = False):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -2126,7 +2176,12 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
if ShardingOption.FULL_SHARD in self.args.sharding:
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
if not self.is_in_train and self.args.use_expert_parallel:
should_save_model_state = self.args.should_save_moe_model_state

Check warning on line 2180 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2179-L2180

Added lines #L2179 - L2180 were not covered by tests
else:
should_save_model_state = self.args.should_save_model_state

Check warning on line 2182 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2182

Added line #L2182 was not covered by tests

if should_save_model_state:

Check warning on line 2184 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2184

Added line #L2184 was not covered by tests
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
Expand Down Expand Up @@ -2245,6 +2300,10 @@ def _save_checkpoint(self, model, metrics=None):
os.makedirs(output_dir, exist_ok=True)
paddle.save(rng_states, os.path.join(output_dir, "rng_state.pth"))

if self.args.use_expert_parallel and self.args.data_parallel_rank > 0:
logger.info("Saving moe weights for data_parallel_rank > 0")
self._save_moe_weights(output_dir)

Check warning on line 2305 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2303-L2305

Added lines #L2303 - L2305 were not covered by tests

# Maybe delete some older checkpoints.
# For hybrid parallel training, the checkpoint files maybe on different node.
need_to_rotate_checkpoints = False
Expand Down Expand Up @@ -2452,7 +2511,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")

if not use_unified_checkpoint:
if self.args.data_parallel_rank == 0:
if self.args.use_expert_parallel or self.args.data_parallel_rank == 0:

Check warning on line 2514 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2514

Added line #L2514 was not covered by tests
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(checkpoint, optimizer_name)
if os.path.isfile(path):
Expand All @@ -2476,7 +2535,11 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if not self.args.use_expert_parallel:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 2539 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2538-L2539

Added lines #L2538 - L2539 were not covered by tests
else:
state_dict = self.model.state_dict()
opt_state_dict = broadcast_moe_optimizer(state_dict, opt_state_dict)

Check warning on line 2542 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2541-L2542

Added lines #L2541 - L2542 were not covered by tests

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down Expand Up @@ -2939,6 +3002,8 @@ def prediction_step(
if has_labels:
with self.autocast_smart_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
if isinstance(loss, dict):
loss = loss.pop("loss")

Check warning on line 3006 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L3005-L3006

Added lines #L3005 - L3006 were not covered by tests
loss = loss.mean().detach()

if isinstance(outputs, dict):
Expand Down
43 changes: 33 additions & 10 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to run distributed training in auto parallel mode"},
)
use_moe: Optional[bool] = field(
use_expert_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Use MoE training."},
)
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def is_segment_parallel_supported():
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_moe:
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]

Check warning on line 1158 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1157-L1158

Added lines #L1157 - L1158 were not covered by tests

if is_segment_parallel_supported():
Expand Down Expand Up @@ -1649,12 +1649,12 @@ def optimizer_name_suffix(self):
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.sharding_parallel_degree > 1:
name.append(self._format_name("shard", self.sharding_parallel_rank, self.sharding_parallel_degree))
if self.use_moe:
name.append(f"moe{self.data_parallel_rank:0>2d}")
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1653 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1652-L1653

Added lines #L1652 - L1653 were not covered by tests
return "_".join(name)
else:
if self.use_moe:
return f"moe{self.data_parallel_rank:0>2d}"
if self.use_expert_parallel:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1657 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1656-L1657

Added lines #L1656 - L1657 were not covered by tests
return None

@property
Expand All @@ -1665,13 +1665,13 @@ def weight_name_suffix(self):
name.append(self._format_name("tp", self.tensor_parallel_rank, self.tensor_parallel_degree))
if self.pipeline_parallel_degree > 1:
name.append(self._format_name("pp", self.pipeline_parallel_rank, self.pipeline_parallel_degree))
if self.use_moe:
name.append(f"moe{self.data_parallel_rank:0>2d}")
if self.use_expert_parallel:
name.append(self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree))

Check warning on line 1669 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1668-L1669

Added lines #L1668 - L1669 were not covered by tests
return "_".join(name)

else:
if self.use_moe:
return f"moe{self.data_parallel_rank:0>2d}"
if self.use_expert_parallel:
return self._format_name("moe", self.data_parallel_rank, self.data_parallel_degree)

Check warning on line 1674 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1673-L1674

Added lines #L1673 - L1674 were not covered by tests
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
Expand Down Expand Up @@ -1787,6 +1787,29 @@ def should_save_model_state(self):
else:
return self.process_index == 0

@property
def should_save_moe_model_state(self):
"""
Whether or not the current process should write to disk, e.g., to save moe models and checkpoints.
For model state:
work for data parallel, tensor parallel, sharding
For optimizer state:
work for data parallel, tensor parallel
not work for sharding
"""
if self.save_on_each_node:
return self.local_process_index == 0

Check warning on line 1802 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1801-L1802

Added lines #L1801 - L1802 were not covered by tests
else:
if self.should_save_sharding_stage1_model:
return True
elif self.enable_auto_parallel:
return True
elif self.use_hybrid_parallel:
return self.sharding_parallel_rank == 0

Check warning on line 1809 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1804-L1809

Added lines #L1804 - L1809 were not covered by tests
else:
return self.process_index == 0

Check warning on line 1811 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1811

Added line #L1811 was not covered by tests

@property
def _no_sync_in_gradient_accumulation(self):
"""
Expand Down
33 changes: 33 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,36 @@ def broadcast_dp_optimizer(state_dict):
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict


def broadcast_moe_optimizer(state_dict, opt_state_dict):
no_sync_vname = []
for k, v in state_dict.items():
if getattr(v, "no_sync", False):
no_sync_vname.append(v.name)
new_opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

Check warning on line 236 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L232-L236

Added lines #L232 - L236 were not covered by tests
# 1. when updating opt_state_dict, we should disable broading the parameters with the same name when `no_sync=True`.
# 2. if the keys of opt_state_dict and new_opt_state_dict are exactly the same, there is no need to update.
# 3. if they are different, the update should be based on the `no_sync_vname`.
if len(opt_state_dict.keys()) != len(new_opt_state_dict.keys()):
for op_k, op_v in new_opt_state_dict.items():
if op_k == "master_weights":
for k, v in new_opt_state_dict["master_weights"].items():
no_sync = False
for no_sync_v in no_sync_vname:
if k.startswith(no_sync_v):
no_sync = True
break
if not no_sync:
opt_state_dict["master_weights"][k] = v
elif op_k == "LR_Scheduler":
pass

Check warning on line 252 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L240-L252

Added lines #L240 - L252 were not covered by tests
else:
no_sync = False
for no_sync_v in no_sync_vname:
if op_k.startswith(no_sync_v):
no_sync = True
break
if not no_sync:
opt_state_dict[op_k] = op_v
return opt_state_dict

Check warning on line 261 in paddlenlp/trainer/utils/helper.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/utils/helper.py#L254-L261

Added lines #L254 - L261 were not covered by tests

0 comments on commit d854330

Please sign in to comment.