Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add collective async wait op #31463

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions paddle/fluid/operators/collective/c_wait_comm_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif

namespace paddle {
namespace operators {

class CWaitCommOp : public framework::OperatorBase {
public:
CWaitCommOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"wait_comm op can run on gpu place only for now."));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");

auto compute_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();

auto event =
platform::NCCLCommContext::Instance().Get(ring_id, place)->comm_event();

// comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(compute_stream, event, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, comm_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(compute_stream, event, 0));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

class CWaitCommOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC(
CWaitComm Operator

Compute stream wait Comm Stream with async event.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(c_wait_comm, ops::CWaitCommOp, ops::CWaitCommOpMaker);
95 changes: 95 additions & 0 deletions paddle/fluid/operators/collective/c_wait_compute_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif

namespace paddle {
namespace operators {

class CWaitComputeOp : public framework::OperatorBase {
public:
CWaitComputeOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE_EQ(
is_gpu_place(place), true,
platform::errors::PreconditionNotMet(
"wait_compute op can run on gpu place only for now."));

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = Attr<int>("ring_id");

auto compute_stream =
static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place))
->stream();
auto comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();

auto event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->compute_event();

// compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_CUDA_SUCCESS(hipEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamWaitEvent(comm_stream, event, 0));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaEventRecord(event, compute_stream));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamWaitEvent(comm_stream, event, 0));
#endif
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
}
};

class CWaitComputeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddOutput("Out", "(Tensor) Dependency of the variable need to sync")
.AsDuplicable();
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC(
CWaitCompute Operator

Comm stream wait Compute Stream with async event.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(c_wait_compute, ops::CWaitComputeOp,
ops::CWaitComputeOpMaker);
28 changes: 28 additions & 0 deletions paddle/fluid/platform/collective_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/fluid/platform/collective_helper.h"
#include <utility>

#include "paddle/fluid/platform/cuda_resource_pool.h"

namespace paddle {
namespace platform {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down Expand Up @@ -43,12 +45,31 @@ class NCCLCommImpl : public NCCLComm {
}
CUDADeviceContext* dev_context() const override { return dev_ctx_.get(); }

gpuEvent_t compute_event() const override { return compute_event_.get(); }

gpuEvent_t comm_event() const override { return comm_event_.get(); }

void set_compute_event(
std::shared_ptr<platform::CudaEventObject>&& compute_event) {
compute_event_ = std::move(compute_event);
}

void set_comm_event(std::shared_ptr<platform::CudaEventObject>&& comm_event) {
comm_event_ = std::move(comm_event);
}

private:
int ring_id_;
int nranks_;
int rank_;
ncclComm_t comm_;
std::unique_ptr<CUDADeviceContext> dev_ctx_;

// used for comm wait compute, compute_stream-->event-->comm_stream
std::shared_ptr<platform::CudaEventObject> compute_event_;

// used for compute wait comm, comm_stream-->event-->compute_stream
std::shared_ptr<platform::CudaEventObject> comm_event_;
};

NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
Expand Down Expand Up @@ -124,12 +145,19 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(ncclComm_t comm, int nranks, int rank,
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_id)));

std::shared_ptr<platform::CudaEventObject> compute_event(
platform::CudaEventResourcePool::Instance().New(dev_id));
std::shared_ptr<platform::CudaEventObject> comm_event(
platform::CudaEventResourcePool::Instance().New(dev_id));

NCCLCommImpl* c = new NCCLCommImpl;
c->set_ring_id(ring_id);
c->set_nranks(nranks);
c->set_rank(rank);
c->set_comm(comm);
c->set_dev_ctx(std::move(dev_ctx));
c->set_compute_event(std::move(compute_event));
c->set_comm_event(std::move(comm_event));

comm_map_mutex_.lock();
if (comm_map_.count(ring_id) == 0) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/platform/collective_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class NCCLComm {
virtual int device_id() const = 0;
virtual ncclComm_t comm() const = 0;
virtual gpuStream_t stream() const = 0;
virtual gpuEvent_t compute_event() const = 0;
virtual gpuEvent_t comm_event() const = 0;
virtual CUDADeviceContext* dev_context() const = 0;
virtual ~NCCLComm() = default;
};
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,7 +2121,8 @@ class Operator(object):
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv'
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
'c_wait_comm', 'c_wait_compute'
}

def __init__(self,
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_wait)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
endif()

Expand Down
114 changes: 114 additions & 0 deletions python/paddle/fluid/tests/unittests/collective_allreduce_op_wait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main

paddle.enable_static()


class TestCollectiveAllreduce(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0

def get_model(self, main_prog, startup_program):
ring_id = 0
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofallreduce",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)

# tout = tin + tin - tin = tin
if True:
main_prog.global_block().append_op(
type="elementwise_add",
inputs={
'X': tindata,
'Y': tindata,
},
outputs={'Out': toutdata}, )
main_prog.global_block().append_op(
type="elementwise_sub",
inputs={
'X': toutdata,
'Y': tindata,
},
outputs={'Out': toutdata}, )

main_prog.global_block().append_op(
type='c_wait_compute',
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})

main_prog.global_block().append_op(
type="c_allreduce_sum",
inputs={'X': toutdata},
attrs={'ring_id': ring_id},
outputs={'Out': toutdata},
attr={'use_calc_stream': False})

main_prog.global_block().append_op(
type="c_wait_comm",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})

# tout = tin + tout - tin = tout
if True:
main_prog.global_block().append_op(
type="elementwise_add",
inputs={
'X': tindata,
'Y': toutdata,
},
outputs={'Out': toutdata}, )
main_prog.global_block().append_op(
type="elementwise_sub",
inputs={
'X': toutdata,
'Y': tindata,
},
outputs={'Out': toutdata}, )

return toutdata


if __name__ == "__main__":
runtime_main(TestCollectiveAllreduce, "allreduce", 0)
Loading