-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add auto completion module for auto parallel #34813
Add auto completion module for auto parallel #34813
Conversation
… auto_parallel_basic
… auto_parallel_basic
… auto_parallel_basic
Thanks for your contribution! |
"spawn", | ||
|
||
__all__ = [ # noqa | ||
"spawn", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why modify this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why modify this file?
Typo. It will be corrected.
def set_process_mesh(self, process_mesh): | ||
self._process_mesh = process_mesh | ||
|
||
def get_dims_mapping(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_dim_mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of implementations rely on dims_mapping. Since this isn't exposed to user, I left it alone. Besides, I think we should try to convince @XiaoguangHu01 to adopt dims_mapping as tf.
self._owner_op = owner_op | ||
if owner_context is None: | ||
self._owner_context = owner_context | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the if else
statement necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the
if else
statement necessary?
No, it's unnecessary and will be removed.
result = cls.__new__(cls) | ||
memo[id(self)] = result | ||
for k, v in self.__dict__.items(): | ||
# No need to copy the owner tensor and context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy the owner op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Be corrected in the next commit.
op_dist_attr.mark_as_annotated("process_mesh") | ||
for tensor_name in op.input_arg_names: | ||
# There may be a better way to find the tensor by name | ||
tensor = op.block._var_recursive(tensor_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try op.block.vars[tensor_name]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to make sure to find the tensor from current block level up to the top level (unnecessarily now).
@@ -0,0 +1,97 @@ | |||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"auto_parallel/operators/embedding.py" rename to "auto_parallel/operators/dist_embedding.py" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
|
||
|
||
# RowParallel | ||
class DistributedEmbeddingImpl0(DistributedOperatorImpl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DistributedEmbeddingRowParallelImpl ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
// Giving each variable an identity can help us map related properties to it. | ||
// For example, the identity can be used as a key for referring to its | ||
// distributed attribute. | ||
uint64_t Id() { return id_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might name as dist_attr_id ? since by now it is used for determining the dist_attr identity only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dist_attr_id is obsoleted in the new code because it cannot work well in different distributed contexts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also have this question, maybe we can write in the comment that the id_
only be used for determining the dist_attr identity in auto_parallel now? avoid to make developers who read the code feel confused.
from .utils import append_distributed_attr_suffix | ||
|
||
# There always exists a default context for user. And user can set it to another one. | ||
DEFAULT_DISTRIBUTED_CONTEXT = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the reason that there always exists a default context ?
unlike the paddle.framework.program, normally we will has just one startup and one train program, so we have a default for each of them.
but in the scenario of autosearch, there is supposed to be multiple DistributedContext, what is the relationship between the default one and those created by autosearch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the default distributed context should be removed. For now, it is just used in the __str__ of variable and operator for debugging because we don't pass a distributed context to these __str__ functions.
dims_mapping = attr.get_dims_mapping() | ||
process_mesh_shape = attr.get_process_mesh().topology | ||
# If the dimension of tensor is less than the sharding dimension of process mesh, | ||
# we just amend the dimension mapping to -1. (Is this really OK?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case it should raise an error directly?
or
we need define a un-even sharding rule for it allowing only part of process hold a real shard of the tensor, and part of them hold a null shard ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. This part should be implemented in a better way. But this is not an error, and it's a design choice.
def __init__(self): | ||
self._name = None | ||
|
||
def forward(self, serial_op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serial_op object as the only one input might not be sufficient?
at least need another input argument to pass context information (like the varname mapping in the program or graph)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
|
||
|
||
register_distributed_operator_impl("reshape2", | ||
DistributedReshapeImpl0("add_one_dim_back")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add_one_dim_back --> add_one_dim ?
the collapsed dimension might not be the last dimension in some cases.
will there be a implement for each case? (add_one_dim_second_to_last, add_one_dim_third_to_last, add_one_dim_front, .......)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distributed reshape here only take cares of the situation where the added dimension is the last one and is just the reversed operation of another implementation. Besides, the naming problem of distributed operators and their implementations should be improved. Python doesn't have the preprocess function and cannot have automatic naming way easily.
register_distributed_operator("softmax", DistributedSoftmax("softmax")) | ||
|
||
|
||
class DistributedSoftmaxImpl(DistributedOperatorImpl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to confirm...we add the distributed op impl for each op here, are these ops only used for static mode? not for dynamic mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The static and dynamic mode of auto parallel should converge at the dist ops as the ops' role in the serial code since they are the lowest level. And we will try to keep them using the same interface. In the dynamic mode, user will use them to construct nn directly and we also can provide dist layer or model to give user more high level abstraction. The distributed ops interface and register mechanism may be improved and implemented in C++ when everything is stable.
@@ -0,0 +1,157 @@ | |||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2020 -> 2021
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -1199,6 +1199,14 @@ def _to_readable_code(self): | |||
if self.persistable: | |||
var_str = "persist " + var_str | |||
|
|||
from paddle.distributed.auto_parallel.context import get_default_distributed_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not import at the beginning of file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be corrected. But this may be removed in the future.
@@ -2359,6 +2367,13 @@ def _to_readable_code(self, skip_op_callstack=True): | |||
if i != len(attr_names) - 1: | |||
attrs_str += ", " | |||
|
|||
from paddle.distributed.auto_parallel.context import get_default_distributed_context |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same above
// Giving each variable an identity can help us map related properties to it. | ||
// For example, the identity can be used as a key for referring to its | ||
// distributed attribute. | ||
uint64_t Id() { return id_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also have this question, maybe we can write in the comment that the id_
only be used for determining the dist_attr identity in auto_parallel now? avoid to make developers who read the code feel confused.
* Fix bugs caused by shallow copy in attributes.py * Imporve amend_distributed_attr_for_program in context.py * Other changes for weihang's comments
…g/Paddle into auto_parallel_completion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for framework.py, *_desc.h
|
||
|
||
# ColumnParallel | ||
class DistributedMatmulImpl0(DistributedOperatorImpl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个API会对外公开吗?为什么不用Column之类来标记呢?用Impl0, Impl1, Impl2比较难区分。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不会对外公开,对用户暴露的只有interface里面接口哈,这次pr都是内部实现。
DistributedMatmulImpl2("replicate_parallel")) | ||
|
||
|
||
class DistributedMatmulV2(DistributedOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个API会对外暴露吗?这里为什么需要加版本号呢?MatmulV2
v2是实现的版本号,不对用户暴露,未来有可能会升级到v3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,对用户暴露的只有interface.py里面接口哈
self._name = name | ||
|
||
|
||
register_distributed_operator("reshape2", DistributedReshape2("reshape2")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的reshape2也是暴露了内部的实现。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,对用户暴露的只有interface.py里面接口哈~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Describe
Note that this pr won't be exposed to users and only involves the implementation of completing the distributed attributes. To see the public API for users, please refer to the other pr (add the basic apis for auto_parallel #33804)
complete_annotation(program, dist_context=None): the function uses a data-flow analysis algorithm to complete all distributed attributes for the input program. In the completion process, it also takes the distributed operators' implementations into consideration.
DistributedContext: this class can be seen as a container to store distributed information related to the program such as the distributed attributes. One partial annotated program may have multiple distributed context, each of which represents a different parallel strategy for the whole program. There always exists a default distributed context. The reason why we need this container is that the auto searcher will search different policies simultaneously and each policy should have its own container to store its corresponding distributed attributes before we select the best one.
TensorDistributedAttribute and OperatorDistributedAttribute: the two classes are wrapper for organizing the raw distributed attributes from VarDesc and OpDesc. Since the distributed attributes will be updated multiple times in the completion process, the two abstracts are really helpful instead of manipulating the raw ones. Furthermore, a tensor or an operator may have different distributed attributes at same time in the auto search process in the future.
Distributed Operators: the distributed operators is the underlying engine for auto parallel. Like operators and their kernels, each distributed operator will have multiple implementations according the parallel strategy. For example, the matmul operator has a corresponding distributed operator, including the row and column parallelization. The distributed operators will take care of the communication within its implementation and can be used to construct network directly for dynamic graph in the future. The ideal situation is that all operators have their own general distributed operators, which have implementations for all parallel strategy. For now, there are only few ones and each of them only has few implementations. Note that this pr doesn't include the implementations for added distributed operators (will be provided by @JZ-LIANG) and only the compatible rules are provided.