diff --git a/demo/semantic_role_labeling/.gitignore b/demo/semantic_role_labeling/.gitignore new file mode 100644 index 0000000000000..cd90ca7bbe9be --- /dev/null +++ b/demo/semantic_role_labeling/.gitignore @@ -0,0 +1,10 @@ +*.pyc +train.log +data/feature +data/conll05st-release/ +data/src.dict +data/test.wsj.props +data/test.wsj.seq_pair +data/test.wsj.words +data/tgt.dict +output diff --git a/paddle/gserver/dataproviders/DataProvider.cpp b/paddle/gserver/dataproviders/DataProvider.cpp index 8cefbb30ada46..2cfb5a3a18c8a 100644 --- a/paddle/gserver/dataproviders/DataProvider.cpp +++ b/paddle/gserver/dataproviders/DataProvider.cpp @@ -131,9 +131,10 @@ void DoubleBuffer::asyncLoadBatch() { taskReadySem_.wait(); if (stopping_) break; - while (batchSize_ == 0) { + while (batchSize_ == 0 && !stopping_) { usleep(5); } + if (stopping_) break; do { DataBatch newBatch; diff --git a/paddle/gserver/dataproviders/PyDataProvider2.cpp b/paddle/gserver/dataproviders/PyDataProvider2.cpp index ca8b07af49ca0..90391a7c307d8 100644 --- a/paddle/gserver/dataproviders/PyDataProvider2.cpp +++ b/paddle/gserver/dataproviders/PyDataProvider2.cpp @@ -433,26 +433,34 @@ class PyDataProvider2 : public DataProvider { inline void resetImpl(bool startNewThread) { DBG << "Reseting " << startNewThread; + exit_.store(true); if (loadThread_) { // is loading. - exit_.store(true); loadThread_->join(); loadThread_.reset(); } { PyGuard g; callingContexts_.clear(); + this->pullCV_.notify_one(); + } + + std::lock_guard guard(mutexForReset_); + { + PyGuard g; dataPool_.clear(); } poolActualSize_ = 0; - exit_ = false; + if (startNewThread && cache_->reset()) { DBG << "Start new thread."; loadThread_.reset(new std::thread([this] { + exit_ = false; loadThread(); })); callingContextCreated_.wait(); } DBG << "Reset done"; + exit_ = false; } private: @@ -465,6 +473,8 @@ class PyDataProvider2 : public DataProvider { std::condition_variable pullCV_; std::mutex mtx_; + std::mutex mutexForReset_; + ThreadBarrier callingContextCreated_; std::unique_ptr cache_; @@ -529,6 +539,7 @@ class PyDataProvider2 : public DataProvider { * Loading a batch of data. */ int64_t getNextBatchInternal(int64_t size_, DataBatch *batch) { + std::lock_guard guard(mutexForReset_); REGISTER_TIMER("PyDP2.getNextBatchInternal") CHECK_GE(size_, 0); size_t size = (size_t) size_; @@ -554,6 +565,10 @@ class PyDataProvider2 : public DataProvider { } else { // loading from cache. poolPtr = this->cache_->load(); } + if (exit_) { + // PyDataProvider is destructing. + return 0; + } CHECK(poolPtr != nullptr); std::deque& pool = *poolPtr; diff --git a/paddle/gserver/tests/test_PyDataProvider2.cpp b/paddle/gserver/tests/test_PyDataProvider2.cpp index 6bf1e32925121..b9867a728d9b4 100644 --- a/paddle/gserver/tests/test_PyDataProvider2.cpp +++ b/paddle/gserver/tests/test_PyDataProvider2.cpp @@ -353,6 +353,23 @@ TEST(PyDataProvider2, test_check) { } } +TEST(PyDataProvider2, multiThread) { + paddle::DataConfig config; + config.set_type("py2"); + config.set_files(FLAGS_train_list.c_str()); + config.set_load_data_module("test_PyDataProvider2"); + config.set_load_data_object("test_dense_no_seq"); + config.set_async_load_data(true); + + std::unique_ptr provider( + paddle::DataProvider::create(config, false)); + provider->reset(); + paddle::DataBatch batch; + provider->getNextBatch(100, &batch); + provider->reset(); + provider.reset(); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); paddle::initMain(argc, argv);