Skip to content

Commit

Permalink
fix CUDNN_VERSION for backward of CudnnBatchNormLayer (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingqing01 authored and reyoung committed Sep 12, 2016
1 parent 674d69c commit dcd87fd
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions paddle/gserver/layers/CudnnBatchNormLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,30 @@ void CudnnBatchNormLayer::backward(const UpdateCallback& callback) {
} else {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
#if CUDNN_VERSION < 5000

// because of the different api of cudnn v4 and v5.
if (weight_->getWGrad()) {
create(tmpWGrad_, 1, channels_, &gammaGrad);
}
if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad()) {
create(tmpWGrad_, 1, channels_, &gammaGrad);
}
if (biases_ && biases_->getWGrad()) {
create(tmpBiasGrad_, 1, channels_, &betaGrad);
}
}
#endif

hl_batch_norm_backward(ioDesc_, input, ioDesc_, outGrad,
ioDesc_, inGrad, bnParamDesc_,
gamma, gammaGrad, betaGrad,
EPS, savedMean, savedInvVar);

#if CUDNN_VERSION < 5000
// because of the different api of cudnn v4 and v5.
if (weight_->getWGrad() && biases_->getWGrad()) {
weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
if (hl_get_cudnn_lib_version() < 5000) {
if (weight_->getWGrad() && biases_->getWGrad()) {
weight_->getWGrad()->add(*tmpWGrad_);
biases_->getWGrad()->add(*tmpBiasGrad_);
}
}
#endif

{
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
biases_->getParameterPtr()->incUpdate(callback);
Expand Down

0 comments on commit dcd87fd

Please sign in to comment.