From be16a012e25a38dae8c559de0c83e0ac4510d1db Mon Sep 17 00:00:00 2001 From: Cyprien Noel Date: Fri, 9 Sep 2016 12:36:51 -0700 Subject: [PATCH] Port multi-GPU to NCCL, add python support --- Makefile | 22 +- Makefile.config.example | 10 + include/caffe/blob.hpp | 1 + include/caffe/common.hpp | 14 +- include/caffe/data_reader.hpp | 82 --- include/caffe/internal_thread.hpp | 4 +- include/caffe/layer.hpp | 43 +- include/caffe/layers/base_data_layer.hpp | 5 +- include/caffe/layers/data_layer.hpp | 7 +- include/caffe/layers/python_layer.hpp | 4 +- include/caffe/net.hpp | 40 +- include/caffe/parallel.hpp | 91 ++-- include/caffe/solver.hpp | 41 +- include/caffe/util/math_functions.hpp | 5 + include/caffe/util/nccl.hpp | 37 ++ python/caffe/__init__.py | 4 +- python/caffe/_caffe.cpp | 95 +++- python/caffe/pycaffe.py | 2 +- python/train.py | 99 ++++ src/caffe/blob.cpp | 6 + src/caffe/common.cpp | 5 +- src/caffe/data_reader.cpp | 119 ----- src/caffe/internal_thread.cpp | 10 +- src/caffe/layer.cpp | 20 - src/caffe/layers/base_data_layer.cpp | 45 +- src/caffe/layers/base_data_layer.cu | 21 +- src/caffe/layers/data_layer.cpp | 80 ++- src/caffe/layers/hdf5_data_layer.cpp | 4 + src/caffe/layers/image_data_layer.cpp | 13 +- src/caffe/layers/window_data_layer.cpp | 8 +- src/caffe/net.cpp | 47 +- src/caffe/parallel.cpp | 484 ++++++++---------- src/caffe/proto/caffe.proto | 9 +- src/caffe/solver.cpp | 44 +- src/caffe/solvers/adagrad_solver.cpp | 1 - src/caffe/solvers/nesterov_solver.cpp | 1 - src/caffe/solvers/sgd_solver.cpp | 4 +- src/caffe/test/test_gradient_based_solver.cpp | 5 +- src/caffe/util/blocking_queue.cpp | 5 - src/caffe/util/db_lmdb.cpp | 2 +- src/caffe/util/math_functions.cu | 20 + tools/caffe.cpp | 5 +- 42 files changed, 766 insertions(+), 798 deletions(-) delete mode 100644 include/caffe/data_reader.hpp create mode 100644 include/caffe/util/nccl.hpp create mode 100644 python/train.py delete mode 100644 src/caffe/data_reader.cpp diff --git a/Makefile b/Makefile index 24894062a6c..bea3c4f186f 100644 --- a/Makefile +++ b/Makefile @@ -328,6 +328,12 @@ ifeq ($(USE_CUDNN), 1) COMMON_FLAGS += -DUSE_CUDNN endif +# NCCL acceleration configuration +ifeq ($(USE_NCCL), 1) + LIBRARIES += nccl + COMMON_FLAGS += -DUSE_NCCL +endif + # configure IO libraries ifeq ($(USE_OPENCV), 1) COMMON_FLAGS += -DUSE_OPENCV @@ -446,9 +452,20 @@ endif py mat py$(PROJECT) mat$(PROJECT) proto runtest \ superclean supercleanlist supercleanfiles warn everything +ifeq ($(CPU_ONLY), 1) +ifeq ($(USE_NCCL), 1) +checks: $(error Cannot define USE_NCCL with CPU_ONLY) +endif +ifeq ($(USE_CUDNN), 1) +checks: $(error Cannot define USE_CUDNN with CPU_ONLY) +endif +endif +checks: + all: lib tools examples -lib: $(STATIC_NAME) $(DYNAMIC_NAME) +lib: checks \ + $(STATIC_NAME) $(DYNAMIC_NAME) everything: $(EVERYTHING_TARGETS) @@ -495,7 +512,8 @@ examples: $(EXAMPLE_BINS) py$(PROJECT): py -py: $(PY$(PROJECT)_SO) $(PROTO_GEN_PY) +py: checks \ + $(PY$(PROJECT)_SO) $(PROTO_GEN_PY) $(PY$(PROJECT)_SO): $(PY$(PROJECT)_SRC) $(PY$(PROJECT)_HXX) | $(DYNAMIC_NAME) @ echo CXX/LD -o $@ $< diff --git a/Makefile.config.example b/Makefile.config.example index 07bed63ae40..13616d66129 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -94,6 +94,16 @@ LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib # INCLUDE_DIRS += $(shell brew --prefix)/include # LIBRARY_DIRS += $(shell brew --prefix)/lib +# NCCL acceleration switch (uncomment to build with NCCL) +# E.g. setup: +# cd +# git clone https://github.com/NVIDIA/nccl +# cd nccl +# make -j +# USE_NCCL := 1 +# INCLUDE_DIRS += $(HOME)/nccl/src +# LIBRARY_DIRS += $(HOME)/nccl/build/lib + # Uncomment to use `pkg-config` to specify OpenCV library paths. # (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.) # USE_PKG_CONFIG := 1 diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index af360ac24bd..2f59471c29e 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -220,6 +220,7 @@ class Blob { void set_cpu_data(Dtype* data); const int* gpu_shape() const; const Dtype* gpu_data() const; + void set_gpu_data(Dtype* data); const Dtype* cpu_diff() const; const Dtype* gpu_diff() const; Dtype* mutable_cpu_data(); diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 3c6a076ec2f..5156c22a8a4 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -158,11 +158,14 @@ class Caffe { // Search from start_id to the highest possible device ordinal, // return the ordinal of the first available device. static int FindDevice(const int start_id = 0); - // Parallel training info + // Parallel training inline static int solver_count() { return Get().solver_count_; } inline static void set_solver_count(int val) { Get().solver_count_ = val; } - inline static bool root_solver() { return Get().root_solver_; } - inline static void set_root_solver(bool val) { Get().root_solver_ = val; } + inline static int solver_rank() { return Get().solver_rank_; } + inline static void set_solver_rank(int val) { Get().solver_rank_ = val; } + inline static bool multi_process() { return Get().multi_process_; } + inline static void set_multi_process(bool val) { Get().multi_process_ = val; } + inline static bool root_solver() { return Get().solver_rank_ == 0; } protected: #ifndef CPU_ONLY @@ -172,8 +175,11 @@ class Caffe { shared_ptr random_generator_; Brew mode_; + + // Parallel training int solver_count_; - bool root_solver_; + int solver_rank_; + bool multi_process_; private: // The private constructor to avoid duplicate instantiation. diff --git a/include/caffe/data_reader.hpp b/include/caffe/data_reader.hpp deleted file mode 100644 index 8ed5542cb8d..00000000000 --- a/include/caffe/data_reader.hpp +++ /dev/null @@ -1,82 +0,0 @@ -#ifndef CAFFE_DATA_READER_HPP_ -#define CAFFE_DATA_READER_HPP_ - -#include -#include -#include - -#include "caffe/common.hpp" -#include "caffe/internal_thread.hpp" -#include "caffe/util/blocking_queue.hpp" -#include "caffe/util/db.hpp" - -namespace caffe { - -/** - * @brief Reads data from a source to queues available to data layers. - * A single reading thread is created per source, even if multiple solvers - * are running in parallel, e.g. for multi-GPU training. This makes sure - * databases are read sequentially, and that each solver accesses a different - * subset of the database. Data is distributed to solvers in a round-robin - * way to keep parallel training deterministic. - */ -class DataReader { - public: - explicit DataReader(const LayerParameter& param); - ~DataReader(); - - inline BlockingQueue& free() const { - return queue_pair_->free_; - } - inline BlockingQueue& full() const { - return queue_pair_->full_; - } - - protected: - // Queue pairs are shared between a body and its readers - class QueuePair { - public: - explicit QueuePair(int size); - ~QueuePair(); - - BlockingQueue free_; - BlockingQueue full_; - - DISABLE_COPY_AND_ASSIGN(QueuePair); - }; - - // A single body is created per source - class Body : public InternalThread { - public: - explicit Body(const LayerParameter& param); - virtual ~Body(); - - protected: - void InternalThreadEntry(); - void read_one(db::Cursor* cursor, QueuePair* qp); - - const LayerParameter param_; - BlockingQueue > new_queue_pairs_; - - friend class DataReader; - - DISABLE_COPY_AND_ASSIGN(Body); - }; - - // A source is uniquely identified by its layer name + path, in case - // the same database is read from two different locations in the net. - static inline string source_key(const LayerParameter& param) { - return param.name() + ":" + param.data_param().source(); - } - - const shared_ptr queue_pair_; - shared_ptr body_; - - static map > bodies_; - -DISABLE_COPY_AND_ASSIGN(DataReader); -}; - -} // namespace caffe - -#endif // CAFFE_DATA_READER_HPP_ diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 6a8c5a02892..0af2699667d 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -42,8 +42,8 @@ class InternalThread { bool must_stop(); private: - void entry(int device, Caffe::Brew mode, int rand_seed, int solver_count, - bool root_solver); + void entry(int device, Caffe::Brew mode, int rand_seed, + int solver_count, int solver_rank, bool multi_process); shared_ptr thread_; }; diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index 10f353f94f9..30dbfd53758 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -38,7 +38,7 @@ class Layer { * layer. */ explicit Layer(const LayerParameter& param) - : layer_param_(param), is_shared_(false) { + : layer_param_(param) { // Set phase and copy blobs (if there are any). phase_ = param.phase(); if (layer_param_.blobs_size() > 0) { @@ -66,7 +66,6 @@ class Layer { */ void SetUp(const vector*>& bottom, const vector*>& top) { - InitMutex(); CheckBlobCounts(bottom, top); LayerSetUp(bottom, top); Reshape(bottom, top); @@ -92,30 +91,6 @@ class Layer { virtual void LayerSetUp(const vector*>& bottom, const vector*>& top) {} - /** - * @brief Whether a layer should be shared by multiple nets during data - * parallelism. By default, all layers except for data layers should - * not be shared. data layers should be shared to ensure each worker - * solver access data sequentially during data parallelism. - */ - virtual inline bool ShareInParallel() const { return false; } - - /** @brief Return whether this layer is actually shared by other nets. - * If ShareInParallel() is true and using more than one GPU and the - * net has TRAIN phase, then this function is expected return true. - */ - inline bool IsShared() const { return is_shared_; } - - /** @brief Set whether this layer is actually shared by other nets - * If ShareInParallel() is true and using more than one GPU and the - * net has TRAIN phase, then is_shared should be set true. - */ - inline void SetShared(bool is_shared) { - CHECK(ShareInParallel() || !is_shared) - << type() << "Layer does not support sharing."; - is_shared_ = is_shared; - } - /** * @brief Adjust the shapes of top blobs and internal buffers to accommodate * the shapes of the bottom blobs. @@ -428,19 +403,6 @@ class Layer { } private: - /** Whether this layer is actually shared by other nets*/ - bool is_shared_; - - /** The mutex for sequential forward if this layer is shared */ - shared_ptr forward_mutex_; - - /** Initialize forward_mutex_ */ - void InitMutex(); - /** Lock forward_mutex_ if this layer is shared */ - void Lock(); - /** Unlock forward_mutex_ if this layer is shared */ - void Unlock(); - DISABLE_COPY_AND_ASSIGN(Layer); }; // class Layer @@ -450,8 +412,6 @@ class Layer { template inline Dtype Layer::Forward(const vector*>& bottom, const vector*>& top) { - // Lock during forward to ensure sequential forward - Lock(); Dtype loss = 0; Reshape(bottom, top); switch (Caffe::mode()) { @@ -482,7 +442,6 @@ inline Dtype Layer::Forward(const vector*>& bottom, default: LOG(FATAL) << "Unknown caffe mode."; } - Unlock(); return loss; } diff --git a/include/caffe/layers/base_data_layer.hpp b/include/caffe/layers/base_data_layer.hpp index 2c49b73184b..925b019d460 100644 --- a/include/caffe/layers/base_data_layer.hpp +++ b/include/caffe/layers/base_data_layer.hpp @@ -68,15 +68,16 @@ class BasePrefetchingDataLayer : const vector*>& top); // Prefetches batches (asynchronously if to GPU memory) - static const int PREFETCH_COUNT = 3; + static const int PREFETCH_COUNT = 4; // same as proto protected: virtual void InternalThreadEntry(); virtual void load_batch(Batch* batch) = 0; - Batch prefetch_[PREFETCH_COUNT]; + vector > > prefetch_; BlockingQueue*> prefetch_free_; BlockingQueue*> prefetch_full_; + Batch* prefetch_current_; Blob transformed_data_; }; diff --git a/include/caffe/layers/data_layer.hpp b/include/caffe/layers/data_layer.hpp index 6c361791a0c..f64294e0b20 100644 --- a/include/caffe/layers/data_layer.hpp +++ b/include/caffe/layers/data_layer.hpp @@ -4,7 +4,6 @@ #include #include "caffe/blob.hpp" -#include "caffe/data_reader.hpp" #include "caffe/data_transformer.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" @@ -29,9 +28,13 @@ class DataLayer : public BasePrefetchingDataLayer { virtual inline int MaxTopBlobs() const { return 2; } protected: + void Next(); + bool Skip(); virtual void load_batch(Batch* batch); - DataReader reader_; + shared_ptr db_; + shared_ptr cursor_; + uint64_t skip_counter_; }; } // namespace caffe diff --git a/include/caffe/layers/python_layer.hpp b/include/caffe/layers/python_layer.hpp index 66dbbdf13b8..eab277adc6c 100644 --- a/include/caffe/layers/python_layer.hpp +++ b/include/caffe/layers/python_layer.hpp @@ -21,8 +21,8 @@ class PythonLayer : public Layer { // Disallow PythonLayer in MultiGPU training stage, due to GIL issues // Details: https://github.com/BVLC/caffe/issues/2936 if (this->phase_ == TRAIN && Caffe::solver_count() > 1 - && !ShareInParallel()) { - LOG(FATAL) << "PythonLayer is not implemented in Multi-GPU training"; + && !Caffe::root_solver() && !Caffe::multi_process()) { + LOG(FATAL) << "PythonLayer does not support CLI Multi-GPU, use train.py"; } self_.attr("param_str") = bp::str( this->layer_param_.python_param().param_str()); diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 493bdf294e2..d3c9306e9cf 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -23,10 +23,9 @@ namespace caffe { template class Net { public: - explicit Net(const NetParameter& param, const Net* root_net = NULL); + explicit Net(const NetParameter& param); explicit Net(const string& param_file, Phase phase, - const int level = 0, const vector* stages = NULL, - const Net* root_net = NULL); + const int level = 0, const vector* stages = NULL); virtual ~Net() {} /// @brief Initialize a network with a NetParameter. @@ -228,6 +227,31 @@ class Net { static bool StateMeetsRule(const NetState& state, const NetStateRule& rule, const string& layer_name); + // Invoked at specific points during an iteration + class Callback { + protected: + virtual void run(int layer) = 0; + + template + friend class Net; + }; + const vector& before_forward() const { return before_forward_; } + void add_before_forward(Callback* value) { + before_forward_.push_back(value); + } + const vector& after_forward() const { return after_forward_; } + void add_after_forward(Callback* value) { + after_forward_.push_back(value); + } + const vector& before_backward() const { return before_backward_; } + void add_before_backward(Callback* value) { + before_backward_.push_back(value); + } + const vector& after_backward() const { return after_backward_; } + void add_after_backward(Callback* value) { + after_backward_.push_back(value); + } + protected: // Helpers for Init. /// @brief Append a new top blob to the net. @@ -306,9 +330,13 @@ class Net { size_t memory_used_; /// Whether to compute and display debug info for the net. bool debug_info_; - /// The root net that actually holds the shared layers in data parallelism - const Net* const root_net_; - DISABLE_COPY_AND_ASSIGN(Net); + // Callbacks + vector before_forward_; + vector after_forward_; + vector before_backward_; + vector after_backward_; + +DISABLE_COPY_AND_ASSIGN(Net); }; diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp index 6c496c884e3..13cef7e3d68 100644 --- a/include/caffe/parallel.hpp +++ b/include/caffe/parallel.hpp @@ -1,8 +1,9 @@ #ifndef CAFFE_PARALLEL_HPP_ #define CAFFE_PARALLEL_HPP_ -#include +#include +#include #include #include "caffe/blob.hpp" @@ -14,6 +15,12 @@ #include "caffe/syncedmem.hpp" #include "caffe/util/blocking_queue.hpp" +#ifdef USE_NCCL + +#include "caffe/util/nccl.hpp" + +#endif + namespace caffe { // Represents a net parameters. Once a net is created, its parameter buffers can @@ -59,58 +66,54 @@ class GPUParams : public Params { using Params::diff_; }; -class DevicePair { - public: - DevicePair(int parent, int device) - : parent_(parent), - device_(device) { - } - inline int parent() { - return parent_; - } - inline int device() { - return device_; - } - - // Group GPUs in pairs, by proximity depending on machine's topology - static void compute(const vector devices, vector* pairs); - - protected: - int parent_; - int device_; -}; - -// Synchronous data parallelism using map-reduce between local GPUs. template -class P2PSync : public GPUParams, public Solver::Callback, - public InternalThread { +class NCCL : public GPUParams, + public Solver::Callback, + public Net::Callback { public: - explicit P2PSync(shared_ptr > root_solver, - P2PSync* parent, const SolverParameter& param); - virtual ~P2PSync(); - - inline const shared_ptr >& solver() const { - return solver_; - } - - void Run(const vector& gpus); - void Prepare(const vector& gpus, - vector > >* syncs); - inline const int initial_iter() const { return initial_iter_; } + /** + * Single process version. + */ + NCCL(shared_ptr > solver, boost::barrier* barrier); + /** + * In multi-process settings, first create a NCCL id (new_uid), then + * pass it to each process to create connected instances. + */ + NCCL(shared_ptr > solver, const string& uid); + ~NCCL(); + + /** + * In single process settings, create instances without uids and + * call this to connect them. + */ + static void InitSingleProcess(vector*>* nccls); + + static string new_uid(); + + /** + * Broadcast weigths from rank 0 other solvers. + */ + void Broadcast(); + + /** + * Single process multi-GPU. + */ + static void Run(shared_ptr > solver, const vector& gpus); protected: + void Init(); void on_start(); + void run(int layer); // Net callback void on_gradients_ready(); - void InternalThreadEntry(); +#ifdef USE_NCCL + ncclComm_t comm_; + cudaStream_t stream_; +#endif - P2PSync* parent_; - vector*> children_; - BlockingQueue*> queue_; - const int initial_iter_; - Dtype* parent_grads_; shared_ptr > solver_; - + // May not be necessary, https://github.com/NVIDIA/nccl/issues/37 + boost::barrier* barrier_; using Params::size_; using Params::data_; using Params::diff_; diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index eafcee32904..6a2b8b9f856 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -6,6 +6,7 @@ #include "caffe/net.hpp" #include "caffe/solver_factory.hpp" +#include "caffe/util/benchmark.hpp" namespace caffe { @@ -40,9 +41,8 @@ typedef boost::function ActionCallback; template class Solver { public: - explicit Solver(const SolverParameter& param, - const Solver* root_solver = NULL); - explicit Solver(const string& param_file, const Solver* root_solver = NULL); + explicit Solver(const SolverParameter& param); + explicit Solver(const string& param_file); void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets(); @@ -72,7 +72,8 @@ class Solver { inline const vector > >& test_nets() { return test_nets_; } - int iter() { return iter_; } + int iter() const { return iter_; } + void set_iter(int value) { iter_ = value; } // Invoked at specific points during an iteration class Callback { @@ -118,10 +119,6 @@ class Solver { vector losses_; Dtype smoothed_loss_; - // The root solver that holds root nets (actually containing shared layers) - // in data parallelism - const Solver* const root_solver_; - // A function that can be set by a client of the Solver to provide indication // that it wants a snapshot saved and/or to exit early. ActionCallback action_request_function_; @@ -129,31 +126,11 @@ class Solver { // True iff a request to stop early was received. bool requested_early_exit_; - DISABLE_COPY_AND_ASSIGN(Solver); -}; + // Timing information, handy to tune e.g. nbr of GPUs + Timer iteration_timer_; + float iterations_last_; -/** - * @brief Solver that only computes gradients, used as worker - * for multi-GPU training. - */ -template -class WorkerSolver : public Solver { - public: - explicit WorkerSolver(const SolverParameter& param, - const Solver* root_solver = NULL) - : Solver(param, root_solver) {} - - protected: - void ApplyUpdate() {} - void SnapshotSolverState(const string& model_filename) { - LOG(FATAL) << "Should not be called on worker solver."; - } - void RestoreSolverStateFromBinaryProto(const string& state_file) { - LOG(FATAL) << "Should not be called on worker solver."; - } - void RestoreSolverStateFromHDF5(const string& state_file) { - LOG(FATAL) << "Should not be called on worker solver."; - } + DISABLE_COPY_AND_ASSIGN(Solver); }; } // namespace caffe diff --git a/include/caffe/util/math_functions.hpp b/include/caffe/util/math_functions.hpp index 6f6d3feeae2..51068fe2b80 100644 --- a/include/caffe/util/math_functions.hpp +++ b/include/caffe/util/math_functions.hpp @@ -185,6 +185,11 @@ void caffe_gpu_add_scalar(const int N, const Dtype alpha, Dtype *X); template void caffe_gpu_scal(const int N, const Dtype alpha, Dtype *X); +#ifndef CPU_ONLY +template +void caffe_gpu_scal(const int N, const Dtype alpha, Dtype* X, cudaStream_t str); +#endif + template void caffe_gpu_add(const int N, const Dtype* a, const Dtype* b, Dtype* y); diff --git a/include/caffe/util/nccl.hpp b/include/caffe/util/nccl.hpp new file mode 100644 index 00000000000..e01fb7451e8 --- /dev/null +++ b/include/caffe/util/nccl.hpp @@ -0,0 +1,37 @@ +#ifndef CAFFE_UTIL_NCCL_H_ +#define CAFFE_UTIL_NCCL_H_ +#ifdef USE_NCCL + +#include + +#include "caffe/common.hpp" + +#define NCCL_CHECK(condition) \ +{ \ + ncclResult_t result = condition; \ + CHECK_EQ(result, ncclSuccess) << " " \ + << ncclGetErrorString(result); \ +} + +namespace caffe { + +namespace nccl { + +template class dataType; + +template<> class dataType { + public: + static const ncclDataType_t type = ncclFloat; +}; +template<> class dataType { + public: + static const ncclDataType_t type = ncclDouble; +}; + +} // namespace nccl + +} // namespace caffe + +#endif // end USE_NCCL + +#endif // CAFFE_UTIL_NCCL_H_ diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py index 35868a403a3..015262090d6 100644 --- a/python/caffe/__init__.py +++ b/python/caffe/__init__.py @@ -1,5 +1,5 @@ -from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver -from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, set_random_seed +from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer +from ._caffe import init_log, log, set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list, solver_count, set_solver_count, solver_rank, set_solver_rank, Layer, get_solver, layer_type_list from ._caffe import __version__ from .proto.caffe_pb2 import TRAIN, TEST from .classifier import Classifier diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index bdee75acd6c..6353c226617 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -51,7 +51,18 @@ const int NPY_DTYPE = NPY_FLOAT32; void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); } void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); } -void set_random_seed(unsigned int seed) { Caffe::set_random_seed(seed); } +void InitLog(int level) { + FLAGS_logtostderr = 1; + FLAGS_minloglevel = level; + ::google::InitGoogleLogging(""); + ::google::InstallFailureSignalHandler(); +} +void InitLogInfo() { + InitLog(google::INFO); +} +void Log(const string& s) { + LOG(INFO) << s; +} // For convenience, check that input files can be opened, and raise an // exception that boost will send to Python if not (caffe could still crash @@ -254,12 +265,12 @@ bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) { } template -class PythonCallback: public Solver::Callback { +class SolverCallback: public Solver::Callback { protected: bp::object on_start_, on_gradients_ready_; public: - PythonCallback(bp::object on_start, bp::object on_gradients_ready) + SolverCallback(bp::object on_start, bp::object on_gradients_ready) : on_start_(on_start), on_gradients_ready_(on_gradients_ready) { } virtual void on_gradients_ready() { on_gradients_ready_(); @@ -271,7 +282,37 @@ class PythonCallback: public Solver::Callback { template void Solver_add_callback(Solver * solver, bp::object on_start, bp::object on_gradients_ready) { - solver->add_callback(new PythonCallback(on_start, on_gradients_ready)); + solver->add_callback(new SolverCallback(on_start, on_gradients_ready)); +} +// Seems boost cannot call the base method directly +void Solver_add_nccl(SGDSolver* solver, NCCL* nccl) { + solver->add_callback(nccl); +} +template +class NetCallback: public Net::Callback { + public: + explicit NetCallback(bp::object run) : run_(run) {} + + protected: + virtual void run(int layer) { + run_(layer); + } + bp::object run_; +}; +void Net_before_forward(Net* net, bp::object run) { + net->add_before_forward(new NetCallback(run)); +} +void Net_after_forward(Net* net, bp::object run) { + net->add_after_forward(new NetCallback(run)); +} +void Net_before_backward(Net* net, bp::object run) { + net->add_before_backward(new NetCallback(run)); +} +void Net_after_backward(Net* net, bp::object run) { + net->add_after_backward(new NetCallback(run)); +} +void Net_add_nccl(Net* net, NCCL* nccl) { + net->add_after_backward(nccl); } BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1); @@ -283,10 +324,16 @@ BOOST_PYTHON_MODULE(_caffe) { bp::scope().attr("__version__") = AS_STRING(CAFFE_VERSION); // Caffe utility functions + bp::def("init_log", &InitLog); + bp::def("init_log", &InitLogInfo); + bp::def("log", &Log); bp::def("set_mode_cpu", &set_mode_cpu); bp::def("set_mode_gpu", &set_mode_gpu); - bp::def("set_random_seed", &set_random_seed); bp::def("set_device", &Caffe::SetDevice); + bp::def("solver_count", &Caffe::solver_count); + bp::def("set_solver_count", &Caffe::set_solver_count); + bp::def("solver_rank", &Caffe::solver_rank); + bp::def("set_solver_rank", &Caffe::set_solver_rank); bp::def("layer_type_list", &LayerRegistry::LayerTypeList); @@ -330,7 +377,12 @@ BOOST_PYTHON_MODULE(_caffe) { bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) .def("save", &Net_Save) .def("save_hdf5", &Net_SaveHDF5) - .def("load_hdf5", &Net_LoadHDF5); + .def("load_hdf5", &Net_LoadHDF5) + .def("before_forward", &Net_before_forward) + .def("after_forward", &Net_after_forward) + .def("before_backward", &Net_before_backward) + .def("after_backward", &Net_after_backward) + .def("after_backward", &Net_add_nccl); BP_REGISTER_SHARED_PTR_TO_PYTHON(Net); bp::class_, shared_ptr >, boost::noncopyable>( @@ -359,10 +411,18 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_internal_reference<>())) .def("setup", &Layer::LayerSetUp) .def("reshape", &Layer::Reshape) - .add_property("type", bp::make_function(&Layer::type)); + .add_property("type", bp::make_function(&Layer::type)) + .add_property("layer_param", bp::make_function(&Layer::layer_param, + bp::return_value_policy())); BP_REGISTER_SHARED_PTR_TO_PYTHON(Layer); - bp::class_("LayerParameter", bp::no_init); + bp::class_("SolverParameter", bp::init<>()) + .add_property("max_iter", &SolverParameter::max_iter) + .add_property("display", &SolverParameter::display) + .add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce); + bp::class_("LayerParameter", bp::init<>()) + .add_property("name", bp::make_function(&LayerParameter::name, + bp::return_value_policy())); bp::class_, shared_ptr >, boost::noncopyable>( "Solver", bp::no_init) @@ -371,11 +431,14 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_internal_reference<>())) .add_property("iter", &Solver::iter) .def("add_callback", &Solver_add_callback) + .def("add_callback", &Solver_add_nccl) .def("solve", static_cast::*)(const char*)>( &Solver::Solve), SolveOverloads()) .def("step", &Solver::Step) .def("restore", &Solver::Restore) - .def("snapshot", &Solver::Snapshot); + .def("snapshot", &Solver::Snapshot) + .add_property("param", bp::make_function(&Solver::param, + bp::return_value_policy())); BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver); bp::class_, bp::bases >, @@ -419,6 +482,20 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_ >("BoolVec") .def(bp::vector_indexing_suite >()); + bp::class_, shared_ptr >, + boost::noncopyable>("NCCL", + bp::init >, const string&>()) + .def("new_uid", &NCCL::new_uid).staticmethod("new_uid") + .def("bcast", &NCCL::Broadcast); + BP_REGISTER_SHARED_PTR_TO_PYTHON(NCCL); + + bp::class_, boost::noncopyable>( + "Timer", bp::init<>()) + .def("start", &Timer::Start) + .def("stop", &Timer::Stop) + .add_property("ms", &Timer::MilliSeconds); + BP_REGISTER_SHARED_PTR_TO_PYTHON(Timer); + // boost python expects a void (missing) return value, while import_array // returns NULL for python3. import_array1() forces a void return value. import_array1(); diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5bae18d9a4d..18803818fef 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -11,7 +11,7 @@ import numpy as np from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \ - RMSPropSolver, AdaDeltaSolver, AdamSolver + RMSPropSolver, AdaDeltaSolver, AdamSolver, NCCL, Timer import caffe.io import six diff --git a/python/train.py b/python/train.py new file mode 100644 index 00000000000..730dbe70186 --- /dev/null +++ b/python/train.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +""" +Trains a model using one or more GPUs. +""" +from multiprocessing import Process + +import caffe + + +def train( + solver, # solver proto definition + snapshot, # solver snapshot to restore + gpus, # list of device ids + timing=False, # show timing info for compute and communications +): + # NCCL uses a uid to identify a session + uid = caffe.NCCL.new_uid() + + caffe.init_log() + caffe.log('Using devices %s' % str(gpus)) + + procs = [] + for rank in range(len(gpus)): + p = Process(target=solve, + args=(solver, snapshot, gpus, timing, uid, rank)) + p.daemon = True + p.start() + procs.append(p) + for p in procs: + p.join() + + +def time(solver, nccl): + fprop = [] + bprop = [] + total = caffe.Timer() + allrd = caffe.Timer() + for _ in range(len(solver.net.layers)): + fprop.append(caffe.Timer()) + bprop.append(caffe.Timer()) + display = solver.param.display + + def show_time(): + if solver.iter % display == 0: + s = '\n' + for i in range(len(solver.net.layers)): + s += 'forw %3d %8s ' % (i, solver.net.layers[i].layer_param.name) + s += ': %.2f\n' % fprop[i].ms + for i in range(len(solver.net.layers) - 1, -1, -1): + s += 'back %3d %8s ' % (i, solver.net.layers[i].layer_param.name) + s += ': %.2f\n' % bprop[i].ms + s += 'solver total: %.2f\n' % total.ms + s += 'allreduce: %.2f\n' % allrd.ms + caffe.log(s) + + solver.net.before_forward(lambda layer: fprop[layer].start()) + solver.net.after_forward(lambda layer: fprop[layer].stop()) + solver.net.before_backward(lambda layer: bprop[layer].start()) + solver.net.after_backward(lambda layer: bprop[layer].stop()) + solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start())) + solver.add_callback(nccl) + solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time())) + + +def solve(proto, snapshot, gpus, timing, uid, rank): + caffe.set_mode_gpu() + caffe.set_device(gpus[rank]) + caffe.set_solver_count(len(gpus)) + caffe.set_solver_rank(rank) + + solver = caffe.SGDSolver(proto) + if snapshot and len(snapshot) != 0: + solver.restore(snapshot) + + nccl = caffe.NCCL(solver, uid) + nccl.bcast() + + if timing and rank == 0: + time(solver, nccl) + else: + solver.add_callback(nccl) + + if solver.param.layer_wise_reduce: + solver.net.after_backward(nccl) + solver.step(solver.param.max_iter) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument("--solver", required=True, help="Solver proto definition.") + parser.add_argument("--snapshot", help="Solver snapshot to restore.") + parser.add_argument("--gpus", type=int, nargs='+', default=[0], + help="List of device ids.") + parser.add_argument("--timing", action='store_true', help="Show timing info.") + args = parser.parse_args() + + train(args.solver, args.snapshot, args.gpus, args.timing) diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 4a34e4c5856..863d940c190 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -98,6 +98,12 @@ const Dtype* Blob::gpu_data() const { return (const Dtype*)data_->gpu_data(); } +template +void Blob::set_gpu_data(Dtype* data) { + CHECK(data); + data_->set_gpu_data(data); +} + template const Dtype* Blob::cpu_diff() const { CHECK(diff_); diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index dee681654aa..1372a9bc6d3 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -53,7 +53,7 @@ void GlobalInit(int* pargc, char*** pargv) { Caffe::Caffe() : random_generator_(), mode_(Caffe::CPU), - solver_count_(1), root_solver_(true) { } + solver_count_(1), solver_rank_(0), multi_process_(false) { } Caffe::~Caffe() { } @@ -106,7 +106,8 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU), solver_count_(1), root_solver_(true) { + mode_(Caffe::CPU), + solver_count_(1), solver_rank_(0), multi_process_(false) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp deleted file mode 100644 index 9f019bbfcb7..00000000000 --- a/src/caffe/data_reader.cpp +++ /dev/null @@ -1,119 +0,0 @@ -#include -#include -#include -#include - -#include "caffe/common.hpp" -#include "caffe/data_reader.hpp" -#include "caffe/layers/data_layer.hpp" -#include "caffe/proto/caffe.pb.h" - -namespace caffe { - -using boost::weak_ptr; - -map > DataReader::bodies_; -static boost::mutex bodies_mutex_; - -DataReader::DataReader(const LayerParameter& param) - : queue_pair_(new QueuePair( // - param.data_param().prefetch() * param.data_param().batch_size())) { - // Get or create a body - boost::mutex::scoped_lock lock(bodies_mutex_); - string key = source_key(param); - weak_ptr& weak = bodies_[key]; - body_ = weak.lock(); - if (!body_) { - body_.reset(new Body(param)); - bodies_[key] = weak_ptr(body_); - } - body_->new_queue_pairs_.push(queue_pair_); -} - -DataReader::~DataReader() { - string key = source_key(body_->param_); - body_.reset(); - boost::mutex::scoped_lock lock(bodies_mutex_); - if (bodies_[key].expired()) { - bodies_.erase(key); - } -} - -// - -DataReader::QueuePair::QueuePair(int size) { - // Initialize the free queue with requested number of datums - for (int i = 0; i < size; ++i) { - free_.push(new Datum()); - } -} - -DataReader::QueuePair::~QueuePair() { - Datum* datum; - while (free_.try_pop(&datum)) { - delete datum; - } - while (full_.try_pop(&datum)) { - delete datum; - } -} - -// - -DataReader::Body::Body(const LayerParameter& param) - : param_(param), - new_queue_pairs_() { - StartInternalThread(); -} - -DataReader::Body::~Body() { - StopInternalThread(); -} - -void DataReader::Body::InternalThreadEntry() { - shared_ptr db(db::GetDB(param_.data_param().backend())); - db->Open(param_.data_param().source(), db::READ); - shared_ptr cursor(db->NewCursor()); - vector > qps; - try { - int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; - - // To ensure deterministic runs, only start running once all solvers - // are ready. But solvers need to peek on one item during initialization, - // so read one item, then wait for the next solver. - for (int i = 0; i < solver_count; ++i) { - shared_ptr qp(new_queue_pairs_.pop()); - read_one(cursor.get(), qp.get()); - qps.push_back(qp); - } - // Main loop - while (!must_stop()) { - for (int i = 0; i < solver_count; ++i) { - read_one(cursor.get(), qps[i].get()); - } - // Check no additional readers have been created. This can happen if - // more than one net is trained at a time per process, whether single - // or multi solver. It might also happen if two data layers have same - // name and same source. - CHECK_EQ(new_queue_pairs_.size(), 0); - } - } catch (boost::thread_interrupted&) { - // Interrupted exception is expected on shutdown - } -} - -void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { - Datum* datum = qp->free_.pop(); - // TODO deserialize in-place instead of copy? - datum->ParseFromString(cursor->value()); - qp->full_.push(datum); - - // go to the next iter - cursor->Next(); - if (!cursor->valid()) { - DLOG(INFO) << "Restarting data prefetching from start."; - cursor->SeekToFirst(); - } -} - -} // namespace caffe diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index 104884e0295..26c09a34bda 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -28,25 +28,27 @@ void InternalThread::StartInternalThread() { Caffe::Brew mode = Caffe::mode(); int rand_seed = caffe_rng_rand(); int solver_count = Caffe::solver_count(); - bool root_solver = Caffe::root_solver(); + int solver_rank = Caffe::solver_rank(); + bool multi_process = Caffe::multi_process(); try { thread_.reset(new boost::thread(&InternalThread::entry, this, device, mode, - rand_seed, solver_count, root_solver)); + rand_seed, solver_count, solver_rank, multi_process)); } catch (std::exception& e) { LOG(FATAL) << "Thread exception: " << e.what(); } } void InternalThread::entry(int device, Caffe::Brew mode, int rand_seed, - int solver_count, bool root_solver) { + int solver_count, int solver_rank, bool multi_process) { #ifndef CPU_ONLY CUDA_CHECK(cudaSetDevice(device)); #endif Caffe::set_mode(mode); Caffe::set_random_seed(rand_seed); Caffe::set_solver_count(solver_count); - Caffe::set_root_solver(root_solver); + Caffe::set_solver_rank(solver_rank); + Caffe::set_multi_process(multi_process); InternalThreadEntry(); } diff --git a/src/caffe/layer.cpp b/src/caffe/layer.cpp index 3b9128986ae..684ae88bb49 100644 --- a/src/caffe/layer.cpp +++ b/src/caffe/layer.cpp @@ -1,27 +1,7 @@ -#include #include "caffe/layer.hpp" namespace caffe { -template -void Layer::InitMutex() { - forward_mutex_.reset(new boost::mutex()); -} - -template -void Layer::Lock() { - if (IsShared()) { - forward_mutex_->lock(); - } -} - -template -void Layer::Unlock() { - if (IsShared()) { - forward_mutex_->unlock(); - } -} - INSTANTIATE_CLASS(Layer); } // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 989319f1a07..9414f6f98b2 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -36,9 +36,12 @@ template BasePrefetchingDataLayer::BasePrefetchingDataLayer( const LayerParameter& param) : BaseDataLayer(param), - prefetch_free_(), prefetch_full_() { - for (int i = 0; i < PREFETCH_COUNT; ++i) { - prefetch_free_.push(&prefetch_[i]); + prefetch_(param.has_data_param() ? + param.data_param().prefetch() : PREFETCH_COUNT), + prefetch_free_(), prefetch_full_(), prefetch_current_() { + for (int i = 0; i < prefetch_.size(); ++i) { + prefetch_[i].reset(new Batch()); + prefetch_free_.push(prefetch_[i].get()); } } @@ -46,22 +49,23 @@ template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); + // Before starting the prefetch thread, we make cpu_data and gpu_data // calls so that the prefetch thread does not accidentally make simultaneous // cudaMalloc calls when the main thread is running. In some GPUs this // seems to cause failures if we do not so. - for (int i = 0; i < PREFETCH_COUNT; ++i) { - prefetch_[i].data_.mutable_cpu_data(); + for (int i = 0; i < prefetch_.size(); ++i) { + prefetch_[i]->data_.mutable_cpu_data(); if (this->output_labels_) { - prefetch_[i].label_.mutable_cpu_data(); + prefetch_[i]->label_.mutable_cpu_data(); } } #ifndef CPU_ONLY if (Caffe::mode() == Caffe::GPU) { - for (int i = 0; i < PREFETCH_COUNT; ++i) { - prefetch_[i].data_.mutable_gpu_data(); + for (int i = 0; i < prefetch_.size(); ++i) { + prefetch_[i]->data_.mutable_gpu_data(); if (this->output_labels_) { - prefetch_[i].label_.mutable_gpu_data(); + prefetch_[i]->label_.mutable_gpu_data(); } } } @@ -88,6 +92,9 @@ void BasePrefetchingDataLayer::InternalThreadEntry() { #ifndef CPU_ONLY if (Caffe::mode() == Caffe::GPU) { batch->data_.data().get()->async_gpu_push(stream); + if (this->output_labels_) { + batch->label_.data().get()->async_gpu_push(stream); + } CUDA_CHECK(cudaStreamSynchronize(stream)); } #endif @@ -106,22 +113,18 @@ void BasePrefetchingDataLayer::InternalThreadEntry() { template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); + if (prefetch_current_) { + prefetch_free_.push(prefetch_current_); + } + prefetch_current_ = prefetch_full_.pop("Waiting for data"); // Reshape to loaded data. - top[0]->ReshapeLike(batch->data_); - // Copy the data - caffe_copy(batch->data_.count(), batch->data_.cpu_data(), - top[0]->mutable_cpu_data()); - DLOG(INFO) << "Prefetch copied"; + top[0]->ReshapeLike(prefetch_current_->data_); + top[0]->set_cpu_data(prefetch_current_->data_.mutable_cpu_data()); if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(batch->label_); - // Copy the labels. - caffe_copy(batch->label_.count(), batch->label_.cpu_data(), - top[1]->mutable_cpu_data()); + top[1]->ReshapeLike(prefetch_current_->label_); + top[1]->set_cpu_data(prefetch_current_->label_.mutable_cpu_data()); } - - prefetch_free_.push(batch); } #ifdef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 4056d36a7b4..64c621a74f1 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,23 +7,18 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); + if (prefetch_current_) { + prefetch_free_.push(prefetch_current_); + } + prefetch_current_ = prefetch_full_.pop("Waiting for data"); // Reshape to loaded data. - top[0]->ReshapeLike(batch->data_); - // Copy the data - caffe_copy(batch->data_.count(), batch->data_.gpu_data(), - top[0]->mutable_gpu_data()); + top[0]->ReshapeLike(prefetch_current_->data_); + top[0]->set_gpu_data(prefetch_current_->data_.mutable_gpu_data()); if (this->output_labels_) { // Reshape to loaded labels. - top[1]->ReshapeLike(batch->label_); - // Copy the labels. - caffe_copy(batch->label_.count(), batch->label_.gpu_data(), - top[1]->mutable_gpu_data()); + top[1]->ReshapeLike(prefetch_current_->label_); + top[1]->set_gpu_data(prefetch_current_->label_.mutable_gpu_data()); } - // Ensure the copy is synchronous wrt the host, so that the next batch isn't - // copied in meanwhile. - CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 66e6301fd45..cea10ef603a 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -14,7 +14,10 @@ namespace caffe { template DataLayer::DataLayer(const LayerParameter& param) : BasePrefetchingDataLayer(param), - reader_(param) { + skip_counter_() { + db_.reset(db::GetDB(param.data_param().backend())); + db_->Open(param.data_param().source(), db::READ); + cursor_.reset(db_->NewCursor()); } template @@ -27,7 +30,8 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { const int batch_size = this->layer_param_.data_param().batch_size(); // Read a data point, and use it to initialize the top blob. - Datum& datum = *(reader_.full().peek()); + Datum datum; + datum.ParseFromString(cursor_->value()); // Use data_transformer to infer the expected blob shape from datum. vector top_shape = this->data_transformer_->InferBlobShape(datum); @@ -35,22 +39,43 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, // Reshape top[0] and prefetch_data according to the batch_size. top_shape[0] = batch_size; top[0]->Reshape(top_shape); - for (int i = 0; i < this->PREFETCH_COUNT; ++i) { - this->prefetch_[i].data_.Reshape(top_shape); + for (int i = 0; i < this->prefetch_.size(); ++i) { + this->prefetch_[i]->data_.Reshape(top_shape); } - LOG(INFO) << "output data size: " << top[0]->num() << "," + LOG_IF(INFO, Caffe::root_solver()) + << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - for (int i = 0; i < this->PREFETCH_COUNT; ++i) { - this->prefetch_[i].label_.Reshape(label_shape); + for (int i = 0; i < this->prefetch_.size(); ++i) { + this->prefetch_[i]->label_.Reshape(label_shape); } } } +template +bool DataLayer::Skip() { + int size = Caffe::solver_count(); + int rank = Caffe::solver_rank(); + bool keep = (skip_counter_++ % size) == rank || + // In test mode, only rank 0 runs, so avoid skipping + this->layer_param_.phase() == TEST; + return !keep; +} + +template +void DataLayer::Next() { + cursor_->Next(); + if (!cursor_->valid()) { + LOG_IF(INFO, Caffe::root_solver()) + << "Restarting data prefetching from start."; + cursor_->SeekToFirst(); + } +} + // This function is called on prefetch thread template void DataLayer::load_batch(Batch* batch) { @@ -61,41 +86,42 @@ void DataLayer::load_batch(Batch* batch) { CPUTimer timer; CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); - - // Reshape according to the first datum of each batch - // on single input batches allows for inputs of varying dimension. const int batch_size = this->layer_param_.data_param().batch_size(); - Datum& datum = *(reader_.full().peek()); - // Use data_transformer to infer the expected blob shape from datum. - vector top_shape = this->data_transformer_->InferBlobShape(datum); - this->transformed_data_.Reshape(top_shape); - // Reshape batch according to the batch_size. - top_shape[0] = batch_size; - batch->data_.Reshape(top_shape); - - Dtype* top_data = batch->data_.mutable_cpu_data(); - Dtype* top_label = NULL; // suppress warnings about uninitialized variables - if (this->output_labels_) { - top_label = batch->label_.mutable_cpu_data(); - } + Datum datum; for (int item_id = 0; item_id < batch_size; ++item_id) { timer.Start(); - // get a datum - Datum& datum = *(reader_.full().pop("Waiting for data")); + while (Skip()) { + Next(); + } + datum.ParseFromString(cursor_->value()); read_time += timer.MicroSeconds(); - timer.Start(); + + if (item_id == 0) { + // Reshape according to the first datum of each batch + // on single input batches allows for inputs of varying dimension. + // Use data_transformer to infer the expected blob shape from datum. + vector top_shape = this->data_transformer_->InferBlobShape(datum); + this->transformed_data_.Reshape(top_shape); + // Reshape batch according to the batch_size. + top_shape[0] = batch_size; + batch->data_.Reshape(top_shape); + } + // Apply data transformations (mirror, scale, crop...) + timer.Start(); int offset = batch->data_.offset(item_id); + Dtype* top_data = batch->data_.mutable_cpu_data(); this->transformed_data_.set_cpu_data(top_data + offset); this->data_transformer_->Transform(datum, &(this->transformed_data_)); // Copy label. if (this->output_labels_) { + Dtype* top_label = batch->label_.mutable_cpu_data(); top_label[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - reader_.free().push(const_cast(&datum)); + Next(); } timer.Stop(); batch_timer.Stop(); diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp index 2f13dc641df..699b7d771dd 100644 --- a/src/caffe/layers/hdf5_data_layer.cpp +++ b/src/caffe/layers/hdf5_data_layer.cpp @@ -104,6 +104,10 @@ void HDF5DataLayer::LayerSetUp(const vector*>& bottom, // Shuffle if needed. if (this->layer_param_.hdf5_data_param().shuffle()) { std::random_shuffle(file_permutation_.begin(), file_permutation_.end()); + } else { + if (this->phase_ == TRAIN && Caffe::solver_rank() > 0) { + LOG(WARNING) << "Shuffling data is recommended for multi-GPU training"; + } } // Load the first HDF5 file and initialize the line counter. diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index 7ee7dc40714..ec0fc5b0383 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -54,6 +54,11 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const unsigned int prefetch_rng_seed = caffe_rng_rand(); prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); ShuffleImages(); + } else { + if (this->phase_ == TRAIN && Caffe::solver_rank() > 0 && + this->layer_param_.image_data_param().rand_skip() == 0) { + LOG(WARNING) << "Shuffling or skipping recommended for multi-GPU"; + } } LOG(INFO) << "A total of " << lines_.size() << " images."; @@ -77,8 +82,8 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); CHECK_GT(batch_size, 0) << "Positive batch size required"; top_shape[0] = batch_size; - for (int i = 0; i < this->PREFETCH_COUNT; ++i) { - this->prefetch_[i].data_.Reshape(top_shape); + for (int i = 0; i < this->prefetch_.size(); ++i) { + this->prefetch_[i]->data_.Reshape(top_shape); } top[0]->Reshape(top_shape); @@ -88,8 +93,8 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - for (int i = 0; i < this->PREFETCH_COUNT; ++i) { - this->prefetch_[i].label_.Reshape(label_shape); + for (int i = 0; i < this->prefetch_.size(); ++i) { + this->prefetch_[i]->label_.Reshape(label_shape); } } diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 103dd4b6af8..1bf3760e9fd 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -173,8 +173,8 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - for (int i = 0; i < this->PREFETCH_COUNT; ++i) - this->prefetch_[i].data_.Reshape( + for (int i = 0; i < this->prefetch_.size(); ++i) + this->prefetch_[i]->data_.Reshape( batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," @@ -183,8 +183,8 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - for (int i = 0; i < this->PREFETCH_COUNT; ++i) { - this->prefetch_[i].label_.Reshape(label_shape); + for (int i = 0; i < this->prefetch_.size(); ++i) { + this->prefetch_[i]->label_.Reshape(label_shape); } // data mean diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index 644cb7e97ee..aa9e8f2f386 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -22,16 +22,13 @@ namespace caffe { template -Net::Net(const NetParameter& param, const Net* root_net) - : root_net_(root_net) { +Net::Net(const NetParameter& param) { Init(param); } template Net::Net(const string& param_file, Phase phase, - const int level, const vector* stages, - const Net* root_net) - : root_net_(root_net) { + const int level, const vector* stages) { NetParameter param; ReadNetParamsFromTextFileOrDie(param_file, ¶m); // Set phase, stages and level @@ -47,8 +44,6 @@ Net::Net(const string& param_file, Phase phase, template void Net::Init(const NetParameter& in_param) { - CHECK(Caffe::root_solver() || root_net_) - << "root_net_ needs to be set for all non-root solvers"; // Set phase from the state. phase_ = in_param.state().phase(); // Filter layers based on their include/exclude rules and @@ -74,9 +69,6 @@ void Net::Init(const NetParameter& in_param) { top_id_vecs_.resize(param.layer_size()); bottom_need_backward_.resize(param.layer_size()); for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) { - // For non-root solvers, whether this layer is shared from root_net_. - bool share_from_root = !Caffe::root_solver() - && root_net_->layers_[layer_id]->ShareInParallel(); // Inherit phase from net if unset. if (!param.layer(layer_id).has_phase()) { param.mutable_layer(layer_id)->set_phase(phase_); @@ -89,13 +81,7 @@ void Net::Init(const NetParameter& in_param) { << "propagate_down param must be specified " << "either 0 or bottom_size times "; } - if (share_from_root) { - LOG(INFO) << "Sharing layer " << layer_param.name() << " from root net"; - layers_.push_back(root_net_->layers_[layer_id]); - layers_[layer_id]->SetShared(true); - } else { - layers_.push_back(LayerRegistry::CreateLayer(layer_param)); - } + layers_.push_back(LayerRegistry::CreateLayer(layer_param)); layer_names_.push_back(layer_param.name()); LOG_IF(INFO, Caffe::root_solver()) << "Creating Layer " << layer_param.name(); @@ -134,19 +120,7 @@ void Net::Init(const NetParameter& in_param) { } } // After this layer is connected, set it up. - if (share_from_root) { - // Set up size of top blobs using root_net_ - const vector*>& base_top = root_net_->top_vecs_[layer_id]; - const vector*>& this_top = this->top_vecs_[layer_id]; - for (int top_id = 0; top_id < base_top.size(); ++top_id) { - this_top[top_id]->ReshapeLike(*base_top[top_id]); - LOG(INFO) << "Created top blob " << top_id << " (shape: " - << this_top[top_id]->shape_string() << ") for shared layer " - << layer_param.name(); - } - } else { - layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); - } + layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); LOG_IF(INFO, Caffe::root_solver()) << "Setting up " << layer_names_[layer_id]; for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { @@ -546,10 +520,15 @@ Dtype Net::ForwardFromTo(int start, int end) { CHECK_LT(end, layers_.size()); Dtype loss = 0; for (int i = start; i <= end; ++i) { - // LOG(ERROR) << "Forwarding " << layer_names_[i]; + for (int c = 0; c < before_forward_.size(); ++c) { + before_forward_[c]->run(i); + } Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]); loss += layer_loss; if (debug_info_) { ForwardDebugInfo(i); } + for (int c = 0; c < after_forward_.size(); ++c) { + after_forward_[c]->run(i); + } } return loss; } @@ -591,11 +570,17 @@ void Net::BackwardFromTo(int start, int end) { CHECK_GE(end, 0); CHECK_LT(start, layers_.size()); for (int i = start; i >= end; --i) { + for (int c = 0; c < before_backward_.size(); ++c) { + before_backward_[c]->run(i); + } if (layer_need_backward_[i]) { layers_[i]->Backward( top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]); if (debug_info_) { BackwardDebugInfo(i); } } + for (int c = 0; c < after_backward_.size(); ++c) { + after_backward_[c]->run(i); + } } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index 5bc41c6a6e5..6a402d29475 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -1,6 +1,9 @@ #ifndef CPU_ONLY + #include + #endif + #include #include @@ -8,7 +11,12 @@ #include #include -#include "boost/thread.hpp" +#ifdef USE_NCCL + +#include + +#endif + #include "caffe/caffe.hpp" #include "caffe/parallel.hpp" @@ -68,14 +76,14 @@ static size_t total_size(const vector*>& params) { template Params::Params(shared_ptr > root_solver) - : size_(total_size(root_solver->net()->learnable_params())), - data_(), - diff_() { + : size_(total_size(root_solver->net()->learnable_params())), + data_(), + diff_() { } template GPUParams::GPUParams(shared_ptr > root_solver, int device) - : Params(root_solver) { + : Params(root_solver) { #ifndef CPU_ONLY int initial_device; CUDA_CHECK(cudaGetDevice(&initial_device)); @@ -86,7 +94,7 @@ GPUParams::GPUParams(shared_ptr > root_solver, int device) // Copy blob values const vector*>& net = - root_solver->net()->learnable_params(); + root_solver->net()->learnable_params(); apply_buffers(net, data_, size_, copy); CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype))); @@ -109,335 +117,259 @@ GPUParams::~GPUParams() { template void GPUParams::configure(Solver* solver) const { const vector*>& net = - solver->net()->learnable_params(); + solver->net()->learnable_params(); apply_buffers(net, data_, size_, replace_gpu); apply_buffers(net, diff_, size_, replace_gpu_diff); } -void DevicePair::compute(const vector devices, vector* pairs) { +static int getDevice() { + int device = 0; #ifndef CPU_ONLY - vector remaining(devices); - - // Depth for reduction tree - int remaining_depth = static_cast(ceil(log2(remaining.size()))); - - // Group GPUs by board - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - cudaDeviceProp a, b; - CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); - CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); - if (a.isMultiGpuBoard && b.isMultiGpuBoard) { - if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; - } - } - } - } - } - ostringstream s; - for (int i = 0; i < remaining.size(); ++i) { - s << (i ? ", " : "") << remaining[i]; - } - DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); - - // Group by P2P accessibility - remaining_depth = ceil(log2(remaining.size())); - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - int access; - CUDA_CHECK( - cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); - if (access) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; - } - } - } - } - s.str(""); - for (int i = 0; i < remaining.size(); ++i) { - s << (i ? ", " : "") << remaining[i]; - } - DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); - - // Group remaining - remaining_depth = ceil(log2(remaining.size())); - for (int d = 0; d < remaining_depth; ++d) { - for (int i = 0; i < remaining.size(); ++i) { - pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); - DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" - << remaining[i + 1]; - remaining.erase(remaining.begin() + i + 1); - } - } - - // Should only be the parent node remaining - CHECK_EQ(remaining.size(), 1); - - pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); - - CHECK(pairs->size() == devices.size()); - for (int i = 0; i < pairs->size(); ++i) { - CHECK((*pairs)[i].parent() != (*pairs)[i].device()); - for (int j = i + 1; j < pairs->size(); ++j) { - CHECK((*pairs)[i].device() != (*pairs)[j].device()); - } - } -#else - NO_GPU; + CUDA_CHECK(cudaGetDevice(&device)); #endif + return device; } -// +template +NCCL::NCCL(shared_ptr > solver, boost::barrier* barrier) + : GPUParams(solver, getDevice()), +#ifdef USE_NCCL + comm_(), +#endif + solver_(solver), barrier_(barrier) { + CHECK(barrier); + this->configure(solver.get()); + Init(); +} template -P2PSync::P2PSync(shared_ptr > root_solver, - P2PSync* parent, const SolverParameter& param) - : GPUParams(root_solver, param.device_id()), - parent_(parent), - children_(), - queue_(), - initial_iter_(root_solver->iter()), - solver_() { -#ifndef CPU_ONLY - int initial_device; - CUDA_CHECK(cudaGetDevice(&initial_device)); - const int self = param.device_id(); - CUDA_CHECK(cudaSetDevice(self)); +NCCL::NCCL(shared_ptr > solver, const string& uid) + : GPUParams(solver, getDevice()), + solver_(solver), barrier_() { + this->configure(solver.get()); + +#ifdef USE_NCCL + ncclUniqueId nccl_uid; + memcpy(&nccl_uid, &uid[0], NCCL_UNIQUE_ID_BYTES); // NOLINT(caffe/alt_fn) + NCCL_CHECK(ncclCommInitRank(&comm_, + Caffe::solver_count(), + nccl_uid, + Caffe::solver_rank())); +#endif + Init(); +} - if (parent == NULL) { - solver_ = root_solver; - } else { - Caffe::set_root_solver(false); - solver_.reset(new WorkerSolver(param, root_solver.get())); - Caffe::set_root_solver(true); - } - this->configure(solver_.get()); +template +void NCCL::Init() { +#ifdef USE_NCCL solver_->add_callback(this); - - if (parent) { - // Enable p2p access between devices - const int peer = parent->solver_->param().device_id(); - int access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); - if (access) { - CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); - } else { - LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; - } - // Allocate receiving buffer on parent - CUDA_CHECK(cudaSetDevice(peer)); - CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype))); - CUDA_CHECK(cudaSetDevice(self)); + if (solver_->param().layer_wise_reduce()) { + solver_->net()->add_after_backward(this); + CUDA_CHECK(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); } - - CUDA_CHECK(cudaSetDevice(initial_device)); #else - NO_GPU; + LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL"; #endif } template -P2PSync::~P2PSync() { -#ifndef CPU_ONLY - int initial_device; - CUDA_CHECK(cudaGetDevice(&initial_device)); - const int self = solver_->param().device_id(); - CUDA_CHECK(cudaSetDevice(self)); - - if (parent_) { - CUDA_CHECK(cudaFree(parent_grads_)); - const int peer = parent_->solver_->param().device_id(); - int access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); - if (access) { - CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); - } +NCCL::~NCCL() { +#ifdef USE_NCCL + if (solver_->param().layer_wise_reduce()) { + CUDA_CHECK(cudaStreamDestroy(stream_)); + } + if (comm_) { + ncclCommDestroy(comm_); } - - CUDA_CHECK(cudaSetDevice(initial_device)); #endif } template -void P2PSync::InternalThreadEntry() { - Caffe::SetDevice(solver_->param().device_id()); - CHECK(Caffe::root_solver()); - Caffe::set_root_solver(false); - // See if there is a defined seed and reset random state if so - if (solver_->param().random_seed() >= 0) { - // Fetch random seed and modulate by device ID to make sure - // everyone doesn't have the same seed. We seem to have some - // solver instability if we have everyone with the same seed - Caffe::set_random_seed( - solver_->param().random_seed() + solver_->param().device_id()); +void NCCL::InitSingleProcess(vector*>* nccls) { +#ifdef USE_NCCL + ncclComm_t* comms = new ncclComm_t[nccls->size()]; + int* gpu_list = new int[nccls->size()]; + for (int i = 0; i < nccls->size(); ++i) { + gpu_list[i] = (*nccls)[i]->solver_->param().device_id(); } - solver_->Step(solver_->param().max_iter() - initial_iter_); + NCCL_CHECK(ncclCommInitAll(comms, static_cast(nccls->size()), gpu_list)); + for (int i = 0; i < nccls->size(); ++i) { + (*nccls)[i]->comm_ = comms[i]; + } +#endif } template -void P2PSync::on_start() { -#ifndef CPU_ONLY -#ifdef DEBUG - int device; - CUDA_CHECK(cudaGetDevice(&device)); - CHECK(device == solver_->param().device_id()); -#else -// CHECK(false); +string NCCL::new_uid() { + string uid; +#ifdef USE_NCCL + uid.resize(NCCL_UNIQUE_ID_BYTES); + ncclUniqueId nccl_uid; + NCCL_CHECK(ncclGetUniqueId(&nccl_uid)); + memcpy(&uid[0], &nccl_uid, NCCL_UNIQUE_ID_BYTES); // NOLINT(caffe/alt_fn) #endif + return uid; +} - // Wait for update from parent - if (parent_) { - P2PSync *parent = queue_.pop(); - CHECK(parent == parent_); - } - - // Update children - for (int i = children_.size() - 1; i >= 0; i--) { - Dtype* src = data_; - Dtype* dst = children_[i]->data_; - -#ifdef DEBUG - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == children_[i]->solver_->param().device_id()); +template +void NCCL::Broadcast() { +#ifdef USE_NCCL + NCCL_CHECK(ncclBcast(data_, static_cast(size_), + nccl::dataType::type, 0, + comm_, cudaStreamDefault)); #endif +} - CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), - cudaMemcpyDeviceToDevice, cudaStreamDefault)); - CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - children_[i]->queue_.push(this); +template +void NCCL::on_start() { + if (barrier_) { // NULL in multi process case + barrier_->wait(); } -#endif } template -void P2PSync::on_gradients_ready() { -#ifndef CPU_ONLY +void NCCL::run(int layer) { + CHECK(solver_->param().layer_wise_reduce()); +#ifdef USE_NCCL + vector > >& blobs = + solver_->net()->layers()[layer]->blobs(); #ifdef DEBUG - int device; - CUDA_CHECK(cudaGetDevice(&device)); - CHECK(device == solver_->param().device_id()); + // Assert blobs are contiguous to reduce in one step (e.g. bias often small) + for (int i = 1; i < blobs.size(); ++i) { + CHECK_EQ(blobs[i - 1]->gpu_diff() + blobs[i - 1]->count(), + blobs[i + 0]->gpu_diff()); + } #endif + if (blobs.size() > 0) { + // Make sure default stream is done computing gradients + // Could be replaced by cudaEventRecord+cudaStreamWaitEvent + // to avoid blocking, but it's actually slower. + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - // Sum children gradients as they appear in the queue - for (int i = 0; i < children_.size(); ++i) { - P2PSync *child = queue_.pop(); - Dtype* src = child->parent_grads_; - Dtype* dst = diff_; - -#ifdef DEBUG - bool ok = false; - for (int j = 0; j < children_.size(); ++j) { - if (child == children_[j]) { - ok = true; - } + // Reduce asynchronously + int size = 0; + for (int i = 0; i < blobs.size(); ++i) { + size += blobs[i]->count(); } - CHECK(ok); - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == device); -#endif - - caffe_gpu_add(size_, src, dst, dst); + NCCL_CHECK(ncclAllReduce(blobs[0]->mutable_gpu_diff(), + blobs[0]->mutable_gpu_diff(), + size, + nccl::dataType::type, + ncclSum, comm_, stream_)); + caffe_gpu_scal(size, (Dtype) 1.0 / comm_->nDev, + blobs[0]->mutable_gpu_diff(), stream_); } - - // Send gradients to parent - if (parent_) { - Dtype* src = diff_; - Dtype* dst = parent_grads_; - -#ifdef DEBUG - cudaPointerAttributes attributes; - CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); - CHECK(attributes.device == device); - CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); - CHECK(attributes.device == parent_->solver_->param().device_id()); #endif +} - CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // - cudaMemcpyDeviceToDevice, cudaStreamDefault)); - CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - parent_->queue_.push(this); +template +void NCCL::on_gradients_ready() { +#ifdef USE_NCCL + if (solver_->param().layer_wise_reduce()) { + CHECK_EQ(solver_->net()->params().size(), + solver_->net()->learnable_params().size()) + << "Layer-wise reduce is not supported for nets wih shared weights."; + + // Make sure reduction is done before applying gradients + CUDA_CHECK(cudaStreamSynchronize(stream_)); } else { - // Loss functions divide gradients by the batch size, so to compensate - // for split batch, the root solver divides by number of solvers. - caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); + NCCL_CHECK(ncclAllReduce(diff_, diff_, static_cast(size_), + nccl::dataType::type, ncclSum, comm_, + cudaStreamDefault)); + caffe_gpu_scal(static_cast(size_), (Dtype) 1.0 / comm_->nDev, diff_); } #endif } +#ifdef USE_NCCL template -void P2PSync::Prepare(const vector& gpus, - vector > >* syncs) { - // Pair devices for map-reduce synchronization - vector pairs; - DevicePair::compute(gpus, &pairs); - ostringstream s; - for (int i = 1; i < pairs.size(); ++i) { - s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); +class Worker : public InternalThread { + public: + explicit Worker(SolverParameter params, int start_iter, + boost::barrier* barrier, vector*>* nccls) + : params_(params), start_iter_(start_iter), + barrier_(barrier), nccls_(nccls) { } - LOG(INFO)<< "GPUs pairs " << s.str(); - - SolverParameter param(solver_->param()); - - // Build the GPU tree by finding the parent for each solver - for (int attempts = 0; attempts < pairs.size(); ++attempts) { - for (int i = 1; i < pairs.size(); ++i) { - if (!syncs->at(i).get()) { - P2PSync* parent = NULL; - for (int j = 0; j < syncs->size(); ++j) { - P2PSync* sync = j == 0 ? this : syncs->at(j).get(); - if (sync) { - const SolverParameter& p = sync->solver()->param(); - if (p.device_id() == pairs[i].parent()) { - parent = sync; - } - } - } - if (parent) { - param.set_device_id(pairs[i].device()); - syncs->at(i).reset(new P2PSync(solver_, parent, param)); - parent->children_.push_back((P2PSync*) syncs->at(i).get()); - } - } + virtual ~Worker() {} + + protected: + void InternalThreadEntry() { + // Create solver and install callbacks + shared_ptr > s(SolverRegistry::CreateSolver(params_)); + s->set_iter(start_iter_); + NCCL nccl(s, barrier_); + s->add_callback(&nccl); + if (s->param().layer_wise_reduce()) { + s->net()->add_after_backward(&nccl); } + (*nccls_)[Caffe::solver_rank()] = &nccl; + + // Wait for other threads + barrier_->wait(); + // Wait for NCCL init + barrier_->wait(); + // Broadcast rank 0 weights + nccl.Broadcast(); + // Solve + s->Step(params_.max_iter() - start_iter_); } -} - -template -void P2PSync::Run(const vector& gpus) { - vector > > syncs(gpus.size()); - Prepare(gpus, &syncs); - LOG(INFO)<< "Starting Optimization"; + SolverParameter params_; + int start_iter_; + boost::barrier* barrier_; + vector*>* nccls_; +}; +#endif - for (int i = 1; i < syncs.size(); ++i) { - syncs[i]->StartInternalThread(); +template +void NCCL::Run(shared_ptr > solver, + const vector& gpus) { +#ifdef USE_NCCL + boost::barrier barrier(static_cast(gpus.size())); + vector*> nccls(gpus.size()); + + // Create workers + vector > > workers(gpus.size()); + for (int i = 1; i < gpus.size(); ++i) { + CUDA_CHECK(cudaSetDevice(gpus[i])); + Caffe::set_solver_rank(i); + SolverParameter param(solver->param()); + param.set_device_id(gpus[i]); + Worker* w = new Worker(param, solver->iter(), + &barrier, &nccls); + w->StartInternalThread(); + workers[i].reset(w); } + CUDA_CHECK(cudaSetDevice(gpus[0])); + Caffe::set_solver_rank(0); - // Run root solver on current thread - solver_->Solve(); - - for (int i = 1; i < syncs.size(); ++i) { - syncs[i]->StopInternalThread(); + NCCL nccl(solver, &barrier); + solver->add_callback(&nccl); + if (solver->param().layer_wise_reduce()) { + solver->net()->add_after_backward(&nccl); + } + nccls[0] = &nccl; + // Wait for workers + barrier.wait(); + // Init NCCL + InitSingleProcess(&nccls); + barrier.wait(); + // Run first solver on current thread + nccl.Broadcast(); + solver->Solve(); + + // Wait for shutdown + for (int i = 1; i < gpus.size(); ++i) { + workers[i]->StopInternalThread(); } +#endif } INSTANTIATE_CLASS(Params); INSTANTIATE_CLASS(GPUParams); -INSTANTIATE_CLASS(P2PSync); +#ifdef USE_NCCL +INSTANTIATE_CLASS(Worker); +#endif +INSTANTIATE_CLASS(NCCL); } // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 6940a705eb6..949a7f6d095 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 41 (last added: type) +// SolverParameter next available ID: 42 (last added: layer_wise_reduce) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -239,6 +239,9 @@ message SolverParameter { } // DEPRECATED: use type instead of solver_type optional SolverType solver_type = 30 [default = SGD]; + + // Overlap compute and communication for data parallel training + optional bool layer_wise_reduce = 41 [default = true]; } // A message that stores the solver snapshots @@ -653,8 +656,8 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; - // Prefetch queue (Number of batches to prefetch to host memory, increase if - // data access bandwidth varies). + // Prefetch queue (Increase if data feeding bandwidth varies, within the + // limit of device memory for GPU training) optional uint32 prefetch = 10 [default = 4]; } diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index ece3913e88a..1c1a9e59565 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -26,16 +26,14 @@ SolverAction::Enum Solver::GetRequestedAction() { } template -Solver::Solver(const SolverParameter& param, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { +Solver::Solver(const SolverParameter& param) + : net_(), callbacks_(), requested_early_exit_(false) { Init(param); } template -Solver::Solver(const string& param_file, const Solver* root_solver) - : net_(), callbacks_(), root_solver_(root_solver), - requested_early_exit_(false) { +Solver::Solver(const string& param_file) + : net_(), callbacks_(), requested_early_exit_(false) { SolverParameter param; ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); @@ -43,15 +41,13 @@ Solver::Solver(const string& param_file, const Solver* root_solver) template void Solver::Init(const SolverParameter& param) { - CHECK(Caffe::root_solver() || root_solver_) - << "root_solver_ needs to be set for all non-root solvers"; LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: " << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; CheckSnapshotWritePermissions(); - if (Caffe::root_solver() && param_.random_seed() >= 0) { - Caffe::set_random_seed(param_.random_seed()); + if (param_.random_seed() >= 0) { + Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank()); } // Scaffolding code InitTrainNet(); @@ -101,11 +97,7 @@ void Solver::InitTrainNet() { net_state.MergeFrom(net_param.state()); net_state.MergeFrom(param_.train_state()); net_param.mutable_state()->CopyFrom(net_state); - if (Caffe::root_solver()) { - net_.reset(new Net(net_param)); - } else { - net_.reset(new Net(net_param, root_solver_->net_.get())); - } + net_.reset(new Net(net_param)); } template @@ -180,12 +172,7 @@ void Solver::InitTestNets() { net_params[i].mutable_state()->CopyFrom(net_state); LOG(INFO) << "Creating test net (#" << i << ") specified by " << sources[i]; - if (Caffe::root_solver()) { - test_nets_[i].reset(new Net(net_params[i])); - } else { - test_nets_[i].reset(new Net(net_params[i], - root_solver_->test_nets_[i].get())); - } + test_nets_[i].reset(new Net(net_params[i])); test_nets_[i]->set_debug_info(param_.debug_info()); } } @@ -197,14 +184,16 @@ void Solver::Step(int iters) { int average_loss = this->param_.average_loss(); losses_.clear(); smoothed_loss_ = 0; + iteration_timer_.Start(); while (iter_ < stop_iter) { // zero-init the params net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization()) - && Caffe::root_solver()) { - TestAll(); + && (iter_ > 0 || param_.test_initialization())) { + if (Caffe::root_solver()) { + TestAll(); + } if (requested_early_exit_) { // Break out of the while loop because stop was requested while testing. break; @@ -225,8 +214,13 @@ void Solver::Step(int iters) { // average the loss across iterations for smoothed reporting UpdateSmoothedLoss(loss, start_iter, average_loss); if (display) { + float lapse = iteration_timer_.Seconds(); + float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ - << ", loss = " << smoothed_loss_; + << " (" << per_s << " iter/s, " << lapse << "s/" + << param_.display() << " iters), loss = " << smoothed_loss_; + iteration_timer_.Start(); + iterations_last_ = iter_; const vector*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp index e78eadca141..d8107e1e623 100644 --- a/src/caffe/solvers/adagrad_solver.cpp +++ b/src/caffe/solvers/adagrad_solver.cpp @@ -12,7 +12,6 @@ void adagrad_update_gpu(int N, Dtype* g, Dtype* h, Dtype delta, template void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype delta = this->param_.delta(); diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp index 23ab2d4369a..7c1fac1f884 100644 --- a/src/caffe/solvers/nesterov_solver.cpp +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -12,7 +12,6 @@ void nesterov_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum, template void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); const vector*>& net_params = this->net_->learnable_params(); const vector& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.momentum(); diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index f30f316d1a0..ad6abe54a0a 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -100,10 +100,10 @@ void SGDSolver::ClipGradients() { template void SGDSolver::ApplyUpdate() { - CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_ + << ", lr = " << rate; } ClipGradients(); for (int param_id = 0; param_id < this->net_->learnable_params().size(); diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 975a8f0f88a..6af97deec64 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -36,7 +36,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { string snapshot_prefix_; shared_ptr > solver_; - shared_ptr > sync_; int seed_; // Dimensions are determined by generate_sample_data.py // TODO this is brittle and the hdf5 file should be checked instead. @@ -202,9 +201,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { gpus.push_back(i); } Caffe::set_solver_count(gpus.size()); - this->sync_.reset(new P2PSync( - this->solver_, NULL, this->solver_->param())); - this->sync_->Run(gpus); + caffe::NCCL::Run(this->solver_, gpus); Caffe::set_solver_count(1); } if (snapshot) { diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp index 058668fe28c..f69d210459c 100644 --- a/src/caffe/util/blocking_queue.cpp +++ b/src/caffe/util/blocking_queue.cpp @@ -1,7 +1,6 @@ #include #include -#include "caffe/data_reader.hpp" #include "caffe/layers/base_data_layer.hpp" #include "caffe/parallel.hpp" #include "caffe/util/blocking_queue.hpp" @@ -88,9 +87,5 @@ size_t BlockingQueue::size() const { template class BlockingQueue*>; template class BlockingQueue*>; -template class BlockingQueue; -template class BlockingQueue >; -template class BlockingQueue*>; -template class BlockingQueue*>; } // namespace caffe diff --git a/src/caffe/util/db_lmdb.cpp b/src/caffe/util/db_lmdb.cpp index fb1d4956aa1..491a9bd03a6 100644 --- a/src/caffe/util/db_lmdb.cpp +++ b/src/caffe/util/db_lmdb.cpp @@ -32,7 +32,7 @@ void LMDB::Open(const string& source, Mode mode) { MDB_CHECK(rc); } #endif - LOG(INFO) << "Opened lmdb " << source; + LOG_IF(INFO, Caffe::root_solver()) << "Opened lmdb " << source; } LMDBCursor* LMDB::NewCursor() { diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu index 4c587537435..6d001026082 100644 --- a/src/caffe/util/math_functions.cu +++ b/src/caffe/util/math_functions.cu @@ -90,6 +90,26 @@ void caffe_gpu_scal(const int N, const double alpha, double *X) { CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); } +template <> +void caffe_gpu_scal(const int N, const float alpha, float* X, + cudaStream_t str) { + cudaStream_t initial_stream; + CUBLAS_CHECK(cublasGetStream(Caffe::cublas_handle(), &initial_stream)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), str)); + CUBLAS_CHECK(cublasSscal(Caffe::cublas_handle(), N, &alpha, X, 1)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), initial_stream)); +} + +template <> +void caffe_gpu_scal(const int N, const double alpha, double* X, + cudaStream_t str) { + cudaStream_t initial_stream; + CUBLAS_CHECK(cublasGetStream(Caffe::cublas_handle(), &initial_stream)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), str)); + CUBLAS_CHECK(cublasDscal(Caffe::cublas_handle(), N, &alpha, X, 1)); + CUBLAS_CHECK(cublasSetStream(Caffe::cublas_handle(), initial_stream)); +} + template <> void caffe_gpu_axpby(const int N, const float alpha, const float* X, const float beta, float* Y) { diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 9bf4214ad93..7f9d95fc73c 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -244,11 +244,10 @@ int train() { CopyLayers(solver.get(), FLAGS_weights); } + LOG(INFO) << "Starting Optimization"; if (gpus.size() > 1) { - caffe::P2PSync sync(solver, NULL, solver->param()); - sync.Run(gpus); + caffe::NCCL::Run(solver, gpus); } else { - LOG(INFO) << "Starting Optimization"; solver->Solve(); } LOG(INFO) << "Optimization Done.";