diff --git a/src/anemoi/graphs/generate/masks.py b/src/anemoi/graphs/generate/masks.py index e1a4128..ad75c5e 100644 --- a/src/anemoi/graphs/generate/masks.py +++ b/src/anemoi/graphs/generate/masks.py @@ -32,6 +32,8 @@ class KNNAreaMaskBuilder: """ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str = None): + assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number." + assert margin_radius_km > 0, "The margin radius must be positive." self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4) self.margin_radius_km = margin_radius_km @@ -40,9 +42,16 @@ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask def fit(self, graph: HeteroData): """Fit the KNN model to the nodes of interest.""" + assert ( + self.reference_node_name in graph.node_types + ), f'Reference node "{self.reference_node_name}" not found in the graph.' reference_mask_str = self.reference_node_name + coords_rad = graph[self.reference_node_name].x.numpy() if self.mask_attr_name is not None: + assert ( + self.mask_attr_name in graph[self.reference_node_name].node_attrs() + ), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.' mask = graph[self.reference_node_name][self.mask_attr_name].squeeze() coords_rad = coords_rad[mask] reference_mask_str += f" ({self.mask_attr_name})" diff --git a/tests/conftest.py b/tests/conftest.py index 2fcc824..23208c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,6 +50,7 @@ def graph_with_nodes() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) return graph @@ -59,6 +60,7 @@ def graph_nodes_and_edges() -> HeteroData: coords = np.array([[lat, lon] for lat in lats for lon in lons]) graph = HeteroData() graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) + graph["test_nodes"].mask = torch.tensor([True] * len(coords)) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph diff --git a/tests/generate/test_masks.py b/tests/generate/test_masks.py new file mode 100644 index 0000000..651bdb7 --- /dev/null +++ b/tests/generate/test_masks.py @@ -0,0 +1,48 @@ +import pytest +from sklearn.neighbors import NearestNeighbors +from torch_geometric.data import HeteroData + +from anemoi.graphs.generate.masks import KNNAreaMaskBuilder + + +def test_init(): + """Test KNNAreaMaskBuilder initialization.""" + mask_builder1 = KNNAreaMaskBuilder("nodes") + mask_builder2 = KNNAreaMaskBuilder("nodes", margin_radius_km=120) + mask_builder3 = KNNAreaMaskBuilder("nodes", mask_attr_name="mask") + mask_builder4 = KNNAreaMaskBuilder("nodes", margin_radius_km=120, mask_attr_name="mask") + + assert isinstance(mask_builder1, KNNAreaMaskBuilder) + assert isinstance(mask_builder2, KNNAreaMaskBuilder) + assert isinstance(mask_builder3, KNNAreaMaskBuilder) + assert isinstance(mask_builder4, KNNAreaMaskBuilder) + + assert isinstance(mask_builder1.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder2.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder3.nearest_neighbour, NearestNeighbors) + assert isinstance(mask_builder4.nearest_neighbour, NearestNeighbors) + + +@pytest.mark.parametrize("margin", [-1, "120", None]) +def test_fail_init_wrong_margin(margin: int): + """Test KNNAreaMaskBuilder initialization with invalid margin.""" + with pytest.raises(AssertionError): + KNNAreaMaskBuilder("nodes", margin_radius_km=margin) + + +@pytest.mark.parametrize("mask", [None, "mask"]) +def test_fit(graph_with_nodes: HeteroData, mask: str): + """Test KNNAreaMaskBuilder fit.""" + mask_builder = KNNAreaMaskBuilder("test_nodes", mask_attr_name=mask) + assert not hasattr(mask_builder.nearest_neighbour, "n_samples_fit_") + + mask_builder.fit(graph_with_nodes) + + assert mask_builder.nearest_neighbour.n_samples_fit_ == graph_with_nodes["test_nodes"].num_nodes + + +def test_fit_fail(graph_with_nodes): + """Test KNNAreaMaskBuilder fit with wrong graph.""" + mask_builder = KNNAreaMaskBuilder("wrong_nodes") + with pytest.raises(AssertionError): + mask_builder.fit(graph_with_nodes)