Skip to content

Building a tabular binary classification neural network to predict Telco's Customer Churn from their publicly available dataset on Kaggle.

Notifications You must be signed in to change notification settings

SpaceFrostDev/telco_kaggle

Repository files navigation

Customer Churn Prediction

I developed, to my knowledge, the most accurate predictive model on Kaggle's Telco Customer Churn dataset with a validation set prediction accuracy of 91.30%. Below I offer an executive summary entailing how I tackle tabular binary classification problems.
perf

Our Features

Features are displayed along with their corresponding possible unique values.
the_features

The Target

In this dataset, we are predicting whether or not a customer will discontinue their current service contract with Telco.
the_targets

The libraries

To install the requisite libraries run conda env create -f environment.yml from this repositories root directory to initialize the vitrual environment and install all required libraries.
The primary dependencies for this project are PyTorch, imblearn, sklearn, and pandas.

The Model

A simple dense neural network with three layers, halfing in size after each. No embedding matrix is used as there are no high-cardinality categorical variables with which to contend.
deeper_model

The hyper-parameters

Loss function (criterion): nn.BCEWithLogitsLoss().
Optimizer and learning rate: optim.Adam(deeper_model.parameters(), lr=1e-2).
Validation set size: 20%.
Batch size: 827.
Instances of each class (after oversampling): 5174.
Epochs: 1000.

Two custom functions

I wrote two custom functions print_unique() and training_loop(), both of which may be found in my notebook. UPDATE Since this repository's creation, training_loop has been significantly updated in my custom_deep_learning repository.

My 'Trick' to the most accurate model

If you've made it this far, take a peek inside my notebook. I'll give you a free cookie :) But, if you're in a rush, to summarize: The trick to achieving the highest accuracy out of any model on Kaggle, was to use Adam as my optimizer, which is more powerful than plain vanilla optim.SGD, to synthetically oversample the minority class SMOTE, and use a neural network with sufficient capacity, which happened to be two layers. nn.BatchNorm2d, nn.Dropout and nn.Embedding are tricks which may be used to improve generalization with a model such as this but, I elected to not implement them here.

About

Building a tabular binary classification neural network to predict Telco's Customer Churn from their publicly available dataset on Kaggle.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published