Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MKL/non-MKL Reconciliation #165

Merged
merged 24 commits into from
Mar 23, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e4e93f4
compile caffe without MKL (dependency replaced by boost::random, Eigen3)
rodrigob Dec 8, 2013
04ca88a
Fixed uniform distribution upper bound to be inclusive
kloudkl Jan 11, 2014
d666bdc
Fixed FlattenLayer Backward_cpu/gpu have no return value
kloudkl Jan 11, 2014
38457e1
Fix test stochastic pooling stepsize/threshold to be same as max pooling
kloudkl Jan 11, 2014
788f070
Fix math funcs, add tests, change Eigen Map to unaligned for lrn_layer
kloudkl Jan 12, 2014
d37a995
relax precision of MultinomialLogisticLossLayer test
shelhamer Jan 9, 2014
2ae2683
nextafter templates off one type
Jan 22, 2014
b925739
mean_bound and sample_mean need referencing with this
Jan 22, 2014
93c9f15
make uniform distribution usage compatible with boost 1.46
jeffdonahue Jan 22, 2014
4b1fba7
use boost variate_generator to pass tests w/ boost 1.46 (Gaussian filler
jeffdonahue Jan 22, 2014
b3e4ac5
change all Rng's to use variate_generator for consistency
jeffdonahue Jan 22, 2014
6cbf9f1
add bernoulli rng test to demonstrate bug (generates all 0s unless p ==
jeffdonahue Jan 29, 2014
4f6b266
fix bernoulli generator bug
jeffdonahue Jan 29, 2014
1cf822e
Replace atlas with multithreaded OpenBLAS to speed-up on multi-core CPU
kloudkl Feb 7, 2014
a8c9b66
major refactoring allow coexistence of MKL and non-MKL cases
Feb 12, 2014
c028d09
rewrite MKL flag note, polish makefile
shelhamer Feb 15, 2014
f6cbe2c
make MKL switch surprise-proof
shelhamer Feb 18, 2014
ff27988
comment out stray mkl includes
shelhamer Feb 27, 2014
40aa12a
Fixed order of cblas and atlas linker flags
jamt9000 Mar 3, 2014
a9e772f
Added extern C wrapper to cblas.h include
jamt9000 Mar 3, 2014
453fcf9
clean up residual mkl comments and code
shelhamer Mar 21, 2014
aaa2646
lint
shelhamer Mar 21, 2014
19bcf2b
Hide boost rng behind facade for osx compatibility
shelhamer Mar 22, 2014
bece205
Set copyright to BVLC and contributors.
shelhamer Mar 22, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -86,27 +86,37 @@ CUDA_LIB_DIR := $(CUDA_DIR)/lib64 $(CUDA_DIR)/lib
MKL_INCLUDE_DIR := $(MKL_DIR)/include
MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
INCLUDE_DIRS += ./src ./include $(CUDA_INCLUDE_DIR)
LIBRARY_DIRS += $(CUDA_LIB_DIR)
LIBRARIES := cudart cublas curand \
mkl_rt \
pthread \
glog protobuf leveldb \
snappy \
glog protobuf leveldb snappy \
boost_system \
hdf5_hl hdf5 \
opencv_core opencv_highgui opencv_imgproc
PYTHON_LIBRARIES := boost_python python2.7
WARNINGS := -Wall

COMMON_FLAGS := -DNDEBUG -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
COMMON_FLAGS := -DNDEBUG -O2

# MKL switch (default = non-MKL)
USE_MKL ?= 0
ifeq ($(USE_MKL), 1)
LIBRARIES += mkl_rt
COMMON_FLAGS += -DUSE_MKL
INCLUDE_DIRS += $(MKL_INCLUDE_DIR)
LIBRARY_DIRS += $(MKL_LIB_DIR)
else
LIBRARIES += cblas atlas
endif

COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS)
NVCCFLAGS := -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir)) \
$(foreach library,$(LIBRARIES),-l$(library))
PYTHON_LDFLAGS := $(LDFLAGS) $(foreach library,$(PYTHON_LIBRARIES),-l$(library))


##############################
# Define build targets
##############################
Expand Down
2 changes: 2 additions & 0 deletions Makefile.config.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \
-gencode arch=compute_30,code=sm_30 \
-gencode arch=compute_35,code=sm_35

# MKL switch: set to 1 for MKL
USE_MKL := 0
# MKL directory contains include/ and lib/ directions that we need.
MKL_DIR := /opt/intel/mkl

Expand Down
8 changes: 8 additions & 0 deletions include/caffe/blob.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Blob {
inline int count() const {return count_; }
inline int offset(const int n, const int c = 0, const int h = 0,
const int w = 0) const {
CHECK_GE(n, 0);
CHECK_LE(n, num_);
CHECK_GE(channels_, 0);
CHECK_LE(c, channels_);
CHECK_GE(height_, 0);
CHECK_LE(h, height_);
CHECK_GE(width_, 0);
CHECK_LE(w, width_);
return ((n * channels_ + c) * height_ + h) * width_ + w;
}
// Copy from source. If copy_diff is false, we copy the data; if copy_diff
Expand Down
102 changes: 59 additions & 43 deletions include/caffe/common.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2013 Yangqing Jia
// Copyright 2014 BVLC and contributors.

#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
Expand All @@ -7,28 +7,8 @@
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
// cuda driver types
#include <driver_types.h>
#include <driver_types.h> // cuda driver types
#include <glog/logging.h>
#include <mkl_vsl.h>

// various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)

#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

// After a kernel is executed, this will check the error and if there is one,
// exit loudly.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())

// Disable the copy and assignment operator for a class.
#define DISABLE_COPY_AND_ASSIGN(classname) \
Expand All @@ -45,6 +25,23 @@ private:\
// is executed we will see a fatal log.
#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"

// CUDA: various checks for different function calls.
#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

// CUDA: check for error after kernel execution and exit loudly if there is one.
#define CUDA_POST_KERNEL_CHECK \
if (cudaSuccess != cudaPeekAtLastError()) \
LOG(FATAL) << "Cuda kernel failed. Error: " \
<< cudaGetErrorString(cudaPeekAtLastError())


namespace caffe {

Expand All @@ -53,20 +50,6 @@ namespace caffe {
using boost::shared_ptr;


// We will use 1024 threads per block, which requires cuda sm_2x or above.
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif



inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}


// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas, curand, etc.
class Caffe {
Expand All @@ -81,15 +64,32 @@ class Caffe {
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };

// The getters for the variables.
// Returns the cublas handle.

// This random number generator facade hides boost and CUDA rng
// implementation from one another (for cross-platform compatibility).
class RNG {
public:
RNG();
explicit RNG(unsigned int seed);
~RNG();
RNG(const RNG&);
RNG& operator=(const RNG&);
const void* generator() const;
void* generator();
private:
class Generator;
Generator* generator_;
};

// Getters for boost rng, curand, and cublas handles
inline static RNG &rng_stream() {
return Get().random_generator_;
}
inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
// Returns the curand generator.
inline static curandGenerator_t curand_generator() {
return Get().curand_generator_;
}
// Returns the MKL random stream.
inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }

// Returns the mode: running on CPU or GPU.
inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
Expand All @@ -102,7 +102,7 @@ class Caffe {
inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand
// Sets the random seed of both boost and curand
static void set_random_seed(const unsigned int seed);
// Sets the device. Since we have cublas and curand stuff, set device also
// requires us to reset those values.
Expand All @@ -113,7 +113,8 @@ class Caffe {
protected:
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
RNG random_generator_;

Brew mode_;
Phase phase_;
static shared_ptr<Caffe> singleton_;
Expand All @@ -126,6 +127,21 @@ class Caffe {
};


// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif

// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}


} // namespace caffe

#endif // CAFFE_COMMON_HPP_
1 change: 0 additions & 1 deletion include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#ifndef CAFFE_FILLER_HPP
#define CAFFE_FILLER_HPP

#include <mkl.h>
#include <string>

#include "caffe/common.hpp"
Expand Down
12 changes: 10 additions & 2 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#ifndef CAFFE_UTIL_MATH_FUNCTIONS_H_
#define CAFFE_UTIL_MATH_FUNCTIONS_H_

#include <mkl.h>

#include <cublas_v2.h>

#include "caffe/util/mkl_alternate.hpp"

namespace caffe {

// Decaf gemm provides a simpler interface to the gemm functions, with the
Expand Down Expand Up @@ -45,7 +47,7 @@ void caffe_gpu_axpy(const int N, const Dtype alpha, const Dtype* X,
Dtype* Y);

template <typename Dtype>
void caffe_axpby(const int N, const Dtype alpha, const Dtype* X,
void caffe_cpu_axpby(const int N, const Dtype alpha, const Dtype* X,
const Dtype beta, Dtype* Y);

template <typename Dtype>
Expand Down Expand Up @@ -85,13 +87,19 @@ void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
template <typename Dtype>
void caffe_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);

template <typename Dtype>
Dtype caffe_nextafter(const Dtype b);

template <typename Dtype>
void caffe_vRngUniform(const int n, Dtype* r, const Dtype a, const Dtype b);

template <typename Dtype>
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);

template <typename Dtype>
void caffe_vRngBernoulli(const int n, Dtype* r, const double p);

template <typename Dtype>
void caffe_exp(const int n, const Dtype* a, Dtype* y);

Expand Down
97 changes: 97 additions & 0 deletions include/caffe/util/mkl_alternate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2013 Rowland Depp

#ifndef CAFFE_UTIL_MKL_ALTERNATE_H_
#define CAFFE_UTIL_MKL_ALTERNATE_H_

#ifdef USE_MKL

#include <mkl.h>

#else // If use MKL, simply include the MKL header

extern "C" {
#include <cblas.h>
}
#include <math.h>

// Functions that caffe uses but are not present if MKL is not linked.

// A simple way to define the vsl unary functions. The operation should
// be in the form e.g. y[i] = sqrt(a[i])
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, float* y) { \
v##name<float>(n, a, y); \
} \
inline void vd##name( \
const int n, const double* a, double* y) { \
v##name<double>(n, a, y); \
}

DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));

// A simple way to define the vsl unary functions with singular parameter b.
// The operation should be in the form e.g. y[i] = pow(a[i], b)
#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const float b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b));

// A simple way to define the vsl binary functions. The operation should
// be in the form e.g. y[i] = a[i] + b[i]
#define DEFINE_VSL_BINARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(b); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vs##name( \
const int n, const float* a, const float* b, float* y) { \
v##name<float>(n, a, b, y); \
} \
inline void vd##name( \
const int n, const double* a, const double* b, double* y) { \
v##name<double>(n, a, b, y); \
}

DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]);
DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]);
DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]);
DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]);

// In addition, MKL comes with an additional function axpby that is not present
// in standard blas. We will simply use a two-step (inefficient, of course) way
// to mimic that.
inline void cblas_saxpby(const int N, const float alpha, const float* X,
const int incX, const float beta, float* Y,
const int incY) {
cblas_sscal(N, beta, Y, incY);
cblas_saxpy(N, alpha, X, incX, Y, incY);
}
inline void cblas_daxpby(const int N, const double alpha, const double* X,
const int incX, const double beta, double* Y,
const int incY) {
cblas_dscal(N, beta, Y, incY);
cblas_daxpy(N, alpha, X, incX, Y, incY);
}

#endif // USE_MKL
#endif // CAFFE_UTIL_MKL_ALTERNATE_H_
19 changes: 19 additions & 0 deletions include/caffe/util/rng.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2014 BVLC and contributors.

#ifndef CAFFE_RNG_CPP_HPP_
#define CAFFE_RNG_CPP_HPP_

#include <boost/random/mersenne_twister.hpp>
#include "caffe/common.hpp"

namespace caffe {

typedef boost::mt19937 rng_t;
inline rng_t& caffe_rng() {
Caffe::RNG &generator = Caffe::rng_stream();
return *(caffe::rng_t*) generator.generator();
}

} // namespace caffe

#endif // CAFFE_RNG_HPP_
Loading