Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support for from_dict for Tabular Classification
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali committed May 6, 2022
1 parent 07d63e3 commit e85349d
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
108 changes: 108 additions & 0 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,111 @@ def from_csv(
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

@classmethod
def from_dict(
cls,
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[Union[str, List[str]]] = None,
parameters: Optional[Dict[str, Any]] = None,
train_data: Optional[DataFrame] = None,
val_data: Optional[DataFrame] = None,
test_data: Optional[DataFrame] = None,
predict_data: Optional[DataFrame] = None,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = TabularClassificationListInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "TabularClassificationData":
"""Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given
lists.
.. note::
The ``categorical_fields``, ``numerical_fields``, and ``target_fields`` do not need to be provided if
``parameters`` are passed instead. These can be obtained from the
:attr:`~flash.tabular.data.TabularData.parameters` attribute of the
:class:`~flash.tabular.data.TabularData` object that contains your training data.
The targets will be extracted from the ``target_fields`` in the data frames and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
categorical_fields: The fields (column names) in the data frames containing categorical data.
numerical_fields: The fields (column names) in the data frames containing numerical data.
target_fields: The field (column name) or list of fields in the data frames containing the targets.
parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_fields`` are not
provided (e.g. when loading data for inference or validation).
train_data: The data to use when training.
val_data: The data to use when validating.
test_data: The data to use when testing.
predict_data: The data to use when predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.tabular.classification.data.TabularClassificationData`.
Examples
________
.. testsetup::
>>> train_data = {
... "animal": ["cat", "dog", "cat"],
... "friendly": ["yes", "yes", "no"],
... "weight": [6, 10, 5],
... }
>>> predict_data = {
... "friendly": ["yes", "no", "yes"],
... "weight": [7, 12, 5],
... }
We have dictionaries ``train_data`` and ``predict_data``.
.. doctest::
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_dict(
... "friendly",
... "weight",
... "animal",
... train_data_frame=train_data,
... predict_data_frame=predict_data,
... batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> del train_data
>>> del predict_data
"""
ds_kw = dict(
target_formatter=target_formatter,
categorical_fields=categorical_fields,
numerical_fields=numerical_fields,
target_fields=target_fields,
parameters=parameters,
)

train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw)
ds_kw["parameters"] = train_input.parameters if train_input else parameters
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_data, **ds_kw),
input_cls(RunningStage.TESTING, test_data, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

29 changes: 29 additions & 0 deletions flash/tabular/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,32 @@ def load_data(
return super().load_data(
read_csv(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter
)


class TabularClassificationListInput(TabularDataFrameInput, ClassificationInputMixin):
def load_data(
self,
data: Dict[str, Union[Any, List[Any]]],
categorical_fields: Optional[Union[str, List[str]]] = None,
numerical_fields: Optional[Union[str, List[str]]] = None,
target_fields: Optional[Union[str, List[str]]] = None,
parameters: Dict[str, Any] = None,
target_formatter: Optional[TargetFormatter] = None,
):
# Convert the data (dict) to a Pandas DataFrame
data_frame = DataFrame.from_dict(data)

cat_vars, num_vars = self.preprocess(data_frame, categorical_fields, numerical_fields, parameters)

if not self.predicting:
targets = resolve_targets(data_frame, target_fields)
self.load_target_metadata(targets, target_formatter=target_formatter)
return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, targets)]
else:
return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)]

def load_sample(self, sample: Dict[str, Any]) -> Any:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
return sample

0 comments on commit e85349d

Please sign in to comment.