-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/shuffle reader #8991
Feature/shuffle reader #8991
Changes from all commits
46ae407
2ea4a5d
225efa6
a8c076e
f9974a4
164f238
127b371
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,40 +16,29 @@ | |
|
||
#include "paddle/fluid/framework/ddim.h" | ||
#include "paddle/fluid/framework/lod_tensor_array.h" | ||
#include "paddle/fluid/platform/place.h" | ||
|
||
#include <memory> | ||
#include <thread> | ||
#include <vector> | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class ReaderBase { | ||
public: | ||
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) { | ||
PADDLE_ENFORCE(!shapes_.empty()); | ||
} | ||
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; | ||
|
||
virtual void ReInit() = 0; | ||
|
||
DDim shape(size_t idx) const; | ||
std::vector<DDim> shapes() const { return shapes_; } | ||
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; } | ||
|
||
virtual bool HasNext() const = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you think of replacing |
||
|
||
virtual ~ReaderBase() {} | ||
|
||
protected: | ||
std::vector<DDim> shapes_; | ||
}; | ||
|
||
class FileReader : public ReaderBase { | ||
public: | ||
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {} | ||
virtual ~ReaderBase(); | ||
}; | ||
|
||
class DecoratedReader : public ReaderBase { | ||
public: | ||
explicit DecoratedReader(ReaderBase* reader) | ||
: ReaderBase(reader->shapes()), reader_(reader) { | ||
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) { | ||
PADDLE_ENFORCE_NOT_NULL(reader_); | ||
} | ||
|
||
|
@@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase { | |
ReaderBase* reader_; | ||
}; | ||
|
||
class FileReader : public ReaderBase { | ||
public: | ||
explicit FileReader(const std::vector<DDim>& dims); | ||
|
||
void ReadNext(std::vector<LoDTensor>* out) override; | ||
|
||
protected: | ||
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0; | ||
|
||
private: | ||
std::vector<DDim> dims_; | ||
}; | ||
|
||
// The ReaderHolder is used as reader' unified wrapper, | ||
// making it easier to access different type reader in Variables. | ||
class ReaderHolder { | ||
|
@@ -78,19 +80,6 @@ class ReaderHolder { | |
reader_->ReInit(); | ||
} | ||
|
||
DDim shape(size_t idx) const { | ||
PADDLE_ENFORCE_NOT_NULL(reader_); | ||
return reader_->shape(idx); | ||
} | ||
std::vector<DDim> shapes() const { | ||
PADDLE_ENFORCE_NOT_NULL(reader_); | ||
return reader_->shapes(); | ||
} | ||
void set_shapes(const std::vector<DDim>& shapes) { | ||
PADDLE_ENFORCE_NOT_NULL(reader_); | ||
reader_->set_shapes(shapes); | ||
} | ||
|
||
bool HasNext() const { return reader_->HasNext(); } | ||
|
||
private: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2; | |
|
||
class DoubleBufferReader : public framework::DecoratedReader { | ||
public: | ||
explicit DoubleBufferReader(ReaderBase* reader) | ||
: DecoratedReader(reader), | ||
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>( | ||
kDoubleBufferSize)) { | ||
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); | ||
struct Item { | ||
Item() : ctx_(nullptr) {} | ||
|
||
std::vector<framework::LoDTensor> payloads_; | ||
platform::DeviceContext* ctx_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this |
||
}; | ||
|
||
explicit DoubleBufferReader( | ||
ReaderBase* reader, platform::Place target_place = platform::CPUPlace()) | ||
: DecoratedReader(reader), place_(target_place) { | ||
for (size_t i = 0; i < kDoubleBufferSize; ++i) { | ||
if (platform::is_gpu_place(place_)) { | ||
#ifdef PADDLE_WITH_CUDA | ||
ctxs_.emplace_back(new platform::CUDADeviceContext( | ||
boost::get<platform::CUDAPlace>(place_))); | ||
#endif | ||
} | ||
} | ||
|
||
start_thread(); | ||
} | ||
|
||
void start_thread() { | ||
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize); | ||
std::thread prefetch([this] { PrefetchThreadFunc(); }); | ||
prefetch.detach(); | ||
} | ||
|
||
|
@@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader { | |
private: | ||
void PrefetchThreadFunc(); | ||
|
||
framework::Channel<std::vector<framework::LoDTensor>>* buffer_; | ||
framework::Channel<Item>* buffer_; | ||
platform::Place place_; | ||
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
mutable Item local_buffer_; | ||
}; | ||
|
||
class CreateDoubleBufferReaderOp : public framework::OperatorBase { | ||
|
@@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { | |
->Get<framework::ReaderHolder>(); | ||
auto* out = scope.FindVar(Output("Out")) | ||
->template GetMutable<framework::ReaderHolder>(); | ||
out->Reset(new DoubleBufferReader(underlying_reader.Get())); | ||
|
||
auto place_str = Attr<std::string>("place"); | ||
platform::Place place; | ||
if (place_str == "CPU") { | ||
place = platform::CPUPlace(); | ||
} else { | ||
std::istringstream sin(place_str); | ||
sin.seekg(std::string("CUDA:").size(), std::ios::beg); | ||
size_t num; | ||
sin >> num; | ||
place = platform::CUDAPlace(static_cast<int>(num)); | ||
} | ||
|
||
out->Reset(new DoubleBufferReader(underlying_reader.Get(), place)); | ||
} | ||
}; | ||
|
||
|
@@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase { | |
It launches another thread to execute the 'underlying reader' asynchronously, | ||
which prevents reading process from blocking subsequent training. | ||
)DOC"); | ||
std::unordered_set<std::string> enum_range; | ||
constexpr size_t kMaxCUDADevs = 128; | ||
for (size_t i = 0; i < kMaxCUDADevs; ++i) { | ||
enum_range.insert(string::Sprintf("CUDA:%d", i)); | ||
} | ||
enum_range.insert("CPU"); | ||
AddAttr<std::string>("place", "The double buffer place, default is CPU") | ||
.SetDefault("CPU") | ||
.InEnum({enum_range}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't 'CPU' be in this enum? |
||
} | ||
}; | ||
|
||
void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) { | ||
out->clear(); | ||
buffer_->Receive(out); | ||
if (local_buffer_.payloads_.empty()) { | ||
buffer_->Receive(&local_buffer_); | ||
} | ||
|
||
*out = local_buffer_.payloads_; | ||
local_buffer_.payloads_.clear(); | ||
if (local_buffer_.ctx_) { | ||
local_buffer_.ctx_->Wait(); | ||
} | ||
} | ||
|
||
void DoubleBufferReader::ReInit() { | ||
reader_->ReInit(); | ||
buffer_->Close(); | ||
// The existing prefetch thread will terminate for the buffer_ is closed. | ||
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>( | ||
kDoubleBufferSize); | ||
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this); | ||
prefetch.detach(); | ||
start_thread(); | ||
} | ||
|
||
void DoubleBufferReader::PrefetchThreadFunc() { | ||
VLOG(5) << "A new prefetch thread starts."; | ||
while (true) { | ||
std::vector<framework::LoDTensor> batch; | ||
reader_->ReadNext(&batch); | ||
if (batch.empty()) { | ||
// EOF | ||
buffer_->Close(); | ||
VLOG(5) << "Reached the end of the file. The prefetch thread terminates."; | ||
break; | ||
size_t gpu_ctx_offset = 0; | ||
while (reader_->HasNext()) { | ||
Item batch; | ||
reader_->ReadNext(&batch.payloads_); | ||
if (platform::is_gpu_place(place_)) { | ||
std::vector<framework::LoDTensor> gpu_batch; | ||
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++]; | ||
gpu_ctx_offset %= this->ctxs_.size(); | ||
gpu_batch.resize(batch.payloads_.size()); | ||
for (size_t i = 0; i < batch.payloads_.size(); ++i) { | ||
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx, | ||
&gpu_batch[i]); | ||
gpu_batch[i].set_lod(batch.payloads_[i].lod()); | ||
} | ||
batch.ctx_ = gpu_ctx.get(); | ||
std::swap(gpu_batch, batch.payloads_); | ||
} | ||
|
||
if (!buffer_->Send(&batch)) { | ||
VLOG(5) << "WARNING: The double buffer channel has been closed. The " | ||
"prefetch thread terminates."; | ||
break; | ||
} | ||
} | ||
buffer_->Close(); | ||
} | ||
|
||
bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); } | ||
bool DoubleBufferReader::HasNext() const { | ||
if (local_buffer_.payloads_.empty()) { | ||
bool ok = buffer_->Receive(&local_buffer_); | ||
return ok; | ||
} else { | ||
return true; | ||
} | ||
} | ||
|
||
} // namespace reader | ||
} // namespace operators | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not related for this PR, just for discussion:
Do we need
ReInit
for ReaderBase? Maybe readers like network reader is a stream of data, can not be reinitialized back to the beginning.What is the use case for
ReInit
? If we don't have a clear use case we probably should drop this method from the base class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a network reader, we can just throw an exception.
The reader module is in flux. I just discussed the design doc with @JiayiFeng yesterday. Whole design doc will be written soon, maybe within 1-2 days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the prompt reply!