Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor file loading to use fsspec #1387

Merged
merged 18 commits into from
Jul 14, 2022
Merged
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372))

- Added support for remote data loading with fsspec ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387))

- Added support for TSV files to `from_csv` methods ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387))

- Added support for more formats when loading audio files ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387))

### Changed

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))
Expand Down
3 changes: 1 addition & 2 deletions docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ __________________
~speech_recognition.data.SpeechRecognitionData
~speech_recognition.model.SpeechRecognition

speech_recognition.input.SpeechRecognitionInputBase
speech_recognition.input.SpeechRecognitionCSVInput
speech_recognition.input.SpeechRecognitionJSONInput
speech_recognition.input.BaseSpeechRecognition
speech_recognition.input.SpeechRecognitionFileInput
speech_recognition.input.SpeechRecognitionPathsInput
speech_recognition.input.SpeechRecognitionDatasetInput
speech_recognition.input.SpeechRecognitionDeserializer
Expand Down
14 changes: 13 additions & 1 deletion docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ _________________________________
~flash.core.data.utilities.collate.wrap_collate
~flash.core.data.utilities.collate.default_collate

flash.core.data.utilities.loading
_________________________________

.. autosummary::
:toctree: generated/
:nosignatures:

~flash.core.data.utilities.loading.load_image
~flash.core.data.utilities.loading.load_spectrogram
~flash.core.data.utilities.loading.load_audio
~flash.core.data.utilities.loading.load_data_frame

flash.core.data.properties
__________________________

Expand Down Expand Up @@ -132,7 +144,7 @@ _____________________
~flash.core.data.utils.download_data

flash.core.data.io.input
___________________________
________________________

.. autosummary::
:toctree: generated/
Expand Down
39 changes: 39 additions & 0 deletions docs/source/general/remote_data_loading.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. _remote_data_loading:

*******************
Remote Data Loading
*******************

Where possible, all file loading in Flash uses the `fsspec library <https://github.com/fsspec/filesystem_spec>`_.
As a result, file references can use any of the protocols returned by ``fsspec.available_protocols()``.

For example, you can load :class:`~flash.tabular.classification.data.TabularClassificationData` from a URL to a CSV file:

.. testcode:: tabular

from flash.tabular import TabularClassificationData

datamodule = TabularClassificationData.from_csv(
categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_fields="Fare",
target_fields="Survived",
train_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv",
val_split=0.1,
batch_size=8,
)

Here's another example, showing how you can load :class:`~flash.image.classification.data.ImageClassificationData` for prediction using images found on the web:

.. testcode:: image

from flash.image import ImageClassificationData

datamodule = ImageClassificationData.from_files(
predict_files=[
"https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg",
"https://pl-flash-data.s3.amazonaws.com/images/ant_2.jpg",
"https://pl-flash-data.s3.amazonaws.com/images/bee_1.jpg",
"https://pl-flash-data.s3.amazonaws.com/images/bee_2.jpg",
],
batch_size=4,
)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Lightning Flash
general/optimization
general/classification_targets
general/customizing_transforms
general/remote_data_loading

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The data is provided in CSV files that look like this:
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.tabular.classification.data.TabularData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularData.from_csv` method.
We can create the :class:`~flash.tabular.classification.data.TabularData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularData.from_csv` method.
From :meth:`the API reference <flash.tabular.classification.data.TabularData.from_csv>`, we need to provide:

* **cat_cols**- A list of the names of columns that contain categorical data (strings or integers).
Expand Down
96 changes: 88 additions & 8 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ def from_files(

The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``,
``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``,
``.ogg``, ``.flac``, ``.mat``, and ``.mp3``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``,
``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``,
``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``.
The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
Expand Down Expand Up @@ -181,8 +182,9 @@ def from_folders(

The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``,
``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``,
``.ogg``, ``.flac``, ``.mat``, and ``.mp3``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``,
``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``,
``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``.
For train, test, and validation data, the folders are expected to contain a sub-folder for each class.
Here's the required structure:

Expand Down Expand Up @@ -501,8 +503,9 @@ def from_data_frame(
Input spectrogram image paths will be extracted from the ``input_field`` in the DataFrame.
The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``,
``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``,
``.ogg``, ``.flac``, ``.mat``, and ``.mp3``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``,
``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``,
``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``.
The targets will be extracted from the ``target_fields`` in the DataFrame and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
Expand Down Expand Up @@ -661,8 +664,9 @@ def from_csv(
Input spectrogram images will be extracted from the ``input_field`` column in the CSV files.
The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``,
``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``,
``.ogg``, ``.flac``, ``.mat``, and ``.mp3``.
The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``,
``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``,
``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``.
The targets will be extracted from the ``target_fields`` in the CSV files and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
Expand Down Expand Up @@ -703,6 +707,8 @@ def from_csv(
Examples
________

The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension.

.. testsetup::

>>> import os
Expand Down Expand Up @@ -774,6 +780,80 @@ def from_csv(
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.csv")
>>> os.remove("predict_data.csv")

Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension.

.. testsetup::

>>> import os
>>> from PIL import Image
>>> from pandas import DataFrame
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> os.makedirs("train_folder", exist_ok=True)
>>> os.makedirs("predict_folder", exist_ok=True)
>>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)]
>>> _ = [rand_image.save(
... os.path.join("predict_folder", f"predict_spectrogram_{i}.png")
... ) for i in range(1, 4)]
>>> DataFrame.from_dict({
... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"],
... "targets": ["meow", "bark", "meow"],
... }).to_csv("train_data.tsv", sep="\\t", index=False)
>>> DataFrame.from_dict({
... "images": ["predict_spectrogram_1.png", "predict_spectrogram_2.png", "predict_spectrogram_3.png"],
... }).to_csv("predict_data.tsv", sep="\\t", index=False)

The file ``train_data.tsv`` contains the following:

.. code-block::

images targets
spectrogram_1.png meow
spectrogram_2.png bark
spectrogram_3.png meow

The file ``predict_data.tsv`` contains the following:

.. code-block::

images
predict_spectrogram_1.png
predict_spectrogram_2.png
predict_spectrogram_3.png

.. doctest::

>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_csv(
... "images",
... "targets",
... train_file="train_data.tsv",
... train_images_root="train_folder",
... predict_file="predict_data.tsv",
... predict_images_root="predict_folder",
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...

.. testcleanup::

>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.tsv")
>>> os.remove("predict_data.tsv")
"""

ds_kw = dict(
Expand Down
37 changes: 12 additions & 25 deletions flash/audio/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,21 @@

import numpy as np
import pandas as pd
import torch

from flash.audio.data import AUDIO_EXTENSIONS
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE
from flash.core.data.utilities.data_frame import resolve_files, resolve_targets
from flash.core.data.utilities.loading import (
AUDIO_EXTENSIONS,
IMG_EXTENSIONS,
load_data_frame,
load_spectrogram,
NP_EXTENSIONS,
)
from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires
from flash.image.data import IMG_EXTENSIONS, NP_EXTENSIONS

if _AUDIO_AVAILABLE:
import librosa
from torchaudio.transforms import Spectrogram


def spectrogram_loader(filepath: str, sampling_rate: int = 16000, n_fft: int = 400):
if has_file_allowed_extension(filepath, IMG_EXTENSIONS):
img = image_default_loader(filepath)
data = np.array(img)
elif has_file_allowed_extension(filepath, AUDIO_EXTENSIONS):
waveform, _ = librosa.load(filepath, sr=sampling_rate)
data = Spectrogram(n_fft, normalized=True)(torch.from_numpy(waveform).unsqueeze(0)).permute(1, 2, 0).numpy()
else:
data = np.load(filepath)
return data
from flash.core.utilities.imports import requires


class AudioClassificationInput(Input, ClassificationInputMixin):
Expand Down Expand Up @@ -84,7 +71,7 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = spectrogram_loader(filepath, sampling_rate=self.sampling_rate, n_fft=self.n_fft)
sample[DataKeys.INPUT] = load_spectrogram(filepath, sampling_rate=self.sampling_rate, n_fft=self.n_fft)
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
Expand Down Expand Up @@ -174,7 +161,7 @@ def load_data(
n_fft: int = 400,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
data_frame = read_csv(csv_file)
data_frame = load_data_frame(csv_file)
if root is None:
root = os.path.dirname(csv_file)
return super().load_data(
Expand Down
15 changes: 0 additions & 15 deletions flash/audio/data.py

This file was deleted.

Loading