[performance] module init w/ from_pretrained
skip storage allocation
#12274
Labels
WIP
Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
🚀 Feature request
pt-1.9.0 added
torch.nn.utils.skip_init()
which (1) skips the module init (2) doesn't allocate any memoryhttps://pytorch.org/tutorials/prototype/skip_param_init.html
note:
torch.nn.utils.skip_init()
itself will be in 1.9.1, but the rest of the code should be in 1.9.0 (update: as 1.9.1 isn't planned, probablys/1.9.1/1.10/
)We already implemented part 1 (skipping the custom init) in #11471.
We could further speed up the start up time and reduce CPU memory usage, by not allocating any storage for module init since
load_state_dict
will already have allocatedstate_dict
from the pretrained weights (and some sub-modules that don't have pre-trained weights - will have to go through normal init). See https://pytorch.org/tutorials/prototype/skip_param_init.html#implementation-detailsanother note: currently deepspeed needs to have the module storage pre-allocated for its
zero.Init
gather/scatter, but if the initial model's weights aren't allocated, then we can probably get rid ofzero.Init
altogether #12273The text was updated successfully, but these errors were encountered: