From e85349da0c7c2f5a5d2ae812d8b4c753ef8328f1 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 6 May 2022 16:25:25 +0530 Subject: [PATCH] Support for from_dict for Tabular Classification --- flash/tabular/classification/data.py | 108 ++++++++++++++++++++++++++ flash/tabular/classification/input.py | 29 +++++++ 2 files changed, 137 insertions(+) diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 6a9da1dcad..c25707523c 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -311,3 +311,111 @@ def from_csv( transform_kwargs=transform_kwargs, **data_module_kwargs, ) + + @classmethod + def from_dict( + cls, + categorical_fields: Optional[Union[str, List[str]]] = None, + numerical_fields: Optional[Union[str, List[str]]] = None, + target_fields: Optional[Union[str, List[str]]] = None, + parameters: Optional[Dict[str, Any]] = None, + train_data: Optional[DataFrame] = None, + val_data: Optional[DataFrame] = None, + test_data: Optional[DataFrame] = None, + predict_data: Optional[DataFrame] = None, + target_formatter: Optional[TargetFormatter] = None, + input_cls: Type[Input] = TabularClassificationListInput, + transform: INPUT_TRANSFORM_TYPE = InputTransform, + transform_kwargs: Optional[Dict] = None, + **data_module_kwargs: Any, + ) -> "TabularClassificationData": + """Creates a :class:`~flash.tabular.classification.data.TabularClassificationData` object from the given + lists. + .. note:: + The ``categorical_fields``, ``numerical_fields``, and ``target_fields`` do not need to be provided if + ``parameters`` are passed instead. These can be obtained from the + :attr:`~flash.tabular.data.TabularData.parameters` attribute of the + :class:`~flash.tabular.data.TabularData` object that contains your training data. + The targets will be extracted from the ``target_fields`` in the data frames and can be in any of our + :ref:`supported classification target formats `. + To learn how to customize the transforms applied for each stage, read our + :ref:`customizing transforms guide `. + Args: + categorical_fields: The fields (column names) in the data frames containing categorical data. + numerical_fields: The fields (column names) in the data frames containing numerical data. + target_fields: The field (column name) or list of fields in the data frames containing the targets. + parameters: Parameters to use if ``categorical_fields``, ``numerical_fields``, and ``target_fields`` are not + provided (e.g. when loading data for inference or validation). + train_data: The data to use when training. + val_data: The data to use when validating. + test_data: The data to use when testing. + predict_data: The data to use when predicting. + target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to + control how targets are handled. See :ref:`formatting_classification_targets` for more details. + input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. + transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use. + transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. + data_module_kwargs: Additional keyword arguments to provide to the + :class:`~flash.core.data.data_module.DataModule` constructor. + Returns: + The constructed :class:`~flash.tabular.classification.data.TabularClassificationData`. + Examples + ________ + .. testsetup:: + >>> train_data = { + ... "animal": ["cat", "dog", "cat"], + ... "friendly": ["yes", "yes", "no"], + ... "weight": [6, 10, 5], + ... } + >>> predict_data = { + ... "friendly": ["yes", "no", "yes"], + ... "weight": [7, 12, 5], + ... } + We have dictionaries ``train_data`` and ``predict_data``. + .. doctest:: + >>> from flash import Trainer + >>> from flash.tabular import TabularClassifier, TabularClassificationData + >>> datamodule = TabularClassificationData.from_dict( + ... "friendly", + ... "weight", + ... "animal", + ... train_data_frame=train_data, + ... predict_data_frame=predict_data, + ... batch_size=4, + ... ) + >>> datamodule.num_classes + 2 + >>> datamodule.labels + ['cat', 'dog'] + >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") + >>> trainer = Trainer(fast_dev_run=True) + >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Training... + >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Predicting... + .. testcleanup:: + >>> del train_data + >>> del predict_data + """ + ds_kw = dict( + target_formatter=target_formatter, + categorical_fields=categorical_fields, + numerical_fields=numerical_fields, + target_fields=target_fields, + parameters=parameters, + ) + + train_input = input_cls(RunningStage.TRAINING, train_data, **ds_kw) + ds_kw["parameters"] = train_input.parameters if train_input else parameters + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + + return cls( + train_input, + input_cls(RunningStage.VALIDATING, val_data, **ds_kw), + input_cls(RunningStage.TESTING, test_data, **ds_kw), + input_cls(RunningStage.PREDICTING, predict_data, **ds_kw), + transform=transform, + transform_kwargs=transform_kwargs, + **data_module_kwargs, + ) + diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index a73490b366..feeb06f575 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -65,3 +65,32 @@ def load_data( return super().load_data( read_csv(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter ) + + +class TabularClassificationListInput(TabularDataFrameInput, ClassificationInputMixin): + def load_data( + self, + data: Dict[str, Union[Any, List[Any]]], + categorical_fields: Optional[Union[str, List[str]]] = None, + numerical_fields: Optional[Union[str, List[str]]] = None, + target_fields: Optional[Union[str, List[str]]] = None, + parameters: Dict[str, Any] = None, + target_formatter: Optional[TargetFormatter] = None, + ): + # Convert the data (dict) to a Pandas DataFrame + data_frame = DataFrame.from_dict(data) + + cat_vars, num_vars = self.preprocess(data_frame, categorical_fields, numerical_fields, parameters) + + if not self.predicting: + targets = resolve_targets(data_frame, target_fields) + self.load_target_metadata(targets, target_formatter=target_formatter) + return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, targets)] + else: + return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + + def load_sample(self, sample: Dict[str, Any]) -> Any: + if DataKeys.TARGET in sample: + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) + return sample +