Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Clean-up to fix rtd build #1259

Merged
merged 1 commit into from
Mar 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,16 @@ class FlashBaseFinetuning(BaseFinetuning):

def __init__(
self,
strategy_key: FinetuningStrategies,
strategy_key: Union[str, FinetuningStrategies],
strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = None,
train_bn: bool = True,
):
"""
Args:
strategy_key: The finetuning strategy to be used. See :meth:`~flash.core.trainer.Trainer.finetune`
for the available strategies.
for the available strategies.
strategy_metadata: Data that accompanies certain finetuning strategies like epoch number or number of
layers.
attr_names: Name(s) of the module attributes of the model to be frozen.
layers.
train_bn: Whether to train Batch Norm layer
"""
super().__init__()
Expand All @@ -62,11 +61,11 @@ def __init__(
self.strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = strategy_metadata
self.train_bn: bool = train_bn

if self.strategy == "freeze_unfreeze" and not isinstance(self.strategy_metadata, int):
if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE and not isinstance(self.strategy_metadata, int):
raise MisconfigurationException(
"`freeze_unfreeze` stratgey only accepts one integer denoting the epoch number to switch."
)
if self.strategy == "unfreeze_milestones" and not (
if self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES and not (
isinstance(self.strategy_metadata, Tuple)
and isinstance(self.strategy_metadata[0], Tuple)
and isinstance(self.strategy_metadata[1], int)
Expand Down Expand Up @@ -161,17 +160,17 @@ def finetune_function(

# Used for properly verifying input and providing neat and helpful error messages for users.
_DEFAULTS_FINETUNE_STRATEGIES = [
"no_freeze",
"freeze",
"freeze_unfreeze",
"unfreeze_milestones",
FinetuningStrategies.NO_FREEZE.value,
FinetuningStrategies.FREEZE.value,
FinetuningStrategies.FREEZE_UNFREEZE.value,
FinetuningStrategies.UNFREEZE_MILESTONES.value,
]

_FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies")

for strategy in FinetuningStrategies:
for strategy in _DEFAULTS_FINETUNE_STRATEGIES:
_FINETUNING_STRATEGIES_REGISTRY(
name=strategy.value,
name=strategy,
fn=partial(FlashBaseFinetuning, strategy_key=strategy),
)

Expand Down