diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 6e2d66cb9d15..d1074c923337 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -590,8 +590,9 @@ def update(self, labels, preds): class _BinaryClassificationMetrics(object): - """Private container class for classification metric statistics. True/false positive and - true/false negative counts are sufficient statistics for various classification metrics. + """Private container class for classification metric statistics. + + True/false positive and true/false negative counts are sufficient statistics for various classification metrics. This class provides the machinery to track those statistics across mini-batches of (label, prediction) pairs. """ @@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric): label_names : list of str, or None Name of labels that should be used when updating with update_dict. By default include all labels. + average : str, default 'macro' + Strategy to be used for aggregating across mini-batches. + "macro": average the pearsonr scores for each batch. + "micro": compute a single pearsonr score across all batches. Examples -------- @@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric): >>> pr = mx.metric.PearsonCorrelation() >>> pr.update(labels, predicts) >>> print pr.get() - ('pearson-correlation', 0.42163704544016178) + ('pearsonr', 0.42163704544016178) """ def __init__(self, name='pearsonr', - output_names=None, label_names=None): + output_names=None, label_names=None, average='macro'): + self.average = average super(PearsonCorrelation, self).__init__( name, output_names=output_names, label_names=label_names, has_global_stats=True) + if self.average == 'micro': + self.reset_micro() + + def reset_micro(self): + self._sse_p = 0 + self._mean_p = 0 + self._sse_l = 0 + self._mean_l = 0 + self._pred_nums = 0 + self._label_nums = 0 + self._conv = 0 + + def reset(self): + self.num_inst = 0 + self.sum_metric = 0.0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 + if self.average == 'micro': + self.reset_micro() + + def update_variance(self, new_values, *aggregate): + #Welford's online algorithm for variance update + count, mean, m_2 = aggregate + count += len(new_values) + delta = new_values - mean + mean += numpy.sum(delta / count) + delta_2 = new_values - mean + m_2 += numpy.sum(delta * delta_2) + return count, mean, m_2 + + def update_cov(self, label, pred): + self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p)) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1457,17 +1495,34 @@ def update(self, labels, preds): Predicted values. """ labels, preds = check_label_shapes(labels, preds, True) - for label, pred in zip(labels, preds): check_label_shapes(label, pred, False, True) - label = label.asnumpy() - pred = pred.asnumpy() - pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] - self.sum_metric += pearson_corr - self.global_sum_metric += pearson_corr - self.num_inst += 1 - self.global_num_inst += 1 + label = label.asnumpy().ravel().astype(numpy.float64) + pred = pred.asnumpy().ravel().astype(numpy.float64) + if self.average == 'macro': + pearson_corr = numpy.corrcoef(pred, label)[0, 1] + self.sum_metric += pearson_corr + self.global_sum_metric += pearson_corr + self.num_inst += 1 + self.global_num_inst += 1 + else: + self.global_num_inst += 1 + self.num_inst += 1 + self._label_nums, self._mean_l, self._sse_l = \ + self.update_variance(label, self._label_nums, self._mean_l, self._sse_l) + self.update_cov(label, pred) + self._pred_nums, self._mean_p, self._sse_p = \ + self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p) + def get(self): + if self.num_inst == 0: + return (self.name, float('nan')) + if self.average == 'macro': + return (self.name, self.sum_metric / self.num_inst) + else: + n = self._label_nums + pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1))) + return (self.name, pearsonr) @register class PCC(EvalMetric): diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 0ae8aeaa697f..a1e5128d8ac6 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +import scipy import json import math from common import with_seed @@ -263,13 +264,40 @@ def test_perplexity(): assert perplexity == perplexity_expected def test_pearsonr(): - pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]]) - label = mx.nd.array([[0, 1], [1, 0], [1, 0]]) - pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1] - metric = mx.metric.create('pearsonr') - metric.update([label], [pred]) - _, pearsonr = metric.get() - assert pearsonr == pearsonr_expected + pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]]) + label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]]) + pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1] + pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel()) + macro_pr = mx.metric.create('pearsonr', average='macro') + micro_pr = mx.metric.create('pearsonr', average='micro') + + assert np.isnan(macro_pr.get()[1]) + assert np.isnan(micro_pr.get()[1]) + + macro_pr.update([label1], [pred1]) + micro_pr.update([label1], [pred1]) + + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy) + + pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]]) + label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]]) + # Note that pred12 = pred1 + pred2; label12 = label1 + label2 + pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]]) + label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]]) + + pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1] + pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel()) + + macro_pr.reset() + micro_pr.update([label2], [pred2]) + macro_pr.update([label12], [pred12]) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy) def cm_batch(cm): # generate a batch yielding a given confusion matrix