From 6439224de0b9de51c7f39cb7403bf6948974ec15 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Wed, 27 Mar 2024 18:43:52 +0100 Subject: [PATCH] Improve util error messages and fixed snapshot updates in certain cases --- tests/util.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/util.py b/tests/util.py index 1b682796..16980198 100644 --- a/tests/util.py +++ b/tests/util.py @@ -312,6 +312,11 @@ def assert_image_inference( assert expected_path.exists(), f"Expected {expected_path} to exist." expected = read_image(expected_path) + if expected.shape != output.shape and update_mode: + # update the snapshot + write_image(expected_path, output) + continue + # Assert that the images are the same within a certain tolerance # The CI for some reason has a bit of FP precision loss compared to my local machine # Therefore, a tolerance of 1 is fine enough. @@ -474,9 +479,15 @@ def test_size(width: int, height: int) -> None: with torch.no_grad(): output_tensor = model(input_tensor.to(device)) - assert output_tensor.shape[1] == model.output_channels, "Incorrect channels" - assert output_tensor.shape[2] == height * model.scale, "Incorrect height" - assert output_tensor.shape[3] == width * model.scale, "Incorrect width" + expected_shape = ( + 1, + model.output_channels, + height * model.scale, + width * model.scale, + ) + assert ( + output_tensor.shape == expected_shape + ), f"Expected {expected_shape}, but got {output_tensor.shape}" except Exception as e: raise AssertionError( f"Failed size requirement test for {width=} {height=}"