Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul error when using output_all_encoded_layers = True, and pooler #451

Open
MostHumble opened this issue Jun 11, 2024 · 0 comments
Open

Comments

@MostHumble
Copy link

MostHumble commented Jun 11, 2024

Hi,

First off thanks for this great contribution!

There seems to be an issue with the handling of then encoder_outputs in the pooler level when passing output_all_encoded_layers = True.

encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
subset_mask=subset_mask)
if masked_tokens_mask is None:
sequence_output = encoder_outputs[-1]
pooled_output = self.pooler(
sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
attention_mask_bool = attention_mask.bool()

because when doing that, I'm getting:

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/PatientTrajectoryForecasting/utils/bert_layers_mosa.py:567, in BertPooler.forward(self, hidden_states, pool)
    561 def forward(self,
    562             hidden_states: torch.Tensor,
    563             pool: Optional[bool] = True) -> torch.Tensor:
    564     # We "pool" the model by simply taking the hidden state corresponding
    565     # to the first token.
    566     first_token_tensor = hidden_states[:, 0] if pool else hidden_states
--> 567     pooled_output = self.dense(first_token_tensor)
    568     pooled_output = self.activation(pooled_output)
    569     return pooled_output

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.conda/envs/mimibert/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x54784 and 768x768)

I believe the issue is due to the padding function not being applied to the hidden layens before appending to the list in the bert encoder level:

all_encoder_layers = []
if subset_mask is None:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states,
cu_seqlens,
seqlen,
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
# Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens.
# Then padding performs the following de-compression:
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
hidden_states = bert_padding_module.pad_input(
hidden_states, indices, batch, seqlen)
else:

(Edit: yep this works, but not haven't checked for deps)

all_encoder_layers.append(bert_padding_module.pad_input(
                hidden_states, indices, batch, seqlen))

The same thing should probably be done when the subset_mask is not None...

Thanks again for your contribution to the comunity!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant