Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable flexible emb init #220

Merged
merged 6 commits into from
Mar 10, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion examples/llm/src/models/param_init_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import warnings
from collections.abc import Sequence
from functools import partial

import torch
Expand Down Expand Up @@ -68,7 +69,38 @@ def generic_param_init_fn_(module, cfg, init_fn_):

elif isinstance(module, nn.Embedding):
# Embedding
init_fn_(module.weight)
if cfg.get('emb_init_std') is not None:
std = cfg.get('emb_init_std')
if std == 0:
warnings.warn(f'Embedding layer initialized to 0.')
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
if cfg.get('verbose', 0) > 1:
warnings.warn(
f'Embedding layer initialized using normal distribution with mean=0 and {std=}.'
)
elif cfg.get('emb_init_uniform_lim') is not None:
lim = cfg.get('emb_init_uniform_lim')
if isinstance(lim, Sequence):
if len(lim) > 2:
raise ValueError(
f'Uniform init requires a min and a max limit. User input: {lim}.'
)
if lim[0] == lim[1]:
warnings.warn(f'Embedding layer initialized to {lim[0]}.')
else:
if lim == 0:
warnings.warn(f'Embedding layer initialized to 0.')
lim = [-lim, lim]
a, b = lim
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
if cfg.get('verbose', 0) > 1:
warnings.warn(
f'Embedding layer initialized using uniform distribution in range {lim}.'
)
else:
emb_init_fn_ = init_fn_

emb_init_fn_(module.weight)

elif isinstance(module, nn.LayerNorm):
# LayerNorm
Expand Down