From f48fe2f4338fadb978e9e3fcb06d6bcb4b28bbe6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 22 Nov 2021 19:21:43 +0000 Subject: [PATCH] Support for new icevision version (0.11.0) (#989) --- CHANGELOG.md | 2 + .../core/integrations/icevision/transforms.py | 62 ++++++++++++++----- flash/core/utilities/imports.py | 1 + 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d0069ed32..d26cdcd97c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion ([#969](https://github.com/PyTorchLightning/lightning-flash/pull/969)) +- Fixed a bug where Flash could not be used with IceVision 0.11.0 ([#989](https://github.com/PyTorchLightning/lightning-flash/pull/989)) + ### Removed - Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939)) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index e254ea4e80..8bc15745cc 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -16,7 +16,7 @@ from torch import nn from flash.core.data.io.input import DataKeys -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, requires +from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _ICEVISION_GREATER_EQUAL_0_11_0, requires if _ICEVISION_AVAILABLE: from icevision.core import tasks @@ -31,12 +31,17 @@ ImageRecordComponent, InstancesLabelsRecordComponent, KeyPointsRecordComponent, - MasksRecordComponent, RecordIDRecordComponent, ) from icevision.data.prediction import Prediction from icevision.tfms import A +if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0: + from icevision.core.mask import MaskFile + from icevision.core.record_components import InstanceMasksRecordComponent +elif _ICEVISION_AVAILABLE: + from icevision.core.record_components import MasksRecordComponent + def to_icevision_record(sample: Dict[str, Any]): record = BaseRecord([]) @@ -65,11 +70,28 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_bboxes(bboxes) record.add_component(component) - if "masks" in sample[DataKeys.TARGET]: - mask_array = MaskArray(sample[DataKeys.TARGET]["masks"]) - component = MasksRecordComponent() - component.set_masks(mask_array) - record.add_component(component) + if _ICEVISION_GREATER_EQUAL_0_11_0: + mask_array = sample[DataKeys.TARGET].get("mask_array", None) + masks = sample[DataKeys.TARGET].get("masks", None) + + if mask_array is not None or masks is not None: + component = InstanceMasksRecordComponent() + + if masks is not None: + masks = [MaskFile(mask) for mask in masks] + component.set_masks(masks) + + if mask_array is not None: + mask_array = MaskArray(mask_array) + component.set_mask_array(mask_array) + + record.add_component(component) + else: + mask_array = sample[DataKeys.TARGET].get("mask_array", None) + if mask_array is not None: + component = MasksRecordComponent() + component.set_masks(mask_array) + record.add_component(component) if "keypoints" in sample[DataKeys.TARGET]: keypoints = [] @@ -118,16 +140,26 @@ def from_icevision_detection(record: "BaseRecord"): for bbox in detection.bboxes ] - if hasattr(detection, "masks"): - masks = detection.masks - - if isinstance(masks, EncodedRLEs): - masks = masks.to_mask(record.height, record.width) + mask_array = ( + getattr(detection, "mask_array", None) if _ICEVISION_GREATER_EQUAL_0_11_0 else getattr(detection, "masks", None) + ) + if mask_array is not None: + if isinstance(mask_array, EncodedRLEs): + mask_array = mask_array.to_mask(record.height, record.width) - if isinstance(masks, MaskArray): - result["masks"] = masks.data + if isinstance(mask_array, MaskArray): + result["mask_array"] = mask_array.data else: - raise RuntimeError("Masks are expected to be a MaskArray or EncodedRLEs.") + raise RuntimeError("Mask arrays are expected to be a MaskArray or EncodedRLEs.") + + masks = getattr(detection, "masks", None) + if masks is not None and _ICEVISION_GREATER_EQUAL_0_11_0: + result["masks"] = [] + for mask in masks: + if isinstance(mask, MaskFile): + result["masks"].append(mask.filepath) + else: + raise RuntimeError("Masks are expected to be MaskFile objects.") if hasattr(detection, "keypoints"): keypoints = detection.keypoints diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 1bd1c5db83..6e643b16c9 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -121,6 +121,7 @@ class Image: _PL_GREATER_EQUAL_1_4_3 = _compare_version("pytorch_lightning", operator.ge, "1.4.3") _PL_GREATER_EQUAL_1_5_0 = _compare_version("pytorch_lightning", operator.ge, "1.5.0") _PANDAS_GREATER_EQUAL_1_3_0 = _compare_version("pandas", operator.ge, "1.3.0") + _ICEVISION_GREATER_EQUAL_0_11_0 = _compare_version("icevision", operator.ge, "0.11.0") _TEXT_AVAILABLE = all( [