From b931fc1ddaeb2ffc29c0dfeb7cc8d1b89ed76d1c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 Jan 2022 20:51:42 +0000 Subject: [PATCH 1/3] Docstrings for `SpeechRecognitionData` (#1097) --- CHANGELOG.md | 6 + flash/audio/speech_recognition/cli.py | 4 +- flash/audio/speech_recognition/data.py | 418 ++++++++++++++++++++++-- flash/audio/speech_recognition/input.py | 38 ++- flash/core/data/utilities/samples.py | 26 +- flash_examples/speech_recognition.py | 4 +- 6 files changed, 458 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 991bf5aef1..21fbb8ab07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug with `AudioClassificationData.from_numpy` ([#1096](https://github.com/PyTorchLightning/lightning-flash/pull/1096)) +- Fixed a bug when using `SpeechRecognitionData.from_files` for training / validating / testing ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097)) + +- Fixed a bug when using `SpeechRecognitionData.from_csv` or `from_json` when predicting without targets ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097)) + +- Fixed a bug where `SpeechRecognitionData.from_datasets` did not work as expected ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097)) + ### Removed ## [0.6.0] - 2021-13-12 diff --git a/flash/audio/speech_recognition/cli.py b/flash/audio/speech_recognition/cli.py index a74a930d25..b44470feb8 100644 --- a/flash/audio/speech_recognition/cli.py +++ b/flash/audio/speech_recognition/cli.py @@ -28,8 +28,8 @@ def from_timit( """Downloads and loads the timit data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") return SpeechRecognitionData.from_json( - input_fields="file", - target_fields="text", + "file", + "text", train_file="data/timit/train.json", test_file="data/timit/test.json", val_split=val_split, diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index aef9454eba..0bce404ac6 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Sequence, Type, Union +from typing import Any, Dict, Optional, Sequence, Type from torch.utils.data import Dataset @@ -27,11 +27,17 @@ from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _AUDIO_AVAILABLE from flash.core.utilities.stages import RunningStage +# Skip doctests if requirements aren't available +if not _AUDIO_AVAILABLE: + __doctest_skip__ = ["SpeechRecognitionData", "SpeechRecognitionData.*"] + class SpeechRecognitionData(DataModule): - """Data Module for text classification tasks.""" + """The ``SpeechRecognitionData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for speech recognition.""" input_transform_cls = InputTransform output_transform_cls = SpeechRecognitionOutputTransform @@ -41,11 +47,11 @@ class SpeechRecognitionData(DataModule): def from_files( cls, train_files: Optional[Sequence[str]] = None, - train_targets: Optional[Sequence[Any]] = None, + train_targets: Optional[Sequence[str]] = None, val_files: Optional[Sequence[str]] = None, - val_targets: Optional[Sequence[Any]] = None, + val_targets: Optional[Sequence[str]] = None, test_files: Optional[Sequence[str]] = None, - test_targets: Optional[Sequence[Any]] = None, + test_targets: Optional[Sequence[str]] = None, predict_files: Optional[Sequence[str]] = None, sampling_rate: int = 16000, train_transform: INPUT_TRANSFORM_TYPE = InputTransform, @@ -56,6 +62,70 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from lists of audio files + and corresponding lists of targets. + + The supported file extensions are: ``wav``, ``ogg``, ``flac``, ``mat``, and ``mp3``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_files: The list of audio files to use when training. + train_targets: The list of targets (ground truth speech transcripts) to use when training. + val_files: The list of audio files to use when validating. + val_targets: The list of targets (ground truth speech transcripts) to use when validating. + test_files: The list of audio files to use when testing. + test_targets: The list of targets (ground truth speech transcripts) to use when testing. + predict_files: The list of audio files to use when predicting. + sampling_rate: Sampling rate to use when loading the audio files. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. + + Examples + ________ + + .. testsetup:: + + >>> import numpy as np + >>> import soundfile as sf + >>> samplerate = 44100 + >>> data = np.random.uniform(-1, 1, size=(samplerate * 3, 2)) + >>> _ = [sf.write(f"speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> _ = [sf.write(f"predict_speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> datamodule = SpeechRecognitionData.from_files( + ... train_files=["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... train_targets=["some speech", "some other speech", "some more speech"], + ... predict_files=["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... batch_size=2, + ... ) + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -75,8 +145,8 @@ def from_files( @classmethod def from_csv( cls, - input_fields: Union[str, Sequence[str]], - target_fields: Optional[str] = None, + input_field: str, + target_field: Optional[str] = None, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, @@ -90,20 +160,114 @@ def from_csv( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from CSV files containing + audio file paths and their corresponding targets. + + Input audio file paths will be extracted from the ``input_field`` column in the CSV files. + The supported file extensions are: ``wav``, ``ogg``, ``flac``, ``mat``, and ``mp3``. + The targets will be extracted from the ``target_field`` in the CSV files. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + input_field: The field (column name) in the CSV files containing the audio file paths. + target_field: The field (column name) in the CSV files containing the targets. + train_file: The CSV file to use when training. + val_file: The CSV file to use when validating. + test_file: The CSV file to use when testing. + predict_file: The CSV file to use when predicting. + sampling_rate: Sampling rate to use when loading the audio files. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. + + Examples + ________ + + .. testsetup:: + + >>> import numpy as np + >>> from pandas import DataFrame + >>> import soundfile as sf + >>> samplerate = 44100 + >>> data = np.random.uniform(-1, 1, size=(samplerate * 3, 2)) + >>> _ = [sf.write(f"speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> _ = [sf.write(f"predict_speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> DataFrame.from_dict({ + ... "speech_files": ["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... "targets": ["some speech", "some other speech", "some more speech"], + ... }).to_csv("train_data.csv", index=False) + >>> DataFrame.from_dict({ + ... "speech_files": ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... }).to_csv("predict_data.csv", index=False) + + The file ``train_data.csv`` contains the following: + + .. code-block:: + + speech_files,targets + speech_1.wav,some speech + speech_2.wav,some other speech + speech_3.wav,some more speech + + The file ``predict_data.csv`` contains the following: + + .. code-block:: + + speech_files + predict_speech_1.wav + predict_speech_2.wav + predict_speech_3.wav + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> datamodule = SpeechRecognitionData.from_csv( + ... "speech_files", + ... "targets", + ... train_file="train_data.csv", + ... predict_file="predict_data.csv", + ... batch_size=2, + ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Downloading... + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + >>> os.remove("train_data.csv") + >>> os.remove("predict_data.csv") + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, - input_key=input_fields, - target_key=target_fields, + input_key=input_field, sampling_rate=sampling_rate, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw), + input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -111,8 +275,8 @@ def from_csv( @classmethod def from_json( cls, - input_fields: Union[str, Sequence[str]], - target_fields: Optional[str] = None, + input_field: str, + target_field: Optional[str] = None, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, @@ -127,21 +291,114 @@ def from_json( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from JSON files containing + audio file paths and their corresponding targets. + + Input audio file paths will be extracted from the ``input_field`` field in the JSON files. + The supported file extensions are: ``wav``, ``ogg``, ``flac``, ``mat``, and ``mp3``. + The targets will be extracted from the ``target_field`` field in the JSON files. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + input_field: The field in the JSON files containing the audio file paths. + target_field: The field in the JSON files containing the targets. + train_file: The JSON file to use when training. + val_file: The JSON file to use when validating. + test_file: The JSON file to use when testing. + predict_file: The JSON file to use when predicting. + sampling_rate: Sampling rate to use when loading the audio files. + field: The field that holds the data in the JSON file. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. + + Examples + ________ + + .. testsetup:: + + >>> import numpy as np + >>> from pandas import DataFrame + >>> import soundfile as sf + >>> samplerate = 44100 + >>> data = np.random.uniform(-1, 1, size=(samplerate * 3, 2)) + >>> _ = [sf.write(f"speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> _ = [sf.write(f"predict_speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> DataFrame.from_dict({ + ... "speech_files": ["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... "targets": ["some speech", "some other speech", "some more speech"], + ... }).to_json("train_data.json", orient="records", lines=True) + >>> DataFrame.from_dict({ + ... "speech_files": ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... }).to_json("predict_data.json", orient="records", lines=True) + + The file ``train_data.json`` contains the following: + + .. code-block:: + + {"speech_files":"speech_1.wav","targets":"some speech"} + {"speech_files":"speech_2.wav","targets":"some other speech"} + {"speech_files":"speech_3.wav","targets":"some more speech"} + + The file ``predict_data.json`` contains the following: + + .. code-block:: + + {"speech_files":"predict_speech_1.wav"} + {"speech_files":"predict_speech_2.wav"} + {"speech_files":"predict_speech_3.wav"} + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> datamodule = SpeechRecognitionData.from_json( + ... "speech_files", + ... "targets", + ... train_file="train_data.json", + ... predict_file="predict_data.json", + ... batch_size=2, + ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Downloading... + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + >>> os.remove("train_data.json") + >>> os.remove("predict_data.json") + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, - input_key=input_fields, - target_key=target_fields, + input_key=input_field, sampling_rate=sampling_rate, field=field, ) return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), + input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw), + input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw), + input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -162,6 +419,129 @@ def from_datasets( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "SpeechRecognitionData": + """Load the :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData` from PyTorch Dataset + objects. + + The Dataset objects should be one of the following: + + * A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(file_path or , target)`` + * A PyTorch Dataset where the ``__getitem__`` returns a dict: ``{"input": file_path, "target": target}`` + + The supported file extensions are: ``wav``, ``ogg``, ``flac``, ``mat``, and ``mp3``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_dataset: The Dataset to use when training. + val_dataset: The Dataset to use when validating. + test_dataset: The Dataset to use when testing. + predict_dataset: The Dataset to use when predicting. + sampling_rate: Sampling rate to use when loading the audio files. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. + test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.audio.speech_recognition.data.SpeechRecognitionData`. + + Examples + ________ + + .. testsetup:: + + >>> import numpy as np + >>> import soundfile as sf + >>> samplerate = 44100 + >>> data = np.random.uniform(-1, 1, size=(samplerate * 3, 2)) + >>> _ = [sf.write(f"speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + >>> _ = [sf.write(f"predict_speech_{i}.wav", data, samplerate, subtype='PCM_24') for i in range(1, 4)] + + A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(file_path, target)``: + + .. doctest:: + + >>> from torch.utils.data import Dataset + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> + >>> class CustomDataset(Dataset): + ... def __init__(self, files, targets=None): + ... self.files = files + ... self.targets = targets + ... def __getitem__(self, index): + ... if self.targets is not None: + ... return self.files[index], self.targets[index] + ... return self.files[index] + ... def __len__(self): + ... return len(self.files) + ... + >>> + >>> datamodule = SpeechRecognitionData.from_datasets( + ... train_dataset=CustomDataset( + ... ["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... ["some speech", "some other speech", "some more speech"], + ... ), + ... predict_dataset=CustomDataset( + ... ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... ), + ... batch_size=2, + ... ) + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + A PyTorch Dataset where the ``__getitem__`` returns a dict: ``{"input": file_path, "target": target}``: + + .. doctest:: + + >>> from torch.utils.data import Dataset + >>> from flash import Trainer + >>> from flash.audio import SpeechRecognitionData, SpeechRecognition + >>> + >>> class CustomDataset(Dataset): + ... def __init__(self, files, targets=None): + ... self.files = files + ... self.targets = targets + ... def __getitem__(self, index): + ... if self.targets is not None: + ... return {"input": self.files[index], "target": self.targets[index]} + ... return {"input": self.files[index]} + ... def __len__(self): + ... return len(self.files) + ... + >>> + >>> datamodule = SpeechRecognitionData.from_datasets( + ... train_dataset=CustomDataset( + ... ["speech_1.wav", "speech_2.wav", "speech_3.wav"], + ... ["some speech", "some other speech", "some more speech"], + ... ), + ... predict_dataset=CustomDataset( + ... ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], + ... ), + ... batch_size=2, + ... ) + >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"speech_{i}.wav") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_speech_{i}.wav") for i in range(1, 4)] + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), diff --git a/flash/audio/speech_recognition/input.py b/flash/audio/speech_recognition/input.py index bc2d48fc13..31c1006532 100644 --- a/flash/audio/speech_recognition/input.py +++ b/flash/audio/speech_recognition/input.py @@ -22,15 +22,13 @@ import flash from flash.core.data.io.input import DataKeys, Input from flash.core.data.process import Deserializer -from flash.core.data.utilities.paths import list_valid_files +from flash.core.data.utilities.paths import filter_valid_files, list_valid_files +from flash.core.data.utilities.samples import to_sample, to_samples from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires if _AUDIO_AVAILABLE: import librosa - from datasets import Dataset as HFDataset from datasets import load_dataset -else: - HFDataset = object class SpeechRecognitionDeserializer(Deserializer): @@ -73,7 +71,7 @@ def load_data( self, file: str, input_key: str, - target_key: str, + target_key: Optional[str] = None, field: Optional[str] = None, sampling_rate: int = 16000, filetype: Optional[str] = None, @@ -88,13 +86,21 @@ def load_data( dataset = dataset_dict[stage] meta = {"root": os.path.dirname(file)} + if target_key is not None: + return [ + { + DataKeys.INPUT: input_file, + DataKeys.TARGET: target, + DataKeys.METADATA: meta, + } + for input_file, target in zip(dataset[input_key], dataset[target_key]) + ] return [ { DataKeys.INPUT: input_file, - DataKeys.TARGET: target, DataKeys.METADATA: meta, } - for input_file, target in zip(dataset[input_key], dataset[target_key]) + for input_file in dataset[input_key] ] def load_sample(self, sample: Dict[str, Any]) -> Any: @@ -107,7 +113,7 @@ def load_data( self, file: str, input_key: str, - target_key: str, + target_key: Optional[str] = None, sampling_rate: int = 16000, ): return super().load_data(file, input_key, target_key, sampling_rate=sampling_rate, filetype="csv") @@ -119,7 +125,7 @@ def load_data( self, file: str, input_key: str, - target_key: str, + target_key: Optional[str] = None, field: Optional[str] = None, sampling_rate: int = 16000, ): @@ -130,11 +136,10 @@ class SpeechRecognitionDatasetInput(BaseSpeechRecognition): @requires("audio") def load_data(self, dataset: Dataset, sampling_rate: int = 16000) -> Sequence[Mapping[str, Any]]: self.sampling_rate = sampling_rate - if isinstance(dataset, HFDataset): - dataset = list(zip(dataset["file"], dataset["text"])) return super().load_data(dataset) def load_sample(self, sample: Any) -> Any: + sample = to_sample(sample) if isinstance(sample[DataKeys.INPUT], (str, Path)): sample = super().load_sample(sample, self.sampling_rate) return sample @@ -142,9 +147,16 @@ def load_sample(self, sample: Any) -> Any: class SpeechRecognitionPathsInput(BaseSpeechRecognition): @requires("audio") - def load_data(self, paths: Union[str, List[str]], sampling_rate: int = 16000) -> Sequence: + def load_data( + self, + paths: Union[str, List[str]], + targets: Optional[List[str]] = None, + sampling_rate: int = 16000, + ) -> Sequence: self.sampling_rate = sampling_rate - return [{DataKeys.INPUT: file} for file in list_valid_files(paths, ("wav", "ogg", "flac", "mat", "mp3"))] + if targets is None: + return to_samples(list_valid_files(paths, ("wav", "ogg", "flac", "mat", "mp3"))) + return to_samples(*filter_valid_files(paths, targets, valid_extensions=("wav", "ogg", "flac", "mat", "mp3"))) def load_sample(self, sample: Dict[str, Any]) -> Any: return super().load_sample(sample, self.sampling_rate) diff --git a/flash/core/data/utilities/samples.py b/flash/core/data/utilities/samples.py index 4d3cfbe79e..57a27ce072 100644 --- a/flash/core/data/utilities/samples.py +++ b/flash/core/data/utilities/samples.py @@ -14,10 +14,32 @@ from typing import Any, Dict, List, Optional, TypeVar from flash.core.data.io.input import DataKeys +from flash.core.data.utilities.classification import _is_list_like T = TypeVar("T") +def to_sample(input: Any) -> Dict[str, Any]: + """Cast a single input to a sample dictionary. Uses the following rules: + + * If the input is a dictionary with an "input" key, it will be returned + * If the input is list-like and of length 2 then the first element will be treated as the input and the second + element will be treated as the target + * Else the whole input will be mapped by the input key in the returned sample + + Args: + input: The input to cast to a sample. + + Returns: + A sample dictionary. + """ + if isinstance(input, dict) and DataKeys.INPUT in input: + return input + if _is_list_like(input) and len(input) == 2: + return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]} + return {DataKeys.INPUT: input} + + def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: """Package a list of inputs and, optionally, a list of targets in a list of dictionaries (samples). @@ -29,5 +51,5 @@ def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[D A list of sample dictionaries. """ if targets is None: - return [{DataKeys.INPUT: input} for input in inputs] - return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in zip(inputs, targets)] + return [to_sample(input) for input in inputs] + return [to_sample(input) for input in zip(inputs, targets)] diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py index 8da2e40aeb..b3fc8eba10 100644 --- a/flash_examples/speech_recognition.py +++ b/flash_examples/speech_recognition.py @@ -21,8 +21,8 @@ download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data") datamodule = SpeechRecognitionData.from_json( - input_fields="file", - target_fields="text", + "file", + "text", train_file="data/timit/train.json", test_file="data/timit/test.json", batch_size=4, From 0b8d27d46261526ba13064490b6250010c188977 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 Jan 2022 20:52:05 +0000 Subject: [PATCH 2/3] Remove dependency on private `_get_default_scheduler_config` (#1099) --- flash/core/model.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index b0d1445db4..7c15e4916a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -27,7 +27,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.finetuning import BaseFinetuning -from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -939,12 +938,22 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s return deepcopy(lr_scheduler_fn) def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: + default_scheduler_config = { + "scheduler": None, + "name": None, + "interval": "epoch", + "frequency": 1, + "reduce_on_plateau": False, + "monitor": None, + "strict": True, + "opt_idx": None, + } if isinstance(self.lr_scheduler, str): lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler) lr_scheduler_fn = lr_scheduler_data.pop("fn") lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) lr_scheduler_kwargs: Dict[str, Any] = {} - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config for key, value in lr_scheduler_config.items(): lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value @@ -953,7 +962,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: lr_scheduler_fn = self.lr_scheduler lr_scheduler_metadata: Dict[str, Any] = None lr_scheduler_kwargs: Dict[str, Any] = {} - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config elif isinstance(self.lr_scheduler, Tuple): if len(self.lr_scheduler) not in [2, 3]: @@ -964,7 +973,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: f"2) Of length 3 with the first index containing a str from {self.available_lr_schedulers()} and" f" the second index containing the required keyword arguments to initialize the LR Scheduler and" f" the third index containing a Lightning scheduler configuration dictionary of the format" - f" {_get_default_scheduler_config()}. NOTE: Do not set the `scheduler` key in the" + f" {default_scheduler_config}. NOTE: Do not set the `scheduler` key in the" f" lr_scheduler_config, it will overridden with an instance of the provided scheduler key." ) @@ -990,7 +999,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: lr_scheduler_fn = lr_scheduler_data.pop("fn") lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None) lr_scheduler_kwargs: Dict[str, Any] = self.lr_scheduler[1] - lr_scheduler_config = _get_default_scheduler_config() + lr_scheduler_config = default_scheduler_config for key, value in lr_scheduler_config.items(): lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value if len(self.lr_scheduler) == 3: @@ -1023,11 +1032,11 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]: if not isinstance(lr_scheduler, (_LRScheduler, Dict)): raise MisconfigurationException( f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" - f" configuration with keys belonging to {list(_get_default_scheduler_config().keys())}." + f" configuration with keys belonging to {list(default_scheduler_config.keys())}." ) if isinstance(lr_scheduler, Dict): - dummy_config = _get_default_scheduler_config() + dummy_config = default_scheduler_config if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()): raise MisconfigurationException( f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler" From 7f45fdf61519e295013877c5154e5ca4077547c6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 5 Jan 2022 21:08:02 +0000 Subject: [PATCH 3/3] Docstrings for `StyleTransferData` (#1100) --- flash/image/style_transfer/data.py | 212 ++++++++++++++++++++++++++++- 1 file changed, 211 insertions(+), 1 deletion(-) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index a42e37e096..acb3761d4a 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -19,16 +19,22 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.utilities.imports import _IMAGE_AVAILABLE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.input import ImageClassificationFilesInput, ImageClassificationFolderInput from flash.image.data import ImageNumpyInput, ImageTensorInput from flash.image.style_transfer.input_transform import StyleTransferInputTransform -__all__ = ["StyleTransferInputTransform", "StyleTransferData"] +# Skip doctests if requirements aren't available +if not _IMAGE_AVAILABLE: + __doctest_skip__ = ["StyleTransferData", "StyleTransferData.*"] class StyleTransferData(DataModule): + """The ``StyleTransferData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of + classmethods for loading data for image style transfer.""" + input_transform_cls = StyleTransferInputTransform @classmethod @@ -42,6 +48,60 @@ def from_files( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from lists of image files. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_files: The list of image files to use when training. + predict_files: The list of image files to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. testsetup:: + + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> _ = [rand_image.save(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [rand_image.save(f"predict_image_{i}.png") for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_files( + ... train_files=["image_1.png", "image_2.png", "image_3.png"], + ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import os + >>> _ = [os.remove(f"image_{i}.png") for i in range(1, 4)] + >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -66,6 +126,73 @@ def from_folders( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from folders containing images. + + The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, + ``.tiff``, ``.webp``, and ``.npy``. + Here's the required folder structure: + + .. code-block:: + + train_folder + ├── image_1.png + ├── image_2.png + ├── image_3.png + ... + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_folder: The folder containing images to use when training. + predict_folder: The folder containing images to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. testsetup:: + + >>> import os + >>> from PIL import Image + >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) + >>> os.makedirs("train_folder", exist_ok=True) + >>> os.makedirs("predict_folder", exist_ok=True) + >>> _ = [rand_image.save(os.path.join("train_folder", f"image_{i}.png")) for i in range(1, 4)] + >>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)] + + .. doctest:: + + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_folders( + ... train_folder="train_folder", + ... predict_folder="predict_folder", + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + + .. testcleanup:: + + >>> import shutil + >>> shutil.rmtree("train_folder") + >>> shutil.rmtree("predict_folder") + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -90,6 +217,47 @@ def from_numpy( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from numpy arrays (or lists of + arrays). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The numpy array or list of arrays to use when training. + predict_data: The numpy array or list of arrays to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. doctest:: + + >>> import numpy as np + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_numpy( + ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], + ... predict_data=[np.random.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ ds_kw = dict( data_pipeline_state=DataPipelineState(), @@ -114,6 +282,48 @@ def from_tensors( transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any ) -> "StyleTransferData": + """Load the :class:`~flash.image.style_transfer.data.StyleTransferData` from torch tensors (or lists of + tensors). + + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + + Args: + train_data: The torch tensor or list of tensors to use when training. + predict_data: The torch tensor or list of tensors to use when predicting. + train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. + predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when + predicting. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + + Returns: + The constructed :class:`~flash.image.style_transfer.data.StyleTransferData`. + + Examples + ________ + + .. doctest:: + + >>> import torch + >>> from flash import Trainer + >>> from flash.image import StyleTransfer, StyleTransferData + >>> datamodule = StyleTransferData.from_tensors( + ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], + ... predict_data=[torch.rand(3, 64, 64)], + ... transform_kwargs=dict(image_size=(128, 128)), + ... batch_size=2, + ... ) + >>> model = StyleTransfer() + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + """ + ds_kw = dict( data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs,