From 4872a8ea0afa22813f1a4446ef5fc0d608660283 Mon Sep 17 00:00:00 2001 From: Alexander Andreev Date: Fri, 9 Dec 2022 19:17:56 +0300 Subject: [PATCH] Fix for balanced class weight (#1080) * add balanced branch of _compute_class_weight * Remove extra computation of weights --- onedal/datatypes/validation.py | 11 ++++++++++- sklearnex/svm/nusvc.py | 6 +----- sklearnex/svm/svc.py | 6 +----- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/onedal/datatypes/validation.py b/onedal/datatypes/validation.py index 3c02144a27..503334826c 100644 --- a/onedal/datatypes/validation.py +++ b/onedal/datatypes/validation.py @@ -18,6 +18,7 @@ import warnings from scipy import sparse as sp from scipy.sparse import issparse, dok_matrix, lil_matrix +from sklearn.preprocessing import LabelEncoder from collections.abc import Sequence from numbers import Integral @@ -57,7 +58,15 @@ def _compute_class_weight(class_weight, classes, y): if class_weight is None or len(class_weight) == 0: weight = np.ones(classes.shape[0], dtype=np.float64, order='C') elif class_weight == 'balanced': - weight = None + y_ = _column_or_1d(y) + classes, _ = np.unique(y_, return_inverse=True) + + le = LabelEncoder() + y_ind = le.fit_transform(y_) + if not all(np.in1d(classes, le.classes_)): + raise ValueError("classes should have valid labels that are in y") + + weight = len(y_) / (len(le.classes_) * np.bincount(y_ind).astype(np.float64)) else: # user-defined dictionary weight = np.ones(classes.shape[0], dtype=np.float64, order='C') diff --git a/sklearnex/svm/nusvc.py b/sklearnex/svm/nusvc.py index ca6616701a..3a2f4ab914 100644 --- a/sklearnex/svm/nusvc.py +++ b/sklearnex/svm/nusvc.py @@ -195,11 +195,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._onedal_estimator = onedal_NuSVC(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) - - if self.class_weight == 'balanced': - self.class_weight_ = self._compute_balanced_class_weight(y) - else: - self.class_weight_ = self._onedal_estimator.class_weight_ + self.class_weight_ = self._onedal_estimator.class_weight_ if self.probability: self._fit_proba(X, y, sample_weight, queue=queue) diff --git a/sklearnex/svm/svc.py b/sklearnex/svm/svc.py index 4cdd738b25..7016187bce 100644 --- a/sklearnex/svm/svc.py +++ b/sklearnex/svm/svc.py @@ -209,11 +209,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): self._onedal_estimator = onedal_SVC(**onedal_params) self._onedal_estimator.fit(X, y, sample_weight, queue=queue) - - if self.class_weight == 'balanced': - self.class_weight_ = self._compute_balanced_class_weight(y) - else: - self.class_weight_ = self._onedal_estimator.class_weight_ + self.class_weight_ = self._onedal_estimator.class_weight_ if self.probability: self._fit_proba(X, y, sample_weight, queue=queue)