Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eval_drop_last flag to fix TE eval bug #1247

Closed
wants to merge 13 commits into from

Conversation

j316chuck
Copy link
Contributor

@j316chuck j316chuck commented Jun 3, 2024

Description

This PR introduces the eval_drop_last flag which enables the drop_last flag in ICL eval pytorch Dataloaders. This flag ensures that all dataset batches will be divisible by eval_batch_size. This feature is necessary because TransformerEngine requires all inputs to be divisible by 8 and so we must pass in batches of size 8. Before, the eval dataloaders would return the remainder of the dataset size on the last batch which would result in an error.

For example, if the dataset was of length 41 and the batch size was 8, the last batch would be of size 41 % 8 = 1 which would break TE. Now with this eval_drop_last flag enabled, we simply skip this last batch of size 1.

Note: enabling this flag will result in different eval scores.

Testing

Unit Test: test_icl_task_tokenizer_and_dataloader

Integration Test:

  • Before: fp8-llama3-8b-metamath-4ep-4LEFPw 🔴

  • Error Traceback:

 [Eval batch=1/6] Eval on gsm8k/0-shot data
[Eval batch=2/6] Eval on gsm8k/0-shot data
[Eval batch=3/6] Eval on gsm8k/0-shot data
[Eval batch=4/6] Eval on gsm8k/0-shot data
[Eval batch=5/6] Eval on gsm8k/0-shot data
/usr/lib/python3/dist-packages/composer/core/data_spec.py:37: UserWarning: Cannot split tensor of length 2 into batches of size 8. As it is smaller, no splitting will be done. This may happen on the last batch of a dataset if it is a smaller size than the microbatch size.
 warnings.warn(
/usr/lib/python3/dist-packages/composer/core/data_spec.py:26: UserWarning: Cannot split list of length 2 into batches of size 8. As it is smaller, no splitting will be done. This may happen on the last batch of a dataset if it is a smaller size than the microbatch size.
...
 [rank6]:   File "/usr/lib/python3/dist-packages/transformer_engine/pytorch/utils.py", line 235, in assert_dim_for_fp8_exec
[rank6]:     tensor.dim() == 2
[rank6]: AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[1404, 4096]
  • After: fp8-llama3-8b-metamath-4ep-0uiOJb
   [Eval batch=1/5] Eval on gsm8k/0-shot data
[Eval batch=2/5] Eval on gsm8k/0-shot data
[Eval batch=3/5] Eval on gsm8k/0-shot data
[Eval batch=4/5] Eval on gsm8k/0-shot data
[Eval batch=5/5] Eval on gsm8k/0-shot data:
Eval metrics/gsm8k/0-shot/InContextLearningGenerationExactMatchAccuracy: 0.6016
  • Reference run: llama3-8b-metamath-4ep-jaIcPX with no skipped batches
Eval metrics/gsm8k/0-shot/InContextLearningGenerationExactMatchAccuracy: 0.5807

Issues Fixed

https://databricks.atlassian.net/browse/RGENAI-165

@j316chuck j316chuck force-pushed the chuck/fix_te_eval_with_drop_last_flag branch from 25e64ee to d821d42 Compare June 3, 2024 06:04
@j316chuck j316chuck requested review from b-chu, dakinggg and irenedea June 3, 2024 07:09
@j316chuck j316chuck changed the title add drop last flag Add eval_drop_last flag to fix TE eval bug Jun 3, 2024
@j316chuck j316chuck marked this pull request as ready for review June 3, 2024 07:25
Copy link
Collaborator

@dakinggg dakinggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this is not good because we should not be dropping evaluation data. Eval results should be exact.

@b-chu
Copy link
Contributor

b-chu commented Jun 5, 2024

Agreed with Daniel here

@j316chuck j316chuck requested a review from a team as a code owner June 8, 2024 01:29
@dakinggg dakinggg marked this pull request as draft June 8, 2024 02:25
@snarayan21
Copy link
Contributor

Can we disable TE layers just for eval if they have this batch size requirement? Or turn off fp8 temporarily?

Copy link
Collaborator

@mvpatel2000 mvpatel2000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we doc for this? Agree dropping is bad

@j316chuck
Copy link
Contributor Author

Closing as we are going with a different approach here

@j316chuck j316chuck closed this Jul 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants