Skip to content

Commit

Permalink
Merge pull request #2211 from nsubtil/fix-cudnn-algo
Browse files Browse the repository at this point in the history
Fallback to different cuDNN algorithm when under memory pressure; fix #2197
  • Loading branch information
shelhamer committed Mar 26, 2015
2 parents c308986 + add73fb commit 5c009d8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
// Initialize CUDA streams and cuDNN.
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
workspaceSizeInBytes = 0;
workspace = NULL;

for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
Expand Down
19 changes: 15 additions & 4 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
Dtype* top_data = top[i]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();

size_t workspace_limit_bytes = this->kernel_h_ *
this->kernel_w_ *
this->channels_ *
sizeof(int) + 1;

// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
cudnnConvolutionFwdAlgo_t algo;
Expand All @@ -32,8 +37,8 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
filter_desc_,
conv_descs_[i],
top_descs_[i],
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0, // memoryLimitInBytes,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_limit_bytes, // memoryLimitInBytes,
&algo));

// get minimum size of the workspace needed for the desired algorithm
Expand All @@ -45,13 +50,19 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
conv_descs_[i],
top_descs_[i],
algo,
&workspaceSizeInBytes));
&workspaceSizeInBytes_temp));

if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
workspaceSizeInBytes = workspaceSizeInBytes_temp;
// free the existing workspace and allocate a new (larger) one
cudaFree(this->workspace);
cudaMalloc(&(this->workspace), workspaceSizeInBytes);
cudaError_t err = cudaMalloc(&(this->workspace), workspaceSizeInBytes);
if (err != cudaSuccess) {
// force zero memory path
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
workspace = NULL;
workspaceSizeInBytes = 0;
}
}

// Filters.
Expand Down

0 comments on commit 5c009d8

Please sign in to comment.