Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slightly incorrect documentation for VAE #543

Closed
jeremyadamsfisher opened this issue Jan 26, 2021 · 4 comments · Fixed by #557
Closed

Slightly incorrect documentation for VAE #543

jeremyadamsfisher opened this issue Jan 26, 2021 · 4 comments · Fixed by #557
Labels
documentation Improvements or additions to documentation model

Comments

@jeremyadamsfisher
Copy link
Contributor

jeremyadamsfisher commented Jan 26, 2021

The example for a VAE indicates that from_pretrained is a static method

https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/4e1c1502b70e4be59a9ac80878ec5ad5a212f87e/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py#L30

Indeed, running this line causes as error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-64-4c7d0e8ca436> in <module>
----> 1 vae = VAE.from_pretrained('cifar10-resnet18')

TypeError: from_pretrained() missing 1 required positional argument: 'checkpoint_name'

This is because the method is actually implemented nonstatically:

def from_pretrained(self, checkpoint_name):
    if checkpoint_name not in VAE.pretrained_urls:
        raise KeyError(str(checkpoint_name) + ' not present in pretrained weights.')

    return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)

Seems to me either:

  • the documentation should to reflect this (vae = VAE(32).from_pretrained(...))
  • the method should be rewritten
@classmethod
def from_pretrained(cls, checkpoint_name):
    # snip
    return self.load_from_checkpoint(cls(32).pretrained_urls[checkpoint_name], strict=False)

Here is a minimal failing example on colab: https://colab.research.google.com/drive/1quvrQUyCIMC7Cq9QCjaJfip5hsi9jwSJ?usp=sharing

Happy to write a PR, if desired

@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@akihironitta akihironitta added model documentation Improvements or additions to documentation labels Jan 29, 2021
@akihironitta
Copy link
Contributor

@jeremyadamsfisher Hi, thank you for reporting the issue! Following the suggestion by @ananyahjha93 in #200 (comment), let's just fix the docs for now.

  1. from_pretrained() needs to be an instance method and not a static method. In most cases, you will initialize the lightning module with specific params according the the weights being loaded.

@akihironitta
Copy link
Contributor

@jeremyadamsfisher Would you be interested in submitting a PR?

@jeremyadamsfisher
Copy link
Contributor Author

jeremyadamsfisher commented Feb 6, 2021

Yeah, no problem

@Borda Borda closed this as completed in #557 Mar 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation model
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants