diff --git a/composer/algorithms/gated_linear_units/gated_linear_units.py b/composer/algorithms/gated_linear_units/gated_linear_units.py index 502175e689..fc930b9ada 100644 --- a/composer/algorithms/gated_linear_units/gated_linear_units.py +++ b/composer/algorithms/gated_linear_units/gated_linear_units.py @@ -88,18 +88,31 @@ def apply_gated_linear_units(model: torch.nn.Module, raise TypeError('Gated Linear Units only has a surgery policy defined for instances of BERT models.') if act_fn is None: - # get the activation functions used - act_fns = {module.intermediate_act_fn for module in model.modules() if isinstance(module, BertIntermediate)} - - if len(act_fns) != 1: - raise ValueError('The model has non-uniform activation functions, which is currently unsupported.') + intermediate_modules = {module for module in model.modules() if isinstance(module, BertIntermediate)} + if len(intermediate_modules) == 0: + warnings.warn( + NoEffectWarning('No instances of BertIntermediate were found so Gated Linear Units will be skipped ' + 'as no modules can be replaced. This is likely because Gated Linear Units has already ' + 'been applied to this model.')) + return - # since our set is of length-1, let's extract the only activation function remaining. + # get the activation functions used + act_fns = {module.intermediate_act_fn for module in intermediate_modules} + if len(act_fns) == 0: + raise ValueError('Tried to get the activation function from the model, but none were found. ' + 'Please specify `act_fn` manually to use Gated Linear Units.') + elif len(act_fns) > 1: + raise ValueError('Tried to get the activation function from the model, but multiple different ' + 'functions are used. This is currently unsupported with Gated Linear Units. ' + 'Please either use one activation function in BertIntermediate modules or ' + 'specify `act_fn` to manually override activation functions.') + + # since our set is of 1, let's extract the only activation function remaining. (act_fn,) = act_fns - if act_fn is None: - raise ValueError( - 'Could not find an existing activation function to use, and no custom activation function was provided.') + if act_fn is None: + raise ValueError( + 'Found activation function was None. If this is an error, please manually specify `act_fn`.') # now that we know the act fn, bind a few parameters of the replacement function def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGatedFFOutput: @@ -117,8 +130,7 @@ def from_bound_BertOutput(layer: torch.nn.Module, module_index: int) -> BERTGate replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy) if len(replaced_instances) == 0: warnings.warn( - NoEffectWarning( - 'No instances of `torch.nn.LayerNorm` were found, and therefore, there were no modules to replace.')) + NoEffectWarning('No instances of BertIntermediate and BertOutput were found so no modules were replaced.')) log.info( f'Successfully replaced {len(replaced_instances)} of BertIntermediate and BertOutput with a GatedLinearUnit.')