Skip to content

Commit

Permalink
finetune loading
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 4, 2015
1 parent e14a8a9 commit c9dab91
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ example
*.log
ps-lite
dmlc-core
bin
rabit
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ ifneq ($(ADD_LDFLAGS), NONE)
endif

# specify tensor path
BIN = bin/cxxnet bin/im2rec tools/bin2rec
BIN = bin/cxxnet bin/im2rec bin/bin2rec
SLIB = wrapper/libcxxnetwrapper.so
OBJ = layer_cpu.o updater_cpu.o nnet_cpu.o main.o nnet_ps_server.o
OBJCXX11 = data.o
Expand Down Expand Up @@ -114,7 +114,7 @@ wrapper/libcxxnetwrapper.so: wrapper/cxxnet_wrapper.cpp $(OBJ) $(OBJCXX11) $(CUD
bin/cxxnet: src/local_main.cpp $(OBJ) $(OBJCXX11) $(LIB_DEP) $(CUDEP)
bin/cxxnet.ps: $(OBJ) $(OBJCXX11) $(CUDEP) $(LIB_DEP) $(PS_LIB)
bin/im2rec: tools/im2rec.cc $(DMLC_CORE)/libdmlc.a
tools/bin2rec: tools/bin2rec.cc $(DMLC_CORE)/libdmlc.a
bin/bin2rec: tools/bin2rec.cc $(DMLC_CORE)/libdmlc.a

$(BIN) :
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
Expand Down
20 changes: 15 additions & 5 deletions src/cxxnet_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,15 @@ class CXXNetLearnTask {

inline void TaskTrain(void) {
bool is_root = true;
bool print_tracker = false;
#if MSHADOW_DIST_PS
is_root = ::ps::MyRank() == 0;
silent = !is_root;
#endif

#if MSHADOW_RABIT_PS
is_root = rabit::GetRank() == 0;
print_tracker = rabit::IsDistributed();
#endif
silent = !is_root;
time_t start = time(NULL);
Expand Down Expand Up @@ -468,10 +470,16 @@ class CXXNetLearnTask {
if (++ sample_counter % print_step == 0) {
elapsed = (long)(time(NULL) - start);
if (!silent) {
printf("\r \r");
printf("round %8d:[%8d] %ld sec elapsed", start_counter-1,
sample_counter, elapsed);
fflush(stdout);
std::ostringstream os;
os << "round " << std::setw(8) << start_counter - 1
<< ":[" << std::setw(8) << sample_counter << "] " << elapsed << " sec elapsed";
if (print_tracker) {
utils::TrackerPrint(os.str().c_str());
} else {
printf("\r \r");
printf("%s", os.str().c_str());
fflush(stdout);
}
}
}
}
Expand All @@ -490,7 +498,9 @@ class CXXNetLearnTask {
utils::TrackerPrint(os.str());
}
elapsed = (unsigned long)(time(NULL) - start);
this->SaveModel();
if (is_root) {
this->SaveModel();
}
}

if (!silent) {
Expand Down
2 changes: 1 addition & 1 deletion src/io/iter_augment_proc-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class AugmentIterator: public IIterator<DataInst> {
} else {
// substract mean image
if ((rand_mirror_ != 0 && rnd.NextDouble() < 0.5f) || mirror_ == 1) {
if (data.shape_ == meanimg_.shape_){
if (data.shape_ == meanimg_.shape_) {
img_ = mirror(crop((data - meanimg_) * contrast + illumination, img_[0].shape_, yy, xx)) * scale_;
} else {
img_ = (mirror(crop(data, img_[0].shape_, yy, xx) - meanimg_) * contrast + illumination) * scale_;
Expand Down
5 changes: 3 additions & 2 deletions src/io/iter_image_recordio-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class ImageRecordIOParser {
int maxthread;
#pragma omp parallel
{
maxthread = std::max(omp_get_num_procs() / 2 - 2, 1);
maxthread = std::max(omp_get_num_procs() / 2 - 1, 1);
}
nthread_ = std::min(maxthread, nthread_);
#pragma omp parallel num_threads(nthread_)
Expand Down Expand Up @@ -184,7 +184,7 @@ inline void ImageRecordIOParser::Init(void) {
(path_imgrec_.c_str(), dist_worker_rank_,
dist_num_worker_, "recordio");
// use 64 MB chunk when possible
source_->HintChunkSize(64 << 20UL);
source_->HintChunkSize(8 << 20UL);
}
inline void ImageRecordIOParser::
SetParam(const char *name, const char *val) {
Expand Down Expand Up @@ -275,6 +275,7 @@ class ImageRecordIOIterator : public IIterator<DataInst> {
}
virtual void Init(void) {
parser_.Init();
iter_.set_max_capacity(4);
iter_.Init([this](std::vector<InstVector> **dptr) {
if (*dptr == NULL) {
*dptr = new std::vector<InstVector>();
Expand Down
9 changes: 5 additions & 4 deletions src/nnet/neural_net-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ struct NeuralNet {
this->ConfigConntions();
for (size_t i = 0; i < connections.size(); ++i) {
if (this->cfg.layers[i].name != "") {
printf("Initializing layer: %s\n", this->cfg.layers[i].name.c_str());
utils::TrackerPrintf("Initializing layer: %s\n", this->cfg.layers[i].name.c_str());
} else {
printf("Initializing layer: %d\n", static_cast<int>(i));
utils::TrackerPrintf("Initializing layer: %d\n", static_cast<int>(i));
}
layer::Connection<xpu> &c = connections[i];
c.layer->InitConnection(c.nodes_in, c.nodes_out, &c.state);
Expand Down Expand Up @@ -225,8 +225,9 @@ struct NeuralNet {
for (size_t i = 0; i < nodes.size(); ++ i) {
mshadow::Shape<4> s = nodes[i].data.shape_;
nodes[i].AllocSpace();
printf("node[%s].shape: %u,%u,%u,%u\n", this->cfg.node_names[i].c_str(),
s[0], s[1], s[2], s[3]);
utils::TrackerPrintf("node[%s].shape: %u,%u,%u,%u\n",
this->cfg.node_names[i].c_str(),
s[0], s[1], s[2], s[3]);
}
}
private:
Expand Down
2 changes: 1 addition & 1 deletion src/updater/sgd_updater-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SGDUpdater : public IUpdater<xpu> {
virtual ~SGDUpdater(void) {}
virtual void Init(void) {
if (param.silent == 0) {
printf("SGDUpdater: eta=%f, mom=%f\n", param.base_lr_, param.momentum);
utils::TrackerPrintf("SGDUpdater: eta=%f, mom=%f\n", param.base_lr_, param.momentum);
}
m_w.Resize(w.shape_, 0.0f);
}
Expand Down
12 changes: 12 additions & 0 deletions src/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,18 @@ inline void TrackerPrint(const std::string msg, bool root_only = true) {
fflush(stderr);
#endif
}

/*! \brief portable version of snprintf */
inline void TrackerPrintf(const char *fmt, ...) {
const int kPrintBuffer = 1 << 10;
std::string msg(kPrintBuffer, '\0');
va_list args;
va_start(args, fmt);
vsnprintf(&msg[0], kPrintBuffer, fmt, args);
va_end(args);
msg.resize(strlen(msg.c_str()));
TrackerPrint(msg);
}
} // namespace utils
} // namespace cxxnet
#endif // CXXNET_UTILS_UTILS_H_

0 comments on commit c9dab91

Please sign in to comment.