Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/shuffle reader #8991

Merged
merged 7 commits into from
Mar 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
}

std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}

void SetDim(const std::string& name, const DDim& dim) override {
Expand All @@ -470,15 +462,7 @@ class RuntimeInferShapeContext : public InferShapeContext {

void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
PADDLE_THROW("Only compile time support this method");
}

proto::VarType::Type GetVarType(const std::string& name) const override {
Expand Down
22 changes: 15 additions & 7 deletions paddle/fluid/framework/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,22 @@

namespace paddle {
namespace framework {
ReaderBase::~ReaderBase() {}

DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}
FileReader::FileReader(const std::vector<DDim> &dims) : dims_(dims) {}

void FileReader::ReadNext(std::vector<LoDTensor> *out) {
ReadNextImpl(out);
PADDLE_ENFORCE_EQ(out->size(), dims_.size());
for (size_t i = 0; i < dims_.size(); ++i) {
auto &actual = out->at(i).dims();
auto &expect = dims_[i];

PADDLE_ENFORCE_EQ(actual.size(), expect.size());
for (int j = 0; j < actual.size(); ++j) {
PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
}
}
}
} // namespace framework
} // namespace paddle
51 changes: 20 additions & 31 deletions paddle/fluid/framework/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,29 @@

#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/platform/place.h"

#include <memory>
#include <thread>
#include <vector>

namespace paddle {
namespace framework {

class ReaderBase {
public:
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) {
PADDLE_ENFORCE(!shapes_.empty());
}
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;

virtual void ReInit() = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not related for this PR, just for discussion:

Do we need ReInit for ReaderBase? Maybe readers like network reader is a stream of data, can not be reinitialized back to the beginning.

What is the use case for ReInit? If we don't have a clear use case we probably should drop this method from the base class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a network reader, we can just throw an exception.

The reader module is in flux. I just discussed the design doc with @JiayiFeng yesterday. Whole design doc will be written soon, maybe within 1-2 days.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the prompt reply!


DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }

virtual bool HasNext() const = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think of replacing HasNext with pre-defined error code returned from ReadNext (e.g., EOF indicates does not have next).


virtual ~ReaderBase() {}

protected:
std::vector<DDim> shapes_;
};

class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {}
virtual ~ReaderBase();
};

class DecoratedReader : public ReaderBase {
public:
explicit DecoratedReader(ReaderBase* reader)
: ReaderBase(reader->shapes()), reader_(reader) {
explicit DecoratedReader(ReaderBase* reader) : ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}

Expand All @@ -61,6 +50,19 @@ class DecoratedReader : public ReaderBase {
ReaderBase* reader_;
};

class FileReader : public ReaderBase {
public:
explicit FileReader(const std::vector<DDim>& dims);

void ReadNext(std::vector<LoDTensor>* out) override;

protected:
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;

private:
std::vector<DDim> dims_;
};

// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
Expand All @@ -78,19 +80,6 @@ class ReaderHolder {
reader_->ReInit();
}

DDim shape(size_t idx) const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shape(idx);
}
std::vector<DDim> shapes() const {
PADDLE_ENFORCE_NOT_NULL(reader_);
return reader_->shapes();
}
void set_shapes(const std::vector<DDim>& shapes) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->set_shapes(shapes);
}

bool HasNext() const { return reader_->HasNext(); }

private:
Expand Down
111 changes: 88 additions & 23 deletions paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,31 @@ static constexpr size_t kDoubleBufferSize = 2;

class DoubleBufferReader : public framework::DecoratedReader {
public:
explicit DoubleBufferReader(ReaderBase* reader)
: DecoratedReader(reader),
buffer_(framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize)) {
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
struct Item {
Item() : ctx_(nullptr) {}

std::vector<framework::LoDTensor> payloads_;
platform::DeviceContext* ctx_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this ctx_ will only hold gpu_ctx. So gpu_ctx_ might be a better name.

};

explicit DoubleBufferReader(
ReaderBase* reader, platform::Place target_place = platform::CPUPlace())
: DecoratedReader(reader), place_(target_place) {
for (size_t i = 0; i < kDoubleBufferSize; ++i) {
if (platform::is_gpu_place(place_)) {
#ifdef PADDLE_WITH_CUDA
ctxs_.emplace_back(new platform::CUDADeviceContext(
boost::get<platform::CUDAPlace>(place_)));
#endif
}
}

start_thread();
}

void start_thread() {
buffer_ = framework::MakeChannel<Item>(kDoubleBufferSize);
std::thread prefetch([this] { PrefetchThreadFunc(); });
prefetch.detach();
}

Expand All @@ -42,7 +62,10 @@ class DoubleBufferReader : public framework::DecoratedReader {
private:
void PrefetchThreadFunc();

framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
framework::Channel<Item>* buffer_;
platform::Place place_;
std::vector<std::unique_ptr<platform::DeviceContext>> ctxs_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DoubleBufferReader only takes one Place and all ctx would be the same. So why do we need a vector here?

mutable Item local_buffer_;
};

class CreateDoubleBufferReaderOp : public framework::OperatorBase {
Expand All @@ -56,7 +79,20 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
->Get<framework::ReaderHolder>();
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new DoubleBufferReader(underlying_reader.Get()));

auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
sin.seekg(std::string("CUDA:").size(), std::ios::beg);
size_t num;
sin >> num;
place = platform::CUDAPlace(static_cast<int>(num));
}

out->Reset(new DoubleBufferReader(underlying_reader.Get(), place));
}
};

Expand All @@ -71,44 +107,73 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
It launches another thread to execute the 'underlying reader' asynchronously,
which prevents reading process from blocking subsequent training.
)DOC");
std::unordered_set<std::string> enum_range;
constexpr size_t kMaxCUDADevs = 128;
for (size_t i = 0; i < kMaxCUDADevs; ++i) {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
.InEnum({enum_range});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't 'CPU' be in this enum?

}
};

void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
out->clear();
buffer_->Receive(out);
if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}

*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
local_buffer_.ctx_->Wait();
}
}

void DoubleBufferReader::ReInit() {
reader_->ReInit();
buffer_->Close();
// The existing prefetch thread will terminate for the buffer_ is closed.
buffer_ = framework::MakeChannel<std::vector<framework::LoDTensor>>(
kDoubleBufferSize);
std::thread prefetch(&DoubleBufferReader::PrefetchThreadFunc, this);
prefetch.detach();
start_thread();
}

void DoubleBufferReader::PrefetchThreadFunc() {
VLOG(5) << "A new prefetch thread starts.";
while (true) {
std::vector<framework::LoDTensor> batch;
reader_->ReadNext(&batch);
if (batch.empty()) {
// EOF
buffer_->Close();
VLOG(5) << "Reached the end of the file. The prefetch thread terminates.";
break;
size_t gpu_ctx_offset = 0;
while (reader_->HasNext()) {
Item batch;
reader_->ReadNext(&batch.payloads_);
if (platform::is_gpu_place(place_)) {
std::vector<framework::LoDTensor> gpu_batch;
auto& gpu_ctx = this->ctxs_[gpu_ctx_offset++];
gpu_ctx_offset %= this->ctxs_.size();
gpu_batch.resize(batch.payloads_.size());
for (size_t i = 0; i < batch.payloads_.size(); ++i) {
framework::TensorCopy(batch.payloads_[i], place_, *gpu_ctx,
&gpu_batch[i]);
gpu_batch[i].set_lod(batch.payloads_[i].lod());
}
batch.ctx_ = gpu_ctx.get();
std::swap(gpu_batch, batch.payloads_);
}

if (!buffer_->Send(&batch)) {
VLOG(5) << "WARNING: The double buffer channel has been closed. The "
"prefetch thread terminates.";
break;
}
}
buffer_->Close();
}

bool DoubleBufferReader::HasNext() const { PADDLE_THROW("Not Implemented"); }
bool DoubleBufferReader::HasNext() const {
if (local_buffer_.payloads_.empty()) {
bool ok = buffer_->Receive(&local_buffer_);
return ok;
} else {
return true;
}
}

} // namespace reader
} // namespace operators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ namespace operators {
namespace reader {

template <typename T>
class RandomDataGenerator : public framework::FileReader {
class RandomDataGenerator : public framework::ReaderBase {
public:
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
float max)
: FileReader(shapes), min_(min), max_(max) {
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
unsigned int seed = std::random_device()();
Expand Down Expand Up @@ -59,6 +59,7 @@ class RandomDataGenerator : public framework::FileReader {
float max_;
std::minstd_rand engine_;
std::uniform_real_distribution<float> dist_;
std::vector<framework::DDim> shapes_;
};

template <typename T>
Expand Down
21 changes: 12 additions & 9 deletions paddle/fluid/operators/reader/create_recordio_file_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@ namespace operators {
namespace reader {
class RecordIOFileReader : public framework::FileReader {
public:
RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& shapes)
: FileReader(shapes),
explicit RecordIOFileReader(const std::string& filename,
const std::vector<framework::DDim>& dims)
: FileReader(dims),
scanner_(filename),
dev_ctx_(*platform::DeviceContextPool::Instance().Get(
platform::CPUPlace())) {}

void ReadNext(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}

bool HasNext() const override { return scanner_.HasNext(); }

void ReInit() override { scanner_.Reset(); }

protected:
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
*out = framework::ReadFromRecordIO(scanner_, dev_ctx_);
}

private:
recordio::Scanner scanner_;
const platform::DeviceContext& dev_ctx_;
Expand All @@ -54,12 +55,12 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
std::string filename = Attr<std::string>("filename");

auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new RecordIOFileReader(filename, shapes));
out->Reset(
new RecordIOFileReader(filename, RestoreShapes(shape_concat, ranks)));
}
};

Expand All @@ -85,3 +86,5 @@ namespace reader = paddle::operators::reader;
REGISTER_FILE_READER_OPERATOR(create_recordio_file_reader,
reader::CreateRecordIOReaderOp,
reader::CreateRecordIOReaderOpMaker);

REGISTER_FILE_READER(recordio, reader::RecordIOFileReader);
Loading