From 768a85d6c67dd66cd0906f3f50dbd82669e504df Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 25 Oct 2021 19:21:11 +0100 Subject: [PATCH 01/11] Try fix --- flash/core/data/data_source.py | 2 +- flash/core/integrations/icevision/data.py | 27 ++----------------- .../core/integrations/icevision/transforms.py | 13 ++++----- flash/image/detection/data.py | 4 +-- flash_examples/object_detection.py | 2 +- 5 files changed, 11 insertions(+), 37 deletions(-) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index cf279dcbba..88d4f2131c 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -71,7 +71,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo Returns: bool: True if the filename ends with one of given extensions """ - return filename.lower().endswith(extensions) + return str(filename).lower().endswith(extensions) # Credit to the PyTorchVision Team: diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 246ace7e13..c3014c208d 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -14,16 +14,12 @@ import inspect from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type -import numpy as np - -from flash.core.data.data_source import DefaultDataKeys, LabelsState +from flash.core.data.data_source import LabelsState from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: - from icevision.core.record import BaseRecord - from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks from icevision.data.data_splitter import SingleSplitSplitter from icevision.parsers.parser import Parser @@ -32,25 +28,6 @@ class IceVisionPathsDataSource(ImagePathsDataSource): def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: return super().predict_load_data(data, dataset) - def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - record = sample[DefaultDataKeys.INPUT].load() - return from_icevision_record(record) - - def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord): - # load the data via IceVision Base Record - return self.load_sample(sample) - # load the data using numpy - filepath = sample[DefaultDataKeys.INPUT] - sample = super().load_sample(sample) - image = np.array(sample[DefaultDataKeys.INPUT]) - - record = BaseRecord([FilepathRecordComponent()]) - record.filepath = filepath - record.set_img(image) - record.add_component(ClassMapRecordComponent(task=tasks.detection)) - return from_icevision_record(record) - class IceVisionParserDataSource(IceVisionPathsDataSource): def __init__(self, parser: Optional[Type["Parser"]] = None): @@ -69,7 +46,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + return [from_icevision_record(record) for record in records[0]] raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 5619dfd5af..75b51f2eaf 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -51,12 +51,12 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_class_map(metadata.get("class_map", None)) record.add_component(component) - if "labels" in sample[DefaultDataKeys.TARGET]: + if "labels" in sample.get(DefaultDataKeys.TARGET, {}): labels_component = InstancesLabelsRecordComponent() labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) record.add_component(labels_component) - if "bboxes" in sample[DefaultDataKeys.TARGET]: + if "bboxes" in sample.get(DefaultDataKeys.TARGET, {}): bboxes = [ BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] @@ -65,13 +65,13 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_bboxes(bboxes) record.add_component(component) - if "masks" in sample[DefaultDataKeys.TARGET]: + if "masks" in sample.get(DefaultDataKeys.TARGET, {}): mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) component = MasksRecordComponent() component.set_masks(mask_array) record.add_component(component) - if "keypoints" in sample[DefaultDataKeys.TARGET]: + if "keypoints" in sample.get(DefaultDataKeys.TARGET, {}): keypoints = [] for keypoints_list, keypoints_metadata in zip( @@ -174,10 +174,7 @@ def from_icevision_record(record: "BaseRecord"): if record.img is not None: sample[DefaultDataKeys.INPUT] = record.img - filepath = getattr(record, "filepath", None) - if filepath is not None: - sample[DefaultDataKeys.METADATA]["filepath"] = filepath - elif record.filepath is not None: + elif getattr(record, "filepath", None) is not None: sample[DefaultDataKeys.INPUT] = record.filepath sample[DefaultDataKeys.TARGET] = from_icevision_detection(record) diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9a7e5c31fa..2418fa1ce8 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -18,7 +18,7 @@ from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource -from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.integrations.icevision.transforms import default_transforms, from_icevision_record from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires SampleCollection = None @@ -125,7 +125,7 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + return [from_icevision_record(record) for record in records[0]] @staticmethod @requires("fiftyone") diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 1a5dddbce9..8330492512 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -30,7 +30,7 @@ model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=1) +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! From 6e64b719619706a71f4deb4a90c24c4f732dbcba Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 25 Oct 2021 20:45:41 +0200 Subject: [PATCH 02/11] rc --- flash/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/__about__.py b/flash/__about__.py index 9498434c6d..ba7d39163f 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1rc0" +__version__ = "0.5.1rc1" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" From 91a82471b61a056d86bd90a5a268b4a3fb8ef560 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 25 Oct 2021 19:47:42 +0100 Subject: [PATCH 03/11] Remove test code --- flash_examples/object_detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 8330492512..1a5dddbce9 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -30,7 +30,7 @@ model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) +trainer = flash.Trainer(max_epochs=1) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Detect objects in a few images! From c27e345cb5bc06e8be171a59db50086d7bcd7782 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 26 Oct 2021 11:18:11 +0100 Subject: [PATCH 04/11] Try something --- flash/core/integrations/icevision/data.py | 19 +++++++++++++++++-- .../core/integrations/icevision/transforms.py | 14 +++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index c3014c208d..cb6e8968c5 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -15,25 +15,36 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type from flash.core.data.data_source import LabelsState -from flash.core.integrations.icevision.transforms import from_icevision_record +from flash.core.integrations.icevision.transforms import from_icevision_record, to_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: from icevision.data.data_splitter import SingleSplitSplitter from icevision.parsers.parser import Parser + from icevision.utils.imageio import ImgSize class IceVisionPathsDataSource(ImagePathsDataSource): def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: return super().predict_load_data(data, dataset) + def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: + sample = super().load_sample(sample) + record = to_icevision_record(sample) + record.autofix() + return from_icevision_record(record) + class IceVisionParserDataSource(IceVisionPathsDataSource): def __init__(self, parser: Optional[Type["Parser"]] = None): super().__init__() self.parser = parser + @staticmethod + def _mock_img_size(_) -> ImgSize: + return ImgSize(None, None) + def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: if self.parser is not None: if inspect.isclass(self.parser) and issubclass(self.parser, Parser): @@ -45,7 +56,11 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq raise ValueError("The parser must be a callable or an IceVision Parser type.") dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) - records = parser.parse(data_splitter=SingleSplitSplitter()) + + # Patch img_size to prevent image being loaded + parser.img_size = self._mock_img_size + + records = parser.parse(data_splitter=SingleSplitSplitter(), autofix=False) return [from_icevision_record(record) for record in records[0]] raise ValueError("The parser argument must be provided.") diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 75b51f2eaf..0a9733e75b 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any, Callable, Dict, List, Tuple from torch import nn @@ -92,7 +93,7 @@ def to_icevision_record(sample: Dict[str, Any]): else: if "filepath" in metadata: input_component = FilepathRecordComponent() - input_component.filepath = metadata["filepath"] + input_component.filepath = Path(metadata["filepath"]) else: input_component = ImageRecordComponent() input_component.composite = record @@ -160,11 +161,10 @@ def from_icevision_detection(record: "BaseRecord"): def from_icevision_record(record: "BaseRecord"): - sample = { - DefaultDataKeys.METADATA: { - "size": (record.height, record.width), - } - } + sample = {DefaultDataKeys.METADATA: {}} + + if getattr(record, "height", None) is not None and getattr(record, "width", None) is not None: + sample[DefaultDataKeys.METADATA]["size"] = (record.height, record.width) if getattr(record, "record_id", None) is not None: sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id @@ -175,7 +175,7 @@ def from_icevision_record(record: "BaseRecord"): if record.img is not None: sample[DefaultDataKeys.INPUT] = record.img elif getattr(record, "filepath", None) is not None: - sample[DefaultDataKeys.INPUT] = record.filepath + sample[DefaultDataKeys.INPUT] = str(record.filepath) sample[DefaultDataKeys.TARGET] = from_icevision_detection(record) From a5270c440bc9dff30611ffdae249f628872864af Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 26 Oct 2021 12:32:48 +0100 Subject: [PATCH 05/11] Try something --- flash/core/integrations/icevision/data.py | 32 ++++++++++++++--------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index cb6e8968c5..a721d805fa 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -14,15 +14,13 @@ import inspect from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type -from flash.core.data.data_source import LabelsState +from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.integrations.icevision.transforms import from_icevision_record, to_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: - from icevision.data.data_splitter import SingleSplitSplitter from icevision.parsers.parser import Parser - from icevision.utils.imageio import ImgSize class IceVisionPathsDataSource(ImagePathsDataSource): @@ -41,10 +39,6 @@ def __init__(self, parser: Optional[Type["Parser"]] = None): super().__init__() self.parser = parser - @staticmethod - def _mock_img_size(_) -> ImgSize: - return ImgSize(None, None) - def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: if self.parser is not None: if inspect.isclass(self.parser) and issubclass(self.parser, Parser): @@ -57,11 +51,7 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) - # Patch img_size to prevent image being loaded - parser.img_size = self._mock_img_size - - records = parser.parse(data_splitter=SingleSplitSplitter(), autofix=False) - return [from_icevision_record(record) for record in records[0]] + return [{DefaultDataKeys.INPUT: sample, DefaultDataKeys.METADATA: {"parser": parser}} for sample in parser] raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: @@ -69,3 +59,21 @@ def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequenc if len(result) == 0: result = self.load_data(data, dataset) return result + + def load_sample(self, sample: Dict[str, Any]): + parser = sample[DefaultDataKeys.METADATA]["parser"] + sample = sample[DefaultDataKeys.INPUT] + + # Adapted from IceVision source code + parser.prepare(sample) + # TODO: Do we still need idmap? + true_record_id = parser.record_id(sample) + record_id = parser.idmap[true_record_id] + + record = parser.create_record() + # HACK: fix record_id (needs to be transformed with idmap) + record.set_record_id(record_id) + is_new = True + + parser.parse_fields(sample, record=record, is_new=is_new) + return super().load_sample(from_icevision_record(record)) From b3b1e952057d45ed5d2830a371d5f0529756b59b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 26 Oct 2021 19:17:22 +0100 Subject: [PATCH 06/11] Try fix --- flash/core/integrations/icevision/data.py | 42 ++++++++++++++++------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index a721d805fa..e84bf089a0 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections import defaultdict from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type from flash.core.data.data_source import DefaultDataKeys, LabelsState @@ -51,7 +52,19 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) - return [{DefaultDataKeys.INPUT: sample, DefaultDataKeys.METADATA: {"parser": parser}} for sample in parser] + samples = defaultdict(list) + + for sample in parser: + parser.prepare(sample) + true_record_id = parser.record_id(sample) + record_id = parser.idmap[true_record_id] + + samples[record_id].append(sample) + + return [ + {DefaultDataKeys.INPUT: samples, DefaultDataKeys.METADATA: {"parser": parser}} + for samples in samples.values() + ] raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: @@ -62,18 +75,23 @@ def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequenc def load_sample(self, sample: Dict[str, Any]): parser = sample[DefaultDataKeys.METADATA]["parser"] - sample = sample[DefaultDataKeys.INPUT] + samples = sample[DefaultDataKeys.INPUT] + + record = None - # Adapted from IceVision source code - parser.prepare(sample) - # TODO: Do we still need idmap? - true_record_id = parser.record_id(sample) - record_id = parser.idmap[true_record_id] + for sample in samples: + # Adapted from IceVision source code + parser.prepare(sample) + # TODO: Do we still need idmap? + true_record_id = parser.record_id(sample) + record_id = parser.idmap[true_record_id] - record = parser.create_record() - # HACK: fix record_id (needs to be transformed with idmap) - record.set_record_id(record_id) - is_new = True + is_new = False + if record is None: + record = parser.create_record() + # HACK: fix record_id (needs to be transformed with idmap) + record.set_record_id(record_id) + is_new = True - parser.parse_fields(sample, record=record, is_new=is_new) + parser.parse_fields(sample, record=record, is_new=is_new) return super().load_sample(from_icevision_record(record)) From 370346826dc829a181a2d25321bbb963b931599f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 1 Nov 2021 11:21:42 +0000 Subject: [PATCH 07/11] Try fix --- flash/core/integrations/icevision/data.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index e84bf089a0..0341848586 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type from flash.core.data.data_source import DefaultDataKeys, LabelsState -from flash.core.integrations.icevision.transforms import from_icevision_record, to_icevision_record +from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource @@ -30,9 +30,10 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample = super().load_sample(sample) - record = to_icevision_record(sample) - record.autofix() - return from_icevision_record(record) + return sample + # record = to_icevision_record(sample) + # record.autofix() + # return from_icevision_record(record) class IceVisionParserDataSource(IceVisionPathsDataSource): @@ -94,4 +95,5 @@ def load_sample(self, sample: Dict[str, Any]): is_new = True parser.parse_fields(sample, record=record, is_new=is_new) + record.autofix() return super().load_sample(from_icevision_record(record)) From 65183f62847619e2493809c8ceb3cf74decab61d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 1 Nov 2021 12:01:38 +0000 Subject: [PATCH 08/11] Try fix --- flash/core/integrations/icevision/data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 0341848586..e27a6f342f 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -30,6 +30,8 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample = super().load_sample(sample) + # Hack to avoid a bug in IceVision + sample[DefaultDataKeys.INPUT].shape = sample[DefaultDataKeys.INPUT].size return sample # record = to_icevision_record(sample) # record.autofix() From 19fb4a53fe414bd49a6ea7d3c924ce0669aaec33 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 1 Nov 2021 13:30:57 +0000 Subject: [PATCH 09/11] Fixes --- flash/core/integrations/icevision/data.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index e27a6f342f..0f6a04bbc7 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -31,11 +31,9 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample = super().load_sample(sample) # Hack to avoid a bug in IceVision - sample[DefaultDataKeys.INPUT].shape = sample[DefaultDataKeys.INPUT].size + if not hasattr(sample[DefaultDataKeys.INPUT], "shape"): + sample[DefaultDataKeys.INPUT].shape = sample[DefaultDataKeys.INPUT].size return sample - # record = to_icevision_record(sample) - # record.autofix() - # return from_icevision_record(record) class IceVisionParserDataSource(IceVisionPathsDataSource): From e2c86c39192b5ce4f9ece0ba73608c495a0a9f69 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 2 Nov 2021 12:10:13 +0100 Subject: [PATCH 10/11] cv2 --- flash/core/utilities/imports.py | 1 + flash/image/data.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index e420117e35..73081f0eed 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -86,6 +86,7 @@ def _compare_version(package: str, op, version) -> bool: _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") +_OPENCV_AVAILABLE = _module_available("cv2") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") _FASTFACE_AVAILABLE = _module_available("fastface") _LIBROSA_AVAILABLE = _module_available("librosa") diff --git a/flash/image/data.py b/flash/image/data.py index 5d0eb9cbe5..31ac6d6700 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -30,13 +30,15 @@ TensorDataSource, ) from flash.core.data.process import Deserializer -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires +from flash.core.utilities.imports import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE, Image, requires if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image else: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +if _OPENCV_AVAILABLE: + import cv2 NP_EXTENSIONS = (".npy",) @@ -44,7 +46,14 @@ def image_loader(filepath: str): if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = default_loader(filepath) + try: + img = default_loader(filepath) + except Exception as ex: + if _OPENCV_AVAILABLE: + im = cv2.cvtColor(cv2.imread(str(filepath)), cv2.COLOR_BGR2RGB) + img = Image.fromarray(im) + else: + raise ex elif has_file_allowed_extension(filepath, NP_EXTENSIONS): img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") else: From 76f6d5d40c75ccb1ed9ee2fb22d47975bd8f75c8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 2 Nov 2021 15:14:03 +0100 Subject: [PATCH 11/11] pin_memory=False --- flash/core/data/data_module.py | 10 +++++----- flash/core/data/new_data_module.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index db9e00aff7..d33a42f7f9 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -289,7 +289,7 @@ def _train_dataloader(self) -> DataLoader: drop_last = False else: drop_last = len(train_ds) > self.batch_size - pin_memory = True + pin_memory = False persistent_workers = self.num_workers > 0 if self.sampler is None: @@ -327,7 +327,7 @@ def _val_dataloader(self) -> DataLoader: """Configure the validation dataloader of the datamodule.""" val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) - pin_memory = True + pin_memory = False persistent_workers = self.num_workers > 0 if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -353,7 +353,7 @@ def _test_dataloader(self) -> DataLoader: """Configure the test dataloader of the datamodule.""" test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) - pin_memory = True + pin_memory = False persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -385,7 +385,7 @@ def _predict_dataloader(self) -> DataLoader: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) - pin_memory = True + pin_memory = False persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -401,7 +401,7 @@ def _predict_dataloader(self) -> DataLoader: predict_ds, batch_size=batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=False, collate_fn=collate_fn, persistent_workers=persistent_workers, ) diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index bd19b2057a..e8bc4cf9cc 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -76,7 +76,7 @@ def __init__( batch_size: Optional[int] = None, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - pin_memory: bool = True, + pin_memory: bool = False, persistent_workers: bool = True, ) -> None: