Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tcp store as a global instance #55956

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/pybind/communication.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License. */
#include <memory>
#include <string>

#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

Expand Down Expand Up @@ -109,6 +110,9 @@ void BindTCPStore(py::module *m) {
py::arg("world_size"),
py::arg("timeout") = 900,
py::call_guard<py::gil_scoped_release>());

m->def("create_or_get_global_tcp_store",
&phi::distributed::CreateOrGetGlobalTCPStore);
}

} // namespace pybind
Expand Down
11 changes: 3 additions & 8 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@ proto_library(auto_parallel_proto SRCS auto_parallel.proto)
set(DISTRIBUTED_SRCS "")

if(WITH_DISTRIBUTE)
list(
APPEND
DISTRIBUTED_SRCS
dist_tensor.cc
reshard_function.cc
reshard_split_functor.cc
reshard_utils.cc
r_to_s_reshard_function.cc)
list(APPEND DISTRIBUTED_SRCS dist_tensor.cc reshard_function.cc
reshard_split_functor.cc r_to_s_reshard_function.cc)
endif()

collect_srcs(
Expand All @@ -20,4 +14,5 @@ collect_srcs(
process_mesh.cc
dist_attr.cc
dist_mapper.cc
reshard_utils.cc
${DISTRIBUTED_SRCS})
71 changes: 62 additions & 9 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

namespace phi {
namespace distributed {
using auto_parallel::str_split;

bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping) {
return std::any_of(dims_mapping.begin(),
Expand All @@ -33,15 +34,6 @@ bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping) {
[](int64_t value) { return value == -1; });
}

int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}

std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh) {
const auto& process_shape = process_mesh.shape();
const auto& process_ids = process_mesh.process_ids();
Expand Down Expand Up @@ -80,5 +72,66 @@ std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
return split_axis_to_mesh_axis;
}

int64_t GetCurGlobalRank() {
const char* cur_rank = std::getenv("PADDLE_TRAINER_ID");
PADDLE_ENFORCE_NOT_NULL(
cur_rank,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINER_ID' cannot be found."));
return std::atoi(cur_rank);
}

int64_t GetGlobalWorldSize() {
const char* world_size = std::getenv("PADDLE_TRAINERS_NUM");
PADDLE_ENFORCE_NOT_NULL(
world_size,
phi::errors::NotFound(
"The environment variable 'PADDLE_TRAINERS_NUM' cannot be found."));
return std::atoi(world_size);
}

namespace {
std::string GetMasterEndpoint() {
const char* master_endpoint = std::getenv("PADDLE_MASTER");
if (!master_endpoint) {
const char* trainer_endpoints = std::getenv("PADDLE_TRAINER_ENDPOINTS");
PADDLE_ENFORCE_NOT_NULL(
trainer_endpoints,
phi::errors::NotFound("The environment variable "
"'PADDLE_TRAINER_ENDPOINTS' cannot be found."));
return str_split(trainer_endpoints, ",")[0];
}

PADDLE_ENFORCE_NOT_NULL(
master_endpoint,
phi::errors::NotFound(
"The environment variable 'PADDLE_MASTER' cannot be found."));
return master_endpoint;
}

} // namespace

std::string GetMasterAddr() {
std::string master_endpoint = GetMasterEndpoint();
return str_split(master_endpoint, ":")[0];
}

uint16_t GetMasterPort() {
std::string master_endpoint = GetMasterEndpoint();
return std::stoi(str_split(master_endpoint, ":")[1]);
}

std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore() {
std::string host = GetMasterAddr();
uint16_t port = GetMasterPort();
int64_t cur_rank = GetCurGlobalRank();
int64_t world_size = GetGlobalWorldSize();
bool is_master = (cur_rank == 0);

static std::shared_ptr<TCPStore> store =
std::make_shared<TCPStore>(host, port, is_master, world_size);
return store;
}

} // namespace distributed
} // namespace phi
16 changes: 14 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/reshard_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <vector>

#include "paddle/phi/core/distributed/store/tcp_store.h"

namespace phi {
namespace distributed {
namespace auto_parallel {
Expand All @@ -31,8 +35,6 @@ bool IsDimsMappingShard(const std::vector<int64_t>& dims_mapping);

bool IsDimsMappingReplicated(const std::vector<int64_t>& dims_mapping);

int64_t GetCurGlobalRank();

// Get the coordinate of cur rank in process mesh. For example, the process mesh
// is [[0, 1], [2, 3], [4, 5], [6, 7]], if the current rank is 4, then will
// return [2, 0]; if the current rank is 3, then will return [1, 1].
Expand All @@ -46,5 +48,15 @@ std::vector<int64_t> GetCurRankCoordInMesh(const ProcessMesh& process_mesh);
std::map<int64_t, int64_t> GetSplitAxisWithDimsMapping(
const std::vector<int64_t>& dims_mapping);

int64_t GetCurGlobalRank();

std::string GetMasterAddr();

int64_t GetGlobalWorldSize();

uint16_t GetMasterPort();

std::shared_ptr<TCPStore> CreateOrGetGlobalTCPStore();

} // namespace distributed
} // namespace phi
43 changes: 14 additions & 29 deletions python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import datetime
import os

import paddle

Expand Down Expand Up @@ -320,32 +319,18 @@ def is_available():


def _init_parallel_env(backend):
master_endpoint = os.getenv("PADDLE_MASTER", None)
if master_endpoint is None:
master_endpoint = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')[0]
assert (
master_endpoint is not None
), "Please set PADDLE_MASTER enviroment variable."
if master_endpoint:
master_addr = master_endpoint.split(":")[0]
master_port = int(master_endpoint.split(":")[1])
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id
is_master = rank == 0
store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
store = core.create_or_get_global_tcp_store()
global_env = _get_global_env()
rank = global_env.rank
world_size = global_env.world_size
dev_id = global_env.device_id

if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)
if backend == "gloo":
core.CommContextManager.create_gloo_comm_context(
store, "0", rank, world_size
)
elif backend == "nccl":
core.CommContextManager.set_cuda_device_id(dev_id)
core.CommContextManager.create_nccl_comm_context(
store, "0", rank, world_size
)