diff --git a/tests/unittests/segmentation/test_hausdorff_distance.py b/tests/unittests/segmentation/test_hausdorff_distance.py index 34b720b5fec..db37a71e6ea 100644 --- a/tests/unittests/segmentation/test_hausdorff_distance.py +++ b/tests/unittests/segmentation/test_hausdorff_distance.py @@ -48,8 +48,9 @@ # Wrapper that converts to numpy to avoid Torch-to-numpy functional issues -def torch_skimage_hausdorff_distance(p: torch.Tensor, t: torch.Tensor) -> float: - return skimage_hausdorff_distance(p.numpy(), t.numpy()) +def torch_skimage_hausdorff_distance(p: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + out = skimage_hausdorff_distance(p.numpy(), t.numpy()) + return torch.tensor([out]) @pytest.mark.parametrize(