Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

PyTorch Tabular integrations #1559

Merged
merged 3 commits into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/flash/core/integrations/pytorch_tabular/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def from_task(
"categorical_dim": len(categorical_fields),
"continuous_dim": num_features - len(categorical_fields),
"output_dim": output_dim,
"embedded_cat_dim": sum([embd_dim for _, embd_dim in embedding_sizes]),
}
return cls(
task_type,
Expand Down
4 changes: 3 additions & 1 deletion src/flash/core/integrations/pytorch_tabular/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoIntConfig,
CategoryEmbeddingModelConfig,
FTTransformerConfig,
GatedAdditiveTreeEnsembleConfig,
NodeConfig,
TabNetModelConfig,
TabTransformerConfig,
Expand Down Expand Up @@ -88,8 +89,9 @@ def load_pytorch_tabular(
AutoIntConfig,
NodeConfig,
CategoryEmbeddingModelConfig,
GatedAdditiveTreeEnsembleConfig,
],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding"],
["tabnet", "tabtransformer", "fttransformer", "autoint", "node", "category_embedding", "gate"],
):
PYTORCH_TABULAR_BACKBONES(
functools.partial(load_pytorch_tabular, model_config_class),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down
6 changes: 3 additions & 3 deletions tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -68,7 +68,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -81,7 +81,7 @@ class TestTabularClassifier(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand Down
9 changes: 3 additions & 6 deletions tests/tabular/regression/test_data_model_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down Expand Up @@ -82,8 +81,7 @@ def test_regression_data_frame(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down Expand Up @@ -113,8 +111,7 @@ def test_regression_dicts(backbone, fields, tmpdir):
("fttransformer", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("autoint", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("node", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# ("category_embedding", # todo: seems to be bug in tabular
# {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
("category_embedding", {"categorical_fields": ["category"], "numerical_fields": ["scalar_a", "scalar_b"]}),
# No categorical / numerical fields
("tabnet", {"categorical_fields": ["category"]}),
("tabnet", {"numerical_fields": ["scalar_a", "scalar_b"]}),
Expand Down
6 changes: 3 additions & 3 deletions tests/tabular/regression/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -66,7 +66,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand All @@ -79,7 +79,7 @@ class TestTabularRegressor(TaskTester):
{"backbone": "fttransformer"},
{"backbone": "autoint"},
{"backbone": "node"},
# {"backbone": "category_embedding"}, # todo: seems to be bug in tabular
{"backbone": "category_embedding"},
],
)
],
Expand Down