diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index f39cb1ab34675..d64bf66c51192 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -31,6 +31,7 @@ EagerParamBase, Variable, default_main_program, + in_dygraph_mode, in_pir_mode, use_pir_api, ) @@ -1000,7 +1001,7 @@ def get_placement_with_sharding(param, sharding_mesh_axis): class _ShardOptimizer(Optimizer): - def __init__(self, optimizer, shard_fn=None): + def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1): assert ( optimizer is not None ), "The argument `optimizer` cannot be empty." @@ -1025,6 +1026,7 @@ def __init__(self, optimizer, shard_fn=None): self._shard_fn = shard_fn self._sharding_mesh_axis = None self._sharding_degree = None + self.gradient_accumulation_steps = gradient_accumulation_steps if isinstance( self._shard_fn, (ShardingStage1, ShardingStage2, ShardingStage3) @@ -1246,6 +1248,21 @@ def state_dict(self): return self._inner_opt.state_dict() def _append_optimize_op(self, block, param_and_grad): + if ( + in_auto_parallel_align_mode() # In align mode, we use enable_delay_scale_loss by default + and in_dygraph_mode() + and param_and_grad[1].is_dist() + ): + placements = param_and_grad[1].placements + meshs = param_and_grad[1].process_mesh + grad = param_and_grad[1] + + for i in range(len(placements) - 1, -1, -1): + if isinstance(placements[i], dist.Partial): + placements[i] = dist.Replicate() + grad = dist.reshard(grad, meshs, placements) + grad /= self.gradient_accumulation_steps + param_and_grad = (param_and_grad[0], grad) return self._inner_opt._append_optimize_op(block, param_and_grad) def __getattr__(self, item): @@ -1596,6 +1613,7 @@ def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor: def shard_optimizer( optimizer: Optimizer, shard_fn: Callable[[str, Tensor, Tensor], Tensor] | None = None, + gradient_accumulation_steps: int = 1, ) -> _ShardOptimizer: """ @@ -1640,7 +1658,7 @@ def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - return _ShardOptimizer(optimizer, shard_fn) + return _ShardOptimizer(optimizer, shard_fn, gradient_accumulation_steps) def shard_scaler(scaler: GradScaler) -> GradScaler: diff --git a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py index f03a05789ae30..524832bcd1895 100644 --- a/python/paddle/distributed/passes/auto_parallel_gradient_merge.py +++ b/python/paddle/distributed/passes/auto_parallel_gradient_merge.py @@ -636,6 +636,94 @@ def parse_program( return grad_to_gradient_merge +def _find_trival_optimizer_ops(block): + optimizer_ops = [] + for op in block.ops: + if "adam" in op.name() or "sgd" in op.name(): + optimizer_ops.append(op) + return optimizer_ops + + +def _get_prev_op(block, optimizer_op): + found = False + for op in reversed(block.ops): + if found: + return op + if op.id == optimizer_op.id: + found = True + return None + + +def _insert_scale_op_after(target_value, optimizer_op, scale, bias=0.0): + scaled_grad = paddle._C_ops.scale_(target_value, scale, bias, False) + + scale_op = scaled_grad.get_defining_op() + scale_op.op_role = int(OpRole.Optimize) + + full_op = scale_op.operand_source(1).get_defining_op() + assert ( + full_op.name() == "pd_op.full" + ), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}" + full_op.op_role = int(OpRole.Optimize) + + if "adam" in optimizer_op.name(): + optimizer_op.operand(1).set_source(scaled_grad) + elif "sgd" in optimizer_op.name(): + optimizer_op.operand(2).set_source(scaled_grad) + + +def _append_scale_op_before_comm(block, new_params_to_grads, k_steps): + for op in reversed(block.ops): + if op.op_role == int(OpRole.Backward): + paddle.pir.set_insertion_point_after(op) + break + for _, new_grad in new_params_to_grads: + new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False) + + scale_op = new_grad.get_defining_op() + scale_op.op_role = int(OpRole.Optimize) + + full_op = scale_op.operand_source(1).get_defining_op() + assert ( + full_op.name() == "pd_op.full" + ), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}" + full_op.op_role = int(OpRole.Optimize) + paddle.pir.set_insertion_point_to_block_end(block) + + +def _append_scale_op_after_comm(block, optimizer_ops, k_steps): + for optimizer_op in optimizer_ops: + target_value = None + if "adam" in optimizer_op.name(): # adam and adamw are included + target_value = optimizer_op.operand_source(1) + elif "sgd" in optimizer_op.name(): + target_value = optimizer_op.operand_source(2) + else: + raise NotImplementedError( + f"We yet support adamw, adam and sgd, but got {optimizer_op.name()}" + ) + assert ( + target_value is not None + ), "target_value is not expected to be None" + insertion_point = target_value.get_defining_op() + if insertion_point is None: + # target_value is a gradient_merge_var, which hasn't defining_op + # so we find the prev op of optimizer_op, inserting a scale op behind. + insertion_point = _get_prev_op(block, optimizer_op) + paddle.pir.set_insertion_point_after(insertion_point) + _insert_scale_op_after(target_value, optimizer_op, 1.0 / k_steps) + paddle.pir.set_insertion_point_to_block_end(block) + + +def _pir_append_scale_op(program, new_params_to_grads, k_steps): + block = program.global_block() + optimizer_ops = _find_trival_optimizer_ops(block) + if len(optimizer_ops) > 0: + _append_scale_op_after_comm(block, optimizer_ops, k_steps) + else: + _append_scale_op_before_comm(block, new_params_to_grads, k_steps) + + def _pir_parse_program( main_program, startup_program, @@ -657,22 +745,7 @@ def _pir_parse_program( # step3: append scale op if avg: - main_block = main_program.global_block() - for op in reversed(main_block.ops): - if op.op_role == int(OpRole.Backward): - paddle.pir.set_insertion_point_after(op) - break - for _, new_grad in new_params_to_grads: - new_grad = paddle._C_ops.scale_(new_grad, 1.0 / k_steps, 0.0, False) - - scale_op = new_grad.get_defining_op() - scale_op.op_role = int(OpRole.Optimize) - - full_op = scale_op.operand_source(1).get_defining_op() - assert ( - full_op.name() == "pd_op.full" - ), f"The defining op of the scale value should be `pd_op.full`, but got {full_op.name()}" - full_op.op_role = int(OpRole.Optimize) + _pir_append_scale_op(main_program, new_params_to_grads, k_steps) @register_pass("auto_parallel_gradient_merge_pass")