Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Clean-up to fix rtd build (#1259)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Mar 30, 2022
1 parent 1a40d3d commit 37e1a3c
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,16 @@ class FlashBaseFinetuning(BaseFinetuning):

def __init__(
self,
strategy_key: FinetuningStrategies,
strategy_key: Union[str, FinetuningStrategies],
strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = None,
train_bn: bool = True,
):
"""
Args:
strategy_key: The finetuning strategy to be used. See :meth:`~flash.core.trainer.Trainer.finetune`
for the available strategies.
for the available strategies.
strategy_metadata: Data that accompanies certain finetuning strategies like epoch number or number of
layers.
attr_names: Name(s) of the module attributes of the model to be frozen.
layers.
train_bn: Whether to train Batch Norm layer
"""
super().__init__()
Expand All @@ -62,11 +61,11 @@ def __init__(
self.strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = strategy_metadata
self.train_bn: bool = train_bn

if self.strategy == "freeze_unfreeze" and not isinstance(self.strategy_metadata, int):
if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE and not isinstance(self.strategy_metadata, int):
raise MisconfigurationException(
"`freeze_unfreeze` stratgey only accepts one integer denoting the epoch number to switch."
)
if self.strategy == "unfreeze_milestones" and not (
if self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES and not (
isinstance(self.strategy_metadata, Tuple)
and isinstance(self.strategy_metadata[0], Tuple)
and isinstance(self.strategy_metadata[1], int)
Expand Down Expand Up @@ -161,17 +160,17 @@ def finetune_function(

# Used for properly verifying input and providing neat and helpful error messages for users.
_DEFAULTS_FINETUNE_STRATEGIES = [
"no_freeze",
"freeze",
"freeze_unfreeze",
"unfreeze_milestones",
FinetuningStrategies.NO_FREEZE.value,
FinetuningStrategies.FREEZE.value,
FinetuningStrategies.FREEZE_UNFREEZE.value,
FinetuningStrategies.UNFREEZE_MILESTONES.value,
]

_FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies")

for strategy in FinetuningStrategies:
for strategy in _DEFAULTS_FINETUNE_STRATEGIES:
_FINETUNING_STRATEGIES_REGISTRY(
name=strategy.value,
name=strategy,
fn=partial(FlashBaseFinetuning, strategy_key=strategy),
)

Expand Down

0 comments on commit 37e1a3c

Please sign in to comment.