-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add from_pretrained with HFHub #9
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added:
- risoluzione di input dentro al checkpoint name;
- numero di canali dentro datasets_info;
- stringhe con " " ;
- fix filename checkpoint;
@fpaissan please check dependencies.
@@ -1,2 +1,9 @@ | |||
from .networks import PhiNet | |||
from .blocks import PhiNetConvBlock | |||
|
|||
|
|||
datasets_info = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment explaining why and where this is used in the code.
phinet/networks.py
Outdated
resolution, | ||
device=None, | ||
): | ||
def num_class_error(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we implement the comment on line 46, we no longer need this.
phinet/networks.py
Outdated
|
||
assert ( | ||
num_classes == phinet.datasets_info[dataset]["Nclasses"] | ||
), num_class_error() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this, we can directly put a string; it should raise AssertionError by default -- I forgot about this.
Something like:
assert (
num_classes == phinet.datasets_info[dataset]["Nclasses"]
), "messaggio di errore"
-- see here alse
@matteobeltrami we should also handle the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added description to the datasets_info dictionary, modified the assertion, changed print("WARNING ...") with logging.warning("...").
@fpaissan please check dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added description to the datasets_info dictionary, modified the assertion, changed print("WARNING ...") with logging.warning("...").
Performed black and flake8 checks
@fpaissan please check dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added "classifier" (initialized as True) to the from_pretrained() method, rewritten the model initialization instruction as follows: model.load_state_dict(state_dict["state_dict"], strict=False). Notice that it differs from the previous version only in the "strict" parameter that determines whether to strictly enforce that the keys in state_dict match the keys returned by the state_dict() function, as explained in the pytorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict.
@fpaissan please check dependencies
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Still missing:
@fpaissan please check dependencies.