Skip to content

Commit

Permalink
Implemented new attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
icedoom888 committed Dec 17, 2024
1 parent 3898f6f commit 7353f83
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,74 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
values = 1 - values

return values


class BooleanBaseEdgeAttribute:
"""Base class for boolean edge attributes."""

def __init__(self) -> None:
pass

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...

def post_process(self, values: np.ndarray) -> torch.Tensor:
"""Post-process the values."""
return torch.tensor(values, dtype=torch.bool)

def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
source_name, _, target_name = edges_name
assert (
source_name in graph.node_types
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
assert (
target_name in graph.node_types
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."

values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
return self.post_process(values)


class AttributeFromNode(BooleanBaseEdgeAttribute):
"""
Copy an attribute of either the source or destination node to the edge.
Accesses origin/target node attribute and propagates it to the edge.
Used for example to identify if an encoder edge originates from a LAM or global node.
Attributes
----------
node_attr_name : str
Name of the node attribute to propagate.
node_type : str
Pick the node to copy from. Options: "src, dst"
Methods
-------
get_raw_values(graph, source_name, target_name)
Computes the edge attribute from the source or destination node attribute.
"""

def __init__(self, node_attr_name: str, node_type: str) -> None:
self.node_attr_name = node_attr_name
assert node_type in ["src", "dst"]
self.node_type = node_type

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:

edge_index = graph[(source_name, "to", target_name)].edge_index

if self.node_type == "src":
name_to_copy = source_name
idx = 0

else:
name_to_copy = target_name
idx = 1

assert hasattr(graph[name_to_copy], self.node_attr_name)

val = getattr(graph[name_to_copy], self.node_attr_name).numpy()[edge_index[idx]]

return val

0 comments on commit 7353f83

Please sign in to comment.