Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add auto completion module for auto parallel #34813

Merged
merged 99 commits into from
Aug 24, 2021

Conversation

aoyulong
Copy link
Contributor

@aoyulong aoyulong commented Aug 11, 2021

PR types

New features

PR changes

Others

Describe

  • Note that this pr won't be exposed to users and only involves the implementation of completing the distributed attributes. To see the public API for users, please refer to the other pr (add the basic apis for auto_parallel  #33804)

  • complete_annotation(program, dist_context=None): the function uses a data-flow analysis algorithm to complete all distributed attributes for the input program. In the completion process, it also takes the distributed operators' implementations into consideration.

  • DistributedContext: this class can be seen as a container to store distributed information related to the program such as the distributed attributes. One partial annotated program may have multiple distributed context, each of which represents a different parallel strategy for the whole program. There always exists a default distributed context. The reason why we need this container is that the auto searcher will search different policies simultaneously and each policy should have its own container to store its corresponding distributed attributes before we select the best one.

  • TensorDistributedAttribute and OperatorDistributedAttribute: the two classes are wrapper for organizing the raw distributed attributes from VarDesc and OpDesc. Since the distributed attributes will be updated multiple times in the completion process, the two abstracts are really helpful instead of manipulating the raw ones. Furthermore, a tensor or an operator may have different distributed attributes at same time in the auto search process in the future.

  • Distributed Operators: the distributed operators is the underlying engine for auto parallel. Like operators and their kernels, each distributed operator will have multiple implementations according the parallel strategy. For example, the matmul operator has a corresponding distributed operator, including the row and column parallelization. The distributed operators will take care of the communication within its implementation and can be used to construct network directly for dynamic graph in the future. The ideal situation is that all operators have their own general distributed operators, which have implementations for all parallel strategy. For now, there are only few ones and each of them only has few implementations. Note that this pr doesn't include the implementations for added distributed operators (will be provided by @JZ-LIANG) and only the compatible rules are provided.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

"spawn",

__all__ = [ # noqa
"spawn",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why modify this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why modify this file?
Typo. It will be corrected.

def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh

def get_dims_mapping(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_dim_mapping?

Copy link
Contributor Author

@aoyulong aoyulong Aug 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of implementations rely on dims_mapping. Since this isn't exposed to user, I left it alone. Besides, I think we should try to convince @XiaoguangHu01 to adopt dims_mapping as tf.

self._owner_op = owner_op
if owner_context is None:
self._owner_context = owner_context
else:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the if else statement necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the if else statement necessary?

No, it's unnecessary and will be removed.

result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy the owner op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be corrected in the next commit.

op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
tensor = op.block._var_recursive(tensor_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try op.block.vars[tensor_name]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to make sure to find the tensor from current block level up to the top level (unnecessarily now).

Shixiaowei02
Shixiaowei02 previously approved these changes Aug 16, 2021
@@ -0,0 +1,97 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"auto_parallel/operators/embedding.py" rename to "auto_parallel/operators/dist_embedding.py" ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!



# RowParallel
class DistributedEmbeddingImpl0(DistributedOperatorImpl):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DistributedEmbeddingRowParallelImpl ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

// Giving each variable an identity can help us map related properties to it.
// For example, the identity can be used as a key for referring to its
// distributed attribute.
uint64_t Id() { return id_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might name as dist_attr_id ? since by now it is used for determining the dist_attr identity only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dist_attr_id is obsoleted in the new code because it cannot work well in different distributed contexts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have this question, maybe we can write in the comment that the id_ only be used for determining the dist_attr identity in auto_parallel now? avoid to make developers who read the code feel confused.

from .utils import append_distributed_attr_suffix

# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason that there always exists a default context ?
unlike the paddle.framework.program, normally we will has just one startup and one train program, so we have a default for each of them.
but in the scenario of autosearch, there is supposed to be multiple DistributedContext, what is the relationship between the default one and those created by autosearch.

Copy link
Contributor Author

@aoyulong aoyulong Aug 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default distributed context should be removed. For now, it is just used in the __str__ of variable and operator for debugging because we don't pass a distributed context to these __str__ functions.

dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case it should raise an error directly?
or
we need define a un-even sharding rule for it allowing only part of process hold a real shard of the tensor, and part of them hold a null shard ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. This part should be implemented in a better way. But this is not an error, and it's a design choice.

def __init__(self):
self._name = None

def forward(self, serial_op):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serial_op object as the only one input might not be sufficient?
at least need another input argument to pass context information (like the varname mapping in the program or graph)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!



register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_one_dim_back --> add_one_dim ?
the collapsed dimension might not be the last dimension in some cases.
will there be a implement for each case? (add_one_dim_second_to_last, add_one_dim_third_to_last, add_one_dim_front, .......)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed reshape here only take cares of the situation where the added dimension is the last one and is just the reversed operation of another implementation. Besides, the naming problem of distributed operators and their implementations should be improved. Python doesn't have the preprocess function and cannot have automatic naming way easily.

fuyinno4
fuyinno4 previously approved these changes Aug 17, 2021
register_distributed_operator("softmax", DistributedSoftmax("softmax"))


class DistributedSoftmaxImpl(DistributedOperatorImpl):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to confirm...we add the distributed op impl for each op here, are these ops only used for static mode? not for dynamic mode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The static and dynamic mode of auto parallel should converge at the dist ops as the ops' role in the serial code since they are the lowest level. And we will try to keep them using the same interface. In the dynamic mode, user will use them to construct nn directly and we also can provide dist layer or model to give user more high level abstraction. The distributed ops interface and register mechanism may be improved and implemented in C++ when everything is stable.

@@ -0,0 +1,157 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2020 -> 2021

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1199,6 +1199,14 @@ def _to_readable_code(self):
if self.persistable:
var_str = "persist " + var_str

from paddle.distributed.auto_parallel.context import get_default_distributed_context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not import at the beginning of file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be corrected. But this may be removed in the future.

@@ -2359,6 +2367,13 @@ def _to_readable_code(self, skip_op_callstack=True):
if i != len(attr_names) - 1:
attrs_str += ", "

from paddle.distributed.auto_parallel.context import get_default_distributed_context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same above

// Giving each variable an identity can help us map related properties to it.
// For example, the identity can be used as a key for referring to its
// distributed attribute.
uint64_t Id() { return id_; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have this question, maybe we can write in the comment that the id_ only be used for determining the dist_attr identity in auto_parallel now? avoid to make developers who read the code feel confused.

* Fix bugs caused by shallow copy in attributes.py
* Imporve amend_distributed_attr_for_program in context.py
* Other changes for weihang's comments
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for framework.py, *_desc.h



# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API会对外公开吗?为什么不用Column之类来标记呢?用Impl0, Impl1, Impl2比较难区分。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会对外公开,对用户暴露的只有interface里面接口哈,这次pr都是内部实现。

DistributedMatmulImpl2("replicate_parallel"))


class DistributedMatmulV2(DistributedOperator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个API会对外暴露吗?这里为什么需要加版本号呢?MatmulV2
v2是实现的版本号,不对用户暴露,未来有可能会升级到v3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,对用户暴露的只有interface.py里面接口哈

self._name = name


register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的reshape2也是暴露了内部的实现。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,对用户暴露的只有interface.py里面接口哈~

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@fuyinno4 fuyinno4 merged commit 93d862b into PaddlePaddle:develop Aug 24, 2021
@aoyulong aoyulong deleted the auto_parallel_completion branch December 10, 2021 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants