Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add fit tests for tabular tasks (#1332)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored May 6, 2022
1 parent 88227d8 commit 48a2500
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 228 deletions.
11 changes: 7 additions & 4 deletions flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@


class PytorchTabularAdapter(Adapter):
def __init__(self, backbone):
def __init__(self, task_type, backbone):
super().__init__()

self.task_type = task_type
self.backbone = backbone

@classmethod
Expand Down Expand Up @@ -52,21 +53,23 @@ def from_task(
"output_dim": output_dim,
}
adapter = cls(
task_type,
task.backbones.get(backbone)(
task_type=task_type, parameters=parameters, loss_fn=loss_fn, metrics=metrics, **backbone_kwargs
)
),
)

return adapter

@staticmethod
def convert_batch(batch):
def convert_batch(self, batch):
new_batch = {
"continuous": batch[DataKeys.INPUT][1],
"categorical": batch[DataKeys.INPUT][0],
}
if DataKeys.TARGET in batch:
new_batch["target"] = batch[DataKeys.TARGET].reshape(-1, 1)
if self.task_type == "regression":
new_batch["target"] = new_batch["target"].float()
return new_batch

def training_step(self, batch, batch_idx) -> Any:
Expand Down
7 changes: 7 additions & 0 deletions flash_examples/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
# limitations under the License.
from flash import Trainer
from flash.core.data.utils import download_data
from flash.core.utilities.imports import example_requires
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

example_requires("text")

import nltk # noqa: E402

nltk.download("punkt")

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/task_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _copy_func(f):
return g


class _StaticDataset(Dataset):
class StaticDataset(Dataset):
def __init__(self, sample, length):
super().__init__()

Expand All @@ -60,7 +60,7 @@ def _test_forward(self):

def _test_fit(self, tmpdir, task_kwargs):
"""Tests that a single batch fit pass completes."""
dataset = _StaticDataset(self.example_train_sample, 4)
dataset = StaticDataset(self.example_train_sample, 4)

args = self.task_args
kwargs = dict(**self.task_kwargs)
Expand Down
128 changes: 49 additions & 79 deletions tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,13 @@
import pandas as pd
import pytest
import torch
from pytorch_lightning import Trainer

import flash
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _SERVE_TESTING, _TABULAR_AVAILABLE, _TABULAR_TESTING
from flash.tabular.classification.data import TabularClassificationData
from flash.tabular.classification.model import TabularClassifier
from tests.helpers.task_tester import TaskTester

# ======== Mock functions ========


class DummyDataset(torch.utils.data.Dataset):
def __init__(self, num_num=16, num_cat=16):
super().__init__()
self.num_num = num_num
self.num_cat = num_cat

def __getitem__(self, index):
target = torch.randint(0, 10, size=(1,)).item()
cat_vars = torch.randint(0, 10, size=(self.num_cat,))
num_vars = torch.rand(self.num_num)
return {DataKeys.INPUT: (cat_vars, num_vars), DataKeys.TARGET: target}

def __len__(self) -> int:
return 100


# ==============================
from tests.helpers.task_tester import StaticDataset, TaskTester


class TestTabularClassifier(TaskTester):
Expand All @@ -66,6 +45,23 @@ class TestTabularClassifier(TaskTester):
scriptable = False
traceable = False

marks = {
"test_fit": [
pytest.mark.parametrize(
"task_kwargs",
[
{"backbone": "tabnet"},
{"backbone": "tabtransformer"},
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
{"backbone": "category_embedding"},
],
)
],
"test_cli": [pytest.mark.parametrize("extra_args", ([],))],
}

@property
def example_forward_input(self):
return {
Expand All @@ -77,63 +73,37 @@ def check_forward_output(self, output: Any):
assert isinstance(output, torch.Tensor)
assert output.shape == torch.Size([1, 10])

@property
def example_train_sample(self):
return {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.rand(4)), DataKeys.TARGET: 1}

@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
@pytest.mark.parametrize(
"backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"]
)
def test_init_train(backbone, tmpdir):
train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16)
data_properties = {
"parameters": {"categorical_fields": list(range(16))},
"embedding_sizes": [(10, 32) for _ in range(16)],
"cat_dims": [10 for _ in range(16)],
"num_features": 32,
"num_classes": 10,
"backbone": backbone,
}

model = TabularClassifier(**data_properties)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
@pytest.mark.parametrize(
"backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"]
)
def test_init_train_no_num(backbone, tmpdir):
train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16)
data_properties = {
"parameters": {"categorical_fields": list(range(16))},
"embedding_sizes": [(10, 32) for _ in range(16)],
"cat_dims": [10 for _ in range(16)],
"num_features": 16,
"num_classes": 10,
"backbone": backbone,
}

model = TabularClassifier(**data_properties)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
@pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"])
def test_init_train_no_cat(backbone, tmpdir):
train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16)
data_properties = {
"parameters": {"categorical_fields": []},
"embedding_sizes": [],
"cat_dims": [],
"num_features": 16,
"num_classes": 10,
"backbone": backbone,
}

model = TabularClassifier(**data_properties)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
@pytest.mark.parametrize(
"backbone", ["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"]
)
def test_init_train_no_num(self, backbone, tmpdir):
no_num_sample = {DataKeys.INPUT: (torch.randint(0, 10, size=(4,)), torch.empty(0)), DataKeys.TARGET: 1}
dataset = StaticDataset(no_num_sample, 4)

args = self.task_args
kwargs = dict(**self.task_kwargs)
kwargs.update(num_features=4)
model = self.task(*args, **kwargs)

trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, model.process_train_dataset(dataset, batch_size=4))

@pytest.mark.parametrize("backbone", ["tabnet", "tabtransformer", "autoint", "node", "category_embedding"])
def test_init_train_no_cat(self, backbone, tmpdir):
no_cat_sample = {DataKeys.INPUT: (torch.empty(0), torch.rand(4)), DataKeys.TARGET: 1}
dataset = StaticDataset(no_cat_sample, 4)

args = self.task_args
kwargs = dict(**self.task_kwargs)
kwargs.update(parameters={"categorical_fields": []}, embedding_sizes=[], cat_dims=[], num_features=4)
model = self.task(*args, **kwargs)

trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, model.process_train_dataset(dataset, batch_size=4))


@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
Expand Down
96 changes: 30 additions & 66 deletions tests/tabular/forecasting/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,17 @@
import torch

import flash
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE, _TABULAR_TESTING
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData
from tests.helpers.task_tester import TaskTester
from flash import DataKeys
from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING
from flash.tabular.forecasting import TabularForecaster
from tests.helpers.task_tester import StaticDataset, TaskTester

if _TABULAR_AVAILABLE:
from pytorch_forecasting.data import EncoderNormalizer, NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
else:
EncoderNormalizer = object
NaNLabelEncoder = object

if _PANDAS_AVAILABLE:
import pandas as pd


class TestTabularForecaster(TaskTester):

Expand Down Expand Up @@ -102,66 +99,33 @@ def check_forward_output(self, output: Any):
assert isinstance(output["prediction"], torch.Tensor)
assert output["prediction"].shape == torch.Size([2, 20])

@property
def example_train_sample(self):
return {
DataKeys.INPUT: {
"x_cat": torch.empty(60, 0, dtype=torch.int64),
"x_cont": torch.zeros(60, 1),
"encoder_target": torch.zeros(60),
"encoder_length": 60,
"decoder_length": 20,
"encoder_time_idx_start": 1,
"groups": torch.zeros(1),
"target_scale": torch.zeros(2),
},
DataKeys.TARGET: (torch.rand(20), None),
}

@pytest.fixture
def sample_data():
data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42)
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
max_prediction_length = 20
training_cutoff = data["time_idx"].max() - max_prediction_length
return data, training_cutoff, max_prediction_length


@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.")
def test_fast_dev_run_smoke(sample_data):
"""Test that fast dev run works with the NBeats example data."""
data, training_cutoff, max_prediction_length = sample_data
datamodule = TabularForecastingData.from_data_frame(
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
time_varying_unknown_reals=["value"],
max_encoder_length=60,
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
val_data_frame=data,
batch_size=4,
)

model = TabularForecaster(
datamodule.parameters,
backbone="n_beats",
backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)

trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01)
trainer.fit(model, datamodule=datamodule)

def test_testing_raises(self, tmpdir):
"""Tests that ``NotImplementedError`` is raised when attempting to perform a test pass."""
dataset = StaticDataset(self.example_train_sample, 4)

@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.")
def test_testing_raises(sample_data):
"""Tests that ``NotImplementedError`` is raised when attempting to perform a test pass."""
data, training_cutoff, max_prediction_length = sample_data
datamodule = TabularForecastingData.from_data_frame(
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
time_varying_unknown_reals=["value"],
max_encoder_length=60,
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
test_data_frame=data,
batch_size=4,
)
args = self.task_args
kwargs = dict(**self.task_kwargs)
model = self.task(*args, **kwargs)

model = TabularForecaster(
datamodule.parameters,
backbone="n_beats",
backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)
trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01)
trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True)

with pytest.raises(NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."):
trainer.test(model, datamodule=datamodule)
with pytest.raises(
NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."
):
trainer.test(model, model.process_test_dataset(dataset, batch_size=4))
Loading

0 comments on commit 48a2500

Please sign in to comment.