Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add micro to pearsonr #16878

Merged
merged 9 commits into from
Dec 9, 2019
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 74 additions & 15 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def reset(self):
self.global_sum_metric = 0.0

def reset_local(self):
"""Resets the local portion of the internal evaluation results to initial state."""
"""Resets the local portion of the internal evaluation results
zburning marked this conversation as resolved.
Show resolved Hide resolved
to initial state."""
self.num_inst = 0
self.sum_metric = 0.0

Expand Down Expand Up @@ -371,7 +372,8 @@ def reset(self):
pass

def reset_local(self):
"""Resets the local portion of the internal evaluation results to initial state."""
zburning marked this conversation as resolved.
Show resolved Hide resolved
"""Resets the local portion of the internal evaluation results
to initial state."""
try:
for metric in self.metrics:
metric.reset_local()
Expand Down Expand Up @@ -607,7 +609,8 @@ def __init__(self):
self.global_true_negatives = 0

def update_binary_stats(self, label, pred):
"""Update various binary classification counts for a single (label, pred) pair.
"""Update various binary classification counts for a single (label, pred)
leezu marked this conversation as resolved.
Show resolved Hide resolved
pair.

Parameters
----------
Expand Down Expand Up @@ -686,7 +689,8 @@ def global_fscore(self):
return 0.

def matthewscc(self, use_global=False):
"""Calculate the Matthew's Correlation Coefficent"""
"""Calculate the Matthew's Correlation Coefficent
"""
zburning marked this conversation as resolved.
Show resolved Hide resolved
if use_global:
if not self.global_total_examples:
return 0.
Expand Down Expand Up @@ -1430,6 +1434,10 @@ class PearsonCorrelation(EvalMetric):
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
average : str, default 'macro'
Strategy to be used for aggregating across mini-batches.
"macro": average the pearsonr scores for each batch.
"micro": compute a single pearsonr score across all batches.

Examples
--------
Expand All @@ -1438,13 +1446,46 @@ class PearsonCorrelation(EvalMetric):
>>> pr = mx.metric.PearsonCorrelation()
>>> pr.update(labels, predicts)
>>> print pr.get()
('pearson-correlation', 0.42163704544016178)
('pearsonr', 0.42163704544016178)
"""
def __init__(self, name='pearsonr',
output_names=None, label_names=None):
output_names=None, label_names=None, average='macro'):
self.average = average
super(PearsonCorrelation, self).__init__(
name, output_names=output_names, label_names=label_names,
has_global_stats=True)
if self.average == 'micro':
self.reset_micro()

def reset_micro(self):
zburning marked this conversation as resolved.
Show resolved Hide resolved
self._sse_p = 0
self._mean_p = 0
self._sse_l = 0
self._mean_l = 0
self._pred_nums = 0
self._label_nums = 0
self._conv = 0

def reset(self):
self.num_inst = 0
self.sum_metric = 0.0
self.global_num_inst = 0
self.global_sum_metric = 0.0
if self.average == 'micro':
self.reset_micro()

def update_variance(self, new_values, *aggregate):
#Welford's online algorithm for variance update
count, mean, m_2 = aggregate
count += len(new_values)
delta = new_values - mean
mean += numpy.sum(delta / count)
delta_2 = new_values - mean
m_2 += numpy.sum(delta * delta_2)
return count, mean, m_2

def update_cov(self, label, pred):
self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p))

def update(self, labels, preds):
"""Updates the internal evaluation result.
Expand All @@ -1457,17 +1498,34 @@ def update(self, labels, preds):
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
check_label_shapes(label, pred, False, True)
label = label.asnumpy()
pred = pred.asnumpy()
pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
label = label.asnumpy().ravel().astype(numpy.float64)
pred = pred.asnumpy().ravel().astype(numpy.float64)
if self.average == 'macro':
pearson_corr = numpy.corrcoef(pred, label)[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
else:
self.global_num_inst += 1
self.num_inst += 1
self._label_nums, self._mean_l, self._sse_l = \
self.update_variance(label, self._label_nums, self._mean_l, self._sse_l)
self.update_cov(label, pred)
self._pred_nums, self._mean_p, self._sse_p = \
self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p)

def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
if self.average == 'macro':
return (self.name, self.sum_metric / self.num_inst)
else:
n = self._label_nums
pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1)))
return (self.name, pearsonr)

zburning marked this conversation as resolved.
Show resolved Hide resolved
@register
class PCC(EvalMetric):
Expand Down Expand Up @@ -1597,7 +1655,8 @@ def reset(self):
self.reset_local()

def reset_local(self):
"""Resets the local portion of the internal evaluation results to initial state."""
zburning marked this conversation as resolved.
Show resolved Hide resolved
"""Resets the local portion of the internal evaluation results
to initial state."""
self.num_inst = 0.
self.lcm = numpy.zeros((self.k, self.k))

Expand Down
42 changes: 35 additions & 7 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import scipy
import json
import math
from common import with_seed
Expand Down Expand Up @@ -263,13 +264,40 @@ def test_perplexity():
assert perplexity == perplexity_expected

def test_pearsonr():
pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]])
label = mx.nd.array([[0, 1], [1, 0], [1, 0]])
pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1]
metric = mx.metric.create('pearsonr')
metric.update([label], [pred])
_, pearsonr = metric.get()
assert pearsonr == pearsonr_expected
pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel())
macro_pr = mx.metric.create('pearsonr', average='macro')
micro_pr = mx.metric.create('pearsonr', average='micro')

assert np.isnan(macro_pr.get()[1])
assert np.isnan(micro_pr.get()[1])

macro_pr.update([label1], [pred1])
micro_pr.update([label1], [pred1])

np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]])
label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]])
# Note that pred12 = pred1 + pred2; label12 = label1 + label2
pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]])
label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]])

pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1]
pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel())

macro_pr.reset()
micro_pr.update([label2], [pred2])
macro_pr.update([label12], [pred12])
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np)
np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy)

def cm_batch(cm):
# generate a batch yielding a given confusion matrix
Expand Down