Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed GMRAE implementation + Improved for DDP #2096

Merged
merged 9 commits into from
Jul 4, 2021
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Tuple, Union, cast
from typing import List, Tuple, cast

import torch

import ignite.distributed as idist
from ignite.contrib.metrics.regression._base import _BaseRegression
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import reinit__is_reduced


class GeometricMeanRelativeAbsoluteError(_BaseRegression):
Expand All @@ -12,7 +14,8 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
.. math::
\text{GMRAE} = \exp(\frac{1}{n}\sum_{j=1}^n \ln\frac{|A_j - P_j|}{|A_j - \bar{A}|})

where :math:`A_j` is the ground truth and :math:`P_j` is the predicted value.
where :math:`A_j` is the ground truth, :math:`P_j` is the predicted value
and :math: `bar{A}` is the mean of the ground truth.

More details can be found in `Botchkarev 2018`__.

Expand All @@ -23,6 +26,18 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):

Parameters are inherited from ``Metric.__init__``.

.. warning::

Current implementation of GMRAE stores all input data (output and target)
as tensors before computing the metric.
This can potentially lead to a memory error if the input data is larger than available RAM.

In distributed configuration, all stored data (output and target) is mutually collected across all processes
using all gather collective operation. This can potentially lead to a memory error.

Compute method compute the metric on zero rank process only and final result is broadcasted to
all processes.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
Expand All @@ -34,23 +49,46 @@ class GeometricMeanRelativeAbsoluteError(_BaseRegression):
non-blocking. By default, CPU.
"""

@reinit__is_reduced
def reset(self) -> None:
self._sum_y = 0.0 # type: Union[float, torch.Tensor]
self._num_examples = 0
self._sum_of_errors = 0.0 # type: Union[float, torch.Tensor]
self._predictions = [] # type: List[torch.Tensor]
self._targets = [] # type: List[torch.Tensor]

def _update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
self._sum_y += y.sum()
self._num_examples += y.shape[0]
y_mean = self._sum_y / self._num_examples
numerator = torch.abs(y.view_as(y_pred) - y_pred)
denominator = torch.abs(y.view_as(y_pred) - y_mean)
self._sum_of_errors += torch.log(numerator / denominator).sum()
y_pred, y = output[0].detach(), output[1].detach()

y_pred = y_pred.clone().to(self._device)
y = y.clone().to(self._device)

self._predictions.append(y_pred)
self._targets.append(y)

def compute(self) -> float:
if self._num_examples == 0:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError(
"GeometricMeanRelativeAbsoluteError must have at least one example before it can be computed."
)
return torch.exp(torch.mean(cast(torch.Tensor, self._sum_of_errors) / self._num_examples)).item()

_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()

if ws > 1:
KickItLikeShika marked this conversation as resolved.
Show resolved Hide resolved
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

result = 0.0
if idist.get_rank() == 0:
result = torch.exp(
torch.log(
KickItLikeShika marked this conversation as resolved.
Show resolved Hide resolved
torch.abs(_target_tensor - _prediction_tensor) / torch.abs(_target_tensor - _target_tensor.mean())
).mean()
).item()

if ws > 1:
# broadcast result to all processes
result = cast(float, idist.broadcast(result, src=0))

return result
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os

import numpy as np
import pytest
import torch

import ignite.distributed as idist
from ignite.contrib.metrics.regression import GeometricMeanRelativeAbsoluteError
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
Expand All @@ -26,7 +29,7 @@ def test_wrong_input_shapes():
m.update((torch.rand(4, 1), torch.rand(4,)))


def test_geometric_mean_relative_absolute_error():
def test_compute():
size = 51
np_y_pred = np.random.rand(size,)
np_y = np.random.rand(size,)
Expand All @@ -42,79 +45,184 @@ def test_geometric_mean_relative_absolute_error():
assert np_gmrae == pytest.approx(m.compute())


def test_geometric_mean_relative_absolute_error_2():
def test_integration():

np.random.seed(1)
size = 105
np_y_pred = np.random.rand(size, 1)
np_y = np.random.rand(size, 1)
np.random.shuffle(np_y)
y_pred = torch.rand(size=(100,))
y = torch.rand(size=(100,))

np_y_sum = 0
num_examples = 0
num_sum_of_errors = 0
np_gmrae = 0
batch_size = 10

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)

m = GeometricMeanRelativeAbsoluteError()
y_pred = torch.from_numpy(np_y_pred)
y = torch.from_numpy(np_y)
m.attach(engine, "gmrae")

m.reset()
n_iters = 15
batch_size = size // n_iters
for i in range(n_iters + 1):
idx = i * batch_size
np_y_i = np_y[idx : idx + batch_size]
np_y_pred_i = np_y_pred[idx : idx + batch_size]
np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

np_y_sum += np_y_i.sum()
num_examples += np_y_i.shape[0]
np_mean = np_y_sum / num_examples
data = list(range(y_pred.shape[0] // batch_size))
gmrae = engine.run(data, max_epochs=1).metrics["gmrae"]

np_gmrae += np.log(np.abs(np_y_i - np_y_pred_i) / np.abs(np_y_i - np_mean)).sum()
m.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
sum_errors = np.log(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean())).sum()
np_len = len(y_pred)
np_ans = np.exp(sum_errors / np_len)

assert np.exp(np_gmrae / num_examples) == pytest.approx(m.compute())
assert np_ans == pytest.approx(gmrae)


def test_integration_geometric_mean_relative_absolute_error_with_output_transform():
def _test_distrib_compute(device):

np.random.seed(1)
size = 105
np_y_pred = np.random.rand(size, 1)
np_y = np.random.rand(size, 1)
np.random.shuffle(np_y)
rank = idist.get_rank()
torch.manual_seed(12)

np_y_sum = 0
num_examples = 0
num_sum_of_errors = 0
np_gmrae = 0
def _test(metric_device):
metric_device = torch.device(metric_device)
m = GeometricMeanRelativeAbsoluteError(device=metric_device)
torch.manual_seed(10 + rank)

n_iters = 15
batch_size = size // n_iters
for i in range(n_iters + 1):
idx = i * batch_size
np_y_i = np_y[idx : idx + batch_size]
np_y_pred_i = np_y_pred[idx : idx + batch_size]
y_pred = torch.rand(size=(100,), device=device)
y = torch.rand(size=(100,), device=device)

np_y_sum += np_y_i.sum()
num_examples += np_y_i.shape[0]
np_mean = np_y_sum / num_examples
m.update((y_pred, y))

np_gmrae += np.log(np.abs(np_y_i - np_y_pred_i) / np.abs(np_y_i - np_mean)).sum()
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return idx, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)
np_y = y.cpu().numpy()
np_y_pred = y_pred.cpu().numpy()

np_gmrae = np.exp(np.log(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean())).mean())

assert m.compute() == pytest.approx(np_gmrae, rel=1e-4)

for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(idist.device())


def _test_distrib_integration(device):

rank = idist.get_rank()
torch.manual_seed(12)

def _test(n_epochs, metric_device):
metric_device = torch.device(metric_device)
n_iters = 80
s = 16
n_classes = 2

offset = n_iters * s
y_true = torch.rand(size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(size=(offset * idist.get_world_size(),)).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
)

engine = Engine(update)

gmrae = GeometricMeanRelativeAbsoluteError(device=metric_device)
gmrae.attach(engine, "gmrae")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

assert "gmrae" in engine.state.metrics

res = engine.state.metrics["gmrae"]

np_y = y_true.cpu().numpy()
np_y_pred = y_preds.cpu().numpy()

np_gmrae = np.exp(np.log(np.abs(np_y - np_y_pred) / np.abs(np_y - np_y.mean())).mean())

assert pytest.approx(res, rel=1e-4) == np_gmrae

metric_devices = ["cpu"]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):

device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

device = idist.device()
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)

engine = Engine(update_fn)

m = GeometricMeanRelativeAbsoluteError(output_transform=lambda x: (x[1], x[2]))
m.attach(engine, "geometric_mean_relative_absolute_error")
def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_compute(device)
_test_distrib_integration(device)

data = list(range(size // batch_size))
gmrae = engine.run(data, max_epochs=1).metrics["geometric_mean_relative_absolute_error"]

assert np.exp(np_gmrae / num_examples) == pytest.approx(m.compute())
@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)