diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 8b8b954ba7..42fdcb299d 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -43,17 +43,16 @@ class FlashBaseFinetuning(BaseFinetuning): def __init__( self, - strategy_key: FinetuningStrategies, + strategy_key: Union[str, FinetuningStrategies], strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = None, train_bn: bool = True, ): """ Args: strategy_key: The finetuning strategy to be used. See :meth:`~flash.core.trainer.Trainer.finetune` - for the available strategies. + for the available strategies. strategy_metadata: Data that accompanies certain finetuning strategies like epoch number or number of - layers. - attr_names: Name(s) of the module attributes of the model to be frozen. + layers. train_bn: Whether to train Batch Norm layer """ super().__init__() @@ -62,11 +61,11 @@ def __init__( self.strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = strategy_metadata self.train_bn: bool = train_bn - if self.strategy == "freeze_unfreeze" and not isinstance(self.strategy_metadata, int): + if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE and not isinstance(self.strategy_metadata, int): raise MisconfigurationException( "`freeze_unfreeze` stratgey only accepts one integer denoting the epoch number to switch." ) - if self.strategy == "unfreeze_milestones" and not ( + if self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES and not ( isinstance(self.strategy_metadata, Tuple) and isinstance(self.strategy_metadata[0], Tuple) and isinstance(self.strategy_metadata[1], int) @@ -161,17 +160,17 @@ def finetune_function( # Used for properly verifying input and providing neat and helpful error messages for users. _DEFAULTS_FINETUNE_STRATEGIES = [ - "no_freeze", - "freeze", - "freeze_unfreeze", - "unfreeze_milestones", + FinetuningStrategies.NO_FREEZE.value, + FinetuningStrategies.FREEZE.value, + FinetuningStrategies.FREEZE_UNFREEZE.value, + FinetuningStrategies.UNFREEZE_MILESTONES.value, ] _FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies") -for strategy in FinetuningStrategies: +for strategy in _DEFAULTS_FINETUNE_STRATEGIES: _FINETUNING_STRATEGIES_REGISTRY( - name=strategy.value, + name=strategy, fn=partial(FlashBaseFinetuning, strategy_key=strategy), )