Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fluid channels should match the semantics of Go Channels #9265

Merged
merged 9 commits into from
Mar 27, 2018
Merged
93 changes: 55 additions & 38 deletions paddle/fluid/framework/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Channel {
public:
virtual bool CanSend() = 0;
virtual bool CanReceive() = 0;
virtual bool Send(T*) = 0;
virtual void Send(T*) = 0;
virtual bool Receive(T*) = 0;
virtual size_t Cap() = 0;
virtual void Lock() = 0;
Expand Down Expand Up @@ -84,96 +84,113 @@ class ChannelHolder {
}

template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
void Send(T* data) {
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
channel->Send(data);
}

template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
PADDLE_ENFORCE_EQ(
holder_->Type(), std::type_index(typeid(T)),
"Channel type is not same as the type of the data being sent");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
PADDLE_ENFORCE_EQ(channel != nullptr, true, "Channel should not be null.");
return channel->Receive(data);
}

bool IsClosed() {
if (IsInitialized()) {
return holder_->IsClosed();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->IsClosed();
}

bool CanSend() {
if (IsInitialized()) {
return holder_->CanSend();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanSend();
}

bool CanReceive() {
if (IsInitialized()) {
return holder_->CanReceive();
}
return false;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->CanReceive();
}

void close() {
if (IsInitialized()) holder_->Close();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Close();
}

size_t Cap() {
if (IsInitialized()) return holder_->Cap();
return -1;
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Cap();
}

void Lock() {
if (IsInitialized()) holder_->Lock();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Lock();
}

void Unlock() {
if (IsInitialized()) holder_->Unlock();
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->Unlock();
}

template <typename T>
void AddToSendQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToSendQ(referrer, data, cond, cb);
}
}

template <typename T>
void AddToReceiveQ(const void* referrer, T* data,
std::shared_ptr<std::condition_variable_any> cond,
std::function<bool(ChannelAction)> cb) {
if (IsInitialized()) {
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
if (channel != nullptr) {
channel->AddToReceiveQ(referrer, data, cond, cb);
}
}

void RemoveFromSendQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromSendQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromSendQ(referrer);
}

void RemoveFromReceiveQ(const void* referrer) {
if (IsInitialized()) holder_->RemoveFromReceiveQ(referrer);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
holder_->RemoveFromReceiveQ(referrer);
}

inline bool IsInitialized() const { return holder_ != nullptr; }

inline const std::type_index Type() {
PADDLE_ENFORCE_EQ(IsInitialized(), true);
PADDLE_ENFORCE_EQ(IsInitialized(), true,
"The Channel hasn't been initialized");
return holder_->Type();
}

Expand Down
35 changes: 22 additions & 13 deletions paddle/fluid/framework/channel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ChannelImpl : public paddle::framework::Channel<T> {
public:
virtual bool CanSend();
virtual bool CanReceive();
virtual bool Send(T *);
virtual void Send(T *);
virtual bool Receive(T *);
virtual size_t Cap() { return cap_; }
virtual void Lock();
Expand Down Expand Up @@ -76,10 +76,9 @@ class ChannelImpl : public paddle::framework::Channel<T> {
}
};

bool send_return(bool value) {
void send_return() {
send_ctr--;
destructor_cond_.notify_all();
return value;
}

bool recv_return(bool value) {
Expand Down Expand Up @@ -118,15 +117,15 @@ bool ChannelImpl<T>::CanReceive() {
}

template <typename T>
bool ChannelImpl<T>::Send(T *item) {
void ChannelImpl<T>::Send(T *item) {
send_ctr++;
std::unique_lock<std::recursive_mutex> lock{mu_};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to explicitly lock after constructor? lock->lock() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we don't need to do that. The unique lock constructor automatically does that.


// If channel is closed, do nothing
// If channel is closed, throw exception
if (closed_) {
lock.unlock();
// TODO(abhinavarora) Should panic on closed channel
return send_return(false);
send_return();
PADDLE_THROW("Cannot send on closed channel");
}

// If there is a receiver, directly pass the value we want
Expand All @@ -143,20 +142,25 @@ bool ChannelImpl<T>::Send(T *item) {
if (m->callback != nullptr) do_send = m->callback(ChannelAction::SEND);
if (do_send)
*(m->data) = std::move(*item);
else
else {
// We cannot do the data transfer because
// this QueueMessage was added by Select
// and some other case was executed.
// So call the Send function again.
// We do not care about notifying other
// because they would have been notified
// by the executed select case.
return send_return(Send(item));
lock.unlock();
Send(item);
send_return();
return;
}

// Wake up the blocked process and unlock
m->Notify();
lock.unlock();
return send_return(true);
send_return();
return;
}

// Unbuffered channel will always bypass this
Expand All @@ -167,16 +171,21 @@ bool ChannelImpl<T>::Send(T *item) {
buf_.push_back(std::move(*item));
// Release lock and return true
lock.unlock();
return send_return(true);
send_return();
return;
}

// Block on channel, because some receiver will complete
// the operation for us
auto m = std::make_shared<QueueMessage>(item);
sendq.push_back(m);
m->Wait(lock);
// TODO(abhinavarora) Should panic on closed channel
return send_return(!m->chan_closed);
if (m->chan_closed) {
lock.unlock();
send_return();
PADDLE_THROW("Cannot send on closed channel");
}
send_return();
}

template <typename T>
Expand Down
Loading