From dedd69bfb54b8bba97058eeef4350a5adc6cc31a Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 28 Aug 2023 08:53:01 +0000 Subject: [PATCH 01/11] B-F overlap --- .../auto_parallel/static/engine.py | 13 + .../paddle/distributed/passes/pass_utils.py | 42 ++-- .../distributed/passes/pipeline_pass_base.py | 35 ++- .../passes/pipeline_scheduler_pass.py | 228 ++++++++++++++++-- ...t_standalone_executor_multi_micro_batch.py | 6 +- 5 files changed, 254 insertions(+), 70 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 6bfb094f28346..255b56fcf79f9 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -992,6 +992,19 @@ def fit( use_program_cache=self._strategy.use_cache, return_numpy=self._strategy.return_numpy, ) + print( + f"memory_allocated = {paddle.device.cuda.memory_allocated()/1024/1024}" + ) + print( + f"max_memory_allocated = {paddle.device.cuda.max_memory_allocated()/1024/1024}" + ) + print( + f"memory_reserved = {paddle.device.cuda.memory_reserved()/1024/1024}" + ) + print( + f"max_memory_reserved = {paddle.device.cuda.max_memory_reserved()/1024/1024}" + ) + except core.EOFException: break lr = auto_utils.get_lr(self.optimizer) diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index afb935008167a..00b8714b667f0 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import OrderedDict -from typing import List from paddle.distributed.auto_parallel.static.utils import ( is_backward_op, @@ -213,22 +212,23 @@ def var_can_be_deleted(var_name, block): return var is not None and not var.persistable -def get_skip_gc_vars(program_list: List[Program]): +def set_skip_gc_vars(num_micro_batches, type_to_program, jobs): """ - Get `skip_gc_vars` for every sub_program of program_list. + Set `skip_gc_vars` for every job in jobs. A whole_program is split up into sub_programs according to the schedule mode, thus a sub_program's vars might be used as the op's input of the later sub_program, and these vars cannot be gc after executing current sub_program. """ + assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1" - # step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer. - required_vars = [set() for _ in range(len(program_list))] - for idx, program in enumerate(program_list): + # step1: Get all vars of every sub_program that are non-persistable and not in op's no_need_buffer. + type_to_required_vars = {} + for type, program in type_to_program.items(): + type_to_required_vars[type] = set() for block in program.blocks: for op in block.ops: - # NOTE(Ruibiao): Some vars maybe be the arguements of conditional_block op but no-need-buffer in the actual subblock, should not add them to the required_vars. - if op.type == "conditional_block": + if op.type in ["conditional_block", "while"]: continue op_info = OpInOutInfo() @@ -237,19 +237,19 @@ def get_skip_gc_vars(program_list: List[Program]): if var_can_be_deleted( arg_name, block ) and op_info.is_needed(arg_name): - required_vars[idx].add(arg_name) - - # step2: Get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program - suffixed_required_vars = set() - skip_gc_vars = [set()] * len(program_list) - for idx, vars_set in reversed(list(enumerate(required_vars))): - if idx < len(required_vars) - 1: - suffixed_required_vars = suffixed_required_vars.union( - required_vars[idx + 1] - ) - skip_gc_vars[idx] = vars_set & suffixed_required_vars - - return skip_gc_vars + type_to_required_vars[type].add(arg_name) + + # step2: Set `skip_gc_vars` for each job + suffixed_required_vars = [set() for i in range(num_micro_batches)] + num_jobs = len(jobs) + for job_id in reversed(range(num_jobs)): + job = jobs[job_id] + required_vars = type_to_required_vars[job.type()] + micro_batch_id = job.micro_batch_id() + job.set_skip_gc_vars( + required_vars & suffixed_required_vars[micro_batch_id] + ) + suffixed_required_vars[micro_batch_id] |= required_vars def _create_param(dst_block, src_var): diff --git a/python/paddle/distributed/passes/pipeline_pass_base.py b/python/paddle/distributed/passes/pipeline_pass_base.py index 14dce9065af0b..6fb2964ab9c39 100644 --- a/python/paddle/distributed/passes/pipeline_pass_base.py +++ b/python/paddle/distributed/passes/pipeline_pass_base.py @@ -15,7 +15,7 @@ from paddle.fluid import core from .pass_base import PassBase -from .pass_utils import get_skip_gc_vars +from .pass_utils import set_skip_gc_vars class PipelinePassBase(PassBase): @@ -28,21 +28,21 @@ def _check_self(self): def _check_conflict(self, other_pass): return True - def create_job_list(self): + def _create_job_list(self): """ An interface that MUST be implemented by subclasses. """ pass - def partial_programs(self, program): + def _partial_programs(self, program): """ An interface that MUST be implemented by subclasses. The return value MUST be two lists, one is a list of types(str), another is a list of sub programs. For example: - return ["lr", "forward", "backward", "optimizer"], [lr_prog, fwd_prog, bwd_prog, opt_prog] + return [LR, FORWARD, BACKWARD, OPT], [lr_prog, fwd_prog, bwd_prog, opt_prog] or - return ["forward"], [fwd_prog] + return [FORWARD], [fwd_prog] """ pass @@ -51,22 +51,15 @@ def _apply_single_impl(self, main_program, startup_program, context): The shared process is implemented in this function and new subclass only need to implement two interfaces above, 'create_job_list' and 'partial_programs'. """ - type_list, sub_program_list = self.partial_programs(main_program) + job_types, sub_programs = self._partial_programs(main_program) + jobs = self._create_job_list() - job_list = self.create_job_list() + type_to_program = dict(zip(job_types, sub_programs)) + set_skip_gc_vars( + self.get_attr("num_micro_batches"), type_to_program, jobs + ) - # Following is a shared gc process for base class. - gc_vars_list = get_skip_gc_vars(sub_program_list) - type_to_gc_vars = {} - for type, gc_var in zip(type_list, gc_vars_list): - type_to_gc_vars[type] = gc_var - - for job in job_list: - job.set_skip_gc_vars(type_to_gc_vars[job.type()]) - - type_to_program = {} - for type, sub_program in zip(type_list, sub_program_list): - type_to_program[type] = sub_program.desc - - plan = core.Plan(job_list, type_to_program) + for type in type_to_program.keys(): + type_to_program[type] = type_to_program[type].desc + plan = core.Plan(jobs, type_to_program) context.set_attr("plan", plan) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 9914b6e517c05..7229d61967448 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import os + from paddle.fluid import core +from ..utils.log_utils import get_logger from .pass_base import PassContext, new_pass, register_pass -from .pass_utils import _program_for_fthenb_and_1f1b +from .pass_utils import _program_for_fthenb_and_1f1b, split_program from .pipeline_pass_base import PipelinePassBase __not_shape_var_type__ = [ @@ -26,35 +30,42 @@ core.VarDesc.VarType.FETCH_LIST, ] +LR = "lr" +FORWARD = "forward" +BACKWARD = "backward" +OPT = "optimizer" + +logger = get_logger(logging.INFO) + @register_pass("pipeline_scheduler_FThenB") class PipelineFThenBPass(PipelinePassBase): def __init__(self): super().__init__() - def create_job_list(self): + def _create_job_list(self): num_micro_batches = self.get_attr("num_micro_batches") job_list = [] - lr_job = core.Job("lr") + lr_job = core.Job(LR) job_list.append(lr_job) for i in range(num_micro_batches): - forward_job = core.Job("forward") + forward_job = core.Job(FORWARD) forward_job.set_micro_batch_id(i) job_list.append(forward_job) for i in range(num_micro_batches): - backward_job = core.Job("backward") + backward_job = core.Job(BACKWARD) backward_job.set_micro_batch_id(i) job_list.append(backward_job) - opt_job = core.Job("optimizer") + opt_job = core.Job(OPT) job_list.append(opt_job) return job_list - def partial_programs(self, program): - types = ["lr", "forward", "backward", "optimizer"] + def _partial_programs(self, program): + types = [LR, FORWARD, BACKWARD, OPT] sub_program_list = _program_for_fthenb_and_1f1b(program) return types, sub_program_list @@ -63,55 +74,216 @@ def partial_programs(self, program): class Pipeline1F1BPass(PipelinePassBase): def __init__(self): super().__init__() + self.jobs_in_stable_phase = [BACKWARD, FORWARD] + # Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj. + # For example: jobs = {..., BACKWARD-i, FORWARD-j, ...}, i < j + # BACKWARD-i: OP1 - AllReduce - OP3 + # FORWARD-j: OP4 - AllReduce - OP6 + # Timeline: + # ===OP1===AllReduce===OP2===OP3===AllReduce===OP4 + # + # After backward-forward overlapping: jobs = {..., OP1, AllReduce, OP3, OP2, AllReduce, OP4} + # Timeline: + # === OP1 === OP3 =====OP2===========OP4 + # \ / + # \ / + # ========= AllReduce == AllReduce + self.set_attr("num_comm_op_in_backward_forward_overlap", 0) - def create_job_list(self): + def _create_job_list(self): num_micro_batches = self.get_attr("num_micro_batches") pp_stage = self.get_attr("pp_stage") pp_degree = self.get_attr("pp_degree") job_list = [] - lr_job = core.Job("lr") + lr_job = core.Job(LR) job_list.append(lr_job) assert ( pp_degree <= num_micro_batches - ), "Num of micro batches should larger than pp degree." + ), "Num of micro batches should larger than or equal to pp degree." micro_batch_in_warmup = pp_degree - pp_stage micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup forward_micro_batch_id = 0 for i in range(micro_batch_in_warmup): - forward_job = core.Job("forward") + forward_job = core.Job(FORWARD) forward_job.set_micro_batch_id(forward_micro_batch_id) job_list.append(forward_job) forward_micro_batch_id += 1 backward_micro_batch_id = 0 for i in range(micro_batch_in_1f1b): - backward_job = core.Job("backward") - backward_job.set_micro_batch_id(backward_micro_batch_id) - job_list.append(backward_job) - backward_micro_batch_id += 1 - forward_job = core.Job("forward") - forward_job.set_micro_batch_id(forward_micro_batch_id) - job_list.append(forward_job) + for job_type in self.jobs_in_stable_phase: + job = core.Job(job_type) + micro_batch_id = ( + forward_micro_batch_id + if job_type.startswith(FORWARD) + else backward_micro_batch_id + ) + job.set_micro_batch_id(micro_batch_id) + job_list.append(job) forward_micro_batch_id += 1 + backward_micro_batch_id += 1 for i in range(micro_batch_in_warmup): - backward_job = core.Job("backward") + backward_job = core.Job(BACKWARD) backward_job.set_micro_batch_id(backward_micro_batch_id) job_list.append(backward_job) backward_micro_batch_id += 1 - opt_job = core.Job("optimizer") + opt_job = core.Job(OPT) job_list.append(opt_job) return job_list - def partial_programs(self, program): - types = ["lr", "forward", "backward", "optimizer"] - sub_program_list = _program_for_fthenb_and_1f1b(program) - return types, sub_program_list + def _cost(self, op_type): + cost = { + "recv_v2": 0.229, + "c_allreduce_sum": float( + "INF" + ), # ONLY for Forward, set the cost of c_allreduce_sum as INF so all of them will be splitted to the end of a chunk. + "cast": 0.052, + "c_embedding": 0.061, + "lookup_table_v2": 0.047, + "elementwise_add": 0.051, + "layer_norm": 0.086, + "c_identity": 0.037, + "matmul_v2": 0.660, + "split": 0.070, + "transpose2": 0.030, + "scale": 0.019, + "fused_softmax_mask_upper_triangle": 0.284, + "gelu": 0.128, + } + return cost[op_type] if op_type in cost else 0.0 + + def _multistreaming_for_overlapping(self, programs): + for program in programs: + last_op = program.global_block().ops[-1] + if self.is_comm_op(last_op) and last_op.attr("use_calc_stream"): + last_op.dist_attr.execution_stream = "allreduce_stream" + + def _partial_programs(self, program): + types = [LR, FORWARD, BACKWARD, OPT] + sub_programs = _program_for_fthenb_and_1f1b(program) + + num_comm_op_in_backward_forward_overlap = self.get_attr( + "num_comm_op_in_backward_forward_overlap" + ) + assert ( + num_comm_op_in_backward_forward_overlap >= 0 + ), f"Get num_comm_op_in_backward_forward_overlap = {num_comm_op_in_backward_forward_overlap}, which should be >= 0." + + if num_comm_op_in_backward_forward_overlap > 0: + logger.info( + f"Backward forward overlap enabled in 1F1B, num_comm_op_in_backward_forward_overlap = {num_comm_op_in_backward_forward_overlap}." + ) + + # Split FORWARD + forward_program = sub_programs[1] + ops = forward_program.global_block().ops + num_ops = len(ops) + + costs = [self._cost(op.type) for op in ops] + prefix_cost = 0 + duration_for_overlap = 0.771 # cost of allreduce in BACKWARD + splitted_op_ids = [] + for op_id, op in enumerate(ops): + if prefix_cost > duration_for_overlap: + splitted_op_ids.append(op_id) + prefix_cost = 0 + if ( + len(splitted_op_ids) + 1 + >= num_comm_op_in_backward_forward_overlap + ): + break + + prefix_cost += self._cost(op.type) + + is_forward_split_point = ( + lambda program, op_id: op_id in splitted_op_ids + ) + + ( + splitted_forward_job_types, + splitted_forward_programs, + ) = self._split_program_for_overlapping( + FORWARD, forward_program, is_forward_split_point + ) + self._multistreaming_for_overlapping(splitted_forward_programs) + types += splitted_forward_job_types + sub_programs += splitted_forward_programs + + # Split BACKWARD + backward_program = sub_programs[2] + comm_op_ids = [ + op_id + for op_id, op in enumerate(backward_program.global_block().ops) + if self.is_comm_op(op) + ] + is_backward_split_point = ( + lambda program, op_id: op_id - 1 in comm_op_ids + and len(comm_op_ids) - comm_op_ids.index(op_id - 1) + < num_comm_op_in_backward_forward_overlap + ) + ( + splitted_backward_job_types, + splitted_backward_programs, + ) = self._split_program_for_overlapping( + BACKWARD, backward_program, is_backward_split_point + ) + self._multistreaming_for_overlapping(splitted_backward_programs) + types += splitted_backward_job_types + sub_programs += splitted_backward_programs + + # Rearrange splitted chunks for BACKWARD and FORWARD + self.jobs_in_stable_phase.clear() + num_splitted_forward_jobs = len(splitted_forward_job_types) + num_splitted_backward_jobs = len(splitted_backward_job_types) + for idx in range( + max(num_splitted_forward_jobs, num_splitted_backward_jobs) + ): + if idx < num_splitted_backward_jobs: + self.jobs_in_stable_phase.append( + splitted_backward_job_types[idx] + ) + if idx < num_splitted_forward_jobs: + self.jobs_in_stable_phase.append( + splitted_forward_job_types[idx] + ) + + for i in range(len(types)): + print(f"type = {types[i]}, sub_programs = {sub_programs[i]}\n") + logger.info(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") + + return types, sub_programs + + def _split_program_for_overlapping(self, job_type, program, is_split_point): + assert job_type in [ + FORWARD, + BACKWARD, + ], f"job_type should be one of {[FORWARD, BACKWARD]}" + + ops = program.global_block().ops + num_ops = len(ops) + + split_ids = [] + for op_id in range(1, num_ops): + if is_split_point(program, op_id): + split_ids.append(op_id) + + splitted_programs, __, __ = split_program(program, split_ids) + + splitted_job_types = [] + num_splitted_programs = len(splitted_programs) + for idx in range(num_splitted_programs): + splitted_job_types.append(f"{job_type}(chunk{idx})") + + return splitted_job_types, splitted_programs + + def is_comm_op(self, op): + return op.type == "c_allreduce_sum" def apply_pass(main_program, startup_program, pass_name, pass_attr={}): @@ -121,6 +293,12 @@ def apply_pass(main_program, startup_program, pass_name, pass_attr={}): ], "pipeline scheduler only support FThenB and 1F1B, but recieve {}".format( pass_name ) + + if pass_name == "1F1B": + pass_attr["num_comm_op_in_backward_forward_overlap"] = int( + os.environ.get("FLAGS_num_comm_op_in_backward_forward_overlap", 0) + ) + pipeline_pass = new_pass("pipeline_scheduler_" + pass_name, pass_attr) pass_context = PassContext() pipeline_pass.apply([main_program], [startup_program], pass_context) diff --git a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py index 61b76559c0098..9d7af3eaaa24f 100644 --- a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py +++ b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py @@ -19,7 +19,7 @@ import numpy as np import paddle -from paddle.distributed.passes.pass_utils import get_skip_gc_vars, split_program +from paddle.distributed.passes.pass_utils import set_skip_gc_vars, split_program from paddle.fluid import core from paddle.fluid.core import Job, Plan from paddle.fluid.executor import _add_feed_fetch_ops, _StandaloneExecutor @@ -180,13 +180,11 @@ def run_train(self, split=False, micro_batch_num=1): job_list = [] program_num = len(programs) - skip_gc_vars = get_skip_gc_vars(programs) for micro_batch_id in range(micro_batch_num): for program_id in range(program_num): job = Job(f"P{program_id}") job.set_micro_batch_id(micro_batch_id) - job.set_skip_gc_vars(skip_gc_vars[program_id]) # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch if program_id == program_num - 1: fetch_op_id_to_col_attr = {} @@ -201,6 +199,8 @@ def run_train(self, split=False, micro_batch_num=1): for program_id in range(program_num): type_to_program[f"P{program_id}"] = programs[program_id].desc + set_skip_gc_vars(micro_batch_num, type_to_program, job_list) + plan = Plan(job_list, type_to_program) main_exe = _StandaloneExecutor(self.place, plan, scope) From ac49ea7a1616b95d528209b406b3609b1044cb42 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Fri, 1 Sep 2023 06:29:34 +0000 Subject: [PATCH 02/11] Add column_parallel_linear_backward_overlapping --- .../interpreter/stream_analyzer.cc | 2 +- .../auto_parallel/static/parallelizer_v2.py | 12 ++ python/paddle/distributed/passes/__init__.py | 7 +- .../passes/auto_parallel_sharding.py | 7 +- ...mn_parallel_linear_backward_overlapping.py | 121 ++++++++++++++++++ .../paddle/distributed/passes/pass_utils.py | 9 ++ .../distributed/passes/pipeline_pass_base.py | 4 - .../passes/pipeline_scheduler_pass.py | 10 +- 8 files changed, 159 insertions(+), 13 deletions(-) create mode 100644 python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index bf7b8392d8ac1..faf29ab762bab 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -346,7 +346,7 @@ void analyse_event_info_for_two_instructions( if (has_data_dependency( instructions[cur_instr_id], instructions[next_instr_id]) || - !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || + //! run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || instructions[next_instr_id]->OpBase()->Type() == "depend") { waiter_instr_ids->insert(next_instr_id); return; diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 8b5136a61b9f6..37b1d9524b65e 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -14,6 +14,7 @@ import copy import logging +import os import time from paddle.distributed.passes import PassManager, new_pass @@ -354,6 +355,17 @@ def _apply_post_optimization( ) params_grads = self._pass_context.get_attr("params_grads") + mp_async_allreduce_in_backward = os.getenv( + "FLAGS_mp_async_allreduce_in_backward" + ) in [1, "1", True, "True"] + if mp_async_allreduce_in_backward: + column_parallel_linear_backward_overlapping_pass = new_pass( + "column_parallel_linear_backward_overlapping", {} + ) + column_parallel_linear_backward_overlapping_pass.apply( + [main_program], [startup_program], self._pass_context + ) + if self.is_train: # GradClip is train-only optimization config = copy.deepcopy(self._strategy.sharding.to_dict()) diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 206c158769b2c..e2f54d47a4e08 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .pass_base import new_pass, PassManager, PassContext -from .fuse_all_reduce import * # noqa: F403 + from .auto_parallel_gradient_merge import * # noqa: F403 from .auto_parallel_sharding import * # noqa: F403 from .auto_parallel_amp import * # noqa: F403 @@ -24,11 +24,14 @@ from .auto_parallel_grad_clip import * # noqa: F403 from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .auto_parallel_pipeline import * # noqa: F403 -from .pipeline_scheduler_pass import * # noqa: F403 +from .column_parallel_linear_backward_overlapping import * # noqa: F403 from .cpp_pass import * # noqa: F403 +from .fuse_all_reduce import * # noqa: F403 +from .pipeline_scheduler_pass import * # noqa: F403 from .ps_trainer_pass import * # noqa: F403 from .ps_server_pass import * # noqa: F403 + __all__ = [ 'new_pass', 'PassManager', diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 44880cd6a3bfc..aefdda49aa9a4 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -43,6 +43,7 @@ from paddle.utils import unique_name from .pass_base import PassBase, register_pass +from .pass_utils import AutoParallelStreamType OpRole = core.op_proto_and_checker_maker.OpRole OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() @@ -748,13 +749,11 @@ def _fuse_overlap_parameter_comm_stage_two(self, sharding_info): group = sharding_info.group else: group = new_process_group(ranks, force_new_group=True) - # NOTE here stream is just a presentation with different name, - # it is up to executor to create the exact streams given the name. - stream = f"sharding_param_comm_stream{i}" + self.param_comm_group_stream_pairs.append( { "comm_group": group, - "comm_stream": stream, + "comm_stream": AutoParallelStreamType.SHARDING_STREAM.value, } ) _logger.info( diff --git a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py b/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py new file mode 100644 index 0000000000000..76f4ca628c9b9 --- /dev/null +++ b/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from .pass_base import PassBase, register_pass +from .pass_utils import AutoParallelStreamType + + +# For allreduce pattern in the backward phase of column parallel linear: +# dX, dY = matmul_grad(X, Y, dOut) +# dX = c_allreduce_sum(dX) +# Split matmul_grad to 2 matmul: +# dX = mutmul(dOut, Y^T) +# dX = c_allreduce_sum(dX) +# dY = matmul(X^T, dOut) +# +# Then the c_allreduce_sum can overlap with the compute of dY. +@register_pass("column_parallel_linear_backward_overlapping") +class ColumnParallelLinearBackwardOverlappingPass(PassBase): + def __init__(self): + super().__init__() + self.set_attr("allreduce_stream", None) + + def _check_self(self): + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + block = main_program.global_block() + matmul_grad_id_to_allreduce_id = ( + self._get_all_matmul_grad_and_allreduce_pairs(block) + ) + self._split_matmul_grad_and_multi_streaming_allreduce( + block, matmul_grad_id_to_allreduce_id + ) + + def _get_all_matmul_grad_and_allreduce_pairs(self, block): + ops = block.ops + op_num = len(ops) + matmul_grad_id_to_allreduce_id = collections.OrderedDict() + for i, op_i in enumerate(ops): + if ( + op_i.type == 'matmul_v2_grad' + and op_i.attr("trans_x") is False + and op_i.attr("trans_y") is False + ): + x_grad = op_i.output("X@GRAD") + for j in range(i + 1, op_num): + op_j = ops[j] + if ( + op_j.type == 'c_allreduce_sum' + and op_j.input("X") == x_grad + ): + matmul_grad_id_to_allreduce_id[i] = j + return matmul_grad_id_to_allreduce_id + + def _split_matmul_grad_and_multi_streaming_allreduce( + self, block, matmul_grad_id_to_allreduce_id + ): + ops = block.ops + + for matmul_grad_id, allreduce_id in reversed( + matmul_grad_id_to_allreduce_id.items() + ): + matmul_grad_op = ops[matmul_grad_id] + allreduce_op = ops[allreduce_id] + + tran_x = matmul_grad_op.attr("trans_x") + assert ( + not tran_x + ), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for column parallel linear backward overlapping" + tran_y = matmul_grad_op.attr("trans_y") + assert ( + not tran_y + ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for column parallel linear backward overlapping" + + allreduce_op.dist_attr.execution_stream = ( + AutoParallelStreamType.MP_STREAM.value + ) + + x = matmul_grad_op.input("X") + y = matmul_grad_op.input("Y") + out_grad = matmul_grad_op.input("Out@GRAD") + x_grad = matmul_grad_op.output("X@GRAD") + y_grad = matmul_grad_op.output("Y@GRAD") + op_role = matmul_grad_op.attr("op_role") + + # NOTE(Ruibiao): Required OP scheduling order: mutmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). + # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass + # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then + # the matmul op before it cannot be overlapped. + block._insert_op_without_sync( + allreduce_id + 1, + type="matmul_v2", + inputs={"X": x, "Y": out_grad}, + outputs={"Out": y_grad}, + attrs={"trans_x": True, "trans_y": False, "op_role": op_role}, + ) + block._insert_op_without_sync( + matmul_grad_id + 1, + type="matmul_v2", + inputs={"X": out_grad, "Y": y}, + outputs={"Out": x_grad}, + attrs={"trans_x": False, "trans_y": True, "op_role": op_role}, + ) + block._remove_op(matmul_grad_id) + block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index e6b4a490ff3f4..83b5726d87d77 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -14,6 +14,7 @@ import logging from collections import OrderedDict +from enum import Enum from paddle.distributed.auto_parallel.static.utils import ( get_logger, @@ -539,3 +540,11 @@ def _add_ops_into_block(src_block, dst_block, ops): # It MUST return in this order return [lr_prog, fwd_prog, bwd_prog, opt_prog] + + +# NOTE here stream is just a presentation with different name, +# it is up to executor to create the exact streams given the name. +class AutoParallelStreamType(Enum): + CALC_STREAM = "default" + MP_STREAM = "auto_parallel_mp" + SHARDING_STREAM = "auto_parallel_sharding" diff --git a/python/paddle/distributed/passes/pipeline_pass_base.py b/python/paddle/distributed/passes/pipeline_pass_base.py index fdaac1144029a..0561dcd0e93e2 100644 --- a/python/paddle/distributed/passes/pipeline_pass_base.py +++ b/python/paddle/distributed/passes/pipeline_pass_base.py @@ -12,16 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -from paddle.distributed.auto_parallel.static.utils import get_logger from paddle.fluid import core from .pass_base import PassBase from .pass_utils import set_skip_gc_vars -_logger = get_logger(logging.INFO) - class PipelinePassBase(PassBase): def __init__(self): diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index a976df9f6ef3d..1911d164f1d5f 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -19,7 +19,11 @@ from ..utils.log_utils import get_logger from .pass_base import PassContext, new_pass, register_pass -from .pass_utils import _program_for_fthenb_and_1f1b, split_program +from .pass_utils import ( + AutoParallelStreamType, + _program_for_fthenb_and_1f1b, + split_program, +) from .pipeline_pass_base import PipelinePassBase __not_shape_var_type__ = [ @@ -164,7 +168,9 @@ def _multistreaming_for_overlapping(self, programs): for program in programs: last_op = program.global_block().ops[-1] if self.is_comm_op(last_op) and last_op.attr("use_calc_stream"): - last_op.dist_attr.execution_stream = "allreduce_stream" + last_op.dist_attr.execution_stream = ( + AutoParallelStreamType.MP_STREAM.value + ) def _partial_programs(self, program): types = [LR, FORWARD, BACKWARD, OPT] From ee4cc584f44c81f8e0c1eedecc92fd37f5f6ab5f Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 5 Sep 2023 09:01:48 +0000 Subject: [PATCH 03/11] Add cost model --- .../passes/pipeline_scheduler_pass.py | 256 ++++++++++-------- 1 file changed, 139 insertions(+), 117 deletions(-) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 1911d164f1d5f..e87e854223a49 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -15,6 +15,7 @@ import logging import os +from paddle.distributed.auto_parallel.static.cost import calc_time_by_cost_model from paddle.fluid import core from ..utils.log_utils import get_logger @@ -80,20 +81,115 @@ class Pipeline1F1BPass(PipelinePassBase): def __init__(self): super().__init__() self.jobs_in_stable_phase = [BACKWARD, FORWARD] - # Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj. - # For example: jobs = {..., BACKWARD-i, FORWARD-j, ...}, i < j - # BACKWARD-i: OP1 - AllReduce - OP3 - # FORWARD-j: OP4 - AllReduce - OP6 - # Timeline: - # ===OP1===AllReduce===OP2===OP3===AllReduce===OP4 - # - # After backward-forward overlapping: jobs = {..., OP1, AllReduce, OP3, OP2, AllReduce, OP4} - # Timeline: - # === OP1 === OP3 =====OP2===========OP4 - # \ / - # \ / - # ========= AllReduce == AllReduce - self.set_attr("num_comm_op_in_backward_forward_overlap", 0) + self.set_attr("enable_backward_forward_overlap", 0) + + # Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj. + # For example: jobs = {..., BACKWARD-i, FORWARD-j, ...}, i < j + # BACKWARD-i: Calc1 - Comm1 - Calc2 - Comm2 - Calc3 + # FORWARD-j: Calc4 - Comm3 - Calc5 - Comm4 - Calc6 + # Timeline: + # ===Calc1==Comm1==Calc2==Comm2==Calc3==Calc4==Comm3==Calc5==Comm4==Calc6=== + # + # After backward-forward overlapping: jobs = {Calc1, Comm1, Calc4, Comm3, Calc2, Comm2, Calc5, Comm4, Calc3, Calc6} + # Timeline: + # ===Calc1==Calc4==Calc2==Calc5==Calc3=Calc6=== + # \ / \ / + # \ / \ / + # ==========Comm1==Comm3==Comm2==Comm4========== + # + def _backward_forward_overlap(self, backward_program, forward_program): + logger.info("Backward forward overlap enabled in 1F1B.") + print(f"backward_program :: {backward_program}") + print(f"fowr_program :: {forward_program}") + # Split BACKWARD + valid_comm_op_ids = [ + op_id + for op_id, op in enumerate(backward_program.global_block().ops) + if self.is_comm_op_valid_to_overlap(op) + ] + # TODO(Ruibiao): Constrain the number of valid comm ops to resolve the potential memory explosion issue. + is_backward_split_point = ( + lambda program, op_id: op_id - 1 in valid_comm_op_ids + ) + ( + splitted_backward_job_types, + splitted_backward_programs, + ) = self._split_program_for_overlapping( + BACKWARD, backward_program, is_backward_split_point + ) + self._multistreaming_for_overlapping(splitted_backward_programs) + + # Split FORWARD + ops = forward_program.global_block().ops + num_ops = len(ops) + splitted_op_ids = [] + op_id = 0 + for splitted_backward_program in splitted_backward_programs: + backward_op_to_overlap = ( + splitted_backward_program.global_block().ops[-1] + ) + backward_cost_to_overlap = self._op_cost(backward_op_to_overlap) + + forward_cost_to_overlap = self._op_cost(ops[op_id]) + print( + f"backward_op_to_overlap : {backward_op_to_overlap}, cost = {backward_cost_to_overlap}" + ) + print( + f"forward_op_to_overlap : {ops[op_id]}, cost = {forward_cost_to_overlap}" + ) + + while ( + op_id < num_ops + and forward_cost_to_overlap <= backward_cost_to_overlap + ): + op_id += 1 + op = ops[op_id] + # Force split when meet comm op since it cannot overlap with comm op in backward. + if op_id > 0 and self.is_comm_op_valid_to_overlap( + ops[op_id - 1] + ): + break + + print( + f"forward_op_to_overlap : {ops[op_id]}, cost = {self._op_cost(ops[op_id])}" + ) + forward_cost_to_overlap += self._op_cost(ops[op_id]) + + splitted_op_ids.append(op_id) + if op_id >= num_ops: + break + + is_forward_split_point = lambda program, op_id: op_id in splitted_op_ids + ( + splitted_forward_job_types, + splitted_forward_programs, + ) = self._split_program_for_overlapping( + FORWARD, forward_program, is_forward_split_point + ) + self._multistreaming_for_overlapping(splitted_forward_programs) + + # Rearrange splitted chunks for BACKWARD and FORWARD + self.jobs_in_stable_phase.clear() + num_splitted_forward_jobs = len(splitted_forward_job_types) + num_splitted_backward_jobs = len(splitted_backward_job_types) + for idx in range( + max(num_splitted_forward_jobs, num_splitted_backward_jobs) + ): + if idx < num_splitted_backward_jobs: + self.jobs_in_stable_phase.append( + splitted_backward_job_types[idx] + ) + if idx < num_splitted_forward_jobs: + self.jobs_in_stable_phase.append( + splitted_forward_job_types[idx] + ) + + return ( + splitted_backward_job_types, + splitted_backward_programs, + splitted_forward_job_types, + splitted_forward_programs, + ) def _create_job_list(self): num_micro_batches = self.get_attr("num_micro_batches") @@ -143,123 +239,45 @@ def _create_job_list(self): job_list.append(opt_job) return job_list - def _cost(self, op_type): - cost = { - "recv_v2": 0.229, - "c_allreduce_sum": float( - "INF" - ), # ONLY for Forward, set the cost of c_allreduce_sum as INF so all of them will be splitted to the end of a chunk. - "cast": 0.052, - "c_embedding": 0.061, - "lookup_table_v2": 0.047, - "elementwise_add": 0.051, - "layer_norm": 0.086, - "c_identity": 0.037, - "matmul_v2": 0.660, - "split": 0.070, - "transpose2": 0.030, - "scale": 0.019, - "fused_softmax_mask_upper_triangle": 0.284, - "gelu": 0.128, - } - return cost[op_type] if op_type in cost else 0.0 - def _multistreaming_for_overlapping(self, programs): + # TODO(Ruibiao): Add cross-program event dependency for multi-stream. for program in programs: last_op = program.global_block().ops[-1] - if self.is_comm_op(last_op) and last_op.attr("use_calc_stream"): + if self.is_comm_op_valid_to_overlap(last_op): last_op.dist_attr.execution_stream = ( AutoParallelStreamType.MP_STREAM.value ) + def _op_cost(self, op): + try: + return calc_time_by_cost_model(op) + except: + logger.info(f"The cost of {op} is unknown.") + return 0.0 + def _partial_programs(self, program): types = [LR, FORWARD, BACKWARD, OPT] sub_programs = _program_for_fthenb_and_1f1b(program) - num_comm_op_in_backward_forward_overlap = self.get_attr( - "num_comm_op_in_backward_forward_overlap" + enable_backward_forward_overlap = self.get_attr( + "enable_backward_forward_overlap" ) - assert ( - num_comm_op_in_backward_forward_overlap >= 0 - ), f"Get num_comm_op_in_backward_forward_overlap = {num_comm_op_in_backward_forward_overlap}, which should be >= 0." - - if num_comm_op_in_backward_forward_overlap > 0: - logger.info( - f"Backward forward overlap enabled in 1F1B, num_comm_op_in_backward_forward_overlap = {num_comm_op_in_backward_forward_overlap}." - ) - - # Split FORWARD - forward_program = sub_programs[1] - ops = forward_program.global_block().ops - num_ops = len(ops) - - costs = [self._cost(op.type) for op in ops] - prefix_cost = 0 - duration_for_overlap = 0.771 # cost of allreduce in BACKWARD - splitted_op_ids = [] - for op_id, op in enumerate(ops): - if prefix_cost > duration_for_overlap: - splitted_op_ids.append(op_id) - prefix_cost = 0 - if ( - len(splitted_op_ids) + 1 - >= num_comm_op_in_backward_forward_overlap - ): - break - - prefix_cost += self._cost(op.type) - - is_forward_split_point = ( - lambda program, op_id: op_id in splitted_op_ids - ) + if enable_backward_forward_overlap: + logger.info("Backward forward overlap enabled in 1F1B.") + forward_program, backward_program = sub_programs[1], sub_programs[2] ( + splitted_backward_job_types, + splitted_backward_programs, splitted_forward_job_types, splitted_forward_programs, - ) = self._split_program_for_overlapping( - FORWARD, forward_program, is_forward_split_point + ) = self._backward_forward_overlap( + backward_program, forward_program ) - self._multistreaming_for_overlapping(splitted_forward_programs) - types += splitted_forward_job_types - sub_programs += splitted_forward_programs - - # Split BACKWARD - backward_program = sub_programs[2] - comm_op_ids = [ - op_id - for op_id, op in enumerate(backward_program.global_block().ops) - if self.is_comm_op(op) - ] - is_backward_split_point = ( - lambda program, op_id: op_id - 1 in comm_op_ids - and len(comm_op_ids) - comm_op_ids.index(op_id - 1) - < num_comm_op_in_backward_forward_overlap + types += splitted_forward_job_types + splitted_backward_job_types + sub_programs += ( + splitted_forward_programs + splitted_backward_programs ) - ( - splitted_backward_job_types, - splitted_backward_programs, - ) = self._split_program_for_overlapping( - BACKWARD, backward_program, is_backward_split_point - ) - self._multistreaming_for_overlapping(splitted_backward_programs) - types += splitted_backward_job_types - sub_programs += splitted_backward_programs - - # Rearrange splitted chunks for BACKWARD and FORWARD - self.jobs_in_stable_phase.clear() - num_splitted_forward_jobs = len(splitted_forward_job_types) - num_splitted_backward_jobs = len(splitted_backward_job_types) - for idx in range( - max(num_splitted_forward_jobs, num_splitted_backward_jobs) - ): - if idx < num_splitted_backward_jobs: - self.jobs_in_stable_phase.append( - splitted_backward_job_types[idx] - ) - if idx < num_splitted_forward_jobs: - self.jobs_in_stable_phase.append( - splitted_forward_job_types[idx] - ) for i in range(len(types)): print(f"type = {types[i]}, sub_programs = {sub_programs[i]}\n") @@ -290,8 +308,12 @@ def _split_program_for_overlapping(self, job_type, program, is_split_point): return splitted_job_types, splitted_programs - def is_comm_op(self, op): - return op.type == "c_allreduce_sum" + def is_comm_op_valid_to_overlap(self, op): + return ( + op.type == "c_allreduce_sum" + and op.dist_attr.execution_stream + == AutoParallelStreamType.CALC_STREAM.value + ) def apply_pass(main_program, startup_program, pass_name, pass_attr={}): @@ -303,8 +325,8 @@ def apply_pass(main_program, startup_program, pass_name, pass_attr={}): ) if pass_name == "1F1B": - pass_attr["num_comm_op_in_backward_forward_overlap"] = int( - os.environ.get("FLAGS_num_comm_op_in_backward_forward_overlap", 0) + pass_attr["enable_backward_forward_overlap"] = int( + os.environ.get("FLAGS_1f1b_backward_forward_overlap", 0) ) pipeline_pass = new_pass("pipeline_scheduler_" + pass_name, pass_attr) From 57e09f87362e56c33492a9b3bf3bafa09f9d65f3 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Fri, 8 Sep 2023 06:27:01 +0000 Subject: [PATCH 04/11] Insert reshape for ColumnParallelLinearBackwardOverlappingPass --- ...mn_parallel_linear_backward_overlapping.py | 76 ++++++++++++++++++- 1 file changed, 72 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py b/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py index 76f4ca628c9b9..aa5dbd7d267e1 100644 --- a/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py +++ b/python/paddle/distributed/passes/column_parallel_linear_backward_overlapping.py @@ -68,6 +68,29 @@ def _get_all_matmul_grad_and_allreduce_pairs(self, block): matmul_grad_id_to_allreduce_id[i] = j return matmul_grad_id_to_allreduce_id + def _insert_reshape_op(self, block, index, x, shape, op_role, out=None): + var_x = block.var(x[0]) + if out is None: + out = block.create_var( + name=f"{x[0]}@reshape.out", + dtype=var_x.dtype, + persistable=False, + ) + x_shape = block.create_var( + name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype + ) + + block._insert_op_without_sync( + index=index, + type="reshape2", + inputs={"X": x}, + outputs={"Out": out, "XShape": x_shape}, + attrs={"shape": shape, "op_role": op_role}, + ) + block._sync_with_cpp() + + return out + def _split_matmul_grad_and_multi_streaming_allreduce( self, block, matmul_grad_id_to_allreduce_id ): @@ -103,19 +126,64 @@ def _split_matmul_grad_and_multi_streaming_allreduce( # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then # the matmul op before it cannot be overlapped. + var_x = block.var(x[0]) + var_out_grad = block.var(out_grad[0]) + var_y_grad = block.var(y_grad[0]) + + x_dims = var_x.shape + out_grad_dims = var_out_grad.shape + y_grad_dims = var_y_grad.shape + + assert len(x_dims) == len( + out_grad_dims + ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." + if len(x_dims) > 2: + assert ( + x_dims[0:2] == out_grad_dims[0:2] + ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." + new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) + new_out_grad_dims = [ + out_grad_dims[0] * out_grad_dims[1] + ] + list(out_grad_dims[2:]) + + # NOTE(Ruibiao): Why insert reshape op here? + # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. + # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. + # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. + new_x = self._insert_reshape_op( + block, allreduce_id + 1, x, new_x_dims, op_role + ) + new_out_grad = self._insert_reshape_op( + block, allreduce_id + 2, out_grad, new_out_grad_dims, op_role + ) + new_y_grad = block.create_var( + name=f"{y_grad[0]}@reshape.out", + dtype=var_y_grad.dtype, + persistable=False, + ) block._insert_op_without_sync( - allreduce_id + 1, + index=allreduce_id + 3, type="matmul_v2", - inputs={"X": x, "Y": out_grad}, - outputs={"Out": y_grad}, + inputs={"X": new_x, "Y": new_out_grad}, + outputs={"Out": new_y_grad}, attrs={"trans_x": True, "trans_y": False, "op_role": op_role}, ) + self._insert_reshape_op( + block, + allreduce_id + 4, + [new_y_grad.name], + y_grad_dims, + op_role, + y_grad, + ) + block._insert_op_without_sync( - matmul_grad_id + 1, + index=matmul_grad_id + 1, type="matmul_v2", inputs={"X": out_grad, "Y": y}, outputs={"Out": x_grad}, attrs={"trans_x": False, "trans_y": True, "op_role": op_role}, ) + block._remove_op(matmul_grad_id) block._sync_with_cpp() From ab00cd7fd53a4a98d87796e8e22313306b620399 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 11 Sep 2023 07:43:05 +0000 Subject: [PATCH 05/11] Add cross-program event dependency --- .../operators/collective/c_identity_op.cc | 4 +++ .../paddle/distributed/passes/pass_utils.py | 16 +++++----- .../passes/pipeline_scheduler_pass.py | 31 +++++++++++++++++-- .../test_standalone_custom_event.py | 2 +- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc index c067c061b8613..87cbea9f21548 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -78,6 +78,9 @@ class CIdentityOpGradMaker : public framework::SingleGradOpMaker { retv->SetAttrMap(this->Attrs()); } }; + +DECLARE_INPLACE_OP_INFERER(IdentityInplaceInferer, {"X", "Out"}); + } // namespace operators } // namespace paddle @@ -92,4 +95,5 @@ REGISTER_OPERATOR(c_identity, ops::CIdentityOpGradMaker, ops::CIdentityOpGradMaker, ops::CIdentityOpMaker, + ops::IdentityInplaceInferer, CIdentityShapeFunctor); diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index 1f62c6de84cba..311621e5e774c 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -550,18 +550,18 @@ def _add_ops_into_block(src_block, dst_block, ops): return [lr_prog, fwd_prog, bwd_prog, opt_prog] -def _add_event_dependency(recorder_op_desc, waiter_op_desc): +def _add_event_dependency(recorder_op, waiter_op): ''' Add the extra event dependcy of the two operators. This function mainly aims for the cross-programs in pipeline parallelism, especial for the 'send_v2' 'recv_v2' etc. ''' - if not recorder_op_desc.dist_attr.force_record_event: - recorder_op_desc.dist_attr.force_record_event = True - # NOTE(lizhiyu): Here is the copy of 'waiter_op_desc.dist_attr.events_to_wait' not the reference, + if not recorder_op.dist_attr.force_record_event: + recorder_op.dist_attr.force_record_event = True + # NOTE(lizhiyu): Here is the copy of 'waiter_op.dist_attr.events_to_wait' not the reference, # because the type of 'events_to_wait' is 'const vector&' while the type of # 'waiter_wait_list' is python list. - waiter_wait_list = waiter_op_desc.dist_attr.events_to_wait - if recorder_op_desc.dist_attr.event_to_record not in waiter_wait_list: - waiter_wait_list.append(recorder_op_desc.dist_attr.event_to_record) - waiter_op_desc.dist_attr.events_to_wait = waiter_wait_list + waiter_wait_list = waiter_op.dist_attr.events_to_wait + if recorder_op.dist_attr.event_to_record not in waiter_wait_list: + waiter_wait_list.append(recorder_op.dist_attr.event_to_record) + waiter_op.dist_attr.events_to_wait = waiter_wait_list diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 9b6e53b84a570..168ecfb00a706 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -22,6 +22,7 @@ from .pass_base import PassContext, new_pass, register_pass from .pass_utils import ( AutoParallelStreamType, + _add_event_dependency, _program_for_fthenb_and_1f1b, split_program, ) @@ -240,16 +241,42 @@ def _create_job_list(self): return job_list def _multistreaming_for_overlapping(self, programs): - # TODO(Ruibiao): Add cross-program event dependency for multi-stream. - for program in programs: + num_programs = len(programs) + for program_id, program in enumerate(programs): last_op = program.global_block().ops[-1] if self.is_comm_op_valid_to_overlap(last_op): last_op.dist_attr.execution_stream = ( AutoParallelStreamType.MP_STREAM.value ) + # Add cross-program event dependency + prior_op_input_arg_names = last_op.input_arg_names + prior_op_output_arg_names = last_op.output_arg_names + for i in range(program_id + 1, num_programs): + posterior_ops = programs[i].global_block().ops + num_posterior_ops = len(posterior_ops) + for op_id in range(num_posterior_ops): + posterior_op = posterior_ops[op_id] + posterior_op_input_arg_names = ( + posterior_op.input_arg_names + ) + posterior_op_output_arg_names = ( + posterior_op.output_arg_names + ) + if ( + set(prior_op_input_arg_names) + & set(posterior_op_output_arg_names) + or set(prior_op_output_arg_names) + & set(posterior_op_input_arg_names) + or set(prior_op_output_arg_names) + & set(posterior_op_output_arg_names) + ): + _add_event_dependency(last_op, posterior_op) def _op_cost(self, op): try: + # TODO(Ruibiao): c_identity is redundant in auto parallel, it is temporarily set as inplace-op and do nothing in kernel, remove it later. + if op.type == "c_identity": + return 0 return calc_time_by_cost_model(op) except: logger.info(f"The cost of {op} is unknown.") diff --git a/test/standalone_executor/test_standalone_custom_event.py b/test/standalone_executor/test_standalone_custom_event.py index e65ed021e7972..246e2395e4b22 100644 --- a/test/standalone_executor/test_standalone_custom_event.py +++ b/test/standalone_executor/test_standalone_custom_event.py @@ -102,7 +102,7 @@ def split_program(self, prog, apply_mannual_event=False): if apply_mannual_event: for waiter, recorders in waiter_recorder_events_map.items(): for recorder in recorders: - _add_event_dependency(ops[recorder].desc, ops[waiter].desc) + _add_event_dependency(ops[recorder], ops[waiter]) main_progs, _, _ = split_program(prog, [11]) return main_progs From 8d97c8fca8e87b79882529d630be654495d03dcd Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Tue, 12 Sep 2023 06:25:52 +0000 Subject: [PATCH 06/11] Refine split program in _backward_forward_overlap --- .../passes/pipeline_scheduler_pass.py | 122 +++++++++--------- 1 file changed, 63 insertions(+), 59 deletions(-) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 168ecfb00a706..0ca8391c16531 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -101,80 +101,92 @@ def __init__(self): def _backward_forward_overlap(self, backward_program, forward_program): logger.info("Backward forward overlap enabled in 1F1B.") print(f"backward_program :: {backward_program}") - print(f"fowr_program :: {forward_program}") - # Split BACKWARD - valid_comm_op_ids = [ - op_id - for op_id, op in enumerate(backward_program.global_block().ops) - if self.is_comm_op_valid_to_overlap(op) - ] - # TODO(Ruibiao): Constrain the number of valid comm ops to resolve the potential memory explosion issue. - is_backward_split_point = ( - lambda program, op_id: op_id - 1 in valid_comm_op_ids + print(f"forward_program :: {forward_program}") + # Split program + backward_ops, forward_ops = ( + backward_program.global_block().ops, + forward_program.global_block().ops, ) - ( - splitted_backward_job_types, - splitted_backward_programs, - ) = self._split_program_for_overlapping( - BACKWARD, backward_program, is_backward_split_point - ) - self._multistreaming_for_overlapping(splitted_backward_programs) + num_backward_ops, num_forward_ops = len(backward_ops), len(forward_ops) + backward_split_points, forward_split_points = [], [] + backward_op_id, forward_op_id = 0, 0 - # Split FORWARD - ops = forward_program.global_block().ops - num_ops = len(ops) - splitted_op_ids = [] - op_id = 0 - for splitted_backward_program in splitted_backward_programs: - backward_op_to_overlap = ( - splitted_backward_program.global_block().ops[-1] - ) + while ( + backward_op_id < num_backward_ops + and forward_op_id < num_forward_ops + ): + # TODO(Ruibiao): Constrain the number of valid comm ops to resolve the potential memory explosion issue. + while ( + backward_op_id < num_backward_ops + and not self.is_comm_op_valid_to_overlap( + backward_ops[backward_op_id] + ) + ): + backward_op_id += 1 + + if backward_op_id >= num_backward_ops: + break + + backward_op_to_overlap = backward_ops[backward_op_id] backward_cost_to_overlap = self._op_cost(backward_op_to_overlap) + backward_op_id += 1 - forward_cost_to_overlap = self._op_cost(ops[op_id]) + forward_op_to_overlap = forward_ops[forward_op_id] + forward_cost_to_overlap = self._op_cost(forward_op_to_overlap) print( f"backward_op_to_overlap : {backward_op_to_overlap}, cost = {backward_cost_to_overlap}" ) print( - f"forward_op_to_overlap : {ops[op_id]}, cost = {forward_cost_to_overlap}" + f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {forward_cost_to_overlap}" ) while ( - op_id < num_ops - and forward_cost_to_overlap <= backward_cost_to_overlap + forward_op_id < num_forward_ops + and backward_cost_to_overlap >= forward_cost_to_overlap + and ( + forward_op_id <= 0 + or not self.is_comm_op_valid_to_overlap( + forward_ops[forward_op_id - 1] + ) + ) ): - op_id += 1 - op = ops[op_id] - # Force split when meet comm op since it cannot overlap with comm op in backward. - if op_id > 0 and self.is_comm_op_valid_to_overlap( - ops[op_id - 1] - ): - break - + forward_op_id += 1 + forward_op_to_overlap = forward_ops[forward_op_id] + forward_cost_to_overlap += self._op_cost(forward_op_to_overlap) print( - f"forward_op_to_overlap : {ops[op_id]}, cost = {self._op_cost(ops[op_id])}" + f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {self._op_cost(forward_op_to_overlap)}" ) - forward_cost_to_overlap += self._op_cost(ops[op_id]) - splitted_op_ids.append(op_id) - if op_id >= num_ops: - break + if ( + not forward_split_points + or forward_op_id > forward_split_points[-1] + ): + backward_split_points.append(backward_op_id) + forward_split_points.append(forward_op_id) - is_forward_split_point = lambda program, op_id: op_id in splitted_op_ids + ( + splitted_backward_job_types, + splitted_backward_programs, + ) = self._split_program_for_overlapping( + BACKWARD, backward_program, backward_split_points + ) ( splitted_forward_job_types, splitted_forward_programs, ) = self._split_program_for_overlapping( - FORWARD, forward_program, is_forward_split_point + FORWARD, forward_program, forward_split_points ) + + self._multistreaming_for_overlapping(splitted_backward_programs) self._multistreaming_for_overlapping(splitted_forward_programs) # Rearrange splitted chunks for BACKWARD and FORWARD self.jobs_in_stable_phase.clear() - num_splitted_forward_jobs = len(splitted_forward_job_types) - num_splitted_backward_jobs = len(splitted_backward_job_types) + num_splitted_backward_jobs, num_splitted_forward_jobs = len( + splitted_backward_job_types + ), len(splitted_forward_job_types) for idx in range( - max(num_splitted_forward_jobs, num_splitted_backward_jobs) + max(num_splitted_backward_jobs, num_splitted_forward_jobs) ): if idx < num_splitted_backward_jobs: self.jobs_in_stable_phase.append( @@ -275,7 +287,7 @@ def _multistreaming_for_overlapping(self, programs): def _op_cost(self, op): try: # TODO(Ruibiao): c_identity is redundant in auto parallel, it is temporarily set as inplace-op and do nothing in kernel, remove it later. - if op.type == "c_identity": + if op.type == "c_identity" or op.type == "recv_v2": return 0 return calc_time_by_cost_model(op) except: @@ -312,21 +324,13 @@ def _partial_programs(self, program): return types, sub_programs - def _split_program_for_overlapping(self, job_type, program, is_split_point): + def _split_program_for_overlapping(self, job_type, program, split_points): assert job_type in [ FORWARD, BACKWARD, ], f"job_type should be one of {[FORWARD, BACKWARD]}" - ops = program.global_block().ops - num_ops = len(ops) - - split_ids = [] - for op_id in range(1, num_ops): - if is_split_point(program, op_id): - split_ids.append(op_id) - - splitted_programs, __, __ = split_program(program, split_ids) + splitted_programs, __, __ = split_program(program, split_points) splitted_job_types = [] num_splitted_programs = len(splitted_programs) From 91d5e42fec7dab6212dd6c82f01676dc54f7c382 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 18 Sep 2023 14:34:12 +0800 Subject: [PATCH 07/11] Add empirical op cost --- .../auto_parallel/static/cluster.py | 34 ++++++- .../auto_parallel/static/cost/base_cost.py | 36 ++++++-- .../auto_parallel/static/parallelizer_v2.py | 16 ++++ .../passes/pipeline_scheduler_pass.py | 92 ++++++++++++++----- 4 files changed, 144 insertions(+), 34 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/static/cluster.py b/python/paddle/distributed/auto_parallel/static/cluster.py index 47d9711367765..e5eb7b25002e6 100644 --- a/python/paddle/distributed/auto_parallel/static/cluster.py +++ b/python/paddle/distributed/auto_parallel/static/cluster.py @@ -61,6 +61,8 @@ def __init__(self, global_id, local_id, machine): self._dp_gflops = None # Single precision GFLOPS self._sp_gflops = None + # Half precision GFLOPS + self._hp_gflops = None # Memory is stored by GB self._memory = None @@ -120,6 +122,14 @@ def sp_gflops(self): def sp_gflops(self, value): self._sp_gflops = value + @property + def hp_gflops(self): + return self._hp_gflops + + @hp_gflops.setter + def hp_gflops(self, value): + self._hp_gflops = value + @property def memory(self): return self._memory @@ -130,7 +140,7 @@ def memory(self, value): def __str__(self): str = "" - str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, memory: {}".format( + str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, hp_flops: {}, memory: {}".format( self.global_id, self.local_id, self.machine.id, @@ -138,6 +148,7 @@ def __str__(self): self.model, self.dp_gflops, self.sp_gflops, + self.hp_gflops, self.memory, ) return str @@ -443,6 +454,7 @@ def gen_default_config_cluster( intra_bandwidth=235, gpu_dp_gflops=7800, gpu_sp_gflops=15700, + gpu_hp_gflops=31400, cpu_dp_gflops=75, cpu_sp_gflops=150, ): @@ -524,8 +536,6 @@ def _convert_to_cpu_info(cpu_model): local_id += 1 type = _convert_to_type(gpu_model) model = _convert_to_model(gpu_model, gpu_memory) - dp_gflops = gpu_dp_gflops - sp_gflops = gpu_dp_gflops memory = gpu_memory device["global_id"] = global_id @@ -533,8 +543,9 @@ def _convert_to_cpu_info(cpu_model): device["type"] = type device["model"] = model device["memory"] = memory - device["sp_gflops"] = sp_gflops - device["dp_gflops"] = dp_gflops + device["sp_gflops"] = gpu_sp_gflops + device["dp_gflops"] = gpu_dp_gflops + device["hp_gflops"] = gpu_hp_gflops # hard code device["type"] = "GPU" global_id_to_device_type[global_id] = type @@ -694,6 +705,7 @@ def _build_from_dict(self, cluster_info): device.model = device_info.get("model", None) device.dp_gflops = float(device_info.get("dp_gflops", 0)) device.sp_gflops = float(device_info.get("sp_gflops", 0)) + device.hp_gflops = float(device_info.get("hp_gflops", 0)) device.memory = float(device_info.get("memory", 0)) self.add_device(device) self.add_machine(machine) @@ -909,10 +921,22 @@ def is_by_json_config(json_config): os.getenv("PADDLE_CURRENT_ENDPOINT", None), ) ) + + gflops_info = { + "V100": {"dp": 7800, "sp": 15700, "hp": 125000}, + "A100": {"dp": 9700, "sp": 19500, "hp": 624000}, + } + default_gflops = ( + gflops_info["A100"] if gpu_model == "A100" else gflops_info["V100"] + ) + cluster.gen_default_config_cluster( node_count=node_count, device_count=local_device_count, gpu_model=gpu_model, gpu_memory=memory, + gpu_dp_gflops=default_gflops["dp"], + gpu_sp_gflops=default_gflops["sp"], + gpu_hp_gflops=default_gflops["hp"], ) return cluster diff --git a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py index 58ab301ad99f8..3c9d9b58beb83 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py @@ -17,9 +17,10 @@ import numpy as np import paddle +from paddle.base.core import VarDesc from paddle.utils.flops import flops -from ..cluster import LinkType, get_default_cluster +from ..cluster import DeviceType, LinkType, get_default_cluster from ..dist_tensor import DistributedTensor from ..process_group import get_process_group from ..utils import _get_comm_group, _get_idx_in_axis @@ -936,7 +937,13 @@ def calc_time_by_cost_model(op, cluster=None): ) if not cluster: cluster = get_default_cluster() - time = 0.0 + + assert cluster._gpu_model in [ + "V100", + "A100", + ], "Only A100 and V100 gpu has been supported currently." + + time = 0.0 # microsecond op_type = op.type # calc comp op time by flops if op_type not in NON_COMP_TYPE: @@ -958,15 +965,30 @@ def calc_time_by_cost_model(op, cluster=None): else: flops_count = flops(op_type, inputs, attrs) - if cluster._gpu_model == "V100": - time = flops_count * 2.9e-7 * 2.6 - elif cluster._gpu_model == "A100": - time = flops_count * 2.9e-7 + # FIXME(Ruibiao): Need a better way to get dtype + var_name = op.output_arg_names[0] + dtype = op.block._var_recursive(var_name).dtype + device = cluster.get_device(0) + assert ( + device.type == DeviceType.GPU + ), "Only GPU device is supported currently." + + gflops = 0.0 + if dtype == VarDesc.VarType.FP64: + gflops = device.dp_gflops + elif dtype == VarDesc.VarType.FP32: + gflops = device.sp_gflops + elif dtype == VarDesc.VarType.FP16 or dtype == VarDesc.VarType.BF16: + gflops = device.hp_gflops else: raise ValueError( - "Only A100 and V100 gpu has been supported currently." + f"Unsupported modeling compute time for dtype: {dtype}." ) + print(f"flops_count = {flops_count}, gflops = {gflops}") + utilization_rate = 0.98 + time = flops_count / (utilization_rate * gflops) * 1e-3 + # calc comm op time by communication modeling formula elif op_type in COMM_OP_TYPE: op_cost = _g_op_cost_factory[op_type]( diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 37b1d9524b65e..7ffb42de093a4 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -438,3 +438,19 @@ def _apply_post_optimization( "pp_degree": len(self._dist_context.process_meshes), "pp_stage": get_pp_stage(self._dist_context, rank), } + + from paddle.distributed.auto_parallel.static.cost import ( + calc_time_by_cost_model, + ) + + for op in main_program.global_block().ops: + cost = 0.0 + try: + # TODO(Ruibiao): c_identity is redundant in auto parallel, it is temporarily set as inplace-op and do nothing in kernel, remove it later. + if op.type == "c_identity" or op.type == "recv_v2": + cost = 0.0 + cost = calc_time_by_cost_model(op) + except Exception as e: + print(f"The cost of {op} is unknown since {repr(e)}.") + + print(f"op : {op}, cost = {cost}") diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index 0ca8391c16531..be2e6ccbd3984 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -100,8 +100,6 @@ def __init__(self): # def _backward_forward_overlap(self, backward_program, forward_program): logger.info("Backward forward overlap enabled in 1F1B.") - print(f"backward_program :: {backward_program}") - print(f"forward_program :: {forward_program}") # Split program backward_ops, forward_ops = ( backward_program.global_block().ops, @@ -128,34 +126,39 @@ def _backward_forward_overlap(self, backward_program, forward_program): break backward_op_to_overlap = backward_ops[backward_op_id] - backward_cost_to_overlap = self._op_cost(backward_op_to_overlap) + backward_cost_to_overlap = 400 backward_op_id += 1 forward_op_to_overlap = forward_ops[forward_op_id] forward_cost_to_overlap = self._op_cost(forward_op_to_overlap) - print( + ''' + # Debug messages: + logger.info( f"backward_op_to_overlap : {backward_op_to_overlap}, cost = {backward_cost_to_overlap}" ) - print( + logger.info( f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {forward_cost_to_overlap}" ) + ''' while ( forward_op_id < num_forward_ops and backward_cost_to_overlap >= forward_cost_to_overlap - and ( - forward_op_id <= 0 - or not self.is_comm_op_valid_to_overlap( - forward_ops[forward_op_id - 1] - ) - ) ): forward_op_id += 1 forward_op_to_overlap = forward_ops[forward_op_id] forward_cost_to_overlap += self._op_cost(forward_op_to_overlap) - print( + ''' + # Debug messages: + logger.info( f"forward_op_to_overlap : {forward_op_to_overlap}, cost = {self._op_cost(forward_op_to_overlap)}" ) + ''' + + if self.is_comm_op_valid_to_overlap( + forward_ops[forward_op_id - 1] + ): + break if ( not forward_split_points @@ -177,8 +180,10 @@ def _backward_forward_overlap(self, backward_program, forward_program): FORWARD, forward_program, forward_split_points ) - self._multistreaming_for_overlapping(splitted_backward_programs) - self._multistreaming_for_overlapping(splitted_forward_programs) + self._multistreaming_for_overlapping( + splitted_backward_programs, BACKWARD + ) + self._multistreaming_for_overlapping(splitted_forward_programs, FORWARD) # Rearrange splitted chunks for BACKWARD and FORWARD self.jobs_in_stable_phase.clear() @@ -252,14 +257,17 @@ def _create_job_list(self): job_list.append(opt_job) return job_list - def _multistreaming_for_overlapping(self, programs): + def _multistreaming_for_overlapping(self, programs, job_type): num_programs = len(programs) + higher_stream_priority = -1 for program_id, program in enumerate(programs): last_op = program.global_block().ops[-1] if self.is_comm_op_valid_to_overlap(last_op): + # TODO(Ruibiao): Assign different stream to FORWAD and BACKWARD CommOps, and set a lower priority for FORWARD Comm stream. It can reduce the impact of FORWARD Comm on BACKWARD Comp. Now the defalut stream prirotiy in standalone executor is already the lowest priority (correspongding to 0 in V100), cannot set a lower one. Maybe we need to support setting default stream for executor. last_op.dist_attr.execution_stream = ( AutoParallelStreamType.MP_STREAM.value ) + last_op.dist_attr.stream_priority = higher_stream_priority # Add cross-program event dependency prior_op_input_arg_names = last_op.input_arg_names prior_op_output_arg_names = last_op.output_arg_names @@ -284,14 +292,52 @@ def _multistreaming_for_overlapping(self, programs): ): _add_event_dependency(last_op, posterior_op) + # TODO(Ruibiao): The cost here is just the experience value for a specific task (GPT-3-6.7B-MP2-PP4). A more genereal cost estimation scheme is required. def _op_cost(self, op): + handwritten_cost_map = { + "c_allreduce_sum": 0, + "elementwise_add": 40, + "split": 76, + "transpose2": 40, + "fused_softmax_mask_upper_triangle": 94, + "layer_norm": 55, + "gelu": 180, + "dropout": 160, + "c_identity": 0, + "recv_v2": 0, + } + + op_type = op.type + if op_type in handwritten_cost_map.keys(): + return handwritten_cost_map[op_type] + + if op_type == "matmul_v2": + var_name = op.output_arg_names[0] + shape = op.block._var_recursive(var_name).shape + if shape == (1, 1024, 6144): + return 399 + elif shape == (1, 16, 1024, 1024): + return 112 + elif shape == (1, 16, 1024, 128): + return 95 + elif shape == (1, 1024, 4096): + return 244 + + if op_type == "scale": + var_name = op.output_arg_names[0] + shape = op.block._var_recursive(var_name).shape + if shape == (1, 16, 1024, 128): + return 20 + if shape == (1, 16, 1024, 1024): + return 90 + try: - # TODO(Ruibiao): c_identity is redundant in auto parallel, it is temporarily set as inplace-op and do nothing in kernel, remove it later. - if op.type == "c_identity" or op.type == "recv_v2": - return 0 - return calc_time_by_cost_model(op) - except: - logger.info(f"The cost of {op} is unknown.") + time = calc_time_by_cost_model(op) + if op.type == "c_allreduce_sum": + time *= 8 + return time + except Exception as e: + logger.info(f"The cost of {op} is unknown since {repr(e)}.") return 0.0 def _partial_programs(self, program): @@ -319,7 +365,9 @@ def _partial_programs(self, program): ) for i in range(len(types)): - print(f"type = {types[i]}, sub_programs = {sub_programs[i]}\n") + logger.info( + f"type = {types[i]}, sub_programs = {sub_programs[i]}\n" + ) logger.info(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}") return types, sub_programs From 5e369a1aa1bb245c83c8287efea4f1c612bdaf8b Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 18 Sep 2023 14:39:53 +0800 Subject: [PATCH 08/11] Add NOTE --- python/paddle/distributed/passes/pipeline_scheduler_pass.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index be2e6ccbd3984..e4a803d48f089 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -404,6 +404,7 @@ def apply_pass(main_program, startup_program, pass_name, pass_attr={}): ) if pass_name == "1F1B": + # TODO(Ruibiao): Move FLAGS_1f1b_backward_forward_overlap and FLAGS_mp_async_allreduce_in_backward to auto parallel Strategy after these two optimizations are available. pass_attr["enable_backward_forward_overlap"] = int( os.environ.get("FLAGS_1f1b_backward_forward_overlap", 0) ) From 50e60c91437936771cd4b48791fdd266ca8ab30d Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 18 Sep 2023 14:54:27 +0800 Subject: [PATCH 09/11] Remove some redundant codes --- .../new_executor/interpreter/stream_analyzer.cc | 3 +-- .../fluid/operators/collective/c_identity_op.cc | 3 --- .../auto_parallel/static/cost/base_cost.py | 1 - .../distributed/auto_parallel/static/engine.py | 13 ------------- .../auto_parallel/static/parallelizer_v2.py | 16 ---------------- python/paddle/distributed/passes/pass_utils.py | 2 +- .../passes/pipeline_scheduler_pass.py | 14 +++++++++++--- 7 files changed, 13 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 1b1b7ee9eab51..3dc9175dbfd4b 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -608,8 +608,7 @@ void StreamAnalyzer::ShrinkEventInfo( platform::DeviceType StreamAnalyzer::GetWaiterType( const Instruction& instr) const { - if (instr.KernelType() == OpFuncType::kCpuSync || - instr.KernelType() == OpFuncType::kGpuSync) { + if (instr.KernelType() == OpFuncType::kCpuSync) { return platform::kCPU; } else { if (platform::is_xpu_place(place_)) { diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc index 87cbea9f21548..daeb52e4f3a0f 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -79,8 +79,6 @@ class CIdentityOpGradMaker : public framework::SingleGradOpMaker { } }; -DECLARE_INPLACE_OP_INFERER(IdentityInplaceInferer, {"X", "Out"}); - } // namespace operators } // namespace paddle @@ -95,5 +93,4 @@ REGISTER_OPERATOR(c_identity, ops::CIdentityOpGradMaker, ops::CIdentityOpGradMaker, ops::CIdentityOpMaker, - ops::IdentityInplaceInferer, CIdentityShapeFunctor); diff --git a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py index 3c9d9b58beb83..f89a03647cfcc 100644 --- a/python/paddle/distributed/auto_parallel/static/cost/base_cost.py +++ b/python/paddle/distributed/auto_parallel/static/cost/base_cost.py @@ -985,7 +985,6 @@ def calc_time_by_cost_model(op, cluster=None): f"Unsupported modeling compute time for dtype: {dtype}." ) - print(f"flops_count = {flops_count}, gflops = {gflops}") utilization_rate = 0.98 time = flops_count / (utilization_rate * gflops) * 1e-3 diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 98152d60c9511..0354c6517a1a2 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -1032,19 +1032,6 @@ def fit( return_numpy=self._strategy.return_numpy, ) - print( - f"memory_allocated = {paddle.device.cuda.memory_allocated()/1024/1024}" - ) - print( - f"max_memory_allocated = {paddle.device.cuda.max_memory_allocated()/1024/1024}" - ) - print( - f"memory_reserved = {paddle.device.cuda.memory_reserved()/1024/1024}" - ) - print( - f"max_memory_reserved = {paddle.device.cuda.max_memory_reserved()/1024/1024}" - ) - lr = auto_utils.get_lr(self.optimizer) logs = self._prepare_logger( outs, diff --git a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py index 7ffb42de093a4..37b1d9524b65e 100644 --- a/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/static/parallelizer_v2.py @@ -438,19 +438,3 @@ def _apply_post_optimization( "pp_degree": len(self._dist_context.process_meshes), "pp_stage": get_pp_stage(self._dist_context, rank), } - - from paddle.distributed.auto_parallel.static.cost import ( - calc_time_by_cost_model, - ) - - for op in main_program.global_block().ops: - cost = 0.0 - try: - # TODO(Ruibiao): c_identity is redundant in auto parallel, it is temporarily set as inplace-op and do nothing in kernel, remove it later. - if op.type == "c_identity" or op.type == "recv_v2": - cost = 0.0 - cost = calc_time_by_cost_model(op) - except Exception as e: - print(f"The cost of {op} is unknown since {repr(e)}.") - - print(f"op : {op}, cost = {cost}") diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index 311621e5e774c..5f3d1876401dc 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -36,7 +36,7 @@ ] -# NOTE Here stream is just a presentation with different name, +# NOTE: Here stream is just a presentation with different name, # it is up to executor to create the exact streams given the name. class AutoParallelStreamType(Enum): CALC_STREAM = "default" diff --git a/python/paddle/distributed/passes/pipeline_scheduler_pass.py b/python/paddle/distributed/passes/pipeline_scheduler_pass.py index e4a803d48f089..a473e7b095eaf 100644 --- a/python/paddle/distributed/passes/pipeline_scheduler_pass.py +++ b/python/paddle/distributed/passes/pipeline_scheduler_pass.py @@ -263,7 +263,12 @@ def _multistreaming_for_overlapping(self, programs, job_type): for program_id, program in enumerate(programs): last_op = program.global_block().ops[-1] if self.is_comm_op_valid_to_overlap(last_op): - # TODO(Ruibiao): Assign different stream to FORWAD and BACKWARD CommOps, and set a lower priority for FORWARD Comm stream. It can reduce the impact of FORWARD Comm on BACKWARD Comp. Now the defalut stream prirotiy in standalone executor is already the lowest priority (correspongding to 0 in V100), cannot set a lower one. Maybe we need to support setting default stream for executor. + # TODO(Ruibiao): Assign different stream to FORWAD and BACKWARD CommOps, + # and set a lower priority for FORWARD Comm stream. It can reduce the + # impact of FORWARD Comm on BACKWARD Comp. Now the defalut stream prirotiy + # in standalone executor is already the lowest priority (correspongding to + # 0 in V100), cannot set a lower one. Maybe we need to support setting + # default stream for executor. last_op.dist_attr.execution_stream = ( AutoParallelStreamType.MP_STREAM.value ) @@ -292,7 +297,8 @@ def _multistreaming_for_overlapping(self, programs, job_type): ): _add_event_dependency(last_op, posterior_op) - # TODO(Ruibiao): The cost here is just the experience value for a specific task (GPT-3-6.7B-MP2-PP4). A more genereal cost estimation scheme is required. + # TODO(Ruibiao): The cost here is just the experience value for a specific task (GPT-3-6.7B-MP2-PP4). + # A more genereal cost estimation scheme is required. def _op_cost(self, op): handwritten_cost_map = { "c_allreduce_sum": 0, @@ -404,7 +410,9 @@ def apply_pass(main_program, startup_program, pass_name, pass_attr={}): ) if pass_name == "1F1B": - # TODO(Ruibiao): Move FLAGS_1f1b_backward_forward_overlap and FLAGS_mp_async_allreduce_in_backward to auto parallel Strategy after these two optimizations are available. + # TODO(Ruibiao): Move FLAGS_1f1b_backward_forward_overlap and + # FLAGS_mp_async_allreduce_in_backward to auto parallel Strategy + # after these two optimizations are available. pass_attr["enable_backward_forward_overlap"] = int( os.environ.get("FLAGS_1f1b_backward_forward_overlap", 0) ) From 48295428604efbd3eec3ea904ae5c4a6de01d9db Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 18 Sep 2023 14:57:22 +0800 Subject: [PATCH 10/11] Remove some redundant codes --- paddle/fluid/operators/collective/c_identity_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/collective/c_identity_op.cc b/paddle/fluid/operators/collective/c_identity_op.cc index daeb52e4f3a0f..c067c061b8613 100644 --- a/paddle/fluid/operators/collective/c_identity_op.cc +++ b/paddle/fluid/operators/collective/c_identity_op.cc @@ -78,7 +78,6 @@ class CIdentityOpGradMaker : public framework::SingleGradOpMaker { retv->SetAttrMap(this->Attrs()); } }; - } // namespace operators } // namespace paddle From 0d36e204cef9bb82bd04f117bac1ba52d5200007 Mon Sep 17 00:00:00 2001 From: chenruibiao Date: Mon, 18 Sep 2023 16:33:53 +0800 Subject: [PATCH 11/11] Fix UTs --- test/standalone_executor/test_standalone_custom_event.py | 9 +++++---- .../test_standalone_executor_multi_micro_batch.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/test/standalone_executor/test_standalone_custom_event.py b/test/standalone_executor/test_standalone_custom_event.py index 246e2395e4b22..b87609841e6e4 100644 --- a/test/standalone_executor/test_standalone_custom_event.py +++ b/test/standalone_executor/test_standalone_custom_event.py @@ -19,7 +19,7 @@ from paddle.base.executor import _add_feed_fetch_ops, _StandaloneExecutor from paddle.distributed.passes.pass_utils import ( _add_event_dependency, - get_skip_gc_vars, + set_skip_gc_vars, split_program, ) @@ -112,7 +112,6 @@ def create_standalone_exe(self, main_progs, startup_progs, fetch_list): job_list = [] prog_num = len(main_progs) fetch_op_num = len(fetch_list) - skip_gc_vars = get_skip_gc_vars(main_progs) if prog_num == 1: # single prog main_progs[0] = _add_feed_fetch_ops( @@ -140,7 +139,6 @@ def create_standalone_exe(self, main_progs, startup_progs, fetch_list): # create jobs for program_id in range(prog_num): job = core.Job(f"prog_{program_id}") - job.set_skip_gc_vars(skip_gc_vars[program_id]) # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch if program_id == prog_num - 1: for i in range(fetch_op_num): @@ -152,8 +150,11 @@ def create_standalone_exe(self, main_progs, startup_progs, fetch_list): type_to_program = {} for program_id in range(prog_num): - type_to_program[f"prog_{program_id}"] = main_progs[program_id].desc + type_to_program[f"prog_{program_id}"] = main_progs[program_id] + set_skip_gc_vars(micro_batch_num, type_to_program, job_list) + for type in type_to_program.keys(): + type_to_program[type] = type_to_program[type].desc plan = core.Plan(job_list, type_to_program) scope = core.Scope() main_exe = _StandaloneExecutor(self.place, plan, scope) diff --git a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py index c9f2ace1b0b09..b829a69fa7f1b 100644 --- a/test/standalone_executor/test_standalone_executor_multi_micro_batch.py +++ b/test/standalone_executor/test_standalone_executor_multi_micro_batch.py @@ -197,10 +197,11 @@ def run_train(self, split=False, micro_batch_num=1): type_to_program = {} for program_id in range(program_num): - type_to_program[f"P{program_id}"] = programs[program_id].desc - + type_to_program[f"P{program_id}"] = programs[program_id] set_skip_gc_vars(micro_batch_num, type_to_program, job_list) + for type in type_to_program.keys(): + type_to_program[type] = type_to_program[type].desc plan = Plan(job_list, type_to_program) main_exe = _StandaloneExecutor(self.place, plan, scope)