Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv authored Nov 25, 2024
2 parents dcd4bc2 + f354b9c commit 42f71f7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
Expand Down
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ loaded_model = TabularModel.load_model("examples/basic")
<sub><b>Luca Actis Grosso</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/snehilchatterjee">
<img src="https://avatars.githubusercontent.com/u/127598707?v=4" width="100;" alt="snehilchatterjee"/>
<br />
<sub><b>Snehil Chatterjee</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sgbaird">
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
Expand All @@ -275,15 +282,15 @@ loaded_model = TabularModel.load_model("examples/basic")
<sub><b>Yinyu Nie</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center">
<a href="https://github.com/YonyBresler">
<img src="https://avatars.githubusercontent.com/u/24940683?v=4" width="100;" alt="YonyBresler"/>
<br />
<sub><b>YonyBresler</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center">
<a href="https://github.com/HernandoR">
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pandas >=1.1.5
scikit-learn >=1.3.0
pytorch-lightning >=2.0.0, <2.5.0
omegaconf >=2.3.0
torchmetrics >=0.10.0, <1.5.0
torchmetrics >=0.10.0, <1.6.0
tensorboard >2.2.0, !=2.5.0
protobuf >=3.20.0, <5.29.0
pytorch-tabnet ==4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
wandb >=0.15.0, <0.17.0
wandb >=0.15.0, <0.19.0
plotly>=5.13.0, <5.25.0
kaleido >=0.2.0, <0.3.0
captum >=0.5.0, <0.8.0
2 changes: 2 additions & 0 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def transform(self, X):
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
X_encoded = X.copy(deep=True)
category_cols = X_encoded.select_dtypes(include="category").columns
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
for col, mapping in self._mapping.items():
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
else:
raise ValueError(f"{config.task} is an unsupported task.")
if self.train is not None:
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
self.train[category_cols] = self.train[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
]
else:
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
]
Expand Down

0 comments on commit 42f71f7

Please sign in to comment.