Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add fit tests for tabular forecasting
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed May 6, 2022
1 parent 6393080 commit 7b6e85f
Showing 1 changed file with 30 additions and 66 deletions.
96 changes: 30 additions & 66 deletions tests/tabular/forecasting/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,17 @@
import torch

import flash
from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE, _TABULAR_TESTING
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData
from tests.helpers.task_tester import TaskTester
from flash import DataKeys
from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TABULAR_TESTING
from flash.tabular.forecasting import TabularForecaster
from tests.helpers.task_tester import StaticDataset, TaskTester

if _TABULAR_AVAILABLE:
from pytorch_forecasting.data import EncoderNormalizer, NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data
else:
EncoderNormalizer = object
NaNLabelEncoder = object

if _PANDAS_AVAILABLE:
import pandas as pd


class TestTabularForecaster(TaskTester):

Expand Down Expand Up @@ -102,66 +99,33 @@ def check_forward_output(self, output: Any):
assert isinstance(output["prediction"], torch.Tensor)
assert output["prediction"].shape == torch.Size([2, 20])

@property
def example_train_sample(self):
return {
DataKeys.INPUT: {
"x_cat": torch.empty(60, 0, dtype=torch.int64),
"x_cont": torch.zeros(60, 1),
"encoder_target": torch.zeros(60),
"encoder_length": 60,
"decoder_length": 20,
"encoder_time_idx_start": 1,
"groups": torch.zeros(1),
"target_scale": torch.zeros(2),
},
DataKeys.TARGET: (torch.rand(20), None),
}

@pytest.fixture
def sample_data():
data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42)
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
max_prediction_length = 20
training_cutoff = data["time_idx"].max() - max_prediction_length
return data, training_cutoff, max_prediction_length


@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.")
def test_fast_dev_run_smoke(sample_data):
"""Test that fast dev run works with the NBeats example data."""
data, training_cutoff, max_prediction_length = sample_data
datamodule = TabularForecastingData.from_data_frame(
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
time_varying_unknown_reals=["value"],
max_encoder_length=60,
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
val_data_frame=data,
batch_size=4,
)

model = TabularForecaster(
datamodule.parameters,
backbone="n_beats",
backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)

trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01)
trainer.fit(model, datamodule=datamodule)

def test_testing_raises(self, tmpdir):
"""Tests that ``NotImplementedError`` is raised when attempting to perform a test pass."""
dataset = StaticDataset(self.example_train_sample, 4)

@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.")
def test_testing_raises(sample_data):
"""Tests that ``NotImplementedError`` is raised when attempting to perform a test pass."""
data, training_cutoff, max_prediction_length = sample_data
datamodule = TabularForecastingData.from_data_frame(
time_idx="time_idx",
target="value",
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
group_ids=["series"],
time_varying_unknown_reals=["value"],
max_encoder_length=60,
max_prediction_length=max_prediction_length,
train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
test_data_frame=data,
batch_size=4,
)
args = self.task_args
kwargs = dict(**self.task_kwargs)
model = self.task(*args, **kwargs)

model = TabularForecaster(
datamodule.parameters,
backbone="n_beats",
backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)
trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01)
trainer = flash.Trainer(default_root_dir=tmpdir, fast_dev_run=True)

with pytest.raises(NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."):
trainer.test(model, datamodule=datamodule)
with pytest.raises(
NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."
):
trainer.test(model, model.process_test_dataset(dataset, batch_size=4))

0 comments on commit 7b6e85f

Please sign in to comment.