Skip to content

Commit

Permalink
Changed the way threads are started and stopped
Browse files Browse the repository at this point in the history
- Interrupt the thread before waiting on join
- Provide a method for looping threads to exit on demand
- CHECK if start and stop succeed instead of returning an error
  • Loading branch information
cypof committed May 19, 2015
1 parent 95dcc20 commit 1f987e1
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 27 deletions.
8 changes: 5 additions & 3 deletions include/caffe/internal_thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ class InternalThread {
* Caffe's thread local state will be initialized using the current
* thread values, e.g. device id, solver index etc. The random seed
* is initialized using caffe_rng_rand.
* Will not return until the internal thread has exited.
*/
bool StartInternalThread();
void StartInternalThread();

/** Will not return until the internal thread has exited. */
bool WaitForInternalThreadToExit();
void StopInternalThread();

bool is_started() const;

Expand All @@ -39,6 +38,9 @@ class InternalThread {
with the code you want your thread to run. */
virtual void InternalThreadEntry() {}

/* Should be tested when running loops to exit when requested. */
bool must_stop();

private:
void entry();

Expand Down
33 changes: 19 additions & 14 deletions src/caffe/internal_thread.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <boost/thread.hpp>
#include <exception>

#include "caffe/internal_thread.hpp"
#include "caffe/util/math_functions.hpp"
Expand All @@ -13,17 +14,22 @@ InternalThread::InternalThread()
}

InternalThread::~InternalThread() {
WaitForInternalThreadToExit();
StopInternalThread();
}

bool InternalThread::is_started() const {
return thread_.get() != NULL && thread_->joinable();
return thread_ && thread_->joinable();
}

bool InternalThread::StartInternalThread() {
if (!WaitForInternalThreadToExit()) {
return false;
}
bool InternalThread::must_stop() {
return thread_ && thread_->interruption_requested();
}

void InternalThread::StartInternalThread() {
// TODO switch to failing once Caffe prefetch thread is persistent.
// Threads should not be started and stopped repeatedly.
// CHECK(!is_started());
StopInternalThread();

#ifndef CPU_ONLY
CUDA_CHECK(cudaGetDevice(&device_));
Expand All @@ -33,10 +39,9 @@ bool InternalThread::StartInternalThread() {

try {
thread_.reset(new boost::thread(&InternalThread::entry, this));
} catch (...) {
return false;
} catch (std::exception& e) {
CHECK(false) << e.what();
}
return true;
}

void InternalThread::entry() {
Expand All @@ -49,16 +54,16 @@ void InternalThread::entry() {
InternalThreadEntry();
}

/** Will not return until the internal thread has exited. */
bool InternalThread::WaitForInternalThreadToExit() {
void InternalThread::StopInternalThread() {
if (is_started()) {
thread_->interrupt();
try {
thread_->join();
} catch (...) {
return false;
} catch (boost::thread_interrupted&) {
} catch (std::exception& e) {
CHECK(false) << e.what();
}
}
return true;
}

} // namespace caffe
4 changes: 2 additions & 2 deletions src/caffe/layers/base_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ void BasePrefetchingDataLayer<Dtype>::LayerSetUp(
template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {
this->data_transformer_->InitRand();
CHECK(StartInternalThread()) << "Thread execution failed";
StartInternalThread();
}

template <typename Dtype>
void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {
CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";
StopInternalThread();
}

template <typename Dtype>
Expand Down
16 changes: 8 additions & 8 deletions src/caffe/test/test_internal_thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class InternalThreadTest : public ::testing::Test {};
TEST_F(InternalThreadTest, TestStartAndExit) {
InternalThread thread;
EXPECT_FALSE(thread.is_started());
EXPECT_TRUE(thread.StartInternalThread());
thread.StartInternalThread();
EXPECT_TRUE(thread.is_started());
EXPECT_TRUE(thread.WaitForInternalThreadToExit());
thread.StopInternalThread();
EXPECT_FALSE(thread.is_started());
}

Expand All @@ -35,18 +35,18 @@ class TestThreadB : public InternalThread {
TEST_F(InternalThreadTest, TestRandomSeed) {
TestThreadA t1;
Caffe::set_random_seed(9658361);
EXPECT_TRUE(t1.StartInternalThread());
EXPECT_TRUE(t1.WaitForInternalThreadToExit());
t1.StartInternalThread();
t1.StopInternalThread();

TestThreadA t2;
Caffe::set_random_seed(9658361);
EXPECT_TRUE(t2.StartInternalThread());
EXPECT_TRUE(t2.WaitForInternalThreadToExit());
t2.StartInternalThread();
t2.StopInternalThread();

TestThreadB t3;
Caffe::set_random_seed(3435563);
EXPECT_TRUE(t3.StartInternalThread());
EXPECT_TRUE(t3.WaitForInternalThreadToExit());
t3.StartInternalThread();
t3.StopInternalThread();
}

} // namespace caffe
Expand Down

0 comments on commit 1f987e1

Please sign in to comment.