Skip to content

Commit

Permalink
learner adequacy check refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Mar 14, 2022
1 parent c31cda9 commit 1442671
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
12 changes: 6 additions & 6 deletions Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterable
import re
import warnings
from typing import Callable, Dict
from typing import Callable, Dict, Tuple

import numpy as np
import scipy
Expand Down Expand Up @@ -86,7 +86,6 @@ class Learner(ReprableWithPreprocessors):
#: A sequence of data preprocessors to apply on data prior to
#: fitting the model
preprocessors = ()
learner_adequacy_err_msg = ''

def __init__(self, preprocessors=None):
self.use_default_preprocessors = False
Expand All @@ -106,8 +105,9 @@ def fit_storage(self, data):
return self.fit(X, Y, W)

def __call__(self, data, progress_callback=None):
if not self.check_learner_adequacy(data.domain):
raise ValueError(self.learner_adequacy_err_msg)
learner_is_adequate, err_msg = self.check_learner_adequacy(data.domain)
if not learner_is_adequate:
raise ValueError(err_msg)

origdomain = data.domain

Expand Down Expand Up @@ -173,8 +173,8 @@ def active_preprocessors(self):
self.preprocessors is not type(self).preprocessors):
yield from type(self).preprocessors

def check_learner_adequacy(self, _):
return True
def check_learner_adequacy(self, _) -> Tuple[bool, str]:
return True, ""

@property
def name(self):
Expand Down
10 changes: 4 additions & 6 deletions Orange/classification/base_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
class LearnerClassification(Learner):

def check_learner_adequacy(self, domain):
is_adequate = True
err_msg = ""
if len(domain.class_vars) > 1:
is_adequate = False
self.learner_adequacy_err_msg = "Too many target variables."
err_msg = "Too many target variables."
elif not domain.has_discrete_class:
is_adequate = False
self.learner_adequacy_err_msg = "Categorical class variable expected."
return is_adequate
err_msg = "Categorical class variable expected."
return not err_msg, err_msg


class ModelClassification(Model):
Expand Down
6 changes: 4 additions & 2 deletions Orange/preprocess/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __call__(self, data, variable):
variable = data.domain[variable]
domain = domain_with_class_var(data.domain, variable)

if self.learner.check_learner_adequacy(domain):
learner_is_adequate, err_msg = self.learner.check_learner_adequacy(domain)
if learner_is_adequate:
data = data.transform(domain)
model = self.learner(data)
assert model.domain.class_var == variable
Expand All @@ -239,7 +240,8 @@ def copy(self):

def supports_variable(self, variable):
domain = Orange.data.Domain([], class_vars=variable)
return self.learner.check_learner_adequacy(domain)
learner_is_adequate, _ = self.learner.check_learner_adequacy(domain)
return learner_is_adequate


def domain_with_class_var(domain, class_var):
Expand Down
10 changes: 4 additions & 6 deletions Orange/regression/base_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
class LearnerRegression(Learner):

def check_learner_adequacy(self, domain):
is_adequate = True
err_msg = ""
if len(domain.class_vars) > 1:
is_adequate = False
self.learner_adequacy_err_msg = "Too many target variables."
err_msg = "Too many target variables."
elif not domain.has_continuous_class:
is_adequate = False
self.learner_adequacy_err_msg = "Numeric class variable expected."
return is_adequate
err_msg = "Numeric class variable expected."
return not err_msg, err_msg


class ModelRegression(Model):
Expand Down
2 changes: 1 addition & 1 deletion Orange/tests/dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DummyMulticlassLearner(SklLearner):
supports_multiclass = True

def check_learner_adequacy(self, domain):
return all(c.is_discrete for c in domain.class_vars)
return all(c.is_discrete for c in domain.class_vars), ''

def fit(self, X, Y, W):
rows, class_vars = Y.shape
Expand Down
5 changes: 3 additions & 2 deletions Orange/widgets/utils/owlearnerwidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def check_data(self):
self.Error.sparse_not_supported.clear()
if self.data is not None and self.learner is not None:
self.Error.data_error.clear()
if not self.learner.check_learner_adequacy(self.data.domain):
self.Error.data_error(self.learner.learner_adequacy_err_msg)
learner_is_adequate, err_msg = self.learner.check_learner_adequacy(self.data.domain)
if not learner_is_adequate:
self.Error.data_error(err_msg)
elif not len(self.data):
self.Error.data_error("Dataset is empty.")
elif len(ut.unique(self.data.Y)) < 2:
Expand Down

0 comments on commit 1442671

Please sign in to comment.