Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed May 15, 2024
1 parent 300a3e2 commit 5e2edde
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class PreTrainingArguments(TrainingArguments):
"help": "Enable fused_linear_param_grad pass, which should replace add_n_op with add_op for gradients accumulation."
},
)
fuse_allreduce_split_to_reducescatter: bool = field(
default=False,
metadata={"help": "Enable fuse_allreduce_split_to_reducescatter pass."},
)
eliminate_transpose: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -138,6 +142,11 @@ def __post_init__(self):
fused_passes.enable = True
fused_passes.fused_passes_list.append("fused_linear_param_grad_add_pass")

if self.fuse_allreduce_split_to_reducescatter:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
fused_passes.fused_passes_list.append("fuse_allreduce_split_to_reducescatter_pass")

if self.eliminate_transpose:
fused_passes = self.strategy.fused_passes
fused_passes.enable = True
Expand Down

0 comments on commit 5e2edde

Please sign in to comment.