Skip to content

Commit

Permalink
use scikit in test
Browse files Browse the repository at this point in the history
  • Loading branch information
sethah committed Feb 14, 2018
1 parent 3c86317 commit 797c01c
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import json

from sklearn.metrics import f1_score as scikit_f1

def check_metric(metric, *args, **kwargs):
metric = mx.metric.create(metric, *args, **kwargs)
str_metric = json.dumps(metric.get_config())
Expand Down Expand Up @@ -55,22 +57,22 @@ def test_acc():
assert acc == expected_acc

def test_f1():
microF1 = mx.metric.create("f1", average="micro")
macroF1 = mx.metric.F1(average="macro")
micro_f1 = mx.metric.create("f1", average="micro")
macro_f1 = mx.metric.F1(average="macro")

assert np.isnan(macroF1.get()[1])
assert np.isnan(microF1.get()[1])
assert np.isnan(macro_f1.get()[1])
assert np.isnan(micro_f1.get()[1])

# check divide by zero
pred = mx.nd.array([[0.9, 0.1],
[0.8, 0.2]])
label = mx.nd.array([0, 0])
macroF1.update([label], [pred])
microF1.update([label], [pred])
assert macroF1.get()[1] == 0.0
assert microF1.get()[1] == 0.0
macroF1.reset()
microF1.reset()
macro_f1.update([label], [pred])
micro_f1.update([label], [pred])
assert macro_f1.get()[1] == 0.0
assert micro_f1.get()[1] == 0.0
macro_f1.reset()
micro_f1.reset()

pred11 = mx.nd.array([[0.1, 0.9],
[0.5, 0.5]])
Expand All @@ -83,23 +85,28 @@ def test_f1():
pred22 = mx.nd.array([[0.2, 0.8]])
label22 = mx.nd.array([1])

microF1.update([label11, label12], [pred11, pred12])
macroF1.update([label11, label12], [pred11, pred12])
assert microF1.num_inst == 4
assert macroF1.num_inst == 1
# f1 = 2 * tp / (2 * tp + fp + fn)
fscore1 = 2. * (1) / (2 * 1 + 1 + 0)
np.testing.assert_almost_equal(microF1.get()[1], fscore1)
np.testing.assert_almost_equal(macroF1.get()[1], fscore1)

microF1.update([label21, label22], [pred21, pred22])
macroF1.update([label21, label22], [pred21, pred22])
assert microF1.num_inst == 6
assert macroF1.num_inst == 2
fscore2 = 2. * (1) / (2 * 1 + 0 + 0)
fscore_total = 2. * (1 + 1) / (2 * (1 + 1) + (1 + 0) + (0 + 0))
np.testing.assert_almost_equal(microF1.get()[1], fscore_total)
np.testing.assert_almost_equal(macroF1.get()[1], (fscore1 + fscore2) / 2.)
micro_f1.update([label11, label12], [pred11, pred12])
macro_f1.update([label11, label12], [pred11, pred12])
assert micro_f1.num_inst == 4
assert macro_f1.num_inst == 1
np_pred1 = np.concatenate([mx.nd.argmax(pred11, axis=1).asnumpy(),
mx.nd.argmax(pred12, axis=1).asnumpy()])
np_label1 = np.concatenate([label11.asnumpy(), label12.asnumpy()])
np.testing.assert_almost_equal(micro_f1.get()[1], scikit_f1(np_label1, np_pred1))
np.testing.assert_almost_equal(macro_f1.get()[1], scikit_f1(np_label1, np_pred1))

micro_f1.update([label21, label22], [pred21, pred22])
macro_f1.update([label21, label22], [pred21, pred22])
assert micro_f1.num_inst == 6
assert macro_f1.num_inst == 2
np_pred2 = np.concatenate([mx.nd.argmax(pred21, axis=1).asnumpy(),
mx.nd.argmax(pred22, axis=1).asnumpy()])
np_pred_total = np.concatenate([np_pred1, np_pred2])
np_label2 = np.concatenate([label21.asnumpy(), label22.asnumpy()])
np_label_total = np.concatenate([np_label1, np_label2])
np.testing.assert_almost_equal(micro_f1.get()[1], scikit_f1(np_label_total, np_pred_total))
np.testing.assert_almost_equal(macro_f1.get()[1], (scikit_f1(np_label1, np_pred1) +
scikit_f1(np_label2, np_pred2)) / 2)

def test_perplexity():
pred = mx.nd.array([[0.8, 0.2], [0.2, 0.8], [0, 1.]])
Expand Down

0 comments on commit 797c01c

Please sign in to comment.