Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

added Normalization transformation to per_sample transforms #1399

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

- Added support for normalization of images in `SemanticSegmentationData` ([#1399](https://github.com/PyTorchLightning/lightning-flash/pull/1399))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
40 changes: 30 additions & 10 deletions flash/image/segmentation/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable, Dict, Tuple, Union

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
Expand Down Expand Up @@ -44,23 +44,43 @@ def remove_extra_dimensions(batch: Dict[str, Any]):
class SemanticSegmentationInputTransform(InputTransform):

image_size: Tuple[int, int] = (128, 128)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

def train_per_sample_transform(self) -> Callable:
return ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
K.geometry.Resize(self.image_size, interpolation="nearest"), K.augmentation.RandomHorizontalFlip(p=0.5)
),
return T.Compose(
[
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
K.geometry.Resize(self.image_size, interpolation="nearest"),
K.augmentation.RandomHorizontalFlip(p=0.5),
),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

def per_sample_transform(self) -> Callable:
return ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
return T.Compose(
[
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

def predict_per_sample_transform(self) -> Callable:
return ApplyToKeys(DataKeys.INPUT, K.geometry.Resize(self.image_size, interpolation="nearest"))
return ApplyToKeys(
DataKeys.INPUT,
K.geometry.Resize(
self.image_size,
interpolation="nearest",
),
K.augmentation.Normalize(mean=self.mean, std=self.std),
)

def collate(self) -> Callable:
return kornia_collate
Expand Down