diff --git a/CHANGELOG.md b/CHANGELOG.md index 44061b8a9f..e8980083ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to use any task as an embedder by calling `as_embedder` ([#1396](https://github.com/PyTorchLightning/lightning-flash/pull/1396)) +- Added support for normalization of images in `SemanticSegmentationData` ([#1399](https://github.com/PyTorchLightning/lightning-flash/pull/1399)) + ### Changed - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) diff --git a/flash/image/segmentation/input_transform.py b/flash/image/segmentation/input_transform.py index 944e368f7a..e36d6a2c49 100644 --- a/flash/image/segmentation/input_transform.py +++ b/flash/image/segmentation/input_transform.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Tuple, Union from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform @@ -44,23 +44,43 @@ def remove_extra_dimensions(batch: Dict[str, Any]): class SemanticSegmentationInputTransform(InputTransform): image_size: Tuple[int, int] = (128, 128) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) def train_per_sample_transform(self) -> Callable: - return ApplyToKeys( - [DataKeys.INPUT, DataKeys.TARGET], - KorniaParallelTransforms( - K.geometry.Resize(self.image_size, interpolation="nearest"), K.augmentation.RandomHorizontalFlip(p=0.5) - ), + return T.Compose( + [ + ApplyToKeys( + [DataKeys.INPUT, DataKeys.TARGET], + KorniaParallelTransforms( + K.geometry.Resize(self.image_size, interpolation="nearest"), + K.augmentation.RandomHorizontalFlip(p=0.5), + ), + ), + ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)), + ] ) def per_sample_transform(self) -> Callable: - return ApplyToKeys( - [DataKeys.INPUT, DataKeys.TARGET], - KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")), + return T.Compose( + [ + ApplyToKeys( + [DataKeys.INPUT, DataKeys.TARGET], + KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")), + ), + ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)), + ] ) def predict_per_sample_transform(self) -> Callable: - return ApplyToKeys(DataKeys.INPUT, K.geometry.Resize(self.image_size, interpolation="nearest")) + return ApplyToKeys( + DataKeys.INPUT, + K.geometry.Resize( + self.image_size, + interpolation="nearest", + ), + K.augmentation.Normalize(mean=self.mean, std=self.std), + ) def collate(self) -> Callable: return kornia_collate