Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fea] Cherry-picked MOE updates from develop #8531

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -705,4 +705,8 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
Whether use flatten_param_grads method in optimizer,
only used on NPU devices.(default:False)

--use_expert_parallel
Whether to enable MoE (Mixture of Experts) expert parallel training.
(default: False)

```
104 changes: 72 additions & 32 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@
from .utils import reshard as reshard_util
from .utils.helper import ( # nested_truncate,
broadcast_dp_optimizer,
broadcast_moe_optimizer,
distributed_concat,
distributed_file,
distributed_isfile,
Expand Down Expand Up @@ -565,7 +566,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
)
self.model.set_state_dict(state_dict)
else:
if resume_from_checkpoint is not None and self.args.dataset_rank == 0:
if resume_from_checkpoint is not None and (self.args.dataset_rank == 0 or self.args.use_expert_parallel):

weights_file = os.path.join(
resume_from_checkpoint, _add_variant(weight_name, self.args.weight_name_suffix)
Expand All @@ -581,7 +582,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
weights_index_file,
]
):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint} -- {weights_file}")

logger.info(f"Loading model from {resume_from_checkpoint} .")

Expand Down Expand Up @@ -930,22 +931,17 @@ def _inner_training_loop(
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
self.timers and self.timers("forward-backward").start()

dp_enabled = (
self.args.data_parallel_degree > 1 if self.args.use_hybrid_parallel else args.local_rank != -1
)
forbidden_no_sync = False
# stage2 and stage3 should not no_sync, because the is no DDP wrapper and no_sync API
# hybrid_parallel (tp or pp or sharding stage 1) should not no_sync
if self.args.use_hybrid_parallel:
forbidden_no_sync = True

availiable_no_sync = dp_enabled and not forbidden_no_sync

availiable_no_sync = hasattr(model, "no_sync")
is_no_sync = (
((step_control + 1) % args.gradient_accumulation_steps != 0)
and availiable_no_sync
and args._no_sync_in_gradient_accumulation
) or (args.recompute and availiable_no_sync)
(
((step_control + 1) % args.gradient_accumulation_steps != 0)
and args._no_sync_in_gradient_accumulation
)
or args.recompute
or args.use_expert_parallel
) and availiable_no_sync
# sharding
# stage1. the same as ddp
# stage2. manualy collect gradient on dp group
Expand All @@ -965,6 +961,14 @@ def _inner_training_loop(

tr_loss += tr_loss_step

def fused_allreduce_gradients_no_sync(paramlist, hcg):
paramlist = list(paramlist)
nonmoe_list = [p for p in paramlist if not getattr(p, "no_sync", False)]
moelist = [p for p in paramlist if getattr(p, "no_sync", False)]
if moelist and not self.args.use_expert_parallel:
logger.warning("found `no sync` param when `use_expert_parallel=False`")
fused_allreduce_gradients(nonmoe_list, hcg)

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
steps_in_epoch <= args.gradient_accumulation_steps
Expand All @@ -983,12 +987,12 @@ def _inner_training_loop(

# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)
if (args.recompute or args.use_expert_parallel) and availiable_no_sync:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)
elif dp_master_grad:
fused_allreduce_gradients_no_sync(list(model.parameters()), None)

# Pipeline parallel mode, handle gradient reduce here to overlap
pipeline_parallel_config = (
Expand All @@ -1007,8 +1011,7 @@ def _inner_training_loop(
self.optimizer._inner_opt.reduce_gradients(list(parameters_list), self.optimizer._hcg)

if self.optimizer._dp_enable or getattr(self.optimizer, "_sep_enable", False):
fused_allreduce_gradients(list(parameters_list), self.optimizer._hcg)

fused_allreduce_gradients_no_sync(list(parameters_list), self.optimizer._hcg)
self.timers and self.timers("all-reduce").stop()
self.timers and self.timers("optimizer-step").start()

Expand All @@ -1028,6 +1031,8 @@ def _inner_training_loop(
)
optimizer_was_run = True
if self.do_grad_scaling:
if args.pipeline_parallel_degree > 1:
assert not self.args.use_expert_parallel, "pipeline moe not work under fp16"
scale_before = paddle.assign(self.scaler._scale)
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down Expand Up @@ -2042,7 +2047,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,

model.train()
inputs = self._prepare_inputs(inputs)

with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)

Expand All @@ -2053,7 +2057,6 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
self.scaler.scale(loss).backward()
else:
loss.backward()

return loss.detach()

def training_pipeline_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
Expand Down Expand Up @@ -2143,6 +2146,26 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
# For ckpt integrity
paddle.save(self.state.global_step, os.path.join(output_dir, ".model_done"))

def _filter_moe_no_sync_optimizer_params(self):
"""
filter optimizer params which should not sync
"""
state_dict = self.model.state_dict()
optimzier_state_dict = self.optimizer.state_dict()
filter_optimzier_state_dict = OrderedDict()
param_names_in_master_weights = list(optimzier_state_dict["master_weights"].keys()) if self.args.bf16 else []
filter_optimzier_state_dict["master_weights"] = OrderedDict()
for k, v in state_dict.items():
if getattr(v, "no_sync", False):
if v.name in param_names_in_master_weights:
filter_optimzier_state_dict["master_weights"][v.name] = optimzier_state_dict["master_weights"][
v.name
]
for op_k, op_v in optimzier_state_dict.items():
if op_k.startswith(v.name):
filter_optimzier_state_dict[op_k] = op_v
return filter_optimzier_state_dict

def _save_checkpoint(self, model, metrics=None):
# assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
self.runtime_timer.start("checkpoint saving time")
Expand All @@ -2165,7 +2188,7 @@ def _save_checkpoint(self, model, metrics=None):
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)

if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0:
if self.dp_group.rank <= 0 or self.args.use_expert_parallel:
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2177,12 +2200,18 @@ def _save_checkpoint(self, model, metrics=None):
safe_serialization=True,
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(),
os.path.join(output_dir, optimizer_name),
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)

if self.args.should_save:
if self.args.should_save or self.args.use_expert_parallel:
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
Expand All @@ -2194,7 +2223,12 @@ def _save_checkpoint(self, model, metrics=None):
safe_serialization=True,
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME)
)
else:
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
Expand Down Expand Up @@ -2452,7 +2486,7 @@ def _load_optimizer_and_scheduler(self, checkpoint):
logger.info("Loading checkpoint, the next checkpoint will be saved as unified checkpoint")

if not use_unified_checkpoint:
if self.args.data_parallel_rank == 0:
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
path = os.path.join(checkpoint, optimizer_name)
if os.path.isfile(path):
Expand All @@ -2476,7 +2510,13 @@ def _load_optimizer_and_scheduler(self, checkpoint):
# broadcast optimizer state in dp group
if self.args.local_rank != -1:
dist.barrier()
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)
if self.args.use_expert_parallel:
opt_state_dict = broadcast_moe_optimizer(
opt_state_dict, broadcast_dp=not self.args.should_load_sharding_stage1_model
)
else:
if not self.args.should_load_sharding_stage1_model:
opt_state_dict = broadcast_dp_optimizer(opt_state_dict)

if opt_state_dict is not None:
# Load in optimizer and scheduler states
Expand Down
30 changes: 26 additions & 4 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to output logits in distributed status"},
)
use_expert_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Enable MoE (Mixture of Experts) expert parallel training"},
)

def __post_init__(self):
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
Expand Down Expand Up @@ -1117,6 +1121,8 @@ def is_segment_parallel_supported():
order = ["dp", "sharding", "pp", "sep", "mp"]
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
order = order[1:-1] + ["dp", "mp"]

if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1598,9 +1604,12 @@ def optimizer_name_suffix(self):
if self.sharding_parallel_degree > 1:
assert self.sharding_parallel_degree < 100, "sharding parallel degree should be less than 100."
name.append(f"shard{self.sharding_parallel_rank:0>2d}")

if self.use_expert_parallel:
name.append(f"moe{self.data_parallel_rank:0>2d}")
return "_".join(name)
else:
if self.use_expert_parallel:
return f"moe{self.data_parallel_rank:0>2d}"
return None

@property
Expand All @@ -1613,12 +1622,16 @@ def weight_name_suffix(self):
if self.pipeline_parallel_degree > 1:
assert self.pipeline_parallel_degree < 100, "tensor parallel rank should be less than 100."
name.append(f"pp{self.pipeline_parallel_rank:0>2d}")
if self.use_expert_parallel:
name.append(f"moe{self.data_parallel_rank:0>2d}")
return "_".join(name)

else:
if self.use_expert_parallel:
return f"moe{self.data_parallel_rank:0>2d}"
return None

def sharded_name_suffix(self, shard_id=None, pp_id=None):
def sharded_name_suffix(self, shard_id=None, pp_id=None, moe_id=None):
if self.use_hybrid_parallel:
name = []
if self.tensor_parallel_degree > 1:
Expand All @@ -1636,8 +1649,17 @@ def sharded_name_suffix(self, shard_id=None, pp_id=None):
assert isinstance(shard_id, int)
assert shard_id < 100, "shard_id should be less than 100."
name.append(f"shard{shard_id:0>2d}")
if self.use_expert_parallel:
if moe_id is None:
moe_id = self.data_parallel_rank
assert isinstance(moe_id, int)
name.append(f"moe{moe_id:0>2d}")
return "_".join(name)
else:
if self.use_expert_parallel:
if moe_id is None:
moe_id = self.data_parallel_rank
return self._format_name("moe", moe_id, self.data_parallel_degree)
return None

@property
Expand Down Expand Up @@ -1730,9 +1752,9 @@ def should_save_model_state(self):
return True
elif self.use_hybrid_parallel:
# save on dataset rank 0
return self.sharding_parallel_rank == 0 and self.data_parallel_rank == 0
return self.sharding_parallel_rank == 0 and (self.data_parallel_rank == 0 or self.use_expert_parallel)
else:
return self.process_index == 0
return self.process_index == 0 or self.use_expert_parallel

@property
def _no_sync_in_gradient_accumulation(self):
Expand Down
59 changes: 59 additions & 0 deletions paddlenlp/trainer/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,62 @@ def broadcast_dp_optimizer(state_dict):
state_dict = nested_broadcast_tensor(state_dict, src=src_rank, group=dp_group)

return state_dict


def broadcast_moe_optimizer(state_dict, broadcast_dp=True):

try:
hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
data_parallel_rank = hcg.get_data_parallel_rank()
# Don't broadcast optimizer for dp rank is 1.
if dp_group.nranks <= 1:
return state_dict
except:
dp_group = None
src_rank = 0
data_parallel_rank = 0

def _broadcast_moe_optimizer_state(state_dict):
# boardcast_keys
base_state_dict = {"master_weights": {}}
buf = [
{i: j.shape for i, j in state_dict.items() if i not in ["master_weights", "LR_Scheduler"]},
{i: j.shape for i, j in state_dict["master_weights"].items()},
{"LR_Scheduler": state_dict.get("LR_Scheduler", {})},
]

dist.broadcast_object_list(buf, src=src_rank, group=dp_group)
# logger.info(f"moe-optimizer-gather-keys{buf}")
for k, s in buf[0].items():
v = state_dict.get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
# k = k.replace("_fp32_master_0", "")
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer {k} from {src_rank}")
base_state_dict[k] = v.cpu()
for k, s in buf[1].items():
v = state_dict["master_weights"].get(k, paddle.zeros(s, "float32")).cuda()
v.name = k
dist.broadcast(v, src=src_rank, group=dp_group)
logger.info(f"broadcast moe optimizer-master_weights {k} from {src_rank}")
base_state_dict["master_weights"][k] = v.cpu()
base_state_dict.update(buf[2])
return base_state_dict

if broadcast_dp:
base_state_dict = broadcast_dp_optimizer(state_dict)
else:
base_state_dict = _broadcast_moe_optimizer_state(state_dict)
if data_parallel_rank > 0:
master_weight = state_dict.pop("master_weights", {})
base_state_dict.update(state_dict)
if master_weight:
if "master_weights" in base_state_dict:
base_state_dict["master_weights"].update(master_weight)
else:
base_state_dict["master_weights"] = master_weight
state_dict = base_state_dict
del base_state_dict
return state_dict
Loading
Loading