From 6b36a5056939efa56a857f709857c3b364c06ccb Mon Sep 17 00:00:00 2001 From: Anthony Mahanna Date: Wed, 4 Oct 2023 23:06:53 -0400 Subject: [PATCH] code cleanup --- adbdgl_adapter/abc.py | 22 +- adbdgl_adapter/adapter.py | 464 ++++++++++++++++++++++++-------------- 2 files changed, 292 insertions(+), 194 deletions(-) diff --git a/adbdgl_adapter/abc.py b/adbdgl_adapter/abc.py index 4d9139f..9f2b4e3 100644 --- a/adbdgl_adapter/abc.py +++ b/adbdgl_adapter/abc.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- from abc import ABC -from typing import Any, List, Set, Union +from typing import Any, Set, Union from arango.graph import Graph as ArangoDBGraph from dgl import DGLGraph, DGLHeteroGraph @@ -38,26 +38,6 @@ def dgl_to_arangodb( ) -> ArangoDBGraph: raise NotImplementedError # pragma: no cover - def etypes_to_edefinitions(self, edge_types: List[DGLCanonicalEType]) -> List[Json]: - raise NotImplementedError # pragma: no cover - - def ntypes_to_ocollections( - self, node_types: List[str], edge_types: List[DGLCanonicalEType] - ) -> List[str]: - raise NotImplementedError # pragma: no cover - - def __fetch_adb_docs(self) -> None: - raise NotImplementedError # pragma: no cover - - def __insert_adb_docs(self) -> None: - raise NotImplementedError # pragma: no cover - - def __build_tensor_from_dataframe(self) -> None: - raise NotImplementedError # pragma: no cover - - def __build_dataframe_from_tensor(self) -> None: - raise NotImplementedError # pragma: no cover - class Abstract_ADBDGL_Controller(ABC): def _prepare_dgl_node(self, dgl_node: Json, node_type: str) -> Json: diff --git a/adbdgl_adapter/adapter.py b/adbdgl_adapter/adapter.py index db39072..b222c89 100644 --- a/adbdgl_adapter/adapter.py +++ b/adbdgl_adapter/adapter.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict from math import ceil -from typing import Any, DefaultDict, Dict, List, Optional, Set, Union +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union from arango.cursor import Cursor from arango.database import Database @@ -227,11 +227,6 @@ def udf_v1_x(v1_df): validate_adb_metagraph(metagraph) - is_homogeneous = ( - len(metagraph["vertexCollections"]) == 1 - and len(metagraph["edgeCollections"]) == 1 - ) - # Maps ArangoDB Vertex _keys to DGL Node ids adb_map: ADBMap = defaultdict(dict) @@ -246,6 +241,12 @@ def udf_v1_x(v1_df): # The edge data view for storing edge features edata: DGLData = defaultdict(lambda: defaultdict(Tensor)) + v_cols: List[str] = list(metagraph["vertexCollections"].keys()) + + ###################### + # Vertex Collections # + ###################### + for v_col, meta in metagraph["vertexCollections"].items(): logger.debug(f"Preparing '{v_col}' vertices") @@ -255,10 +256,12 @@ def udf_v1_x(v1_df): cursor_batch = len(cursor.batch()) # type: ignore df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) + # 1. Map each ArangoDB _key to a DGL node id for adb_id in df["_key"]: adb_map[v_col][adb_id] = dgl_id dgl_id += 1 + # 2. Set the DGL Node Data self.__set_dgl_data(v_col, meta, ndata, df) if cursor.has_more(): @@ -266,9 +269,14 @@ def udf_v1_x(v1_df): df.drop(df.index, inplace=True) + #################### + # Edge Collections # + #################### + + # et = Edge Type et_df: DataFrame et_blacklist: List[DGLCanonicalEType] = [] # A list of skipped edge types - v_cols: List[str] = list(metagraph["vertexCollections"].keys()) + for e_col, meta in metagraph["edgeCollections"].items(): logger.debug(f"Preparing '{e_col}' edges") @@ -277,26 +285,32 @@ def udf_v1_x(v1_df): cursor_batch = len(cursor.batch()) # type: ignore df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) + # 1. Split the ArangoDB _from & _to IDs into two columns df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"]) df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"]) + # 2. Iterate over each edge type for (from_col, to_col), count in ( df[["from_col", "to_col"]].value_counts().items() ): edge_type: DGLCanonicalEType = (from_col, e_col, to_col) + + # 3. Check for partial Edge Collection import if from_col not in v_cols or to_col not in v_cols: logger.debug(f"Skipping {edge_type}") et_blacklist.append(edge_type) - continue # partial edge collection import to dgl + continue logger.debug(f"Preparing {count} '{edge_type}' edges") - # Get the edge data corresponding to the current edge type + # 4. Get the edge data corresponding to the current edge type et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)] + # 5. Map each ArangoDB from/to _key to the corresponding DGL node id from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + # 6. Set/Update the DGL Edge Index if edge_type not in data_dict: data_dict[edge_type] = (tensor(from_nodes), tensor(to_nodes)) else: @@ -306,6 +320,7 @@ def udf_v1_x(v1_df): cat((previous_to_nodes, tensor(to_nodes))), ) + # 7. Set the DGL Edge Data self.__set_dgl_data(edge_type, meta, edata, df) if cursor.has_more(): @@ -321,20 +336,9 @@ def udf_v1_x(v1_df): """ raise ValueError(msg) - dgl_g: Union[DGLGraph, DGLHeteroGraph] - if is_homogeneous: - num_nodes = len(adb_map[v_col]) - data = list(data_dict.values())[0] - dgl_g = graph(data, num_nodes=num_nodes) - else: - num_nodes_dict = {v_col: len(adb_map[v_col]) for v_col in adb_map} - dgl_g = heterograph(data_dict, num_nodes_dict) - - has_one_ntype = len(dgl_g.ntypes) == 1 - has_one_etype = len(dgl_g.canonical_etypes) == 1 - - self.__copy_dgl_data(dgl_g.ndata, ndata, has_one_ntype) - self.__copy_dgl_data(dgl_g.edata, edata, has_one_etype) + dgl_g = self.__create_dgl_graph(data_dict, adb_map, metagraph) + self.__copy_dgl_data(dgl_g.ndata, ndata, len(dgl_g.ntypes) == 1) + self.__copy_dgl_data(dgl_g.edata, edata, len(dgl_g.canonical_etypes) == 1) logger.info(f"Created DGL '{name}' Graph") return dgl_g @@ -392,6 +396,156 @@ def arangodb_graph_to_dgl( return self.arangodb_collections_to_dgl(name, v_cols, e_cols, **query_options) + def __fetch_adb_docs( + self, + col: str, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + query_options: Any, + ) -> Cursor: + """Fetches ArangoDB documents within a collection. Returns the + documents in a DataFrame. + + :param col: The ArangoDB collection. + :type col: str + :param meta: The MetaGraph associated to **col** + :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] + :param query_options: Keyword arguments to specify AQL query options + when fetching documents from the ArangoDB instance. + :type query_options: Any + :return: A DataFrame representing the ArangoDB documents. + :rtype: pandas.DataFrame + """ + + def get_aql_return_value( + meta: Union[Set[str], Dict[str, ADBMetagraphValues]] + ) -> str: + """Helper method to formulate the AQL `RETURN` value based on + the document attributes specified in **meta** + """ + attributes = [] + + if type(meta) is set: + attributes = list(meta) + + elif type(meta) is dict: + for value in meta.values(): + if type(value) is str: + attributes.append(value) + elif type(value) is dict: + attributes.extend(list(value.keys())) + elif callable(value): + # Cannot determine which attributes to extract if UDFs are used + # Therefore we just return the entire document + return "doc" + + return f""" + MERGE( + {{ _key: doc._key, _from: doc._from, _to: doc._to }}, + KEEP(doc, {list(attributes)}) + ) + """ + + with progress( + f"(ADB → DGL): {col}", + text_style="#319BF5", + spinner_style="#FCFDFC", + ) as p: + p.add_task("__fetch_adb_docs") + return self.__db.aql.execute( # type: ignore + f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}", + bind_vars={"@col": col}, + **{**{"stream": True}, **query_options}, + ) + + def __set_dgl_data( + self, + data_type: DGLDataTypes, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + dgl_data: DGLData, + df: DataFrame, + ) -> None: + """A helper method to build the DGL NodeSpace or EdgeSpace object + for the DGL graph. Is responsible for preparing the input **meta** such + that it becomes a dictionary, and building DGL-ready tensors from the + ArangoDB DataFrame **df**. + + :param data_type: The current node or edge type of the soon-to-be DGL graph. + :type data_type: str | tuple[str, str, str] + :param meta: The metagraph associated to the current ArangoDB vertex or + edge collection. e.g metagraph['vertexCollections']['Users'] + :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] + :param dgl_data: The (currently empty) DefaultDict object storing the node or + edge features of the soon-to-be DGL graph. + :type dgl_data: adbdgl_adapter.typings.DGLData + :param df: The DataFrame representing the ArangoDB collection data + :type df: pandas.DataFrame + """ + valid_meta: Dict[str, ADBMetagraphValues] + valid_meta = meta if type(meta) is dict else {m: m for m in meta} + + for k, v in valid_meta.items(): + t = self.__build_tensor_from_dataframe(df, k, v) + dgl_data[k][data_type] = cat((dgl_data[k][data_type], t)) + + def __split_adb_ids(self, s: Series) -> Series: + """Helper method to split the ArangoDB IDs within a Series into two columns""" + return s.str.split(pat="/", n=1, expand=True) + + def __create_dgl_graph( + self, data_dict: DGLDataDict, adb_map: ADBMap, metagraph: ADBMetagraph + ) -> Union[DGLGraph, DGLHeteroGraph]: + """Creates a DGL graph from the given DGL data. + + :param data_dict: The data for constructing a graph, + which takes the form of (U, V). + (U[i], V[i]) forms the edge with ID i in the graph. + :type data_dict: adbdgl_adapter.typings.DGLDataDict + :param adb_map: A mapping of ArangoDB IDs to DGL IDs. + :type adb_map: adbdgl_adapter.typings.ADBMap + :param metagraph: The ArangoDB metagraph. + :type metagraph: adbdgl_adapter.typings.ADBMetagraph + :return: A DGL Homogeneous or Heterogeneous graph object + :rtype: dgl.DGLGraph | dgl.DGLHeteroGraph + """ + is_homogeneous = ( + len(metagraph["vertexCollections"]) == 1 + and len(metagraph["edgeCollections"]) == 1 + ) + + if is_homogeneous: + v_col = next(iter(metagraph["vertexCollections"])) + data = next(iter(data_dict.values())) + + return graph(data, num_nodes=len(adb_map[v_col])) + + num_nodes_dict = {v_col: len(adb_map[v_col]) for v_col in adb_map} + return heterograph(data_dict, num_nodes_dict) + + def __copy_dgl_data( + self, + dgl_data: Union[HeteroNodeDataView, HeteroEdgeDataView], + dgl_data_temp: DGLData, + has_one_type: bool, + ) -> None: + """Copies **dgl_data_temp** into **dgl_data**. This method is (unfortunately) + required, since a dgl graph's `ndata` and `edata` properties can't be + manually set (i.e `g.ndata = ndata` is not possible). + + :param dgl_data: The (empty) ndata or edata instance attribute of a dgl graph, + which is about to receive **dgl_data_temp**. + :type dgl_data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] + :param dgl_data_temp: A temporary place to store the ndata or edata features. + :type dgl_data_temp: adbdgl_adapter.typings.DGLData + :param has_one_type: Set to True if the DGL graph only has one + node type or edge type. + :type has_one_type: bool + """ + for feature_name, feature_map in dgl_data_temp.items(): + for data_type, dgl_tensor in feature_map.items(): + dgl_data[feature_name] = ( + dgl_tensor if has_one_type else {data_type: dgl_tensor} + ) + def dgl_to_arangodb( self, name: str, @@ -491,40 +645,22 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): logger.debug(f"--dgl_to_arangodb('{name}')--") validate_dgl_metagraph(metagraph) - is_custom_controller = type(self.__cntrl) is not ADBDGL_Controller + is_custom_controller = type(self.__cntrl) is not ADBDGL_Controller has_one_ntype = len(dgl_g.ntypes) == 1 has_one_etype = len(dgl_g.canonical_etypes) == 1 - has_default_canonical_etypes = dgl_g.canonical_etypes == [("_N", "_E", "_N")] - node_types: List[str] - edge_types: List[DGLCanonicalEType] - explicit_metagraph = metagraph != {} and explicit_metagraph - if explicit_metagraph: - node_types = metagraph.get("nodeTypes", {}).keys() # type: ignore - edge_types = metagraph.get("edgeTypes", {}).keys() # type: ignore - - elif has_default_canonical_etypes: - n_type = name + "_N" - node_types = [n_type] - edge_types = [(n_type, name + "_E", n_type)] - - else: - node_types = dgl_g.ntypes - edge_types = dgl_g.canonical_etypes + node_types, edge_types = self.__get_node_and_edge_types( + name, dgl_g, metagraph, explicit_metagraph + ) - if overwrite_graph: - logger.debug("Overwrite graph flag is True. Deleting old graph.") - self.__db.delete_graph(name, ignore_missing=True) + adb_graph = self.__create_adb_graph( + name, overwrite_graph, node_types, edge_types + ) - if self.__db.has_graph(name): - adb_graph = self.__db.graph(name) - else: - edge_definitions = self.etypes_to_edefinitions(edge_types) - orphan_collections = self.ntypes_to_ocollections(node_types, edge_types) - adb_graph = self.__db.create_graph( - name, edge_definitions, orphan_collections - ) # type: ignore + ############## + # Node Types # + ############## n_meta = metagraph.get("nodeTypes", {}) for n_type in node_types: @@ -539,8 +675,12 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): end_index = min(ndata_batch_size, ndata_size) batches = ceil(ndata_size / ndata_batch_size) + # For each batch of nodes for _ in range(batches): + # 1. Map each DGL node id to an ArangoDB _key adb_keys = [{"_key": str(i)} for i in range(start_index, end_index)] + + # 2. Set the ArangoDB Node Data df = self.__set_adb_data( DataFrame(adb_keys, index=range(start_index, end_index)), meta, @@ -551,15 +691,22 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): explicit_metagraph, ) + # 3. Apply the ArangoDB Node Controller (if provided) if is_custom_controller: f = lambda n: self.__cntrl._prepare_dgl_node(n, n_type) df = df.apply(f, axis=1) + # 4. Insert the ArangoDB Node Documents self.__insert_adb_docs(n_type, df, import_options) + # 5. Update the batch indices start_index = end_index end_index = min(end_index + ndata_batch_size, ndata_size) + ############## + # Edge Types # + ############## + e_meta = metagraph.get("edgeTypes", {}) for e_type in edge_types: meta = e_meta.get(e_type, {}) @@ -576,7 +723,9 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): from_nodes, to_nodes = dgl_g.edges(etype=e_key) + # For each batch of edges for _ in range(batches): + # 1. Map the DGL edges to ArangoDB _from & _to IDs data = zip( *( from_nodes[start_index:end_index].tolist(), @@ -584,6 +733,7 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): ) ) + # 2. Set the ArangoDB Edge Data df = self.__set_adb_data( DataFrame( data, @@ -601,19 +751,107 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): df["_from"] = from_col + "/" + df["_from"].astype(str) df["_to"] = to_col + "/" + df["_to"].astype(str) + # 3. Apply the ArangoDB Edge Controller (if provided) if is_custom_controller: f = lambda e: self.__cntrl._prepare_dgl_edge(e, e_type) df = df.apply(f, axis=1) + # 4. Insert the ArangoDB Edge Documents self.__insert_adb_docs(e_type, df, import_options) + # 5. Update the batch indices start_index = end_index end_index = min(end_index + edata_batch_size, edata_size) logger.info(f"Created ArangoDB '{name}' Graph") return adb_graph - def etypes_to_edefinitions(self, edge_types: List[DGLCanonicalEType]) -> List[Json]: + def __get_node_and_edge_types( + self, + name: str, + dgl_g: DGLGraph, + metagraph: DGLMetagraph, + explicit_metagraph: bool, + ) -> Tuple[List[str], List[DGLCanonicalEType]]: + """Returns the node & edge types of the DGL graph, based on the + metagraph and whether the graph has default canonical etypes. + + :param name: The DGL graph name. + :type name: str + :param dgl_g: The existing DGL graph. + :type dgl_g: dgl.DGLGraph + :param metagraph: The DGL Metagraph. + :type metagraph: adbdgl_adapter.typings.DGLMetagraph + :param explicit_metagraph: Whether to take the metagraph at face value or not. + If False, node & edge types OMITTED from the metagraph will be + brought over into ArangoDB. Also applies to node & edge attributes. + Defaults to True. + :type explicit_metagraph: bool + :return: The node & edge types of the DGL graph. + :rtype: Tuple[List[str], List[adbdgl_adapter.typings.DGLCanonicalEType]] + """ + node_types: List[str] + edge_types: List[DGLCanonicalEType] + + explicit_metagraph = metagraph != {} and explicit_metagraph + has_default_canonical_etypes = dgl_g.canonical_etypes == [("_N", "_E", "_N")] + + if explicit_metagraph: + node_types = metagraph.get("nodeTypes", {}).keys() # type: ignore + edge_types = metagraph.get("edgeTypes", {}).keys() # type: ignore + + elif has_default_canonical_etypes: + n_type = name + "_N" + node_types = [n_type] + edge_types = [(n_type, name + "_E", n_type)] + + else: + node_types = dgl_g.ntypes + edge_types = dgl_g.canonical_etypes + + return node_types, edge_types + + def __create_adb_graph( + self, + name: str, + overwrite_graph: bool, + node_types: List[str], + edge_types: List[DGLCanonicalEType], + ) -> ADBGraph: + """Creates an ArangoDB graph. + + :param name: The ArangoDB graph name. + :type name: str + :param overwrite_graph: Overwrites the graph if it already exists. + Does not drop associated collections. Defaults to False. + :type overwrite_graph: bool + :param node_types: A list of strings representing the DGL node types. + :type node_types: List[str] + :param edge_types: A list of string triplets (str, str, str) for + source node type, edge type and destination node type. + :type edge_types: List[adbdgl_adapter.typings.DGLCanonicalEType] + :return: The ArangoDB Graph API wrapper. + :rtype: arango.graph.Graph + """ + if overwrite_graph: + logger.debug("Overwrite graph flag is True. Deleting old graph.") + self.__db.delete_graph(name, ignore_missing=True) + + if self.__db.has_graph(name): + return self.__db.graph(name) + + edge_definitions = self.__etypes_to_edefinitions(edge_types) + orphan_collections = self.__ntypes_to_ocollections(node_types, edge_types) + + return self.__db.create_graph( # type: ignore[return-value] + name, + edge_definitions, + orphan_collections, + ) + + def __etypes_to_edefinitions( + self, edge_types: List[DGLCanonicalEType] + ) -> List[Json]: """Converts DGL canonical_etypes to ArangoDB edge_definitions :param edge_types: A list of string triplets (str, str, str) for @@ -657,7 +895,7 @@ def etypes_to_edefinitions(self, edge_types: List[DGLCanonicalEType]) -> List[Js return edge_definitions - def ntypes_to_ocollections( + def __ntypes_to_ocollections( self, node_types: List[str], edge_types: List[DGLCanonicalEType] ) -> List[str]: """Converts DGL node_types to ArangoDB orphan collections, if any. @@ -679,67 +917,6 @@ def ntypes_to_ocollections( orphan_collections = set(node_types) ^ non_orphan_collections return list(orphan_collections) - def __fetch_adb_docs( - self, - col: str, - meta: Union[Set[str], Dict[str, ADBMetagraphValues]], - query_options: Any, - ) -> Cursor: - """Fetches ArangoDB documents within a collection. Returns the - documents in a DataFrame. - - :param col: The ArangoDB collection. - :type col: str - :param meta: The MetaGraph associated to **col** - :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] - :param query_options: Keyword arguments to specify AQL query options - when fetching documents from the ArangoDB instance. - :type query_options: Any - :return: A DataFrame representing the ArangoDB documents. - :rtype: pandas.DataFrame - """ - - def get_aql_return_value( - meta: Union[Set[str], Dict[str, ADBMetagraphValues]] - ) -> str: - """Helper method to formulate the AQL `RETURN` value based on - the document attributes specified in **meta** - """ - attributes = [] - - if type(meta) is set: - attributes = list(meta) - - elif type(meta) is dict: - for value in meta.values(): - if type(value) is str: - attributes.append(value) - elif type(value) is dict: - attributes.extend(list(value.keys())) - elif callable(value): - # Cannot determine which attributes to extract if UDFs are used - # Therefore we just return the entire document - return "doc" - - return f""" - MERGE( - {{ _key: doc._key, _from: doc._from, _to: doc._to }}, - KEEP(doc, {list(attributes)}) - ) - """ - - with progress( - f"(ADB → DGL): {col}", - text_style="#319BF5", - spinner_style="#FCFDFC", - ) as p: - p.add_task("__fetch_adb_docs") - return self.__db.aql.execute( # type: ignore - f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}", - bind_vars={"@col": col}, - **{**{"stream": True}, **query_options}, - ) - def __insert_adb_docs( self, doc_type: Union[str, DGLCanonicalEType], df: DataFrame, kwargs: Any ) -> None: @@ -767,65 +944,6 @@ def __insert_adb_docs( logger.debug(result) df.drop(df.index, inplace=True) - def __split_adb_ids(self, s: Series) -> Series: - """Helper method to split the ArangoDB IDs within a Series into two columns""" - return s.str.split(pat="/", n=1, expand=True) - - def __set_dgl_data( - self, - data_type: DGLDataTypes, - meta: Union[Set[str], Dict[str, ADBMetagraphValues]], - dgl_data: DGLData, - df: DataFrame, - ) -> None: - """A helper method to build the DGL NodeSpace or EdgeSpace object - for the DGL graph. Is responsible for preparing the input **meta** such - that it becomes a dictionary, and building DGL-ready tensors from the - ArangoDB DataFrame **df**. - - :param data_type: The current node or edge type of the soon-to-be DGL graph. - :type data_type: str | tuple[str, str, str] - :param meta: The metagraph associated to the current ArangoDB vertex or - edge collection. e.g metagraph['vertexCollections']['Users'] - :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] - :param dgl_data: The (currently empty) DefaultDict object storing the node or - edge features of the soon-to-be DGL graph. - :type dgl_data: adbdgl_adapter.typings.DGLData - :param df: The DataFrame representing the ArangoDB collection data - :type df: pandas.DataFrame - """ - valid_meta: Dict[str, ADBMetagraphValues] - valid_meta = meta if type(meta) is dict else {m: m for m in meta} - - for k, v in valid_meta.items(): - t = self.__build_tensor_from_dataframe(df, k, v) - dgl_data[k][data_type] = cat((dgl_data[k][data_type], t)) - - def __copy_dgl_data( - self, - dgl_data: Union[HeteroNodeDataView, HeteroEdgeDataView], - dgl_data_temp: DGLData, - has_one_type: bool, - ) -> None: - """Copies **dgl_data_temp** into **dgl_data**. This method is (unfortunately) - required, since a dgl graph's `ndata` and `edata` properties can't be - manually set (i.e `g.ndata = ndata` is not possible). - - :param dgl_data: The (empty) ndata or edata instance attribute of a dgl graph, - which is about to receive **dgl_data_temp**. - :type dgl_data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] - :param dgl_data_temp: A temporary place to store the ndata or edata features. - :type dgl_data_temp: adbdgl_adapter.typings.DGLData - :param has_one_type: Set to True if the DGL graph only has one - node type or edge type. - :type has_one_type: bool - """ - for feature_name, feature_map in dgl_data_temp.items(): - for data_type, dgl_tensor in feature_map.items(): - dgl_data[feature_name] = ( - dgl_tensor if has_one_type else {data_type: dgl_tensor} - ) - def __set_adb_data( self, df: DataFrame,