Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving various tests with parametrize #2552

Merged
merged 14 commits into from
Apr 19, 2022
Merged
Empty file added (N
Empty file.
260 changes: 121 additions & 139 deletions tests/ignite/metrics/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,57 +62,50 @@ def test_binary_wrong_inputs():
acc.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long()))


def test_binary_input():

@pytest.fixture(params=[item for item in range(12)])
nmcguire101 marked this conversation as resolved.
Show resolved Hide resolved
def test_data_binary(request):
return [
# Binary accuracy on input of shape (N, 1) or (N, )
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary accuracy on input of shape (N, L)
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
# Binary accuracy on input of shape (N, H, W, ...)
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_binary_input(n_times, test_data_binary):
acc = Accuracy()

def _test(y_pred, y, batch_size):
acc.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

def get_test_cases():

test_cases = [
# Binary accuracy on input of shape (N, 1) or (N, )
(torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1),
(torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary accuracy on input of shape (N, L)
(torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1),
(torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16),
(torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16),
# Binary accuracy on input of shape (N, H, W, ...)
(torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1),
(torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16),
]

return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, n_iters in test_cases:
_test(y_pred, y, n_iters)
y_pred, y, batch_size = test_data_binary
acc.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))

np_y = y.numpy().ravel()
np_y_pred = y_pred.numpy().ravel()

assert acc._type == "binary"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())


def test_multiclass_wrong_inputs():
Expand All @@ -131,53 +124,48 @@ def test_multiclass_wrong_inputs():
acc.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long()))


def test_multiclass_input():
@pytest.fixture(params=[item for item in range(11)])
nmcguire101 marked this conversation as resolved.
Show resolved Hide resolved
def test_data_multiclass(request):
return [
# Multiclass input data of shape (N, ) and (N, C)
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
# 2-classes
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
# Multiclass input data of shape (N, L) and (N, C, L)
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_multiclass_input(n_times, test_data_multiclass):
acc = Accuracy()

def _test(y_pred, y, batch_size):
acc.reset()
if batch_size > 1:
# Batched Updates
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))

np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()

assert acc._type == "multiclass"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

def get_test_cases():

test_cases = [
# Multiclass input data of shape (N, ) and (N, C)
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1),
(torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1),
(torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1),
(torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1),
# 2-classes
(torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1),
(torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16),
# Multiclass input data of shape (N, L) and (N, C, L)
(torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1),
(torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1),
(torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16),
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
(torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1),
(torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16),
]
return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
y_pred, y, batch_size = test_data_multiclass
acc.reset()
if batch_size > 1:
# Batched Updates
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))

np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()

assert acc._type == "multiclass"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())


def to_numpy_multilabel(y):
Expand Down Expand Up @@ -208,55 +196,49 @@ def test_multilabel_wrong_inputs():
acc.update((torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)).long()))


def test_multilabel_input():
@pytest.fixture(params=[item for item in range(12)])
nmcguire101 marked this conversation as resolved.
Show resolved Hide resolved
def test_data_multilabel(request):
return [
# Multilabel input data of shape (N, C) and (N, C)
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
# Multilabel input data of shape (N, H, W)
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_multilabel_input(n_times, test_data_multilabel):
acc = Accuracy(is_multilabel=True)

def _test(y_pred, y, batch_size):
acc.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))
y_pred, y, batch_size = test_data_multilabel
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
acc.update((y_pred, y))

np_y_pred = to_numpy_multilabel(y_pred)
np_y = to_numpy_multilabel(y)
np_y_pred = to_numpy_multilabel(y_pred)
np_y = to_numpy_multilabel(y)

assert acc._type == "multilabel"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

def get_test_cases():

test_cases = [
# Multilabel input data of shape (N, C) and (N, C)
(torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1),
(torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
# Multilabel input data of shape (N, H, W)
(torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1),
(torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16),
# Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)
(torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1),
(torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16),
(torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16),
]
return test_cases

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
assert acc._type == "multilabel"
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())


def test_incorrect_type():
Expand Down