diff --git a/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py new file mode 100644 index 000000000000..91fc6705d5eb --- /dev/null +++ b/benchmarks/benchmarks/api/bench_fused_sample_neighbors.py @@ -0,0 +1,39 @@ +import time + +import dgl +import dgl.function as fn + +import numpy as np +import torch + +from .. import utils + + +@utils.benchmark("time") +@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"]) +@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"]) +@utils.parametrize("format", ["csr", "csc"]) +@utils.parametrize("seed_nodes_num", [200, 5000, 20000]) +@utils.parametrize("fanout", [5, 20, 40]) +def track_time(graph_name, format, seed_nodes_num, fanout): + device = utils.get_bench_device() + graph = utils.get_graph(graph_name, format).to(device) + + edge_dir = "in" if format == "csc" else "out" + seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num) + seed_nodes = torch.from_numpy(seed_nodes).to(device) + + # dry run + for i in range(3): + dgl.sampling.sample_neighbors( + graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + ) + + # timing + with utils.Timer() as t: + for i in range(50): + dgl.sampling.sample_neighbors( + graph, seed_nodes, fanout, edge_dir=edge_dir, fused=True + ) + + return t.elapsed_secs / 50 diff --git a/include/dgl/aten/csr.h b/include/dgl/aten/csr.h index 110369c57daf..ebe8601444b8 100644 --- a/include/dgl/aten/csr.h +++ b/include/dgl/aten/csr.h @@ -570,6 +570,72 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask = NDArray(), bool replace = true); +/*! + * @brief Randomly select a fixed number of non-zero entries along each given + * row independently. + * + * The function performs random choices along each row independently. + * The picked indices are returned in the form of a CSR matrix, with + * additional IdArray that is an extended version of CSR's index pointers. + * + * With template parameter set to True rows are also saved as new seed nodes and + * mapped + * + * If replace is false and a row has fewer non-zero values than num_samples, + * all the values are picked. + * + * Examples: + * + * // csr.num_rows = 4; + * // csr.num_cols = 4; + * // csr.indptr = [0, 2, 3, 3, 5] + * // csr.indices = [0, 1, 1, 2, 3] + * // csr.data = [2, 3, 0, 1, 4] + * CSRMatrix csr = ...; + * IdArray rows = ... ; // [1, 3] + * IdArray seed_mapping = [-1, -1, -1, -1]; + * std::vector new_seed_nodes = {}; + * + * std::pair sampled = CSRRowWiseSamplingFused< + * typename IdType, True>( + * csr, rows, seed_mapping, + * new_seed_nodes, 2, + * FloatArray(), false); + * // possible sampled csr matrix: + * // sampled.first.num_rows = 2 + * // sampled.first.num_cols = 3 + * // sampled.first.indptr = [0, 1, 3] + * // sampled.first.indices = [1, 2, 3] + * // sampled.first.data = [0, 1, 4] + * // sampled.second = [0, 1, 1] + * // seed_mapping = [-1, 0, -1, 1]; + * // new_seed_nodes = {1, 3}; + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes + * @param mat Input CSR matrix. + * @param rows Rows to sample from. + * @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row + * from rows will be set to its position e.g. mapping[rows[i]] = i. + * @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will + * contain rows. + * @param rows Rows to sample from. + * @param num_samples Number of samples + * @param prob_or_mask Unnormalized probability array or mask array. + * Should be of the same length as the data array. + * If an empty array is provided, assume uniform. + * @param replace True if sample with replacement + * @return A CSRMatrix storing the picked row, col and data indices, + * COO version of picked rows + * @note The edges of the entire graph must be ordered by their edge types, + * rows must be unique + */ +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask = NDArray(), bool replace = true); + /** * @brief Randomly select a fixed number of non-zero entries for each edge type * along each given row independently. diff --git a/include/dgl/sampling/neighbor.h b/include/dgl/sampling/neighbor.h index 7c17050777a2..ae7cf9e6d5d7 100644 --- a/include/dgl/sampling/neighbor.h +++ b/include/dgl/sampling/neighbor.h @@ -47,6 +47,55 @@ HeteroSubgraph SampleNeighbors( const std::vector& probability, const std::vector& exclude_edges, bool replace = true); +/** + * @brief Sample from the neighbors of the given nodes and convert a graph into + * a bipartite-structured graph for message passing. + * + * Specifically, we create one node type \c ntype_l on the "left" side and + * another node type \c ntype_r on the "right" side for each node type \c ntype. + * The nodes of type \c ntype_r would contain the nodes designated by the + * caller, and node type \c ntype_l would contain the nodes that has an edge + * connecting to one of the designated nodes. + * + * The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r. + * When sampling with replacement, the sampled subgraph could have parallel + * edges. + * + * For sampling without replace, if fanout > the number of neighbors, all the + * neighbors will be sampled. + * + * Non-deterministic algorithm, requires nodes parameter to store unique Node + * IDs. + * + * @tparam IdType Graph's index data type, can be int32_t or int64_t + * @param hg The input graph. + * @param nodes Node IDs of each type. The vector length must be equal to the + * number of node types. Empty array is allowed. + * @param mapping External parameter that should be set to a vector of IdArrays + * filled with -1, required for mapping of nodes in returned + * graph + * @param fanouts Number of sampled neighbors for each edge type. The vector + * length should be equal to the number of edge types, or one if they all have + * the same fanout. + * @param dir Edge direction. + * @param probability A vector of 1D float arrays, indicating the transition + * probability of each edge by edge type. An empty float array assumes uniform + * transition. + * @param exclude_edges Edges IDs of each type which will be excluded during + * sampling. The vector length must be equal to the number of edges types. Empty + * array is allowed. + * @param replace If true, sample with replacement. + * @return Sampled neighborhoods as a graph. The return graph has the same + * schema as the original one. + */ +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace = true); + /** * Select the neighbors with k-largest weights on the connecting edges for each * given node. diff --git a/python/dgl/dataloading/neighbor_sampler.py b/python/dgl/dataloading/neighbor_sampler.py index 573946eade9f..3243ee1d4dd6 100644 --- a/python/dgl/dataloading/neighbor_sampler.py +++ b/python/dgl/dataloading/neighbor_sampler.py @@ -1,4 +1,5 @@ """Data loading components for neighbor sampling""" +from .. import backend as F from ..base import EID, NID from ..transforms import to_block from .base import BlockSampler @@ -54,6 +55,9 @@ class NeighborSampler(BlockSampler): output_device : device, optional The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes. + fused : bool, default True + If True and device is CPU fused sample neighbors is invoked. This version + requires seed_nodes to be unique Examples -------- @@ -120,6 +124,7 @@ def __init__( prefetch_labels=None, prefetch_edge_feats=None, output_device=None, + fused=True, ): super().__init__( prefetch_node_feats=prefetch_node_feats, @@ -137,25 +142,47 @@ def __init__( ) self.prob = prob or mask self.replace = replace + self.fused = fused + self.mapping = {} + self.g = None def sample_blocks(self, g, seed_nodes, exclude_eids=None): output_nodes = seed_nodes blocks = [] - for fanout in reversed(self.fanouts): - frontier = g.sample_neighbors( - seed_nodes, - fanout, - edge_dir=self.edge_dir, - prob=self.prob, - replace=self.replace, - output_device=self.output_device, - exclude_edges=exclude_eids, - ) - eid = frontier.edata[EID] - block = to_block(frontier, seed_nodes) - block.edata[EID] = eid - seed_nodes = block.srcdata[NID] - blocks.insert(0, block) + if F.device_type(g.device) == "cpu" and self.fused: + if self.g != g: + self.mapping = {} + self.g = g + for fanout in reversed(self.fanouts): + block = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + fused=True, + exclude_edges=exclude_eids, + mapping=self.mapping, + ) + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) + else: + for fanout in reversed(self.fanouts): + frontier = g.sample_neighbors( + seed_nodes, + fanout, + edge_dir=self.edge_dir, + prob=self.prob, + replace=self.replace, + output_device=self.output_device, + exclude_edges=exclude_eids, + ) + eid = frontier.edata[EID] + block = to_block(frontier, seed_nodes) + block.edata[EID] = eid + seed_nodes = block.srcdata[NID] + blocks.insert(0, block) return seed_nodes, output_nodes, blocks diff --git a/python/dgl/sampling/neighbor.py b/python/dgl/sampling/neighbor.py index 232fb3972744..d96372ca39f6 100644 --- a/python/dgl/sampling/neighbor.py +++ b/python/dgl/sampling/neighbor.py @@ -1,9 +1,13 @@ """Neighbor sampling APIs""" +import os + +import torch + from .. import backend as F, ndarray as nd, utils from .._ffi.function import _init_api -from ..base import DGLError, EID -from ..heterograph import DGLGraph +from ..base import DGLError, EID, NID +from ..heterograph import DGLBlock, DGLGraph from .utils import EidExcluder __all__ = [ @@ -214,6 +218,8 @@ def sample_neighbors( _dist_training=False, exclude_edges=None, output_device=None, + fused=False, + mapping=None, ): """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -282,6 +288,18 @@ def sample_neighbors( output_device : Framework-specific device context object, optional The output device. Default is the same as the input graph. + fused : bool, optional + Enables faster version of NeighborSampler that is also compacting output graph, + returning a computational block. Requires nodes to be unique + + (Default: False) + + mapping : dictionary, optional + Used by fused version of sample_neighbors. To avoid constant data allocation + provide empty dictionary ({}) that will be allocated once with proper data and reused + by each function call + + (Default: None) Returns ------- DGLGraph @@ -361,6 +379,8 @@ def sample_neighbors( copy_ndata=copy_ndata, copy_edata=copy_edata, exclude_edges=exclude_edges, + fused=fused, + mapping=mapping, ) else: frontier = _sample_neighbors( @@ -372,6 +392,8 @@ def sample_neighbors( replace=replace, copy_ndata=copy_ndata, copy_edata=copy_edata, + fused=fused, + mapping=mapping, ) if exclude_edges is not None: eid_excluder = EidExcluder(exclude_edges) @@ -390,6 +412,8 @@ def _sample_neighbors( copy_edata=True, _dist_training=False, exclude_edges=None, + fused=False, + mapping=None, ): if not isinstance(nodes, dict): if len(g.ntypes) > 1: @@ -446,17 +470,54 @@ def _sample_neighbors( else: excluded_edges_all_t.append(nd.array([], ctx=ctx)) - subgidx = _CAPI_DGLSampleNeighbors( - g._graph, - nodes_all_types, - fanout_array, - edge_dir, - prob_arrays, - excluded_edges_all_t, - replace, - ) - induced_edges = subgidx.induced_edges - ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + if fused: + if _dist_training: + raise DGLError( + "distributed training not supported in fused sampling" + ) + if F.device_type(g.device) != "cpu": + raise DGLError("Only cpu is supported in fused sampling") + + if mapping == None: + mapping = {} + mapping_name = "__mapping" + str(os.getpid()) + if mapping_name not in mapping.keys(): + mapping[mapping_name] = [ + torch.LongTensor(g.num_nodes(ntype)).fill_(-1) + for ntype in g.ntypes + ] + + subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused( + g._graph, + nodes_all_types, + [F.to_dgl_nd(m) for m in mapping[mapping_name]], + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + for mapping_vector, src_nodes in zip( + mapping[mapping_name], induced_nodes + ): + mapping_vector[F.from_dgl_nd(src_nodes).type(torch.int64)] = -1 + + new_ntypes = (g.ntypes, g.ntypes) + ret = DGLBlock(subgidx, new_ntypes, g.etypes) + assert ret.is_unibipartite + + else: + subgidx = _CAPI_DGLSampleNeighbors( + g._graph, + nodes_all_types, + fanout_array, + edge_dir, + prob_arrays, + excluded_edges_all_t, + replace, + ) + ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes) + induced_edges = subgidx.induced_edges # handle features # (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other @@ -465,12 +526,31 @@ def _sample_neighbors( # only set the edge IDs. if not _dist_training: if copy_ndata: - node_frames = utils.extract_node_subframes(g, device) - utils.set_new_frames(ret, node_frames=node_frames) + if fused: + src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes] + dst_node_ids = [ + utils.toindex( + nodes.get(ntype, []), g._idtype_str + ).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx)) + for ntype in g.ntypes + ] + node_frames = utils.extract_node_subframes_for_block( + g, src_node_ids, dst_node_ids + ) + utils.set_new_frames(ret, node_frames=node_frames) + else: + node_frames = utils.extract_node_subframes(g, device) + utils.set_new_frames(ret, node_frames=node_frames) if copy_edata: - edge_frames = utils.extract_edge_subframes(g, induced_edges) - utils.set_new_frames(ret, edge_frames=edge_frames) + if fused: + edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges] + edge_frames = utils.extract_edge_subframes(g, edge_ids) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: + edge_frames = utils.extract_edge_subframes(g, induced_edges) + utils.set_new_frames(ret, edge_frames=edge_frames) + else: for i, etype in enumerate(ret.canonical_etypes): ret.edges[etype].data[EID] = induced_edges[i] diff --git a/src/array/array.cc b/src/array/array.cc index 2bd2303c0376..b04a80fba2a5 100644 --- a/src/array/array.cc +++ b/src/array/array.cc @@ -595,6 +595,47 @@ COOMatrix CSRRowWiseSampling( return ret; } +template +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + std::pair ret; + if (IsNullArray(prob_or_mask)) { + ATEN_XPU_SWITCH( + rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", { + ret = + impl::CSRRowWiseSamplingUniformFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + replace); + }); + } else { + CHECK_VALID_CONTEXT(prob_or_mask, rows); + ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", { + ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH( + prob_or_mask->dtype, FloatType, "probability or mask", { + ret = impl::CSRRowWiseSamplingFused< + XPU, IdType, FloatType, map_seed_nodes>( + mat, rows, seed_mapping, new_seed_nodes, num_samples, + prob_or_mask, replace); + }); + }); + } + return ret; +} + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, const std::vector& num_samples, diff --git a/src/array/array_op.h b/src/array/array_op.h index ff31bdc18f7c..4d33d84fc40d 100644 --- a/src/array/array_op.h +++ b/src/array/array_op.h @@ -182,6 +182,14 @@ COOMatrix CSRRowWiseSampling( CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask, bool replace); +// FloatType is the type of probability data. +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace); + // FloatType is the type of probability data. template COOMatrix CSRRowWisePerEtypeSampling( @@ -194,6 +202,11 @@ template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, bool replace); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/array/cpu/rowwise_pick.h b/src/array/cpu/rowwise_pick.h index cb0b32298157..d4872d2d0397 100644 --- a/src/array/cpu/rowwise_pick.h +++ b/src/array/cpu/rowwise_pick.h @@ -94,6 +94,115 @@ using EtypeRangePickFn = std::function& et_idx, const std::vector& et_eid, const IdxType* eid, IdxType* out_idx)>; +template +std::pair CSRRowWisePickFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_picks, bool replace, + PickFn pick_fn, NumPicksFn num_picks_fn) { + using namespace aten; + + const IdxType* indptr = static_cast(mat.indptr->data); + const IdxType* indices = static_cast(mat.indices->data); + const IdxType* data = + CSRHasData(mat) ? static_cast(mat.data->data) : nullptr; + const IdxType* rows_data = static_cast(rows->data); + const int64_t num_rows = rows->shape[0]; + const auto& ctx = mat.indptr->ctx; + const auto& idtype = mat.indptr->dtype; + IdxType* seed_mapping_data = nullptr; + if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr(); + + const int num_threads = runtime::compute_num_threads(0, num_rows, 1); + std::vector global_prefix(num_threads + 1, 0); + + IdArray picked_col, picked_idx, picked_coo_rows; + + IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx); + IdxType* block_csr_indptr_data = block_csr_indptr.Ptr(); + +#pragma omp parallel num_threads(num_threads) + { + const int thread_id = omp_get_thread_num(); + + const int64_t start_i = + thread_id * (num_rows / num_threads) + + std::min(static_cast(thread_id), num_rows % num_threads); + const int64_t end_i = + (thread_id + 1) * (num_rows / num_threads) + + std::min(static_cast(thread_id + 1), num_rows % num_threads); + assert(thread_id + 1 < num_threads || end_i == num_rows); + + const int64_t num_local = end_i - start_i; + + std::unique_ptr local_prefix(new int64_t[num_local + 1]); + local_prefix[0] = 0; + for (int64_t i = start_i; i < end_i; ++i) { + // build prefix-sum + const int64_t local_i = i - start_i; + const IdxType rid = rows_data[i]; + if (map_seed_nodes) seed_mapping_data[rid] = i; + + IdxType len = num_picks_fn( + rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data); + local_prefix[local_i + 1] = local_prefix[local_i] + len; + } + global_prefix[thread_id + 1] = local_prefix[num_local]; + +#pragma omp barrier +#pragma omp master + { + for (int t = 0; t < num_threads; ++t) { + global_prefix[t + 1] += global_prefix[t]; + } + picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + picked_coo_rows = + IdArray::Empty({global_prefix[num_threads]}, idtype, ctx); + } + +#pragma omp barrier + IdxType* picked_cdata = picked_col.Ptr(); + IdxType* picked_idata = picked_idx.Ptr(); + IdxType* picked_rows = picked_coo_rows.Ptr(); + + const IdxType thread_offset = global_prefix[thread_id]; + + for (int64_t i = start_i; i < end_i; ++i) { + const IdxType rid = rows_data[i]; + const int64_t local_i = i - start_i; + block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset; + + const IdxType off = indptr[rid]; + const IdxType len = indptr[rid + 1] - off; + if (len == 0) continue; + + const int64_t row_offset = local_prefix[local_i] + thread_offset; + const int64_t num_picks = + local_prefix[local_i + 1] + thread_offset - row_offset; + + pick_fn( + rid, off, len, num_picks, indices, data, picked_idata + row_offset); + for (int64_t j = 0; j < num_picks; ++j) { + const IdxType picked = picked_idata[row_offset + j]; + picked_cdata[row_offset + j] = indices[picked]; + picked_idata[row_offset + j] = data ? data[picked] : picked; + picked_rows[row_offset + j] = i; + } + } + } + block_csr_indptr_data[num_rows] = global_prefix.back(); + + const IdxType num_cols = picked_col->shape[0]; + if (map_seed_nodes) { + new_seed_nodes.resize(num_rows); + memcpy(new_seed_nodes.data(), rows_data, sizeof(IdxType) * num_rows); + } + + return std::make_pair( + CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx), + picked_coo_rows); +} + // Template for picking non-zero values row-wise. The implementation utilizes // OpenMP parallelization on rows because each row performs computation // independently. diff --git a/src/array/cpu/rowwise_sampling.cc b/src/array/cpu/rowwise_sampling.cc index 56bb9a40bb4c..911f69b42371 100644 --- a/src/array/cpu/rowwise_sampling.cc +++ b/src/array/cpu/rowwise_sampling.cc @@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling( template COOMatrix CSRRowWiseSampling( CSRMatrix, IdArray, int64_t, NDArray, bool); +template < + DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes> +std::pair CSRRowWiseSamplingFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, + NDArray prob_or_mask, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + CHECK(prob_or_mask.defined()); + auto num_picks_fn = + GetSamplingNumPicksFn(num_samples, prob_or_mask, replace); + auto pick_fn = + GetSamplingPickFn(num_samples, prob_or_mask, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); +template std::pair +CSRRowWiseSamplingFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, NDArray, bool); + template COOMatrix CSRRowWisePerEtypeSampling( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, @@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform( template COOMatrix CSRRowWiseSamplingUniform( CSRMatrix, IdArray, int64_t, bool); +template +std::pair CSRRowWiseSamplingUniformFused( + CSRMatrix mat, IdArray rows, IdArray seed_mapping, + std::vector& new_seed_nodes, int64_t num_samples, bool replace) { + // If num_samples is -1, select all neighbors without replacement. + replace = (replace && num_samples != -1); + auto num_picks_fn = + GetSamplingUniformNumPicksFn(num_samples, replace); + auto pick_fn = GetSamplingUniformPickFn(num_samples, replace); + return CSRRowWisePickFused( + mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn, + num_picks_fn); +} + +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); +template std::pair +CSRRowWiseSamplingUniformFused( + CSRMatrix, IdArray, IdArray, std::vector&, int64_t, bool); + template COOMatrix CSRRowWisePerEtypeSamplingUniform( CSRMatrix mat, IdArray rows, const std::vector& eid2etype_offset, diff --git a/src/graph/sampling/neighbor/neighbor.cc b/src/graph/sampling/neighbor/neighbor.cc index 5b31c14c8a3c..d364e25633b1 100644 --- a/src/graph/sampling/neighbor/neighbor.cc +++ b/src/graph/sampling/neighbor/neighbor.cc @@ -6,8 +6,10 @@ #include #include +#include #include #include +#include #include #include @@ -22,6 +24,76 @@ using namespace dgl::aten; namespace dgl { namespace sampling { +template +void ExcludeCertainEdgesFused( + std::vector& sampled_graphs, std::vector& induced_edges, + std::vector& sampled_coo_rows, + const std::vector& exclude_edges, + std::vector* weights = nullptr) { + int etypes = sampled_graphs.size(); + std::vector remain_induced_edges(etypes); + std::vector remain_indptrs(etypes); + std::vector remain_indices(etypes); + std::vector remain_coo_rows(etypes); + std::vector remain_weights(etypes); + for (int etype = 0; etype < etypes; ++etype) { + if (exclude_edges[etype].GetSize() == 0 || + sampled_graphs[etype].num_rows == 0) { + remain_induced_edges[etype] = induced_edges[etype]; + if (weights) remain_weights[etype] = (*weights)[etype]; + continue; + } + const auto dtype = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype]->dtype + : DGLDataType{kDGLFloat, 8 * sizeof(float), 1}; + ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", { + IdType* indptr = sampled_graphs[etype].indptr.Ptr(); + IdType* indices = sampled_graphs[etype].indices.Ptr(); + IdType* coo_rows = sampled_coo_rows[etype].Ptr(); + IdType* induced_edges_data = induced_edges[etype].Ptr(); + FloatType* weights_data = weights && (*weights)[etype]->shape[0] + ? (*weights)[etype].Ptr() + : nullptr; + const IdType exclude_edges_len = exclude_edges[etype]->shape[0]; + std::sort( + exclude_edges[etype].Ptr(), + exclude_edges[etype].Ptr() + exclude_edges_len); + const IdType* exclude_edges_data = exclude_edges[etype].Ptr(); + IdType outIndices = 0; + for (IdType row = 0; row < sampled_graphs[etype].indptr->shape[0] - 1; + ++row) { + auto tmp_row = indptr[row]; + if (outIndices != indptr[row]) indptr[row] = outIndices; + for (IdType col = tmp_row; col < indptr[row + 1]; ++col) { + if (!std::binary_search( + exclude_edges_data, exclude_edges_data + exclude_edges_len, + induced_edges_data[col])) { + indices[outIndices] = indices[col]; + induced_edges_data[outIndices] = induced_edges_data[col]; + coo_rows[outIndices] = coo_rows[col]; + if (weights_data) weights_data[outIndices] = weights_data[col]; + ++outIndices; + } + } + } + indptr[sampled_graphs[etype].indptr->shape[0] - 1] = outIndices; + remain_induced_edges[etype] = + aten::IndexSelect(induced_edges[etype], 0, outIndices); + remain_weights[etype] = + weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices) + : NullArray(); + remain_indices[etype] = + aten::IndexSelect(sampled_graphs[etype].indices, 0, outIndices); + sampled_coo_rows[etype] = + aten::IndexSelect(sampled_coo_rows[etype], 0, outIndices); + sampled_graphs[etype] = CSRMatrix( + sampled_graphs[etype].num_rows, outIndices, + sampled_graphs[etype].indptr, remain_indices[etype], + remain_induced_edges[etype]); + }); + } +} + std::pair> ExcludeCertainEdges( const HeteroSubgraph& sg, const std::vector& exclude_edges, const std::vector* weights = nullptr) { @@ -267,6 +339,229 @@ HeteroSubgraph SampleNeighbors( return ret; } +template +std::tuple, std::vector> +SampleNeighborsFused( + const HeteroGraphPtr hg, const std::vector& nodes, + std::vector& mapping, const std::vector& fanouts, + EdgeDir dir, const std::vector& prob_or_mask, + const std::vector& exclude_edges, bool replace) { + CHECK_EQ(nodes.size(), hg->NumVertexTypes()) + << "Number of node ID tensors must match the number of node types."; + CHECK_EQ(fanouts.size(), hg->NumEdgeTypes()) + << "Number of fanout values must match the number of edge types."; + CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes()) + << "Number of probability tensors must match the number of edge types."; + + DGLContext ctx = aten::GetContextOf(nodes); + + std::vector sampled_graphs; + std::vector sampled_coo_rows; + std::vector induced_edges; + std::vector induced_vertices; + std::vector num_nodes_per_type; + std::vector> new_nodes_vec(hg->NumVertexTypes()); + std::vector seed_nodes_mapped(hg->NumVertexTypes(), 0); + + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t rhs_node_type = + (dir == EdgeDir::kOut) ? src_vtype : dst_vtype; + const IdArray nodes_ntype = nodes[rhs_node_type]; + const int64_t num_nodes = nodes_ntype->shape[0]; + + if (num_nodes == 0 || fanouts[etype] == 0) { + // Nothing to sample for this etype, create a placeholder + sampled_graphs.push_back(CSRMatrix()); + sampled_coo_rows.push_back(IdArray()); + induced_edges.push_back(aten::NullArray(hg->DataType(), ctx)); + } else { + bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type]; + // sample from one relation graph + std::pair sampled_graph; + auto sampling_fn = map_seed_nodes + ? aten::CSRRowWiseSamplingFused + : aten::CSRRowWiseSamplingFused; + auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE; + auto avail_fmt = hg->SelectFormat(etype, req_fmt); + switch (avail_fmt) { + case SparseFormat::kCSR: + CHECK(dir == EdgeDir::kOut) + << "Cannot sample out edges on CSC matrix."; + // In heterographs nodes of two diffrent types can be connected + // therefore two diffrent mappings and node vectors are needed + sampled_graph = sampling_fn( + hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype], + new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + case SparseFormat::kCSC: + CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix."; + sampled_graph = sampling_fn( + hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype], + new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype], + replace); + break; + default: + LOG(FATAL) << "Unsupported sparse format."; + } + seed_nodes_mapped[rhs_node_type]++; + sampled_graphs.push_back(sampled_graph.first); + if (sampled_graph.first.data.defined()) + induced_edges.push_back(sampled_graph.first.data); + else + induced_edges.push_back( + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx)); + sampled_coo_rows.push_back(sampled_graph.second); + } + } + + if (!exclude_edges.empty()) { + ExcludeCertainEdgesFused( + sampled_graphs, induced_edges, sampled_coo_rows, exclude_edges); + for (size_t i = 0; i < hg->NumEdgeTypes(); i++) { + if (sampled_graphs[i].data.defined()) + induced_edges[i] = std::move(sampled_graphs[i].data); + else + induced_edges[i] = + aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx); + } + } + + // map indices + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + const dgl_type_t lhs_node_type = + (dir == EdgeDir::kIn) ? src_vtype : dst_vtype; + if (sampled_graphs[etype].num_cols != 0) { + auto num_cols = sampled_graphs[etype].num_cols; + const int num_threads_col = runtime::compute_num_threads(0, num_cols, 1); + std::vector global_prefix_col(num_threads_col + 1, 0); + std::vector> src_nodes_local(num_threads_col); + IdType* mapping_data_dst = mapping[lhs_node_type].Ptr(); + IdType* cdata = sampled_graphs[etype].indices.Ptr(); +#pragma omp parallel num_threads(num_threads_col) + { + const int thread_id = omp_get_thread_num(); + + const int64_t start_i = + thread_id * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id), num_cols % num_threads_col); + const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) + + std::min( + static_cast(thread_id + 1), + num_cols % num_threads_col); + assert(thread_id + 1 < num_threads_col || end_i == num_cols); + for (int64_t i = start_i; i < end_i; ++i) { + int64_t picked_idx = cdata[i]; + bool spot_claimed = __sync_bool_compare_and_swap( + &mapping_data_dst[picked_idx], -1, 0); + if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx); + } + global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size(); + +#pragma omp barrier +#pragma omp master + { + global_prefix_col[0] = new_nodes_vec[lhs_node_type].size(); + for (int t = 0; t < num_threads_col; ++t) { + global_prefix_col[t + 1] += global_prefix_col[t]; + } + } + +#pragma omp barrier + int64_t mapping_shift = global_prefix_col[thread_id]; + for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i) + mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i; + +#pragma omp barrier + for (int64_t i = start_i; i < end_i; ++i) { + IdType picked_idx = cdata[i]; + IdType mapped_idx = mapping_data_dst[picked_idx]; + cdata[i] = mapped_idx; + } + } + IdType offset = new_nodes_vec[lhs_node_type].size(); + new_nodes_vec[lhs_node_type].resize(global_prefix_col.back()); + for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) { + memcpy( + new_nodes_vec[lhs_node_type].data() + offset, + &src_nodes_local[thread_id][0], + src_nodes_local[thread_id].size() * sizeof(IdType)); + offset += src_nodes_local[thread_id].size(); + } + } + } + + // counting how many nodes of each ntype were sampled + num_nodes_per_type.resize(2 * hg->NumVertexTypes()); + for (size_t i = 0; i < hg->NumVertexTypes(); i++) { + num_nodes_per_type[i] = new_nodes_vec[i].size(); + num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0]; + induced_vertices.push_back( + VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8)); + } + + std::vector subrels(hg->NumEdgeTypes()); + for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) { + auto pair = hg->meta_graph()->FindEdge(etype); + const dgl_type_t src_vtype = pair.first; + const dgl_type_t dst_vtype = pair.second; + if (sampled_graphs[etype].num_rows == 0) { + subrels[etype] = UnitGraph::Empty( + 2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + hg->DataType(), ctx); + } else { + CSRMatrix graph = sampled_graphs[etype]; + if (dir == EdgeDir::kOut) { + subrels[etype] = UnitGraph::CreateFromCSRAndCOO( + 2, + CSRMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(), + sampled_coo_rows[etype], graph.indices), + ALL_CODE); + } else { + subrels[etype] = UnitGraph::CreateFromCSCAndCOO( + 2, + CSRMatrix( + nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(), + graph.indptr, graph.indices, + Range( + 0, graph.indices->shape[0], graph.indices->dtype.bits, + ctx)), + COOMatrix( + new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0], + graph.indices, sampled_coo_rows[etype]), + ALL_CODE); + } + } + } + + HeteroSubgraph ret; + + const auto meta_graph = hg->meta_graph(); + const EdgeArray etypes = meta_graph->Edges("eid"); + const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes()); + + const auto new_meta_graph = ImmutableGraph::CreateFromCOO( + hg->NumVertexTypes() * 2, etypes.src, new_dst); + + HeteroGraphPtr new_graph = + CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type); + return std::make_tuple(new_graph, induced_edges, induced_vertices); +} + HeteroSubgraph SampleNeighborsEType( const HeteroGraphPtr hg, const IdArray nodes, const std::vector& eid2etype_offset, @@ -568,6 +863,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors") *rv = HeteroSubgraphRef(subg); }); +DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsFused") + .set_body([](DGLArgs args, DGLRetValue* rv) { + HeteroGraphRef hg = args[0]; + const auto& nodes = ListValueToVector(args[1]); + auto mapping = ListValueToVector(args[2]); + IdArray fanouts_array = args[3]; + const auto& fanouts = fanouts_array.ToVector(); + const std::string dir_str = args[4]; + const auto& prob_or_mask = ListValueToVector(args[5]); + const auto& exclude_edges = ListValueToVector(args[6]); + const bool replace = args[7]; + + CHECK(dir_str == "in" || dir_str == "out") + << "Invalid edge direction. Must be \"in\" or \"out\"."; + EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut; + + HeteroGraphPtr new_graph; + std::vector induced_edges; + std::vector induced_vertices; + + ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, { + std::tie(new_graph, induced_edges, induced_vertices) = + SampleNeighborsFused( + hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask, + exclude_edges, replace); + }); + + List lhs_nodes_ref; + for (IdArray& array : induced_vertices) + lhs_nodes_ref.push_back(Value(MakeValue(array))); + List induced_edges_ref; + for (IdArray& array : induced_edges) + induced_edges_ref.push_back(Value(MakeValue(array))); + List ret; + ret.push_back(HeteroGraphRef(new_graph)); + ret.push_back(lhs_nodes_ref); + ret.push_back(induced_edges_ref); + + *rv = ret; + }); + DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk") .set_body([](DGLArgs args, DGLRetValue* rv) { HeteroGraphRef hg = args[0]; diff --git a/src/graph/unit_graph.cc b/src/graph/unit_graph.cc index 3c26cf92f82e..b4e29de3a39a 100644 --- a/src/graph/unit_graph.cc +++ b/src/graph/unit_graph.cc @@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR( return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csr.num_rows); + CHECK_EQ(coo.num_cols, csr.num_cols); + if (num_vtypes == 1) { + CHECK_EQ(csr.num_rows, csr.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr csrPtr(new CSR(mg, csr)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, IdArray indices, IdArray edge_ids, dgl_format_code_t formats) { @@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC( return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats)); } +HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo, + dgl_format_code_t formats) { + CHECK(num_vtypes == 1 || num_vtypes == 2); + CHECK_EQ(coo.num_rows, csc.num_cols); + CHECK_EQ(coo.num_cols, csc.num_rows); + if (num_vtypes == 1) { + CHECK_EQ(csc.num_rows, csc.num_cols); + } + auto mg = CreateUnitGraphMetaGraph(num_vtypes); + CSRPtr cscPtr(new CSR(mg, csc)); + COOPtr cooPtr(new COO(mg, coo)); + return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats)); +} + HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) { if (g->NumBits() == bits) { return g; diff --git a/src/graph/unit_graph.h b/src/graph/unit_graph.h index 5951c5001f5c..244f922e906d 100644 --- a/src/graph/unit_graph.h +++ b/src/graph/unit_graph.h @@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (out) CSR and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSRAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csr, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC arrays */ static HeteroGraphPtr CreateFromCSC( int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr, @@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph { int64_t num_vtypes, const aten::CSRMatrix& mat, dgl_format_code_t formats = ALL_CODE); + /** @brief Create a graph from (in) CSC and COO arrays, both representing the + * same graph */ + static HeteroGraphPtr CreateFromCSCAndCOO( + int64_t num_vtypes, const aten::CSRMatrix& csc, + const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE); + /** @brief Convert the graph to use the given number of bits for storage */ static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits); diff --git a/tests/python/common/sampling/test_sampling.py b/tests/python/common/sampling/test_sampling.py index 4e2195a502fe..6d33fb9c0ed3 100644 --- a/tests/python/common/sampling/test_sampling.py +++ b/tests/python/common/sampling/test_sampling.py @@ -555,15 +555,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse): return g, hg -def _test_sample_neighbors(hypersparse, prob): +def _test_sample_neighbors(hypersparse, prob, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False) def _test1(p, replace): subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace + g, [0, 1], -1, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -577,11 +580,16 @@ def _test1(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 1], 2, prob=p, replace=replace + g, [0, 1], 2, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(v))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -601,10 +609,13 @@ def _test1(p, replace): def _test2(p, replace): # fanout > #neighbors subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace + g, [0, 2], -1, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -618,12 +629,15 @@ def _test2(p, replace): # fanout > #neighbors for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 2], 2, prob=p, replace=replace + g, [0, 2], 2, prob=p, replace=replace, fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v] assert set(F.asnumpy(F.unique(v))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -642,9 +656,17 @@ def _test2(p, replace): # fanout > #neighbors def _test3(p, replace): subg = dgl.sampling.sample_neighbors( - hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace + hg, + {"user": [0, 1], "game": 0}, + -1, + prob=p, + replace=replace, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -653,9 +675,17 @@ def _test3(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace + hg, + {"user": [0, 1], "game": 0}, + 2, + prob=p, + replace=replace, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -672,8 +702,12 @@ def _test3(p, replace): {"user": [0, 1], "game": 0, "coin": 0}, {"follow": 1, "play": 2, "liked-by": 0, "flips": -1}, replace=True, + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 2 assert subg["play"].num_edges() == 2 @@ -681,15 +715,19 @@ def _test3(p, replace): assert subg["flips"].num_edges() == 4 -def _test_sample_neighbors_outedge(hypersparse): +def _test_sample_neighbors_outedge(hypersparse, fused): g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True) def _test1(p, replace): subg = dgl.sampling.sample_neighbors( - g, [0, 1], -1, prob=p, replace=replace, edge_dir="out" + g, [0, 1], -1, prob=p, replace=replace, edge_dir="out", fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() + u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -703,11 +741,20 @@ def _test1(p, replace): for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 1], 2, prob=p, replace=replace, edge_dir="out" + g, + [0, 1], + 2, + prob=p, + replace=replace, + edge_dir="out", + fused=fused, ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() assert subg.num_edges() == 4 u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] assert set(F.asnumpy(F.unique(u))) == {0, 1} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -729,10 +776,13 @@ def _test1(p, replace): def _test2(p, replace): # fanout > #neighbors subg = dgl.sampling.sample_neighbors( - g, [0, 2], -1, prob=p, replace=replace, edge_dir="out" + g, [0, 2], -1, prob=p, replace=replace, edge_dir="out", fused=fused ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all") if p is not None: emask = F.gather_row(g.edata[p], e_ans) @@ -746,12 +796,22 @@ def _test2(p, replace): # fanout > #neighbors for i in range(10): subg = dgl.sampling.sample_neighbors( - g, [0, 2], 2, prob=p, replace=replace, edge_dir="out" + g, + [0, 2], + 2, + prob=p, + replace=replace, + edge_dir="out", + fused=fused, ) - assert subg.num_nodes() == g.num_nodes() + if not fused: + assert subg.num_nodes() == g.num_nodes() num_edges = 4 if replace else 3 assert subg.num_edges() == num_edges u, v = subg.edges() + if fused: + u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v] + assert set(F.asnumpy(F.unique(u))) == {0, 2} assert F.array_equal( F.astype(g.has_edges_between(u, v), F.int64), @@ -778,8 +838,13 @@ def _test3(p, replace): prob=p, replace=replace, edge_dir="out", + fused=fused, ) - assert len(subg.ntypes) == 3 + + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 6 if p is None else 4 assert subg["play"].num_edges() == 1 @@ -794,8 +859,12 @@ def _test3(p, replace): prob=p, replace=replace, edge_dir="out", + fused=fused, ) - assert len(subg.ntypes) == 3 + if not fused: + assert len(subg.ntypes) == 3 + assert len(subg.srctypes) == 3 + assert len(subg.dsttypes) == 3 assert len(subg.etypes) == 4 assert subg["follow"].num_edges() == 4 assert subg["play"].num_edges() == 2 if replace else 1 @@ -963,17 +1032,23 @@ def _test3(): def test_sample_neighbors_noprob(): - _test_sample_neighbors(False, None) + _test_sample_neighbors(False, None, False) + if F._default_context_str != "gpu": + _test_sample_neighbors(False, None, True) # _test_sample_neighbors(True) def test_sample_neighbors_prob(): - _test_sample_neighbors(False, "prob") + _test_sample_neighbors(False, "prob", False) + if F._default_context_str != "gpu": + _test_sample_neighbors(False, "prob", True) # _test_sample_neighbors(True) def test_sample_neighbors_outedge(): - _test_sample_neighbors_outedge(False) + _test_sample_neighbors_outedge(False, False) + if F._default_context_str != "gpu": + _test_sample_neighbors_outedge(False, True) # _test_sample_neighbors_outedge(True) @@ -985,7 +1060,8 @@ def test_sample_neighbors_outedge(): reason="GPU sample neighbors with mask not implemented", ) def test_sample_neighbors_mask(): - _test_sample_neighbors(False, "mask") + _test_sample_neighbors(False, "mask", False) + _test_sample_neighbors(False, "mask", True) @unittest.skipIf( @@ -1006,22 +1082,45 @@ def test_sample_neighbors_topk_outedge(): # _test_sample_neighbors_topk_outedge(True) -def test_sample_neighbors_with_0deg(): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_with_0deg(fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") g = dgl.graph(([], []), num_nodes=5).to(F.ctx()) sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="in", + replace=False, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="in", + replace=True, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="out", + replace=False, + fused=fused, ) assert sg.num_edges() == 0 sg = dgl.sampling.sample_neighbors( - g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True + g, + F.tensor([1, 2], dtype=F.int64), + 2, + edge_dir="out", + replace=True, + fused=fused, ) assert sg.num_edges() == 0 @@ -1152,7 +1251,7 @@ def check_num(nodes, tag): ) def test_sample_neighbors_biased_bipartite(): g = create_test_graph(100, 30, True) - num_dst = g.number_of_dst_nodes() + num_dst = g.num_dst_nodes() bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32) def check_num(nodes, tag): @@ -1370,7 +1469,10 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction): @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_heteroG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_heteroG(dtype, fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") d_i_d_u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1452,39 +1554,85 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype): }, sampled_amount, exclude_edges=excluded_edges, + fused=fused, ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - did_excluded_nodes_U, - did_excluded_nodes_V, - etype=("drug", "interacts", "drug"), + if fused: + + def contain_edge(g, sg, etype, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges[etype].data[dgl.EID], etype), + axis=1, + ), + ) ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) + + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge( + g, + sg, + ("drug", "interacts", "drug"), + did_excluded_nodes_U, + did_excluded_nodes_V, ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dig_excluded_nodes_U, - dig_excluded_nodes_V, - etype=("drug", "interacts", "gene"), + assert not contain_edge( + g, + sg, + ("drug", "interacts", "gene"), + dig_excluded_nodes_U, + dig_excluded_nodes_V, + ) + assert not contain_edge( + g, + sg, + ("drug", "treats", "disease"), + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + ) + else: + assert not np.any( + F.asnumpy( + sg.has_edges_between( + did_excluded_nodes_U, + did_excluded_nodes_V, + etype=("drug", "interacts", "drug"), + ) ) ) - ) - assert not np.any( - F.asnumpy( - sg.has_edges_between( - dtd_excluded_nodes_U, - dtd_excluded_nodes_V, - etype=("drug", "treats", "disease"), + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dig_excluded_nodes_U, + dig_excluded_nodes_V, + etype=("drug", "interacts", "gene"), + ) + ) + ) + assert not np.any( + F.asnumpy( + sg.has_edges_between( + dtd_excluded_nodes_U, + dtd_excluded_nodes_V, + etype=("drug", "treats", "disease"), + ) ) ) - ) @pytest.mark.parametrize("dtype", ["int32", "int64"]) -def test_sample_neighbors_exclude_edges_homoG(dtype): +@pytest.mark.parametrize("fused", [False, True]) +def test_sample_neighbors_exclude_edges_homoG(dtype, fused): + if fused and F._default_context_str == "gpu": + pytest.skip("Fused sampling doesn't support GPU.") u_nodes = F.zerocopy_from_numpy( np.unique(np.random.randint(300, size=100, dtype=dtype)) ) @@ -1508,12 +1656,36 @@ def test_sample_neighbors_exclude_edges_homoG(dtype): excluded_nodes_V = g_edges[V][b_idx:e_idx] sg = dgl.sampling.sample_neighbors( - g, sampled_node, sampled_amount, exclude_edges=excluded_edges + g, + sampled_node, + sampled_amount, + exclude_edges=excluded_edges, + fused=fused, ) + if fused: + + def contain_edge(g, sg, u, v): + # set of subgraph graph edges deduced from original graph + org_edges = set( + map( + tuple, + np.stack( + g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1 + ), + ) + ) + # set of excluded edges + excluded_edges = set(map(tuple, np.stack((u, v), axis=1))) - assert not np.any( - F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) - ) + diff_set = org_edges - excluded_edges + + return len(diff_set) != len(org_edges) + + assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V) + else: + assert not np.any( + F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V)) + ) @pytest.mark.parametrize("dtype", ["int32", "int64"])