Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Improve the APIs #45776

Merged
merged 60 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
45328ab
[Auto Parallel] Use c++ dist attr in the completion process
aoyulong Aug 24, 2022
5bab411
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_dis…
aoyulong Aug 25, 2022
247d61e
[Auto Parallel] Add minor changes
aoyulong Aug 25, 2022
35cc6e3
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into new_dis…
aoyulong Aug 30, 2022
26f9b9d
[Auto Parallel] Use c++ dist attr in the completion process
aoyulong Aug 24, 2022
dafc2f3
[Auto Parallel] Add minor changes
aoyulong Aug 25, 2022
c37f036
[Auto Parallel] Add the serialization process for dist attrs
aoyulong Aug 30, 2022
76e193d
[Auto Parallel] Remove unnecessary comments
aoyulong Aug 30, 2022
e988924
[Auto Parallel] Fix some bugs
aoyulong Sep 1, 2022
4b757a5
Merge branch 'develop' into new_dist_attr
aoyulong Sep 1, 2022
1897263
[Auto Parallel] Fix the code style
aoyulong Sep 1, 2022
6215661
Merge branch 'new_dist_attr' into serialize_dist_attr
aoyulong Sep 1, 2022
a243b06
[Auto Parallel] Remove unnecessary impls
aoyulong Sep 1, 2022
aeb113a
[Auto Parallel] Fix the importing error
aoyulong Sep 1, 2022
929fd58
Merge branch 'new_dist_attr' into serialize_dist_attr
aoyulong Sep 1, 2022
8f10e71
[Auto Parallel] Fix the copy from bugs of op dist attr
aoyulong Sep 2, 2022
748b1d2
Merge branch 'new_dist_attr' into serialize_dist_attr
aoyulong Sep 2, 2022
d5198e7
[Auto Parallel] Replace the use of constexpr if
aoyulong Sep 2, 2022
f77dc9a
[Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh
aoyulong Sep 5, 2022
020bf9e
[Auto Parallel] Change API of the completion unittest
aoyulong Sep 5, 2022
6660644
[Auto Parallel] Fix the bug when set_attr an int
aoyulong Sep 5, 2022
58be315
Merge branch 'develop' into serialize_dist_attr
aoyulong Sep 5, 2022
66e17d7
Merge branch 'serialize_dist_attr' into new_api
aoyulong Sep 5, 2022
7f9ddc8
[Auto Parallel] Add the unittest for the serialization
aoyulong Sep 6, 2022
b971647
Merge branch 'serialize_dist_attr' into new_api
aoyulong Sep 6, 2022
fb5d7f4
[Auto Parallel] Add some unit tests
aoyulong Sep 6, 2022
1572348
[Auto Paralle] Unify the strategy
aoyulong Sep 7, 2022
0e53739
[Auto Parallel] Improve the engine api
aoyulong Sep 8, 2022
5fbf40e
[Auto Parallel] Reset the changes made to the framework
aoyulong Sep 8, 2022
3d828a6
[Auto Parallel] Change the engine unittest
aoyulong Sep 8, 2022
7ff7caf
[Auto Parallel] Update API of the completion and partitioner
aoyulong Sep 8, 2022
3fa80fe
[Auto Parallel] Update unit tests using engine api
aoyulong Sep 9, 2022
0d00d5d
fix confict
Sep 12, 2022
f668cf8
Merge branch 'new_api' of https://github.com/aoyulong/Paddle into new…
Sep 12, 2022
1665367
update shard annotation
Caozhou1995 Sep 9, 2022
91ce66b
Merge pull request #5 from Caozhou1995/new_api_1
Caozhou1995 Sep 13, 2022
291c533
Merge branch 'new_api' of https://github.com/aoyulong/Paddle into Aut…
Sep 13, 2022
0db3279
[Auto Parallel] Remove the modifications of other modules
aoyulong Sep 13, 2022
5a3d676
Merge branch 'new_api' of https://github.com/aoyulong/Paddle into new…
aoyulong Sep 13, 2022
ee0893c
[Auto Parallel] Add docs for APIs
aoyulong Sep 13, 2022
06174d1
Merge branch 'new_api' of https://github.com/aoyulong/Paddle into Aut…
Sep 13, 2022
a6bf0c2
add new strategy
Sep 13, 2022
4d92b53
[Auto Parallel] Replace the logger
aoyulong Sep 13, 2022
6b1fed1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Sep 13, 2022
23f6539
[Auto Parallel] Restore the test_program.py
aoyulong Sep 13, 2022
d148cb8
fix conflict
Sep 13, 2022
e94e263
Merge pull request #6 from zhaoyinglia/AutoParallel/new_api
aoyulong Sep 13, 2022
31c4d71
Merge branch 'develop' into new_api
aoyulong Sep 13, 2022
a0fe2c3
[Auto Parallel] Change the import rules
aoyulong Sep 14, 2022
2b54978
[Auto Parallel] Add the examples for Engine
aoyulong Sep 14, 2022
eac9880
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Sep 14, 2022
3c5b79c
[Auto Parallel] Do some minor changes
aoyulong Sep 14, 2022
fb435a0
[Auto Parallel] Remove yaml dependency
aoyulong Sep 14, 2022
14bce35
[Auto Parallel] Fix the unittests
aoyulong Sep 14, 2022
dd26467
add valid after train
Sep 14, 2022
28d6af8
Merge pull request #7 from zhaoyinglia/new_api_valid
aoyulong Sep 14, 2022
5b0603b
bug fix
Sep 14, 2022
c0add9e
Merge pull request #8 from zhaoyinglia/new_api_fix
aoyulong Sep 14, 2022
8e81432
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
aoyulong Sep 15, 2022
5093402
Merge branch 'new_api' of https://github.com/aoyulong/Paddle into new…
aoyulong Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/paddle/distributed/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .strategy import Strategy
from .process_mesh import ProcessMesh
from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost
from .engine import Engine
from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch

__all__ = []
118 changes: 118 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from collections import defaultdict

# _g_default_config[category][field] = default_value
_g_default_config = defaultdict(dict)


def get_category_default_config(category):
return _g_default_config[category]


def set_category_default_config(category, default_value):
_g_default_config[category] = default_value


def get_field_default_config(category, field):
return _g_default_config[category][field]


def set_field_default_config(category, field, default_value):
_g_default_config[category][field] = default_value


NOT_FOUND = "not_found"

#########################################
# base configuration
#########################################
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug

#########################################
# recompute configuration
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "enable_tuning", False)

#########################################
# AMP configuration
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
set_field_default_config(AMP, "incr_ratio", 2.0)
set_field_default_config(AMP, "decr_ratio", 0.8)
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)

#########################################
# sharding configuration
#########################################
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "sharding_degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])

#########################################
# gradient merge configuration
#########################################
GRADIENT_MERGE = "gradient_merge"
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)

#########################################
# quantization configuration
#########################################
QAT = "qat"
set_field_default_config(QAT, "enable", False)
set_field_default_config(QAT, "channel_wise_abs_max", True)
set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)

# #########################################
# auto tuning configuration
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
import logging
import numpy as np
from ..utils import get_logger
from .utils import get_logger


class Converter(object):
Expand Down
32 changes: 32 additions & 0 deletions python/paddle/distributed/auto_parallel/dist_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ def mark_annotated_as(self, dist_attr):
def clear_annotated(self):
self._is_annotated.clear()

def __eq__(self, other):
if not isinstance(other, TensorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.dims_mapping != other.dims_mapping:
return False
if self._is_annotated != other._is_annotated:
return False
return True

def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
Expand Down Expand Up @@ -486,6 +497,27 @@ def is_annotated_output_dims_mapping(self, name):
else:
return False

def __eq__(self, other):
if not isinstance(other, OperatorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.op_type != other.op_type:
return False
if self.impl_type != other.impl_type:
return False
if self.impl_idx != other.impl_idx:
return False
if self._is_annotated != other._is_annotated:
return False
if self._is_recompute != other._is_recompute:
return False
if self.inputs_dist_attrs != other.inputs_dist_attrs:
return False
if self.outputs_dist_attrs != other.outputs_dist_attrs:
return False
return True

def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
Expand Down
3 changes: 0 additions & 3 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def __init__(self,
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False

# flag whether using `to_static`
self._dygraph_mode = False

@property
def serial_main_program(self):
return self._serial_main_program
Expand Down
100 changes: 92 additions & 8 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
from .utils import convert_to_shard_spec, verify_shard_spec


class DistributedOperator:
Expand Down Expand Up @@ -248,23 +249,106 @@ def __deepcopy__(self, memo):
return result


class DistributedModule:
class DistributedOperatorHelper:

def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __init__(self, serial_op, process_mesh, in_dims_mappings,
out_dims_mappings):
self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings

def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
for arg in kwargs.values() and self._in_dims_mappings:
if isinstance(arg, Variable):
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1

default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs)
output = self._serial_op(*args, **kwargs)
new_op_size = len(cur_block.ops)

if isinstance(output, tuple) or isinstance(output, list):
new_output = list(output)
elif isinstance(output, Variable):
new_output = [output]
else:
raise ValueError("Unrecognized outpout.")

if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]

from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
dist_op = DistributedOperator(op)
for name in dist_op.serial_op.input_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)

return output
Loading