Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Sharded Layernorm/RMSNorm #12635

Merged
merged 23 commits into from
Sep 20, 2024
Merged

Conversation

kpaigwar
Copy link
Contributor

@kpaigwar kpaigwar commented Sep 13, 2024

Summary

This PR introduces sharded distributed LayerNorm for pre/post-all-gather stages, leveraging existing single device sharded LayerNorm host code. The implementation includes new kernels and comprehensive test coverage.

Key Changes

  1. Implemented sharded distributed LayerNorm for pre/post-all-gather stages

    • Utilized existing single device sharded LayerNorm host code
    • Ensures compatibility with current implementation
  2. Performance and Accuracy

    • No impact on performance or accuracy of existing single device sharded LayerNorm
  3. New Kernel Development

    • Added new kernels specifically for pre/post-all-gather operations
  4. Test Coverage

    • Implemented test coverage for pre/post-all-gather on single device
    • Added end-to-end (e2e) test coverage on galaxy (TG)

Performance Benchmarks

Full Activation Shape : [1, 1, 32, 8192]
Activation and Weights dtype : ttnn.bfloat16
Stages Norm Type Num Distributed Devices Num Cores FW Duration(ns)
pre_all_gather RMSnorm 4 32 6855
post_all_gather RMSnorm 4 32 6396
pre_all_gather Layernorm 4 32 9441
post_all_gather Layernorm 4 32 7944

Future Steps

  1. Add normalization support for batch_size > 32
  2. Create functions for CB creation to improve readability

Checklist

@kpaigwar kpaigwar marked this pull request as ready for review September 16, 2024 14:03
@kpaigwar kpaigwar force-pushed the kpaigwar/distrbuted_sharded_LN branch from c58601e to 943bb69 Compare September 16, 2024 14:59
Copy link
Contributor

@TT-BrianLiu TT-BrianLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change copyright to 2024 for all new files.

@kpaigwar kpaigwar force-pushed the kpaigwar/distrbuted_sharded_LN branch from f646f4e to 4a8172f Compare September 20, 2024 17:58
@kpaigwar kpaigwar merged commit 96ea95f into main Sep 20, 2024
6 checks passed
@kpaigwar kpaigwar deleted the kpaigwar/distrbuted_sharded_LN branch September 20, 2024 18:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants