Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[litgpt benchmark] enable force_recompute_fp8_weight_in_bwd when torchao.float8 is used with FSDP2 #1528

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Dec 9, 2024

What does this PR do?

As per title.

ref: pytorch/ao@e919558

With 8 H100s and pjnl-20241209.
Used command is: torchrun --nproc-per-node 8 --local-ranks-filter 0 --role rank --tee 3 thunder/benchmarks/benchmark_litgpt.py --model_name <MODEL_NAME> --compile inductor --distributed_mode fsdp2 --shard_mode zero2 --use_torchao_fp8_linear true --use_torchao_fp8_allgather true --use_torchao_fp8_precompute_scale_for_fsdp true

Llama-2-7b-hf.

branch perf (tokens/s/gpu) mem usage (GB)
main 13947.29 34.26
this PR 13995.80 27.69

Llama-3-8B

branch perf (tokens/s/gpu) mem usage (GB)
main 12404.18 58.65
this PR 12414.15 51.67

cc @crcrpar

ref: pytorch/ao@e919558
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/ao-recompute_fp8_weight_in_bwd branch from a824ae4 to 9ad3327 Compare December 17, 2024 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants