diff --git a/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py new file mode 100644 index 000000000000..f695a5c575ba --- /dev/null +++ b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py @@ -0,0 +1,39 @@ +import time + +import dgl +import dgl.function as fn + +import numpy as np +import torch + +from .. import utils + + +@utils.benchmark("time") +@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"]) +@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"]) +@utils.parametrize("format", ["csr", "csc"]) +@utils.parametrize("seed_nodes_num", [200, 5000, 20000]) +@utils.parametrize("fanout", [5, 20, 40]) +def track_time(graph_name, format, seed_nodes_num, fanout): + device = utils.get_bench_device() + graph = utils.get_graph(graph_name, format).to(device) + + edge_dir = "in" if format == "csc" else "out" + seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) + + # dry run + for i in range(3): + dgl.sampling.sample_neighbors_fused( + graph, seed_nodes, fanout, edge_dir=edge_dir + ) + + # timing + with utils.Timer() as t: + for i in range(50): + dgl.sampling.sample_neighbors_fused( + graph, seed_nodes, fanout, edge_dir=edge_dir + ) + + return t.elapsed_secs / 50 diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index cbed0e41cc5e..6886aa6f1976 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask = NDArray(), bool replace = true); +/*! + * @brief Randomly select a fixed number of non-zero entries along each given + * row independently. + * + * The function performs random choices along each row independently. + * The picked indices are returned in the form of a CSR matrix, with + * additional IdArray that is an extended version of CSR's index pointers. + * + * With template parameter set to True rows are also saved as new seed nodes and + * mapped + * + * If replace is false and a row has fewer non-zero values than num_samples, + * all the values are picked. + * + * Examples: + * + * // csr.num_rows = 4; + * // csr.num_cols = 4; + * // csr.indptr = [0, 2, 3, 3, 5] + * // csr.indices = [0, 1, 1, 2, 3] + * // csr.data = [2, 3, 0, 1, 4] + * CSRMatrix csr = ...; + * IdArray rows = ... ; // [1, 3] + * IdArray seed_mapping = [-1, -1, -1, -1]; + * std::vector new_seed_nodes = {}; + * + * std::pair sampled = CSRRowWiseSamplingFused< + * typename IdType, True>( + * csr, rows, seed_mapping, + * new_seed_nodes, 2, + * FloatArray(), false); + * // possible sampled csr matrix: + * // sampled.first.num_rows = 2 + * // sampled.first.num_cols = 3 + * // sampled.first.indptr = [0, 1, 3] + * // sampled.first.indices = [1, 2, 3] + * // sampled.first.data = [0, 1, 4] + * // sampled.second = [0, 1, 1] + * // seed_mapping = [-1, 0, -1, 1]; + * // new_seed_nodes = {1, 3}; + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes + * @param mat Input CSR matrix. + * @param rows Rows to sample from. + * @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row + * from rows will be set to its position e.g. mapping[rows[i]] = i. + * @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will + * contain rows. + * @param rows Rows to sample from. + * @param num_samples Number of samples + * @param prob_or_mask Unnormalized probability array or mask array. + * Should be of the same length as the data array. + * If an empty array is provided, assume uniform. + * @param replace True if sample with replacement + * @return A CSRMatrix storing the picked row, col and data indices, + * COO version of picked rows + * @note The edges of the entire graph must be ordered by their edge types, + * rows must be unique + */ +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask = NDArray(), bool replace = true); + /** * @brief Randomly select a fixed number of non-zero entries for each edge type * along each given row independently. diff --git a/include/dgl/sampling/neighbor.h b/include/dgl/sampling/neighbor.h index 7c17050777a2..1c3a4bb64eb3 100644 --- a/include/dgl/sampling/neighbor.h +++ b/include/dgl/sampling/neighbor.h @@ -9,6 +9,7 @@ #include #include +#include #include namespace dgl { @@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors( const std::vector& probability, const std::vector& exclude_edges, bool replace = true); +/** + * @brief Sample from the neighbors of the given nodes and convert a graph into + * a bipartite-structured graph for message passing. + * + * Specifically, we create one node type \c ntype_l on the "left" side and + * another node type \c ntype_r on the "right" side for each node type \c ntype. + * The nodes of type \c ntype_r would contain the nodes designated by the + * caller, and node type \c ntype_l would contain the nodes that has an edge + * connecting to one of the designated nodes. + * + * The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r. + * When sampling with replacement, the sampled subgraph could have parallel + * edges. + * + * For sampling without replace, if fanout > the number of neighbors, all the + * neighbors will be sampled. + * + * Non-deterministic algorithm, requires nodes parameter to store unique Node + * IDs. + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @param hg The input graph. + * @param nodes Node IDs of each type. The vector length must be equal to the + * number of node types. Empty array is allowed. + * @param mapping External parameter that should be set to a vector of IdArrays + * filled with -1, required for mapping of nodes in returned + * graph + * @param fanouts Number of sampled neighbors for each edge type. The vector + * length should be equal to the number of edge types, or one if they all have + * the same fanout. + * @param dir Edge direction. + * @param probability A vector of 1D float arrays, indicating the transition + * probability of each edge by edge type. An empty float array assumes uniform + * transition. + * @param exclude_edges Edges IDs of each type which will be excluded during + * sampling. The vector length must be equal to the number of edges types. Empty + * array is allowed. + * @param replace If true, sample with replacement. + * @return Sampled neighborhoods as a graph. The return graph has the same + * schema as the original one. + */ +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + const std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace = true); + /** * Select the neighbors with k-largest weights on the connecting edges for each * given node. diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 573946eade9f..2a0c37f417b0 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -1,5 +1,7 @@ """Data loading components for neighbor sampling""" +from .. import backend as F from ..base import EID, NID +from ..heterograph import DGLGraph from ..transforms import to_block from .base import BlockSampler @@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler): output_device : device, optional The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes. + fused : bool, default True + If True and device is CPU fused sample neighbors is invoked. This version + requires seed_nodes to be unique Examples -------- @@ -120,6 +125,7 @@ def __init__( prefetch_labels=None, prefetch_edge_feats=None, output_device=None, + fused=True, ): super().__init__( prefetch_node_feats=prefetch_node_feats, @@ -137,10 +143,43 @@ def __init__( ) self.prob = prob or mask self.replace = replace + self.fused = fused + self.mapping = {} + self.g = None def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_nodes = seed_nodes blocks = [] + + if self.fused: + cpu = F.device_type(g.device) == "cpu" + if isinstance(seed_nodes, dict): + for ntype in list(seed_nodes.keys()): + if not cpu: + break + cpu = ( + cpu and F.device_type(seed_nodes[ntype].device) == "cpu" + ) + else: + cpu = cpu and F.device_type(seed_nodes.device) == "cpu" + if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch": + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors_fused( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + return seed_nodes, output_nodes, blocks + for fanout in reversed(self.fanouts): frontier = g.sample_neighbors( seed_nodes, diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 232fb3972744..66bc692d3d9e 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -1,14 +1,19 @@ """Neighbor sampling APIs""" +import os + +import torch + from .. import backend as F, ndarray as nd, utils from .._ffi.function import _init_api from ..base import DGLError, EID -from ..heterograph import DGLGraph +from ..heterograph import DGLBlock, DGLGraph from .utils import EidExcluder __all__ = [ "sample_etype_neighbors", "sample_neighbors", + "sample_neighbors_fused", "sample_neighbors_biased", "select_topk", ] @@ -379,6 +384,126 @@ def sample_neighbors( return frontier if output_device is None else frontier.to(output_device) +def sample_neighbors_fused( + g, + nodes, + fanout, + edge_dir="in", + prob=None, + replace=False, + copy_ndata=True, + copy_edata=True, + exclude_edges=None, + mapping=None, +): + """Sample neighboring edges of the given nodes and return the induced subgraph. + + For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges + will be randomly chosen. The graph returned will then contain all the nodes in the + original graph, but only the sampled edges. Nodes will be renumbered starting from id 0, + which would be new node id of first seed node. + + Parameters + ---------- + g : DGLGraph + The graph. Can be either on CPU or GPU. + nodes : tensor or dict + Node IDs to sample neighbors from. + + This argument can take a single ID tensor or a dictionary of node types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + fanout : int or dict[etype, int] + The number of edges to be sampled for each node on each edge type. + + This argument can take a single int or a dictionary of edge types and ints. + If a single int is given, DGL will sample this number of edges for each node for + every edge type. + + If -1 is given for a single edge type, all the neighboring edges with that edge + type and non-zero probability will be selected. + edge_dir : str, optional + Determines whether to sample inbound or outbound edges. + + Can take either ``in`` for inbound edges or ``out`` for outbound edges. + prob : str, optional + Feature name used as the (unnormalized) probabilities associated with each + neighboring edge of a node. The feature must have only one element for each + edge. + + The features must be non-negative floats or boolean. Otherwise, the result + will be undefined. + exclude_edges: tensor or dict + Edge IDs to exclude during sampling neighbors for the seed nodes. + + This argument can take a single ID tensor or a dictionary of edge types and ID tensors. + If a single tensor is given, the graph must only have one type of nodes. + replace : bool, optional + If True, sample with replacement. + copy_ndata: bool, optional + If True, the node features of the new graph are copied from + the original graph. If False, the new graph will not have any + node features. + + (Default: True) + copy_edata: bool, optional + If True, the edge features of the new graph are copied from + the original graph. If False, the new graph will not have any + edge features. + + (Default: False) + + mapping : dictionary, optional + Used by fused version of NeighborSampler. To avoid constant data allocation + provide empty dictionary ({}) that will be allocated once with proper data and reused + by each function call + + (Default: None) + Returns + ------- + DGLGraph + A sampled subgraph containing only the sampled neighboring edges. + + Notes + ----- + If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as + the node or edge features of the original graph and the new graph. + As a result, users should avoid performing in-place operations + on the node features of the new graph to avoid feature corruption. + + """ + if not g.is_pinned(): + frontier = _sample_neighbors( + g, + nodes, + fanout, + edge_dir=edge_dir, + prob=prob, + replace=replace, + copy_ndata=copy_ndata, + copy_edata=copy_edata, + exclude_edges=exclude_edges, + fused=True, + mapping=mapping, + ) + else: + frontier = _sample_neighbors( + g, + nodes, + fanout, + edge_dir=edge_dir, + prob=prob, + replace=replace, + copy_ndata=copy_ndata, + copy_edata=copy_edata, + fused=True, + mapping=mapping, + ) + if exclude_edges is not None: + eid_excluder = EidExcluder(exclude_edges) + frontier = eid_excluder(frontier) + return frontier + + def _sample_neighbors( g, nodes, @@ -390,6 +515,8 @@ def _sample_neighbors( copy_edata=True, _dist_training=False, exclude_edges=None, + fused=False, + mapping=None, ): if not isinstance(nodes, dict): if len(g.ntypes) > 1: @@ -446,17 +573,64 @@ def _sample_neighbors( else: excluded_edges_all_t.append(nd.array([], ctx=ctx)) - subgidx = _CAPI_DGLSampleNeighbors( - g._graph, - nodes_all_types, - fanout_array, - edge_dir, - prob_arrays, - excluded_edges_all_t, - replace, - ) - induced_edges = subgidx.induced_edges - ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + if fused: + if _dist_training: + raise DGLError( + "distributed training not supported in fused sampling" + ) + cpu = F.device_type(g.device) == "cpu" + if isinstance(nodes, dict): + for ntype in list(nodes.keys()): + if not cpu: + break + cpu = cpu and F.device_type(nodes[ntype].device) == "cpu" + else: + cpu = cpu and F.device_type(nodes.device) == "cpu" + if not cpu or F.backend_name != "pytorch": + raise DGLError( + "Only PyTorch backend and cpu is supported in fused sampling" + ) + + if mapping is None: + mapping = {} + mapping_name = "__mapping" + str(os.getpid()) + if mapping_name not in mapping.keys(): + mapping[mapping_name] = [ + torch.LongTensor(g.num_nodes(ntype)).fill_(-1) + for ntype in g.ntypes + ] + + subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused( + g._graph, + nodes_all_types, + [F.to_dgl_nd(m) for m in mapping[mapping_name]], + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + for mapping_vector, src_nodes in zip( + mapping[mapping_name], induced_nodes + ): + mapping_vector[F.from_dgl_nd(src_nodes).type(F.int64)] = -1 + + new_ntypes = (g.ntypes, g.ntypes) + ret = DGLBlock(subgidx, new_ntypes, g.etypes) + assert ret.is_unibipartite + + else: + subgidx = _CAPI_DGLSampleNeighbors( + g._graph, + nodes_all_types, + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + induced_edges = subgidx.induced_edges # handle features # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other @@ -465,12 +639,31 @@ def _sample_neighbors( # only set the edge IDs. if not _dist_training: if copy_ndata: - node_frames = utils.extract_node_subframes(g, device) - utils.set_new_frames(ret, node_frames=node_frames) + if fused: + src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes] + dst_node_ids = [ + utils.toindex( + nodes.get(ntype, []), g._idtype_str + ).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx)) + for ntype in g.ntypes + ] + node_frames = utils.extract_node_subframes_for_block( + g, src_node_ids, dst_node_ids + ) + utils.set_new_frames(ret, node_frames=node_frames) + else: + node_frames = utils.extract_node_subframes(g, device) + utils.set_new_frames(ret, node_frames=node_frames) if copy_edata: - edge_frames = utils.extract_edge_subframes(g, induced_edges) - utils.set_new_frames(ret, edge_frames=edge_frames) + if fused: + edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges] + edge_frames = utils.extract_edge_subframes(g, edge_ids) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: + edge_frames = utils.extract_edge_subframes(g, induced_edges) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: for i, etype in enumerate(ret.canonical_etypes): ret.edges[etype].data[EID] = induced_edges[i] @@ -479,6 +672,7 @@ def _sample_neighbors( DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors) +DGLGraph.sample_neighbors_fused = utils.alias_func(sample_neighbors_fused) def sample_neighbors_biased( diff --git a/src/array/array.cc b/src/array/array.cc index 57a2af0761c8..7472a4c3e09f 100644 --- a/src/array/array.cc +++ b/src/array/array.cc @@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling( return ret; } +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + std::pair ret; + if (IsNullArray(prob_or_mask)) { + ATEN_XPU_SWITCH( + rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", { + ret = + impl::CSRRowWiseSamplingUniformFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + replace); + }); + } else { + CHECK_VALID_CONTEXT(prob_or_mask, rows); + ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", { + ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( + prob_or_mask->dtype, FloatType, "probability or mask", { + ret = impl::CSRRowWiseSamplingFused< + XPU, IdType, FloatType, map_seed_nodes>( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + prob_or_mask, replace); + }); + }); + } + return ret; +} + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, diff --git a/src/array/array_op.h b/src/array/array_op.h index 3b37db78fba2..ceb15492a07c 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); +// FloatType is the type of probability data. +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace); + // FloatType is the type of probability data. template COOMatrix CSRRowWisePerEtypeSampling( @@ -190,6 +198,11 @@ template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, bool replace); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/array/cpu/concurrent_id_hash_map.cc b/src/array/cpu/concurrent_id_hash_map.cc index 46e444eb9ebf..2f9bde7c5dce 100644 --- a/src/array/cpu/concurrent_id_hash_map.cc +++ b/src/array/cpu/concurrent_id_hash_map.cc @@ -223,5 +223,27 @@ ConcurrentIdHashMap::AttemptInsertAt(int64_t pos, IdType key) { template class ConcurrentIdHashMap; template class ConcurrentIdHashMap; +template +bool BoolCompareAndSwap(IdType* ptr) { +#ifdef _MSC_VER + if (sizeof(IdType) == 4) { + return _InterlockedCompareExchange(reinterpret_cast(ptr), 0, -1) == + -1; + } else if (sizeof(IdType) == 8) { + return _InterlockedCompareExchange64( + reinterpret_cast(ptr), 0, -1) == -1; + } else { + LOG(FATAL) << "ID can only be int32 or int64"; + } +#elif __GNUC__ // _MSC_VER + return __sync_bool_compare_and_swap(ptr, -1, 0); +#else // _MSC_VER +#error "CompareAndSwap is not supported on this platform." +#endif // _MSC_VER +} + +template bool BoolCompareAndSwap(int32_t*); +template bool BoolCompareAndSwap(int64_t*); + } // namespace aten } // namespace dgl diff --git a/src/array/cpu/concurrent_id_hash_map.h b/src/array/cpu/concurrent_id_hash_map.h index db09ee902130..aa3a32a2fdcb 100644 --- a/src/array/cpu/concurrent_id_hash_map.h +++ b/src/array/cpu/concurrent_id_hash_map.h @@ -195,6 +195,9 @@ class ConcurrentIdHashMap { IdType mask_; }; +template +bool BoolCompareAndSwap(IdType* ptr); + } // namespace aten } // namespace dgl diff --git a/src/array/cpu/rowwise_pick.h b/src/array/cpu/rowwise_pick.h index cb0b32298157..7e34976e46f0 100644 --- a/src/array/cpu/rowwise_pick.h +++ b/src/array/cpu/rowwise_pick.h @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace dgl { @@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function& et_idx, const std::vector& et_eid, const IdxType* eid, IdxType* out_idx)>; +template +std::pair CSRRowWisePickFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_picks, bool replace, + PickFn pick_fn, NumPicksFn num_picks_fn) { + using namespace aten; + + const IdxType* indptr = static_cast(mat.indptr->data); + const IdxType* indices = static_cast(mat.indices->data); + const IdxType* data = + CSRHasData(mat) ? static_cast(mat.data->data) : nullptr; + const IdxType* rows_data = static_cast(rows->data); + const int64_t num_rows = rows->shape[0]; + const auto& ctx = mat.indptr->ctx; + const auto& idtype = mat.indptr->dtype; + IdxType* seed_mapping_data = nullptr; + if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr(); + + const int num_threads = runtime::compute_num_threads(0, num_rows, 1); + std::vector global_prefix(num_threads + 1, 0); + + IdArray picked_col, picked_idx, picked_coo_rows; + + IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx); + IdxType* block_csr_indptr_data = block_csr_indptr.Ptr(); + +#pragma omp parallel num_threads(num_threads) + { + const int thread_id = omp_get_thread_num(); + + const int64_t start_i = + thread_id * (num_rows / num_threads) + + std::min(static_cast(thread_id), num_rows % num_threads); + const int64_t end_i = + (thread_id + 1) * (num_rows / num_threads) + + std::min(static_cast(thread_id + 1), num_rows % num_threads); + assert(thread_id + 1 < num_threads || end_i == num_rows); + + const int64_t num_local = end_i - start_i; + + std::unique_ptr local_prefix(new int64_t[num_local + 1]); + local_prefix[0] = 0; + for (int64_t i = start_i; i < end_i; ++i) { + // build prefix-sum + const int64_t local_i = i - start_i; + const IdxType rid = rows_data[i]; + if (map_seed_nodes) seed_mapping_data[rid] = i; + + IdxType len = num_picks_fn( + rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data); + local_prefix[local_i + 1] = local_prefix[local_i] + len; + } + global_prefix[thread_id + 1] = local_prefix[num_local]; + +#pragma omp barrier +#pragma omp master + { + for (int t = 0; t < num_threads; ++t) { + global_prefix[t + 1] += global_prefix[t]; + } + picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_coo_rows = + IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + } + +#pragma omp barrier + IdxType* picked_cdata = picked_col.Ptr(); + IdxType* picked_idata = picked_idx.Ptr(); + IdxType* picked_rows = picked_coo_rows.Ptr(); + + const IdxType thread_offset = global_prefix[thread_id]; + + for (int64_t i = start_i; i < end_i; ++i) { + const IdxType rid = rows_data[i]; + const int64_t local_i = i - start_i; + block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset; + + const IdxType off = indptr[rid]; + const IdxType len = indptr[rid + 1] - off; + if (len == 0) continue; + + const int64_t row_offset = local_prefix[local_i] + thread_offset; + const int64_t num_picks = + local_prefix[local_i + 1] + thread_offset - row_offset; + + pick_fn( + rid, off, len, num_picks, indices, data, picked_idata + row_offset); + for (int64_t j = 0; j < num_picks; ++j) { + const IdxType picked = picked_idata[row_offset + j]; + picked_cdata[row_offset + j] = indices[picked]; + picked_idata[row_offset + j] = data ? data[picked] : picked; + picked_rows[row_offset + j] = i; + } + } + } + block_csr_indptr_data[num_rows] = global_prefix.back(); + + const IdxType num_cols = picked_col->shape[0]; + if (map_seed_nodes) { + (*new_seed_nodes).resize(num_rows); + memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows); + } + + return std::make_pair( + CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx), + picked_coo_rows); +} + // Template for picking non-zero values row-wise. The implementation utilizes // OpenMP parallelization on rows because each row performs computation // independently. diff --git a/src/array/cpu/rowwise_sampling.cc b/src/array/cpu/rowwise_sampling.cc index 56bb9a40bb4c..222a009341f9 100644 --- a/src/array/cpu/rowwise_sampling.cc +++ b/src/array/cpu/rowwise_sampling.cc @@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling( template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + CHECK(prob_or_mask.defined()); + auto num_picks_fn = + GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); + auto pick_fn = + GetSamplingPickFn(num_samples, prob_or_mask, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, NDArray, bool); + template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, @@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform( template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector* new_seed_nodes, int64_t num_samples, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + auto num_picks_fn = + GetSamplingUniformNumPicksFn(num_samples, replace); + auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector*, int64_t, bool); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 5a161f744534..5393a200123d 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -6,13 +6,16 @@ #include #include +#include #include #include +#include #include #include #include +#include "../../../array/cpu/concurrent_id_hash_map.h" #include "../../../c_api_common.h" #include "../../unit_graph.h" @@ -22,6 +25,76 @@ using namespace dgl::aten; namespace dgl { namespace sampling { +template +void ExcludeCertainEdgesFused( + std::vector* sampled_graphs, std::vector* induced_edges, + std::vector* sampled_coo_rows, + const std::vector& exclude_edges, + std::vector* weights = nullptr) { + int etypes = (*sampled_graphs).size(); + std::vector remain_induced_edges(etypes); + std::vector remain_indptrs(etypes); + std::vector remain_indices(etypes); + std::vector remain_coo_rows(etypes); + std::vector remain_weights(etypes); + for (int etype = 0; etype < etypes; ++etype) { + if (exclude_edges[etype].GetSize() == 0 || + (*sampled_graphs)[etype].num_rows == 0) { + remain_induced_edges[etype] = (*induced_edges)[etype]; + if (weights) remain_weights[etype] = (*weights)[etype]; + continue; + } + const auto dtype = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype]->dtype + : DGLDataType{kDGLFloat, 8 * sizeof(float), 1}; + ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", { + IdType* indptr = (*sampled_graphs)[etype].indptr.Ptr(); + IdType* indices = (*sampled_graphs)[etype].indices.Ptr(); + IdType* coo_rows = (*sampled_coo_rows)[etype].Ptr(); + IdType* induced_edges_data = (*induced_edges)[etype].Ptr(); + FloatType* weights_data = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype].Ptr() + : nullptr; + const IdType exclude_edges_len = exclude_edges[etype]->shape[0]; + std::sort( + exclude_edges[etype].Ptr(), + exclude_edges[etype].Ptr() + exclude_edges_len); + const IdType* exclude_edges_data = exclude_edges[etype].Ptr(); + IdType outIndices = 0; + for (IdType row = 0; row < (*sampled_graphs)[etype].indptr->shape[0] - 1; + ++row) { + auto tmp_row = indptr[row]; + if (outIndices != indptr[row]) indptr[row] = outIndices; + for (IdType col = tmp_row; col < indptr[row + 1]; ++col) { + if (!std::binary_search( + exclude_edges_data, exclude_edges_data + exclude_edges_len, + induced_edges_data[col])) { + indices[outIndices] = indices[col]; + induced_edges_data[outIndices] = induced_edges_data[col]; + coo_rows[outIndices] = coo_rows[col]; + if (weights_data) weights_data[outIndices] = weights_data[col]; + ++outIndices; + } + } + } + indptr[(*sampled_graphs)[etype].indptr->shape[0] - 1] = outIndices; + remain_induced_edges[etype] = + aten::IndexSelect((*induced_edges)[etype], 0, outIndices); + remain_weights[etype] = + weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices) + : NullArray(); + remain_indices[etype] = + aten::IndexSelect((*sampled_graphs)[etype].indices, 0, outIndices); + (*sampled_coo_rows)[etype] = + aten::IndexSelect((*sampled_coo_rows)[etype], 0, outIndices); + (*sampled_graphs)[etype] = CSRMatrix( + (*sampled_graphs)[etype].num_rows, outIndices, + (*sampled_graphs)[etype].indptr, remain_indices[etype], + remain_induced_edges[etype]); + }); + } +} + std::pair> ExcludeCertainEdges( const HeteroSubgraph& sg, const std::vector& exclude_edges, const std::vector* weights = nullptr) { @@ -266,6 +339,242 @@ HeteroSubgraph SampleNeighbors( return ret; } +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + const std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace) { + CHECK_EQ(nodes.size(), hg->NumVertexTypes()) + << "Number of node ID tensors must match the number of node types."; + CHECK_EQ(fanouts.size(), hg->NumEdgeTypes()) + << "Number of fanout values must match the number of edge types."; + CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes()) + << "Number of probability tensors must match the number of edge types."; + + DGLContext ctx = aten::GetContextOf(nodes); + + std::vector sampled_graphs; + std::vector sampled_coo_rows; + std::vector induced_edges; + std::vector induced_vertices; + std::vector num_nodes_per_type; + std::vector> new_nodes_vec(hg->NumVertexTypes()); + std::vector seed_nodes_mapped(hg->NumVertexTypes(), 0); + + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t rhs_node_type = + (dir == EdgeDir::kOut) ? src_vtype : dst_vtype; + const IdArray nodes_ntype = nodes[rhs_node_type]; + const int64_t num_nodes = nodes_ntype->shape[0]; + + if (num_nodes == 0 || fanouts[etype] == 0) { + // Nothing to sample for this etype, create a placeholder + sampled_graphs.push_back(CSRMatrix()); + sampled_coo_rows.push_back(IdArray()); + induced_edges.push_back(aten::NullArray(hg->DataType(), ctx)); + } else { + bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type]; + // sample from one relation graph + std::pair sampled_graph; + auto sampling_fn = map_seed_nodes + ? aten::CSRRowWiseSamplingFused + : aten::CSRRowWiseSamplingFused; + auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE; + auto avail_fmt = hg->SelectFormat(etype, req_fmt); + switch (avail_fmt) { + case SparseFormat::kCSR: + CHECK(dir == EdgeDir::kOut) + << "Cannot sample out edges on CSC matrix."; + // In heterographs nodes of two diffrent types can be connected + // therefore two diffrent mappings and node vectors are needed + sampled_graph = sampling_fn( + hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype], + &new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + case SparseFormat::kCSC: + CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; + sampled_graph = sampling_fn( + hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype], + &new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + default: + LOG(FATAL) << "Unsupported sparse format."; + } + seed_nodes_mapped[rhs_node_type]++; + sampled_graphs.push_back(sampled_graph.first); + if (sampled_graph.first.data.defined()) + induced_edges.push_back(sampled_graph.first.data); + else + induced_edges.push_back( + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)); + sampled_coo_rows.push_back(sampled_graph.second); + } + } + + if (!exclude_edges.empty()) { + ExcludeCertainEdgesFused( + &sampled_graphs, &induced_edges, &sampled_coo_rows, exclude_edges); + for (size_t i = 0; i < hg->NumEdgeTypes(); i++) { + if (sampled_graphs[i].data.defined()) + induced_edges[i] = std::move(sampled_graphs[i].data); + else + induced_edges[i] = + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx); + } + } + + // map indices + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t lhs_node_type = + (dir == EdgeDir::kIn) ? src_vtype : dst_vtype; + if (sampled_graphs[etype].num_cols != 0) { + auto num_cols = sampled_graphs[etype].num_cols; + int num_threads_col = runtime::compute_num_threads(0, num_cols, 1); + std::vector global_prefix_col(num_threads_col + 1, 0); + std::vector> src_nodes_local(num_threads_col); + IdType* mapping_data_dst = mapping[lhs_node_type].Ptr(); + IdType* cdata = sampled_graphs[etype].indices.Ptr(); +#pragma omp parallel num_threads(num_threads_col) + { + const int thread_id = omp_get_thread_num(); + num_threads_col = omp_get_num_threads(); + + const int64_t start_i = + thread_id * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id), num_cols % num_threads_col); + const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id + 1), + num_cols % num_threads_col); + assert(thread_id + 1 < num_threads_col || end_i == num_cols); + for (int64_t i = start_i; i < end_i; ++i) { + int64_t picked_idx = cdata[i]; + bool spot_claimed = + BoolCompareAndSwap(&mapping_data_dst[picked_idx]); + if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx); + } + global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size(); + +#pragma omp barrier +#pragma omp master + { + global_prefix_col[0] = new_nodes_vec[lhs_node_type].size(); + for (int t = 0; t < num_threads_col; ++t) { + global_prefix_col[t + 1] += global_prefix_col[t]; + } + } + +#pragma omp barrier + int64_t mapping_shift = global_prefix_col[thread_id]; + for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i) + mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i; + +#pragma omp barrier + for (int64_t i = start_i; i < end_i; ++i) { + IdType picked_idx = cdata[i]; + IdType mapped_idx = mapping_data_dst[picked_idx]; + cdata[i] = mapped_idx; + } + } + IdType offset = new_nodes_vec[lhs_node_type].size(); + new_nodes_vec[lhs_node_type].resize(global_prefix_col.back()); + for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) { + memcpy( + new_nodes_vec[lhs_node_type].data() + offset, + &src_nodes_local[thread_id][0], + src_nodes_local[thread_id].size() * sizeof(IdType)); + offset += src_nodes_local[thread_id].size(); + } + } + } + + // counting how many nodes of each ntype were sampled + num_nodes_per_type.resize(2 * hg->NumVertexTypes()); + for (size_t i = 0; i < hg->NumVertexTypes(); i++) { + num_nodes_per_type[i] = new_nodes_vec[i].size(); + num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0]; + induced_vertices.push_back( + VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8)); + } + + std::vector subrels(hg->NumEdgeTypes()); + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + if (sampled_graphs[etype].num_rows == 0) { + subrels[etype] = UnitGraph::Empty( + 2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + hg->DataType(), ctx); + } else { + CSRMatrix graph = sampled_graphs[etype]; + if (dir == EdgeDir::kOut) { + subrels[etype] = UnitGraph::CreateFromCSRAndCOO( + 2, + CSRMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + sampled_coo_rows[etype], graph.indices), + ALL_CODE); + } else { + subrels[etype] = UnitGraph::CreateFromCSCAndCOO( + 2, + CSRMatrix( + nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + graph.indices, sampled_coo_rows[etype]), + ALL_CODE); + } + } + } + + HeteroSubgraph ret; + + const auto meta_graph = hg->meta_graph(); + const EdgeArray etypes = meta_graph->Edges("eid"); + const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes()); + + const auto new_meta_graph = ImmutableGraph::CreateFromCOO( + hg->NumVertexTypes() * 2, etypes.src, new_dst); + + HeteroGraphPtr new_graph = + CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type); + return std::make_tuple(new_graph, induced_edges, induced_vertices); +} + +template std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr, const std::vector&, + const std::vector&, const std::vector&, EdgeDir, + const std::vector&, const std::vector&, bool); + +template std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr, const std::vector&, + const std::vector&, const std::vector&, EdgeDir, + const std::vector&, const std::vector&, bool); + HeteroSubgraph SampleNeighborsEType( const HeteroGraphPtr hg, const IdArray nodes, const std::vector& eid2etype_offset, @@ -568,6 +877,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") *rv = HeteroSubgraphRef(subg); }); +DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsFused") + .set_body([](DGLArgs args, DGLRetValue* rv) { + HeteroGraphRef hg = args[0]; + const auto& nodes = ListValueToVector(args[1]); + auto mapping = ListValueToVector(args[2]); + IdArray fanouts_array = args[3]; + const auto& fanouts = fanouts_array.ToVector(); + const std::string dir_str = args[4]; + const auto& prob_or_mask = ListValueToVector(args[5]); + const auto& exclude_edges = ListValueToVector(args[6]); + const bool replace = args[7]; + + CHECK(dir_str == "in" || dir_str == "out") + << "Invalid edge direction. Must be \"in\" or \"out\"."; + EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut; + + HeteroGraphPtr new_graph; + std::vector induced_edges; + std::vector induced_vertices; + + ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, { + std::tie(new_graph, induced_edges, induced_vertices) = + SampleNeighborsFused( + hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask, + exclude_edges, replace); + }); + + List lhs_nodes_ref; + for (IdArray& array : induced_vertices) + lhs_nodes_ref.push_back(Value(MakeValue(array))); + List induced_edges_ref; + for (IdArray& array : induced_edges) + induced_edges_ref.push_back(Value(MakeValue(array))); + List ret; + ret.push_back(HeteroGraphRef(new_graph)); + ret.push_back(lhs_nodes_ref); + ret.push_back(induced_edges_ref); + + *rv = ret; + }); + DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index 3c26cf92f82e..b4e29de3a39a 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csr.num_rows); + CHECK_EQ(coo.num_cols, csr.num_cols); + if (num_vtypes == 1) { + CHECK_EQ(csr.num_rows, csr.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr csrPtr(new CSR(mg, csr)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { @@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC( return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csc.num_cols); + CHECK_EQ(coo.num_cols, csc.num_rows); + if (num_vtypes == 1) { + CHECK_EQ(csc.num_rows, csc.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr cscPtr(new CSR(mg, csc)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { if (g->NumBits() == bits) { return g; diff --git a/src/graph/unit_graph.h b/src/graph/unit_graph.h index 5951c5001f5c..244f922e906d 100644 --- a/src/graph/unit_graph.h +++ b/src/graph/unit_graph.h @@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (out) CSR and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC arrays */ static HeteroGraphPtr CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, @@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Convert the graph to use the given number of bits for storage */ static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index f62ba858afc6..1964eb80c6d0 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -7,6 +7,11 @@ import numpy as np import pytest +sample_neighbors_fusing_mode = { + True: dgl.sampling.sample_neighbors_fused, + False: dgl.sampling.sample_neighbors, +} + def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None): traces = F.asnumpy(traces) @@ -555,15 +560,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse): return g, hg -def _test_sample_neighbors(hypersparse, prob): +def _test_sample_neighbors(hypersparse, prob, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) def _test1(p, replace): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 1], -1, prob=p, replace=replace ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -576,12 +584,17 @@ def _test1(p, replace): assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 1], 2, prob=p, replace=replace ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(v))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -600,11 +613,14 @@ def _test1(p, replace): _test1(prob, False) # w/o replacement, uniform def _test2(p, replace): # fanout > #neighbors - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 2], -1, prob=p, replace=replace ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -617,13 +633,16 @@ def _test2(p, replace): # fanout > #neighbors assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 2], 2, prob=p, replace=replace ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] assert set(F.asnumpy(F.unique(v))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -641,10 +660,13 @@ def _test2(p, replace): # fanout > #neighbors _test2(prob, False) # w/o replacement, uniform def _test3(p, replace): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -652,10 +674,13 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 0 for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -667,13 +692,16 @@ def _test3(p, replace): # test different fanouts for different relations for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0, "coin": 0}, {"follow": 1, "play": 2, "liked-by": 0, "flips": -1}, replace=True, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 2 assert subg["play"].num_edges() == 2 @@ -795,15 +823,19 @@ def _test3(p): assert subg["flips"].num_edges() == 4 -def _test_sample_neighbors_outedge(hypersparse): +def _test_sample_neighbors_outedge(hypersparse, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) def _test1(p, replace): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 1], -1, prob=p, replace=replace, edge_dir="out" ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -816,12 +848,15 @@ def _test1(p, replace): assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 1], 2, prob=p, replace=replace, edge_dir="out" ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] assert set(F.asnumpy(F.unique(u))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -842,11 +877,14 @@ def _test1(p, replace): _test1("prob", False) # w/o replacement def _test2(p, replace): # fanout > #neighbors - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 2], -1, prob=p, replace=replace, edge_dir="out" ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -859,13 +897,17 @@ def _test2(p, replace): # fanout > #neighbors assert uv == uv_ans for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( g, [0, 2], 2, prob=p, replace=replace, edge_dir="out" ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(u))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -885,7 +927,7 @@ def _test2(p, replace): # fanout > #neighbors _test2("prob", False) # w/o replacement def _test3(p, replace): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, -1, @@ -893,7 +935,11 @@ def _test3(p, replace): replace=replace, edge_dir="out", ) - assert len(subg.ntypes) == 3 + + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -901,7 +947,7 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 0 for i in range(10): - subg = dgl.sampling.sample_neighbors( + subg = sample_neighbors_fusing_mode[fused]( hg, {"user": [0, 1], "game": 0}, 2, @@ -909,7 +955,10 @@ def _test3(p, replace): replace=replace, edge_dir="out", ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -1077,7 +1126,9 @@ def _test3(): def test_sample_neighbors_noprob(): - _test_sample_neighbors(False, None) + _test_sample_neighbors(False, None, False) + if F._default_context_str != "gpu" and F.backend_name == "pytorch": + _test_sample_neighbors(False, None, True) # _test_sample_neighbors(True) @@ -1086,7 +1137,9 @@ def test_sample_labors_noprob(): def test_sample_neighbors_prob(): - _test_sample_neighbors(False, "prob") + _test_sample_neighbors(False, "prob", False) + if F._default_context_str != "gpu" and F.backend_name == "pytorch": + _test_sample_neighbors(False, "prob", True) # _test_sample_neighbors(True) @@ -1095,7 +1148,9 @@ def test_sample_labors_prob(): def test_sample_neighbors_outedge(): - _test_sample_neighbors_outedge(False) + _test_sample_neighbors_outedge(False, False) + if F._default_context_str != "gpu" and F.backend_name == "pytorch": + _test_sample_neighbors_outedge(False, True) # _test_sample_neighbors_outedge(True) @@ -1107,7 +1162,9 @@ def test_sample_neighbors_outedge(): reason="GPU sample neighbors with mask not implemented", ) def test_sample_neighbors_mask(): - _test_sample_neighbors(False, "mask") + _test_sample_neighbors(False, "mask", False) + if F._default_context_str != "gpu" and F.backend_name == "pytorch": + _test_sample_neighbors(False, "mask", True) @unittest.skipIf( @@ -1128,21 +1185,26 @@ def test_sample_neighbors_topk_outedge(): # _test_sample_neighbors_topk_outedge(True) -def test_sample_neighbors_with_0deg(): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_with_0deg(fused): + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False ) assert sg.num_edges() == 0 - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True ) assert sg.num_edges() == 0 @@ -1274,7 +1336,7 @@ def check_num(nodes, tag): ) def test_sample_neighbors_biased_bipartite(): g = create_test_graph(100, 30, True) - num_dst = g.number_of_dst_nodes() + num_dst = g.num_dst_nodes() bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32) def check_num(nodes, tag): @@ -1492,7 +1554,12 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_heteroG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") d_i_d_u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1565,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): ("drug", "treats", "disease"): excluded_d_t_d_edges, } - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, { "drug": sampled_drug_node, @@ -1576,37 +1643,84 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): exclude_edges=excluded_edges, ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - did_excluded_nodes_U, - did_excluded_nodes_V, - etype=("drug", "interacts", "drug"), + if fused: + + def contain_edge(g, sg, etype, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges[etype].data[dgl.EID], etype), + axis=1, + ), + ) ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) + + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge( + g, + sg, + ("drug", "interacts", "drug"), + did_excluded_nodes_U, + did_excluded_nodes_V, ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dig_excluded_nodes_U, - dig_excluded_nodes_V, - etype=("drug", "interacts", "gene"), + assert not contain_edge( + g, + sg, + ("drug", "interacts", "gene"), + dig_excluded_nodes_U, + dig_excluded_nodes_V, + ) + assert not contain_edge( + g, + sg, + ("drug", "treats", "disease"), + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + ) + else: + assert not np.any( + F.asnumpy( + sg.has_edges_between( + did_excluded_nodes_U, + did_excluded_nodes_V, + etype=("drug", "interacts", "drug"), + ) ) ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dtd_excluded_nodes_U, - dtd_excluded_nodes_V, - etype=("drug", "treats", "disease"), + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dig_excluded_nodes_U, + dig_excluded_nodes_V, + etype=("drug", "interacts", "gene"), + ) + ) + ) + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + etype=("drug", "treats", "disease"), + ) ) ) - ) @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_homoG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_homoG(dtype, fused): + if fused and ( + F._default_context_str == "gpu" or F.backend_name != "pytorch" + ): + pytest.skip("Fused sampling support CPU with backend PyTorch.") u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1629,13 +1743,33 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): excluded_nodes_U = g_edges[U][b_idx:e_idx] excluded_nodes_V = g_edges[V][b_idx:e_idx] - sg = dgl.sampling.sample_neighbors( + sg = sample_neighbors_fusing_mode[fused]( g, sampled_node, sampled_amount, exclude_edges=excluded_edges ) + if fused: + + def contain_edge(g, sg, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1 + ), + ) + ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) - assert not np.any( - F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) - ) + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V) + else: + assert not np.any( + F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) + ) @pytest.mark.parametrize("dtype", ["int32", "int64"])