diff --git a/CHANGELOG.md b/CHANGELOG.md index edf685f3c1..bcc17a414a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for loading `ObjectDetectionData` with `from_numpy`, `from_images`, and `from_tensors` ([#1372](https://github.com/PyTorchLightning/lightning-flash/pull/1372)) +- Added support for remote data loading with fsspec ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387)) + +- Added support for TSV files to `from_csv` methods ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387)) + +- Added support for more formats when loading audio files ([#1387](https://github.com/PyTorchLightning/lightning-flash/pull/1387)) + ### Changed - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index 9b4cffe20b..20fe9e3fea 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -31,10 +31,9 @@ __________________ ~speech_recognition.data.SpeechRecognitionData ~speech_recognition.model.SpeechRecognition + speech_recognition.input.SpeechRecognitionInputBase speech_recognition.input.SpeechRecognitionCSVInput speech_recognition.input.SpeechRecognitionJSONInput - speech_recognition.input.BaseSpeechRecognition - speech_recognition.input.SpeechRecognitionFileInput speech_recognition.input.SpeechRecognitionPathsInput speech_recognition.input.SpeechRecognitionDatasetInput speech_recognition.input.SpeechRecognitionDeserializer diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 11cf91434f..31468a85f6 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -77,6 +77,18 @@ _________________________________ ~flash.core.data.utilities.collate.wrap_collate ~flash.core.data.utilities.collate.default_collate +flash.core.data.utilities.loading +_________________________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + ~flash.core.data.utilities.loading.load_image + ~flash.core.data.utilities.loading.load_spectrogram + ~flash.core.data.utilities.loading.load_audio + ~flash.core.data.utilities.loading.load_data_frame + flash.core.data.properties __________________________ @@ -132,7 +144,7 @@ _____________________ ~flash.core.data.utils.download_data flash.core.data.io.input -___________________________ +________________________ .. autosummary:: :toctree: generated/ diff --git a/docs/source/general/remote_data_loading.rst b/docs/source/general/remote_data_loading.rst new file mode 100644 index 0000000000..69a1acbbd6 --- /dev/null +++ b/docs/source/general/remote_data_loading.rst @@ -0,0 +1,39 @@ +.. _remote_data_loading: + +******************* +Remote Data Loading +******************* + +Where possible, all file loading in Flash uses the `fsspec library `_. +As a result, file references can use any of the protocols returned by ``fsspec.available_protocols()``. + +For example, you can load :class:`~flash.tabular.classification.data.TabularClassificationData` from a URL to a CSV file: + +.. testcode:: tabular + + from flash.tabular import TabularClassificationData + + datamodule = TabularClassificationData.from_csv( + categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], + numerical_fields="Fare", + target_fields="Survived", + train_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv", + val_split=0.1, + batch_size=8, + ) + +Here's another example, showing how you can load :class:`~flash.image.classification.data.ImageClassificationData` for prediction using images found on the web: + +.. testcode:: image + + from flash.image import ImageClassificationData + + datamodule = ImageClassificationData.from_files( + predict_files=[ + "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/ant_2.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/bee_1.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/bee_2.jpg", + ], + batch_size=4, + ) diff --git a/docs/source/index.rst b/docs/source/index.rst index 5512441ca7..fb79dfde67 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,6 +44,7 @@ Lightning Flash general/optimization general/classification_targets general/customizing_transforms + general/remote_data_loading .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index 558147343f..4c339e4e14 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -35,7 +35,7 @@ The data is provided in CSV files that look like this: 6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q ... -Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.tabular.classification.data.TabularData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularData.from_csv` method. +We can create the :class:`~flash.tabular.classification.data.TabularData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularData.from_csv` method. From :meth:`the API reference `, we need to provide: * **cat_cols**- A list of the names of columns that contain categorical data (strings or integers). diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index 295d9fcd07..e440c96482 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -70,8 +70,9 @@ def from_files( The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. - The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``, - ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``, + ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, + ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``. The targets can be in any of our :ref:`supported classification target formats `. To learn how to customize the transforms applied for each stage, read our @@ -181,8 +182,9 @@ def from_folders( The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. - The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``, - ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``, + ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, + ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``. For train, test, and validation data, the folders are expected to contain a sub-folder for each class. Here's the required structure: @@ -501,8 +503,9 @@ def from_data_frame( Input spectrogram image paths will be extracted from the ``input_field`` in the DataFrame. The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. - The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``, - ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``, + ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, + ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``. The targets will be extracted from the ``target_fields`` in the DataFrame and can be in any of our :ref:`supported classification target formats `. To learn how to customize the transforms applied for each stage, read our @@ -661,8 +664,9 @@ def from_csv( Input spectrogram images will be extracted from the ``input_field`` column in the CSV files. The supported file extensions for precomputed spectrograms are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. - The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.wav``, - ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions for raw audio (where spectrograms will be computed automatically) are: ``.aiff``, + ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, + ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, ``.wav``, ``.nist``, and ``.wavex``. The targets will be extracted from the ``target_fields`` in the CSV files and can be in any of our :ref:`supported classification target formats `. To learn how to customize the transforms applied for each stage, read our @@ -703,6 +707,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -774,6 +780,80 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from PIL import Image + >>> from pandas import DataFrame + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> os.makedirs("train_folder", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_image.save( + ... os.path.join("predict_folder", f"predict_spectrogram_{i}.png") + ... ) for i in range(1, 4)] + >>> DataFrame.from_dict({ + ... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], + ... "targets": ["meow", "bark", "meow"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "images": ["predict_spectrogram_1.png", "predict_spectrogram_2.png", "predict_spectrogram_3.png"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + images targets + spectrogram_1.png meow + spectrogram_2.png bark + spectrogram_3.png meow + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + images + predict_spectrogram_1.png + predict_spectrogram_2.png + predict_spectrogram_3.png + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.audio import AudioClassificationData + >>> from flash.image import ImageClassifier + >>> datamodule = AudioClassificationData.from_csv( + ... "images", + ... "targets", + ... train_file="train_data.tsv", + ... train_images_root="train_folder", + ... predict_file="predict_data.tsv", + ... predict_images_root="predict_folder", + ... transform_kwargs=dict(spectrogram_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['bark', 'meow'] + >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_folder") + >>> shutil.rmtree("predict_folder") + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( diff --git a/flash/audio/classification/input.py b/flash/audio/classification/input.py index fd2a9bc8a3..07eb0e1e67 100644 --- a/flash/audio/classification/input.py +++ b/flash/audio/classification/input.py @@ -16,34 +16,21 @@ import numpy as np import pandas as pd -import torch -from flash.audio.data import AUDIO_EXTENSIONS from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter -from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets -from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE +from flash.core.data.utilities.data_frame import resolve_files, resolve_targets +from flash.core.data.utilities.loading import ( + AUDIO_EXTENSIONS, + IMG_EXTENSIONS, + load_data_frame, + load_spectrogram, + NP_EXTENSIONS, +) +from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE from flash.core.data.utilities.samples import to_samples -from flash.core.data.utils import image_default_loader -from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires -from flash.image.data import IMG_EXTENSIONS, NP_EXTENSIONS - -if _AUDIO_AVAILABLE: - import librosa - from torchaudio.transforms import Spectrogram - - -def spectrogram_loader(filepath: str, sampling_rate: int = 16000, n_fft: int = 400): - if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = image_default_loader(filepath) - data = np.array(img) - elif has_file_allowed_extension(filepath, AUDIO_EXTENSIONS): - waveform, _ = librosa.load(filepath, sr=sampling_rate) - data = Spectrogram(n_fft, normalized=True)(torch.from_numpy(waveform).unsqueeze(0)).permute(1, 2, 0).numpy() - else: - data = np.load(filepath) - return data +from flash.core.utilities.imports import requires class AudioClassificationInput(Input, ClassificationInputMixin): @@ -84,7 +71,7 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: filepath = sample[DataKeys.INPUT] - sample[DataKeys.INPUT] = spectrogram_loader(filepath, sampling_rate=self.sampling_rate, n_fft=self.n_fft) + sample[DataKeys.INPUT] = load_spectrogram(filepath, sampling_rate=self.sampling_rate, n_fft=self.n_fft) sample = super().load_sample(sample) sample[DataKeys.METADATA]["filepath"] = filepath return sample @@ -174,7 +161,7 @@ def load_data( n_fft: int = 400, target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: - data_frame = read_csv(csv_file) + data_frame = load_data_frame(csv_file) if root is None: root = os.path.dirname(csv_file) return super().load_data( diff --git a/flash/audio/data.py b/flash/audio/data.py deleted file mode 100644 index 91e2650eb4..0000000000 --- a/flash/audio/data.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -AUDIO_EXTENSIONS = (".wav", ".ogg", ".flac", ".mat", ".mp3") diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index 19ca452ee8..85a993ebe0 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -59,7 +59,9 @@ def from_files( """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from lists of audio files and corresponding lists of targets. - The supported file extensions are: ``.wav``, ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, + ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, + ``.wav``, ``.nist``, and ``.wavex``. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -150,7 +152,9 @@ def from_csv( audio file paths and their corresponding targets. Input audio file paths will be extracted from the ``input_field`` column in the CSV files. - The supported file extensions are: ``.wav``, ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, + ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, + ``.wav``, ``.nist``, and ``.wavex``. The targets will be extracted from the ``target_field`` in the CSV files. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -175,6 +179,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import numpy as np @@ -220,8 +226,7 @@ def from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, - ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Downloading... + ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -236,6 +241,69 @@ def from_csv( >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with either a ``.tsv``. + + .. testsetup:: + + >>> import numpy as np + >>> from pandas import DataFrame + >>> import soundfile as sf + >>> samplerate = 1000 + >>> data = np.random.uniform(-1, 1, size=(samplerate, 2)) + >>> _ = [sf.write(f"speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> _ = [sf.write(f"predict_speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> DataFrame.from_dict({ + ... "speech_files": ["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... "targets": ["some speech", "some other speech", "some more speech"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "speech_files": ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + speech_files targets + speech_1.wav some speech + speech_2.wav some other speech + speech_3.wav some more speech + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + speech_files + predict_speech_1.wav + predict_speech_2.wav + predict_speech_3.wav + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> datamodule = SpeechRecognitionData.from_csv( + ... "speech_files", + ... "targets", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=2, + ... ) + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( @@ -273,7 +341,9 @@ def from_json( audio file paths and their corresponding targets. Input audio file paths will be extracted from the ``input_field`` field in the JSON files. - The supported file extensions are: ``.wav``, ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, + ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, + ``.wav``, ``.nist``, and ``.wavex``. The targets will be extracted from the ``target_field`` field in the JSON files. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. @@ -397,7 +467,9 @@ def from_datasets( * A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(file_path or , target)`` * A PyTorch Dataset where the ``__getitem__`` returns a dict: ``{"input": file_path, "target": target}`` - The supported file extensions are: ``.wav``, ``.ogg``, ``.flac``, ``.mat``, and ``.mp3``. + The supported file extensions are: ``.aiff``, ``.au``, ``.avr``, ``.caf``, ``.flac``, ``.mat``, ``.mat4``, + ``.mat5``, ``.mpc2k``, ``.ogg``, ``.paf``, ``.pvf``, ``.rf64``, ``.sd2``, ``.ircam``, ``.voc``, ``.w64``, + ``.wav``, ``.nist``, and ``.wavex``. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide `. diff --git a/flash/audio/speech_recognition/input.py b/flash/audio/speech_recognition/input.py index 77651417d5..d98ec62580 100644 --- a/flash/audio/speech_recognition/input.py +++ b/flash/audio/speech_recognition/input.py @@ -20,15 +20,18 @@ from torch.utils.data import Dataset import flash -from flash.audio.data import AUDIO_EXTENSIONS from flash.core.data.io.input import DataKeys, Input, ServeInput +from flash.core.data.utilities.loading import AUDIO_EXTENSIONS, load_audio, load_data_frame from flash.core.data.utilities.paths import filter_valid_files, list_valid_files from flash.core.data.utilities.samples import to_sample, to_samples from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires if _AUDIO_AVAILABLE: import librosa + from datasets import Dataset as HFDataset from datasets import load_dataset +else: + HFDataset = object class SpeechRecognitionDeserializer(ServeInput): @@ -53,41 +56,22 @@ def example_input(self) -> str: return base64.b64encode(f.read()).decode("UTF-8") -class BaseSpeechRecognition(Input): - @staticmethod - def load_sample(sample: Dict[str, Any], sampling_rate: int = 16000) -> Any: - path = sample[DataKeys.INPUT] - if not os.path.isabs(path) and DataKeys.METADATA in sample and "root" in sample[DataKeys.METADATA]: - path = os.path.join(sample[DataKeys.METADATA]["root"], path) - speech_array, sampling_rate = librosa.load(path, sr=sampling_rate) - sample[DataKeys.INPUT] = speech_array - sample[DataKeys.METADATA] = {"sampling_rate": sampling_rate} - return sample - - -class SpeechRecognitionFileInput(BaseSpeechRecognition): +class SpeechRecognitionInputBase(Input): sampling_rate: int @requires("audio") def load_data( self, - file: str, + hf_dataset: HFDataset, + root: str, input_key: str, target_key: Optional[str] = None, - field: Optional[str] = None, sampling_rate: int = 16000, filetype: Optional[str] = None, ) -> Sequence[Mapping[str, Any]]: self.sampling_rate = sampling_rate - stage = self.running_stage.value - if filetype == "json" and field is not None: - dataset_dict = load_dataset(filetype, data_files={stage: str(file)}, field=field) - else: - dataset_dict = load_dataset(filetype, data_files={stage: str(file)}) - - dataset = dataset_dict[stage] - meta = {"root": os.path.dirname(file)} + meta = {"root": root} if target_key is not None: return [ { @@ -95,63 +79,81 @@ def load_data( DataKeys.TARGET: target, DataKeys.METADATA: meta, } - for input_file, target in zip(dataset[input_key], dataset[target_key]) + for input_file, target in zip(hf_dataset[input_key], hf_dataset[target_key]) ] return [ { DataKeys.INPUT: input_file, DataKeys.METADATA: meta, } - for input_file in dataset[input_key] + for input_file in hf_dataset[input_key] ] def load_sample(self, sample: Dict[str, Any]) -> Any: - return super().load_sample(sample, self.sampling_rate) + path = sample[DataKeys.INPUT] + if not os.path.isabs(path) and DataKeys.METADATA in sample and "root" in sample[DataKeys.METADATA]: + path = os.path.join(sample[DataKeys.METADATA]["root"], path) + speech_array = load_audio(path, sampling_rate=self.sampling_rate) + sample[DataKeys.INPUT] = speech_array + sample[DataKeys.METADATA] = {"sampling_rate": self.sampling_rate} + return sample -class SpeechRecognitionCSVInput(SpeechRecognitionFileInput): +class SpeechRecognitionCSVInput(SpeechRecognitionInputBase): @requires("audio") def load_data( self, - file: str, + csv_file: str, input_key: str, target_key: Optional[str] = None, sampling_rate: int = 16000, ): - return super().load_data(file, input_key, target_key, sampling_rate=sampling_rate, filetype="csv") + return super().load_data( + HFDataset.from_pandas(load_data_frame(csv_file)), + os.path.dirname(csv_file), + input_key, + target_key, + sampling_rate=sampling_rate, + ) -class SpeechRecognitionJSONInput(SpeechRecognitionFileInput): +class SpeechRecognitionJSONInput(SpeechRecognitionInputBase): @requires("audio") def load_data( self, - file: str, + json_file: str, input_key: str, target_key: Optional[str] = None, field: Optional[str] = None, sampling_rate: int = 16000, ): - return super().load_data(file, input_key, target_key, field, sampling_rate=sampling_rate, filetype="json") - - -class SpeechRecognitionDatasetInput(BaseSpeechRecognition): + dataset_dict = load_dataset("json", data_files={"data": str(json_file)}, field=field) + return super().load_data( + dataset_dict["data"], + os.path.dirname(json_file), + input_key, + target_key, + sampling_rate=sampling_rate, + filetype="json", + ) + + +class SpeechRecognitionDatasetInput(SpeechRecognitionInputBase): sampling_rate: int @requires("audio") def load_data(self, dataset: Dataset, sampling_rate: int = 16000) -> Sequence[Mapping[str, Any]]: self.sampling_rate = sampling_rate - return super().load_data(dataset) + return dataset def load_sample(self, sample: Any) -> Any: sample = to_sample(sample) if isinstance(sample[DataKeys.INPUT], (str, Path)): - sample = super().load_sample(sample, self.sampling_rate) + sample = super().load_sample(sample) return sample -class SpeechRecognitionPathsInput(BaseSpeechRecognition): - sampling_rate: int - +class SpeechRecognitionPathsInput(SpeechRecognitionInputBase): @requires("audio") def load_data( self, @@ -163,6 +165,3 @@ def load_data( if targets is None: return to_samples(list_valid_files(paths, AUDIO_EXTENSIONS)) return to_samples(*filter_valid_files(paths, targets, valid_extensions=AUDIO_EXTENSIONS)) - - def load_sample(self, sample: Dict[str, Any]) -> Any: - return super().load_sample(sample, self.sampling_rate) diff --git a/flash/core/data/utilities/data_frame.py b/flash/core/data/utilities/data_frame.py index c41903f626..d2f4d7fc8f 100644 --- a/flash/core/data/utilities/data_frame.py +++ b/flash/core/data/utilities/data_frame.py @@ -16,29 +16,8 @@ from typing import Any, Callable, List, Optional, Union import pandas as pd -from pytorch_lightning.utilities import rank_zero_warn from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.utilities.imports import _PANDAS_GREATER_EQUAL_1_3_0 - - -def read_csv(file: PATH_TYPE) -> pd.DataFrame: - """A wrapper for ``pd.read_csv`` which tries to handle errors gracefully. - - Args: - file: The CSV file to read. - - Returns: - A ``DataFrame`` containing the contents of the file. - """ - try: - return pd.read_csv(file, encoding="utf-8") - except UnicodeDecodeError: - rank_zero_warn("A UnicodeDecodeError was raised when reading the CSV. This error will be ignored.") - if _PANDAS_GREATER_EQUAL_1_3_0: - return pd.read_csv(file, encoding="utf-8", encoding_errors="ignore") - else: - return pd.read_csv(file, encoding=None, engine="python") def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> List[Any]: diff --git a/flash/core/data/utilities/loading.py b/flash/core/data/utilities/loading.py new file mode 100644 index 0000000000..a466add7eb --- /dev/null +++ b/flash/core/data/utilities/loading.py @@ -0,0 +1,193 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from functools import partial + +import fsspec +import numpy as np +import pandas as pd +import torch + +from flash.core.data.utilities.paths import has_file_allowed_extension +from flash.core.utilities.imports import _AUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, Image + +if _AUDIO_AVAILABLE: + from torchaudio.transforms import Spectrogram + +if _TORCHVISION_AVAILABLE: + from torchvision.datasets.folder import IMG_EXTENSIONS +else: + IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") + +NP_EXTENSIONS = (".npy",) + +AUDIO_EXTENSIONS = ( + ".aiff", + ".au", + ".avr", + ".caf", + ".flac", + ".mat", + ".mat4", + ".mat5", + ".mpc2k", + ".ogg", + ".paf", + ".pvf", + ".rf64", + ".sd2", + ".ircam", + ".voc", + ".w64", + ".wav", + ".nist", + ".wavex", +) + +CSV_EXTENSIONS = (".csv", ".txt") + +TSV_EXTENSIONS = (".tsv",) + + +def _load_image_from_image(file, drop_alpha: bool = True): + img = Image.open(file) + img.load() + + if img.mode == "RGBA" and drop_alpha: + img = img.convert("RGB") + return img + + +def _load_image_from_numpy(file): + return Image.fromarray(np.load(file).astype("uint8"), "RGB") + + +def _load_spectrogram_from_image(file): + img = _load_image_from_image(file, drop_alpha=False) + return np.array(img).astype("float32") + + +def _load_spectrogram_from_numpy(file): + return np.load(file).astype("float32") + + +def _load_spectrogram_from_audio(file, sampling_rate: int = 16000, n_fft: int = 400): + # Import locally to prevent import errors if system dependencies are not available. + import librosa + from soundfile import SoundFile + + sound_file = SoundFile(file) + waveform, _ = librosa.load(sound_file, sr=sampling_rate) + return Spectrogram(n_fft, normalized=True)(torch.from_numpy(waveform).unsqueeze(0)).permute(1, 2, 0).numpy() + + +def _load_audio_from_audio(file, sampling_rate: int = 16000): + # Import locally to prevent import errors if system dependencies are not available. + import librosa + + waveform, _ = librosa.load(file, sr=sampling_rate) + return waveform + + +def _load_data_frame_from_csv(file, encoding: str): + return pd.read_csv(file, encoding=encoding) + + +def _load_data_frame_from_tsv(file, encoding: str): + return pd.read_csv(file, sep="\t", encoding=encoding) + + +_image_loaders = { + IMG_EXTENSIONS: _load_image_from_image, + NP_EXTENSIONS: _load_image_from_numpy, +} + + +_spectrogram_loaders = { + IMG_EXTENSIONS: _load_spectrogram_from_image, + NP_EXTENSIONS: _load_spectrogram_from_numpy, + AUDIO_EXTENSIONS: _load_spectrogram_from_audio, +} + + +_audio_loaders = { + AUDIO_EXTENSIONS: _load_audio_from_audio, +} + + +_data_frame_loaders = { + CSV_EXTENSIONS: _load_data_frame_from_csv, + TSV_EXTENSIONS: _load_data_frame_from_tsv, +} + + +def _get_loader(file_path: str, loaders): + for extensions, loader in loaders.items(): + if has_file_allowed_extension(file_path, extensions): + return loader + raise ValueError( + f"File: {file_path} has an unsupported extension. Supported extensions: " f"{list(sum(loaders.keys(), ()))}." + ) + + +def load(file_path: str, loaders): + loader = _get_loader(file_path, loaders) + with fsspec.open(file_path) as file: + return loader(file) + + +def load_image(file_path: str): + """Load an image from a file. + + Args: + file_path: The image file to load. + """ + return load(file_path, _image_loaders) + + +def load_spectrogram(file_path: str, sampling_rate: int = 16000, n_fft: int = 400): + """Load a spectrogram from an image or audio file. + + Args: + file_path: The file to load. + sampling_rate: The sampling rate to resample to if loading from an audio file. + n_fft: The size of the FFT to use when creating a spectrogram from an audio file. + """ + loaders = copy.copy(_spectrogram_loaders) + loaders[AUDIO_EXTENSIONS] = partial(loaders[AUDIO_EXTENSIONS], sampling_rate=sampling_rate, n_fft=n_fft) + return load(file_path, loaders) + + +def load_audio(file_path: str, sampling_rate: int = 16000): + """Load a waveform from an audio file. + + Args: + file_path: The file to load. + sampling_rate: The sampling rate to resample to. + """ + loaders = { + extensions: partial(loader, sampling_rate=sampling_rate) for extensions, loader in _audio_loaders.items() + } + return load(file_path, loaders) + + +def load_data_frame(file_path: str, encoding: str = "utf-8"): + """Load a data frame from a CSV (or similar) file. + + Args: + file_path: The file to load. + encoding: The encoding to use when reading the file. + """ + loaders = {extensions: partial(loader, encoding=encoding) for extensions, loader in _data_frame_loaders.items()} + return load(file_path, loaders) diff --git a/flash/core/data/utilities/paths.py b/flash/core/data/utilities/paths.py index a333062703..3304046033 100644 --- a/flash/core/data/utilities/paths.py +++ b/flash/core/data/utilities/paths.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Any, Callable, cast, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, cast, List, Optional, Tuple, Union from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -21,8 +21,6 @@ PATH_TYPE = Union[str, bytes, os.PathLike] -T = TypeVar("T") - # adapted from torchvision: # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L10 diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 41a713b2be..ea7fe10032 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -23,21 +23,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from tqdm.auto import tqdm as tq -from flash.core.utilities.imports import _CORE_TESTING, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _CORE_TESTING from flash.core.utilities.stages import RunningStage # Skip doctests if requirements aren't available if not _CORE_TESTING: __doctest_skip__ = ["download_data"] -if _PIL_AVAILABLE: - from PIL.Image import Image -else: - Image = object - -if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import default_loader - _STAGES_PREFIX = { RunningStage.TRAINING: "train", RunningStage.TESTING: "test", @@ -157,16 +149,3 @@ def convert_to_modules(transforms: Optional[Dict[str, Callable]]): transforms, Iterable, torch.nn.ModuleList, wrong_dtype=(torch.nn.ModuleList, torch.nn.ModuleDict) ) return transforms - - -def image_default_loader(file_path: str, drop_alpha: bool = True) -> Image: - """Default loader for images. - - Args: - file_path: The image file to load. - drop_alpha: If ``True`` (default) then any alpha channels will be silently removed. - """ - img = default_loader(file_path) - if img.mode == "RGBA" and drop_alpha: - img = img.convert("RGB") - return img diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 2e6f8d8711..81ce73782d 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -17,6 +17,7 @@ import numpy as np from flash.core.data.io.input import DataKeys, Input +from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS from flash.core.data.utilities.paths import list_valid_files from flash.core.integrations.icevision.transforms import from_icevision_record from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -57,8 +58,6 @@ def load_data( def predict_load_data( self, paths: Union[str, List[str]], parser: Optional[Type["Parser"]] = None ) -> List[Dict[str, Any]]: - from flash.image.data import IMG_EXTENSIONS, NP_EXTENSIONS # Import locally to prevent circular import - paths = list_valid_files(paths, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) return [{DataKeys.INPUT: path} for path in paths] @@ -67,12 +66,10 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return from_icevision_record(record) def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: - from flash.image.data import image_loader # Import locally to prevent circular import - if isinstance(sample[DataKeys.INPUT], BaseRecord): return self.load_sample(sample) filepath = sample[DataKeys.INPUT] - image = np.array(image_loader(filepath)) + image = np.array(load_image(filepath)) record = BaseRecord([FilepathRecordComponent()]) record.filepath = filepath diff --git a/flash/core/integrations/labelstudio/input.py b/flash/core/integrations/labelstudio/input.py index 4d861c2779..c273187cd0 100644 --- a/flash/core/integrations/labelstudio/input.py +++ b/flash/core/integrations/labelstudio/input.py @@ -13,7 +13,7 @@ from flash.core.data.io.input import DataKeys, Input, IterableInput from flash.core.data.properties import Properties -from flash.core.data.utils import image_default_loader +from flash.core.data.utilities.loading import load_image from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -241,7 +241,7 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" p = sample["file_upload"] # loading image - image = image_default_loader(p) + image = load_image(p) result = { DataKeys.INPUT: image, DataKeys.TARGET: _get_labels_from_sample(sample["label"], self.parameters.classes), diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 37743ceb02..1910c88201 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -775,6 +775,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -843,6 +845,77 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from PIL import Image + >>> from pandas import DataFrame + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> os.makedirs("train_folder", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [rand_image.save(os.path.join("train_folder", f"image_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)] + >>> DataFrame.from_dict({ + ... "images": ["image_1.png", "image_2.png", "image_3.png"], + ... "targets": ["cat", "dog", "cat"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "images": ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + images targets + image_1.png cat + image_2.png dog + image_3.png cat + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + images + predict_image_1.png + predict_image_2.png + predict_image_3.png + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import ImageClassifier, ImageClassificationData + >>> datamodule = ImageClassificationData.from_csv( + ... "images", + ... "targets", + ... train_file="train_data.tsv", + ... train_images_root="train_folder", + ... predict_file="predict_data.tsv", + ... predict_images_root="predict_folder", + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_folder") + >>> shutil.rmtree("predict_folder") + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( target_formatter=target_formatter, diff --git a/flash/image/classification/input.py b/flash/image/classification/input.py index 8898a7a6c5..71731945d8 100644 --- a/flash/image/classification/input.py +++ b/flash/image/classification/input.py @@ -19,7 +19,8 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter -from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets +from flash.core.data.utilities.data_frame import resolve_files, resolve_targets +from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities @@ -177,7 +178,7 @@ def load_data( resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: - data_frame = read_csv(csv_file) + data_frame = load_data_frame(csv_file) if root is None: root = os.path.dirname(csv_file) return super().load_data(data_frame, input_key, target_keys, root, resolver, target_formatter=target_formatter) diff --git a/flash/image/data.py b/flash/image/data.py index f72062e75a..43602892a9 100644 --- a/flash/image/data.py +++ b/flash/image/data.py @@ -16,36 +16,17 @@ from pathlib import Path from typing import Any, Dict, List -import numpy as np import torch import flash from flash.core.data.io.input import DataKeys, Input, ServeInput -from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, PATH_TYPE +from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image, NP_EXTENSIONS +from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.data.utilities.samples import to_samples -from flash.core.data.utils import image_default_loader from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, Image, requires if _TORCHVISION_AVAILABLE: - from torchvision.datasets.folder import IMG_EXTENSIONS from torchvision.transforms.functional import to_pil_image -else: - IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") - -NP_EXTENSIONS = (".npy",) - - -def image_loader(filepath: str): - if has_file_allowed_extension(filepath, IMG_EXTENSIONS): - img = image_default_loader(filepath) - elif has_file_allowed_extension(filepath, NP_EXTENSIONS): - img = Image.fromarray(np.load(filepath).astype("uint8"), "RGB") - else: - raise ValueError( - f"File: {filepath} has an unsupported extension. Supported extensions: " - f"{list(IMG_EXTENSIONS + NP_EXTENSIONS)}." - ) - return img class ImageDeserializer(ServeInput): @@ -82,7 +63,7 @@ def load_data(self, files: List[PATH_TYPE]) -> List[Dict[str, Any]]: def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: filepath = sample[DataKeys.INPUT] - sample[DataKeys.INPUT] = image_loader(filepath) + sample[DataKeys.INPUT] = load_image(filepath) sample = super().load_sample(sample) sample[DataKeys.METADATA]["filepath"] = filepath return sample diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index 6a3c66e6be..b72694fe66 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -17,11 +17,12 @@ import torch from flash.core.data.io.input import DataKeys, Input +from flash.core.data.utilities.loading import IMG_EXTENSIONS, load_image from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TORCHVISION_AVAILABLE, lazy_import -from flash.image.data import image_loader, ImageDeserializer, IMG_EXTENSIONS +from flash.image.data import ImageDeserializer from flash.image.segmentation.output import SegmentationLabelsOutput if _FIFTYONE_AVAILABLE: @@ -104,9 +105,9 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: filepath = sample[DataKeys.INPUT] - sample[DataKeys.INPUT] = to_tensor(image_loader(filepath)) + sample[DataKeys.INPUT] = to_tensor(load_image(filepath)) if DataKeys.TARGET in sample: - sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0] + sample[DataKeys.TARGET] = (to_tensor(load_image(sample[DataKeys.TARGET])) * 255).long()[0] sample = super().load_sample(sample) sample[DataKeys.METADATA]["filepath"] = filepath return sample diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 7a5277653e..d6a46e5a1e 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -235,6 +235,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> from pandas import DataFrame @@ -294,6 +296,68 @@ def from_csv( >>> import os >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "animal": ["cat", "dog", "cat"], + ... "friendly": ["yes", "yes", "no"], + ... "weight": [6, 10, 5], + ... }).to_csv("train_data.tsv", sep="\\t") + >>> predict_data = DataFrame.from_dict({ + ... "friendly": ["yes", "no", "yes"], + ... "weight": [7, 12, 5], + ... }).to_csv("predict_data.tsv", sep="\\t") + + We have a ``train_data.tsv`` with the following contents: + + .. code-block:: + + animal friendly weight + cat yes 6 + dog yes 10 + cat no 5 + + and a ``predict_data.tsv`` with the following contents: + + .. code-block:: + + friendly weight + yes 7 + no 12 + yes 5 + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.tabular import TabularClassifier, TabularClassificationData + >>> datamodule = TabularClassificationData.from_csv( + ... "friendly", + ... "weight", + ... "animal", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=4, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( target_formatter=target_formatter, diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index 298b7171fa..a393f2f2a9 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -16,7 +16,8 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys from flash.core.data.utilities.classification import TargetFormatter -from flash.core.data.utilities.data_frame import read_csv, resolve_targets +from flash.core.data.utilities.data_frame import resolve_targets +from flash.core.data.utilities.loading import load_data_frame from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.input import TabularDataFrameInput @@ -63,7 +64,7 @@ def load_data( ): if file is not None: return super().load_data( - read_csv(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter + load_data_frame(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter ) diff --git a/flash/tabular/input.py b/flash/tabular/input.py index df01df10ae..005c40eb48 100644 --- a/flash/tabular/input.py +++ b/flash/tabular/input.py @@ -18,7 +18,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.data.io.input import DataKeys, Input, ServeInput -from flash.core.data.utilities.data_frame import read_csv from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.classification.utils import ( _compute_normalization, @@ -29,6 +28,7 @@ ) if _PANDAS_AVAILABLE: + import pandas as pd from pandas.core.frame import DataFrame else: DataFrame = object @@ -121,7 +121,7 @@ def __init__(self, *args, parameters: Optional[Dict[str, Any]] = None, **kwargs) def serve_load_sample(self, data: str) -> Any: parameters = self._parameters - df = read_csv(StringIO(data)) + df = pd.read_csv(StringIO(data)) df = _pre_transform( df, parameters["numerical_fields"], diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 51cc4e9bba..c082b1c0c4 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -219,6 +219,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> from pandas import DataFrame @@ -274,6 +276,64 @@ def from_csv( >>> import os >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "age": [2, 4, 1], + ... "animal": ["cat", "dog", "cat"], + ... "weight": [6, 10, 5], + ... }).to_csv("train_data.tsv", sep="\\t") + >>> DataFrame.from_dict({ + ... "animal": ["dog", "dog", "cat"], + ... "weight": [7, 12, 5], + ... }).to_csv("predict_data.tsv", sep="\\t") + + We have a ``train_data.tsv`` with the following contents: + + .. code-block:: + + age animal weight + 2 cat 6 + 4 dog 10 + 1 cat 5 + + and a ``predict_data.tsv`` with the following contents: + + .. code-block:: + + animal weight + dog 7 + dog 12 + cat 5 + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.tabular import TabularRegressor, TabularRegressionData + >>> datamodule = TabularRegressionData.from_csv( + ... "animal", + ... "weight", + ... "age", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=4, + ... ) + >>> model = TabularRegressor.from_data(datamodule, backbone="tabnet") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( categorical_fields=categorical_fields, diff --git a/flash/tabular/regression/input.py b/flash/tabular/regression/input.py index 370e4ee1ea..a673f19ee4 100644 --- a/flash/tabular/regression/input.py +++ b/flash/tabular/regression/input.py @@ -16,7 +16,7 @@ import numpy as np from flash.core.data.io.input import DataKeys -from flash.core.data.utilities.data_frame import read_csv +from flash.core.data.utilities.loading import load_data_frame from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.input import TabularDataFrameInput @@ -54,7 +54,9 @@ def load_data( parameters: Dict[str, Any] = None, ): if file is not None: - return super().load_data(read_csv(file), categorical_fields, numerical_fields, target_field, parameters) + return super().load_data( + load_data_frame(file), categorical_fields, numerical_fields, target_field, parameters + ) class TabularRegressionDictInput(TabularRegressionDataFrameInput): diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 60b9468c4d..2261d1e833 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -93,6 +93,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -133,8 +135,7 @@ def from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, - ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Downloading... + ... ) >>> datamodule.num_classes 3 >>> datamodule.labels @@ -150,6 +151,65 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], + ... "targets": ["positive", "negative", "neutral"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + reviews targets + Best movie ever! positive + Not good negative + Fine I guess neutral + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + reviews + Worst movie ever! + I didn't enjoy it + It was ok + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TextClassifier, TextClassificationData + >>> datamodule = TextClassificationData.from_csv( + ... "reviews", + ... "targets", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 3 + >>> datamodule.labels + ['negative', 'neutral', 'positive'] + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( target_formatter=target_formatter, diff --git a/flash/text/classification/input.py b/flash/text/classification/input.py index c4c96c8ffa..73f0a9f0c8 100644 --- a/flash/text/classification/input.py +++ b/flash/text/classification/input.py @@ -19,6 +19,7 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter +from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE from flash.core.utilities.imports import _TEXT_AVAILABLE, requires @@ -79,8 +80,9 @@ def load_data( target_keys: Optional[Union[str, List[str]]] = None, target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: - dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) - return super().load_data(dataset_dict["data"], input_key, target_keys, target_formatter=target_formatter) + return super().load_data( + Dataset.from_pandas(load_data_frame(csv_file)), input_key, target_keys, target_formatter=target_formatter + ) class TextClassificationJSONInput(TextClassificationInput): diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 43e6158e60..f0d90202f5 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -80,41 +80,30 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os >>> from pandas import DataFrame >>> DataFrame.from_dict({ - ... "id": ["12345", "12346", "12347", "12348"], + ... "id": ["1", "2", "3"], ... "context": [ - ... "this is an answer one. this is a context one", - ... "this is an answer two. this is a context two", - ... "this is an answer three. this is a context three", - ... "this is an answer four. this is a context four", - ... ], - ... "question": [ - ... "this is a question one", - ... "this is a question two", - ... "this is a question three", - ... "this is a question four", - ... ], - ... "answer_text": [ - ... "this is an answer one", - ... "this is an answer two", - ... "this is an answer three", - ... "this is an answer four", + ... "I am three years old", + ... "I am six feet tall", + ... "I am eight years old", ... ], - ... "answer_start": [0, 0, 0, 0], + ... "question": ["How old are you?", "How tall are you?", "How old are you?"], + ... "answer_text": ["three", "six", "eight"], + ... "answer_start": [0, 0, 0], ... }).to_csv("train_data.csv", index=False) >>> DataFrame.from_dict({ - ... "id": ["12349", "12350"], + ... "id": ["4"], ... "context": [ - ... "this is an answer five. this is a context five", - ... "this is an answer six. this is a context six", + ... "I am five feet tall", ... ], ... "question": [ - ... "this is a question five", - ... "this is a question six", + ... "How tall are you?", ... ], ... }).to_csv("predict_data.csv", index=False) @@ -123,19 +112,16 @@ def from_csv( .. code-block:: id,context,question,answer_text,answer_start - 12345,this is an answer one. this is a context one,this is a question one,this is an answer one,0 - 12346,this is an answer two. this is a context two,this is a question two,this is an answer two,0 - 12347,this is an answer three. this is a context three,this is a question three,this is an answer three,0 - 12348,this is an answer four. this is a context four,this is a question four,this is an answer four,0 - + 1,I am three years old,How old are you?,three,0 + 2,I am six feet tall,How tall are you?,six,0 + 3,I am eight years old,How old are you?,eight,0 The file ``predict_data.csv`` contains the following: .. code-block:: id,context,question - 12349,this is an answer five. this is a context five,this is a question five - 12350,this is an answer six. this is a context six,this is a question six + 4,I am five feet tall,How tall are you? .. doctest:: @@ -145,8 +131,7 @@ def from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, - ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Downloading... + ... ) >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -158,6 +143,70 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "id": ["1", "2", "3"], + ... "context": [ + ... "I am three years old", + ... "I am six feet tall", + ... "I am eight years old", + ... ], + ... "question": ["How old are you?", "How tall are you?", "How old are you?"], + ... "answer_text": ["three", "six", "eight"], + ... "answer_start": [0, 0, 0], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "id": ["4"], + ... "context": [ + ... "I am five feet tall", + ... ], + ... "question": [ + ... "How tall are you?", + ... ], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + id context question answer_text answer_start + 1 I am three years old How old are you? three 0 + 2 I am six feet tall How tall are you? six 0 + 3 I am eight years old How old are you? eight 0 + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + id context question + 4 I am five feet tall How tall are you? + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask + >>> datamodule = QuestionAnsweringData.from_csv( + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=2, + ... ) + >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( diff --git a/flash/text/question_answering/input.py b/flash/text/question_answering/input.py index b7498421d5..60302fe505 100644 --- a/flash/text/question_answering/input.py +++ b/flash/text/question_answering/input.py @@ -21,6 +21,7 @@ import flash from flash.core.data.io.input import Input +from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE from flash.core.utilities.imports import _TEXT_AVAILABLE, requires @@ -91,9 +92,8 @@ def load_data( context_column_name: str = "context", answer_column_name: str = "answer", ) -> Dataset: - dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) return super().load_data( - dataset_dict["data"], + Dataset.from_pandas(load_data_frame(csv_file)), question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, diff --git a/flash/text/seq2seq/core/input.py b/flash/text/seq2seq/core/input.py index 88d14adde6..01421fa8c1 100644 --- a/flash/text/seq2seq/core/input.py +++ b/flash/text/seq2seq/core/input.py @@ -15,6 +15,7 @@ import flash from flash.core.data.io.input import DataKeys, Input +from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import PATH_TYPE from flash.core.utilities.imports import _TEXT_AVAILABLE, requires @@ -57,9 +58,8 @@ def load_data( input_key: str, target_key: Optional[str] = None, ) -> Dataset: - dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) return super().load_data( - dataset_dict["data"], + Dataset.from_pandas(load_data_frame(csv_file)), input_key, target_key, ) diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 3496d3ef6f..4f13901b0e 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -79,6 +79,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -117,8 +119,7 @@ def from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, - ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Downloading... + ... ) >>> model = SummarizationTask() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -130,6 +131,59 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "texts": ["A long paragraph", "A news article"], + ... "summaries": ["A short paragraph", "A news headline"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "texts": ["A movie review", "A book chapter"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + texts summaries + A long paragraph A short paragraph + A news article A news headline + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + texts + A movie review + A book chapter + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import SummarizationTask, SummarizationData + >>> datamodule = SummarizationData.from_csv( + ... "texts", + ... "summaries", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=2, + ... ) + >>> model = SummarizationTask() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 2704f03fe2..91a012c3b4 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -79,6 +79,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -116,8 +118,7 @@ def from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, - ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Downloading... + ... ) >>> model = TranslationTask() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -129,6 +130,58 @@ def from_csv( >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> from pandas import DataFrame + >>> DataFrame.from_dict({ + ... "pig latin": ["ayay entencesay inyay igpay atinlay", "ellohay orldway"], + ... "english": ["a sentence in pig latin", "hello world"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "pig latin": ["ayay entencesay orfay edictionpray"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + pig latin english + ayay entencesay inyay igpay atinlay a sentence in pig latin + ellohay orldway hello world + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + pig latin + ayay entencesay orfay edictionpray + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.text import TranslationTask, TranslationData + >>> datamodule = TranslationData.from_csv( + ... "pig latin", + ... "english", + ... train_file="train_data.tsv", + ... predict_file="predict_data.tsv", + ... batch_size=2, + ... ) + >>> model = TranslationTask() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index f60b5039a5..1f833ed1f7 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -648,6 +648,8 @@ def from_csv( Examples ________ + The files can be in Comma Separated Values (CSV) format with either a ``.csv`` or ``.txt`` extension. + .. testsetup:: >>> import os @@ -723,6 +725,84 @@ def from_csv( >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") + + Alternatively, the files can be in Tab Separated Values (TSV) format with a ``.tsv`` extension. + + .. testsetup:: + + >>> import os + >>> import torch + >>> from torchvision import io + >>> from pandas import DataFrame + >>> data = torch.randint(255, (10, 64, 64, 3)) + >>> os.makedirs("train_folder", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [io.write_video( + ... os.path.join("train_folder", f"video_{i}.mp4"), data, 5, "libx264rgb", {"crf": "0"} + ... ) for i in range(1, 4)] + >>> _ = [ + ... io.write_video( + ... os.path.join("predict_folder", f"predict_video_{i}.mp4"), data, 5, "libx264rgb", {"crf": "0"} + ... ) for i in range(1, 4) + ... ] + >>> DataFrame.from_dict({ + ... "videos": ["video_1.mp4", "video_2.mp4", "video_3.mp4"], + ... "targets": ["cat", "dog", "cat"], + ... }).to_csv("train_data.tsv", sep="\\t", index=False) + >>> DataFrame.from_dict({ + ... "videos": ["predict_video_1.mp4", "predict_video_2.mp4", "predict_video_3.mp4"], + ... }).to_csv("predict_data.tsv", sep="\\t", index=False) + + The file ``train_data.tsv`` contains the following: + + .. code-block:: + + videos targets + video_1.mp4 cat + video_2.mp4 dog + video_3.mp4 cat + + The file ``predict_data.tsv`` contains the following: + + .. code-block:: + + videos + predict_video_1.mp4 + predict_video_2.mp4 + predict_video_3.mp4 + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.video import VideoClassifier, VideoClassificationData + >>> datamodule = VideoClassificationData.from_csv( + ... "videos", + ... "targets", + ... train_file="train_data.tsv", + ... train_videos_root="train_folder", + ... predict_file="predict_data.tsv", + ... predict_videos_root="predict_folder", + ... transform_kwargs=dict(image_size=(244, 244)), + ... batch_size=2, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes) + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_folder") + >>> shutil.rmtree("predict_folder") + >>> os.remove("train_data.tsv") + >>> os.remove("predict_data.tsv") """ ds_kw = dict( clip_sampler=clip_sampler, diff --git a/flash/video/classification/input.py b/flash/video/classification/input.py index 027b1393ab..b149587e2e 100644 --- a/flash/video/classification/input.py +++ b/flash/video/classification/input.py @@ -22,7 +22,8 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input, IterableInput from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter -from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets +from flash.core.data.utilities.data_frame import resolve_files, resolve_targets +from flash.core.data.utilities.loading import load_data_frame from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import, requires @@ -193,7 +194,7 @@ def load_data( decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": - data_frame = read_csv(csv_file) + data_frame = load_data_frame(csv_file) if root is None: root = os.path.dirname(csv_file) return super().load_data( @@ -328,7 +329,7 @@ def predict_load_data( decode_audio: bool = False, decoder: str = "pyav", ) -> List[str]: - data_frame = read_csv(csv_file) + data_frame = load_data_frame(csv_file) if root is None: root = os.path.dirname(csv_file) return super().predict_load_data( diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index 9c4dfb37d0..82ee7cbba0 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -37,9 +37,9 @@ # 4. Predict what's on a few images! ants or bees? datamodule = ImageClassificationData.from_files( predict_files=[ - "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", - "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg", - "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/ant_2.jpg", + "https://pl-flash-data.s3.amazonaws.com/images/bee_1.jpg", ], batch_size=3, ) diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index f587ee147e..244eb88455 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -14,17 +14,14 @@ import torch import flash -from flash.core.data.utils import download_data from flash.tabular import TabularClassificationData, TabularClassifier # 1. Create the DataModule -download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") - datamodule = TabularClassificationData.from_csv( categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], numerical_fields="Fare", target_fields="Survived", - train_file="data/titanic/titanic.csv", + train_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv", val_split=0.1, batch_size=8, ) @@ -38,7 +35,7 @@ # 4. Generate predictions from a CSV datamodule = TabularClassificationData.from_csv( - predict_file="data/titanic/titanic.csv", + predict_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv", parameters=datamodule.parameters, batch_size=8, ) diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py index b2a6643e99..a6c77b551b 100644 --- a/flash_examples/tabular_regression.py +++ b/flash_examples/tabular_regression.py @@ -14,12 +14,9 @@ import torch import flash -from flash.core.data.utils import download_data from flash.tabular import TabularRegressionData, TabularRegressor # 1. Create the DataModule -download_data("https://pl-flash-data.s3.amazonaws.com/SeoulBikeData.csv", "./data") - datamodule = TabularRegressionData.from_csv( categorical_fields=["Seasons", "Holiday", "Functioning Day"], numerical_fields=[ @@ -34,7 +31,7 @@ "Snowfall", ], target_field="Rented Bike Count", - train_file="data/SeoulBikeData.csv", + train_file="https://pl-flash-data.s3.amazonaws.com/SeoulBikeData.csv", val_split=0.1, batch_size=8, ) @@ -48,7 +45,7 @@ # 4. Generate predictions from a CSV datamodule = TabularRegressionData.from_csv( - predict_file="data/SeoulBikeData.csv", + predict_file="https://pl-flash-data.s3.amazonaws.com/SeoulBikeData.csv", parameters=datamodule.parameters, batch_size=8, ) diff --git a/tests/core/data/utilities/test_loading.py b/tests/core/data/utilities/test_loading.py new file mode 100644 index 0000000000..054c9fd092 --- /dev/null +++ b/tests/core/data/utilities/test_loading.py @@ -0,0 +1,167 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np +import pytest + +from flash.core.data.utilities.loading import ( + AUDIO_EXTENSIONS, + CSV_EXTENSIONS, + IMG_EXTENSIONS, + load_audio, + load_data_frame, + load_image, + load_spectrogram, + NP_EXTENSIONS, + TSV_EXTENSIONS, +) +from flash.core.utilities.imports import ( + _AUDIO_AVAILABLE, + _AUDIO_TESTING, + _IMAGE_TESTING, + _PANDAS_AVAILABLE, + _TABULAR_TESTING, + Image, +) + +if _AUDIO_AVAILABLE: + import soundfile as sf + +if _PANDAS_AVAILABLE: + from pandas import DataFrame +else: + DataFrame = object + + +def write_image(file_path): + Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")).save(file_path) + + +def write_numpy(file_path): + np.save(file_path, np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + + +def write_audio(file_path): + samplerate = 1000 + data = np.random.uniform(-1, 1, size=(samplerate, 2)) + subtype = "VORBIS" if "ogg" in file_path else "PCM_16" + format = "mat5" if "mat" in file_path else None + sf.write(file_path, data, samplerate, subtype=subtype, format=format) + + +def write_csv(file_path): + DataFrame.from_dict( + { + "animal": ["cat", "dog", "cat"], + "friendly": ["yes", "yes", "no"], + "weight": [6, 10, 5], + } + ).to_csv(file_path) + + +def write_tsv(file_path): + DataFrame.from_dict( + { + "animal": ["cat", "dog", "cat"], + "friendly": ["yes", "yes", "no"], + "weight": [6, 10, 5], + } + ).to_csv(file_path, sep="\t") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.parametrize( + "extension,write", + [(extension, write_image) for extension in IMG_EXTENSIONS] + + [(extension, write_numpy) for extension in NP_EXTENSIONS], +) +def test_load_image(tmpdir, extension, write): + file_path = os.path.join(tmpdir, f"test{extension}") + write(file_path) + + image = load_image(file_path) + + assert isinstance(image, Image.Image) + assert image.mode == "RGB" + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.parametrize( + "extension,write", + [(extension, write_image) for extension in IMG_EXTENSIONS] + + [(extension, write_numpy) for extension in NP_EXTENSIONS] + + [(extension, write_audio) for extension in AUDIO_EXTENSIONS], +) +def test_load_spectrogram(tmpdir, extension, write): + file_path = os.path.join(tmpdir, f"test{extension}") + write(file_path) + + spectrogram = load_spectrogram(file_path) + + assert isinstance(spectrogram, np.ndarray) + assert spectrogram.dtype == np.dtype("float32") + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.parametrize("extension,write", [(extension, write_audio) for extension in AUDIO_EXTENSIONS]) +def test_load_audio(tmpdir, extension, write): + file_path = os.path.join(tmpdir, f"test{extension}") + write(file_path) + + audio = load_audio(file_path) + + assert isinstance(audio, np.ndarray) + assert audio.dtype == np.dtype("float32") + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.") +@pytest.mark.parametrize( + "extension,write", + [(extension, write_csv) for extension in CSV_EXTENSIONS] + [(extension, write_tsv) for extension in TSV_EXTENSIONS], +) +def test_load_data_frame(tmpdir, extension, write): + file_path = os.path.join(tmpdir, f"test{extension}") + write(file_path) + + data_frame = load_data_frame(file_path) + + assert isinstance(data_frame, DataFrame) + + +@pytest.mark.parametrize( + "path, loader, target_type", + [ + pytest.param( + "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", + load_image, + Image.Image, + marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed."), + ), + pytest.param( + "https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg", + load_spectrogram, + np.ndarray, + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed."), + ), + pytest.param( + "https://pl-flash-data.s3.amazonaws.com/titanic.csv", + load_data_frame, + DataFrame, + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed."), + ), + ], +) +def test_load_remote(path, loader, target_type): + assert isinstance(loader(path), target_type) diff --git a/tests/core/data/utilities/test_paths.py b/tests/core/data/utilities/test_paths.py index 1ba01213a3..a7397a7cf2 100644 --- a/tests/core/data/utilities/test_paths.py +++ b/tests/core/data/utilities/test_paths.py @@ -19,9 +19,8 @@ import pytest from numpy import random -from flash.audio.data import AUDIO_EXTENSIONS +from flash.core.data.utilities.loading import AUDIO_EXTENSIONS, IMG_EXTENSIONS, NP_EXTENSIONS from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE -from flash.image.data import IMG_EXTENSIONS, NP_EXTENSIONS def _make_mock_dir(root, mock_files: List) -> List[PATH_TYPE]: diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 47a1aae2b3..8da468b980 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -43,10 +43,6 @@ import fiftyone as fo -def _dummy_image_loader(_): - return torch.rand(3, 196, 196) - - def _rand_image(size: Tuple[int, int] = None): if size is None: _size = np.random.choice([196, 244]) diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py index b9206fbbfe..b9ea6c85e9 100644 --- a/tests/image/classification/test_data_model_integration.py +++ b/tests/image/classification/test_data_model_integration.py @@ -15,7 +15,6 @@ import numpy as np import pytest -import torch from flash import Trainer from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _IMAGE_TESTING, _PIL_AVAILABLE @@ -28,10 +27,6 @@ import fiftyone as fo -def _dummy_image_loader(_): - return torch.rand(3, 224, 224) - - def _rand_image(): return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))