Skip to content

Commit

Permalink
Merge pull request #2093 from pavlin-policar/fix-fitter-preprocessors
Browse files Browse the repository at this point in the history
[FIX] Fitter: Properly delegate preprocessors
  • Loading branch information
astaric authored Mar 10, 2017
2 parents 81afdde + cdf77bb commit 8364ae8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
6 changes: 6 additions & 0 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def _fit_model(self, data):
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
return learner.fit(X, Y, W)

def preprocess(self, data):
if data.domain.has_discrete_class:
return self.get_learner(self.CLASSIFICATION).preprocess(data)
else:
return self.get_learner(self.REGRESSION).preprocess(data)

def get_learner(self, problem_type):
"""Get the learner for a given problem type.
Expand Down
26 changes: 24 additions & 2 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from unittest.mock import Mock, patch

from Orange.classification.base_classification import LearnerClassification
from Orange.data import Table
from Orange.data import Table, ContinuousVariable
from Orange.evaluation import CrossValidation
from Orange.modelling import Fitter
from Orange.preprocess import Randomize
from Orange.preprocess import Randomize, Discretize
from Orange.regression.base_regression import LearnerRegression


Expand Down Expand Up @@ -130,3 +130,25 @@ def test_correctly_sets_preprocessors_on_learner(self):
def test_n_jobs_fitting(self):
with patch('Orange.evaluation.testing.CrossValidation._MIN_NJOBS_X_SIZE', 1):
CrossValidation(self.heart_disease, [DummyFitter()], k=5, n_jobs=5)

def test_properly_delegates_preprocessing(self):
class DummyClassificationLearner(LearnerClassification):
preprocessors = [Discretize()]

def __init__(self, classification_param=1, **_):
super().__init__()
self.param = classification_param

class DummyFitter(Fitter):
__fits__ = {'classification': DummyClassificationLearner,
'regression': DummyRegressionLearner}

data = self.heart_disease
fitter = DummyFitter()
# Sanity check
self.assertTrue(any(
isinstance(v, ContinuousVariable) for v in data.domain.variables))
# Preprocess the data and check that the discretization was applied
pp_data = fitter.preprocess(self.heart_disease)
self.assertTrue(not any(
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))

0 comments on commit 8364ae8

Please sign in to comment.