Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into docs/semantic_segmentation_data
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Jan 6, 2022
2 parents 6aa37a6 + 7f45fdf commit cc001f7
Show file tree
Hide file tree
Showing 8 changed files with 685 additions and 46 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug with `AudioClassificationData.from_numpy` ([#1096](https://github.com/PyTorchLightning/lightning-flash/pull/1096))

- Fixed a bug when using `SpeechRecognitionData.from_files` for training / validating / testing ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097))

- Fixed a bug when using `SpeechRecognitionData.from_csv` or `from_json` when predicting without targets ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097))

- Fixed a bug where `SpeechRecognitionData.from_datasets` did not work as expected ([#1097](https://github.com/PyTorchLightning/lightning-flash/pull/1097))

- Fixed a bug where loading data for prediction with `SemanticSegmentationData.from_folders` raised an error ([#1101](https://github.com/PyTorchLightning/lightning-flash/pull/1101))

### Removed
Expand Down
4 changes: 2 additions & 2 deletions flash/audio/speech_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def from_timit(
"""Downloads and loads the timit data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
return SpeechRecognitionData.from_json(
input_fields="file",
target_fields="text",
"file",
"text",
train_file="data/timit/train.json",
test_file="data/timit/test.json",
val_split=val_split,
Expand Down
418 changes: 399 additions & 19 deletions flash/audio/speech_recognition/data.py

Large diffs are not rendered by default.

38 changes: 25 additions & 13 deletions flash/audio/speech_recognition/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@
import flash
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.process import Deserializer
from flash.core.data.utilities.paths import list_valid_files
from flash.core.data.utilities.paths import filter_valid_files, list_valid_files
from flash.core.data.utilities.samples import to_sample, to_samples
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires

if _AUDIO_AVAILABLE:
import librosa
from datasets import Dataset as HFDataset
from datasets import load_dataset
else:
HFDataset = object


class SpeechRecognitionDeserializer(Deserializer):
Expand Down Expand Up @@ -73,7 +71,7 @@ def load_data(
self,
file: str,
input_key: str,
target_key: str,
target_key: Optional[str] = None,
field: Optional[str] = None,
sampling_rate: int = 16000,
filetype: Optional[str] = None,
Expand All @@ -88,13 +86,21 @@ def load_data(

dataset = dataset_dict[stage]
meta = {"root": os.path.dirname(file)}
if target_key is not None:
return [
{
DataKeys.INPUT: input_file,
DataKeys.TARGET: target,
DataKeys.METADATA: meta,
}
for input_file, target in zip(dataset[input_key], dataset[target_key])
]
return [
{
DataKeys.INPUT: input_file,
DataKeys.TARGET: target,
DataKeys.METADATA: meta,
}
for input_file, target in zip(dataset[input_key], dataset[target_key])
for input_file in dataset[input_key]
]

def load_sample(self, sample: Dict[str, Any]) -> Any:
Expand All @@ -107,7 +113,7 @@ def load_data(
self,
file: str,
input_key: str,
target_key: str,
target_key: Optional[str] = None,
sampling_rate: int = 16000,
):
return super().load_data(file, input_key, target_key, sampling_rate=sampling_rate, filetype="csv")
Expand All @@ -119,7 +125,7 @@ def load_data(
self,
file: str,
input_key: str,
target_key: str,
target_key: Optional[str] = None,
field: Optional[str] = None,
sampling_rate: int = 16000,
):
Expand All @@ -130,21 +136,27 @@ class SpeechRecognitionDatasetInput(BaseSpeechRecognition):
@requires("audio")
def load_data(self, dataset: Dataset, sampling_rate: int = 16000) -> Sequence[Mapping[str, Any]]:
self.sampling_rate = sampling_rate
if isinstance(dataset, HFDataset):
dataset = list(zip(dataset["file"], dataset["text"]))
return super().load_data(dataset)

def load_sample(self, sample: Any) -> Any:
sample = to_sample(sample)
if isinstance(sample[DataKeys.INPUT], (str, Path)):
sample = super().load_sample(sample, self.sampling_rate)
return sample


class SpeechRecognitionPathsInput(BaseSpeechRecognition):
@requires("audio")
def load_data(self, paths: Union[str, List[str]], sampling_rate: int = 16000) -> Sequence:
def load_data(
self,
paths: Union[str, List[str]],
targets: Optional[List[str]] = None,
sampling_rate: int = 16000,
) -> Sequence:
self.sampling_rate = sampling_rate
return [{DataKeys.INPUT: file} for file in list_valid_files(paths, ("wav", "ogg", "flac", "mat", "mp3"))]
if targets is None:
return to_samples(list_valid_files(paths, ("wav", "ogg", "flac", "mat", "mp3")))
return to_samples(*filter_valid_files(paths, targets, valid_extensions=("wav", "ogg", "flac", "mat", "mp3")))

def load_sample(self, sample: Dict[str, Any]) -> Any:
return super().load_sample(sample, self.sampling_rate)
26 changes: 24 additions & 2 deletions flash/core/data/utilities/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,32 @@
from typing import Any, Dict, List, Optional, TypeVar

from flash.core.data.io.input import DataKeys
from flash.core.data.utilities.classification import _is_list_like

T = TypeVar("T")


def to_sample(input: Any) -> Dict[str, Any]:
"""Cast a single input to a sample dictionary. Uses the following rules:
* If the input is a dictionary with an "input" key, it will be returned
* If the input is list-like and of length 2 then the first element will be treated as the input and the second
element will be treated as the target
* Else the whole input will be mapped by the input key in the returned sample
Args:
input: The input to cast to a sample.
Returns:
A sample dictionary.
"""
if isinstance(input, dict) and DataKeys.INPUT in input:
return input
if _is_list_like(input) and len(input) == 2:
return {DataKeys.INPUT: input[0], DataKeys.TARGET: input[1]}
return {DataKeys.INPUT: input}


def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
"""Package a list of inputs and, optionally, a list of targets in a list of dictionaries (samples).
Expand All @@ -29,5 +51,5 @@ def to_samples(inputs: List[Any], targets: Optional[List[Any]] = None) -> List[D
A list of sample dictionaries.
"""
if targets is None:
return [{DataKeys.INPUT: input} for input in inputs]
return [{DataKeys.INPUT: input, DataKeys.TARGET: target} for input, target in zip(inputs, targets)]
return [to_sample(input) for input in inputs]
return [to_sample(input) for input in zip(inputs, targets)]
23 changes: 16 additions & 7 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -939,12 +938,22 @@ def _get_lr_scheduler_class_from_registry(self, lr_scheduler_key: str) -> Dict[s
return deepcopy(lr_scheduler_fn)

def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
default_scheduler_config = {
"scheduler": None,
"name": None,
"interval": "epoch",
"frequency": 1,
"reduce_on_plateau": False,
"monitor": None,
"strict": True,
"opt_idx": None,
}
if isinstance(self.lr_scheduler, str):
lr_scheduler_data: Dict[str, Any] = self._get_lr_scheduler_class_from_registry(self.lr_scheduler)
lr_scheduler_fn = lr_scheduler_data.pop("fn")
lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None)
lr_scheduler_kwargs: Dict[str, Any] = {}
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config
for key, value in lr_scheduler_config.items():
lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value

Expand All @@ -953,7 +962,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
lr_scheduler_fn = self.lr_scheduler
lr_scheduler_metadata: Dict[str, Any] = None
lr_scheduler_kwargs: Dict[str, Any] = {}
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config

elif isinstance(self.lr_scheduler, Tuple):
if len(self.lr_scheduler) not in [2, 3]:
Expand All @@ -964,7 +973,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
f"2) Of length 3 with the first index containing a str from {self.available_lr_schedulers()} and"
f" the second index containing the required keyword arguments to initialize the LR Scheduler and"
f" the third index containing a Lightning scheduler configuration dictionary of the format"
f" {_get_default_scheduler_config()}. NOTE: Do not set the `scheduler` key in the"
f" {default_scheduler_config}. NOTE: Do not set the `scheduler` key in the"
f" lr_scheduler_config, it will overridden with an instance of the provided scheduler key."
)

Expand All @@ -990,7 +999,7 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
lr_scheduler_fn = lr_scheduler_data.pop("fn")
lr_scheduler_metadata: Dict[str, Any] = lr_scheduler_data.pop("metadata", None)
lr_scheduler_kwargs: Dict[str, Any] = self.lr_scheduler[1]
lr_scheduler_config = _get_default_scheduler_config()
lr_scheduler_config = default_scheduler_config
for key, value in lr_scheduler_config.items():
lr_scheduler_config[key] = lr_scheduler_metadata.pop(key, None) or value
if len(self.lr_scheduler) == 3:
Expand Down Expand Up @@ -1023,11 +1032,11 @@ def _instantiate_lr_scheduler(self, optimizer: Optimizer) -> Dict[str, Any]:
if not isinstance(lr_scheduler, (_LRScheduler, Dict)):
raise MisconfigurationException(
f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"
f" configuration with keys belonging to {list(_get_default_scheduler_config().keys())}."
f" configuration with keys belonging to {list(default_scheduler_config.keys())}."
)

if isinstance(lr_scheduler, Dict):
dummy_config = _get_default_scheduler_config()
dummy_config = default_scheduler_config
if not all(config_key in dummy_config.keys() for config_key in lr_scheduler.keys()):
raise MisconfigurationException(
f"Please make sure that your custom configuration outputs either an LR Scheduler or a scheduler"
Expand Down
Loading

0 comments on commit cc001f7

Please sign in to comment.