Mambular is a Python library for tabular deep learning. It includes models that leverage the Mamba (State Space Model) architecture, as well as other popular models like TabTransformer, FTTransformer, TabM and tabular ResNets. Check out our paper Mambular: A Sequential Model for Tabular Deep Learning
, available here. Also check out our paper introducing TabulaRNN and analyzing the efficiency of NLP inspired tabular models.
- 🏃 Quickstart
- 📖 Introduction
- 🤖 Models
- 📚 Documentation
- 🛠️ Installation
- 🚀 Usage
- 💻 Implement Your Own Model
- Custom Training
- 🏷️ Citation
- License
Similar to any sklearn model, Mambular models can be fit as easy as this:
from mambular.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier()
# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
model.fit(X, y, max_epochs=150, lr=1e-04)
Mambular is a Python package that brings the power of advanced deep learning architectures to tabular data, offering a suite of models for regression, classification, and distributional regression tasks. Designed with ease of use in mind, Mambular models adhere to scikit-learn's BaseEstimator
interface, making them highly compatible with the familiar scikit-learn ecosystem. This means you can fit, predict, and evaluate using Mambular models just as you would with any traditional scikit-learn model, but with the added performance and flexibility of deep learning.
Model | Description |
---|---|
Mambular |
A sequential model using Mamba blocks specifically designed for various tabular data tasks introduced here. |
TabM |
Batch Ensembling for a MLP as introduced by Gorishniy et al. |
NODE |
Neural Oblivious Decision Ensembles as introduced by Popov et al. |
FTTransformer |
A model leveraging transformer encoders, as introduced by Gorishniy et al., for tabular data. |
MLP |
A classical Multi-Layer Perceptron (MLP) model for handling tabular data tasks. |
ResNet |
An adaptation of the ResNet architecture for tabular data applications. |
TabTransformer |
A transformer-based model for tabular data introduced by Huang et al., enhancing feature learning capabilities. |
MambaTab |
A tabular model using a Mamba-Block on a joint input representation described here . Not a sequential model. |
TabulaRNN |
A Recurrent Neural Network for Tabular data, introduced here. |
MambAttention |
A combination between Mamba and Transformers, also introduced here. |
NDTF |
A neural decision forest using soft decision trees. See Kontschieder et al. for inspiration. |
All models are available for regression
, classification
and distributional regression, denoted by LSS
.
Hence, they are available as e.g. MambularRegressor
, MambularClassifier
or MambularLSS
You can find the Mamba-Tabular API documentation here.
Install Mambular using pip:
pip install mambular
If you want to use the original mamba and mamba2 implementations, additionally install mamba-ssm via:
pip install mamba-ssm
Be careful to use the correct torch and cuda versions:
pip install torch==2.0.0+cu118 torchvision==0.15.0+cu118 torchaudio==2.0.0+cu118 -f https://download.pytorch.org/whl/cu118/torch_stable.html
pip install mamba-ssm
Mambular simplifies data preprocessing with a range of tools designed for easy transformation of tabular data.
- Ordinal & One-Hot Encoding: Automatically transforms categorical data into numerical formats using continuous ordinal encoding or one-hot encoding. Includes options for transforming outputs to
float
for compatibility with downstream models. - Binning: Discretizes numerical features into bins, with support for both fixed binning strategies and optimal binning derived from decision tree models.
- MinMax: Scales numerical data to a specific range, such as [-1, 1], using Min-Max scaling or similar techniques.
- Standardization: Centers and scales numerical features to have a mean of zero and unit variance for better compatibility with certain models.
- Quantile Transformations: Normalizes numerical data to follow a uniform or normal distribution, handling distributional shifts effectively.
- Spline Transformations: Captures nonlinearity in numerical features using spline-based transformations, ideal for complex relationships.
- Piecewise Linear Encodings (PLE): Captures complex numerical patterns by applying piecewise linear encoding, suitable for data with periodic or nonlinear structures.
- Polynomial Features: Automatically generates polynomial and interaction terms for numerical features, enhancing the ability to capture higher-order relationships.
- Box-Cox & Yeo-Johnson Transformations: Performs power transformations to stabilize variance and normalize distributions.
- Custom Binning: Enables user-defined bin edges for precise discretization of numerical data.
from mambular.models import MambularClassifier
# Initialize and fit your model
model = MambularClassifier(
d_model=64,
n_layers=4,
numerical_preprocessing="ple",
n_bins=50,
d_conv=8
)
# X can be a dataframe or something that can be easily transformed into a pd.DataFrame as a np.array
model.fit(X, y, max_epochs=150, lr=1e-04)
Predictions are also easily obtained:
# simple predictions
preds = model.predict(X)
# Predict probabilities
preds = model.predict_proba(X)
from sklearn.model_selection import RandomizedSearchCV
param_dist = {
'd_model': randint(32, 128),
'n_layers': randint(2, 10),
'lr': uniform(1e-5, 1e-3)
}
random_search = RandomizedSearchCV(
estimator=model,
param_distributions=param_dist,
n_iter=50, # Number of parameter settings sampled
cv=5, # 5-fold cross-validation
scoring='accuracy', # Metric to optimize
random_state=42
)
fit_params = {"max_epochs":5, "rebuild":False}
# Fit the model
random_search.fit(X, y, **fit_params)
# Best parameters and score
print("Best Parameters:", random_search.best_params_)
print("Best Score:", random_search.best_score_)
Note, that using this, you can also optimize the preprocessing. Just use the prefix prepro__
when specifying the preprocessor arguments you want to optimize:
param_dist = {
'd_model': randint(32, 128),
'n_layers': randint(2, 10),
'lr': uniform(1e-5, 1e-3),
"prepro__numerical_preprocessing": ["ple", "standardization", "box-cox"]
}
Since we have early stopping integrated and return the best model with respect to the validation loss, setting max_epochs to a large number is sensible.
Or use the built-in bayesian hpo simply by running:
best_params = model.optimize_hparams(X, y)
This automatically sets the search space based on the default config from mambular.configs
. See the documentation for all params with regard to optimize_hparams()
. However, the preprocessor arguments are fixed and cannot be optimized here.
MambularLSS allows you to model the full distribution of a response variable, not just its mean. This is crucial when understanding variability, skewness, or kurtosis is important. All Mambular models are available as distributional models.
- Full Distribution Modeling: Predicts the entire distribution, not just a single value, providing richer insights.
- Customizable Distribution Types: Supports various distributions (e.g., Gaussian, Poisson, Binomial) for different data types.
- Location, Scale, Shape Parameters: Predicts key distributional parameters for deeper insights.
- Enhanced Predictive Uncertainty: Offers more robust predictions by modeling the entire distribution.
- normal: For continuous data with a symmetric distribution.
- poisson: For count data within a fixed interval.
- gamma: For skewed continuous data, often used for waiting times.
- beta: For data bounded between 0 and 1, like proportions.
- dirichlet: For multivariate data with correlated components.
- studentt: For data with heavier tails, useful with small samples.
- negativebinom: For over-dispersed count data.
- inversegamma: Often used as a prior in Bayesian inference.
- categorical: For data with more than two categories.
- Quantile: For quantile regression using the pinball loss.
These distribution classes make MambularLSS versatile in modeling various data types and distributions.
To integrate distributional regression into your workflow with MambularLSS
, start by initializing the model with your desired configuration, similar to other Mambular models:
from mambular.models import MambularLSS
# Initialize the MambularLSS model
model = MambularLSS(
dropout=0.2,
d_model=64,
n_layers=8,
)
# Fit the model to your data
model.fit(
X,
y,
max_epochs=150,
lr=1e-04,
patience=10,
family="normal" # define your distribution
)
Mambular allows users to easily integrate their custom models into the existing logic. This process is designed to be straightforward, making it simple to create a PyTorch model and define its forward pass. Instead of inheriting from nn.Module
, you inherit from Mambular's BaseModel
. Each Mambular model takes three main arguments: the number of classes (e.g., 1 for regression or 2 for binary classification), cat_feature_info
, and num_feature_info
for categorical and numerical feature information, respectively. Additionally, you can provide a config argument, which can either be a custom configuration or one of the provided default configs.
One of the key advantages of using Mambular is that the inputs to the forward passes are lists of tensors. While this might be unconventional, it is highly beneficial for models that treat different data types differently. For example, the TabTransformer model leverages this feature to handle categorical and numerical data separately, applying different transformations and processing steps to each type of data.
Here's how you can implement a custom model with Mambular:
-
First, define your config:
The configuration class allows you to specify hyperparameters and other settings for your model. This can be done using a simple dataclass.from dataclasses import dataclass @dataclass class MyConfig: lr: float = 1e-04 lr_patience: int = 10 weight_decay: float = 1e-06 lr_factor: float = 0.1
-
Second, define your model:
Define your custom model just as you would for annn.Module
. The main difference is that you will inherit fromBaseModel
and use the provided feature information to construct your layers. To integrate your model into the existing API, you only need to define the architecture and the forward pass.from mambular.base_models import BaseModel from mambular.utils.get_feature_dimensions import get_feature_dimensions import torch import torch.nn class MyCustomModel(BaseModel): def __init__( self, cat_feature_info, num_feature_info, num_classes: int = 1, config=None, **kwargs, ): super().__init__(**kwargs) self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"]) input_dim = get_feature_dimensions(num_feature_info, cat_feature_info) self.linear = nn.Linear(input_dim, num_classes) def forward(self, num_features, cat_features): x = num_features + cat_features x = torch.cat(x, dim=1) # Pass through linear layer output = self.linear(x) return output
-
Leverage the Mambular API:
You can build a regression, classification, or distributional regression model that can leverage all of Mambular's built-in methods by using the following:from mambular.models import SklearnBaseRegressor class MyRegressor(SklearnBaseRegressor): def __init__(self, **kwargs): super().__init__(model=MyCustomModel, config=MyConfig, **kwargs)
-
Train and evaluate your model:
You can now fit, evaluate, and predict with your custom model just like with any other Mambular model. For classification or distributional regression, inherit fromSklearnBaseClassifier
orSklearnBaseLSS
respectively.regressor = MyRegressor(numerical_preprocessing="ple") regressor.fit(X_train, y_train, max_epochs=50)
If you prefer to setup custom training, preprocessing and evaluation, you can simply use the mambular.base_models
.
Just be careful that all basemodels expect lists of features as inputs. More precisely as list for numerical features and a list for categorical features. A custom training loop, with random data could look like this.
import torch
import torch.nn as nn
import torch.optim as optim
from mambular.base_models import Mambular
from mambular.configs import DefaultMambularConfig
# Dummy data and configuration
cat_feature_info = {
"cat1": {
"preprocessing": "imputer -> continuous_ordinal",
"dimension": 1,
"categories": 4,
}
} # Example categorical feature information
num_feature_info = {
"num1": {"preprocessing": "imputer -> scaler", "dimension": 1, "categories": None}
} # Example numerical feature information
num_classes = 1
config = DefaultMambularConfig() # Use the desired configuration
# Initialize model, loss function, and optimizer
model = Mambular(cat_feature_info, num_feature_info, num_classes, config)
criterion = nn.MSELoss() # Use MSE for regression; change as appropriate for your task
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Example training loop
for epoch in range(10): # Number of epochs
model.train()
optimizer.zero_grad()
# Dummy Data
num_features = [torch.randn(32, 1) for _ in num_feature_info]
cat_features = [torch.randint(0, 5, (32,)) for _ in cat_feature_info]
labels = torch.randn(32, num_classes)
# Forward pass
outputs = model(num_features, cat_features)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Print loss for monitoring
print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")
If you find this project useful in your research, please consider cite:
@article{thielmann2024mambular,
title={Mambular: A Sequential Model for Tabular Deep Learning},
author={Thielmann, Anton Frederik and Kumar, Manish and Weisser, Christoph and Reuter, Arik and S{\"a}fken, Benjamin and Samiee, Soheila},
journal={arXiv preprint arXiv:2408.06291},
year={2024}
}
If you use TabulaRNN please consider to cite:
@article{thielmann2024efficiency,
title={On the Efficiency of NLP-Inspired Methods for Tabular Deep Learning},
author={Thielmann, Anton Frederik and Samiee, Soheila},
journal={arXiv preprint arXiv:2411.17207},
year={2024}
}
The entire codebase is under MIT license.