Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add micro averaging strategy to pearsonr metric (#16878)
Browse files Browse the repository at this point in the history
        Strategy to be used for aggregating across mini-batches.
            "macro": average the pearsonr scores for each batch.
            "micro": compute a single pearsonr score across all batches.
  • Loading branch information
zburning authored and leezu committed Dec 9, 2019
1 parent 251e6f6 commit d58f6cb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 19 deletions.
79 changes: 67 additions & 12 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,9 @@ def update(self, labels, preds):


class _BinaryClassificationMetrics(object):
"""Private container class for classification metric statistics. True/false positive and
true/false negative counts are sufficient statistics for various classification metrics.
"""Private container class for classification metric statistics.
True/false positive and true/false negative counts are sufficient statistics for various classification metrics.
This class provides the machinery to track those statistics across mini-batches of
(label, prediction) pairs.
"""
Expand Down Expand Up @@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric):
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
average : str, default 'macro'
Strategy to be used for aggregating across mini-batches.
"macro": average the pearsonr scores for each batch.
"micro": compute a single pearsonr score across all batches.
Examples
--------
Expand All @@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric):
>>> pr = mx.metric.PearsonCorrelation()
>>> pr.update(labels, predicts)
>>> print pr.get()
('pearson-correlation', 0.42163704544016178)
('pearsonr', 0.42163704544016178)
"""
def __init__(self, name='pearsonr',
output_names=None, label_names=None):
output_names=None, label_names=None, average='macro'):
self.average = average
super(PearsonCorrelation, self).__init__(
name, output_names=output_names, label_names=label_names,
has_global_stats=True)
if self.average == 'micro':
self.reset_micro()

def reset_micro(self):
self._sse_p = 0
self._mean_p = 0
self._sse_l = 0
self._mean_l = 0
self._pred_nums = 0
self._label_nums = 0
self._conv = 0

def reset(self):
self.num_inst = 0
self.sum_metric = 0.0
self.global_num_inst = 0
self.global_sum_metric = 0.0
if self.average == 'micro':
self.reset_micro()

def update_variance(self, new_values, *aggregate):
#Welford's online algorithm for variance update
count, mean, m_2 = aggregate
count += len(new_values)
delta = new_values - mean
mean += numpy.sum(delta / count)
delta_2 = new_values - mean
m_2 += numpy.sum(delta * delta_2)
return count, mean, m_2

def update_cov(self, label, pred):
self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p))

def update(self, labels, preds):
"""Updates the internal evaluation result.
Expand All @@ -1457,17 +1495,34 @@ def update(self, labels, preds):
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
check_label_shapes(label, pred, False, True)
label = label.asnumpy()
pred = pred.asnumpy()
pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
label = label.asnumpy().ravel().astype(numpy.float64)
pred = pred.asnumpy().ravel().astype(numpy.float64)
if self.average == 'macro':
pearson_corr = numpy.corrcoef(pred, label)[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
else:
self.global_num_inst += 1
self.num_inst += 1
self._label_nums, self._mean_l, self._sse_l = \
self.update_variance(label, self._label_nums, self._mean_l, self._sse_l)
self.update_cov(label, pred)
self._pred_nums, self._mean_p, self._sse_p = \
self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p)

def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
if self.average == 'macro':
return (self.name, self.sum_metric / self.num_inst)
else:
n = self._label_nums
pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1)))
return (self.name, pearsonr)

@register
class PCC(EvalMetric):
Expand Down
42 changes: 35 additions & 7 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import scipy
import json
import math
from common import with_seed
Expand Down Expand Up @@ -263,13 +264,40 @@ def test_perplexity():
assert perplexity == perplexity_expected

def test_pearsonr():
pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]])
label = mx.nd.array([[0, 1], [1, 0], [1, 0]])
pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1]
metric = mx.metric.create('pearsonr')
metric.update([label], [pred])
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected
pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
macro_pr = mx.metric.create('pearsonr', average='macro')
micro_pr = mx.metric.create('pearsonr', average='micro')

assert np.isnan(macro_pr.get()[1])
assert np.isnan(micro_pr.get()[1])

macro_pr.update([label1], [pred1])
micro_pr.update([label1], [pred1])

np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]])
label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
# Note that pred12 = pred1 + pred2; label12 = label1 + label2
pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]])
label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])

pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())

macro_pr.reset()
micro_pr.update([label2], [pred2])
macro_pr.update([label12], [pred12])
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

def cm_batch(cm):
# generate a batch yielding a given confusion matrix
Expand Down

0 comments on commit d58f6cb

Please sign in to comment.