Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MP overlap for 1f1b #57446

Merged
merged 14 commits into from
Sep 19, 2023
34 changes: 29 additions & 5 deletions python/paddle/distributed/auto_parallel/static/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, global_id, local_id, machine):
self._dp_gflops = None
# Single precision GFLOPS
self._sp_gflops = None
# Half precision GFLOPS
self._hp_gflops = None
# Memory is stored by GB
self._memory = None

Expand Down Expand Up @@ -120,6 +122,14 @@ def sp_gflops(self):
def sp_gflops(self, value):
self._sp_gflops = value

@property
def hp_gflops(self):
return self._hp_gflops

@hp_gflops.setter
def hp_gflops(self, value):
self._hp_gflops = value

@property
def memory(self):
return self._memory
Expand All @@ -130,14 +140,15 @@ def memory(self, value):

def __str__(self):
str = ""
str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, memory: {}".format(
str += "global_id: {}, local_id: {}, machine_id: {}, type: {}, model: {}, dp_flops: {}, sp_flops: {}, hp_flops: {}, memory: {}".format(
self.global_id,
self.local_id,
self.machine.id,
self.type.name,
self.model,
self.dp_gflops,
self.sp_gflops,
self.hp_gflops,
self.memory,
)
return str
Expand Down Expand Up @@ -443,6 +454,7 @@ def gen_default_config_cluster(
intra_bandwidth=235,
gpu_dp_gflops=7800,
gpu_sp_gflops=15700,
gpu_hp_gflops=31400,
cpu_dp_gflops=75,
cpu_sp_gflops=150,
):
Expand Down Expand Up @@ -524,17 +536,16 @@ def _convert_to_cpu_info(cpu_model):
local_id += 1
type = _convert_to_type(gpu_model)
model = _convert_to_model(gpu_model, gpu_memory)
dp_gflops = gpu_dp_gflops
sp_gflops = gpu_dp_gflops
memory = gpu_memory

device["global_id"] = global_id
device["local_id"] = local_id
device["type"] = type
device["model"] = model
device["memory"] = memory
device["sp_gflops"] = sp_gflops
device["dp_gflops"] = dp_gflops
device["sp_gflops"] = gpu_sp_gflops
device["dp_gflops"] = gpu_dp_gflops
device["hp_gflops"] = gpu_hp_gflops
# hard code
device["type"] = "GPU"
global_id_to_device_type[global_id] = type
Expand Down Expand Up @@ -694,6 +705,7 @@ def _build_from_dict(self, cluster_info):
device.model = device_info.get("model", None)
device.dp_gflops = float(device_info.get("dp_gflops", 0))
device.sp_gflops = float(device_info.get("sp_gflops", 0))
device.hp_gflops = float(device_info.get("hp_gflops", 0))
device.memory = float(device_info.get("memory", 0))
self.add_device(device)
self.add_machine(machine)
Expand Down Expand Up @@ -909,10 +921,22 @@ def is_by_json_config(json_config):
os.getenv("PADDLE_CURRENT_ENDPOINT", None),
)
)

gflops_info = {
"V100": {"dp": 7800, "sp": 15700, "hp": 125000},
"A100": {"dp": 9700, "sp": 19500, "hp": 624000},
}
default_gflops = (
gflops_info["A100"] if gpu_model == "A100" else gflops_info["V100"]
)

cluster.gen_default_config_cluster(
node_count=node_count,
device_count=local_device_count,
gpu_model=gpu_model,
gpu_memory=memory,
gpu_dp_gflops=default_gflops["dp"],
gpu_sp_gflops=default_gflops["sp"],
gpu_hp_gflops=default_gflops["hp"],
)
return cluster
35 changes: 28 additions & 7 deletions python/paddle/distributed/auto_parallel/static/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import numpy as np

import paddle
from paddle.base.core import VarDesc
from paddle.utils.flops import flops

from ..cluster import LinkType, get_default_cluster
from ..cluster import DeviceType, LinkType, get_default_cluster
from ..dist_tensor import DistributedTensor
from ..process_group import get_process_group
from ..utils import _get_comm_group, _get_idx_in_axis
Expand Down Expand Up @@ -936,7 +937,13 @@ def calc_time_by_cost_model(op, cluster=None):
)
if not cluster:
cluster = get_default_cluster()
time = 0.0

assert cluster._gpu_model in [
"V100",
"A100",
], "Only A100 and V100 gpu has been supported currently."

time = 0.0 # microsecond
op_type = op.type
# calc comp op time by flops
if op_type not in NON_COMP_TYPE:
Expand All @@ -958,15 +965,29 @@ def calc_time_by_cost_model(op, cluster=None):
else:
flops_count = flops(op_type, inputs, attrs)

if cluster._gpu_model == "V100":
time = flops_count * 2.9e-7 * 2.6
elif cluster._gpu_model == "A100":
time = flops_count * 2.9e-7
# FIXME(Ruibiao): Need a better way to get dtype
var_name = op.output_arg_names[0]
dtype = op.block._var_recursive(var_name).dtype
device = cluster.get_device(0)
assert (
device.type == DeviceType.GPU
), "Only GPU device is supported currently."

gflops = 0.0
if dtype == VarDesc.VarType.FP64:
gflops = device.dp_gflops
elif dtype == VarDesc.VarType.FP32:
gflops = device.sp_gflops
elif dtype == VarDesc.VarType.FP16 or dtype == VarDesc.VarType.BF16:
gflops = device.hp_gflops
else:
raise ValueError(
"Only A100 and V100 gpu has been supported currently."
f"Unsupported modeling compute time for dtype: {dtype}."
)

utilization_rate = 0.98
time = flops_count / (utilization_rate * gflops) * 1e-3

# calc comm op time by communication modeling formula
elif op_type in COMM_OP_TYPE:
op_cost = _g_op_cost_factory[op_type](
Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/static/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,7 @@ def fit(
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy,
)

lr = auto_utils.get_lr(self.optimizer)
logs = self._prepare_logger(
outs,
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/distributed/auto_parallel/static/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import logging
import os
import time

from paddle.distributed.passes import PassManager, new_pass
Expand Down Expand Up @@ -354,6 +355,17 @@ def _apply_post_optimization(
)
params_grads = self._pass_context.get_attr("params_grads")

mp_async_allreduce_in_backward = os.getenv(
"FLAGS_mp_async_allreduce_in_backward"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could use config like:
config["use_sharding"] = self._strategy.sharding.enable
for switch

) in [1, "1", True, "True"]
if mp_async_allreduce_in_backward:
column_parallel_linear_backward_overlapping_pass = new_pass(
"column_parallel_linear_backward_overlapping", {}
)
column_parallel_linear_backward_overlapping_pass.apply(
[main_program], [startup_program], self._pass_context
)

if self.is_train:
# GradClip is train-only optimization
config = copy.deepcopy(self._strategy.sharding.to_dict())
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/distributed/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from .pass_base import new_pass, PassManager, PassContext
from .fuse_all_reduce import * # noqa: F403

from .auto_parallel_gradient_merge import * # noqa: F403
from .auto_parallel_sharding import * # noqa: F403
from .auto_parallel_amp import * # noqa: F403
Expand All @@ -24,11 +24,14 @@
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403
from .pipeline_scheduler_pass import * # noqa: F403
from .column_parallel_linear_backward_overlapping import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .fuse_all_reduce import * # noqa: F403
from .pipeline_scheduler_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403


__all__ = [
'new_pass',
'PassManager',
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/distributed/passes/auto_parallel_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from paddle.utils import unique_name

from .pass_base import PassBase, register_pass
from .pass_utils import AutoParallelStreamType

OpRole = core.op_proto_and_checker_maker.OpRole
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
Expand Down Expand Up @@ -748,13 +749,11 @@ def _fuse_overlap_parameter_comm_stage_two(self, sharding_info):
group = sharding_info.group
else:
group = new_process_group(ranks, force_new_group=True)
# NOTE here stream is just a presentation with different name,
# it is up to executor to create the exact streams given the name.
stream = f"sharding_param_comm_stream{i}"

self.param_comm_group_stream_pairs.append(
{
"comm_group": group,
"comm_stream": stream,
"comm_stream": AutoParallelStreamType.SHARDING_STREAM.value,
}
)
_logger.info(
Expand Down
Loading