From 8541d97fd15f079c56f2cef30b99519c32f4de06 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 24 Sep 2024 12:58:59 +0200 Subject: [PATCH] Output a tensor for reference metric --- tests/unittests/segmentation/test_hausdorff_distance.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py index 34b720b5fec..db37a71e6ea 100644 --- a/tests/unittests/segmentation/test_hausdorff_distance.py +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -48,8 +48,9 @@ # Wrapper that converts to numpy to avoid Torch-to-numpy functional issues -def torch_skimage_hausdorff_distance(p: torch.Tensor, t: torch.Tensor) -> float: - return skimage_hausdorff_distance(p.numpy(), t.numpy()) +def torch_skimage_hausdorff_distance(p: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + out = skimage_hausdorff_distance(p.numpy(), t.numpy()) + return torch.tensor([out]) @pytest.mark.parametrize(