Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Random forest classification accuracy gap #3764

Closed
RAMitchell opened this issue Apr 18, 2021 · 6 comments · Fixed by #3869 or #4191
Closed

[BUG] Random forest classification accuracy gap #3764

RAMitchell opened this issue Apr 18, 2021 · 6 comments · Fixed by #3869 or #4191
Assignees
Labels
bug Something isn't working

Comments

@RAMitchell
Copy link
Contributor

This is a specific diagnosis for the random forest classification accuracy issue described in #2518. The following script reproduces the accuracy gap between cuml/sklearn on a small amount of synthetic data, and implements a hacky fix in python demonstrating how we can fix the accuracy gap.

from cuml import RandomForestClassifier as cuRF
from sklearn.ensemble import RandomForestClassifier as sklRF
from sklearn.datasets import make_classification
from sklearn.metrics import accuracy_score
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import json
matplotlib.use('Agg')
sns.set()
def update_tree_statistics(X, y, tree):
    for i in range(X.shape[0]):
        node = tree
        while True:
            if 'leaf_value' in node:
                if 'true_count' not in node:
                    node['true_count'] = 0
                node['true_count'] += y[i]
                break
            else:
                split_feature = node['split_feature']
                split_threshold = node['split_threshold']
                if X[i, split_feature] <= split_threshold + 1e-5:
                    node = node['children'][0]
                else:
                    node = node['children'][1]
def add_predict(X, tree, proba):
    for i in range(X.shape[0]):
        node = tree
        while True:
            if 'leaf_value' in node:
                proba[i, 1] += node['true_count'] / node['instance_count']
                break
            else:
                split_feature = node['split_feature']
                split_threshold = node['split_threshold']
                if X[i, split_feature] <= split_threshold + 1e-5:
                    node = node['children'][0]
                else:
                    node = node['children'][1]
def custom_rf_predict(X, y, cuml_clf, predict_proba=False):
    forest = json.loads(cuml_clf.get_json())
    proba = np.zeros((X.shape[0], 2))
    # First get leaf statistics
    for tree in forest:
        update_tree_statistics(X, y, tree)
    for tree in forest:
        add_predict(X, tree, proba)
    proba[:, 1] /= len(forest)
    proba[:, 0] = 1 - proba[:, 1]
    if predict_proba:
        return proba
    return np.argmax(proba, axis=1)
X, y = make_classification(n_samples=20, n_features=2,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False, flip_y=0.0)
X = X.astype(np.float32)
y = y.astype(np.float32)
rs = np.random.RandomState(92)
df = pd.DataFrame(columns=["algorithm", "accuracy", "depth"])
depths = [x for x in range(1, 4)]
n_repeats = 100
bootstrap = False
max_samples = 1.0
max_features = 0.5
n_estimators = 10
n_bins = min(X.shape[0], 256)
worst_cuml = None
worst_skl = None
for d in depths:
    for _ in range(n_repeats):
        clf = sklRF(n_estimators=n_estimators, max_depth=d, random_state=rs,
                    max_features=max_features, bootstrap=bootstrap,
                    max_samples=max_samples if max_samples < 1.0 else None)
        clf.fit(X, y)
        pred = clf.predict(X)
        cu_clf = cuRF(n_estimators=n_estimators, max_depth=d, random_state=rs.randint(0, 1 << 32),
                      n_bins=n_bins, max_features=max_features, bootstrap=bootstrap,
                      max_samples=max_samples)
        cu_clf.fit(X, y)
        cu_pred = cu_clf.predict(X, predict_model='CPU')
        adjusted_cu_pred = custom_rf_predict(X, y, cu_clf)
        skl_accuracy = accuracy_score(y, pred)
        cu_accuracy = accuracy_score(y, cu_pred)
        adjusted_cu_accuracy = accuracy_score(y, adjusted_cu_pred)
        if cu_accuracy <= 1.0:
            worst_cuml = cu_clf
            worst_skl = clf
        df = df.append({"algorithm": "cuml", "accuracy": cu_accuracy, "depth": d},
                       ignore_index=True)
        df = df.append({"algorithm": "adjusted_cuml", "accuracy": adjusted_cu_accuracy, "depth": d},
                       ignore_index=True)
        df = df.append({"algorithm": "sklearn", "accuracy": skl_accuracy, "depth": d},
                       ignore_index=True)
sns.lineplot(data=df, x="depth", y="accuracy", hue="algorithm")
plt.savefig("rf.png")

rf

Diagnosis

Sklearn predicts labels in random forest classifiers by obtaining class probabilities from each component tree then averaging these class probabilities over the ensemble members, finally outputting the highest probability label.

Cuml rf predicts labels by generating a label prediction for each tree (as opposed to a probability) and then outputting the mode (the most frequently occurring label).

Consider an ensemble containing two decision stumps. We make a binary classification (0-1) prediction for label x, which ends up in the left node of both trees. We have statistics in the tree leaves indicating how many training instances were positive or negative.

Tree A -> num_positive = 11, num_negative=9
Tree B -> num_positive = 1, num_negative=19

See that Tree A predicts 1 with low confidence (p(1)=11/20) and Tree B predicts 0 with high confidence (p(0)=19/20).

Sklearn averages the estimates of both trees to obtain probabilities p(0)=28/40, p(1)=12/40, and so outputs the label 0.

cuml obtains majority predictions, 1 for Tree A, and 0 for Tree B, and yields probability estimates p(0)=1/2 p(1)=1/2, with the output label being 0 or 1 depending on the rounding scheme.

So cuml is discarding confidence information from individual trees, leading to less accurate predictions.

Fix

Class label statistics must be stored in order to output the same probability scores as skearn. In the case of multiclass classification this means storing vectors at leaf nodes. As cuml is using FIL for GPU predictions, FIL should support predictions from vector leaves.

@RAMitchell RAMitchell added ? - Needs Triage Need team to review and classify bug Something isn't working labels Apr 18, 2021
@viclafargue viclafargue removed the ? - Needs Triage Need team to review and classify label Apr 19, 2021
@teju85
Copy link
Member

teju85 commented Apr 20, 2021

@RAMitchell Thank you for finding this! We certainly need to dig more on this one. Can you please post the same plot when you increase the depth, for eg: all the way to 16?

EDIT: on second thought. @venkywonka can you please repeat the above experiment for deeper depths and report back?

@venkywonka
Copy link
Contributor

@RAMitchell Thank you for finding this! We certainly need to dig more on this one. Can you please post the same plot when you increase the depth, for eg: all the way to 16?

EDIT: on second thought. @venkywonka can you please repeat the above experiment for deeper depths and report back?

sure @teju85

@venkywonka
Copy link
Contributor

venkywonka commented Apr 20, 2021

For the given dataset generated with make_classification parameters (n_samples=20, n_features=2, n_redundant=0), the models start quickly overfitting at depth 4 itself as seen below:

image

So I introduced some noise by adding redundant features but they all seem to quickly overfit beyond depth 4.

I tried bumping up the samples from 20 to 20000, with n_informative=3 and n_redundant=2, for deeper depths (atleast until they start overfitting) and we get something like this:

image

EDIT: The following gives a better idea of how they fair for some variations of n_samples, n_features, n_informative_features
(don't mind the blank plots, they correspond to error-handled invalid make_classification params when shmoo-ed):

image


TLDR
There seems to be a marginal bias introduced by cuml's method of ensembling in comparison to sklearn's that persists through increasing depths, causing a minor lag (wrt depth) to overfit. As of now, will update the docs notifying users regarding this, but definitely is something to fix in the future 😅
Running it on real-world datasets like higgs but they take too long with sklearn so will update here once that gets done ✌🏾

@venkywonka
Copy link
Contributor

venkywonka commented Apr 22, 2021

In my previous plots, I seem to have have incorrectly added noise, apologies 😅(thanks for pointing out @vinaydes )
I have re-run the same by adding noise through the flip_y parameter in sklearn.datasets.make_classification . The noise prevents models from reaching accuracy=1.0 (since that's not very helpful).
the following plot is with @RAMitchell 's initial (and very helpful) script; but with

X, y = make_classification(n_samples=2000, n_features=2000, flip_y=0.1, random_state=0)

This is so that it's more representative of real-world datasets. We see something like this

image


TLDR

  • For some reason, cuml (and adjusted cuml) seems to fit better than sklearn 🤔
  • The accuracy issue (gap between adjusted cuml and cuml) seems to reduce on higher depths and the disparity for depth<3 does not seem to be representative of the whole behaviour.

@RAMitchell
Copy link
Contributor Author

If you are adding large amounts of noise to the dataset you might consider regenerating the dataset inside the repetitions so the error bars include dataset variation.

rapids-bot bot pushed a commit that referenced this issue Apr 22, 2021
#3776)

This small PR adds details regarding accuracy issue detailed [here](#3764) as a known limitation for users of Random Forest Classifier.

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3776
@hcho3
Copy link
Contributor

hcho3 commented May 17, 2021

@RAMitchell I re-ran your script with my prototype #3862 and now the accuracy gap is closed:
rf

rapids-bot bot pushed a commit that referenced this issue May 27, 2021
…ementaton) (#3869)

Alternative implementation of #3862 that does not depend on #3854
Closes #3764
Closes #2518

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Vinay Deshpande (https://github.com/vinaydes)

URL: #3869
rapids-bot bot pushed a commit that referenced this issue Sep 17, 2021
Fixes #3764,#2518

To do:
- post charts confirming the improvement in accuracy
- address python tests
- benchmark

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Vinay Deshpande (https://github.com/vinaydes)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #4191
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
rapidsai#3776)

This small PR adds details regarding accuracy issue detailed [here](rapidsai#3764) as a known limitation for users of Random Forest Classifier.

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3776
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
…ementaton) (rapidsai#3869)

Alternative implementation of rapidsai#3862 that does not depend on rapidsai#3854
Closes rapidsai#3764
Closes rapidsai#2518

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Vinay Deshpande (https://github.com/vinaydes)

URL: rapidsai#3869
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this issue Oct 9, 2023
Fixes rapidsai#3764,rapidsai#2518

To do:
- post charts confirming the improvement in accuracy
- address python tests
- benchmark

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Vinay Deshpande (https://github.com/vinaydes)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4191
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment