Skip to content

Commit

Permalink
Add Built-in Support for Model Stacking (#520)
Browse files Browse the repository at this point in the history
* add stacking model & config

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor: Add StackingEmbeddingLayer to delete "forward" from StackingModel

* refactor: remove the use of eval for passing ruff format.

* fix typo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add Stacking Model Documentation and Tutorial

- Updated API documentation to include `StackingModelConfig` and `StackingModel`.
- Added a new tutorial notebook demonstrating model stacking in PyTorch Tabular, covering setup, configuration, training, and evaluation.
- Enhanced existing documentation to explain the model stacking concept and its benefits.

This commit improves the usability and understanding of the stacking functionality in the library.

* Refactor: Remove GatedAdditiveTreeEnsembleConfig from model configuration

This commit removes the GatedAdditiveTreeEnsembleConfig lambda function from the get_model_configs function in the test_model_stacking.py file, streamlining the model configuration process. This change enhances code clarity and focuses on the relevant model configurations for stacking.

* Update mkdocs.yml to include new Model Stacking section in documentation

- Added a new entry for "Model Stacking" in the navigation structure.
- Included a link to the tutorial notebook "tutorials/16-Model Stacking.ipynb" for users to learn about model stacking.

This change enhances the documentation by providing users with direct access to resources related to model stacking.

* Refactor mkdocs.yml to streamline navigation structure

- Removed unnecessary indentation for the "Model Stacking" entry in the navigation.
- Maintained the link to the tutorial notebook "tutorials/16-Model Stacking.ipynb" for user access.

This change improves the clarity of the documentation structure without altering the content.

* Refactor StackingModelConfig to simplify model_configs type annotation

- Changed the type annotation of model_configs from list[ModelConfig] to list

* Refactor StackingBackbone forward method to remove type annotation

* Refactor StackingEmbeddingLayer to remove type annotation from forward method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add model stacking diagram and enhance documentation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
taimo3810 and pre-commit-ci[bot] authored Dec 17, 2024
1 parent f04a05c commit e042201
Show file tree
Hide file tree
Showing 11 changed files with 1,915 additions and 1 deletion.
7 changes: 6 additions & 1 deletion docs/apidocs_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
::: pytorch_tabular.models.TabTransformerConfig
options:
heading_level: 3
::: pytorch_tabular.models.StackingModelConfig
options:
heading_level: 3
::: pytorch_tabular.config.ModelConfig
options:
heading_level: 3
Expand Down Expand Up @@ -66,7 +69,9 @@
::: pytorch_tabular.models.TabTransformerModel
options:
heading_level: 3

::: pytorch_tabular.models.StackingModel
options:
heading_level: 3
## Base Model Class
::: pytorch_tabular.models.BaseModel
options:
Expand Down
Binary file added docs/imgs/model_stacking_concept.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,30 @@ All the parameters have beet set to recommended values from the paper. Let's loo
**For a complete list of parameters refer to the API Docs**
[pytorch_tabular.models.DANetConfig][]

## Model Stacking

Model stacking is an ensemble learning technique that combines multiple base models to create a more powerful predictive model. Each base model processes the input features independently, and their outputs are concatenated before making the final prediction. This allows the model to leverage different learning patterns captured by each backbone architecture. You can use it by choosing `StackingModelConfig`.

The following diagram shows the concept of model stacking in PyTorch Tabular.
![Model Stacking](imgs/model_stacking_concept.png)

The following model architectures are supported for stacking:
- Category Embedding Model
- TabNet Model
- FTTransformer Model
- Gated Additive Tree Ensemble Model
- DANet Model
- AutoInt Model
- GANDALF Model
- Node Model

All the parameters have been set to provide flexibility while maintaining ease of use. Let's look at them:

- `model_configs`: List[ModelConfig]: List of configurations for each base model. Each config should be a valid PyTorch Tabular model config (e.g., NodeConfig, GANDALFConfig)

**For a complete list of parameters refer to the API Docs**
[pytorch_tabular.models.StackingModelConfig][]

## Implementing New Architectures

PyTorch Tabular is very easy to extend and infinitely customizable. All the models that have been implemented in PyTorch Tabular inherits an Abstract Class `BaseModel` which is in fact a PyTorchLightning Model.
Expand Down
1,486 changes: 1,486 additions & 0 deletions docs/tutorials/16-Model Stacking.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ nav:
- SHAP, Deep LIFT and so on through Captum Integration: "tutorials/14-Explainability.ipynb"
- Custom PyTorch Models:
- Implementing New Supervised Architectures: "tutorials/04-Implementing New Architectures.ipynb"
- Model Stacking: "tutorials/16-Model Stacking.ipynb"
- Other Features:
- Using Neural Categorical Embeddings in Scikit-Learn Workflows: "tutorials/03-Neural Embedding in Scikit-Learn Workflows.ipynb"
- Self-Supervised Learning using Denoising Autoencoders: "tutorials/08-Self-Supervised Learning-DAE.ipynb"
Expand Down
4 changes: 4 additions & 0 deletions src/pytorch_tabular/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .gate import GatedAdditiveTreeEnsembleConfig, GatedAdditiveTreeEnsembleModel
from .mixture_density import MDNConfig, MDNModel
from .node import NodeConfig, NODEModel
from .stacking import StackingModel, StackingModelConfig
from .tab_transformer import TabTransformerConfig, TabTransformerModel
from .tabnet import TabNetModel, TabNetModelConfig

Expand All @@ -45,6 +46,8 @@
"GANDALFBackbone",
"DANetConfig",
"DANetModel",
"StackingModel",
"StackingModelConfig",
"category_embedding",
"node",
"mixture_density",
Expand All @@ -55,4 +58,5 @@
"gate",
"gandalf",
"danet",
"stacking",
]
4 changes: 4 additions & 0 deletions src/pytorch_tabular/models/stacking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .config import StackingModelConfig
from .stacking_model import StackingBackbone, StackingModel

__all__ = ["StackingModel", "StackingModelConfig", "StackingBackbone"]
26 changes: 26 additions & 0 deletions src/pytorch_tabular/models/stacking/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass, field

from pytorch_tabular.config import ModelConfig


@dataclass
class StackingModelConfig(ModelConfig):
"""StackingModelConfig is a configuration class for the StackingModel. It is used to stack multiple models
together. Now, CategoryEmbeddingModel, TabNetModel, FTTransformerModel, GatedAdditiveTreeEnsembleModel, DANetModel,
AutoIntModel, GANDALFModel, NodeModel are supported.
Args:
model_configs (list[ModelConfig]): List of model configs to stack.
"""

model_configs: list = field(default_factory=list, metadata={"help": "List of model configs to stack"})
_module_src: str = field(default="models.stacking")
_model_name: str = field(default="StackingModel")
_backbone_name: str = field(default="StackingBackbone")
_config_name: str = field(default="StackingConfig")


# if __name__ == "__main__":
# from pytorch_tabular.utils import generate_doc_dataclass
# print(generate_doc_dataclass(StackingModelConfig))
140 changes: 140 additions & 0 deletions src/pytorch_tabular/models/stacking/stacking_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import inspect

import torch
import torch.nn as nn
from omegaconf import DictConfig

import pytorch_tabular.models as models
from pytorch_tabular.models import BaseModel
from pytorch_tabular.models.common.heads import blocks
from pytorch_tabular.models.gate import GatedAdditiveTreesBackbone
from pytorch_tabular.models.node import NODEBackbone


def instantiate_backbone(hparams, backbone_name):
backbone_class = getattr(getattr(models, hparams._module_src.split(".")[-1]), backbone_name)
class_args = list(inspect.signature(backbone_class).parameters.keys())
if "config" in class_args:
return backbone_class(config=hparams)
else:
return backbone_class(
**{
arg: getattr(hparams, arg) if arg != "block_activation" else getattr(nn, getattr(hparams, arg))()
for arg in class_args
}
)


class StackingEmbeddingLayer(nn.Module):
def __init__(self, embedding_layers: nn.ModuleList):
super().__init__()
self.embedding_layers = embedding_layers

def forward(self, x):
outputs = []
for embedding_layer in self.embedding_layers:
em_output = embedding_layer(x)
outputs.append(em_output)
return outputs


class StackingBackbone(nn.Module):
def __init__(self, config: DictConfig):
super().__init__()
self.hparams = config
self._build_network()

def _build_network(self):
self._backbones = nn.ModuleList()
self._heads = nn.ModuleList()
self._backbone_output_dims = []
assert len(self.hparams.model_configs) > 0, "Stacking requires more than 0 model"
for model_i in range(len(self.hparams.model_configs)):
# move necessary params to each model config
self.hparams.model_configs[model_i].embedded_cat_dim = self.hparams.embedded_cat_dim
self.hparams.model_configs[model_i].continuous_dim = self.hparams.continuous_dim
self.hparams.model_configs[model_i].n_continuous_features = self.hparams.continuous_dim

self.hparams.model_configs[model_i].embedding_dims = self.hparams.embedding_dims
self.hparams.model_configs[model_i].categorical_cardinality = self.hparams.categorical_cardinality
self.hparams.model_configs[model_i].categorical_dim = self.hparams.categorical_dim
self.hparams.model_configs[model_i].cat_embedding_dims = self.hparams.embedding_dims

# if output_dim is not set, set it to 128
if getattr(self.hparams.model_configs[model_i], "output_dim", None) is None:
self.hparams.model_configs[model_i].output_dim = 128

# if inferred_config is not set, set it to None.
if getattr(self.hparams, "inferred_config", None) is not None:
self.hparams.model_configs[model_i].inferred_config = self.hparams.inferred_config

# instantiate backbone
_backbone = instantiate_backbone(
self.hparams.model_configs[model_i], self.hparams.model_configs[model_i]._backbone_name
)
# set continuous_dim
_backbone.continuous_dim = self.hparams.continuous_dim
# if output_dim is not set, set it to the output_dim in model_config
if getattr(_backbone, "output_dim", None) is None:
setattr(
_backbone,
"output_dim",
self.hparams.model_configs[model_i].output_dim,
)
self._backbones.append(_backbone)
self._backbone_output_dims.append(_backbone.output_dim)

self.output_dim = sum(self._backbone_output_dims)

def _build_embedding_layer(self):
assert getattr(self, "_backbones", None) is not None, "Backbones are not built"
embedding_layers = nn.ModuleList()
for backbone in self._backbones:
if getattr(backbone, "_build_embedding_layer", None) is None:
embedding_layers.append(nn.Identity())
else:
embedding_layers.append(backbone._build_embedding_layer())
return StackingEmbeddingLayer(embedding_layers)

def forward(self, x_list):
outputs = []
for i, backbone in enumerate(self._backbones):
bb_output = backbone(x_list[i])
if len(bb_output.shape) == 3 and isinstance(backbone, GatedAdditiveTreesBackbone):
bb_output = bb_output.mean(dim=-1)
elif len(bb_output.shape) == 3 and isinstance(backbone, NODEBackbone):
bb_output = bb_output.mean(dim=1)
outputs.append(bb_output)
x = torch.cat(outputs, dim=1)
return x


class StackingModel(BaseModel):
def __init__(self, config: DictConfig, **kwargs):
super().__init__(config, **kwargs)

def _build_network(self):
self._backbone = StackingBackbone(self.hparams)
self._embedding_layer = self._backbone._build_embedding_layer()
self.output_dim = self._backbone.output_dim
self._head = self._get_head_from_config()

def _get_head_from_config(self):
_head_callable = getattr(blocks, self.hparams.head)
return _head_callable(
in_units=self.output_dim,
output_dim=self.hparams.output_dim,
config=_head_callable._config_template(**self.hparams.head_config),
)

@property
def backbone(self):
return self._backbone

@property
def embedding_layer(self):
return self._embedding_layer

@property
def head(self):
return self._head
1 change: 1 addition & 0 deletions src/pytorch_tabular/models/tabnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class TabNetModelConfig(ModelConfig):
_module_src: str = field(default="models.tabnet")
_model_name: str = field(default="TabNetModel")
_config_name: str = field(default="TabNetModelConfig")
_backbone_name: str = field(default="TabNetBackbone")


# if __name__ == "__main__":
Expand Down
Loading

0 comments on commit e042201

Please sign in to comment.