Skip to content

Commit

Permalink
Merge pull request #283 from srahul1222/my-feature
Browse files Browse the repository at this point in the history
Add max_depth as a parameter to LifelongClassificationForest and UncertaintyForest
  • Loading branch information
levinwil authored Oct 7, 2020
2 parents c8af45d + c76fad7 commit c738804
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1,271 deletions.
65 changes: 47 additions & 18 deletions proglearn/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@ class LifelongClassificationForest(ClassificationProgressiveLearner):
tree. The remainder of the data is used to fill in voting posteriors.
This is used if 'tree_construction_proportion' is not fed to add_task.
default_finite_sample_correction : bool, default=False
Boolean indicating whether this learner will have finite sample correction
Boolean indicating whether this learner will have finite sample correction.
This is used if 'finite_sample_correction' is not fed to add_task.
default_max_depth : int, default=30
The maximum depth of a tree in the Lifelong Classification Forest.
This is used if 'max_depth' is not fed to add_task.
Methods
---
add_task(X, y, task_id)
adds a task with id task_id, given input data matrix X
and output data matrix y, to the Lifelong Classification Forest
add_transformer(X, y, transformer_id)
adds a transformer with id transformer_id, trained on given input data matrix, X
and output data matrix, y, to the Lifelong Classification Forest. Also
trains the voters and deciders from new transformer to previous tasks, and will
add_task(X, y, task_id, tree_construction_proportion, finite_sample_correction, max_depth)
adds a task with id task_id, max tree depth max_depth, given input data matrix X
and output data matrix y, to the Lifelong Classification Forest. Also splits
data for training and voting based on tree_construction_proportion and uses the
value of finite_sample_correction to determine whether the learner will have
finite sample correction.
add_transformer(X, y, transformer_id, max_depth)
adds a transformer with id transformer_id and max tree depth max_depth, trained on
given input data matrix, X, and output data matrix, y, to the Lifelong Classification Forest.
Also trains the voters and deciders from new transformer to previous tasks, and will
train voters and deciders from this transformer to all new tasks.
predict(X, task_id)
predicts class labels under task_id for each example in input data X.
Expand All @@ -46,10 +52,12 @@ def __init__(
n_estimators=100,
default_tree_construction_proportion=0.67,
default_finite_sample_correction=False,
default_max_depth=30,
):
self.n_estimators = n_estimators
self.default_tree_construction_proportion = default_tree_construction_proportion
self.default_finite_sample_correction = default_finite_sample_correction
self.default_max_depth = default_max_depth
self.pl = ClassificationProgressiveLearner(
default_transformer_class=TreeClassificationTransformer,
default_transformer_kwargs={},
Expand All @@ -68,10 +76,14 @@ def add_task(
task_id=None,
tree_construction_proportion=None,
finite_sample_correction=None,
max_depth=None,
):
"""
adds a task with id task_id, given input data matrix X
and output data matrix y, to the Lifelong Classification Forest
adds a task with id task_id, max tree depth max_depth, given input data matrix X
and output data matrix y, to the Lifelong Classification Forest. Also splits
data for training and voting based on tree_construction_proportion and uses the
value of finite_sample_correction to determine whether the learner will have
finite sample correction.
Parameters
---
Expand All @@ -86,13 +98,18 @@ def add_task(
tree. The remainder of the data is used to fill in voting posteriors.
The default is used if 'None' is provided.
finite_sample_correction : bool, default=False
Boolean indicating whether this learner will have finite sample correction
Boolean indicating whether this learner will have finite sample correction.
The default is used if 'None' is provided.
max_depth : int, default=30
The maximum depth of a tree in the Lifelong Classification Forest.
The default is used if 'None' is provided.
"""
if tree_construction_proportion is None:
tree_construction_proportion = self.default_tree_construction_proportion
if finite_sample_correction is None:
finite_sample_correction = self.default_finite_sample_correction
if max_depth is None:
max_depth = self.default_max_depth

self.pl.add_task(
X,
Expand All @@ -104,6 +121,7 @@ def add_task(
0,
],
num_transformers=self.n_estimators,
transformer_kwargs={"kwargs": {"max_depth": max_depth}},
voter_kwargs={
"classes": np.unique(y),
"finite_sample_correction": finite_sample_correction,
Expand All @@ -112,11 +130,11 @@ def add_task(
)
return self

def add_transformer(self, X, y, transformer_id=None):
def add_transformer(self, X, y, transformer_id=None, max_depth=None):
"""
adds a transformer with id transformer_id, trained on given input data matrix, X
and output data matrix, y, to the Lifelong Classification Forest. Also
trains the voters and deciders from new transformer to previous tasks, and will
adds a transformer with id transformer_id and max tree depth max_depth, trained on
given input data matrix, X, and output data matrix, y, to the Lifelong Classification Forest.
Also trains the voters and deciders from new transformer to previous tasks, and will
train voters and deciders from this transformer to all new tasks.
Parameters
Expand All @@ -127,10 +145,17 @@ def add_transformer(self, X, y, transformer_id=None):
The output (response) data matrix.
transformer_id : obj, default=None
The id corresponding to the transformer being added.
max_depth : int, default=30
The maximum depth of a tree in the UncertaintyForest.
The default is used if 'None' is provided.
"""
if max_depth is None:
max_depth = self.default_max_depth

self.pl.add_transformer(
X,
y,
transformer_kwargs={"kwargs": {"max_depth": max_depth}},
transformer_id=transformer_id,
num_transformers=self.n_estimators,
)
Expand Down Expand Up @@ -172,11 +197,13 @@ class UncertaintyForest:
---
lf : LifelongClassificationForest
A lifelong classification forest object
n_estimators : int
n_estimators : int, default=100
The number of trees in the UncertaintyForest
finite_sample_correction : bool
finite_sample_correction : bool, default=False
Boolean indicating whether this learner
will use finite sample correction
max_depth : int, default=30
The maximum depth of a tree in the UncertaintyForest
Methods
---
Expand All @@ -188,9 +215,10 @@ class UncertaintyForest:
estimates class posteriors for each example in input data X.
"""

def __init__(self, n_estimators=100, finite_sample_correction=False):
def __init__(self, n_estimators=100, finite_sample_correction=False, max_depth=30):
self.n_estimators = n_estimators
self.finite_sample_correction = finite_sample_correction
self.max_depth = max_depth

def fit(self, X, y):
"""
Expand All @@ -206,6 +234,7 @@ def fit(self, X, y):
self.lf = LifelongClassificationForest(
n_estimators=self.n_estimators,
default_finite_sample_correction=self.finite_sample_correction,
default_max_depth=max_depth,
)
self.lf.add_task(X, y, task_id=0)
return self
Expand Down
Loading

0 comments on commit c738804

Please sign in to comment.