Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated precision_recall_curve.py #2490

Merged
merged 30 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d01467a
updated precision_recall_curve.py
sayantan1410 Feb 24, 2022
c5e2757
autopep8 fix
sayantan1410 Feb 24, 2022
0ce730f
removed unsed imports
sayantan1410 Feb 24, 2022
dc884ec
made some small changes
sayantan1410 Feb 24, 2022
d0f52e4
solved unused import issue
sayantan1410 Feb 24, 2022
0c60883
autopep8 fix
sayantan1410 Feb 24, 2022
6034269
reverted back some changes and changed epoch_metric.py
sayantan1410 Feb 27, 2022
52de161
autopep8 fix
sayantan1410 Feb 27, 2022
36544bc
re written compute function for precision_recall_curve.py
sayantan1410 Mar 3, 2022
a6a8996
Merge branch 'feature' of https://github.com/sayantan1410/ignite into…
sayantan1410 Mar 3, 2022
2d45c21
reverted back epoch_metric.py
sayantan1410 Mar 3, 2022
57439fd
reverted back unnecessary changes to doc string
sayantan1410 Mar 3, 2022
96de71a
reverted a line break that was added by mistake
sayantan1410 Mar 3, 2022
19b568d
autopep8 fix
sayantan1410 Mar 3, 2022
74cb48b
corrected function annotation
sayantan1410 Mar 3, 2022
89a69e9
fixed mypy issues
sayantan1410 Mar 3, 2022
777ec91
Added tests for GPU and TPU
sayantan1410 Mar 4, 2022
8ac2ecf
autopep8 fix
sayantan1410 Mar 4, 2022
a81cc31
fixed a few tests in precision_recall_curve
sayantan1410 Mar 5, 2022
7fbc0c9
autopep8 fix
sayantan1410 Mar 5, 2022
95d8deb
Merge branch 'master' into feature
sdesrozis Mar 5, 2022
9d118c4
fixed a few errors for the tests
sayantan1410 Mar 6, 2022
6855a92
autopep8 fix
sayantan1410 Mar 6, 2022
810d0f3
added tests for array shape
sayantan1410 Mar 6, 2022
203b9a1
autopep8 fix
sayantan1410 Mar 6, 2022
ca9f0a4
made some small changes
sayantan1410 Mar 7, 2022
74aa143
Fixed all the errors in the tests
sayantan1410 Mar 7, 2022
83b599c
fix distributed computation
Mar 7, 2022
743c752
converted tensors to numpy array
sayantan1410 Mar 8, 2022
e004667
checking for approx equal
sayantan1410 Mar 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Callable, Tuple
from typing import Any, Callable, cast, Tuple, Union

import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import EpochMetric


Expand Down Expand Up @@ -69,7 +71,48 @@ def sigmoid_output_transform(output):

"""

def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None:
def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:
super(PrecisionRecallCurve, self).__init__(
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
precision_recall_curve_compute_fn,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

- "EpochMetric must have ..."
+ "PrecisionRecallCurve must have ..."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing it soon.


_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
if ws > 1 and not self._is_reduced:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))
self._is_reduced = True

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
precision, recall, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
precision = torch.Tensor(precision)
recall = torch.Tensor(recall)
# thresholds can have negative strides, not compatible with torch tensors
# https://discuss.pytorch.org/t/negative-strides-in-tensor-error/134287/2
thresholds = torch.Tensor(thresholds.copy())
Comment on lines +104 to +108
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor creation should be done with torch.tensor and not torch.Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will change it.

else:
precision, recall, thresholds = None, None, None

if ws > 1:
# broadcast result to all processes
precision = idist.broadcast(precision, src=0, safe_mode=True)
recall = idist.broadcast(recall, src=0, safe_mode=True)
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

return precision, recall, thresholds
4 changes: 2 additions & 2 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Callable, cast, List, Tuple, Union
from typing import Any, Callable, cast, List, Tuple, Union

import torch

Expand Down Expand Up @@ -136,7 +136,7 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
except Exception as e:
warnings.warn(f"Probably, there can be a problem with `compute_fn`:\n {e}.", EpochMetricWarning)

def compute(self) -> float:
def compute(self) -> Any:
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("EpochMetric must have at least one example before it can be computed.")

Expand Down
204 changes: 197 additions & 7 deletions tests/ignite/contrib/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
from typing import Tuple
from unittest.mock import patch

import numpy as np
Expand All @@ -6,6 +8,7 @@
import torch
from sklearn.metrics import precision_recall_curve

import ignite.distributed as idist
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.engine import Engine
from ignite.metrics.epoch_metric import EpochMetricWarning
Expand Down Expand Up @@ -38,9 +41,12 @@ def test_precision_recall_curve():

precision_recall_curve_metric.update((y_pred, y))
precision, recall, thresholds = precision_recall_curve_metric.compute()
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()

assert np.array_equal(precision, sk_precision)
assert np.array_equal(recall, sk_recall)
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

Expand Down Expand Up @@ -70,9 +76,11 @@ def update_fn(engine, batch):

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]

assert np.array_equal(precision, sk_precision)
assert np.array_equal(recall, sk_recall)
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

Expand Down Expand Up @@ -103,9 +111,12 @@ def update_fn(engine, batch):

data = list(range(size // batch_size))
precision, recall, thresholds = engine.run(data, max_epochs=1).metrics["precision_recall_curve"]
precision = precision.numpy()
recall = recall.numpy()
thresholds = thresholds.numpy()

assert np.array_equal(precision, sk_precision)
assert np.array_equal(recall, sk_recall)
assert pytest.approx(precision) == sk_precision
assert pytest.approx(recall) == sk_recall
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)

Expand All @@ -124,3 +135,182 @@ def test_check_compute_fn():

em = PrecisionRecallCurve(check_compute_fn=False)
em.update(output)


def _test_distrib_compute(device):

rank = idist.get_rank()
torch.manual_seed(12)

def _test(y_pred, y, batch_size, metric_device):

metric_device = torch.device(metric_device)
prc = PrecisionRecallCurve(device=metric_device)

torch.manual_seed(10 + rank)

prc.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
prc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
prc.update((y_pred, y))

# gather y_pred, y
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y = y.cpu().numpy()
np_y_pred = y_pred.cpu().numpy()

res = prc.compute()
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(res, Tuple)
assert precision_recall_curve(np_y, np_y_pred)[0] == pytest.approx(res[0])
assert precision_recall_curve(np_y, np_y_pred)[1] == pytest.approx(res[1])
assert precision_recall_curve(np_y, np_y_pred)[2] == pytest.approx(res[2])

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1),
(torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1),
# updated batches
(torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16),
(torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16),
]
return test_cases

for _ in range(5):
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size, "cpu")
if device.type != "xla":
_test(y_pred, y, batch_size, idist.device())


def _test_distrib_integration(device):

rank = idist.get_rank()
Copy link
Collaborator

@vfdev-5 vfdev-5 Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case does not use rank and each process generates some random preds and true values per distributed process. PrecisionRecallCurve implementation should gather all data from all processes but it is checked against local process computation:

        np_y_true = y_true.cpu().numpy().ravel()
        np_y_preds = y_preds.cpu().numpy().ravel()
        sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)

I would assume this to fail but looks like it is passing. @sayantan1410 can you check why it is so ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sur will check it !!

Copy link
Contributor

@sdesrozis sdesrozis Mar 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the error comes from the following test which seems wrong (and need a fix too)

@sayantan1410 Have a look to this correct implementation

def _test(n_epochs, metric_device):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesrozis Will check it soon !!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesrozis A small question,

def _test(n_epochs, metric_device):

Here, also we don't have something like idist.all_gather so how is the data getting gathered from all the processes ?

torch.manual_seed(12)

def _test(n_epochs, metric_device):
metric_device = torch.device(metric_device)
n_iters = 80
size = 151
y_true = torch.randint(0, 2, (size,)).to(device)
y_preds = torch.randint(0, 2, (size,)).to(device)

def update(engine, i):
return (
y_preds[i * size : (i + 1) * size],
y_true[i * size : (i + 1) * size],
)

engine = Engine(update)

prc = PrecisionRecallCurve(device=metric_device)
prc.attach(engine, "prc")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "prc" in engine.state.metrics

precision, recall, thresholds = engine.state.metrics["prc"]

np_y_true = y_true.cpu().numpy().ravel()
np_y_preds = y_preds.cpu().numpy().ravel()

sk_precision, sk_recall, sk_thresholds = precision_recall_curve(np_y_true, np_y_preds)

assert precision.shape == sk_precision.shape
assert recall.shape == sk_recall.shape
assert thresholds.shape == sk_thresholds.shape
assert pytest.approx(precision) == sk_precision
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
assert pytest.approx(recall) == sk_recall
assert pytest.approx(thresholds) == sk_thresholds

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)