Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing position_ids argument when recompute_granularity == full #86

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

xingyaoww
Copy link
Contributor

When set --recompute_granularity full for finetuning, we will see traceback like this:

  File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 757, in forward
    attention_output, attention_bias = self.self_attention(layernorm_output,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Megatron-LLM/megatron/model/transformer.py", line 502, in forward
    query_layer, key_layer = apply_rotary_emb(query_layer, key_layer, self.freqs_cis, position_ids=position_ids)
  File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 36, in apply_rotary_emb
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
  File "/workspace/Megatron-LLM/megatron/model/positional_embeddings.py", line 19, in reshape_for_broadcast
    assert freqs_cis.shape == (x.shape[0], x.shape[-1])
AssertionError

When tracing back, we find that reshape_for_broadcast is only called when position_ids is None, which means the position_ids was NOT passed to each transformer layer when --recompute_granularity full (finetuning did work when --recompute_granularity selective).

I further chased the error down to megatron/model/transformer.py, it turns out there are some missing arguments when calling the checkpoint function through _checkpointed_forward, which I fixed in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant