Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Improvements to icevision data loading #889

Closed
wants to merge 15 commits into from
2 changes: 1 addition & 1 deletion flash/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.1rc0"
__version__ = "0.5.1rc1"
__author__ = "PyTorchLightning et al."
__author_email__ = "name@pytorchlightning.ai"
__license__ = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
return str(filename).lower().endswith(extensions)


# Credit to the PyTorchVision Team:
Expand Down
46 changes: 23 additions & 23 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,12 @@
import inspect
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type

import numpy as np

from flash.core.data.data_source import DefaultDataKeys, LabelsState
from flash.core.integrations.icevision.transforms import from_icevision_record
from flash.core.integrations.icevision.transforms import from_icevision_record, to_icevision_record
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource

if _ICEVISION_AVAILABLE:
from icevision.core.record import BaseRecord
from icevision.core.record_components import ClassMapRecordComponent, FilepathRecordComponent, tasks
from icevision.data.data_splitter import SingleSplitSplitter
from icevision.parsers.parser import Parser


Expand All @@ -33,22 +28,9 @@ def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None
return super().predict_load_data(data, dataset)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
record = sample[DefaultDataKeys.INPUT].load()
return from_icevision_record(record)

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DefaultDataKeys.INPUT], BaseRecord):
# load the data via IceVision Base Record
return self.load_sample(sample)
# load the data using numpy
filepath = sample[DefaultDataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DefaultDataKeys.INPUT])

record = BaseRecord([FilepathRecordComponent()])
record.filepath = filepath
record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
record = to_icevision_record(sample)
record.autofix()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this voodoo function do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The autofix just has some simple logic to deal with invalid annotations. E.g. often you get bounding boxes that are out of bounds due to floating point wierdness, and this would clamp them to the width / height of the image.

return from_icevision_record(record)


Expand All @@ -68,12 +50,30 @@ def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Seq
raise ValueError("The parser must be a callable or an IceVision Parser type.")
dataset.num_classes = parser.class_map.num_classes
self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)]))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]

return [{DefaultDataKeys.INPUT: sample, DefaultDataKeys.METADATA: {"parser": parser}} for sample in parser]
raise ValueError("The parser argument must be provided.")

def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
result = super().predict_load_data(data, dataset)
if len(result) == 0:
result = self.load_data(data, dataset)
return result

def load_sample(self, sample: Dict[str, Any]):
parser = sample[DefaultDataKeys.METADATA]["parser"]
sample = sample[DefaultDataKeys.INPUT]

# Adapted from IceVision source code
parser.prepare(sample)
# TODO: Do we still need idmap?
true_record_id = parser.record_id(sample)
record_id = parser.idmap[true_record_id]

record = parser.create_record()
# HACK: fix record_id (needs to be transformed with idmap)
record.set_record_id(record_id)
is_new = True

parser.parse_fields(sample, record=record, is_new=is_new)
return super().load_sample(from_icevision_record(record))
27 changes: 12 additions & 15 deletions flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any, Callable, Dict, List, Tuple

from torch import nn
Expand Down Expand Up @@ -51,12 +52,12 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_class_map(metadata.get("class_map", None))
record.add_component(component)

if "labels" in sample[DefaultDataKeys.TARGET]:
if "labels" in sample.get(DefaultDataKeys.TARGET, {}):
labels_component = InstancesLabelsRecordComponent()
labels_component.add_labels_by_id(sample[DefaultDataKeys.TARGET]["labels"])
record.add_component(labels_component)

if "bboxes" in sample[DefaultDataKeys.TARGET]:
if "bboxes" in sample.get(DefaultDataKeys.TARGET, {}):
bboxes = [
BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"])
for bbox in sample[DefaultDataKeys.TARGET]["bboxes"]
Expand All @@ -65,13 +66,13 @@ def to_icevision_record(sample: Dict[str, Any]):
component.set_bboxes(bboxes)
record.add_component(component)

if "masks" in sample[DefaultDataKeys.TARGET]:
if "masks" in sample.get(DefaultDataKeys.TARGET, {}):
mask_array = MaskArray(sample[DefaultDataKeys.TARGET]["masks"])
component = MasksRecordComponent()
component.set_masks(mask_array)
record.add_component(component)

if "keypoints" in sample[DefaultDataKeys.TARGET]:
if "keypoints" in sample.get(DefaultDataKeys.TARGET, {}):
keypoints = []

for keypoints_list, keypoints_metadata in zip(
Expand All @@ -92,7 +93,7 @@ def to_icevision_record(sample: Dict[str, Any]):
else:
if "filepath" in metadata:
input_component = FilepathRecordComponent()
input_component.filepath = metadata["filepath"]
input_component.filepath = Path(metadata["filepath"])
else:
input_component = ImageRecordComponent()
input_component.composite = record
Expand Down Expand Up @@ -160,11 +161,10 @@ def from_icevision_detection(record: "BaseRecord"):


def from_icevision_record(record: "BaseRecord"):
sample = {
DefaultDataKeys.METADATA: {
"size": (record.height, record.width),
}
}
sample = {DefaultDataKeys.METADATA: {}}

if getattr(record, "height", None) is not None and getattr(record, "width", None) is not None:
sample[DefaultDataKeys.METADATA]["size"] = (record.height, record.width)

if getattr(record, "record_id", None) is not None:
sample[DefaultDataKeys.METADATA]["image_id"] = record.record_id
Expand All @@ -174,11 +174,8 @@ def from_icevision_record(record: "BaseRecord"):

if record.img is not None:
sample[DefaultDataKeys.INPUT] = record.img
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DefaultDataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
sample[DefaultDataKeys.INPUT] = record.filepath
elif getattr(record, "filepath", None) is not None:
sample[DefaultDataKeys.INPUT] = str(record.filepath)

sample[DefaultDataKeys.TARGET] = from_icevision_detection(record)

Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources, FiftyOneDataSource
from flash.core.data.process import Preprocess
from flash.core.integrations.icevision.data import IceVisionParserDataSource, IceVisionPathsDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.core.integrations.icevision.transforms import default_transforms, from_icevision_record
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _ICEVISION_AVAILABLE, lazy_import, requires

SampleCollection = None
Expand Down Expand Up @@ -125,7 +125,7 @@ def load_data(self, data: SampleCollection, dataset: Optional[Any] = None) -> Se

parser = FiftyOneParser(data, class_map, self.label_field, self.iscrowd)
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DefaultDataKeys.INPUT: record} for record in records[0]]
return [from_icevision_record(record) for record in records[0]]

@staticmethod
@requires("fiftyone")
Expand Down