-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SABlock parameters when using more heads #7661
Comments
7 tasks
KumoLiu
added a commit
that referenced
this issue
May 8, 2024
Fixes #7661. ### Description The changes made add a parameter (_dim_head_) to set the output paramters of all the heads in the Self-attention Block (SABlock). Currently the output dimension is set to be _hidden_size_ and when increasing the number of heads this is equally distributed among all heads. ### Example The original implementation automatically determines **_equally_distributed_head_dim_**: (qkv * num_heds * equally_distributed_head_dim = 3*hidden_size in this example -> 3 * 8 * 16 = 384) ``` block = SABlock(hidden_size=128, num_heads=8) x = torch.zeros(1, 256, 128) x = block.qkv(x) print(x.shape) x = block.input_rearrange(x) print(x.shape) > torch.Size([1, 256, 384]) > torch.Size([3, 1, 8, 256, 16]) # <- This corresponds to (qkv batch num_heads sequence_length equally_distributed_head_dim) ``` The propesed implementation fixes this by setting the new argument **_dim_head_:** ``` block_new = SABlock(hidden_size=128, num_heads=8, dim_head=32) x = torch.zeros(1, 256, 128) x = block_new.qkv(x) print(x.shape) x = block_new.input_rearrange(x) print(x.shape) > torch.Size([1, 256, 384]) > torch.Size([3, 1, 8, 256, 32]) # <- This corresponds to (qkv batch num_heads sequence_length dim_head) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: NabJa <nabil.jabareen@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
The number of parameters in the SABlock should be increased when increasing the number of heads (num_heads). However, this is not the case and limits comparability to famous scaling like ViT-S or ViT-B.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
The number of trainable parameters should be increased with increasing number of heads.
Environment
The text was updated successfully, but these errors were encountered: