-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: deprecate flexible mlp heads #160
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just some nitpicks
@@ -19,6 +19,7 @@ dependencies = [ | |||
"litgpt[all]==0.5.0", | |||
"syne-tune[moo]>=0.13", | |||
"torchvision>=0.18", | |||
"tokenizers==0.20.0", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this relevant for this PR? If not I would drop it to keep the changelogs clean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, an unrelated test case fails it this is not specified, hence I would keep it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok makes sense
print( | ||
f"Mini model {compute_flops(model=model, metric='macs')} macs" | ||
) | ||
print(f"Mini model {compute_flops(model=model, metric='flops')} flops") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here if this is not relevant for this PR, I would drop it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ruff complains about the way this file is formatted, I am not sure how the checks missed this.
whittle/models/gpt/model.py
Outdated
sub_network_intermediate_size: list, | ||
sub_network_num_heads: list, | ||
sub_network_intermediate_size: int, | ||
sub_network_num_heads: int, | ||
sub_network_n_layers: int, | ||
sub_network_query_groups=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use type hints here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good to me.
I would suggest adding support for getting subnet head_size and n_query_groups from the supernet fields - this makes extraction very easy to do.
whittle/models/gpt/model.py
Outdated
@@ -184,12 +184,29 @@ def set_sub_network( | |||
self.sub_network_n_layers = sub_network_n_layers | |||
self.transformer.wte.set_sub_network(self.sub_network_n_embd) | |||
self.transformer.ln_f.set_sub_network(self.sub_network_n_embd) | |||
if sub_network_query_groups is None: | |||
if self.config.n_query_groups == 1: | |||
sub_network_query_groups = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also set this as a field?
self.sub_network_query_groups = sub_network_query_groups
Then, we can easily get the value for the extract
function without duplicating the query_group and head_size computation:
supernet = GPT(config)
supernet.set_sub_network(**subnet_args) # sub_network_query_groups == None
# ... sub_network_query_groups gets computed here
subnet_config = ... # same as subnet_args
subnet_config.n_query_groups = supernet.sub_network_query_groups
subnet_config.head_size = supernet.sub_network_head_size
subnet_correct_sizes = extract_sub_network(supernet, subnet_config)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gabikadlecova I dont clearly understand the changes needed to extract function+the changes in tests needed. Could you perhaps create a separate PR for that after this PR is merged, or push to this branch directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both sub_network_query_groups and sub_network_head_size are supernet fields now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rheasukthanker I can push to this branch, it's a small change
self.sub_network_query_groups = sub_network_query_groups | ||
self.sub_network_head_size = sub_network_head_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we set it like this without checking for None, these should become positional args.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way this is initialized currently, they can never be none
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant to tag lines 48-49
sub_network_query_groups=None,
sub_network_head_size=None,
Since now they cannot be None, we might want to change it to
sub_network_query_groups: int,
sub_network_head_size: int,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they should be either None
or int no? Since they are None
by default?
sub_network_query_groups: int | None,
sub_network_head_size: int | None,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aaronkl yes, but I think they should not be allowed to be None
anymore - the computation is done in model.py and hence None
is not a valid value here anymore
test/test_extract.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to add a test case where the default head_size
and/or n_query_groups
of the subnet become different, and we need to compute it/copy it from supernet.sub_net_head_size
(see my comment here on whittle/models/gpt/model.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rheasukthanker since you set it as supernet fields already, I only changed the test case to use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The subnet head size and query groups are fields now - the PR is ready to be merged.
I'd still change the subnet head_size/query groups to positional arguments and remove None from type hints (they cannot be None now and should not be None). But it's minor
self.sub_network_head_size: int | None = self.config.head_size | ||
self.sub_network_query_groups: int | None = self.config.n_query_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it really be None now? It has to be present in the config and when setting the subnetwork, it will never be None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is that our Linter complains since it does some automated type checking later. Maybe there is a more elegant solution?
self.sub_network_head_size: int | None = self.config.head_size | ||
self.sub_network_query_groups: int | None = self.config.n_query_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, can it not be None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Subnetwork head size and n_query_groups are fields now
self.sub_network_query_groups = sub_network_query_groups | ||
self.sub_network_head_size = sub_network_head_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they should be either None
or int no? Since they are None
by default?
sub_network_query_groups: int | None,
sub_network_head_size: int | None,
self.sub_network_head_size: int | None = self.config.head_size | ||
self.sub_network_query_groups: int | None = self.config.n_query_groups |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is that our Linter complains since it does some automated type checking later. Maybe there is a more elegant solution?
Reference Issues/PRs
Resolves #146
What does this implement/fix? Explain your changes.
Minimal Example / How should this PR be tested?
Any other comments?
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the
terms of your choice.