Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
added Normalization transformation to per_sample transforms (#1399)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicola Occelli <nocc0001@hpda.ulb.ac.be>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
3 people authored Jul 20, 2022
1 parent 04f1b8f commit 4f6fe93
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support to use any task as an embedder by calling `as_embedder` ([#1396](https://github.com/PyTorchLightning/lightning-flash/pull/1396))

- Added support for normalization of images in `SemanticSegmentationData` ([#1399](https://github.com/PyTorchLightning/lightning-flash/pull/1399))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
40 changes: 30 additions & 10 deletions flash/image/segmentation/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Tuple, Union

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
Expand Down Expand Up @@ -44,23 +44,43 @@ def remove_extra_dimensions(batch: Dict[str, Any]):
class SemanticSegmentationInputTransform(InputTransform):

image_size: Tuple[int, int] = (128, 128)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

def train_per_sample_transform(self) -> Callable:
return ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
K.geometry.Resize(self.image_size, interpolation="nearest"), K.augmentation.RandomHorizontalFlip(p=0.5)
),
return T.Compose(
[
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
K.geometry.Resize(self.image_size, interpolation="nearest"),
K.augmentation.RandomHorizontalFlip(p=0.5),
),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

def per_sample_transform(self) -> Callable:
return ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
return T.Compose(
[
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

def predict_per_sample_transform(self) -> Callable:
return ApplyToKeys(DataKeys.INPUT, K.geometry.Resize(self.image_size, interpolation="nearest"))
return ApplyToKeys(
DataKeys.INPUT,
K.geometry.Resize(
self.image_size,
interpolation="nearest",
),
K.augmentation.Normalize(mean=self.mean, std=self.std),
)

def collate(self) -> Callable:
return kornia_collate
Expand Down

0 comments on commit 4f6fe93

Please sign in to comment.