Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Bug Fixed] fix batch norm when fix_gamma is True #18492

Closed
wants to merge 1 commit into from

Conversation

wkcn
Copy link
Member

@wkcn wkcn commented Jun 5, 2020

Description

Fix the issue #16297 #18475
When fix_gamma is True, the batch norm will set grad_req of gamma to null.
CudnnBatchNorm will be invoked when cudnn_off=False and axis=1, and the gradient of beta will be accumulated by mistake.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Fix the bug of BatchNorm when fix_gamma=True.
  • Add unittest for fixing gamma.

@mxnet-bot
Copy link

Hey @wkcn , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [unix-cpu, clang, unix-gpu, centos-cpu, miscellaneous, website, edge, centos-gpu, windows-cpu, windows-gpu, sanity]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@wkcn
Copy link
Member Author

wkcn commented Jun 5, 2020

Hi @sxjscience , could you please help take a review?

@@ -228,7 +228,7 @@ class CuDNNBatchNormOp {
&a,
&b,
&a,
req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we just test with req[cudnnbatchnorm::kGamma] == kAddTo. We won't be able to set the gradient accumulation appropriately if req[cudnnbatchnorm::kGamma] == kWriteTo && req[cudnnbatchnorm::kBeta] == kAddTo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhreshold I believe there are some grad_req = add issues in BN. Would you know if there are cases that gamma.grad_req = write and beta.grad_req = add? If not, we may just raise an error in the OP if it happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any use case I can imagine, raising an exception is acceptable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sxjscience @zhreshold Thanks for your review! I will add the grad_req check and an unittest for gradient accumulation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gamma.grad_req = kAddTo is not supported in Naive CPU Implementation, Naive GPU one and MKLDNN one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it shows a larger problem of the BN layer...

Copy link
Member Author

@wkcn wkcn Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gradient of input, gamma and beta on CPU is wrong when grad_req is True. The gradient of input is not accumulated. The gradient of gamma and beta are both zero.

import mxnet as mx
from mxnet.gluon import nn

N = 1
C = 3
H = W = 2
block = nn.BatchNorm() 
block.collect_params().initialize()
block.collect_params().setattr('grad_req', 'add')

x = mx.nd.arange(N*C*H*W).reshape((N, C, H, W))
x.attach_grad()
for i in range(2):
    with mx.autograd.record():
        y = block(x)
        loss = (y * y).sum() 
    loss.backward()
print(x.grad, block.gamma.grad(), block.beta.grad())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is somehow out-of-the-scope of this PR. How about to create an issue and we can have another PR to fix the problem?

Copy link
Member Author

@wkcn wkcn Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created an issue #18499

@wkcn wkcn linked an issue Jun 6, 2020 that may be closed by this pull request
@wkcn
Copy link
Member Author

wkcn commented Jun 6, 2020

Close it and focus on the PR #18500 , which contains this PR.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

incorrect grad of gluon.nn.BatchNorm when scale=False
4 participants