Skip to content

Commit

Permalink
Bug fix for "Categorical" dtype (#493)
Browse files Browse the repository at this point in the history
* Categorical bug fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dataloader kwargs part removed

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Manu Joseph V <manujosephv@gmail.com>
  • Loading branch information
3 people authored Nov 23, 2024
1 parent a3272ea commit cf1454a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def transform(self, X):
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
X_encoded = X.copy(deep=True)
category_cols = X_encoded.select_dtypes(include="category").columns
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
for col, mapping in self._mapping.items():
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
else:
raise ValueError(f"{config.task} is an unsupported task.")
if self.train is not None:
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
self.train[category_cols] = self.train[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
]
else:
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
]
Expand Down

0 comments on commit cf1454a

Please sign in to comment.