Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

fix: set inputs as optional #109

Merged
merged 18 commits into from
Feb 13, 2021
21 changes: 15 additions & 6 deletions flash/tabular/classification/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ class TabularData(DataModule):
def __init__(
self,
train_df: DataFrame,
categorical_input: List,
numerical_input: List,
target: str,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 2,
Expand All @@ -82,6 +82,15 @@ def __init__(
dfs = [train_df]
self._test_df = None

if categorical_input is None and numerical_input is None:
raise TypeError('Both categorical_input and numerical_input are None!')
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

if categorical_input is None:
categorical_input = []

if numerical_input is None:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
numerical_input = []

if valid_df is not None:
dfs.append(valid_df)

Expand Down Expand Up @@ -133,8 +142,8 @@ def from_df(
cls,
train_df: DataFrame,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
batch_size: int = 8,
Expand Down Expand Up @@ -194,8 +203,8 @@ def from_csv(
cls,
train_csv: str,
target: str,
categorical_input: List,
numerical_input: List,
categorical_input: Optional[List] = None,
numerical_input: Optional[List] = None,
valid_csv: Optional[str] = None,
test_csv: Optional[str] = None,
batch_size: int = 8,
Expand Down