Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

mx.metric F1 is using numpy logic #9586

Open
szha opened this issue Jan 27, 2018 · 3 comments
Open

mx.metric F1 is using numpy logic #9586

szha opened this issue Jan 27, 2018 · 3 comments

Comments

@szha
Copy link
Member

szha commented Jan 27, 2018

The metric module has been using numpy logic and is not benefiting from existing mxnet operators.

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/metric.py#L535-L569

@sxjscience
Copy link
Member

sxjscience commented Feb 13, 2018

Bring discussion to the correct place. I've implemented an ndarray version of F1 score when doing the experiments and I've included my nd_f1 in the following:

May be useful if we want to accelerate the F1 score computation in the future. Also, we can take advantage of the fact that the micro F1 is equivalent to accuracy for single-label classification to accelerate the computatoin.

import mxnet.ndarray as nd
from sklearn.metrics import f1_score
import numpy as np
import mxnet as mx
import time

def nd_f1(pred, label, num_class, average="micro"):
    """Evaluate F1 using mx.nd.NDArray

    Parameters
    ----------
    pred : nd.NDArray
        Shape (num, label_num) or (num,)
    label : nd.NDArray
        Shape (num, label_num) or (num,)
    num_class : int
    average : str

    Returns
    -------
    f1 : float
    """
    if pred.dtype != np.float32:
        pred = pred.astype(np.float32)
        label = label.astype(np.float32)
    assert num_class > 1
    assert pred.ndim == label.ndim
    if num_class == 2 and average == "micro":
        tp = nd.sum((pred == 1) * (label == 1)).asscalar()
        fp = nd.sum((pred == 1) * (label == 0)).asscalar()
        fn = nd.sum((pred == 0) * (label == 1)).asscalar()
        precision = float(tp) / (tp + fp)
        recall = float(tp) / (tp + fn)
        f1 = 2 * (precision * recall) / (precision + recall)
    else:
        assert num_class is not None
        pred_onehot = nd.one_hot(indices=pred, depth=num_class)
        label_onehot = nd.one_hot(indices=label, depth=num_class)
        tp = pred_onehot * label_onehot
        fp = pred_onehot * (1 - label_onehot)
        fn = (1 - pred_onehot) * label_onehot
        if average == "micro":
            tp = nd.sum(tp).asscalar()
            fp = nd.sum(fp).asscalar()
            fn = nd.sum(fn).asscalar()
            precision = float(tp) / (tp + fp)
            recall = float(tp) / (tp + fn)
            f1 = 2 * (precision * recall) / (precision + recall)
        elif average == "macro":
            if tp.ndim == 3:
                tp = nd.sum(tp, axis=(0, 1))
                fp = nd.sum(fp, axis=(0, 1))
                fn = nd.sum(fn, axis=(0, 1))
            else:
                tp = nd.sum(tp, axis=0)
                fp = nd.sum(fp, axis=0)
                fn = nd.sum(fn, axis=0)
            precision = nd.mean(tp / (tp + fp)).asscalar()
            recall = nd.mean(tp / (tp + fn)).asscalar()
            f1 = 2 * (precision * recall) / (precision + recall)
        else:
            raise NotImplementedError
    return f1

for pred_npy, label_npy, num_class\
        in [(np.random.randint(0, 50, size=(100000,)),
             np.random.randint(0, 50, size=(100000,)),
             50),
            (np.random.randint(0, 2, size=(10000, 121)),
             np.random.randint(0, 2, size=(10000, 121)),
             2)]:
    # Test F1 score
    for average in ['micro', 'macro']:
        start = time.time()
        for _ in range(5):
            f1_npy = f1_score(y_true=label_npy, y_pred=pred_npy, average=average)
        end = time.time()
        print("Average=", average, "Npy Time Spent:", end - start)
        pred_nd = nd.array(pred_npy, ctx=mx.gpu(), dtype=np.float32)
        label_nd = nd.array(label_npy, ctx=mx.gpu(), dtype=np.float32)
        nd.waitall()
        f1_nd = nd_f1(pred=pred_nd,
                      label=label_nd,
                      num_class=num_class,
                      average=average)
        nd.waitall()
        start = time.time()
        for _ in range(5):
            f1_nd = nd_f1(pred=pred_nd,
                          label=label_nd,
                          num_class=num_class,
                          average=average)
            nd.waitall()
        end = time.time()
        print("Average=", average, "NDArray Time Spent:", end - start, 'abs diff:', abs(f1_nd - f1_npy))

Result:

Average= micro Npy Time Spent: 0.1795516014099121
Average= micro NDArray Time Spent: 0.033802032470703125 abs diff: 0.0
Average= macro Npy Time Spent: 0.17911505699157715
Average= macro NDArray Time Spent: 0.07393026351928711 abs diff: 4.64383991273e-06
Average= micro Npy Time Spent: 0.6379575729370117
Average= micro NDArray Time Spent: 0.029665708541870117 abs diff: 0.0
Average= macro Npy Time Spent: 0.6377367973327637
Average= macro NDArray Time Spent: 0.034937143325805664 abs diff: 0.000381544355229

@szha
Copy link
Member Author

szha commented Feb 13, 2018

@sxjscience awesome. Would you propose a PR after #9777 is merged? If/when you do, remember to report the benchmark test results from #9705

@sxjscience
Copy link
Member

OK, I'll PR after it's merged.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

2 participants