Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GLU support #38

Merged
merged 25 commits into from
Dec 1, 2023
Merged

Add GLU support #38

merged 25 commits into from
Dec 1, 2023

Conversation

sashaDoubov
Copy link
Collaborator

This change adds GLU blocks to megablocks (replacing vanilla MLPs), and does some refactoring around the mlp types, including an MLP_TYPE_REGISTRY.
Note, this is unoptimized at the moment:

  • no support for the memory optimized mlp version yet
  • w1 & v1 are not fused into one op that is split

@sashaDoubov sashaDoubov changed the title Add GLU support to megablocks Add GLU support Nov 8, 2023
megablocks/layers/dmoe.py Outdated Show resolved Hide resolved
megablocks/layers/dmoe_test.py Show resolved Hide resolved
megablocks/layers/dmoe_test.py Outdated Show resolved Hide resolved
megablocks/layers/glu.py Show resolved Hide resolved
megablocks/layers/glu.py Outdated Show resolved Hide resolved
megablocks/layers/glu.py Outdated Show resolved Hide resolved
megablocks/layers/glu.py Outdated Show resolved Hide resolved
megablocks/layers/glu.py Show resolved Hide resolved
megablocks/layers/glu.py Show resolved Hide resolved
@sashaDoubov sashaDoubov marked this pull request as ready for review November 10, 2023 17:59
@sashaDoubov sashaDoubov marked this pull request as draft November 10, 2023 17:59
@sashaDoubov sashaDoubov marked this pull request as ready for review November 10, 2023 23:00
@sashaDoubov
Copy link
Collaborator Author

I've updated the PR based on @tgale96 's feedback, mainly:

  • adding registry
  • isolated test file

I've left a few of the issues unresolved as I work through them, namelystk support for a cleaner implementation (as well as some code quality changes)

megablocks/layers/dmlp_registry.py Outdated Show resolved Hide resolved
megablocks/layers/glu_test.py Outdated Show resolved Hide resolved
megablocks/layers/glu_test.py Outdated Show resolved Hide resolved
@@ -38,8 +38,8 @@ class Arguments:

# Compute arguments.
memory_optimized_mlp : bool = False
mlp_type: str = 'mlp'
grouped_mlp: bool = False
mlp_type : str = 'mlp'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean to add these spaces? Looks like we're actually mixed on having them and not having them in this file...

Copy link
Collaborator Author

@sashaDoubov sashaDoubov Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm yea I thought that the spaces would match the existing style of the file


MlpType = Union[mlp.SparseMLP, glu.SparseGLU]

class dMlpRegistry:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stylistic thing - can we remove this class and have get just be a function on the module? Then REGISTRY can be a private, global? i.e., _REGISTRY?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored it to get rid of the class, let me know if it looks good now!

@tgale96
Copy link
Contributor

tgale96 commented Dec 1, 2023

LGTM! Ready to merge?

@sashaDoubov
Copy link
Collaborator Author

@tgale96 great, yes!

@tgale96 tgale96 merged commit 059ae20 into databricks:main Dec 1, 2023
@tgale96
Copy link
Contributor

tgale96 commented Dec 1, 2023

Thanks for the contribution Sasha! This is awesome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants