Skip to content

Commit

Permalink
[ENH] Add sample weight for fit (#7)
Browse files Browse the repository at this point in the history
* feat(fit): add sample_weight

* fix(pytest): remove unused operator

* Polish up docstrings

* Update CHANGELOG.rst

Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
  • Loading branch information
tczhao and xuyxu authored Feb 3, 2021
1 parent e19358d commit 5e80183
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ Version 0.1.*
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`

- |Feature| support sample weight in :meth:`fit` (`#7 <https://github.com/LAMDA-NJU/Deep-Forest/pull/7>`__) @tczhao
- |Feature| configurable predictor parameter (`#9 <https://github.com/LAMDA-NJU/Deep-Forest/issues/10>`__) @tczhao
- |Enhancement| add base class ``BaseEstimator`` and ``ClassifierMixin`` (`#8 <https://github.com/LAMDA-NJU/Deep-Forest/pull/8>`__) @pjgao
4 changes: 2 additions & 2 deletions deepforest/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(
def oob_decision_function_(self):
return self.estimator_.oob_decision_function_

def fit_transform(self, X, y):
self.estimator_.fit(X, y)
def fit_transform(self, X, y, sample_weight=None):
self.estimator_.fit(X, y, sample_weight)
X_aug = self.estimator_.oob_decision_function_

return X_aug
Expand Down
11 changes: 7 additions & 4 deletions deepforest/_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ def _build_estimator(
oob_decision_function,
partial_mode=True,
buffer=None,
verbose=1
verbose=1,
sample_weight=None
):
"""Private function used to fit a single estimator."""
if verbose > 1:
msg = "{} - Fitting estimator = {:<5} in layer = {}"
key = estimator_name + "_" + str(estimator_idx)
print(msg.format(_utils.ctime(), key, layer_idx))

X_aug_train = estimator.fit_transform(X, y)
X_aug_train = estimator.fit_transform(X, y, sample_weight)
oob_decision_function += estimator.oob_decision_function_

if partial_mode:
Expand Down Expand Up @@ -107,7 +108,7 @@ def _validate_params(self):
msg = "`n_trees` = {} should be strictly positive."
raise ValueError(msg.format(self.n_trees))

def fit_transform(self, X, y):
def fit_transform(self, X, y, sample_weight=None):

self._validate_params()
n_samples, _ = X.shape
Expand All @@ -128,6 +129,7 @@ def fit_transform(self, X, y):
self.partial_mode,
self.buffer,
self.verbose,
sample_weight,
)
X_aug.append(X_aug_)
key = "{}-{}-{}".format(self.layer_idx, estimator_idx, "rf")
Expand All @@ -145,6 +147,7 @@ def fit_transform(self, X, y):
self.partial_mode,
self.buffer,
self.verbose,
sample_weight,
)
X_aug.append(X_aug_)
key = "{}-{}-{}".format(self.layer_idx, estimator_idx, "erf")
Expand All @@ -153,7 +156,7 @@ def fit_transform(self, X, y):
# Set the OOB estimations and validation accuracy
self.oob_decision_function_ = oob_decision_function / self.n_estimators
y_pred = np.argmax(oob_decision_function, axis=1)
self.val_acc_ = accuracy_score(y, y_pred)
self.val_acc_ = accuracy_score(y, y_pred, sample_weight=sample_weight)

X_aug = np.hstack(X_aug)
return X_aug
Expand Down
19 changes: 15 additions & 4 deletions deepforest/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def predict(self, X):
def n_aug_features_(self):
return 2 * self.n_estimators * self.n_outputs_

def fit(self, X, y):
# flake8: noqa: E501
def fit(self, X, y, sample_weight=None):
"""
Build a deep forest using the training data.
Expand All @@ -449,6 +450,8 @@ def fit(self, X, y):
``np.uint8``.
y : :obj:`numpy.ndarray` of shape (n_samples,)
The class labels of input samples.
sample_weight : :obj:`numpy.ndarray` of shape (n_samples,), default=None
Sample weights. If ``None``, then samples are equally weighted.
"""
self._check_input(X, y)
self._validate_params()
Expand Down Expand Up @@ -491,7 +494,7 @@ def fit(self, X, y):
print("{} Fitting cascade layer = {:<2}".format(_utils.ctime(), 0))

tic = time.time()
X_aug_train_ = layer_.fit_transform(X_train_, y)
X_aug_train_ = layer_.fit_transform(X_train_, y, sample_weight)
toc = time.time()
training_time = toc - tic

Expand Down Expand Up @@ -567,7 +570,11 @@ def fit(self, X, y):
print(msg.format(_utils.ctime(), layer_idx))

tic = time.time()
X_aug_train_ = layer_.fit_transform(X_middle_train_, y)
X_aug_train_ = layer_.fit_transform(
X_middle_train_,
y,
sample_weight
)
toc = time.time()
training_time = toc - tic

Expand Down Expand Up @@ -667,7 +674,11 @@ def fit(self, X, y):
print(msg.format(_utils.ctime(), self.predictor_name))

tic = time.time()
self.predictor_.fit(X_middle_train_, y)
self.predictor_.fit(
X_middle_train_,
y,
sample_weight=sample_weight
)
toc = time.time()

if self.verbose > 0:
Expand Down
7 changes: 6 additions & 1 deletion deepforest/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def _parallel_build_trees(
X,
y,
n_samples_bootstrap,
sample_weight,
out,
lock
lock,
):
"""
Private function used to fit a single tree in parallel."""
Expand All @@ -107,8 +108,11 @@ def _parallel_build_trees(
n_samples_bootstrap)

# Fit the tree on the bootstrapped samples
if sample_weight is not None:
sample_weight = sample_weight[sample_mask]
feature, threshold, children, value = tree.fit(X[sample_mask],
y[sample_mask],
sample_weight=sample_weight,
check_input=False)

if not children.flags["C_CONTIGUOUS"]:
Expand Down Expand Up @@ -422,6 +426,7 @@ def fit(self, X, y, sample_weight=None):
X,
y,
n_samples_bootstrap,
sample_weight,
oob_decision_function,
lock)
for i, t in enumerate(trees))
Expand Down
33 changes: 32 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import pytest
import shutil
from numpy.testing import assert_array_equal
import numpy as np
from numpy.testing import assert_array_equal, assert_raises
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

Expand Down Expand Up @@ -126,6 +127,36 @@ def test_model_workflow_partial_mode():
shutil.rmtree(save_dir)


def test_model_sample_weight():
"""Run the workflow of deep forest with a local buffer."""

case_kwargs = copy.deepcopy(kwargs)

# Training without sample_weight
model = CascadeForestClassifier(**case_kwargs)
model.fit(X_train, y_train)
y_pred_no_sample_weight = model.predict(X_test)

# Training with equal sample_weight
model = CascadeForestClassifier(**case_kwargs)
sample_weight = np.ones(y_train.size)
model.fit(X_train, y_train, sample_weight=sample_weight)
y_pred_equal_sample_weight = model.predict(X_test)

# Make sure the same predictions with None and equal sample_weight
assert_array_equal(y_pred_no_sample_weight, y_pred_equal_sample_weight)

model = CascadeForestClassifier(**case_kwargs)
sample_weight = np.where(y_train == 0, 0.1, y_train)
model.fit(X_train, y_train, sample_weight=y_train)
y_pred_skewed_sample_weight = model.predict(X_test)

# Make sure the different predictions with None and equal sample_weight
assert_raises(AssertionError, assert_array_equal, y_pred_skewed_sample_weight, y_pred_equal_sample_weight)

model.clean() # clear the buffer


def test_model_workflow_in_memory():
"""Run the workflow of deep forest with in-memory mode."""

Expand Down

0 comments on commit 5e80183

Please sign in to comment.