Skip to content

Commit

Permalink
RDS-305: Raise explicit error for NaNs in training data.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: b6e543531925cbac900b65b627aacd2e9529848a
  • Loading branch information
kboyd committed Jul 8, 2022
1 parent 19af35c commit 2ed9d31
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/gretel_synthetics/timeseries_dgan/dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def train_numpy(
"First dimension of attributes and features must be the same length, i.e., the number of training examples."
)

_check_for_nans(attributes, features)

if not self.is_built:
attribute_outputs, feature_outputs = create_outputs_from_data(
attributes,
Expand Down Expand Up @@ -297,6 +299,8 @@ def train_dataframe(

attributes, features = self.data_frame_converter.convert(df)

_check_for_nans(attributes, features)

self.train_numpy(
attributes=attributes,
features=features,
Expand Down Expand Up @@ -915,6 +919,24 @@ def load(cls, file_name: str, **kwargs) -> DGAN:
return dgan


def _check_for_nans(attributes: Optional[np.ndarray], features: np.ndarray):
"""Helper function to raise an error if NaNs are found.
The DGAN model does not handle NaNs at this time, so we want to throw a
specific error instead of waiting for later steps to fail that are harder to
debug.
"""
if attributes is not None and np.any(np.isnan(attributes)):
raise ValueError(
"NaN found in attributes. DGAN does not support NaNs, please remove NaNs before training."
)

if np.any(np.isnan(features)):
raise ValueError(
"NaN found in features. DGAN does not support NANs, please remove NaNs before training."
)


class _DataFrameConverter(abc.ABC):
"""Abstract class for converting DGAN input to and from a DataFrame."""

Expand Down
30 changes: 30 additions & 0 deletions tests/timeseries_dgan/test_dgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,36 @@ def test_train_dataframe_long_no_attributes(config: DGANConfig):
assert list(synthetic_df.columns) == list(df.columns)


def test_train_numpy_nans(config: DGANConfig, feature_data):
features, feature_types = feature_data
# Insert a NaN
features[11, 3, 1] = np.NaN

dg = DGAN(config=config)

with pytest.raises(ValueError, match="NaN"):
dg.train_numpy(features=features, feature_types=feature_types)


def test_train_dataframe_nans(config: DGANConfig):
n = 50
df = pd.DataFrame(
{
"2022-01-01": np.random.rand(n),
"2022-02-01": np.NaN,
"2022-03-01": np.random.rand(n),
"2022-04-01": np.random.rand(n),
}
)

config.max_sequence_len = 4
config.sample_len = 1

dg = DGAN(config=config)
with pytest.raises(ValueError, match="NaN"):
dg.train_dataframe(df=df, df_style=DfStyle.WIDE)


@pytest.fixture
def df_wide() -> pd.DataFrame:
return pd.DataFrame(
Expand Down

0 comments on commit 2ed9d31

Please sign in to comment.