Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add from_lists to TextClassificationData (#805)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingyiusuen authored Sep 29, 2021
1 parent 0a28672 commit f495d9d
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

- Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))
Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class DefaultDataSources(LightningEnum):
JSON = "json"
DATASETS = "datasets"
FIFTYONE = "fiftyone"
DATAFRAME = "data_frame"
LISTS = "lists"
SENTENCES = "sentences"
LABELSTUDIO = "labelstudio"

# TODO: Create a FlashEnum class???
Expand Down
138 changes: 133 additions & 5 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _multilabel_target(targets, element):

def load_data(
self,
data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]],
data: Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
Expand Down Expand Up @@ -279,6 +279,55 @@ def __setstate__(self, state):
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextListDataSource(TextDataSource):
def load_data(
self,
data: Tuple[List[str], Union[List[Any], List[List[Any]]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
input, target = data
hf_dataset = Dataset.from_dict({"input": input, "labels": target})

if not self.predicting:
if isinstance(target[0], List):
# multi-target
dataset.multi_label = True
dataset.num_classes = len(target[0])
self.set_state(LabelsState(target))
else:
dataset.multi_label = False
if self.training:
labels = list(sorted(list(set(hf_dataset["labels"]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, "labels"))

hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input="input"), batched=True)
hf_dataset.set_format("torch", columns=columns)

return hf_dataset

def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):
def __init__(self, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
Expand Down Expand Up @@ -330,13 +379,14 @@ def __init__(
data_sources={
DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length),
DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length),
"data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length),
"sentences": TextSentencesDataSource(self.backbone, max_length=max_length),
DefaultDataSources.DATAFRAME: TextDataFrameDataSource(self.backbone, max_length=max_length),
DefaultDataSources.LISTS: TextListDataSource(self.backbone, max_length=max_length),
DefaultDataSources.SENTENCES: TextSentencesDataSource(self.backbone, max_length=max_length),
DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource(
backbone=self.backbone, max_length=max_length
),
},
default_data_source="sentences",
default_data_source=DefaultDataSources.SENTENCES,
deserializer=TextDeserializer(backbone, max_length),
)

Expand Down Expand Up @@ -437,7 +487,7 @@ def from_data_frame(
The constructed data module.
"""
return cls.from_data_source(
"data_frame",
DefaultDataSources.DATAFRAME,
(train_data_frame, input_field, target_fields),
(val_data_frame, input_field, target_fields),
(test_data_frame, input_field, target_fields),
Expand All @@ -454,3 +504,81 @@ def from_data_frame(
sampler=sampler,
**preprocess_kwargs,
)

@classmethod
def from_lists(
cls,
train_data: Optional[List[str]] = None,
train_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
val_data: Optional[List[str]] = None,
val_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
test_data: Optional[List[str]] = None,
test_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
predict_data: Optional[List[str]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python
lists.
Args:
train_data: A list of sentences to use as the train inputs.
train_targets: A list of targets to use as the train targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
val_data: A list of sentences to use as the validation inputs.
val_targets: A list of targets to use as the validation targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
test_data: A list of sentences to use as the test inputs.
test_targets: A list of targets to use as the test targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
predict_data: A list of sentences to use when predicting.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
"""
return cls.from_data_source(
DefaultDataSources.LISTS,
(train_data, train_targets),
(val_data, val_targets),
(test_data, test_targets),
predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)
48 changes: 46 additions & 2 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextDataSource,
TextFileDataSource,
TextJSONDataSource,
TextListDataSource,
TextSentencesDataSource,
)
from tests.helpers.utils import _TEXT_TESTING
Expand Down Expand Up @@ -54,10 +55,19 @@


TEST_DATA_FRAME_DATA = pd.DataFrame(
{"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]},
{
"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"],
"lab1": [0, 1, 0],
"lab2": [1, 0, 1],
},
)


TEST_LIST_DATA = ["this is a sentence one", "this is a sentence two", "this is a sentence three"]
TEST_LIST_TARGETS = [0, 1, 0]
TEST_LIST_TARGETS_MULTILABEL = [[0, 1], [1, 0], [0, 1]]


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
path.write_text(TEST_CSV_DATA)
Expand Down Expand Up @@ -134,13 +144,46 @@ def test_from_json_with_field(tmpdir):
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_data_frame():
dm = TextClassificationData.from_data_frame(
"sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
"sentence", "lab1", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_data_frame_multilabel():
dm = TextClassificationData.from_data_frame(
"sentence", ["lab1", "lab2"], backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert all([label in [0, 1] for label in batch["labels"][0]])
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_lists():
dm = TextClassificationData.from_lists(
backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_lists_multilabel():
dm = TextClassificationData.from_lists(
backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS_MULTILABEL, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert all([label in [0, 1] for label in batch["labels"][0]])
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand All @@ -157,6 +200,7 @@ def test_text_module_not_found_error():
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextDataFrameDataSource, {}),
(TextListDataSource, {}),
(TextSentencesDataSource, {}),
],
)
Expand Down

0 comments on commit f495d9d

Please sign in to comment.