From 6e28da802e483fc741056a1768c825737c840cca Mon Sep 17 00:00:00 2001 From: Li Qinbin Date: Mon, 6 May 2019 22:21:18 +0800 Subject: [PATCH] support balanced class_weight #141 --- python/README.md | 4 +-- python/thundersvm/thundersvm.py | 52 +++++++++++++++++++++++----- src/thundersvm/thundersvm-scikit.cpp | 2 +- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/python/README.md b/python/README.md index b0051fa8..5680a028 100644 --- a/python/README.md +++ b/python/README.md @@ -69,8 +69,8 @@ The usage of thundersvm scikit interface is similar to sklearn.svm. *probability*: boolean, optional(default=False)\ whether to train a SVC or SVR model for probability estimates, True or False -*class_weight*: {dict}, optional(default=None)\ - set the parameter C of class i to weight*C, for C-SVC +*class_weight*: {dict, 'balanced'}, optional(default=None)\ + set the parameter C of class i to weight*C, for C-SVC. If not given, all classes are supposed to have weight one. The “balanced” mode uses the values of y to automatically adjust weights inversely proportional to class frequencies in the input data as ```n_samples / (n_classes * np.bincount(y))``` *shrinking*: boolean, optional (default=False, not supported yet for True)\ whether to use the shrinking heuristic. diff --git a/python/thundersvm/thundersvm.py b/python/thundersvm/thundersvm.py index 4c8afd8c..f404a1bb 100644 --- a/python/thundersvm/thundersvm.py +++ b/python/thundersvm/thundersvm.py @@ -185,12 +185,30 @@ def _dense_fit(self, X, y, solver_type, kernel): if self.class_weight is None: weight_size = 0 self.class_weight = dict() + weight_label = (c_int * weight_size)() + weight_label[:] = list(self.class_weight.keys()) + weight = (c_float * weight_size)() + weight[:] = list(self.class_weight.values()) + elif self.class_weight == 'balanced': + y_unique = np.unique(y) + y_count = np.bincount(y.astype(int)) + weight_label_list = [] + weight_list = [] + for n in range(0, len(y_count)): + if y_count[n] != 0: + weight_label_list.append(n) + weight_list.append(samples/(len(y_unique)*y_count[n])) + weight_size=len(weight_list) + weight_label = (c_int * weight_size)() + weight_label[:] = weight_label_list + weight = (c_float * weight_size)() + weight[:] = weight_list else: weight_size = len(self.class_weight) - weight_label = (c_int * weight_size)() - weight_label[:] = list(self.class_weight.keys()) - weight = (c_float * weight_size)() - weight[:] = list(self.class_weight.values()) + weight_label = (c_int * weight_size)() + weight_label[:] = list(self.class_weight.keys()) + weight = (c_float * weight_size)() + weight[:] = list(self.class_weight.values()) n_features = (c_int * 1)() n_classes = (c_int * 1)() @@ -228,12 +246,30 @@ def _sparse_fit(self, X, y, solver_type, kernel): if self.class_weight is None: weight_size = 0 self.class_weight = dict() + weight_label = (c_int * weight_size)() + weight_label[:] = list(self.class_weight.keys()) + weight = (c_float * weight_size)() + weight[:] = list(self.class_weight.values()) + elif self.class_weight == 'balanced': + y_unique = np.unique(y) + y_count = np.bincount(y.astype(int)) + weight_label_list = [] + weight_list = [] + for n in range(0, len(y_count)): + if y_count[n] != 0: + weight_label_list.append(n) + weight_list.append(X.shape[0]/(len(y_unique)*y_count[n])) + weight_size=len(weight_list) + weight_label = (c_int * weight_size)() + weight_label[:] = weight_label_list + weight = (c_float * weight_size)() + weight[:] = weight_list else: weight_size = len(self.class_weight) - weight_label = (c_int * weight_size)() - weight_label[:] = list(self.class_weight.keys()) - weight = (c_float * weight_size)() - weight[:] = list(self.class_weight.values()) + weight_label = (c_int * weight_size)() + weight_label[:] = list(self.class_weight.keys()) + weight = (c_float * weight_size)() + weight[:] = list(self.class_weight.values()) n_features = (c_int * 1)() n_classes = (c_int * 1)() diff --git a/src/thundersvm/thundersvm-scikit.cpp b/src/thundersvm/thundersvm-scikit.cpp index bdbd6bf2..0bc1fb55 100644 --- a/src/thundersvm/thundersvm-scikit.cpp +++ b/src/thundersvm/thundersvm-scikit.cpp @@ -221,7 +221,7 @@ extern "C" { param_cmd.max_mem_size = static_cast(max(max_mem_size, 0)) << 20; if(weight_size != 0) { param_cmd.nr_weight = weight_size; - param_cmd.weight = (float_type *) malloc(weight_size * sizeof(float_type)); + param_cmd.weight = (float_type *) malloc(weight_size * sizeof(double)); param_cmd.weight_label = (int *) malloc(weight_size * sizeof(int)); for (int i = 0; i < weight_size; i++) { param_cmd.weight[i] = weight[i];