Skip to content

Commit

Permalink
Merge pull request #88 from thatguymike/cub_usage_fixes
Browse files Browse the repository at this point in the history
Dynamically reshape and adjust pool.  Part 1 of fixing this.  @borisfom has some improvements in #89 but we have more debugging to do.
  • Loading branch information
thatguymike committed Dec 2, 2015
2 parents 3c378a8 + 1fdefe9 commit 082bf14
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
148 changes: 85 additions & 63 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,19 @@ namespace caffe {
const Dtype* bottom_data = bottom[i]->gpu_data();
Dtype* top_data = top[i]->mutable_gpu_data();

// Test free space and force reshape if allocations have changed
size_t workspace_limit_bytes, total_memory;
gpu_memory::getInfo(&workspace_limit_bytes, &total_memory);
if (workspace_fwd_sizes_[i] > workspace_limit_bytes) {
this->Reshape(bottom, top);
}

// !!!! Not safe if group_ > 1 !!!!
workspace.reserve(workspace_fwd_sizes_[i]);

// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
workspace.reserve(workspace_fwd_sizes_[i]);
// Filters.
// Filters.
CUDNN_CHECK(cudnnConvFwd(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bottom_descs_[i],
Expand All @@ -47,8 +55,6 @@ namespace caffe {
top_descs_[i],
top_data + top_offset_ * g));

workspace.release();

// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
Expand All @@ -62,6 +68,7 @@ namespace caffe {
}
}

workspace.release();
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
Expand All @@ -78,6 +85,7 @@ namespace caffe {
const Dtype* weight = NULL;
Dtype* weight_diff = NULL;


if (this->param_propagate_down_[0]) {
weight = this->blobs_[0]->gpu_data();
weight_diff = this->blobs_[0]->mutable_gpu_diff();
Expand All @@ -89,69 +97,83 @@ namespace caffe {
}

for (int i = 0; i < top.size(); ++i) {
const Dtype* top_diff = top[i]->gpu_diff();

// Backward through cuDNN in parallel over groups and gradients.
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvBwdBias(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
top_descs_[i],
top_diff + top_offset_ * g,
cudnn::dataType<Dtype>::one,
bias_desc_,
bias_diff + bias_offset_ * g));
const Dtype* top_diff = top[i]->gpu_diff();

// Test free space and force reshape if allocations have changed
size_t workspace_limit_bytes, total_memory;
gpu_memory::getInfo(&workspace_limit_bytes, &total_memory);
if (workspace_bwd_filter_sizes_[i] > workspace_limit_bytes ||
workspace_bwd_data_sizes_[i] > workspace_limit_bytes) {
this->Reshape(bottom, top);
}

// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
workspace.reserve(workspace_bwd_filter_sizes_[i]);
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvBwdFilter(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bottom_descs_[i],
bottom_data + bottom_offset_ * g,
top_descs_[i],
top_diff + top_offset_ * g,
conv_descs_[i],
bwd_filter_algo_[i],
workspace.data(),
workspace.size(),
cudnn::dataType<Dtype>::one,
filter_desc_,
weight_diff + weight_offset_ * g));
workspace.release();
// To remove pressure on allocator, allocate the larger of the
// workspaces needed for the following steps
size_t workspace_reserve = workspace_bwd_filter_sizes_[i] >
workspace_bwd_data_sizes_[i] ?
workspace_bwd_filter_sizes_[i] : workspace_bwd_data_sizes_[i];

// !!!! Not safe if group_ > 1 !!!!
workspace.reserve(workspace_reserve);

// Backward through cuDNN in parallel over groups and gradients.
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvBwdBias(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
top_descs_[i],
top_diff + top_offset_ * g,
cudnn::dataType<Dtype>::one,
bias_desc_,
bias_diff + bias_offset_ * g));
}

// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvBwdFilter(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bottom_descs_[i],
bottom_data + bottom_offset_ * g,
top_descs_[i],
top_diff + top_offset_ * g,
conv_descs_[i],
bwd_filter_algo_[i],
workspace.data(),
workspace.size(),
cudnn::dataType<Dtype>::one,
filter_desc_,
weight_diff + weight_offset_ * g));
}

// Gradient w.r.t. bottom data.
if (propagate_down[i]) {
if (weight == NULL) {
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvBwdData(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
filter_desc_,
weight + this->weight_offset_ * g,
top_descs_[i],
top_diff + top_offset_ * g,
conv_descs_[i],
bwd_data_algo_[i],
workspace.data(),
workspace.size(),
cudnn::dataType<Dtype>::zero,
bottom_descs_[i],
bottom_diff + bottom_offset_ * g));
}
}

// Gradient w.r.t. bottom data.
if (propagate_down[i]) {
if (weight == NULL) {
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
workspace.reserve(workspace_bwd_data_sizes_[i]);
CUDNN_CHECK(cudnnConvBwdData(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
filter_desc_,
weight + this->weight_offset_ * g,
top_descs_[i],
top_diff + top_offset_ * g,
conv_descs_[i],
bwd_data_algo_[i],
workspace.data(),
workspace.size(),
cudnn::dataType<Dtype>::zero,
bottom_descs_[i],
bottom_diff + bottom_offset_ * g));
workspace.release();
}
}

// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
workspace.release();
// Synchronize the work across groups, each of which went into its own
// stream, by launching an empty kernel into the default (null) stream.
// NOLINT_NEXT_LINE(whitespace/operators)
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
}
}

Expand Down
6 changes: 5 additions & 1 deletion src/caffe/util/gpu_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,11 @@ namespace caffe {

cubAlloc = new cub::CachingDeviceAllocator( 2, // defaults
6,
16,
32, // largest
// cached
// allocation
// becomes
// 2^32 here
poolsize_,
false,
debug_);
Expand Down

0 comments on commit 082bf14

Please sign in to comment.