Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Oct 12, 2021
1 parent 6aed8f3 commit 64e376d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
44 changes: 42 additions & 2 deletions flash/core/data/preprocess_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def from_transform(
return cls(running_stage, {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform})

if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)):
enum, transform_kwargs = cls._sanetize_registry_transform(transform, transform_registry)
enum, transform_kwargs = cls._sanitize_registry_transform(transform, transform_registry)
transform_cls = transform_registry.get(enum)
return transform_cls(running_stage, transform=None, **transform_kwargs)

Expand All @@ -154,6 +154,46 @@ def from_transform(

raise MisconfigurationException(f"The format for the transform isn't correct. Found {transform}")

@classmethod
def from_train_transform(
cls,
transform: TRANSFORM_TYPE,
transform_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TRAINING, transform_registry=transform_registry
)

@classmethod
def from_val_transform(
cls,
transform: TRANSFORM_TYPE,
transform_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.VALIDATING, transform_registry=transform_registry
)

@classmethod
def from_test_transform(
cls,
transform: TRANSFORM_TYPE,
transform_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TESTING, transform_registry=transform_registry
)

@classmethod
def from_predict_transform(
cls,
transform: TRANSFORM_TYPE,
transform_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.PREDICTING, transform_registry=transform_registry
)

def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
from flash.core.data.data_pipeline import DataPipeline

Expand Down Expand Up @@ -206,7 +246,7 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable:
return self._identity

@classmethod
def _sanetize_registry_transform(
def _sanitize_registry_transform(
cls, transform: Tuple[Union[LightningEnum, str], Any], transform_registry: Optional[FlashRegistry]
) -> Tuple[Union[LightningEnum, str], Dict]:
msg = "The transform should be provided as a tuple with the following types (LightningEnum, Dict[str, Any]) "
Expand Down
11 changes: 10 additions & 1 deletion tests/core/data/test_preprocess_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ def test_preprocess_transform():
def fn(x):
return x + 1

transform = PreprocessTransform.from_transform(running_stage=RunningStage.TRAINING, transform=fn)
transform = PreprocessTransform.from_train_transform(transform=fn)
assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn}

transform = PreprocessTransform.from_val_transform(transform=fn)
assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn}

transform = PreprocessTransform.from_test_transform(transform=fn)
assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn}

transform = PreprocessTransform.from_predict_transform(transform=fn)
assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn}

class MyPreprocessTransform(PreprocessTransform):
Expand Down

0 comments on commit 64e376d

Please sign in to comment.