Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GLU Fixes #1564

Merged
merged 4 commits into from
Sep 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions composer/algorithms/gated_linear_units/gated_linear_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,31 @@ def apply_gated_linear_units(model: torch.nn.Module,
raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERT models.')

if act_fn is None:
# get the activation functions used
act_fns = {module.intermediate_act_fn for module in model.modules() if isinstance(module, BertIntermediate)}

if len(act_fns) != 1:
raise ValueError('The model has non-uniform activation functions, which is currently unsupported.')
intermediate_modules = {module for module in model.modules() if isinstance(module, BertIntermediate)}
if len(intermediate_modules) == 0:
warnings.warn(
NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped '
'as no modules can be replaced. This is likely because Gated Linear Units has already '
'been applied to this model.'))
return

# since our set is of length-1, let's extract the only activation function remaining.
# get the activation functions used
act_fns = {module.intermediate_act_fn for module in intermediate_modules}
if len(act_fns) == 0:
raise ValueError('Tried to get the activation function from the model, but none were found. '
'Please specify `act_fn` manually to use Gated Linear Units.')
elif len(act_fns) > 1:
raise ValueError('Tried to get the activation function from the model, but multiple different '
'functions are used. This is currently unsupported with Gated Linear Units. '
'Please either use one activation function in BertIntermediate modules or '
'specify `act_fn` to manually override activation functions.')

# since our set is of 1, let's extract the only activation function remaining.
(act_fn,) = act_fns

if act_fn is None:
raise ValueError(
'Could not find an existing activation function to use, and no custom activation function was provided.')
if act_fn is None:
raise ValueError(
'Found activation function was None. If this is an error, please manually specify `act_fn`.')

# now that we know the act fn, bind a few parameters of the replacement function
def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGatedFFOutput:
Expand All @@ -117,8 +130,7 @@ def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGate
replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
if len(replaced_instances) == 0:
warnings.warn(
NoEffectWarning(
'No instances of `torch.nn.LayerNorm` were found, and therefore, there were no modules to replace.'))
NoEffectWarning('No instances of BertIntermediate and BertOutput were found so no modules were replaced.'))
log.info(
f'Successfully replaced {len(replaced_instances)} of BertIntermediate and BertOutput with a GatedLinearUnit.')

Expand Down