Skip to content

Commit

Permalink
test: new test for KNNAreaMaskBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Aug 19, 2024
1 parent 6acfc32 commit 36aa407
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/anemoi/graphs/generate/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class KNNAreaMaskBuilder:
"""

def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask_attr_name: str = None):
assert isinstance(margin_radius_km, (int, float)), "The margin radius must be a number."
assert margin_radius_km > 0, "The margin radius must be positive."

self.nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
self.margin_radius_km = margin_radius_km
Expand All @@ -40,9 +42,16 @@ def __init__(self, reference_node_name: str, margin_radius_km: float = 100, mask

def fit(self, graph: HeteroData):
"""Fit the KNN model to the nodes of interest."""
assert (
self.reference_node_name in graph.node_types
), f'Reference node "{self.reference_node_name}" not found in the graph.'
reference_mask_str = self.reference_node_name

coords_rad = graph[self.reference_node_name].x.numpy()
if self.mask_attr_name is not None:
assert (
self.mask_attr_name in graph[self.reference_node_name].node_attrs()
), f'Mask attribute "{self.mask_attr_name}" not found in the reference nodes.'
mask = graph[self.reference_node_name][self.mask_attr_name].squeeze()
coords_rad = coords_rad[mask]
reference_mask_str += f" ({self.mask_attr_name})"
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def graph_with_nodes() -> HeteroData:
coords = np.array([[lat, lon] for lat in lats for lon in lons])
graph = HeteroData()
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords)
graph["test_nodes"].mask = torch.tensor([True] * len(coords))
return graph


Expand All @@ -59,6 +60,7 @@ def graph_nodes_and_edges() -> HeteroData:
coords = np.array([[lat, lon] for lat in lats for lon in lons])
graph = HeteroData()
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords)
graph["test_nodes"].mask = torch.tensor([True] * len(coords))
graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
return graph

Expand Down
48 changes: 48 additions & 0 deletions tests/generate/test_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData

from anemoi.graphs.generate.masks import KNNAreaMaskBuilder


def test_init():
"""Test KNNAreaMaskBuilder initialization."""
mask_builder1 = KNNAreaMaskBuilder("nodes")
mask_builder2 = KNNAreaMaskBuilder("nodes", margin_radius_km=120)
mask_builder3 = KNNAreaMaskBuilder("nodes", mask_attr_name="mask")
mask_builder4 = KNNAreaMaskBuilder("nodes", margin_radius_km=120, mask_attr_name="mask")

assert isinstance(mask_builder1, KNNAreaMaskBuilder)
assert isinstance(mask_builder2, KNNAreaMaskBuilder)
assert isinstance(mask_builder3, KNNAreaMaskBuilder)
assert isinstance(mask_builder4, KNNAreaMaskBuilder)

assert isinstance(mask_builder1.nearest_neighbour, NearestNeighbors)
assert isinstance(mask_builder2.nearest_neighbour, NearestNeighbors)
assert isinstance(mask_builder3.nearest_neighbour, NearestNeighbors)
assert isinstance(mask_builder4.nearest_neighbour, NearestNeighbors)


@pytest.mark.parametrize("margin", [-1, "120", None])
def test_fail_init_wrong_margin(margin: int):
"""Test KNNAreaMaskBuilder initialization with invalid margin."""
with pytest.raises(AssertionError):
KNNAreaMaskBuilder("nodes", margin_radius_km=margin)


@pytest.mark.parametrize("mask", [None, "mask"])
def test_fit(graph_with_nodes: HeteroData, mask: str):
"""Test KNNAreaMaskBuilder fit."""
mask_builder = KNNAreaMaskBuilder("test_nodes", mask_attr_name=mask)
assert not hasattr(mask_builder.nearest_neighbour, "n_samples_fit_")

mask_builder.fit(graph_with_nodes)

assert mask_builder.nearest_neighbour.n_samples_fit_ == graph_with_nodes["test_nodes"].num_nodes


def test_fit_fail(graph_with_nodes):
"""Test KNNAreaMaskBuilder fit with wrong graph."""
mask_builder = KNNAreaMaskBuilder("wrong_nodes")
with pytest.raises(AssertionError):
mask_builder.fit(graph_with_nodes)

0 comments on commit 36aa407

Please sign in to comment.