Skip to content

Commit

Permalink
Merge branch 'main' into dependabot-pip-numpy-gt-1.20.0-and-lt-3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
manujosephv authored Nov 25, 2024
2 parents e7c033f + f354b9c commit 5fc25d6
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 132 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
- id: detect-private-key

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
Expand Down
279 changes: 150 additions & 129 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,135 +150,156 @@ loaded_model = TabularModel.load_model("examples/basic")

<!-- readme: contributors -start -->
<table>
<tr>
<td align="center">
<a href="https://github.com/manujosephv">
<img src="https://avatars.githubusercontent.com/u/10508493?v=4" width="100;" alt="manujosephv"/>
<br />
<sub><b>Manu Joseph</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/Borda">
<img src="https://avatars.githubusercontent.com/u/6035284?v=4" width="100;" alt="Borda"/>
<br />
<sub><b>Jirka Borovec</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/wsad1">
<img src="https://avatars.githubusercontent.com/u/13963626?v=4" width="100;" alt="wsad1"/>
<br />
<sub><b>Jinu Sunil</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/ProgramadorArtificial">
<img src="https://avatars.githubusercontent.com/u/130674366?v=4" width="100;" alt="ProgramadorArtificial"/>
<br />
<sub><b>Programador Artificial</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sorenmacbeth">
<img src="https://avatars.githubusercontent.com/u/130043?v=4" width="100;" alt="sorenmacbeth"/>
<br />
<sub><b>Soren Macbeth</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/fonnesbeck">
<img src="https://avatars.githubusercontent.com/u/81476?v=4" width="100;" alt="fonnesbeck"/>
<br />
<sub><b>Chris Fonnesbeck</b></sub>
</a>
</td></tr>
<tr>
<td align="center">
<a href="https://github.com/jxtrbtk">
<img src="https://avatars.githubusercontent.com/u/40494970?v=4" width="100;" alt="jxtrbtk"/>
<br />
<sub><b>Null</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/abhisharsinha">
<img src="https://avatars.githubusercontent.com/u/24841841?v=4" width="100;" alt="abhisharsinha"/>
<br />
<sub><b>Abhishar Sinha</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/ndrsfel">
<img src="https://avatars.githubusercontent.com/u/21068727?v=4" width="100;" alt="ndrsfel"/>
<br />
<sub><b>Andreas</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/charitarthchugh">
<img src="https://avatars.githubusercontent.com/u/37895518?v=4" width="100;" alt="charitarthchugh"/>
<br />
<sub><b>Charitarth Chugh</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/EeyoreLee">
<img src="https://avatars.githubusercontent.com/u/49790022?v=4" width="100;" alt="EeyoreLee"/>
<br />
<sub><b>Earlee</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/JulianRein">
<img src="https://avatars.githubusercontent.com/u/35046938?v=4" width="100;" alt="JulianRein"/>
<br />
<sub><b>Null</b></sub>
</a>
</td></tr>
<tr>
<td align="center">
<a href="https://github.com/krshrimali">
<img src="https://avatars.githubusercontent.com/u/19997320?v=4" width="100;" alt="krshrimali"/>
<br />
<sub><b>Kushashwa Ravi Shrimali</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/Actis92">
<img src="https://avatars.githubusercontent.com/u/46601193?v=4" width="100;" alt="Actis92"/>
<br />
<sub><b>Luca Actis Grosso</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sgbaird">
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
<br />
<sub><b>Sterling G. Baird</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/furyhawk">
<img src="https://avatars.githubusercontent.com/u/831682?v=4" width="100;" alt="furyhawk"/>
<br />
<sub><b>Teck Meng</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/yinyunie">
<img src="https://avatars.githubusercontent.com/u/25686434?v=4" width="100;" alt="yinyunie"/>
<br />
<sub><b>Yinyu Nie</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/HernandoR">
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>
<br />
<sub><b>Liu Zhen</b></sub>
</a>
</td></tr>
<tbody>
<tr>
<td align="center">
<a href="https://github.com/manujosephv">
<img src="https://avatars.githubusercontent.com/u/10508493?v=4" width="100;" alt="manujosephv"/>
<br />
<sub><b>Manu Joseph</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/Borda">
<img src="https://avatars.githubusercontent.com/u/6035284?v=4" width="100;" alt="Borda"/>
<br />
<sub><b>Jirka Borovec</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/wsad1">
<img src="https://avatars.githubusercontent.com/u/13963626?v=4" width="100;" alt="wsad1"/>
<br />
<sub><b>Jinu Sunil</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/ProgramadorArtificial">
<img src="https://avatars.githubusercontent.com/u/130674366?v=4" width="100;" alt="ProgramadorArtificial"/>
<br />
<sub><b>Programador Artificial</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sorenmacbeth">
<img src="https://avatars.githubusercontent.com/u/130043?v=4" width="100;" alt="sorenmacbeth"/>
<br />
<sub><b>Soren Macbeth</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/fonnesbeck">
<img src="https://avatars.githubusercontent.com/u/81476?v=4" width="100;" alt="fonnesbeck"/>
<br />
<sub><b>Chris Fonnesbeck</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center">
<a href="https://github.com/jxtrbtk">
<img src="https://avatars.githubusercontent.com/u/40494970?v=4" width="100;" alt="jxtrbtk"/>
<br />
<sub><b>Null</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/abhisharsinha">
<img src="https://avatars.githubusercontent.com/u/24841841?v=4" width="100;" alt="abhisharsinha"/>
<br />
<sub><b>Abhishar Sinha</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/ndrsfel">
<img src="https://avatars.githubusercontent.com/u/21068727?v=4" width="100;" alt="ndrsfel"/>
<br />
<sub><b>Andreas</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/charitarthchugh">
<img src="https://avatars.githubusercontent.com/u/37895518?v=4" width="100;" alt="charitarthchugh"/>
<br />
<sub><b>Charitarth Chugh</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/EeyoreLee">
<img src="https://avatars.githubusercontent.com/u/49790022?v=4" width="100;" alt="EeyoreLee"/>
<br />
<sub><b>Earlee</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/JulianRein">
<img src="https://avatars.githubusercontent.com/u/35046938?v=4" width="100;" alt="JulianRein"/>
<br />
<sub><b>Null</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center">
<a href="https://github.com/krshrimali">
<img src="https://avatars.githubusercontent.com/u/19997320?v=4" width="100;" alt="krshrimali"/>
<br />
<sub><b>Kushashwa Ravi Shrimali</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/Actis92">
<img src="https://avatars.githubusercontent.com/u/46601193?v=4" width="100;" alt="Actis92"/>
<br />
<sub><b>Luca Actis Grosso</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/snehilchatterjee">
<img src="https://avatars.githubusercontent.com/u/127598707?v=4" width="100;" alt="snehilchatterjee"/>
<br />
<sub><b>Snehil Chatterjee</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/sgbaird">
<img src="https://avatars.githubusercontent.com/u/45469701?v=4" width="100;" alt="sgbaird"/>
<br />
<sub><b>Sterling G. Baird</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/furyhawk">
<img src="https://avatars.githubusercontent.com/u/831682?v=4" width="100;" alt="furyhawk"/>
<br />
<sub><b>Teck Meng</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/yinyunie">
<img src="https://avatars.githubusercontent.com/u/25686434?v=4" width="100;" alt="yinyunie"/>
<br />
<sub><b>Yinyu Nie</b></sub>
</a>
</td>
</tr>
<tr>
<td align="center">
<a href="https://github.com/YonyBresler">
<img src="https://avatars.githubusercontent.com/u/24940683?v=4" width="100;" alt="YonyBresler"/>
<br />
<sub><b>YonyBresler</b></sub>
</a>
</td>
<td align="center">
<a href="https://github.com/HernandoR">
<img src="https://avatars.githubusercontent.com/u/45709656?v=4" width="100;" alt="HernandoR"/>
<br />
<sub><b>Liu Zhen</b></sub>
</a>
</td>
</tr>
<tbody>
</table>
<!-- readme: contributors -end -->

Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pandas >=1.1.5
scikit-learn >=1.3.0
pytorch-lightning >=2.0.0, <2.5.0
omegaconf >=2.3.0
torchmetrics >=0.10.0, <1.5.0
torchmetrics >=0.10.0, <1.6.0
tensorboard >2.2.0, !=2.5.0
protobuf >=3.20.0, <5.29.0
pytorch-tabnet ==4.1
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
wandb >=0.15.0, <0.17.0
wandb >=0.15.0, <0.19.0
plotly>=5.13.0, <5.25.0
kaleido >=0.2.0, <0.3.0
captum >=0.5.0, <0.8.0
2 changes: 2 additions & 0 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def transform(self, X):
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
X_encoded = X.copy(deep=True)
category_cols = X_encoded.select_dtypes(include="category").columns
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
for col, mapping in self._mapping.items():
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

Expand Down
8 changes: 8 additions & 0 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class DataConfig:
handle_missing_values (bool): Whether to handle missing values in categorical columns as
unknown
dataloader_kwargs (Dict[str, Any]): Additional kwargs to be passed to PyTorch DataLoader. See
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
"""

target: Optional[List[str]] = field(
Expand Down Expand Up @@ -176,6 +179,11 @@ class DataConfig:
metadata={"help": "Whether or not to handle missing values in categorical columns as unknown"},
)

dataloader_kwargs: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
)

def __post_init__(self):
assert (
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
Expand Down
7 changes: 7 additions & 0 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,14 @@ def _update_config(self, config) -> InferredConfig:
else:
raise ValueError(f"{config.task} is an unsupported task.")
if self.train is not None:
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
self.train[category_cols] = self.train[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
]
else:
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
]
Expand Down Expand Up @@ -805,6 +809,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
num_workers=self.config.num_workers,
sampler=self.train_sampler,
pin_memory=self.config.pin_memory,
**self.config.dataloader_kwargs,
)

def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
Expand All @@ -823,6 +828,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
shuffle=False,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
**self.config.dataloader_kwargs,
)

def _prepare_inference_data(self, df: DataFrame) -> DataFrame:
Expand Down Expand Up @@ -865,6 +871,7 @@ def prepare_inference_dataloader(
batch_size or self.batch_size,
shuffle=False,
num_workers=self.config.num_workers,
**self.config.dataloader_kwargs,
)

def save_dataloader(self, path: Union[str, Path]) -> None:
Expand Down

0 comments on commit 5fc25d6

Please sign in to comment.