From e5b06fa208e33388feead20831130ea6859d67bd Mon Sep 17 00:00:00 2001 From: CJ Carey Date: Thu, 28 Sep 2023 14:21:14 -0400 Subject: [PATCH] Fix sklearn compat issues --- metric_learn/base_metric.py | 11 +++++++---- metric_learn/sdml.py | 4 ++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index d0ba1ef9..47efe4b7 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -2,7 +2,7 @@ Base module. """ -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils.extmath import stable_cumsum from sklearn.utils.validation import _is_arraylike, check_is_fitted from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve @@ -464,7 +464,7 @@ def get_mahalanobis_matrix(self): return self.components_.T.dot(self.components_) -class _PairsClassifierMixin(BaseMetricLearner): +class _PairsClassifierMixin(BaseMetricLearner, ClassifierMixin): """Base class for pairs learners. Attributes @@ -475,6 +475,7 @@ class _PairsClassifierMixin(BaseMetricLearner): classified as dissimilar. """ + classes_ = np.array([0, 1]) _tuple_size = 2 # number of points in a tuple, 2 for pairs def predict(self, pairs): @@ -752,11 +753,12 @@ def _validate_calibration_params(strategy='accuracy', min_rate=None, 'Got {} instead.'.format(type(beta))) -class _TripletsClassifierMixin(BaseMetricLearner): +class _TripletsClassifierMixin(BaseMetricLearner, ClassifierMixin): """ Base class for triplets learners. """ + classes_ = np.array([0, 1]) _tuple_size = 3 # number of points in a tuple, 3 for triplets def predict(self, triplets): @@ -837,11 +839,12 @@ def score(self, triplets): return self.predict(triplets).mean() / 2 + 0.5 -class _QuadrupletsClassifierMixin(BaseMetricLearner): +class _QuadrupletsClassifierMixin(BaseMetricLearner, ClassifierMixin): """ Base class for quadruplets learners. """ + classes_ = np.array([0, 1]) _tuple_size = 4 # number of points in a tuple, 4 for quadruplets def predict(self, quadruplets): diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 93f3f441..c76de99b 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -43,6 +43,9 @@ def _fit(self, pairs, y): print("SDML will use skggm's graphical lasso solver.") pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') + n_features = pairs.shape[2] + if n_features < 2: + raise ValueError(f"Cannot fit SDML with {n_features} feature(s)") # set up (the inverse of) the prior M # if the prior is the default (None), we raise a warning @@ -83,6 +86,7 @@ def _fit(self, pairs, y): w_mahalanobis, _ = np.linalg.eigh(M) not_spd = any(w_mahalanobis < 0.) not_finite = not np.isfinite(M).all() + # TODO: Narrow this to the specific exceptions we expect. except Exception as e: raised_error = e not_spd = False # not_spd not applicable here so we set to False