Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: deprecate flexible mlp heads #160

Merged
merged 31 commits into from
Nov 6, 2024
Merged

Conversation

rheasukthanker
Copy link
Collaborator

Reference Issues/PRs

Resolves #146

What does this implement/fix? Explain your changes.

Minimal Example / How should this PR be tested?

Any other comments?


By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the
terms of your choice.

@rheasukthanker rheasukthanker changed the title Deprecate flexible mlp heads fix: deprecate flexible mlp heads Nov 4, 2024
@rheasukthanker rheasukthanker marked this pull request as draft November 4, 2024 19:41
@rheasukthanker rheasukthanker marked this pull request as ready for review November 4, 2024 20:29
Copy link
Collaborator

@aaronkl aaronkl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just some nitpicks

@@ -19,6 +19,7 @@ dependencies = [
"litgpt[all]==0.5.0",
"syne-tune[moo]>=0.13",
"torchvision>=0.18",
"tokenizers==0.20.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this relevant for this PR? If not I would drop it to keep the changelogs clean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, an unrelated test case fails it this is not specified, hence I would keep it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok makes sense

print(
f"Mini model {compute_flops(model=model, metric='macs')} macs"
)
print(f"Mini model {compute_flops(model=model, metric='flops')} flops")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here if this is not relevant for this PR, I would drop it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruff complains about the way this file is formatted, I am not sure how the checks missed this.

sub_network_intermediate_size: list,
sub_network_num_heads: list,
sub_network_intermediate_size: int,
sub_network_num_heads: int,
sub_network_n_layers: int,
sub_network_query_groups=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use type hints here

Copy link
Collaborator

@gabikadlecova gabikadlecova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good to me.

I would suggest adding support for getting subnet head_size and n_query_groups from the supernet fields - this makes extraction very easy to do.

@@ -184,12 +184,29 @@ def set_sub_network(
self.sub_network_n_layers = sub_network_n_layers
self.transformer.wte.set_sub_network(self.sub_network_n_embd)
self.transformer.ln_f.set_sub_network(self.sub_network_n_embd)
if sub_network_query_groups is None:
if self.config.n_query_groups == 1:
sub_network_query_groups = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also set this as a field?
self.sub_network_query_groups = sub_network_query_groups

Then, we can easily get the value for the extract function without duplicating the query_group and head_size computation:

supernet = GPT(config)
supernet.set_sub_network(**subnet_args)  # sub_network_query_groups == None
# ... sub_network_query_groups gets computed here

subnet_config = ... # same as subnet_args
subnet_config.n_query_groups = supernet.sub_network_query_groups
subnet_config.head_size = supernet.sub_network_head_size
subnet_correct_sizes = extract_sub_network(supernet, subnet_config)

Copy link
Collaborator Author

@rheasukthanker rheasukthanker Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gabikadlecova I dont clearly understand the changes needed to extract function+the changes in tests needed. Could you perhaps create a separate PR for that after this PR is merged, or push to this branch directly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both sub_network_query_groups and sub_network_head_size are supernet fields now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rheasukthanker I can push to this branch, it's a small change

Comment on lines +53 to +54
self.sub_network_query_groups = sub_network_query_groups
self.sub_network_head_size = sub_network_head_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we set it like this without checking for None, these should become positional args.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this is initialized currently, they can never be none

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant to tag lines 48-49

      sub_network_query_groups=None,
      sub_network_head_size=None,

Since now they cannot be None, we might want to change it to

      sub_network_query_groups: int,
      sub_network_head_size: int,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be either None or int no? Since they are None by default?

sub_network_query_groups: int | None,
sub_network_head_size: int | None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaronkl yes, but I think they should not be allowed to be None anymore - the computation is done in model.py and hence None is not a valid value here anymore

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to add a test case where the default head_size and/or n_query_groups of the subnet become different, and we need to compute it/copy it from supernet.sub_net_head_size (see my comment here on whittle/models/gpt/model.py)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rheasukthanker since you set it as supernet fields already, I only changed the test case to use it

Copy link
Collaborator

@gabikadlecova gabikadlecova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subnet head size and query groups are fields now - the PR is ready to be merged.

I'd still change the subnet head_size/query groups to positional arguments and remove None from type hints (they cannot be None now and should not be None). But it's minor

Comment on lines +51 to +52
self.sub_network_head_size: int | None = self.config.head_size
self.sub_network_query_groups: int | None = self.config.n_query_groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it really be None now? It has to be present in the config and when setting the subnetwork, it will never be None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that our Linter complains since it does some automated type checking later. Maybe there is a more elegant solution?

Comment on lines +234 to +235
self.sub_network_head_size: int | None = self.config.head_size
self.sub_network_query_groups: int | None = self.config.n_query_groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, can it not be None?

@gabikadlecova gabikadlecova self-requested a review November 6, 2024 08:56
Copy link
Collaborator

@gabikadlecova gabikadlecova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subnetwork head size and n_query_groups are fields now

Comment on lines +53 to +54
self.sub_network_query_groups = sub_network_query_groups
self.sub_network_head_size = sub_network_head_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they should be either None or int no? Since they are None by default?

sub_network_query_groups: int | None,
sub_network_head_size: int | None,

Comment on lines +51 to +52
self.sub_network_head_size: int | None = self.config.head_size
self.sub_network_query_groups: int | None = self.config.n_query_groups
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is that our Linter complains since it does some automated type checking later. Maybe there is a more elegant solution?

@rheasukthanker rheasukthanker merged commit a07ee5a into main Nov 6, 2024
7 checks passed
@rheasukthanker rheasukthanker deleted the deprecate-flexible-mlp-heads branch November 6, 2024 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

use fixed number of heads / intermediate size per layer
3 participants