diff --git a/flash/__about__.py b/flash/__about__.py index 9498434c6d..ba7d39163f 100644 --- a/flash/__about__.py +++ b/flash/__about__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.1rc0" +__version__ = "0.5.1rc1" __author__ = "PyTorchLightning et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index db9e00aff7..d33a42f7f9 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -289,7 +289,7 @@ def _train_dataloader(self) -> DataLoader: drop_last = False else: drop_last = len(train_ds) > self.batch_size - pin_memory = True + pin_memory = False persistent_workers = self.num_workers > 0 if self.sampler is None: @@ -327,7 +327,7 @@ def _val_dataloader(self) -> DataLoader: """Configure the validation dataloader of the datamodule.""" val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) - pin_memory = True + pin_memory = False persistent_workers = self.num_workers > 0 if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -353,7 +353,7 @@ def _test_dataloader(self) -> DataLoader: """Configure the test dataloader of the datamodule.""" test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) - pin_memory = True + pin_memory = False persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -385,7 +385,7 @@ def _predict_dataloader(self) -> DataLoader: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) - pin_memory = True + pin_memory = False persistent_workers = False if isinstance(getattr(self, "trainer", None), pl.Trainer): @@ -401,7 +401,7 @@ def _predict_dataloader(self) -> DataLoader: predict_ds, batch_size=batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=False, collate_fn=collate_fn, persistent_workers=persistent_workers, ) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index cf279dcbba..88d4f2131c 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -71,7 +71,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo Returns: bool: True if the filename ends with one of given extensions """ - return filename.lower().endswith(extensions) + return str(filename).lower().endswith(extensions) # Credit to the PyTorchVision Team: diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index bd19b2057a..e8bc4cf9cc 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -76,7 +76,7 @@ def __init__( batch_size: Optional[int] = None, num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, - pin_memory: bool = True, + pin_memory: bool = False, persistent_workers: bool = True, ) -> None: diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 246ace7e13..0f6a04bbc7 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -12,19 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from collections import defaultdict from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type -import numpy as np - from flash.core.data.data_source import DefaultDataKeys, LabelsState from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE from flash.image.data import ImagePathsDataSource if _ICEVISION_AVAILABLE: - from icevision.core.record import BaseRecord - from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks - from icevision.data.data_splitter import SingleSplitSplitter from icevision.parsers.parser import Parser @@ -33,23 +29,11 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None return super().predict_load_data(data, dataset) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - record = sample[DefaultDataKeys.INPUT].load() - return from_icevision_record(record) - - def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord): - # load the data via IceVision Base Record - return self.load_sample(sample) - # load the data using numpy - filepath = sample[DefaultDataKeys.INPUT] sample = super().load_sample(sample) - image = np.array(sample[DefaultDataKeys.INPUT]) - - record = BaseRecord([FilepathRecordComponent()]) - record.filepath = filepath - record.set_img(image) - record.add_component(ClassMapRecordComponent(task=tasks.detection)) - return from_icevision_record(record) + # Hack to avoid a bug in IceVision + if not hasattr(sample[DefaultDataKeys.INPUT], "shape"): + sample[DefaultDataKeys.INPUT].shape = sample[DefaultDataKeys.INPUT].size + return sample class IceVisionParserDataSource(IceVisionPathsDataSource): @@ -68,8 +52,20 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq raise ValueError("The parser must be a callable or an IceVision Parser type.") dataset.num_classes = parser.class_map.num_classes self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)])) - records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + + samples = defaultdict(list) + + for sample in parser: + parser.prepare(sample) + true_record_id = parser.record_id(sample) + record_id = parser.idmap[true_record_id] + + samples[record_id].append(sample) + + return [ + {DefaultDataKeys.INPUT: samples, DefaultDataKeys.METADATA: {"parser": parser}} + for samples in samples.values() + ] raise ValueError("The parser argument must be provided.") def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: @@ -77,3 +73,27 @@ def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequenc if len(result) == 0: result = self.load_data(data, dataset) return result + + def load_sample(self, sample: Dict[str, Any]): + parser = sample[DefaultDataKeys.METADATA]["parser"] + samples = sample[DefaultDataKeys.INPUT] + + record = None + + for sample in samples: + # Adapted from IceVision source code + parser.prepare(sample) + # TODO: Do we still need idmap? + true_record_id = parser.record_id(sample) + record_id = parser.idmap[true_record_id] + + is_new = False + if record is None: + record = parser.create_record() + # HACK: fix record_id (needs to be transformed with idmap) + record.set_record_id(record_id) + is_new = True + + parser.parse_fields(sample, record=record, is_new=is_new) + record.autofix() + return super().load_sample(from_icevision_record(record)) diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 5619dfd5af..0a9733e75b 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any, Callable, Dict, List, Tuple from torch import nn @@ -51,12 +52,12 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_class_map(metadata.get("class_map", None)) record.add_component(component) - if "labels" in sample[DefaultDataKeys.TARGET]: + if "labels" in sample.get(DefaultDataKeys.TARGET, {}): labels_component = InstancesLabelsRecordComponent() labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"]) record.add_component(labels_component) - if "bboxes" in sample[DefaultDataKeys.TARGET]: + if "bboxes" in sample.get(DefaultDataKeys.TARGET, {}): bboxes = [ BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"]) for bbox in sample[DefaultDataKeys.TARGET]["bboxes"] @@ -65,13 +66,13 @@ def to_icevision_record(sample: Dict[str, Any]): component.set_bboxes(bboxes) record.add_component(component) - if "masks" in sample[DefaultDataKeys.TARGET]: + if "masks" in sample.get(DefaultDataKeys.TARGET, {}): mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"]) component = MasksRecordComponent() component.set_masks(mask_array) record.add_component(component) - if "keypoints" in sample[DefaultDataKeys.TARGET]: + if "keypoints" in sample.get(DefaultDataKeys.TARGET, {}): keypoints = [] for keypoints_list, keypoints_metadata in zip( @@ -92,7 +93,7 @@ def to_icevision_record(sample: Dict[str, Any]): else: if "filepath" in metadata: input_component = FilepathRecordComponent() - input_component.filepath = metadata["filepath"] + input_component.filepath = Path(metadata["filepath"]) else: input_component = ImageRecordComponent() input_component.composite = record @@ -160,11 +161,10 @@ def from_icevision_detection(record: "BaseRecord"): def from_icevision_record(record: "BaseRecord"): - sample = { - DefaultDataKeys.METADATA: { - "size": (record.height, record.width), - } - } + sample = {DefaultDataKeys.METADATA: {}} + + if getattr(record, "height", None) is not None and getattr(record, "width", None) is not None: + sample[DefaultDataKeys.METADATA]["size"] = (record.height, record.width) if getattr(record, "record_id", None) is not None: sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id @@ -174,11 +174,8 @@ def from_icevision_record(record: "BaseRecord"): if record.img is not None: sample[DefaultDataKeys.INPUT] = record.img - filepath = getattr(record, "filepath", None) - if filepath is not None: - sample[DefaultDataKeys.METADATA]["filepath"] = filepath - elif record.filepath is not None: - sample[DefaultDataKeys.INPUT] = record.filepath + elif getattr(record, "filepath", None) is not None: + sample[DefaultDataKeys.INPUT] = str(record.filepath) sample[DefaultDataKeys.TARGET] = from_icevision_detection(record) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index e420117e35..73081f0eed 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -86,6 +86,7 @@ def _compare_version(package: str, op, version) -> bool: _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") +_OPENCV_AVAILABLE = _module_available("cv2") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") _FASTFACE_AVAILABLE = _module_available("fastface") _LIBROSA_AVAILABLE = _module_available("librosa") diff --git a/flash/image/data.py b/flash/image/data.py index 5d0eb9cbe5..31ac6d6700 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -30,13 +30,15 @@ TensorDataSource, ) from flash.core.data.process import Deserializer -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires +from flash.core.utilities.imports import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE, Image, requires if _TORCHVISION_AVAILABLE: from torchvision.datasets.folder import default_loader, IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image else: IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +if _OPENCV_AVAILABLE: + import cv2 NP_EXTENSIONS = (".npy",) @@ -44,7 +46,14 @@ def image_loader(filepath: str): if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = default_loader(filepath) + try: + img = default_loader(filepath) + except Exception as ex: + if _OPENCV_AVAILABLE: + im = cv2.cvtColor(cv2.imread(str(filepath)), cv2.COLOR_BGR2RGB) + img = Image.fromarray(im) + else: + raise ex elif has_file_allowed_extension(filepath, NP_EXTENSIONS): img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") else: diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 9a7e5c31fa..2418fa1ce8 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -18,7 +18,7 @@ from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource from flash.core.data.process import Preprocess from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource -from flash.core.integrations.icevision.transforms import default_transforms +from flash.core.integrations.icevision.transforms import default_transforms, from_icevision_record from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires SampleCollection = None @@ -125,7 +125,7 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd) records = parser.parse(data_splitter=SingleSplitSplitter()) - return [{DefaultDataKeys.INPUT: record} for record in records[0]] + return [from_icevision_record(record) for record in records[0]] @staticmethod @requires("fiftyone")